1import json
2
3from django import forms
4from django.core import checks, exceptions
5from django.db import NotSupportedError, connections, router
6from django.db.models import expressions, lookups
7from django.db.models.constants import LOOKUP_SEP
8from django.db.models.fields import TextField
9from django.db.models.lookups import (
10 FieldGetDbPrepValueMixin,
11 PostgresOperatorLookup,
12 Transform,
13)
14from django.utils.translation import gettext_lazy as _
15
16from . import Field
17from .mixins import CheckFieldDefaultMixin
18
19__all__ = ["JSONField"]
20
21
22class JSONField(CheckFieldDefaultMixin, Field):
23 empty_strings_allowed = False
24 description = _("A JSON object")
25 default_error_messages = {
26 "invalid": _("Value must be valid JSON."),
27 }
28 _default_hint = ("dict", "{}")
29
30 def __init__(
31 self,
32 verbose_name=None,
33 name=None,
34 encoder=None,
35 decoder=None,
36 **kwargs,
37 ):
38 if encoder and not callable(encoder):
39 raise ValueError("The encoder parameter must be a callable object.")
40 if decoder and not callable(decoder):
41 raise ValueError("The decoder parameter must be a callable object.")
42 self.encoder = encoder
43 self.decoder = decoder
44 super().__init__(verbose_name, name, **kwargs)
45
46 def check(self, **kwargs):
47 errors = super().check(**kwargs)
48 databases = kwargs.get("databases") or []
49 errors.extend(self._check_supported(databases))
50 return errors
51
52 def _check_supported(self, databases):
53 errors = []
54 for db in databases:
55 if not router.allow_migrate_model(db, self.model):
56 continue
57 connection = connections[db]
58 if (
59 self.model._meta.required_db_vendor
60 and self.model._meta.required_db_vendor != connection.vendor
61 ):
62 continue
63 if not (
64 "supports_json_field" in self.model._meta.required_db_features
65 or connection.features.supports_json_field
66 ):
67 errors.append(
68 checks.Error(
69 "%s does not support JSONFields." % connection.display_name,
70 obj=self.model,
71 id="fields.E180",
72 )
73 )
74 return errors
75
76 def deconstruct(self):
77 name, path, args, kwargs = super().deconstruct()
78 if self.encoder is not None:
79 kwargs["encoder"] = self.encoder
80 if self.decoder is not None:
81 kwargs["decoder"] = self.decoder
82 return name, path, args, kwargs
83
84 def from_db_value(self, value, expression, connection):
85 if value is None:
86 return value
87 # Some backends (SQLite at least) extract non-string values in their
88 # SQL datatypes.
89 if isinstance(expression, KeyTransform) and not isinstance(value, str):
90 return value
91 try:
92 return json.loads(value, cls=self.decoder)
93 except json.JSONDecodeError:
94 return value
95
96 def get_internal_type(self):
97 return "JSONField"
98
99 def get_db_prep_value(self, value, connection, prepared=False):
100 if not prepared:
101 value = self.get_prep_value(value)
102 if isinstance(value, expressions.Value) and isinstance(
103 value.output_field, JSONField
104 ):
105 value = value.value
106 elif hasattr(value, "as_sql"):
107 return value
108 return connection.ops.adapt_json_value(value, self.encoder)
109
110 def get_db_prep_save(self, value, connection):
111 if value is None:
112 return value
113 return self.get_db_prep_value(value, connection)
114
115 def get_transform(self, name):
116 transform = super().get_transform(name)
117 if transform:
118 return transform
119 return KeyTransformFactory(name)
120
121 def validate(self, value, model_instance):
122 super().validate(value, model_instance)
123 try:
124 json.dumps(value, cls=self.encoder)
125 except TypeError:
126 raise exceptions.ValidationError(
127 self.error_messages["invalid"],
128 code="invalid",
129 params={"value": value},
130 )
131
132 def value_to_string(self, obj):
133 return self.value_from_object(obj)
134
135 def formfield(self, **kwargs):
136 return super().formfield(
137 **{
138 "form_class": forms.JSONField,
139 "encoder": self.encoder,
140 "decoder": self.decoder,
141 **kwargs,
142 }
143 )
144
145
146def compile_json_path(key_transforms, include_root=True):
147 path = ["$"] if include_root else []
148 for key_transform in key_transforms:
149 try:
150 num = int(key_transform)
151 except ValueError: # non-integer
152 path.append(".")
153 path.append(json.dumps(key_transform))
154 else:
155 path.append("[%s]" % num)
156 return "".join(path)
157
158
159class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
160 lookup_name = "contains"
161 postgres_operator = "@>"
162
163 def as_sql(self, compiler, connection):
164 if not connection.features.supports_json_field_contains:
165 raise NotSupportedError(
166 "contains lookup is not supported on this database backend."
167 )
168 lhs, lhs_params = self.process_lhs(compiler, connection)
169 rhs, rhs_params = self.process_rhs(compiler, connection)
170 params = tuple(lhs_params) + tuple(rhs_params)
171 return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
172
173
174class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
175 lookup_name = "contained_by"
176 postgres_operator = "<@"
177
178 def as_sql(self, compiler, connection):
179 if not connection.features.supports_json_field_contains:
180 raise NotSupportedError(
181 "contained_by lookup is not supported on this database backend."
182 )
183 lhs, lhs_params = self.process_lhs(compiler, connection)
184 rhs, rhs_params = self.process_rhs(compiler, connection)
185 params = tuple(rhs_params) + tuple(lhs_params)
186 return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
187
188
189class HasKeyLookup(PostgresOperatorLookup):
190 logical_operator = None
191
192 def compile_json_path_final_key(self, key_transform):
193 # Compile the final key without interpreting ints as array elements.
194 return ".%s" % json.dumps(key_transform)
195
196 def _as_sql_parts(self, compiler, connection):
197 # Process JSON path from the left-hand side.
198 if isinstance(self.lhs, KeyTransform):
199 lhs_sql, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
200 compiler, connection
201 )
202 lhs_json_path = compile_json_path(lhs_key_transforms)
203 else:
204 lhs_sql, lhs_params = self.process_lhs(compiler, connection)
205 lhs_json_path = "$"
206 # Process JSON path from the right-hand side.
207 rhs = self.rhs
208 if not isinstance(rhs, (list, tuple)):
209 rhs = [rhs]
210 for key in rhs:
211 if isinstance(key, KeyTransform):
212 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
213 else:
214 rhs_key_transforms = [key]
215 *rhs_key_transforms, final_key = rhs_key_transforms
216 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
217 rhs_json_path += self.compile_json_path_final_key(final_key)
218 yield lhs_sql, lhs_params, lhs_json_path + rhs_json_path
219
220 def _combine_sql_parts(self, parts):
221 # Add condition for each key.
222 if self.logical_operator:
223 return "(%s)" % self.logical_operator.join(parts)
224 return "".join(parts)
225
226 def as_sql(self, compiler, connection, template=None):
227 sql_parts = []
228 params = []
229 for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
230 compiler, connection
231 ):
232 sql_parts.append(template % (lhs_sql, "%s"))
233 params.extend(lhs_params + [rhs_json_path])
234 return self._combine_sql_parts(sql_parts), tuple(params)
235
236 def as_mysql(self, compiler, connection):
237 return self.as_sql(
238 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %s)"
239 )
240
241 def as_oracle(self, compiler, connection):
242 # Use a custom delimiter to prevent the JSON path from escaping the SQL
243 # literal. See comment in KeyTransform.
244 template = "JSON_EXISTS(%s, q'\uffff%s\uffff')"
245 sql_parts = []
246 params = []
247 for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
248 compiler, connection
249 ):
250 # Add right-hand-side directly into SQL because it cannot be passed
251 # as bind variables to JSON_EXISTS. It might result in invalid
252 # queries but it is assumed that it cannot be evaded because the
253 # path is JSON serialized.
254 sql_parts.append(template % (lhs_sql, rhs_json_path))
255 params.extend(lhs_params)
256 return self._combine_sql_parts(sql_parts), tuple(params)
257
258 def as_postgresql(self, compiler, connection):
259 if isinstance(self.rhs, KeyTransform):
260 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
261 for key in rhs_key_transforms[:-1]:
262 self.lhs = KeyTransform(key, self.lhs)
263 self.rhs = rhs_key_transforms[-1]
264 return super().as_postgresql(compiler, connection)
265
266 def as_sqlite(self, compiler, connection):
267 return self.as_sql(
268 compiler, connection, template="JSON_TYPE(%s, %s) IS NOT NULL"
269 )
270
271
272class HasKey(HasKeyLookup):
273 lookup_name = "has_key"
274 postgres_operator = "?"
275 prepare_rhs = False
276
277
278class HasKeys(HasKeyLookup):
279 lookup_name = "has_keys"
280 postgres_operator = "?&"
281 logical_operator = " AND "
282
283 def get_prep_lookup(self):
284 return [str(item) for item in self.rhs]
285
286
287class HasAnyKeys(HasKeys):
288 lookup_name = "has_any_keys"
289 postgres_operator = "?|"
290 logical_operator = " OR "
291
292
293class HasKeyOrArrayIndex(HasKey):
294 def compile_json_path_final_key(self, key_transform):
295 return compile_json_path([key_transform], include_root=False)
296
297
298class CaseInsensitiveMixin:
299 """
300 Mixin to allow case-insensitive comparison of JSON values on MySQL.
301 MySQL handles strings used in JSON context using the utf8mb4_bin collation.
302 Because utf8mb4_bin is a binary collation, comparison of JSON values is
303 case-sensitive.
304 """
305
306 def process_lhs(self, compiler, connection):
307 lhs, lhs_params = super().process_lhs(compiler, connection)
308 if connection.vendor == "mysql":
309 return "LOWER(%s)" % lhs, lhs_params
310 return lhs, lhs_params
311
312 def process_rhs(self, compiler, connection):
313 rhs, rhs_params = super().process_rhs(compiler, connection)
314 if connection.vendor == "mysql":
315 return "LOWER(%s)" % rhs, rhs_params
316 return rhs, rhs_params
317
318
319class JSONExact(lookups.Exact):
320 can_use_none_as_rhs = True
321
322 def process_rhs(self, compiler, connection):
323 rhs, rhs_params = super().process_rhs(compiler, connection)
324 # Treat None lookup values as null.
325 if rhs == "%s" and rhs_params == [None]:
326 rhs_params = ["null"]
327 if connection.vendor == "mysql":
328 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
329 rhs %= tuple(func)
330 return rhs, rhs_params
331
332 def as_oracle(self, compiler, connection):
333 lhs, lhs_params = self.process_lhs(compiler, connection)
334 rhs, rhs_params = self.process_rhs(compiler, connection)
335 if connection.features.supports_primitives_in_json_field:
336 lhs = f"JSON({lhs})"
337 rhs = f"JSON({rhs})"
338 return f"JSON_EQUAL({lhs}, {rhs} ERROR ON ERROR)", (*lhs_params, *rhs_params)
339
340
341class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
342 pass
343
344
345JSONField.register_lookup(DataContains)
346JSONField.register_lookup(ContainedBy)
347JSONField.register_lookup(HasKey)
348JSONField.register_lookup(HasKeys)
349JSONField.register_lookup(HasAnyKeys)
350JSONField.register_lookup(JSONExact)
351JSONField.register_lookup(JSONIContains)
352
353
354class KeyTransform(Transform):
355 postgres_operator = "->"
356 postgres_nested_operator = "#>"
357
358 def __init__(self, key_name, *args, **kwargs):
359 super().__init__(*args, **kwargs)
360 self.key_name = str(key_name)
361
362 def preprocess_lhs(self, compiler, connection):
363 key_transforms = [self.key_name]
364 previous = self.lhs
365 while isinstance(previous, KeyTransform):
366 key_transforms.insert(0, previous.key_name)
367 previous = previous.lhs
368 lhs, params = compiler.compile(previous)
369 if connection.vendor == "oracle":
370 # Escape string-formatting.
371 key_transforms = [key.replace("%", "%%") for key in key_transforms]
372 return lhs, params, key_transforms
373
374 def as_mysql(self, compiler, connection):
375 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
376 json_path = compile_json_path(key_transforms)
377 return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
378
379 def as_oracle(self, compiler, connection):
380 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
381 json_path = compile_json_path(key_transforms)
382 if connection.features.supports_primitives_in_json_field:
383 sql = (
384 "COALESCE("
385 "JSON_VALUE(%s, q'\uffff%s\uffff'),"
386 "JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
387 ")"
388 )
389 else:
390 sql = (
391 "COALESCE("
392 "JSON_QUERY(%s, q'\uffff%s\uffff'),"
393 "JSON_VALUE(%s, q'\uffff%s\uffff')"
394 ")"
395 )
396 # Add paths directly into SQL because path expressions cannot be passed
397 # as bind variables on Oracle. Use a custom delimiter to prevent the
398 # JSON path from escaping the SQL literal. Each key in the JSON path is
399 # passed through json.dumps() with ensure_ascii=True (the default),
400 # which converts the delimiter into the escaped \uffff format. This
401 # ensures that the delimiter is not present in the JSON path.
402 return sql % ((lhs, json_path) * 2), tuple(params) * 2
403
404 def as_postgresql(self, compiler, connection):
405 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
406 if len(key_transforms) > 1:
407 sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
408 return sql, tuple(params) + (key_transforms,)
409 try:
410 lookup = int(self.key_name)
411 except ValueError:
412 lookup = self.key_name
413 return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
414
415 def as_sqlite(self, compiler, connection):
416 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
417 json_path = compile_json_path(key_transforms)
418 datatype_values = ",".join(
419 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
420 )
421 return (
422 "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
423 "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
424 ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
425
426
427class KeyTextTransform(KeyTransform):
428 postgres_operator = "->>"
429 postgres_nested_operator = "#>>"
430 output_field = TextField()
431
432 def as_mysql(self, compiler, connection):
433 if connection.mysql_is_mariadb:
434 # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
435 sql, params = super().as_mysql(compiler, connection)
436 return "JSON_UNQUOTE(%s)" % sql, params
437 else:
438 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
439 json_path = compile_json_path(key_transforms)
440 return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
441
442 @classmethod
443 def from_lookup(cls, lookup):
444 transform, *keys = lookup.split(LOOKUP_SEP)
445 if not keys:
446 raise ValueError("Lookup must contain key or index transforms.")
447 for key in keys:
448 transform = cls(key, transform)
449 return transform
450
451
452KT = KeyTextTransform.from_lookup
453
454
455class KeyTransformTextLookupMixin:
456 """
457 Mixin for combining with a lookup expecting a text lhs from a JSONField
458 key lookup. On PostgreSQL, make use of the ->> operator instead of casting
459 key values to text and performing the lookup on the resulting
460 representation.
461 """
462
463 def __init__(self, key_transform, *args, **kwargs):
464 if not isinstance(key_transform, KeyTransform):
465 raise TypeError(
466 "Transform should be an instance of KeyTransform in order to "
467 "use this lookup."
468 )
469 key_text_transform = KeyTextTransform(
470 key_transform.key_name,
471 *key_transform.source_expressions,
472 **key_transform.extra,
473 )
474 super().__init__(key_text_transform, *args, **kwargs)
475
476
477class KeyTransformIsNull(lookups.IsNull):
478 # key__isnull=False is the same as has_key='key'
479 def as_oracle(self, compiler, connection):
480 sql, params = HasKeyOrArrayIndex(
481 self.lhs.lhs,
482 self.lhs.key_name,
483 ).as_oracle(compiler, connection)
484 if not self.rhs:
485 return sql, params
486 # Column doesn't have a key or IS NULL.
487 lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
488 return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
489
490 def as_sqlite(self, compiler, connection):
491 template = "JSON_TYPE(%s, %s) IS NULL"
492 if not self.rhs:
493 template = "JSON_TYPE(%s, %s) IS NOT NULL"
494 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
495 compiler,
496 connection,
497 template=template,
498 )
499
500
501class KeyTransformIn(lookups.In):
502 def resolve_expression_parameter(self, compiler, connection, sql, param):
503 sql, params = super().resolve_expression_parameter(
504 compiler,
505 connection,
506 sql,
507 param,
508 )
509 if (
510 not hasattr(param, "as_sql")
511 and not connection.features.has_native_json_field
512 ):
513 if connection.vendor == "oracle":
514 value = json.loads(param)
515 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
516 if isinstance(value, (list, dict)):
517 sql %= "JSON_QUERY"
518 else:
519 sql %= "JSON_VALUE"
520 elif connection.vendor == "mysql" or (
521 connection.vendor == "sqlite"
522 and params[0] not in connection.ops.jsonfield_datatype_values
523 ):
524 sql = "JSON_EXTRACT(%s, '$')"
525 if connection.vendor == "mysql" and connection.mysql_is_mariadb:
526 sql = "JSON_UNQUOTE(%s)" % sql
527 return sql, params
528
529
530class KeyTransformExact(JSONExact):
531 def process_rhs(self, compiler, connection):
532 if isinstance(self.rhs, KeyTransform):
533 return super(lookups.Exact, self).process_rhs(compiler, connection)
534 rhs, rhs_params = super().process_rhs(compiler, connection)
535 if connection.vendor == "oracle":
536 func = []
537 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
538 for value in rhs_params:
539 value = json.loads(value)
540 if isinstance(value, (list, dict)):
541 func.append(sql % "JSON_QUERY")
542 else:
543 func.append(sql % "JSON_VALUE")
544 rhs %= tuple(func)
545 elif connection.vendor == "sqlite":
546 func = []
547 for value in rhs_params:
548 if value in connection.ops.jsonfield_datatype_values:
549 func.append("%s")
550 else:
551 func.append("JSON_EXTRACT(%s, '$')")
552 rhs %= tuple(func)
553 return rhs, rhs_params
554
555 def as_oracle(self, compiler, connection):
556 rhs, rhs_params = super().process_rhs(compiler, connection)
557 if rhs_params == ["null"]:
558 # Field has key and it's NULL.
559 has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
560 has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
561 is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
562 is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
563 return (
564 "%s AND %s" % (has_key_sql, is_null_sql),
565 tuple(has_key_params) + tuple(is_null_params),
566 )
567 return super().as_sql(compiler, connection)
568
569
570class KeyTransformIExact(
571 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
572):
573 pass
574
575
576class KeyTransformIContains(
577 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
578):
579 pass
580
581
582class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
583 pass
584
585
586class KeyTransformIStartsWith(
587 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
588):
589 pass
590
591
592class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
593 pass
594
595
596class KeyTransformIEndsWith(
597 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
598):
599 pass
600
601
602class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
603 pass
604
605
606class KeyTransformIRegex(
607 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
608):
609 pass
610
611
612class KeyTransformNumericLookupMixin:
613 def process_rhs(self, compiler, connection):
614 rhs, rhs_params = super().process_rhs(compiler, connection)
615 if not connection.features.has_native_json_field:
616 rhs_params = [json.loads(value) for value in rhs_params]
617 return rhs, rhs_params
618
619
620class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
621 pass
622
623
624class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
625 pass
626
627
628class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
629 pass
630
631
632class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
633 pass
634
635
636KeyTransform.register_lookup(KeyTransformIn)
637KeyTransform.register_lookup(KeyTransformExact)
638KeyTransform.register_lookup(KeyTransformIExact)
639KeyTransform.register_lookup(KeyTransformIsNull)
640KeyTransform.register_lookup(KeyTransformIContains)
641KeyTransform.register_lookup(KeyTransformStartsWith)
642KeyTransform.register_lookup(KeyTransformIStartsWith)
643KeyTransform.register_lookup(KeyTransformEndsWith)
644KeyTransform.register_lookup(KeyTransformIEndsWith)
645KeyTransform.register_lookup(KeyTransformRegex)
646KeyTransform.register_lookup(KeyTransformIRegex)
647
648KeyTransform.register_lookup(KeyTransformLt)
649KeyTransform.register_lookup(KeyTransformLte)
650KeyTransform.register_lookup(KeyTransformGt)
651KeyTransform.register_lookup(KeyTransformGte)
652
653
654class KeyTransformFactory:
655 def __init__(self, key_name):
656 self.key_name = key_name
657
658 def __call__(self, *args, **kwargs):
659 return KeyTransform(self.key_name, *args, **kwargs)