Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/django/db/models/constraints.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

313 statements  

1from enum import Enum 

2from types import NoneType 

3 

4from django.core import checks 

5from django.core.exceptions import FieldDoesNotExist, FieldError, ValidationError 

6from django.db import connections 

7from django.db.models.constants import LOOKUP_SEP 

8from django.db.models.expressions import Exists, ExpressionList, F, RawSQL 

9from django.db.models.indexes import IndexExpression 

10from django.db.models.lookups import Exact, IsNull 

11from django.db.models.query_utils import Q 

12from django.db.models.sql.query import Query 

13from django.db.utils import DEFAULT_DB_ALIAS 

14from django.utils.translation import gettext_lazy as _ 

15 

16__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"] 

17 

18 

19class BaseConstraint: 

20 default_violation_error_message = _("Constraint “%(name)s” is violated.") 

21 violation_error_code = None 

22 violation_error_message = None 

23 

24 non_db_attrs = ("violation_error_code", "violation_error_message") 

25 

26 def __init__( 

27 self, *, name, violation_error_code=None, violation_error_message=None 

28 ): 

29 self.name = name 

30 if violation_error_code is not None: 

31 self.violation_error_code = violation_error_code 

32 if violation_error_message is not None: 

33 self.violation_error_message = violation_error_message 

34 else: 

35 self.violation_error_message = self.default_violation_error_message 

36 

37 @property 

38 def contains_expressions(self): 

39 return False 

40 

41 def constraint_sql(self, model, schema_editor): 

42 raise NotImplementedError("This method must be implemented by a subclass.") 

43 

44 def create_sql(self, model, schema_editor): 

45 raise NotImplementedError("This method must be implemented by a subclass.") 

46 

47 def remove_sql(self, model, schema_editor): 

48 raise NotImplementedError("This method must be implemented by a subclass.") 

49 

50 @classmethod 

51 def _expression_refs_exclude(cls, model, expression, exclude): 

52 get_field = model._meta.get_field 

53 for field_name, *__ in model._get_expr_references(expression): 

54 if field_name in exclude: 

55 return True 

56 field = get_field(field_name) 

57 if field.generated and cls._expression_refs_exclude( 

58 model, field.expression, exclude 

59 ): 

60 return True 

61 return False 

62 

63 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

64 raise NotImplementedError("This method must be implemented by a subclass.") 

65 

66 def get_violation_error_message(self): 

67 return self.violation_error_message % {"name": self.name} 

68 

69 def _check(self, model, connection): 

70 return [] 

71 

72 def _check_references(self, model, references): 

73 from django.db.models.fields.composite import CompositePrimaryKey 

74 

75 errors = [] 

76 fields = set() 

77 for field_name, *lookups in references: 

78 # pk is an alias that won't be found by opts.get_field(). 

79 if field_name != "pk" or isinstance(model._meta.pk, CompositePrimaryKey): 

80 fields.add(field_name) 

81 if not lookups: 

82 # If it has no lookups it cannot result in a JOIN. 

83 continue 

84 try: 

85 if field_name == "pk": 

86 field = model._meta.pk 

87 else: 

88 field = model._meta.get_field(field_name) 

89 if not field.is_relation or field.many_to_many or field.one_to_many: 

90 continue 

91 except FieldDoesNotExist: 

92 continue 

93 # JOIN must happen at the first lookup. 

94 first_lookup = lookups[0] 

95 if ( 

96 hasattr(field, "get_transform") 

97 and hasattr(field, "get_lookup") 

98 and field.get_transform(first_lookup) is None 

99 and field.get_lookup(first_lookup) is None 

100 ): 

101 errors.append( 

102 checks.Error( 

103 "'constraints' refers to the joined field '%s'." 

104 % LOOKUP_SEP.join([field_name] + lookups), 

105 obj=model, 

106 id="models.E041", 

107 ) 

108 ) 

109 errors.extend(model._check_local_fields(fields, "constraints")) 

110 return errors 

111 

112 def deconstruct(self): 

113 path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) 

114 path = path.replace("django.db.models.constraints", "django.db.models") 

