mysteriendrama/lib/python3.11/site-packages/django/db/models/fields/json.py
2023-07-22 12:13:39 +02:00

639 lines
22 KiB
Python

import json
import warnings
from django import forms
from django.core import checks, exceptions
from django.db import NotSupportedError, connections, router
from django.db.models import expressions, lookups
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import TextField
from django.db.models.lookups import (
FieldGetDbPrepValueMixin,
PostgresOperatorLookup,
Transform,
)
from django.utils.deprecation import RemovedInDjango51Warning
from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
__all__ = ["JSONField"]
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
description = _("A JSON object")
default_error_messages = {
"invalid": _("Value must be valid JSON."),
}
_default_hint = ("dict", "{}")
def __init__(
self,
verbose_name=None,
name=None,
encoder=None,
decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
raise ValueError("The encoder parameter must be a callable object.")
if decoder and not callable(decoder):
raise ValueError("The decoder parameter must be a callable object.")
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
databases = kwargs.get("databases") or []
errors.extend(self._check_supported(databases))
return errors
def _check_supported(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
"supports_json_field" in self.model._meta.required_db_features
or connection.features.supports_json_field
):
errors.append(
checks.Error(
"%s does not support JSONFields." % connection.display_name,
obj=self.model,
id="fields.E180",
)
)
return errors
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
kwargs["encoder"] = self.encoder
if self.decoder is not None:
kwargs["decoder"] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
if value is None:
return value
# Some backends (SQLite at least) extract non-string values in their
# SQL datatypes.
if isinstance(expression, KeyTransform) and not isinstance(value, str):
return value
try:
return json.loads(value, cls=self.decoder)
except json.JSONDecodeError:
return value
def get_internal_type(self):
return "JSONField"
def get_db_prep_value(self, value, connection, prepared=False):
if not prepared:
value = self.get_prep_value(value)
# RemovedInDjango51Warning: When the deprecation ends, replace with:
# if (
# isinstance(value, expressions.Value)
# and isinstance(value.output_field, JSONField)
# ):
# value = value.value
# elif hasattr(value, "as_sql"): ...
if isinstance(value, expressions.Value):
if isinstance(value.value, str) and not isinstance(
value.output_field, JSONField
):
try:
value = json.loads(value.value, cls=self.decoder)
except json.JSONDecodeError:
value = value.value
else:
warnings.warn(
"Providing an encoded JSON string via Value() is deprecated. "
f"Use Value({value!r}, output_field=JSONField()) instead.",
category=RemovedInDjango51Warning,
)
elif isinstance(value.output_field, JSONField):
value = value.value
else:
return value
elif hasattr(value, "as_sql"):
return value
return connection.ops.adapt_json_value(value, self.encoder)
def get_db_prep_save(self, value, connection):
if value is None:
return value
return self.get_db_prep_value(value, connection)
def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)
def validate(self, value, model_instance):
super().validate(value, model_instance)
try:
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
self.error_messages["invalid"],
code="invalid",
params={"value": value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.JSONField,
"encoder": self.encoder,
"decoder": self.decoder,
**kwargs,
}
)
def compile_json_path(key_transforms, include_root=True):
path = ["$"] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
path.append(".")
path.append(json.dumps(key_transform))
else:
path.append("[%s]" % num)
return "".join(path)
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contains"
postgres_operator = "@>"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contains lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contained_by"
postgres_operator = "<@"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contained_by lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
logical_operator = None
def compile_json_path_final_key(self, key_transform):
# Compile the final key without interpreting ints as array elements.
return ".%s" % json.dumps(key_transform)
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
compiler, connection
)
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
lhs_json_path = "$"
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
rhs_params = []
if not isinstance(rhs, (list, tuple)):
rhs = [rhs]
for key in rhs:
if isinstance(key, KeyTransform):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
*rhs_key_transforms, final_key = rhs_key_transforms
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
rhs_json_path += self.compile_json_path_final_key(final_key)
rhs_params.append(lhs_json_path + rhs_json_path)
# Add condition for each key.
if self.logical_operator:
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
)
def as_oracle(self, compiler, connection):
sql, params = self.as_sql(
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
)
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
def as_postgresql(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
for key in rhs_key_transforms[:-1]:
self.lhs = KeyTransform(key, self.lhs)
self.rhs = rhs_key_transforms[-1]
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
)
class HasKey(HasKeyLookup):
lookup_name = "has_key"
postgres_operator = "?"
prepare_rhs = False
class HasKeys(HasKeyLookup):
lookup_name = "has_keys"
postgres_operator = "?&"
logical_operator = " AND "
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
lookup_name = "has_any_keys"
postgres_operator = "?|"
logical_operator = " OR "
class HasKeyOrArrayIndex(HasKey):
def compile_json_path_final_key(self, key_transform):
return compile_json_path([key_transform], include_root=False)
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % rhs, rhs_params
return rhs, rhs_params
class JSONExact(lookups.Exact):
can_use_none_as_rhs = True
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
if rhs == "%s" and rhs_params == [None]:
rhs_params = ["null"]
if connection.vendor == "mysql":
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs %= tuple(func)
return rhs, rhs_params
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
postgres_operator = "->"
postgres_nested_operator = "#>"
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
if connection.vendor == "oracle":
# Escape string-formatting.
key_transforms = [key.replace("%", "%%") for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
% ((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
datatype_values = ",".join(
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
)
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
class KeyTextTransform(KeyTransform):
postgres_operator = "->>"
postgres_nested_operator = "#>>"
output_field = TextField()
def as_mysql(self, compiler, connection):
if connection.mysql_is_mariadb:
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
sql, params = super().as_mysql(compiler, connection)
return "JSON_UNQUOTE(%s)" % sql, params
else:
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
@classmethod
def from_lookup(cls, lookup):
transform, *keys = lookup.split(LOOKUP_SEP)
if not keys:
raise ValueError("Lookup must contain key or index transforms.")
for key in keys:
transform = cls(key, transform)
return transform
KT = KeyTextTransform.from_lookup
class KeyTransformTextLookupMixin:
"""
Mixin for combining with a lookup expecting a text lhs from a JSONField
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
key values to text and performing the lookup on the resulting
representation.
"""
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
"Transform should be an instance of KeyTransform in order to "
"use this lookup."
)
key_text_transform = KeyTextTransform(
key_transform.key_name,
*key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
class KeyTransformIsNull(lookups.IsNull):
# key__isnull=False is the same as has_key='key'
def as_oracle(self, compiler, connection):
sql, params = HasKeyOrArrayIndex(
self.lhs.lhs,
self.lhs.key_name,
).as_oracle(compiler, connection)
if not self.rhs:
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
template = "JSON_TYPE(%s, %%s) IS NULL"
if not self.rhs:
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
template=template,
)
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler,
connection,
sql,
param,
)
if (
not hasattr(param, "as_sql")
and not connection.features.has_native_json_field
):
if connection.vendor == "oracle":
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql %= "JSON_QUERY"
else:
sql %= "JSON_VALUE"
elif connection.vendor == "mysql" or (
connection.vendor == "sqlite"
and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
sql = "JSON_UNQUOTE(%s)" % sql
return sql, params
class KeyTransformExact(JSONExact):
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "oracle":
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
func.append(sql % "JSON_QUERY")
else:
func.append(sql % "JSON_VALUE")
rhs %= tuple(func)
elif connection.vendor == "sqlite":
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
func.append("%s")
else:
func.append("JSON_EXTRACT(%s, '$')")
rhs %= tuple(func)
return rhs, rhs_params
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs_params == ["null"]:
# Field has key and it's NULL.
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
"%s AND %s" % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
class KeyTransformIExact(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
):
pass
class KeyTransformIContains(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
):
pass
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
class KeyTransformIStartsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
):
pass
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
class KeyTransformIEndsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
):
pass
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
class KeyTransformIRegex(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
):
pass
class KeyTransformNumericLookupMixin:
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if not connection.features.has_native_json_field:
rhs_params = [json.loads(value) for value in rhs_params]
return rhs, rhs_params
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
pass
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
pass
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
pass
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
pass
KeyTransform.register_lookup(KeyTransformIn)
KeyTransform.register_lookup(KeyTransformExact)
KeyTransform.register_lookup(KeyTransformIExact)
KeyTransform.register_lookup(KeyTransformIsNull)
KeyTransform.register_lookup(KeyTransformIContains)
KeyTransform.register_lookup(KeyTransformStartsWith)
KeyTransform.register_lookup(KeyTransformIStartsWith)
KeyTransform.register_lookup(KeyTransformEndsWith)
KeyTransform.register_lookup(KeyTransformIEndsWith)
KeyTransform.register_lookup(KeyTransformRegex)
KeyTransform.register_lookup(KeyTransformIRegex)
KeyTransform.register_lookup(KeyTransformLt)
KeyTransform.register_lookup(KeyTransformLte)
KeyTransform.register_lookup(KeyTransformGt)
KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name
def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)