Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/fields/json.py: 40%

384 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1import json 

2import warnings 

3 

4from django import forms 

5from django.core import checks, exceptions 

6from django.db import NotSupportedError, connections, router 

7from django.db.models import expressions, lookups 

8from django.db.models.constants import LOOKUP_SEP 

9from django.db.models.fields import TextField 

10from django.db.models.lookups import ( 

11 FieldGetDbPrepValueMixin, 

12 PostgresOperatorLookup, 

13 Transform, 

14) 

15from django.utils.deprecation import RemovedInDjango51Warning 

16from django.utils.translation import gettext_lazy as _ 

17 

18from . import Field 

19from .mixins import CheckFieldDefaultMixin 

20 

21__all__ = ["JSONField"] 

22 

23 

24class JSONField(CheckFieldDefaultMixin, Field): 

25 empty_strings_allowed = False 

26 description = _("A JSON object") 

27 default_error_messages = { 

28 "invalid": _("Value must be valid JSON."), 

29 } 

30 _default_hint = ("dict", "{}") 

31 

32 def __init__( 

33 self, 

34 verbose_name=None, 

35 name=None, 

36 encoder=None, 

37 decoder=None, 

38 **kwargs, 

39 ): 

40 if encoder and not callable(encoder): 

41 raise ValueError("The encoder parameter must be a callable object.") 

42 if decoder and not callable(decoder): 

43 raise ValueError("The decoder parameter must be a callable object.") 

44 self.encoder = encoder 

45 self.decoder = decoder 

46 super().__init__(verbose_name, name, **kwargs) 

47 

48 def check(self, **kwargs): 

49 errors = super().check(**kwargs) 

50 databases = kwargs.get("databases") or [] 

51 errors.extend(self._check_supported(databases)) 

52 return errors 

53 

54 def _check_supported(self, databases): 

55 errors = [] 

56 for db in databases: 

57 if not router.allow_migrate_model(db, self.model): 

58 continue 

59 connection = connections[db] 

60 if ( 

61 self.model._meta.required_db_vendor 

62 and self.model._meta.required_db_vendor != connection.vendor 

63 ): 

64 continue 

65 if not ( 

66 "supports_json_field" in self.model._meta.required_db_features 

67 or connection.features.supports_json_field 

68 ): 

69 errors.append( 

70 checks.Error( 

71 "%s does not support JSONFields." % connection.display_name, 

72 obj=self.model, 

73 id="fields.E180", 

74 ) 

75 ) 

76 return errors 

77 

78 def deconstruct(self): 

79 name, path, args, kwargs = super().deconstruct() 

80 if self.encoder is not None: 

81 kwargs["encoder"] = self.encoder 

82 if self.decoder is not None: 

83 kwargs["decoder"] = self.decoder 

84 return name, path, args, kwargs 

85 

86 def from_db_value(self, value, expression, connection): 

87 if value is None: 

88 return value 

89 # Some backends (SQLite at least) extract non-string values in their 

90 # SQL datatypes. 

91 if isinstance(expression, KeyTransform) and not isinstance(value, str): 

92 return value 

93 try: 

94 return json.loads(value, cls=self.decoder) 

95 except json.JSONDecodeError: 

96 return value 

97 

98 def get_internal_type(self): 

99 return "JSONField" 

100 

101 def get_db_prep_value(self, value, connection, prepared=False): 

102 # RemovedInDjango51Warning: When the deprecation ends, replace with: 

103 # if ( 

104 # isinstance(value, expressions.Value) 

105 # and isinstance(value.output_field, JSONField) 

106 # ): 

107 # value = value.value 

108 # elif hasattr(value, "as_sql"): ... 

109 if isinstance(value, expressions.Value): 

110 if isinstance(value.value, str) and not isinstance( 

111 value.output_field, JSONField 

112 ): 

113 try: 

114 value = json.loads(value.value, cls=self.decoder) 

115 except json.JSONDecodeError: 

116 value = value.value 

117 else: 

118 warnings.warn( 

119 "Providing an encoded JSON string via Value() is deprecated. " 

120 f"Use Value({value!r}, output_field=JSONField()) instead.", 

121 category=RemovedInDjango51Warning, 

122 ) 

123 elif isinstance(value.output_field, JSONField): 

124 value = value.value 

125 else: 

126 return value 