115 kwargs = {"name": self.name} 

116 if ( 

117 self.violation_error_message is not None 

118 and self.violation_error_message != self.default_violation_error_message 

119 ): 

120 kwargs["violation_error_message"] = self.violation_error_message 

121 if self.violation_error_code is not None: 

122 kwargs["violation_error_code"] = self.violation_error_code 

123 return (path, (), kwargs) 

124 

125 def clone(self): 

126 _, args, kwargs = self.deconstruct() 

127 return self.__class__(*args, **kwargs) 

128 

129 

130class CheckConstraint(BaseConstraint): 

131 def __init__( 

132 self, 

133 *, 

134 condition, 

135 name, 

136 violation_error_code=None, 

137 violation_error_message=None, 

138 ): 

139 self.condition = condition 

140 if not getattr(condition, "conditional", False): 

141 raise TypeError( 

142 "CheckConstraint.condition must be a Q instance or boolean expression." 

143 ) 

144 super().__init__( 

145 name=name, 

146 violation_error_code=violation_error_code, 

147 violation_error_message=violation_error_message, 

148 ) 

149 

150 def _check(self, model, connection): 

151 errors = [] 

152 if not ( 

153 connection.features.supports_table_check_constraints 

154 or "supports_table_check_constraints" in model._meta.required_db_features 

155 ): 

156 errors.append( 

157 checks.Warning( 

158 f"{connection.display_name} does not support check constraints.", 

159 hint=( 

160 "A constraint won't be created. Silence this warning if you " 

161 "don't care about it." 

162 ), 

163 obj=model, 

164 id="models.W027", 

165 ) 

166 ) 

167 elif ( 

168 connection.features.supports_table_check_constraints 

169 or "supports_table_check_constraints" 

170 not in model._meta.required_db_features 

171 ): 

172 references = set() 

173 condition = self.condition 

174 if isinstance(condition, Q): 

175 references.update(model._get_expr_references(condition)) 

176 if any(isinstance(expr, RawSQL) for expr in condition.flatten()): 

177 errors.append( 

178 checks.Warning( 

179 f"Check constraint {self.name!r} contains RawSQL() expression " 

180 "and won't be validated during the model full_clean().", 

181 hint="Silence this warning if you don't care about it.", 

182 obj=model, 

183 id="models.W045", 

184 ), 

185 ) 

186 errors.extend(self._check_references(model, references)) 

187 return errors 

188 

189 def _get_check_sql(self, model, schema_editor): 

190 query = Query(model=model, alias_cols=False) 

191 where = query.build_where(self.condition) 

192 compiler = query.get_compiler(connection=schema_editor.connection) 

193 sql, params = where.as_sql(compiler, schema_editor.connection) 

194 return sql % tuple(schema_editor.quote_value(p) for p in params) 

195 

196 def constraint_sql(self, model, schema_editor): 

197 check = self._get_check_sql(model, schema_editor) 

198 return schema_editor._check_sql(self.name, check) 

199 

200 def create_sql(self, model, schema_editor): 

201 check = self._get_check_sql(model, schema_editor) 

202 return schema_editor._create_check_sql(model, self.name, check) 

203 

204 def remove_sql(self, model, schema_editor): 

205 return schema_editor._delete_check_sql(model, self.name) 

206 

207 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

208 against = instance._get_field_expression_map(meta=model._meta, exclude=exclude) 

209 try: 

210 if not Q(self.condition).check(against, using=using): 

211 raise ValidationError( 

212 self.get_violation_error_message(), code=self.violation_error_code 

213 ) 

214 except FieldError: 

215 pass 

216 

217 def __repr__(self): 

218 return "<%s: condition=%s name=%s%s%s>" % ( 

219 self.__class__.__qualname__, 

220 self.condition, 

221 repr(self.name), 

222 ( 

223 "" 

224 if self.violation_error_code is None 

225 else " violation_error_code=%r" % self.violation_error_code 

226 ), 

227 ( 

228 "" 

229 if self.violation_error_message is None 

230 or self.violation_error_message == self.default_violation_error_message 

231 else " violation_error_message=%r" % self.violation_error_message 

232 ), 

233 ) 

