1from django.core import checks
2from django.db import connections, router
3from django.db.models.sql import Query
4from django.utils.functional import cached_property
5
6from . import NOT_PROVIDED, Field
7
8__all__ = ["GeneratedField"]
9
10
11class GeneratedField(Field):
12 generated = True
13 db_returning = True
14
15 _query = None
16 output_field = None
17
18 def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
19 if kwargs.setdefault("editable", False):
20 raise ValueError("GeneratedField cannot be editable.")
21 if not kwargs.setdefault("blank", True):
22 raise ValueError("GeneratedField must be blank.")
23 if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
24 raise ValueError("GeneratedField cannot have a default.")
25 if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
26 raise ValueError("GeneratedField cannot have a database default.")
27 if db_persist not in (True, False):
28 raise ValueError("GeneratedField.db_persist must be True or False.")
29
30 self.expression = expression
31 self.output_field = output_field
32 self.db_persist = db_persist
33 super().__init__(**kwargs)
34
35 @cached_property
36 def cached_col(self):
37 from django.db.models.expressions import Col
38
39 return Col(self.model._meta.db_table, self, self.output_field)
40
41 def get_col(self, alias, output_field=None):
42 if alias != self.model._meta.db_table and output_field in (None, self):
43 output_field = self.output_field
44 return super().get_col(alias, output_field)
45
46 def contribute_to_class(self, *args, **kwargs):
47 super().contribute_to_class(*args, **kwargs)
48
49 self._query = Query(model=self.model, alias_cols=False)
50 # Register lookups from the output_field class.
51 for lookup_name, lookup in self.output_field.get_class_lookups().items():
52 self.register_lookup(lookup, lookup_name=lookup_name)
53
54 def generated_sql(self, connection):
55 compiler = connection.ops.compiler("SQLCompiler")(
56 self._query, connection=connection, using=None
57 )
58 resolved_expression = self.expression.resolve_expression(
59 self._query, allow_joins=False
60 )
61 sql, params = compiler.compile(resolved_expression)
62 if (
63 getattr(self.expression, "conditional", False)
64 and not connection.features.supports_boolean_expr_in_select_clause
65 ):
66 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
67 return sql, params
68
69 def check(self, **kwargs):
70 databases = kwargs.get("databases") or []
71 errors = [
72 *super().check(**kwargs),
73 *self._check_supported(databases),
74 *self._check_persistence(databases),
75 ]
76 output_field_clone = self.output_field.clone()
77 output_field_clone.model = self.model
78 output_field_checks = output_field_clone.check(databases=databases)
79 if output_field_checks:
80 separator = "\n "
81 error_messages = separator.join(
82 f"{output_check.msg} ({output_check.id})"
83 for output_check in output_field_checks
84 if isinstance(output_check, checks.Error)
85 )
86 if error_messages:
87 errors.append(
88 checks.Error(
89 "GeneratedField.output_field has errors:"
90 f"{separator}{error_messages}",
91 obj=self,
92 id="fields.E223",
93 )
94 )
95 warning_messages = separator.join(
96 f"{output_check.msg} ({output_check.id})"
97 for output_check in output_field_checks
98 if isinstance(output_check, checks.Warning)
99 )
100 if warning_messages:
101 errors.append(
102 checks.Warning(
103 "GeneratedField.output_field has warnings:"
104 f"{separator}{warning_messages}",
105 obj=self,
106 id="fields.W224",
107 )
108 )
109 return errors
110
111 def _check_supported(self, databases):
112 errors = []
113 for db in databases:
114 if not router.allow_migrate_model(db, self.model):
115 continue
116 connection = connections[db]
117 if (
118 self.model._meta.required_db_vendor
119 and self.model._meta.required_db_vendor != connection.vendor
120 ):
121 continue
122 if not (
123 connection.features.supports_virtual_generated_columns
124 or "supports_stored_generated_columns"
125 in self.model._meta.required_db_features
126 ) and not (
127 connection.features.supports_stored_generated_columns
128 or "supports_virtual_generated_columns"
129 in self.model._meta.required_db_features
130 ):
131 errors.append(
132 checks.Error(
133 f"{connection.display_name} does not support GeneratedFields.",
134 obj=self,
135 id="fields.E220",
136 )
137 )
138 return errors
139
140 def _check_persistence(self, databases):
141 errors = []
142 for db in databases:
143 if not router.allow_migrate_model(db, self.model):
144 continue
145 connection = connections[db]
146 if (
147 self.model._meta.required_db_vendor
148 and self.model._meta.required_db_vendor != connection.vendor
149 ):
150 continue
151 if not self.db_persist and not (
152 connection.features.supports_virtual_generated_columns
153 or "supports_virtual_generated_columns"
154 in self.model._meta.required_db_features
155 ):
156 errors.append(
157 checks.Error(
158 f"{connection.display_name} does not support non-persisted "
159 "GeneratedFields.",
160 obj=self,
161 id="fields.E221",
162 hint="Set db_persist=True on the field.",
163 )
164 )
165 if self.db_persist and not (
166 connection.features.supports_stored_generated_columns
167 or "supports_stored_generated_columns"
168 in self.model._meta.required_db_features
169 ):
170 errors.append(
171 checks.Error(
172 f"{connection.display_name} does not support persisted "
173 "GeneratedFields.",
174 obj=self,
175 id="fields.E222",
176 hint="Set db_persist=False on the field.",
177 )
178 )
179 return errors
180
181 def deconstruct(self):
182 name, path, args, kwargs = super().deconstruct()
183 del kwargs["blank"]
184 del kwargs["editable"]
185 kwargs["db_persist"] = self.db_persist
186 kwargs["expression"] = self.expression
187 kwargs["output_field"] = self.output_field
188 return name, path, args, kwargs
189
190 def get_internal_type(self):
191 return self.output_field.get_internal_type()
192
193 def db_parameters(self, connection):
194 return self.output_field.db_parameters(connection)
195
196 def db_type_parameters(self, connection):
197 return self.output_field.db_type_parameters(connection)