127 elif hasattr(value, "as_sql"): 

128 return value 

129 return connection.ops.adapt_json_value(value, self.encoder) 

130 

131 def get_db_prep_save(self, value, connection): 

132 if value is None: 

133 return value 

134 return self.get_db_prep_value(value, connection) 

135 

136 def get_transform(self, name): 

137 transform = super().get_transform(name) 

138 if transform: 

139 return transform 

140 return KeyTransformFactory(name) 

141 

142 def validate(self, value, model_instance): 

143 super().validate(value, model_instance) 

144 try: 

145 json.dumps(value, cls=self.encoder) 

146 except TypeError: 

147 raise exceptions.ValidationError( 

148 self.error_messages["invalid"], 

149 code="invalid", 

150 params={"value": value}, 

151 ) 

152 

153 def value_to_string(self, obj): 

154 return self.value_from_object(obj) 

155 

156 def formfield(self, **kwargs): 

157 return super().formfield( 

158 **{ 

159 "form_class": forms.JSONField, 

160 "encoder": self.encoder, 

161 "decoder": self.decoder, 

162 **kwargs, 

163 } 

164 ) 

165 

166 

167def compile_json_path(key_transforms, include_root=True): 

168 path = ["$"] if include_root else [] 

169 for key_transform in key_transforms: 

170 try: 

171 num = int(key_transform) 

172 except ValueError: # non-integer 

173 path.append(".") 

174 path.append(json.dumps(key_transform)) 

175 else: 

176 path.append("[%s]" % num) 

177 return "".join(path) 

178 

179 

180class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup): 

181 lookup_name = "contains" 

182 postgres_operator = "@>" 

183 

184 def as_sql(self, compiler, connection): 

185 if not connection.features.supports_json_field_contains: 

186 raise NotSupportedError( 

187 "contains lookup is not supported on this database backend." 

188 ) 

189 lhs, lhs_params = self.process_lhs(compiler, connection) 

190 rhs, rhs_params = self.process_rhs(compiler, connection) 

191 params = tuple(lhs_params) + tuple(rhs_params) 

192 return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params 

193 

194 

195class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup): 

196 lookup_name = "contained_by" 

197 postgres_operator = "<@" 

198 

199 def as_sql(self, compiler, connection): 

200 if not connection.features.supports_json_field_contains: 

201 raise NotSupportedError( 

202 "contained_by lookup is not supported on this database backend." 

203 ) 

204 lhs, lhs_params = self.process_lhs(compiler, connection) 

205 rhs, rhs_params = self.process_rhs(compiler, connection) 

206 params = tuple(rhs_params) + tuple(lhs_params) 

207 return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params 

208 

209 

210class HasKeyLookup(PostgresOperatorLookup): 

211 logical_operator = None 

212 

213 def compile_json_path_final_key(self, key_transform): 

214 # Compile the final key without interpreting ints as array elements. 

215 return ".%s" % json.dumps(key_transform) 

216 

217 def as_sql(self, compiler, connection, template=None): 

218 # Process JSON path from the left-hand side. 

219 if isinstance(self.lhs, KeyTransform): 

220 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs( 

221 compiler, connection 

222 ) 

223 lhs_json_path = compile_json_path(lhs_key_transforms) 

224 else: 

225 lhs, lhs_params = self.process_lhs(compiler, connection) 

226 lhs_json_path = "$" 

227 sql = template % lhs 

228 # Process JSON path from the right-hand side. 

229 rhs = self.rhs 

230 rhs_params = [] 

231 if not isinstance(rhs, (list, tuple)): 

232 rhs = [rhs] 

233 for key in rhs: 

234 if isinstance(key, KeyTransform): 

235 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) 

236 else: 

237 rhs_key_transforms = [key] 

238 *rhs_key_transforms, final_key = rhs_key_transforms 

239 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False) 

240 rhs_json_path += self.compile_json_path_final_key(final_key) 

241 rhs_params.append(lhs_json_path + rhs_json_path) 

242 # Add condition for each key. 

243 if self.logical_operator: 

244 sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params)) 

245 return sql, tuple(lhs_params) + tuple(rhs_params) 

246 

247 def as_mysql(self, compiler, connection): 

248 return self.as_sql( 

249 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)" 

250 ) 

251 

252 def as_oracle(self, compiler, connection): 