234 

235 def __eq__(self, other): 

236 if isinstance(other, CheckConstraint): 

237 return ( 

238 self.name == other.name 

239 and self.condition == other.condition 

240 and self.violation_error_code == other.violation_error_code 

241 and self.violation_error_message == other.violation_error_message 

242 ) 

243 return super().__eq__(other) 

244 

245 def deconstruct(self): 

246 path, args, kwargs = super().deconstruct() 

247 kwargs["condition"] = self.condition 

248 return path, args, kwargs 

249 

250 

251class Deferrable(Enum): 

252 DEFERRED = "deferred" 

253 IMMEDIATE = "immediate" 

254 

255 # A similar format was proposed for Python 3.10. 

256 def __repr__(self): 

257 return f"{self.__class__.__qualname__}.{self._name_}" 

258 

259 

260class UniqueConstraint(BaseConstraint): 

261 def __init__( 

262 self, 

263 *expressions, 

264 fields=(), 

265 name=None, 

266 condition=None, 

267 deferrable=None, 

268 include=None, 

269 opclasses=(), 

270 nulls_distinct=None, 

271 violation_error_code=None, 

272 violation_error_message=None, 

273 ): 

274 if not name: 

275 raise ValueError("A unique constraint must be named.") 

276 if not expressions and not fields: 

277 raise ValueError( 

278 "At least one field or expression is required to define a " 

279 "unique constraint." 

280 ) 

281 if expressions and fields: 

282 raise ValueError( 

283 "UniqueConstraint.fields and expressions are mutually exclusive." 

284 ) 

285 if not isinstance(condition, (NoneType, Q)): 

286 raise ValueError("UniqueConstraint.condition must be a Q instance.") 

287 if condition and deferrable: 

288 raise ValueError("UniqueConstraint with conditions cannot be deferred.") 

289 if include and deferrable: 

290 raise ValueError("UniqueConstraint with include fields cannot be deferred.") 

291 if opclasses and deferrable: 

292 raise ValueError("UniqueConstraint with opclasses cannot be deferred.") 

293 if expressions and deferrable: 

294 raise ValueError("UniqueConstraint with expressions cannot be deferred.") 

295 if expressions and opclasses: 

296 raise ValueError( 

297 "UniqueConstraint.opclasses cannot be used with expressions. " 

298 "Use django.contrib.postgres.indexes.OpClass() instead." 

299 ) 

300 if not isinstance(deferrable, (NoneType, Deferrable)): 

301 raise TypeError( 

302 "UniqueConstraint.deferrable must be a Deferrable instance." 

303 ) 

304 if not isinstance(include, (NoneType, list, tuple)): 

305 raise TypeError("UniqueConstraint.include must be a list or tuple.") 

306 if not isinstance(opclasses, (list, tuple)): 

307 raise TypeError("UniqueConstraint.opclasses must be a list or tuple.") 

308 if not isinstance(nulls_distinct, (NoneType, bool)): 

309 raise TypeError("UniqueConstraint.nulls_distinct must be a bool.") 

310 if opclasses and len(fields) != len(opclasses): 

311 raise ValueError( 

312 "UniqueConstraint.fields and UniqueConstraint.opclasses must " 

313 "have the same number of elements." 

314 ) 

315 self.fields = tuple(fields) 

316 self.condition = condition 

317 self.deferrable = deferrable 

318 self.include = tuple(include) if include else () 

319 self.opclasses = opclasses 

320 self.nulls_distinct = nulls_distinct 

321 self.expressions = tuple( 

322 F(expression) if isinstance(expression, str) else expression 

323 for expression in expressions 

324 ) 

325 super().__init__( 

326 name=name, 

327 violation_error_code=violation_error_code, 

328 violation_error_message=violation_error_message, 

329 ) 

330 

331 @property 

332 def contains_expressions(self): 

333 return bool(self.expressions) 

334 

335 def _check(self, model, connection): 

336 errors = model._check_local_fields({*self.fields, *self.include}, "constraints") 

