639 lines
22 KiB
Python
639 lines
22 KiB
Python
|
import json
|
||
|
import warnings
|
||
|
|
||
|
from django import forms
|
||
|
from django.core import checks, exceptions
|
||
|
from django.db import NotSupportedError, connections, router
|
||
|
from django.db.models import expressions, lookups
|
||
|
from django.db.models.constants import LOOKUP_SEP
|
||
|
from django.db.models.fields import TextField
|
||
|
from django.db.models.lookups import (
|
||
|
FieldGetDbPrepValueMixin,
|
||
|
PostgresOperatorLookup,
|
||
|
Transform,
|
||
|
)
|
||
|
from django.utils.deprecation import RemovedInDjango51Warning
|
||
|
from django.utils.translation import gettext_lazy as _
|
||
|
|
||
|
from . import Field
|
||
|
from .mixins import CheckFieldDefaultMixin
|
||
|
|
||
|
__all__ = ["JSONField"]
|
||
|
|
||
|
|
||
|
class JSONField(CheckFieldDefaultMixin, Field):
|
||
|
empty_strings_allowed = False
|
||
|
description = _("A JSON object")
|
||
|
default_error_messages = {
|
||
|
"invalid": _("Value must be valid JSON."),
|
||
|
}
|
||
|
_default_hint = ("dict", "{}")
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
verbose_name=None,
|
||
|
name=None,
|
||
|
encoder=None,
|
||
|
decoder=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
if encoder and not callable(encoder):
|
||
|
raise ValueError("The encoder parameter must be a callable object.")
|
||
|
if decoder and not callable(decoder):
|
||
|
raise ValueError("The decoder parameter must be a callable object.")
|
||
|
self.encoder = encoder
|
||
|
self.decoder = decoder
|
||
|
super().__init__(verbose_name, name, **kwargs)
|
||
|
|
||
|
def check(self, **kwargs):
|
||
|
errors = super().check(**kwargs)
|
||
|
databases = kwargs.get("databases") or []
|
||
|
errors.extend(self._check_supported(databases))
|
||
|
return errors
|
||
|
|
||
|
def _check_supported(self, databases):
|
||
|
errors = []
|
||
|
for db in databases:
|
||
|
if not router.allow_migrate_model(db, self.model):
|
||
|
continue
|
||
|
connection = connections[db]
|
||
|
if (
|
||
|
self.model._meta.required_db_vendor
|
||
|
and self.model._meta.required_db_vendor != connection.vendor
|
||
|
):
|
||
|
continue
|
||
|
if not (
|
||
|
"supports_json_field" in self.model._meta.required_db_features
|
||
|
or connection.features.supports_json_field
|
||
|
):
|
||
|
errors.append(
|
||
|
checks.Error(
|
||
|
"%s does not support JSONFields." % connection.display_name,
|
||
|
obj=self.model,
|
||
|
id="fields.E180",
|
||
|
)
|
||
|
)
|
||
|
return errors
|
||
|
|
||
|
def deconstruct(self):
|
||
|
name, path, args, kwargs = super().deconstruct()
|
||
|
if self.encoder is not None:
|
||
|
kwargs["encoder"] = self.encoder
|
||
|
if self.decoder is not None:
|
||
|
kwargs["decoder"] = self.decoder
|
||
|
return name, path, args, kwargs
|
||
|
|
||
|
def from_db_value(self, value, expression, connection):
|
||
|
if value is None:
|
||
|
return value
|
||
|
# Some backends (SQLite at least) extract non-string values in their
|
||
|
# SQL datatypes.
|
||
|
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||
|
return value
|
||
|
try:
|
||
|
return json.loads(value, cls=self.decoder)
|
||
|
except json.JSONDecodeError:
|
||
|
return value
|
||
|
|
||
|
def get_internal_type(self):
|
||
|
return "JSONField"
|
||
|
|
||
|
def get_db_prep_value(self, value, connection, prepared=False):
|
||
|
if not prepared:
|
||
|
value = self.get_prep_value(value)
|
||
|
# RemovedInDjango51Warning: When the deprecation ends, replace with:
|
||
|
# if (
|
||
|
# isinstance(value, expressions.Value)
|
||
|
# and isinstance(value.output_field, JSONField)
|
||
|
# ):
|
||
|
# value = value.value
|
||
|
# elif hasattr(value, "as_sql"): ...
|
||
|
if isinstance(value, expressions.Value):
|
||
|
if isinstance(value.value, str) and not isinstance(
|
||
|
value.output_field, JSONField
|
||
|
):
|
||
|
try:
|
||
|
value = json.loads(value.value, cls=self.decoder)
|
||
|
except json.JSONDecodeError:
|
||
|
value = value.value
|
||
|
else:
|
||
|
warnings.warn(
|
||
|
"Providing an encoded JSON string via Value() is deprecated. "
|
||
|
f"Use Value({value!r}, output_field=JSONField()) instead.",
|
||
|
category=RemovedInDjango51Warning,
|
||
|
)
|
||
|
elif isinstance(value.output_field, JSONField):
|
||
|
value = value.value
|
||
|
else:
|
||
|
return value
|
||
|
elif hasattr(value, "as_sql"):
|
||
|
return value
|
||
|
return connection.ops.adapt_json_value(value, self.encoder)
|
||
|
|
||
|
def get_db_prep_save(self, value, connection):
|
||
|
if value is None:
|
||
|
return value
|
||
|
return self.get_db_prep_value(value, connection)
|
||
|
|
||
|
def get_transform(self, name):
|
||
|
transform = super().get_transform(name)
|
||
|
if transform:
|
||
|
return transform
|
||
|
return KeyTransformFactory(name)
|
||
|
|
||
|
def validate(self, value, model_instance):
|
||
|
super().validate(value, model_instance)
|
||
|
try:
|
||
|
json.dumps(value, cls=self.encoder)
|
||
|
except TypeError:
|
||
|
raise exceptions.ValidationError(
|
||
|
self.error_messages["invalid"],
|
||
|
code="invalid",
|
||
|
params={"value": value},
|
||
|
)
|
||
|
|
||
|
def value_to_string(self, obj):
|
||
|
return self.value_from_object(obj)
|
||
|
|
||
|
def formfield(self, **kwargs):
|
||
|
return super().formfield(
|
||
|
**{
|
||
|
"form_class": forms.JSONField,
|
||
|
"encoder": self.encoder,
|
||
|
"decoder": self.decoder,
|
||
|
**kwargs,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
|
||
|
def compile_json_path(key_transforms, include_root=True):
|
||
|
path = ["$"] if include_root else []
|
||
|
for key_transform in key_transforms:
|
||
|
try:
|
||
|
num = int(key_transform)
|
||
|
except ValueError: # non-integer
|
||
|
path.append(".")
|
||
|
path.append(json.dumps(key_transform))
|
||
|
else:
|
||
|
path.append("[%s]" % num)
|
||
|
return "".join(path)
|
||
|
|
||
|
|
||
|
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||
|
lookup_name = "contains"
|
||
|
postgres_operator = "@>"
|
||
|
|
||
|
def as_sql(self, compiler, connection):
|
||
|
if not connection.features.supports_json_field_contains:
|
||
|
raise NotSupportedError(
|
||
|
"contains lookup is not supported on this database backend."
|
||
|
)
|
||
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||
|
params = tuple(lhs_params) + tuple(rhs_params)
|
||
|
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||
|
|
||
|
|
||
|
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||
|
lookup_name = "contained_by"
|
||
|
postgres_operator = "<@"
|
||
|
|
||
|
def as_sql(self, compiler, connection):
|
||
|
if not connection.features.supports_json_field_contains:
|
||
|
raise NotSupportedError(
|
||
|
"contained_by lookup is not supported on this database backend."
|
||
|
)
|
||
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||
|
params = tuple(rhs_params) + tuple(lhs_params)
|
||
|
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||
|
|
||
|
|
||
|
class HasKeyLookup(PostgresOperatorLookup):
|
||
|
logical_operator = None
|
||
|
|
||
|
def compile_json_path_final_key(self, key_transform):
|
||
|
# Compile the final key without interpreting ints as array elements.
|
||
|
return ".%s" % json.dumps(key_transform)
|
||
|
|
||
|
def as_sql(self, compiler, connection, template=None):
|
||
|
# Process JSON path from the left-hand side.
|
||
|
if isinstance(self.lhs, KeyTransform):
|
||
|
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||
|
compiler, connection
|
||
|
)
|
||
|
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||
|
else:
|
||
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||
|
lhs_json_path = "$"
|
||
|
sql = template % lhs
|
||
|
# Process JSON path from the right-hand side.
|
||
|
rhs = self.rhs
|
||
|
rhs_params = []
|
||
|
if not isinstance(rhs, (list, tuple)):
|
||
|
rhs = [rhs]
|
||
|
for key in rhs:
|
||
|
if isinstance(key, KeyTransform):
|
||
|
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||
|
else:
|
||
|
rhs_key_transforms = [key]
|
||
|
*rhs_key_transforms, final_key = rhs_key_transforms
|
||
|
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
|
||
|
rhs_json_path += self.compile_json_path_final_key(final_key)
|
||
|
rhs_params.append(lhs_json_path + rhs_json_path)
|
||
|
# Add condition for each key.
|
||
|
if self.logical_operator:
|
||
|
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
|
||
|
return sql, tuple(lhs_params) + tuple(rhs_params)
|
||
|
|
||
|
def as_mysql(self, compiler, connection):
|
||
|
return self.as_sql(
|
||
|
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
|
||
|
)
|
||
|
|
||
|
def as_oracle(self, compiler, connection):
|
||
|
sql, params = self.as_sql(
|
||
|
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
|
||
|
)
|
||
|
# Add paths directly into SQL because path expressions cannot be passed
|
||
|
# as bind variables on Oracle.
|
||
|
return sql % tuple(params), []
|
||
|
|
||
|
def as_postgresql(self, compiler, connection):
|
||
|
if isinstance(self.rhs, KeyTransform):
|
||
|
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||
|
for key in rhs_key_transforms[:-1]:
|
||
|
self.lhs = KeyTransform(key, self.lhs)
|
||
|
self.rhs = rhs_key_transforms[-1]
|
||
|
return super().as_postgresql(compiler, connection)
|
||
|
|
||
|
def as_sqlite(self, compiler, connection):
|
||
|
return self.as_sql(
|
||
|
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
|
||
|
)
|
||
|
|
||
|
|
||
|
class HasKey(HasKeyLookup):
|
||
|
lookup_name = "has_key"
|
||
|
postgres_operator = "?"
|
||
|
prepare_rhs = False
|
||
|
|
||
|
|
||
|
class HasKeys(HasKeyLookup):
|
||
|
lookup_name = "has_keys"
|
||
|
postgres_operator = "?&"
|
||
|
logical_operator = " AND "
|
||
|
|
||
|
def get_prep_lookup(self):
|
||
|
return [str(item) for item in self.rhs]
|
||
|
|
||
|
|
||
|
class HasAnyKeys(HasKeys):
|
||
|
lookup_name = "has_any_keys"
|
||
|
postgres_operator = "?|"
|
||
|
logical_operator = " OR "
|
||
|
|
||
|
|
||
|
class HasKeyOrArrayIndex(HasKey):
|
||
|
def compile_json_path_final_key(self, key_transform):
|
||
|
return compile_json_path([key_transform], include_root=False)
|
||
|
|
||
|
|
||
|
class CaseInsensitiveMixin:
|
||
|
"""
|
||
|
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||
|
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||
|
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||
|
case-sensitive.
|
||
|
"""
|
||
|
|
||
|
def process_lhs(self, compiler, connection):
|
||
|
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||
|
if connection.vendor == "mysql":
|
||
|
return "LOWER(%s)" % lhs, lhs_params
|
||
|
return lhs, lhs_params
|
||
|
|
||
|
def process_rhs(self, compiler, connection):
|
||
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||
|
if connection.vendor == "mysql":
|
||
|
return "LOWER(%s)" % rhs, rhs_params
|
||
|
return rhs, rhs_params
|
||
|
|
||
|
|
||
|
class JSONExact(lookups.Exact):
|
||
|
can_use_none_as_rhs = True
|
||
|
|
||
|
def process_rhs(self, compiler, connection):
|
||
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||
|
# Treat None lookup values as null.
|
||
|
if rhs == "%s" and rhs_params == [None]:
|
||
|
rhs_params = ["null"]
|
||
|
if connection.vendor == "mysql":
|
||
|
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||
|
rhs %= tuple(func)
|
||
|
return rhs, rhs_params
|
||
|
|
||
|
|
||
|
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||
|
pass
|
||
|
|
||
|
|
||
|
JSONField.register_lookup(DataContains)
|
||
|
JSONField.register_lookup(ContainedBy)
|
||
|
JSONField.register_lookup(HasKey)
|
||
|
JSONField.register_lookup(HasKeys)
|
||
|
JSONField.register_lookup(HasAnyKeys)
|
||
|
JSONField.register_lookup(JSONExact)
|
||
|
JSONField.register_lookup(JSONIContains)
|
||
|
|
||
|
|
||
|
class KeyTransform(Transform):
|
||
|
postgres_operator = "->"
|
||
|
postgres_nested_operator = "#>"
|
||
|
|
||
|
def __init__(self, key_name, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.key_name = str(key_name)
|
||
|
|
||
|
def preprocess_lhs(self, compiler, connection):
|
||
|
key_transforms = [self.key_name]
|
||
|
previous = self.lhs
|
||
|
while isinstance(previous, KeyTransform):
|
||
|
key_transforms.insert(0, previous.key_name)
|
||
|
previous = previous.lhs
|
||
|
lhs, params = compiler.compile(previous)
|
||
|
if connection.vendor == "oracle":
|
||
|
# Escape string-formatting.
|
||
|
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||
|
return lhs, params, key_transforms
|
||
|
|
||
|
def as_mysql(self, compiler, connection):
|
||
|
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||
|
json_path = compile_json_path(key_transforms)
|
||
|
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||
|
|
||
|
def as_oracle(self, compiler, connection):
|
||
|
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||
|
json_path = compile_json_path(key_transforms)
|
||
|
return (
|
||
|
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
|
||
|
% ((lhs, json_path) * 2)
|
||
|
), tuple(params) * 2
|
||
|
|
||
|
def as_postgresql(self, compiler, connection):
|
||
|
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||
|
if len(key_transforms) > 1:
|
||
|
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||
|
return sql, tuple(params) + (key_transforms,)
|
||
|
try:
|
||
|
lookup = int(self.key_name)
|
||
|
except ValueError:
|
||
|
lookup = self.key_name
|
||
|
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||
|
|
||
|
def as_sqlite(self, compiler, connection):
|
||
|
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||
|
json_path = compile_json_path(key_transforms)
|
||
|
datatype_values = ",".join(
|
||
|
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||
|
)
|
||
|
return (
|
||
|
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||
|
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||
|
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||
|
|
||
|
|
||
|
class KeyTextTransform(KeyTransform):
|
||
|
postgres_operator = "->>"
|
||
|
postgres_nested_operator = "#>>"
|
||
|
output_field = TextField()
|
||
|
|
||
|
def as_mysql(self, compiler, connection):
|
||
|
if connection.mysql_is_mariadb:
|
||
|
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
|
||
|
sql, params = super().as_mysql(compiler, connection)
|
||
|
return "JSON_UNQUOTE(%s)" % sql, params
|
||
|
else:
|
||
|
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||
|
json_path = compile_json_path(key_transforms)
|
||
|
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
|
||
|
|
||
|
@classmethod
|
||
|
def from_lookup(cls, lookup):
|
||
|
transform, *keys = lookup.split(LOOKUP_SEP)
|
||
|
if not keys:
|
||
|
raise ValueError("Lookup must contain key or index transforms.")
|
||
|
for key in keys:
|
||
|
transform = cls(key, transform)
|
||
|
return transform
|
||
|
|
||
|
|
||
|
KT = KeyTextTransform.from_lookup
|
||
|
|
||
|
|
||
|
class KeyTransformTextLookupMixin:
|
||
|
"""
|
||
|
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||
|
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||
|
key values to text and performing the lookup on the resulting
|
||
|
representation.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, key_transform, *args, **kwargs):
|
||
|
if not isinstance(key_transform, KeyTransform):
|
||
|
raise TypeError(
|
||
|
"Transform should be an instance of KeyTransform in order to "
|
||
|
"use this lookup."
|
||
|
)
|
||
|
key_text_transform = KeyTextTransform(
|
||
|
key_transform.key_name,
|
||
|
*key_transform.source_expressions,
|
||
|
**key_transform.extra,
|
||
|
)
|
||
|
super().__init__(key_text_transform, *args, **kwargs)
|
||
|
|
||
|
|
||
|
class KeyTransformIsNull(lookups.IsNull):
|
||
|
# key__isnull=False is the same as has_key='key'
|
||
|
def as_oracle(self, compiler, connection):
|
||
|
sql, params = HasKeyOrArrayIndex(
|
||
|
self.lhs.lhs,
|
||
|
self.lhs.key_name,
|
||
|
).as_oracle(compiler, connection)
|
||
|
if not self.rhs:
|
||
|
return sql, params
|
||
|
# Column doesn't have a key or IS NULL.
|
||
|
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||
|
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||
|
|
||
|
def as_sqlite(self, compiler, connection):
|
||
|
template = "JSON_TYPE(%s, %%s) IS NULL"
|
||
|
if not self.rhs:
|
||
|
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
|
||
|
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||
|
compiler,
|
||
|
connection,
|
||
|
template=template,
|
||
|
)
|
||
|
|
||
|
|
||
|
class KeyTransformIn(lookups.In):
|
||
|
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||
|
sql, params = super().resolve_expression_parameter(
|
||
|
compiler,
|
||
|
connection,
|
||
|
sql,
|
||
|
param,
|
||
|
)
|
||
|
if (
|
||
|
not hasattr(param, "as_sql")
|
||
|
and not connection.features.has_native_json_field
|
||
|
):
|
||
|
if connection.vendor == "oracle":
|
||
|
value = json.loads(param)
|
||
|
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||
|
if isinstance(value, (list, dict)):
|
||
|
sql %= "JSON_QUERY"
|
||
|
else:
|
||
|
sql %= "JSON_VALUE"
|
||
|
elif connection.vendor == "mysql" or (
|
||
|
connection.vendor == "sqlite"
|
||
|
and params[0] not in connection.ops.jsonfield_datatype_values
|
||
|
):
|
||
|
sql = "JSON_EXTRACT(%s, '$')"
|
||
|
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||
|
sql = "JSON_UNQUOTE(%s)" % sql
|
||
|
return sql, params
|
||
|
|
||
|
|
||
|
class KeyTransformExact(JSONExact):
|
||
|
def process_rhs(self, compiler, connection):
|
||
|
if isinstance(self.rhs, KeyTransform):
|
||
|
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||
|
if connection.vendor == "oracle":
|
||
|
func = []
|
||
|
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||
|
for value in rhs_params:
|
||
|
value = json.loads(value)
|
||
|
if isinstance(value, (list, dict)):
|
||
|
func.append(sql % "JSON_QUERY")
|
||
|
else:
|
||
|
func.append(sql % "JSON_VALUE")
|
||
|
rhs %= tuple(func)
|
||
|
elif connection.vendor == "sqlite":
|
||
|
func = []
|
||
|
for value in rhs_params:
|
||
|
if value in connection.ops.jsonfield_datatype_values:
|
||
|
func.append("%s")
|
||
|
else:
|
||
|
func.append("JSON_EXTRACT(%s, '$')")
|
||
|
rhs %= tuple(func)
|
||
|
return rhs, rhs_params
|
||
|
|
||
|
def as_oracle(self, compiler, connection):
|
||
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||
|
if rhs_params == ["null"]:
|
||
|
# Field has key and it's NULL.
|
||
|
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
|
||
|
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||
|
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||
|
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||
|
return (
|
||
|
"%s AND %s" % (has_key_sql, is_null_sql),
|
||
|
tuple(has_key_params) + tuple(is_null_params),
|
||
|
)
|
||
|
return super().as_sql(compiler, connection)
|
||
|
|
||
|
|
||
|
class KeyTransformIExact(
|
||
|
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformIContains(
|
||
|
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformIStartsWith(
|
||
|
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformIEndsWith(
|
||
|
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformIRegex(
|
||
|
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformNumericLookupMixin:
|
||
|
def process_rhs(self, compiler, connection):
|
||
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||
|
if not connection.features.has_native_json_field:
|
||
|
rhs_params = [json.loads(value) for value in rhs_params]
|
||
|
return rhs, rhs_params
|
||
|
|
||
|
|
||
|
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||
|
pass
|
||
|
|
||
|
|
||
|
KeyTransform.register_lookup(KeyTransformIn)
|
||
|
KeyTransform.register_lookup(KeyTransformExact)
|
||
|
KeyTransform.register_lookup(KeyTransformIExact)
|
||
|
KeyTransform.register_lookup(KeyTransformIsNull)
|
||
|
KeyTransform.register_lookup(KeyTransformIContains)
|
||
|
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||
|
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||
|
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||
|
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||
|
KeyTransform.register_lookup(KeyTransformRegex)
|
||
|
KeyTransform.register_lookup(KeyTransformIRegex)
|
||
|
|
||
|
KeyTransform.register_lookup(KeyTransformLt)
|
||
|
KeyTransform.register_lookup(KeyTransformLte)
|
||
|
KeyTransform.register_lookup(KeyTransformGt)
|
||
|
KeyTransform.register_lookup(KeyTransformGte)
|
||
|
|
||
|
|
||
|
class KeyTransformFactory:
|
||
|
def __init__(self, key_name):
|
||
|
self.key_name = key_name
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
return KeyTransform(self.key_name, *args, **kwargs)
|