253 sql, params = self.as_sql( 

254 compiler, connection, template="JSON_EXISTS(%s, '%%s')" 

255 ) 

256 # Add paths directly into SQL because path expressions cannot be passed 

257 # as bind variables on Oracle. 

258 return sql % tuple(params), [] 

259 

260 def as_postgresql(self, compiler, connection): 

261 if isinstance(self.rhs, KeyTransform): 

262 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection) 

263 for key in rhs_key_transforms[:-1]: 

264 self.lhs = KeyTransform(key, self.lhs) 

265 self.rhs = rhs_key_transforms[-1] 

266 return super().as_postgresql(compiler, connection) 

267 

268 def as_sqlite(self, compiler, connection): 

269 return self.as_sql( 

270 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL" 

271 ) 

272 

273 

274class HasKey(HasKeyLookup): 

275 lookup_name = "has_key" 

276 postgres_operator = "?" 

277 prepare_rhs = False 

278 

279 

280class HasKeys(HasKeyLookup): 

281 lookup_name = "has_keys" 

282 postgres_operator = "?&" 

283 logical_operator = " AND " 

284 

285 def get_prep_lookup(self): 

286 return [str(item) for item in self.rhs] 

287 

288 

289class HasAnyKeys(HasKeys): 

290 lookup_name = "has_any_keys" 

291 postgres_operator = "?|" 

292 logical_operator = " OR " 

293 

294 

295class HasKeyOrArrayIndex(HasKey): 

296 def compile_json_path_final_key(self, key_transform): 

297 return compile_json_path([key_transform], include_root=False) 

298 

299 

300class CaseInsensitiveMixin: 

301 """ 

302 Mixin to allow case-insensitive comparison of JSON values on MySQL. 

303 MySQL handles strings used in JSON context using the utf8mb4_bin collation. 

304 Because utf8mb4_bin is a binary collation, comparison of JSON values is 

305 case-sensitive. 

306 """ 

307 

308 def process_lhs(self, compiler, connection): 

309 lhs, lhs_params = super().process_lhs(compiler, connection) 

310 if connection.vendor == "mysql": 

311 return "LOWER(%s)" % lhs, lhs_params 

312 return lhs, lhs_params 

313 

314 def process_rhs(self, compiler, connection): 

315 rhs, rhs_params = super().process_rhs(compiler, connection) 

316 if connection.vendor == "mysql": 

317 return "LOWER(%s)" % rhs, rhs_params 

318 return rhs, rhs_params 

319 

320 

321class JSONExact(lookups.Exact): 

322 can_use_none_as_rhs = True 

323 

324 def process_rhs(self, compiler, connection): 

325 rhs, rhs_params = super().process_rhs(compiler, connection) 

326 # Treat None lookup values as null. 

327 if rhs == "%s" and rhs_params == [None]: 

328 rhs_params = ["null"] 

329 if connection.vendor == "mysql": 

330 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) 

331 rhs %= tuple(func) 

332 return rhs, rhs_params 

333 

334 

335class JSONIContains(CaseInsensitiveMixin, lookups.IContains): 

336 pass 

337 

338 

339JSONField.register_lookup(DataContains) 

340JSONField.register_lookup(ContainedBy) 

341JSONField.register_lookup(HasKey) 

342JSONField.register_lookup(HasKeys) 

343JSONField.register_lookup(HasAnyKeys) 

344JSONField.register_lookup(JSONExact) 

345JSONField.register_lookup(JSONIContains) 

346 

347 

348class KeyTransform(Transform): 

349 postgres_operator = "->" 

350 postgres_nested_operator = "#>" 

351 

352 def __init__(self, key_name, *args, **kwargs): 

353 super().__init__(*args, **kwargs) 

354 self.key_name = str(key_name) 

355 

356 def preprocess_lhs(self, compiler, connection): 

357 key_transforms = [self.key_name] 

358 previous = self.lhs 

359 while isinstance(previous, KeyTransform): 

360 key_transforms.insert(0, previous.key_name) 

361 previous = previous.lhs 

362 lhs, params = compiler.compile(previous) 

363 if connection.vendor == "oracle": 

364 # Escape string-formatting. 

365 key_transforms = [key.replace("%", "%%") for key in key_transforms] 

366 return lhs, params, key_transforms 