337 required_db_features = model._meta.required_db_features 

338 if self.condition is not None and not ( 

339 connection.features.supports_partial_indexes 

340 or "supports_partial_indexes" in required_db_features 

341 ): 

342 errors.append( 

343 checks.Warning( 

344 f"{connection.display_name} does not support unique constraints " 

345 "with conditions.", 

346 hint=( 

347 "A constraint won't be created. Silence this warning if you " 

348 "don't care about it." 

349 ), 

350 obj=model, 

351 id="models.W036", 

352 ) 

353 ) 

354 if self.deferrable is not None and not ( 

355 connection.features.supports_deferrable_unique_constraints 

356 or "supports_deferrable_unique_constraints" in required_db_features 

357 ): 

358 errors.append( 

359 checks.Warning( 

360 f"{connection.display_name} does not support deferrable unique " 

361 "constraints.", 

362 hint=( 

363 "A constraint won't be created. Silence this warning if you " 

364 "don't care about it." 

365 ), 

366 obj=model, 

367 id="models.W038", 

368 ) 

369 ) 

370 if self.include and not ( 

371 connection.features.supports_covering_indexes 

372 or "supports_covering_indexes" in required_db_features 

373 ): 

374 errors.append( 

375 checks.Warning( 

376 f"{connection.display_name} does not support unique constraints " 

377 "with non-key columns.", 

378 hint=( 

379 "A constraint won't be created. Silence this warning if you " 

380 "don't care about it." 

381 ), 

382 obj=model, 

383 id="models.W039", 

384 ) 

385 ) 

386 if self.contains_expressions and not ( 

387 connection.features.supports_expression_indexes 

388 or "supports_expression_indexes" in required_db_features 

389 ): 

390 errors.append( 

391 checks.Warning( 

392 f"{connection.display_name} does not support unique constraints on " 

393 "expressions.", 

394 hint=( 

395 "A constraint won't be created. Silence this warning if you " 

396 "don't care about it." 

397 ), 

398 obj=model, 

399 id="models.W044", 

400 ) 

401 ) 

402 if self.nulls_distinct is not None and not ( 

403 connection.features.supports_nulls_distinct_unique_constraints 

404 or "supports_nulls_distinct_unique_constraints" in required_db_features 

405 ): 

406 errors.append( 

407 checks.Warning( 

408 f"{connection.display_name} does not support unique constraints " 

409 "with nulls distinct.", 

410 hint=( 

411 "A constraint won't be created. Silence this warning if you " 

412 "don't care about it." 

413 ), 

414 obj=model, 

415 id="models.W047", 

416 ) 

417 ) 

418 references = set() 

419 if ( 

420 connection.features.supports_partial_indexes 

421 or "supports_partial_indexes" not in required_db_features 

422 ) and isinstance(self.condition, Q): 

423 references.update(model._get_expr_references(self.condition)) 

424 if self.contains_expressions and ( 

425 connection.features.supports_expression_indexes 

426 or "supports_expression_indexes" not in required_db_features 

427 ): 

428 for expression in self.expressions: 

429 references.update(model._get_expr_references(expression)) 

430 errors.extend(self._check_references(model, references)) 

431 return errors 

432 

433 def _get_condition_sql(self, model, schema_editor): 

434 if self.condition is None: 

435 return None 

436 query = Query(model=model, alias_cols=False) 

437 where = query.build_where(self.condition) 

438 compiler = query.get_compiler(connection=schema_editor.connection) 

439 sql, params = where.as_sql(compiler, schema_editor.connection) 

440 return sql % tuple(schema_editor.quote_value(p) for p in params) 

441 

442 def _get_index_expressions(self, model, schema_editor): 

443 if not self.expressions: 

444 return None 

445 index_expressions = [] 

446 for expression in self.expressions: 

447 index_expression = IndexExpression(expression) 

448 index_expression.set_wrapper_classes(schema_editor.connection) 

449 index_expressions.append(index_expression) 

450 return ExpressionList(*index_expressions).resolve_expression( 

451 Query(model, alias_cols=False), 

452 ) 

