1from enum import Enum
2from types import NoneType
3
4from django.core import checks
5from django.core.exceptions import FieldDoesNotExist, FieldError, ValidationError
6from django.db import connections
7from django.db.models.constants import LOOKUP_SEP
8from django.db.models.expressions import Exists, ExpressionList, F, RawSQL
9from django.db.models.indexes import IndexExpression
10from django.db.models.lookups import Exact, IsNull
11from django.db.models.query_utils import Q
12from django.db.models.sql.query import Query
13from django.db.utils import DEFAULT_DB_ALIAS
14from django.utils.translation import gettext_lazy as _
15
16__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
17
18
19class BaseConstraint:
20 default_violation_error_message = _("Constraint “%(name)s” is violated.")
21 violation_error_code = None
22 violation_error_message = None
23
24 non_db_attrs = ("violation_error_code", "violation_error_message")
25
26 def __init__(
27 self, *, name, violation_error_code=None, violation_error_message=None
28 ):
29 self.name = name
30 if violation_error_code is not None:
31 self.violation_error_code = violation_error_code
32 if violation_error_message is not None:
33 self.violation_error_message = violation_error_message
34 else:
35 self.violation_error_message = self.default_violation_error_message
36
37 @property
38 def contains_expressions(self):
39 return False
40
41 def constraint_sql(self, model, schema_editor):
42 raise NotImplementedError("This method must be implemented by a subclass.")
43
44 def create_sql(self, model, schema_editor):
45 raise NotImplementedError("This method must be implemented by a subclass.")
46
47 def remove_sql(self, model, schema_editor):
48 raise NotImplementedError("This method must be implemented by a subclass.")
49
50 @classmethod
51 def _expression_refs_exclude(cls, model, expression, exclude):
52 get_field = model._meta.get_field
53 for field_name, *__ in model._get_expr_references(expression):
54 if field_name in exclude:
55 return True
56 field = get_field(field_name)
57 if field.generated and cls._expression_refs_exclude(
58 model, field.expression, exclude
59 ):
60 return True
61 return False
62
63 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
64 raise NotImplementedError("This method must be implemented by a subclass.")
65
66 def get_violation_error_message(self):
67 return self.violation_error_message % {"name": self.name}
68
69 def _check(self, model, connection):
70 return []
71
72 def _check_references(self, model, references):
73 from django.db.models.fields.composite import CompositePrimaryKey
74
75 errors = []
76 fields = set()
77 for field_name, *lookups in references:
78 # pk is an alias that won't be found by opts.get_field().
79 if field_name != "pk" or isinstance(model._meta.pk, CompositePrimaryKey):
80 fields.add(field_name)
81 if not lookups:
82 # If it has no lookups it cannot result in a JOIN.
83 continue
84 try:
85 if field_name == "pk":
86 field = model._meta.pk
87 else:
88 field = model._meta.get_field(field_name)
89 if not field.is_relation or field.many_to_many or field.one_to_many:
90 continue
91 except FieldDoesNotExist:
92 continue
93 # JOIN must happen at the first lookup.
94 first_lookup = lookups[0]
95 if (
96 hasattr(field, "get_transform")
97 and hasattr(field, "get_lookup")
98 and field.get_transform(first_lookup) is None
99 and field.get_lookup(first_lookup) is None
100 ):
101 errors.append(
102 checks.Error(
103 "'constraints' refers to the joined field '%s'."
104 % LOOKUP_SEP.join([field_name] + lookups),
105 obj=model,
106 id="models.E041",
107 )
108 )
109 errors.extend(model._check_local_fields(fields, "constraints"))
110 return errors
111
112 def deconstruct(self):
113 path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
114 path = path.replace("django.db.models.constraints", "django.db.models")
115 kwargs = {"name": self.name}
116 if (
117 self.violation_error_message is not None
118 and self.violation_error_message != self.default_violation_error_message
119 ):
120 kwargs["violation_error_message"] = self.violation_error_message
121 if self.violation_error_code is not None:
122 kwargs["violation_error_code"] = self.violation_error_code
123 return (path, (), kwargs)
124
125 def clone(self):
126 _, args, kwargs = self.deconstruct()
127 return self.__class__(*args, **kwargs)
128
129
130class CheckConstraint(BaseConstraint):
131 def __init__(
132 self,
133 *,
134 condition,
135 name,
136 violation_error_code=None,
137 violation_error_message=None,
138 ):
139 self.condition = condition
140 if not getattr(condition, "conditional", False):
141 raise TypeError(
142 "CheckConstraint.condition must be a Q instance or boolean expression."
143 )
144 super().__init__(
145 name=name,
146 violation_error_code=violation_error_code,
147 violation_error_message=violation_error_message,
148 )
149
150 def _check(self, model, connection):
151 errors = []
152 if not (
153 connection.features.supports_table_check_constraints
154 or "supports_table_check_constraints" in model._meta.required_db_features
155 ):
156 errors.append(
157 checks.Warning(
158 f"{connection.display_name} does not support check constraints.",
159 hint=(
160 "A constraint won't be created. Silence this warning if you "
161 "don't care about it."
162 ),
163 obj=model,
164 id="models.W027",
165 )
166 )
167 elif (
168 connection.features.supports_table_check_constraints
169 or "supports_table_check_constraints"
170 not in model._meta.required_db_features
171 ):
172 references = set()
173 condition = self.condition
174 if isinstance(condition, Q):
175 references.update(model._get_expr_references(condition))
176 if any(isinstance(expr, RawSQL) for expr in condition.flatten()):
177 errors.append(
178 checks.Warning(
179 f"Check constraint {self.name!r} contains RawSQL() expression "
180 "and won't be validated during the model full_clean().",
181 hint="Silence this warning if you don't care about it.",
182 obj=model,
183 id="models.W045",
184 ),
185 )
186 errors.extend(self._check_references(model, references))
187 return errors
188
189 def _get_check_sql(self, model, schema_editor):
190 query = Query(model=model, alias_cols=False)
191 where = query.build_where(self.condition)
192 compiler = query.get_compiler(connection=schema_editor.connection)
193 sql, params = where.as_sql(compiler, schema_editor.connection)
194 return sql % tuple(schema_editor.quote_value(p) for p in params)
195
196 def constraint_sql(self, model, schema_editor):
197 check = self._get_check_sql(model, schema_editor)
198 return schema_editor._check_sql(self.name, check)
199
200 def create_sql(self, model, schema_editor):
201 check = self._get_check_sql(model, schema_editor)
202 return schema_editor._create_check_sql(model, self.name, check)
203
204 def remove_sql(self, model, schema_editor):
205 return schema_editor._delete_check_sql(model, self.name)
206
207 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
208 against = instance._get_field_expression_map(meta=model._meta, exclude=exclude)
209 try:
210 if not Q(self.condition).check(against, using=using):
211 raise ValidationError(
212 self.get_violation_error_message(), code=self.violation_error_code
213 )
214 except FieldError:
215 pass
216
217 def __repr__(self):
218 return "<%s: condition=%s name=%s%s%s>" % (
219 self.__class__.__qualname__,
220 self.condition,
221 repr(self.name),
222 (
223 ""
224 if self.violation_error_code is None
225 else " violation_error_code=%r" % self.violation_error_code
226 ),
227 (
228 ""
229 if self.violation_error_message is None
230 or self.violation_error_message == self.default_violation_error_message
231 else " violation_error_message=%r" % self.violation_error_message
232 ),
233 )
234
235 def __eq__(self, other):
236 if isinstance(other, CheckConstraint):
237 return (
238 self.name == other.name
239 and self.condition == other.condition
240 and self.violation_error_code == other.violation_error_code
241 and self.violation_error_message == other.violation_error_message
242 )
243 return super().__eq__(other)
244
245 def deconstruct(self):
246 path, args, kwargs = super().deconstruct()
247 kwargs["condition"] = self.condition
248 return path, args, kwargs
249
250
251class Deferrable(Enum):
252 DEFERRED = "deferred"
253 IMMEDIATE = "immediate"
254
255 # A similar format was proposed for Python 3.10.
256 def __repr__(self):
257 return f"{self.__class__.__qualname__}.{self._name_}"
258
259
260class UniqueConstraint(BaseConstraint):
261 def __init__(
262 self,
263 *expressions,
264 fields=(),
265 name=None,
266 condition=None,
267 deferrable=None,
268 include=None,
269 opclasses=(),
270 nulls_distinct=None,
271 violation_error_code=None,
272 violation_error_message=None,
273 ):
274 if not name:
275 raise ValueError("A unique constraint must be named.")
276 if not expressions and not fields:
277 raise ValueError(
278 "At least one field or expression is required to define a "
279 "unique constraint."
280 )
281 if expressions and fields:
282 raise ValueError(
283 "UniqueConstraint.fields and expressions are mutually exclusive."
284 )
285 if not isinstance(condition, (NoneType, Q)):
286 raise ValueError("UniqueConstraint.condition must be a Q instance.")
287 if condition and deferrable:
288 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
289 if include and deferrable:
290 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
291 if opclasses and deferrable:
292 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
293 if expressions and deferrable:
294 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
295 if expressions and opclasses:
296 raise ValueError(
297 "UniqueConstraint.opclasses cannot be used with expressions. "
298 "Use django.contrib.postgres.indexes.OpClass() instead."
299 )
300 if not isinstance(deferrable, (NoneType, Deferrable)):
301 raise TypeError(
302 "UniqueConstraint.deferrable must be a Deferrable instance."
303 )
304 if not isinstance(include, (NoneType, list, tuple)):
305 raise TypeError("UniqueConstraint.include must be a list or tuple.")
306 if not isinstance(opclasses, (list, tuple)):
307 raise TypeError("UniqueConstraint.opclasses must be a list or tuple.")
308 if not isinstance(nulls_distinct, (NoneType, bool)):
309 raise TypeError("UniqueConstraint.nulls_distinct must be a bool.")
310 if opclasses and len(fields) != len(opclasses):
311 raise ValueError(
312 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
313 "have the same number of elements."
314 )
315 self.fields = tuple(fields)
316 self.condition = condition
317 self.deferrable = deferrable
318 self.include = tuple(include) if include else ()
319 self.opclasses = opclasses
320 self.nulls_distinct = nulls_distinct
321 self.expressions = tuple(
322 F(expression) if isinstance(expression, str) else expression
323 for expression in expressions
324 )
325 super().__init__(
326 name=name,
327 violation_error_code=violation_error_code,
328 violation_error_message=violation_error_message,
329 )
330
331 @property
332 def contains_expressions(self):
333 return bool(self.expressions)
334
335 def _check(self, model, connection):
336 errors = model._check_local_fields({*self.fields, *self.include}, "constraints")
337 required_db_features = model._meta.required_db_features
338 if self.condition is not None and not (
339 connection.features.supports_partial_indexes
340 or "supports_partial_indexes" in required_db_features
341 ):
342 errors.append(
343 checks.Warning(
344 f"{connection.display_name} does not support unique constraints "
345 "with conditions.",
346 hint=(
347 "A constraint won't be created. Silence this warning if you "
348 "don't care about it."
349 ),
350 obj=model,
351 id="models.W036",
352 )
353 )
354 if self.deferrable is not None and not (
355 connection.features.supports_deferrable_unique_constraints
356 or "supports_deferrable_unique_constraints" in required_db_features
357 ):
358 errors.append(
359 checks.Warning(
360 f"{connection.display_name} does not support deferrable unique "
361 "constraints.",
362 hint=(
363 "A constraint won't be created. Silence this warning if you "
364 "don't care about it."
365 ),
366 obj=model,
367 id="models.W038",
368 )
369 )
370 if self.include and not (
371 connection.features.supports_covering_indexes
372 or "supports_covering_indexes" in required_db_features
373 ):
374 errors.append(
375 checks.Warning(
376 f"{connection.display_name} does not support unique constraints "
377 "with non-key columns.",
378 hint=(
379 "A constraint won't be created. Silence this warning if you "
380 "don't care about it."
381 ),
382 obj=model,
383 id="models.W039",
384 )
385 )
386 if self.contains_expressions and not (
387 connection.features.supports_expression_indexes
388 or "supports_expression_indexes" in required_db_features
389 ):
390 errors.append(
391 checks.Warning(
392 f"{connection.display_name} does not support unique constraints on "
393 "expressions.",
394 hint=(
395 "A constraint won't be created. Silence this warning if you "
396 "don't care about it."
397 ),
398 obj=model,
399 id="models.W044",
400 )
401 )
402 if self.nulls_distinct is not None and not (
403 connection.features.supports_nulls_distinct_unique_constraints
404 or "supports_nulls_distinct_unique_constraints" in required_db_features
405 ):
406 errors.append(
407 checks.Warning(
408 f"{connection.display_name} does not support unique constraints "
409 "with nulls distinct.",
410 hint=(
411 "A constraint won't be created. Silence this warning if you "
412 "don't care about it."
413 ),
414 obj=model,
415 id="models.W047",
416 )
417 )
418 references = set()
419 if (
420 connection.features.supports_partial_indexes
421 or "supports_partial_indexes" not in required_db_features
422 ) and isinstance(self.condition, Q):
423 references.update(model._get_expr_references(self.condition))
424 if self.contains_expressions and (
425 connection.features.supports_expression_indexes
426 or "supports_expression_indexes" not in required_db_features
427 ):
428 for expression in self.expressions:
429 references.update(model._get_expr_references(expression))
430 errors.extend(self._check_references(model, references))
431 return errors
432
433 def _get_condition_sql(self, model, schema_editor):
434 if self.condition is None:
435 return None
436 query = Query(model=model, alias_cols=False)
437 where = query.build_where(self.condition)
438 compiler = query.get_compiler(connection=schema_editor.connection)
439 sql, params = where.as_sql(compiler, schema_editor.connection)
440 return sql % tuple(schema_editor.quote_value(p) for p in params)
441
442 def _get_index_expressions(self, model, schema_editor):
443 if not self.expressions:
444 return None
445 index_expressions = []
446 for expression in self.expressions:
447 index_expression = IndexExpression(expression)
448 index_expression.set_wrapper_classes(schema_editor.connection)
449 index_expressions.append(index_expression)
450 return ExpressionList(*index_expressions).resolve_expression(
451 Query(model, alias_cols=False),
452 )
453
454 def constraint_sql(self, model, schema_editor):
455 fields = [model._meta.get_field(field_name) for field_name in self.fields]
456 include = [
457 model._meta.get_field(field_name).column for field_name in self.include
458 ]
459 condition = self._get_condition_sql(model, schema_editor)
460 expressions = self._get_index_expressions(model, schema_editor)
461 return schema_editor._unique_sql(
462 model,
463 fields,
464 self.name,
465 condition=condition,
466 deferrable=self.deferrable,
467 include=include,
468 opclasses=self.opclasses,
469 expressions=expressions,
470 nulls_distinct=self.nulls_distinct,
471 )
472
473 def create_sql(self, model, schema_editor):
474 fields = [model._meta.get_field(field_name) for field_name in self.fields]
475 include = [
476 model._meta.get_field(field_name).column for field_name in self.include
477 ]
478 condition = self._get_condition_sql(model, schema_editor)
479 expressions = self._get_index_expressions(model, schema_editor)
480 return schema_editor._create_unique_sql(
481 model,
482 fields,
483 self.name,
484 condition=condition,
485 deferrable=self.deferrable,
486 include=include,
487 opclasses=self.opclasses,
488 expressions=expressions,
489 nulls_distinct=self.nulls_distinct,
490 )
491
492 def remove_sql(self, model, schema_editor):
493 condition = self._get_condition_sql(model, schema_editor)
494 include = [
495 model._meta.get_field(field_name).column for field_name in self.include
496 ]
497 expressions = self._get_index_expressions(model, schema_editor)
498 return schema_editor._delete_unique_sql(
499 model,
500 self.name,
501 condition=condition,
502 deferrable=self.deferrable,
503 include=include,
504 opclasses=self.opclasses,
505 expressions=expressions,
506 nulls_distinct=self.nulls_distinct,
507 )
508
509 def __repr__(self):
510 return "<%s:%s%s%s%s%s%s%s%s%s%s>" % (
511 self.__class__.__qualname__,
512 "" if not self.fields else " fields=%s" % repr(self.fields),
513 "" if not self.expressions else " expressions=%s" % repr(self.expressions),
514 " name=%s" % repr(self.name),
515 "" if self.condition is None else " condition=%s" % self.condition,
516 "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
517 "" if not self.include else " include=%s" % repr(self.include),
518 "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
519 (
520 ""
521 if self.nulls_distinct is None
522 else " nulls_distinct=%r" % self.nulls_distinct
523 ),
524 (
525 ""
526 if self.violation_error_code is None
527 else " violation_error_code=%r" % self.violation_error_code
528 ),
529 (
530 ""
531 if self.violation_error_message is None
532 or self.violation_error_message == self.default_violation_error_message
533 else " violation_error_message=%r" % self.violation_error_message
534 ),
535 )
536
537 def __eq__(self, other):
538 if isinstance(other, UniqueConstraint):
539 return (
540 self.name == other.name
541 and self.fields == other.fields
542 and self.condition == other.condition
543 and self.deferrable == other.deferrable
544 and self.include == other.include
545 and self.opclasses == other.opclasses
546 and self.expressions == other.expressions
547 and self.nulls_distinct is other.nulls_distinct
548 and self.violation_error_code == other.violation_error_code
549 and self.violation_error_message == other.violation_error_message
550 )
551 return super().__eq__(other)
552
553 def deconstruct(self):
554 path, args, kwargs = super().deconstruct()
555 if self.fields:
556 kwargs["fields"] = self.fields
557 if self.condition:
558 kwargs["condition"] = self.condition
559 if self.deferrable:
560 kwargs["deferrable"] = self.deferrable
561 if self.include:
562 kwargs["include"] = self.include
563 if self.opclasses:
564 kwargs["opclasses"] = self.opclasses
565 if self.nulls_distinct is not None:
566 kwargs["nulls_distinct"] = self.nulls_distinct
567 return path, self.expressions, kwargs
568
569 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
570 queryset = model._default_manager.using(using)
571 if self.fields:
572 lookup_kwargs = {}
573 generated_field_names = []
574 for field_name in self.fields:
575 if exclude and field_name in exclude:
576 return
577 field = model._meta.get_field(field_name)
578 if field.generated:
579 if exclude and self._expression_refs_exclude(
580 model, field.expression, exclude
581 ):
582 return
583 generated_field_names.append(field.name)
584 else:
585 lookup_value = getattr(instance, field.attname)
586 if (
587 self.nulls_distinct is not False
588 and lookup_value is None
589 or (
590 lookup_value == ""
591 and connections[
592 using
593 ].features.interprets_empty_strings_as_nulls
594 )
595 ):
596 # A composite constraint containing NULL value cannot cause
597 # a violation since NULL != NULL in SQL.
598 return
599 lookup_kwargs[field.name] = lookup_value
600 lookup_args = []
601 if generated_field_names:
602 field_expression_map = instance._get_field_expression_map(
603 meta=model._meta, exclude=exclude
604 )
605 for field_name in generated_field_names:
606 expression = field_expression_map[field_name]
607 if self.nulls_distinct is False:
608 lhs = F(field_name)
609 condition = Q(Exact(lhs, expression)) | Q(
610 IsNull(lhs, True), IsNull(expression, True)
611 )
612 lookup_args.append(condition)
613 else:
614 lookup_kwargs[field_name] = expression
615 queryset = queryset.filter(*lookup_args, **lookup_kwargs)
616 else:
617 # Ignore constraints with excluded fields.
618 if exclude and any(
619 self._expression_refs_exclude(model, expression, exclude)
620 for expression in self.expressions
621 ):
622 return
623 replacements = {
624 F(field): value
625 for field, value in instance._get_field_expression_map(
626 meta=model._meta, exclude=exclude
627 ).items()
628 }
629 filters = []
630 for expr in self.expressions:
631 if hasattr(expr, "get_expression_for_validation"):
632 expr = expr.get_expression_for_validation()
633 rhs = expr.replace_expressions(replacements)
634 condition = Exact(expr, rhs)
635 if self.nulls_distinct is False:
636 condition = Q(condition) | Q(IsNull(expr, True), IsNull(rhs, True))
637 filters.append(condition)
638 queryset = queryset.filter(*filters)
639 model_class_pk = instance._get_pk_val(model._meta)
640 if not instance._state.adding and instance._is_pk_set(model._meta):
641 queryset = queryset.exclude(pk=model_class_pk)
642 if not self.condition:
643 if queryset.exists():
644 if (
645 self.fields
646 and self.violation_error_message
647 == self.default_violation_error_message
648 ):
649 # When fields are defined, use the unique_error_message() as
650 # a default for backward compatibility.
651 validation_error_message = instance.unique_error_message(
652 model, self.fields
653 )
654 raise ValidationError(
655 validation_error_message,
656 code=validation_error_message.code,
657 )
658 raise ValidationError(
659 self.get_violation_error_message(),
660 code=self.violation_error_code,
661 )
662 else:
663 against = instance._get_field_expression_map(
664 meta=model._meta, exclude=exclude
665 )
666 try:
667 if (self.condition & Exists(queryset.filter(self.condition))).check(
668 against, using=using
669 ):
670 raise ValidationError(
671 self.get_violation_error_message(),
672 code=self.violation_error_code,
673 )
674 except FieldError:
675 pass