367 

368 def as_mysql(self, compiler, connection): 

369 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

370 json_path = compile_json_path(key_transforms) 

371 return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,) 

372 

373 def as_oracle(self, compiler, connection): 

374 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

375 json_path = compile_json_path(key_transforms) 

376 return ( 

377 "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" 

378 % ((lhs, json_path) * 2) 

379 ), tuple(params) * 2 

380 

381 def as_postgresql(self, compiler, connection): 

382 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

383 if len(key_transforms) > 1: 

384 sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator) 

385 return sql, tuple(params) + (key_transforms,) 

386 try: 

387 lookup = int(self.key_name) 

388 except ValueError: 

389 lookup = self.key_name 

390 return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,) 

391 

392 def as_sqlite(self, compiler, connection): 

393 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

394 json_path = compile_json_path(key_transforms) 

395 datatype_values = ",".join( 

396 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] 

397 ) 

398 return ( 

399 "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " 

400 "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" 

401 ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3 

402 

403 

404class KeyTextTransform(KeyTransform): 

405 postgres_operator = "->>" 

406 postgres_nested_operator = "#>>" 

407 output_field = TextField() 

408 

409 def as_mysql(self, compiler, connection): 

410 if connection.mysql_is_mariadb: 

411 # MariaDB doesn't support -> and ->> operators (see MDEV-13594). 

412 sql, params = super().as_mysql(compiler, connection) 

413 return "JSON_UNQUOTE(%s)" % sql, params 

414 else: 

415 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

416 json_path = compile_json_path(key_transforms) 

417 return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,) 

418 

419 @classmethod 

420 def from_lookup(cls, lookup): 

421 transform, *keys = lookup.split(LOOKUP_SEP) 

422 if not keys: 

423 raise ValueError("Lookup must contain key or index transforms.") 

424 for key in keys: 

425 transform = cls(key, transform) 

426 return transform 

427 

428 

429KT = KeyTextTransform.from_lookup 

430 

431 

432class KeyTransformTextLookupMixin: 

433 """ 

434 Mixin for combining with a lookup expecting a text lhs from a JSONField 

435 key lookup. On PostgreSQL, make use of the ->> operator instead of casting 

436 key values to text and performing the lookup on the resulting 

437 representation. 

438 """ 

439 

440 def __init__(self, key_transform, *args, **kwargs): 

441 if not isinstance(key_transform, KeyTransform): 

442 raise TypeError( 

443 "Transform should be an instance of KeyTransform in order to " 

444 "use this lookup." 

445 ) 

446 key_text_transform = KeyTextTransform( 

447 key_transform.key_name, 

448 *key_transform.source_expressions, 

449 **key_transform.extra, 

450 ) 

451 super().__init__(key_text_transform, *args, **kwargs) 

452 

453 

454class KeyTransformIsNull(lookups.IsNull): 

455 # key__isnull=False is the same as has_key='key' 

456 def as_oracle(self, compiler, connection): 

457 sql, params = HasKeyOrArrayIndex( 

458 self.lhs.lhs, 

459 self.lhs.key_name, 

460 ).as_oracle(compiler, connection) 

461 if not self.rhs: 

462 return sql, params 

463 # Column doesn't have a key or IS NULL. 

464 lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection) 

465 return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params) 

466 

467 def as_sqlite(self, compiler, connection): 

468 template = "JSON_TYPE(%s, %%s) IS NULL" 

469 if not self.rhs: 

470 template = "JSON_TYPE(%s, %%s) IS NOT NULL" 

471 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql( 

472 compiler, 

473 connection, 

474 template=template, 

475 ) 

476 

477 

478class KeyTransformIn(lookups.In): 

479 def resolve_expression_parameter(self, compiler, connection, sql, param): 

480 sql, params = super().resolve_expression_parameter( 

481 compiler, 

482 connection, 

483 sql, 

484 param, 

485 ) 

486 if ( 

487 not hasattr(param, "as_sql") 

488 and not connection.features.has_native_json_field 

489 ): 

490 if connection.vendor == "oracle": 

491 value = json.loads(param) 

492 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" 

493 if isinstance(value, (list, dict)): 

494 sql %= "JSON_QUERY" 

495 else: 

496 sql %= "JSON_VALUE" 