453 

454 def constraint_sql(self, model, schema_editor): 

455 fields = [model._meta.get_field(field_name) for field_name in self.fields] 

456 include = [ 

457 model._meta.get_field(field_name).column for field_name in self.include 

458 ] 

459 condition = self._get_condition_sql(model, schema_editor) 

460 expressions = self._get_index_expressions(model, schema_editor) 

461 return schema_editor._unique_sql( 

462 model, 

463 fields, 

464 self.name, 

465 condition=condition, 

466 deferrable=self.deferrable, 

467 include=include, 

468 opclasses=self.opclasses, 

469 expressions=expressions, 

470 nulls_distinct=self.nulls_distinct, 

471 ) 

472 

473 def create_sql(self, model, schema_editor): 

474 fields = [model._meta.get_field(field_name) for field_name in self.fields] 

475 include = [ 

476 model._meta.get_field(field_name).column for field_name in self.include 

477 ] 

478 condition = self._get_condition_sql(model, schema_editor) 

479 expressions = self._get_index_expressions(model, schema_editor) 

480 return schema_editor._create_unique_sql( 

481 model, 

482 fields, 

483 self.name, 

484 condition=condition, 

485 deferrable=self.deferrable, 

486 include=include, 

487 opclasses=self.opclasses, 

488 expressions=expressions, 

489 nulls_distinct=self.nulls_distinct, 

490 ) 

491 

492 def remove_sql(self, model, schema_editor): 

493 condition = self._get_condition_sql(model, schema_editor) 

494 include = [ 

495 model._meta.get_field(field_name).column for field_name in self.include 

496 ] 

497 expressions = self._get_index_expressions(model, schema_editor) 

498 return schema_editor._delete_unique_sql( 

499 model, 

500 self.name, 

501 condition=condition, 

502 deferrable=self.deferrable, 

503 include=include, 

504 opclasses=self.opclasses, 

505 expressions=expressions, 

506 nulls_distinct=self.nulls_distinct, 

507 ) 

508 

509 def __repr__(self): 

510 return "<%s:%s%s%s%s%s%s%s%s%s%s>" % ( 

511 self.__class__.__qualname__, 

512 "" if not self.fields else " fields=%s" % repr(self.fields), 

513 "" if not self.expressions else " expressions=%s" % repr(self.expressions), 

514 " name=%s" % repr(self.name), 

515 "" if self.condition is None else " condition=%s" % self.condition, 

516 "" if self.deferrable is None else " deferrable=%r" % self.deferrable, 

517 "" if not self.include else " include=%s" % repr(self.include), 

518 "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), 

519 ( 

520 "" 

521 if self.nulls_distinct is None 

522 else " nulls_distinct=%r" % self.nulls_distinct 

523 ), 

524 ( 

525 "" 

526 if self.violation_error_code is None 

527 else " violation_error_code=%r" % self.violation_error_code 

528 ), 

529 ( 

530 "" 

531 if self.violation_error_message is None 

532 or self.violation_error_message == self.default_violation_error_message 

533 else " violation_error_message=%r" % self.violation_error_message 

534 ), 

535 ) 

536 

537 def __eq__(self, other): 

538 if isinstance(other, UniqueConstraint): 

539 return ( 

540 self.name == other.name 

541 and self.fields == other.fields 

542 and self.condition == other.condition 

543 and self.deferrable == other.deferrable 

544 and self.include == other.include 

545 and self.opclasses == other.opclasses 

546 and self.expressions == other.expressions 

547 and self.nulls_distinct is other.nulls_distinct 

548 and self.violation_error_code == other.violation_error_code 

549 and self.violation_error_message == other.violation_error_message 

550 ) 

551 return super().__eq__(other) 

552 

553 def deconstruct(self): 

554 path, args, kwargs = super().deconstruct() 

555 if self.fields: 

556 kwargs["fields"] = self.fields 

557 if self.condition: 

558 kwargs["condition"] = self.condition 

559 if self.deferrable: 

560 kwargs["deferrable"] = self.deferrable 

