Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/fields/json.py: 40%
384 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
1import json
2import warnings
4from django import forms
5from django.core import checks, exceptions
6from django.db import NotSupportedError, connections, router
7from django.db.models import expressions, lookups
8from django.db.models.constants import LOOKUP_SEP
9from django.db.models.fields import TextField
10from django.db.models.lookups import (
11 FieldGetDbPrepValueMixin,
12 PostgresOperatorLookup,
13 Transform,
14)
15from django.utils.deprecation import RemovedInDjango51Warning
16from django.utils.translation import gettext_lazy as _
18from . import Field
19from .mixins import CheckFieldDefaultMixin
21__all__ = ["JSONField"]
24class JSONField(CheckFieldDefaultMixin, Field):
25 empty_strings_allowed = False
26 description = _("A JSON object")
27 default_error_messages = {
28 "invalid": _("Value must be valid JSON."),
29 }
30 _default_hint = ("dict", "{}")
32 def __init__(
33 self,
34 verbose_name=None,
35 name=None,
36 encoder=None,
37 decoder=None,
38 **kwargs,
39 ):
40 if encoder and not callable(encoder):
41 raise ValueError("The encoder parameter must be a callable object.")
42 if decoder and not callable(decoder):
43 raise ValueError("The decoder parameter must be a callable object.")
44 self.encoder = encoder
45 self.decoder = decoder
46 super().__init__(verbose_name, name, **kwargs)
48 def check(self, **kwargs):
49 errors = super().check(**kwargs)
50 databases = kwargs.get("databases") or []
51 errors.extend(self._check_supported(databases))
52 return errors
54 def _check_supported(self, databases):
55 errors = []
56 for db in databases:
57 if not router.allow_migrate_model(db, self.model):
58 continue
59 connection = connections[db]
60 if (
61 self.model._meta.required_db_vendor
62 and self.model._meta.required_db_vendor != connection.vendor
63 ):
64 continue
65 if not (
66 "supports_json_field" in self.model._meta.required_db_features
67 or connection.features.supports_json_field
68 ):
69 errors.append(
70 checks.Error(
71 "%s does not support JSONFields." % connection.display_name,
72 obj=self.model,
73 id="fields.E180",
74 )
75 )
76 return errors
78 def deconstruct(self):
79 name, path, args, kwargs = super().deconstruct()
80 if self.encoder is not None:
81 kwargs["encoder"] = self.encoder
82 if self.decoder is not None:
83 kwargs["decoder"] = self.decoder
84 return name, path, args, kwargs
86 def from_db_value(self, value, expression, connection):
87 if value is None:
88 return value
89 # Some backends (SQLite at least) extract non-string values in their
90 # SQL datatypes.
91 if isinstance(expression, KeyTransform) and not isinstance(value, str):
92 return value
93 try:
94 return json.loads(value, cls=self.decoder)
95 except json.JSONDecodeError:
96 return value
98 def get_internal_type(self):
99 return "JSONField"
101 def get_db_prep_value(self, value, connection, prepared=False):
102 # RemovedInDjango51Warning: When the deprecation ends, replace with:
103 # if (
104 # isinstance(value, expressions.Value)
105 # and isinstance(value.output_field, JSONField)
106 # ):
107 # value = value.value
108 # elif hasattr(value, "as_sql"): ...
109 if isinstance(value, expressions.Value):
110 if isinstance(value.value, str) and not isinstance(
111 value.output_field, JSONField
112 ):
113 try:
114 value = json.loads(value.value, cls=self.decoder)
115 except json.JSONDecodeError:
116 value = value.value
117 else:
118 warnings.warn(
119 "Providing an encoded JSON string via Value() is deprecated. "
120 f"Use Value({value!r}, output_field=JSONField()) instead.",
121 category=RemovedInDjango51Warning,
122 )
123 elif isinstance(value.output_field, JSONField):
124 value = value.value
125 else:
126 return value
127 elif hasattr(value, "as_sql"):
128 return value
129 return connection.ops.adapt_json_value(value, self.encoder)
131 def get_db_prep_save(self, value, connection):
132 if value is None:
133 return value
134 return self.get_db_prep_value(value, connection)
136 def get_transform(self, name):
137 transform = super().get_transform(name)
138 if transform:
139 return transform
140 return KeyTransformFactory(name)
142 def validate(self, value, model_instance):
143 super().validate(value, model_instance)
144 try:
145 json.dumps(value, cls=self.encoder)
146 except TypeError:
147 raise exceptions.ValidationError(
148 self.error_messages["invalid"],
149 code="invalid",
150 params={"value": value},
151 )
153 def value_to_string(self, obj):
154 return self.value_from_object(obj)
156 def formfield(self, **kwargs):
157 return super().formfield(
158 **{
159 "form_class": forms.JSONField,
160 "encoder": self.encoder,
161 "decoder": self.decoder,
162 **kwargs,
163 }
164 )
167def compile_json_path(key_transforms, include_root=True):
168 path = ["$"] if include_root else []
169 for key_transform in key_transforms:
170 try:
171 num = int(key_transform)
172 except ValueError: # non-integer
173 path.append(".")
174 path.append(json.dumps(key_transform))
175 else:
176 path.append("[%s]" % num)
177 return "".join(path)
180class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
181 lookup_name = "contains"
182 postgres_operator = "@>"
184 def as_sql(self, compiler, connection):
185 if not connection.features.supports_json_field_contains:
186 raise NotSupportedError(
187 "contains lookup is not supported on this database backend."
188 )
189 lhs, lhs_params = self.process_lhs(compiler, connection)
190 rhs, rhs_params = self.process_rhs(compiler, connection)
191 params = tuple(lhs_params) + tuple(rhs_params)
192 return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
195class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
196 lookup_name = "contained_by"
197 postgres_operator = "<@"
199 def as_sql(self, compiler, connection):
200 if not connection.features.supports_json_field_contains:
201 raise NotSupportedError(
202 "contained_by lookup is not supported on this database backend."
203 )
204 lhs, lhs_params = self.process_lhs(compiler, connection)
205 rhs, rhs_params = self.process_rhs(compiler, connection)
206 params = tuple(rhs_params) + tuple(lhs_params)
207 return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
210class HasKeyLookup(PostgresOperatorLookup):
211 logical_operator = None
213 def compile_json_path_final_key(self, key_transform):
214 # Compile the final key without interpreting ints as array elements.
215 return ".%s" % json.dumps(key_transform)
217 def as_sql(self, compiler, connection, template=None):
218 # Process JSON path from the left-hand side.
219 if isinstance(self.lhs, KeyTransform):
220 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
221 compiler, connection
222 )
223 lhs_json_path = compile_json_path(lhs_key_transforms)
224 else:
225 lhs, lhs_params = self.process_lhs(compiler, connection)
226 lhs_json_path = "$"
227 sql = template % lhs
228 # Process JSON path from the right-hand side.
229 rhs = self.rhs
230 rhs_params = []
231 if not isinstance(rhs, (list, tuple)):
232 rhs = [rhs]
233 for key in rhs:
234 if isinstance(key, KeyTransform):
235 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
236 else:
237 rhs_key_transforms = [key]
238 *rhs_key_transforms, final_key = rhs_key_transforms
239 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
240 rhs_json_path += self.compile_json_path_final_key(final_key)
241 rhs_params.append(lhs_json_path + rhs_json_path)
242 # Add condition for each key.
243 if self.logical_operator:
244 sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
245 return sql, tuple(lhs_params) + tuple(rhs_params)
247 def as_mysql(self, compiler, connection):
248 return self.as_sql(
249 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
250 )
252 def as_oracle(self, compiler, connection):
253 sql, params = self.as_sql(
254 compiler, connection, template="JSON_EXISTS(%s, '%%s')"
255 )
256 # Add paths directly into SQL because path expressions cannot be passed
257 # as bind variables on Oracle.
258 return sql % tuple(params), []
260 def as_postgresql(self, compiler, connection):
261 if isinstance(self.rhs, KeyTransform):
262 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
263 for key in rhs_key_transforms[:-1]:
264 self.lhs = KeyTransform(key, self.lhs)
265 self.rhs = rhs_key_transforms[-1]
266 return super().as_postgresql(compiler, connection)
268 def as_sqlite(self, compiler, connection):
269 return self.as_sql(
270 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
271 )
274class HasKey(HasKeyLookup):
275 lookup_name = "has_key"
276 postgres_operator = "?"
277 prepare_rhs = False
280class HasKeys(HasKeyLookup):
281 lookup_name = "has_keys"
282 postgres_operator = "?&"
283 logical_operator = " AND "
285 def get_prep_lookup(self):
286 return [str(item) for item in self.rhs]
289class HasAnyKeys(HasKeys):
290 lookup_name = "has_any_keys"
291 postgres_operator = "?|"
292 logical_operator = " OR "
295class HasKeyOrArrayIndex(HasKey):
296 def compile_json_path_final_key(self, key_transform):
297 return compile_json_path([key_transform], include_root=False)
300class CaseInsensitiveMixin:
301 """
302 Mixin to allow case-insensitive comparison of JSON values on MySQL.
303 MySQL handles strings used in JSON context using the utf8mb4_bin collation.
304 Because utf8mb4_bin is a binary collation, comparison of JSON values is
305 case-sensitive.
306 """
308 def process_lhs(self, compiler, connection):
309 lhs, lhs_params = super().process_lhs(compiler, connection)
310 if connection.vendor == "mysql":
311 return "LOWER(%s)" % lhs, lhs_params
312 return lhs, lhs_params
314 def process_rhs(self, compiler, connection):
315 rhs, rhs_params = super().process_rhs(compiler, connection)
316 if connection.vendor == "mysql":
317 return "LOWER(%s)" % rhs, rhs_params
318 return rhs, rhs_params
321class JSONExact(lookups.Exact):
322 can_use_none_as_rhs = True
324 def process_rhs(self, compiler, connection):
325 rhs, rhs_params = super().process_rhs(compiler, connection)
326 # Treat None lookup values as null.
327 if rhs == "%s" and rhs_params == [None]:
328 rhs_params = ["null"]
329 if connection.vendor == "mysql":
330 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
331 rhs %= tuple(func)
332 return rhs, rhs_params
335class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
336 pass
339JSONField.register_lookup(DataContains)
340JSONField.register_lookup(ContainedBy)
341JSONField.register_lookup(HasKey)
342JSONField.register_lookup(HasKeys)
343JSONField.register_lookup(HasAnyKeys)
344JSONField.register_lookup(JSONExact)
345JSONField.register_lookup(JSONIContains)
348class KeyTransform(Transform):
349 postgres_operator = "->"
350 postgres_nested_operator = "#>"
352 def __init__(self, key_name, *args, **kwargs):
353 super().__init__(*args, **kwargs)
354 self.key_name = str(key_name)
356 def preprocess_lhs(self, compiler, connection):
357 key_transforms = [self.key_name]
358 previous = self.lhs
359 while isinstance(previous, KeyTransform):
360 key_transforms.insert(0, previous.key_name)
361 previous = previous.lhs
362 lhs, params = compiler.compile(previous)
363 if connection.vendor == "oracle":
364 # Escape string-formatting.
365 key_transforms = [key.replace("%", "%%") for key in key_transforms]
366 return lhs, params, key_transforms
368 def as_mysql(self, compiler, connection):
369 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
370 json_path = compile_json_path(key_transforms)
371 return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
373 def as_oracle(self, compiler, connection):
374 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
375 json_path = compile_json_path(key_transforms)
376 return (
377 "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
378 % ((lhs, json_path) * 2)
379 ), tuple(params) * 2
381 def as_postgresql(self, compiler, connection):
382 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
383 if len(key_transforms) > 1:
384 sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
385 return sql, tuple(params) + (key_transforms,)
386 try:
387 lookup = int(self.key_name)
388 except ValueError:
389 lookup = self.key_name
390 return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
392 def as_sqlite(self, compiler, connection):
393 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
394 json_path = compile_json_path(key_transforms)
395 datatype_values = ",".join(
396 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
397 )
398 return (
399 "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
400 "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
401 ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
404class KeyTextTransform(KeyTransform):
405 postgres_operator = "->>"
406 postgres_nested_operator = "#>>"
407 output_field = TextField()
409 def as_mysql(self, compiler, connection):
410 if connection.mysql_is_mariadb:
411 # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
412 sql, params = super().as_mysql(compiler, connection)
413 return "JSON_UNQUOTE(%s)" % sql, params
414 else:
415 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
416 json_path = compile_json_path(key_transforms)
417 return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
419 @classmethod
420 def from_lookup(cls, lookup):
421 transform, *keys = lookup.split(LOOKUP_SEP)
422 if not keys:
423 raise ValueError("Lookup must contain key or index transforms.")
424 for key in keys:
425 transform = cls(key, transform)
426 return transform
429KT = KeyTextTransform.from_lookup
432class KeyTransformTextLookupMixin:
433 """
434 Mixin for combining with a lookup expecting a text lhs from a JSONField
435 key lookup. On PostgreSQL, make use of the ->> operator instead of casting
436 key values to text and performing the lookup on the resulting
437 representation.
438 """
440 def __init__(self, key_transform, *args, **kwargs):
441 if not isinstance(key_transform, KeyTransform):
442 raise TypeError(
443 "Transform should be an instance of KeyTransform in order to "
444 "use this lookup."
445 )
446 key_text_transform = KeyTextTransform(
447 key_transform.key_name,
448 *key_transform.source_expressions,
449 **key_transform.extra,
450 )
451 super().__init__(key_text_transform, *args, **kwargs)
454class KeyTransformIsNull(lookups.IsNull):
455 # key__isnull=False is the same as has_key='key'
456 def as_oracle(self, compiler, connection):
457 sql, params = HasKeyOrArrayIndex(
458 self.lhs.lhs,
459 self.lhs.key_name,
460 ).as_oracle(compiler, connection)
461 if not self.rhs:
462 return sql, params
463 # Column doesn't have a key or IS NULL.
464 lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
465 return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
467 def as_sqlite(self, compiler, connection):
468 template = "JSON_TYPE(%s, %%s) IS NULL"
469 if not self.rhs:
470 template = "JSON_TYPE(%s, %%s) IS NOT NULL"
471 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
472 compiler,
473 connection,
474 template=template,
475 )
478class KeyTransformIn(lookups.In):
479 def resolve_expression_parameter(self, compiler, connection, sql, param):
480 sql, params = super().resolve_expression_parameter(
481 compiler,
482 connection,
483 sql,
484 param,
485 )
486 if (
487 not hasattr(param, "as_sql")
488 and not connection.features.has_native_json_field
489 ):
490 if connection.vendor == "oracle":
491 value = json.loads(param)
492 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
493 if isinstance(value, (list, dict)):
494 sql %= "JSON_QUERY"
495 else:
496 sql %= "JSON_VALUE"
497 elif connection.vendor == "mysql" or (
498 connection.vendor == "sqlite"
499 and params[0] not in connection.ops.jsonfield_datatype_values
500 ):
501 sql = "JSON_EXTRACT(%s, '$')"
502 if connection.vendor == "mysql" and connection.mysql_is_mariadb:
503 sql = "JSON_UNQUOTE(%s)" % sql
504 return sql, params
507class KeyTransformExact(JSONExact):
508 def process_rhs(self, compiler, connection):
509 if isinstance(self.rhs, KeyTransform):
510 return super(lookups.Exact, self).process_rhs(compiler, connection)
511 rhs, rhs_params = super().process_rhs(compiler, connection)
512 if connection.vendor == "oracle":
513 func = []
514 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
515 for value in rhs_params:
516 value = json.loads(value)
517 if isinstance(value, (list, dict)):
518 func.append(sql % "JSON_QUERY")
519 else:
520 func.append(sql % "JSON_VALUE")
521 rhs %= tuple(func)
522 elif connection.vendor == "sqlite":
523 func = []
524 for value in rhs_params:
525 if value in connection.ops.jsonfield_datatype_values:
526 func.append("%s")
527 else:
528 func.append("JSON_EXTRACT(%s, '$')")
529 rhs %= tuple(func)
530 return rhs, rhs_params
532 def as_oracle(self, compiler, connection):
533 rhs, rhs_params = super().process_rhs(compiler, connection)
534 if rhs_params == ["null"]:
535 # Field has key and it's NULL.
536 has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
537 has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
538 is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
539 is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
540 return (
541 "%s AND %s" % (has_key_sql, is_null_sql),
542 tuple(has_key_params) + tuple(is_null_params),
543 )
544 return super().as_sql(compiler, connection)
547class KeyTransformIExact(
548 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
549):
550 pass
553class KeyTransformIContains(
554 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
555):
556 pass
559class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
560 pass
563class KeyTransformIStartsWith(
564 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
565):
566 pass
569class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
570 pass
573class KeyTransformIEndsWith(
574 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
575):
576 pass
579class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
580 pass
583class KeyTransformIRegex(
584 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
585):
586 pass
589class KeyTransformNumericLookupMixin:
590 def process_rhs(self, compiler, connection):
591 rhs, rhs_params = super().process_rhs(compiler, connection)
592 if not connection.features.has_native_json_field:
593 rhs_params = [json.loads(value) for value in rhs_params]
594 return rhs, rhs_params
597class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
598 pass
601class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
602 pass
605class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
606 pass
609class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
610 pass
613KeyTransform.register_lookup(KeyTransformIn)
614KeyTransform.register_lookup(KeyTransformExact)
615KeyTransform.register_lookup(KeyTransformIExact)
616KeyTransform.register_lookup(KeyTransformIsNull)
617KeyTransform.register_lookup(KeyTransformIContains)
618KeyTransform.register_lookup(KeyTransformStartsWith)
619KeyTransform.register_lookup(KeyTransformIStartsWith)
620KeyTransform.register_lookup(KeyTransformEndsWith)
621KeyTransform.register_lookup(KeyTransformIEndsWith)
622KeyTransform.register_lookup(KeyTransformRegex)
623KeyTransform.register_lookup(KeyTransformIRegex)
625KeyTransform.register_lookup(KeyTransformLt)
626KeyTransform.register_lookup(KeyTransformLte)
627KeyTransform.register_lookup(KeyTransformGt)
628KeyTransform.register_lookup(KeyTransformGte)
631class KeyTransformFactory:
632 def __init__(self, key_name):
633 self.key_name = key_name
635 def __call__(self, *args, **kwargs):
636 return KeyTransform(self.key_name, *args, **kwargs)