497 elif connection.vendor == "mysql" or ( 

498 connection.vendor == "sqlite" 

499 and params[0] not in connection.ops.jsonfield_datatype_values 

500 ): 

501 sql = "JSON_EXTRACT(%s, '$')" 

502 if connection.vendor == "mysql" and connection.mysql_is_mariadb: 

503 sql = "JSON_UNQUOTE(%s)" % sql 

504 return sql, params 

505 

506 

507class KeyTransformExact(JSONExact): 

508 def process_rhs(self, compiler, connection): 

509 if isinstance(self.rhs, KeyTransform): 

510 return super(lookups.Exact, self).process_rhs(compiler, connection) 

511 rhs, rhs_params = super().process_rhs(compiler, connection) 

512 if connection.vendor == "oracle": 

513 func = [] 

514 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" 

515 for value in rhs_params: 

516 value = json.loads(value) 

517 if isinstance(value, (list, dict)): 

518 func.append(sql % "JSON_QUERY") 

519 else: 

520 func.append(sql % "JSON_VALUE") 

521 rhs %= tuple(func) 

522 elif connection.vendor == "sqlite": 

523 func = [] 

524 for value in rhs_params: 

525 if value in connection.ops.jsonfield_datatype_values: 

526 func.append("%s") 

527 else: 

528 func.append("JSON_EXTRACT(%s, '$')") 

529 rhs %= tuple(func) 

530 return rhs, rhs_params 

531 

532 def as_oracle(self, compiler, connection): 

533 rhs, rhs_params = super().process_rhs(compiler, connection) 

534 if rhs_params == ["null"]: 

535 # Field has key and it's NULL. 

536 has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name) 

537 has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) 

538 is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True) 

539 is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) 

540 return ( 

541 "%s AND %s" % (has_key_sql, is_null_sql), 

542 tuple(has_key_params) + tuple(is_null_params), 

543 ) 

544 return super().as_sql(compiler, connection) 

545 

546 

547class KeyTransformIExact( 

548 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact 

549): 

550 pass 

551 

552 

553class KeyTransformIContains( 

554 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains 

555): 

556 pass 

557 

558 

559class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): 

560 pass 

561 

562 

563class KeyTransformIStartsWith( 

564 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith 

565): 

566 pass 

567 

568 

569class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): 

570 pass 

571 

572 

573class KeyTransformIEndsWith( 

574 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith 

575): 

576 pass 

577 

578 

579class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): 

580 pass 

581 

582 

583class KeyTransformIRegex( 

584 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex 

585): 

586 pass 

587 

588 

589class KeyTransformNumericLookupMixin: 

590 def process_rhs(self, compiler, connection): 

591 rhs, rhs_params = super().process_rhs(compiler, connection) 

592 if not connection.features.has_native_json_field: 

593 rhs_params = [json.loads(value) for value in rhs_params] 

594 return rhs, rhs_params 

595 

596 

597class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan): 

598 pass 

599 

600 

601class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual): 

602 pass 

603 

604 

605class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan): 

606 pass 

607 

608 

609class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual): 

610 pass 

611 

612 

613KeyTransform.register_lookup(KeyTransformIn) 

614KeyTransform.register_lookup(KeyTransformExact) 

615KeyTransform.register_lookup(KeyTransformIExact) 

616KeyTransform.register_lookup(KeyTransformIsNull) 

617KeyTransform.register_lookup(KeyTransformIContains) 

618KeyTransform.register_lookup(KeyTransformStartsWith) 

619KeyTransform.register_lookup(KeyTransformIStartsWith) 

620KeyTransform.register_lookup(KeyTransformEndsWith) 

621KeyTransform.register_lookup(KeyTransformIEndsWith) 

622KeyTransform.register_lookup(KeyTransformRegex) 

623KeyTransform.register_lookup(KeyTransformIRegex) 

624 

625KeyTransform.register_lookup(KeyTransformLt) 

626KeyTransform.register_lookup(KeyTransformLte) 

627KeyTransform.register_lookup(KeyTransformGt) 

628KeyTransform.register_lookup(KeyTransformGte) 

629 

630 

631class KeyTransformFactory: 

632 def __init__(self, key_name): 

633 self.key_name = key_name 

634 

635 def __call__(self, *args, **kwargs): 

636 return KeyTransform(self.key_name, *args, **kwargs)