561 if self.include: 

562 kwargs["include"] = self.include 

563 if self.opclasses: 

564 kwargs["opclasses"] = self.opclasses 

565 if self.nulls_distinct is not None: 

566 kwargs["nulls_distinct"] = self.nulls_distinct 

567 return path, self.expressions, kwargs 

568 

569 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

570 queryset = model._default_manager.using(using) 

571 if self.fields: 

572 lookup_kwargs = {} 

573 generated_field_names = [] 

574 for field_name in self.fields: 

575 if exclude and field_name in exclude: 

576 return 

577 field = model._meta.get_field(field_name) 

578 if field.generated: 

579 if exclude and self._expression_refs_exclude( 

580 model, field.expression, exclude 

581 ): 

582 return 

583 generated_field_names.append(field.name) 

584 else: 

585 lookup_value = getattr(instance, field.attname) 

586 if ( 

587 self.nulls_distinct is not False 

588 and lookup_value is None 

589 or ( 

590 lookup_value == "" 

591 and connections[ 

592 using 

593 ].features.interprets_empty_strings_as_nulls 

594 ) 

595 ): 

596 # A composite constraint containing NULL value cannot cause 

597 # a violation since NULL != NULL in SQL. 

598 return 

599 lookup_kwargs[field.name] = lookup_value 

600 lookup_args = [] 

601 if generated_field_names: 

602 field_expression_map = instance._get_field_expression_map( 

603 meta=model._meta, exclude=exclude 

604 ) 

605 for field_name in generated_field_names: 

606 expression = field_expression_map[field_name] 

607 if self.nulls_distinct is False: 

608 lhs = F(field_name) 

609 condition = Q(Exact(lhs, expression)) | Q( 

610 IsNull(lhs, True), IsNull(expression, True) 

611 ) 

612 lookup_args.append(condition) 

613 else: 

614 lookup_kwargs[field_name] = expression 

615 queryset = queryset.filter(*lookup_args, **lookup_kwargs) 

616 else: 

617 # Ignore constraints with excluded fields. 

618 if exclude and any( 

619 self._expression_refs_exclude(model, expression, exclude) 

620 for expression in self.expressions 

621 ): 

622 return 

623 replacements = { 

624 F(field): value 

625 for field, value in instance._get_field_expression_map( 

626 meta=model._meta, exclude=exclude 

627 ).items() 

628 } 

629 filters = [] 

630 for expr in self.expressions: 

631 if hasattr(expr, "get_expression_for_validation"): 

632 expr = expr.get_expression_for_validation() 

633 rhs = expr.replace_expressions(replacements) 

634 condition = Exact(expr, rhs) 

635 if self.nulls_distinct is False: 

636 condition = Q(condition) | Q(IsNull(expr, True), IsNull(rhs, True)) 

637 filters.append(condition) 

638 queryset = queryset.filter(*filters) 

639 model_class_pk = instance._get_pk_val(model._meta) 

640 if not instance._state.adding and instance._is_pk_set(model._meta): 

641 queryset = queryset.exclude(pk=model_class_pk) 

642 if not self.condition: 

643 if queryset.exists(): 

644 if ( 

645 self.fields 

646 and self.violation_error_message 

647 == self.default_violation_error_message 

648 ): 

649 # When fields are defined, use the unique_error_message() as 

650 # a default for backward compatibility. 

651 validation_error_message = instance.unique_error_message( 

652 model, self.fields 

653 ) 

654 raise ValidationError( 

655 validation_error_message, 

656 code=validation_error_message.code, 

657 ) 

658 raise ValidationError( 

659 self.get_violation_error_message(), 

660 code=self.violation_error_code, 

661 ) 

662 else: 

663 against = instance._get_field_expression_map( 

664 meta=model._meta, exclude=exclude 

665 ) 

666 try: 

667 if (self.condition & Exists(queryset.filter(self.condition))).check( 

668 against, using=using 

669 ): 

670 raise ValidationError( 

671 self.get_violation_error_message(), 

672 code=self.violation_error_code, 

673 ) 

674 except FieldError: 

675 pass