1"""
2Various data structures used in query construction.
3
4Factored out from django.db.models.query to avoid making the main module very
5large and/or so that they can be used by other modules without getting into
6circular import difficulties.
7"""
8
9import functools
10import inspect
11import logging
12from collections import namedtuple
13from contextlib import nullcontext
14
15from django.core.exceptions import FieldError
16from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, transaction
17from django.db.models.constants import LOOKUP_SEP
18from django.utils import tree
19from django.utils.functional import cached_property
20from django.utils.hashable import make_hashable
21
22logger = logging.getLogger("django.db.models")
23
24# PathInfo is used when converting lookups (fk__somecol). The contents
25# describe the relation in Model terms (model Options and Fields for both
26# sides of the relation. The join_field is the field backing the relation.
27PathInfo = namedtuple(
28 "PathInfo",
29 "from_opts to_opts target_fields join_field m2m direct filtered_relation",
30)
31
32
33def subclasses(cls):
34 yield cls
35 for subclass in cls.__subclasses__():
36 yield from subclasses(subclass)
37
38
39class Q(tree.Node):
40 """
41 Encapsulate filters as objects that can then be combined logically (using
42 `&` and `|`).
43 """
44
45 # Connection types
46 AND = "AND"
47 OR = "OR"
48 XOR = "XOR"
49 default = AND
50 conditional = True
51
52 def __init__(self, *args, _connector=None, _negated=False, **kwargs):
53 super().__init__(
54 children=[*args, *sorted(kwargs.items())],
55 connector=_connector,
56 negated=_negated,
57 )
58
59 def _combine(self, other, conn):
60 if getattr(other, "conditional", False) is False:
61 raise TypeError(other)
62 if not self:
63 return other.copy()
64 if not other and isinstance(other, Q):
65 return self.copy()
66
67 obj = self.create(connector=conn)
68 obj.add(self, conn)
69 obj.add(other, conn)
70 return obj
71
72 def __or__(self, other):
73 return self._combine(other, self.OR)
74
75 def __and__(self, other):
76 return self._combine(other, self.AND)
77
78 def __xor__(self, other):
79 return self._combine(other, self.XOR)
80
81 def __invert__(self):
82 obj = self.copy()
83 obj.negate()
84 return obj
85
86 def resolve_expression(
87 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
88 ):
89 # We must promote any new joins to left outer joins so that when Q is
90 # used as an expression, rows aren't filtered due to joins.
91 clause, joins = query._add_q(
92 self,
93 reuse,
94 allow_joins=allow_joins,
95 split_subq=False,
96 check_filterable=False,
97 summarize=summarize,
98 )
99 query.promote_joins(joins)
100 return clause
101
102 def flatten(self):
103 """
104 Recursively yield this Q object and all subexpressions, in depth-first
105 order.
106 """
107 yield self
108 for child in self.children:
109 if isinstance(child, tuple):
110 # Use the lookup.
111 child = child[1]
112 if hasattr(child, "flatten"):
113 yield from child.flatten()
114 else:
115 yield child
116
117 def check(self, against, using=DEFAULT_DB_ALIAS):
118 """
119 Do a database query to check if the expressions of the Q instance
120 matches against the expressions.
121 """
122 # Avoid circular imports.
123 from django.db.models import BooleanField, Value
124 from django.db.models.functions import Coalesce
125 from django.db.models.sql import Query
126 from django.db.models.sql.constants import SINGLE
127
128 query = Query(None)
129 for name, value in against.items():
130 if not hasattr(value, "resolve_expression"):
131 value = Value(value)
132 query.add_annotation(value, name, select=False)
133 query.add_annotation(Value(1), "_check")
134 connection = connections[using]
135 # This will raise a FieldError if a field is missing in "against".
136 if connection.features.supports_comparing_boolean_expr:
137 query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
138 else:
139 query.add_q(self)
140 compiler = query.get_compiler(using=using)
141 context_manager = (
142 transaction.atomic(using=using)
143 if connection.in_atomic_block
144 else nullcontext()
145 )
146 try:
147 with context_manager:
148 return compiler.execute_sql(SINGLE) is not None
149 except DatabaseError as e:
150 logger.warning("Got a database error calling check() on %r: %s", self, e)
151 return True
152
153 def deconstruct(self):
154 path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
155 if path.startswith("django.db.models.query_utils"):
156 path = path.replace("django.db.models.query_utils", "django.db.models")
157 args = tuple(self.children)
158 kwargs = {}
159 if self.connector != self.default:
160 kwargs["_connector"] = self.connector
161 if self.negated:
162 kwargs["_negated"] = True
163 return path, args, kwargs
164
165 @cached_property
166 def identity(self):
167 path, args, kwargs = self.deconstruct()
168 identity = [path, *kwargs.items()]
169 for child in args:
170 if isinstance(child, tuple):
171 arg, value = child
172 value = make_hashable(value)
173 identity.append((arg, value))
174 else:
175 identity.append(child)
176 return tuple(identity)
177
178 def __eq__(self, other):
179 if not isinstance(other, Q):
180 return NotImplemented
181 return other.identity == self.identity
182
183 def __hash__(self):
184 return hash(self.identity)
185
186 @cached_property
187 def referenced_base_fields(self):
188 """
189 Retrieve all base fields referenced directly or through F expressions
190 excluding any fields referenced through joins.
191 """
192 # Avoid circular imports.
193 from django.db.models.sql import query
194
195 return {
196 child.split(LOOKUP_SEP, 1)[0] for child in query.get_children_from_q(self)
197 }
198
199
200class DeferredAttribute:
201 """
202 A wrapper for a deferred-loading field. When the value is read from this
203 object the first time, the query is executed.
204 """
205
206 def __init__(self, field):
207 self.field = field
208
209 def __get__(self, instance, cls=None):
210 """
211 Retrieve and caches the value from the datastore on the first lookup.
212 Return the cached value.
213 """
214 if instance is None:
215 return self
216 data = instance.__dict__
217 field_name = self.field.attname
218 if field_name not in data:
219 # Let's see if the field is part of the parent chain. If so we
220 # might be able to reuse the already loaded value. Refs #18343.
221 val = self._check_parent_chain(instance)
222 if val is None:
223 if not instance._is_pk_set() and self.field.generated:
224 raise AttributeError(
225 "Cannot read a generated field from an unsaved model."
226 )
227 instance.refresh_from_db(fields=[field_name])
228 else:
229 data[field_name] = val
230 return data[field_name]
231
232 def _check_parent_chain(self, instance):
233 """
234 Check if the field value can be fetched from a parent field already
235 loaded in the instance. This can be done if the to-be fetched
236 field is a primary key field.
237 """
238 opts = instance._meta
239 link_field = opts.get_ancestor_link(self.field.model)
240 if self.field.primary_key and self.field != link_field:
241 return getattr(instance, link_field.attname)
242 return None
243
244
245class class_or_instance_method:
246 """
247 Hook used in RegisterLookupMixin to return partial functions depending on
248 the caller type (instance or class of models.Field).
249 """
250
251 def __init__(self, class_method, instance_method):
252 self.class_method = class_method
253 self.instance_method = instance_method
254
255 def __get__(self, instance, owner):
256 if instance is None:
257 return functools.partial(self.class_method, owner)
258 return functools.partial(self.instance_method, instance)
259
260
261class RegisterLookupMixin:
262 def _get_lookup(self, lookup_name):
263 return self.get_lookups().get(lookup_name, None)
264
265 @functools.cache
266 def get_class_lookups(cls):
267 class_lookups = [
268 parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
269 ]
270 return cls.merge_dicts(class_lookups)
271
272 def get_instance_lookups(self):
273 class_lookups = self.get_class_lookups()
274 if instance_lookups := getattr(self, "instance_lookups", None):
275 return {**class_lookups, **instance_lookups}
276 return class_lookups
277
278 get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
279 get_class_lookups = classmethod(get_class_lookups)
280
281 def get_lookup(self, lookup_name):
282 from django.db.models.lookups import Lookup
283
284 found = self._get_lookup(lookup_name)
285 if found is None and hasattr(self, "output_field"):
286 return self.output_field.get_lookup(lookup_name)
287 if found is not None and not issubclass(found, Lookup):
288 return None
289 return found
290
291 def get_transform(self, lookup_name):
292 from django.db.models.lookups import Transform
293
294 found = self._get_lookup(lookup_name)
295 if found is None and hasattr(self, "output_field"):
296 return self.output_field.get_transform(lookup_name)
297 if found is not None and not issubclass(found, Transform):
298 return None
299 return found
300
301 @staticmethod
302 def merge_dicts(dicts):
303 """
304 Merge dicts in reverse to preference the order of the original list. e.g.,
305 merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
306 """
307 merged = {}
308 for d in reversed(dicts):
309 merged.update(d)
310 return merged
311
312 @classmethod
313 def _clear_cached_class_lookups(cls):
314 for subclass in subclasses(cls):
315 subclass.get_class_lookups.cache_clear()
316
317 def register_class_lookup(cls, lookup, lookup_name=None):
318 if lookup_name is None:
319 lookup_name = lookup.lookup_name
320 if "class_lookups" not in cls.__dict__:
321 cls.class_lookups = {}
322 cls.class_lookups[lookup_name] = lookup
323 cls._clear_cached_class_lookups()
324 return lookup
325
326 def register_instance_lookup(self, lookup, lookup_name=None):
327 if lookup_name is None:
328 lookup_name = lookup.lookup_name
329 if "instance_lookups" not in self.__dict__:
330 self.instance_lookups = {}
331 self.instance_lookups[lookup_name] = lookup
332 return lookup
333
334 register_lookup = class_or_instance_method(
335 register_class_lookup, register_instance_lookup
336 )
337 register_class_lookup = classmethod(register_class_lookup)
338
339 def _unregister_class_lookup(cls, lookup, lookup_name=None):
340 """
341 Remove given lookup from cls lookups. For use in tests only as it's
342 not thread-safe.
343 """
344 if lookup_name is None:
345 lookup_name = lookup.lookup_name
346 del cls.class_lookups[lookup_name]
347 cls._clear_cached_class_lookups()
348
349 def _unregister_instance_lookup(self, lookup, lookup_name=None):
350 """
351 Remove given lookup from instance lookups. For use in tests only as
352 it's not thread-safe.
353 """
354 if lookup_name is None:
355 lookup_name = lookup.lookup_name
356 del self.instance_lookups[lookup_name]
357
358 _unregister_lookup = class_or_instance_method(
359 _unregister_class_lookup, _unregister_instance_lookup
360 )
361 _unregister_class_lookup = classmethod(_unregister_class_lookup)
362
363
364def select_related_descend(field, restricted, requested, select_mask):
365 """
366 Return whether `field` should be used to descend deeper for
367 `select_related()` purposes.
368
369 Arguments:
370 * `field` - the field to be checked. Can be either a `Field` or
371 `ForeignObjectRel` instance.
372 * `restricted` - a boolean field, indicating if the field list has been
373 manually restricted using a select_related() clause.
374 * `requested` - the select_related() dictionary.
375 * `select_mask` - the dictionary of selected fields.
376 """
377 # Only relationships can be descended.
378 if not field.remote_field:
379 return False
380 # Forward MTI parent links should not be explicitly descended as they are
381 # always JOIN'ed against (unless excluded by `select_mask`).
382 if getattr(field.remote_field, "parent_link", False):
383 return False
384 # When `select_related()` is used without a `*requested` mask all
385 # relationships are descended unless they are nullable.
386 if not restricted:
387 return not field.null
388 # When `select_related(*requested)` is used only fields that are part of
389 # `requested` should be descended.
390 if field.name not in requested:
391 return False
392 # Prevent invalid usages of `select_related()` and `only()`/`defer()` such
393 # as `select_related("a").only("b")` and `select_related("a").defer("a")`.
394 if select_mask and field not in select_mask:
395 raise FieldError(
396 f"Field {field.model._meta.object_name}.{field.name} cannot be both "
397 "deferred and traversed using select_related at the same time."
398 )
399 return True
400
401
402def refs_expression(lookup_parts, annotations):
403 """
404 Check if the lookup_parts contains references to the given annotations set.
405 Because the LOOKUP_SEP is contained in the default annotation names, check
406 each prefix of the lookup_parts for a match.
407 """
408 for n in range(1, len(lookup_parts) + 1):
409 level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
410 if annotations.get(level_n_lookup):
411 return level_n_lookup, lookup_parts[n:]
412 return None, ()
413
414
415def check_rel_lookup_compatibility(model, target_opts, field):
416 """
417 Check that self.model is compatible with target_opts. Compatibility
418 is OK if:
419 1) model and opts match (where proxy inheritance is removed)
420 2) model is parent of opts' model or the other way around
421 """
422
423 def check(opts):
424 return (
425 model._meta.concrete_model == opts.concrete_model
426 or opts.concrete_model in model._meta.all_parents
427 or model in opts.all_parents
428 )
429
430 # If the field is a primary key, then doing a query against the field's
431 # model is ok, too. Consider the case:
432 # class Restaurant(models.Model):
433 # place = OneToOneField(Place, primary_key=True):
434 # Restaurant.objects.filter(pk__in=Restaurant.objects.all()).
435 # If we didn't have the primary key check, then pk__in (== place__in) would
436 # give Place's opts as the target opts, but Restaurant isn't compatible
437 # with that. This logic applies only to primary keys, as when doing __in=qs,
438 # we are going to turn this into __in=qs.values('pk') later on.
439 return check(target_opts) or (
440 getattr(field, "primary_key", False) and check(field.model._meta)
441 )
442
443
444class FilteredRelation:
445 """Specify custom filtering in the ON clause of SQL joins."""
446
447 def __init__(self, relation_name, *, condition=Q()):
448 if not relation_name:
449 raise ValueError("relation_name cannot be empty.")
450 self.relation_name = relation_name
451 self.alias = None
452 if not isinstance(condition, Q):
453 raise ValueError("condition argument must be a Q() instance.")
454 # .condition and .resolved_condition have to be stored independently
455 # as the former must remain unchanged for Join.__eq__ to remain stable
456 # and reusable even once their .filtered_relation are resolved.
457 self.condition = condition
458 self.resolved_condition = None
459
460 def __eq__(self, other):
461 if not isinstance(other, self.__class__):
462 return NotImplemented
463 return (
464 self.relation_name == other.relation_name
465 and self.alias == other.alias
466 and self.condition == other.condition
467 )
468
469 def clone(self):
470 clone = FilteredRelation(self.relation_name, condition=self.condition)
471 clone.alias = self.alias
472 if (resolved_condition := self.resolved_condition) is not None:
473 clone.resolved_condition = resolved_condition.clone()
474 return clone
475
476 def relabeled_clone(self, change_map):
477 clone = self.clone()
478 if resolved_condition := clone.resolved_condition:
479 clone.resolved_condition = resolved_condition.relabeled_clone(change_map)
480 return clone
481
482 def resolve_expression(self, query, reuse, *args, **kwargs):
483 clone = self.clone()
484 clone.resolved_condition = query.build_filter(
485 self.condition,
486 can_reuse=reuse,
487 allow_joins=True,
488 split_subq=False,
489 update_join_types=False,
490 )[0]
491 return clone
492
493 def as_sql(self, compiler, connection):
494 return compiler.compile(self.resolved_condition)