1"""
2Code to manage the creation and SQL rendering of 'where' constraints.
3"""
4
5import operator
6from functools import reduce
7
8from django.core.exceptions import EmptyResultSet, FullResultSet
9from django.db.models.expressions import Case, When
10from django.db.models.functions import Mod
11from django.db.models.lookups import Exact
12from django.utils import tree
13from django.utils.functional import cached_property
14
15# Connection types
16AND = "AND"
17OR = "OR"
18XOR = "XOR"
19
20
21class WhereNode(tree.Node):
22 """
23 An SQL WHERE clause.
24
25 The class is tied to the Query class that created it (in order to create
26 the correct SQL).
27
28 A child is usually an expression producing boolean values. Most likely the
29 expression is a Lookup instance.
30
31 However, a child could also be any class with as_sql() and either
32 relabeled_clone() method or relabel_aliases() and clone() methods and
33 contains_aggregate attribute.
34 """
35
36 default = AND
37 resolved = False
38 conditional = True
39
40 def split_having_qualify(self, negated=False, must_group_by=False):
41 """
42 Return three possibly None nodes: one for those parts of self that
43 should be included in the WHERE clause, one for those parts of self
44 that must be included in the HAVING clause, and one for those parts
45 that refer to window functions.
46 """
47 if not self.contains_aggregate and not self.contains_over_clause:
48 return self, None, None
49 in_negated = negated ^ self.negated
50 # Whether or not children must be connected in the same filtering
51 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
52 must_remain_connected = (
53 (in_negated and self.connector == AND)
54 or (not in_negated and self.connector == OR)
55 or self.connector == XOR
56 )
57 if (
58 must_remain_connected
59 and self.contains_aggregate
60 and not self.contains_over_clause
61 ):
62 # It's must cheaper to short-circuit and stash everything in the
63 # HAVING clause than split children if possible.
64 return None, self, None
65 where_parts = []
66 having_parts = []
67 qualify_parts = []
68 for c in self.children:
69 if hasattr(c, "split_having_qualify"):
70 where_part, having_part, qualify_part = c.split_having_qualify(
71 in_negated, must_group_by
72 )
73 if where_part is not None:
74 where_parts.append(where_part)
75 if having_part is not None:
76 having_parts.append(having_part)
77 if qualify_part is not None:
78 qualify_parts.append(qualify_part)
79 elif c.contains_over_clause:
80 qualify_parts.append(c)
81 elif c.contains_aggregate:
82 having_parts.append(c)
83 else:
84 where_parts.append(c)
85 if must_remain_connected and qualify_parts:
86 # Disjunctive heterogeneous predicates can be pushed down to
87 # qualify as long as no conditional aggregation is involved.
88 if not where_parts or (where_parts and not must_group_by):
89 return None, None, self
90 elif where_parts:
91 # In theory this should only be enforced when dealing with
92 # where_parts containing predicates against multi-valued
93 # relationships that could affect aggregation results but this
94 # is complex to infer properly.
95 raise NotImplementedError(
96 "Heterogeneous disjunctive predicates against window functions are "
97 "not implemented when performing conditional aggregation."
98 )
99 where_node = (
100 self.create(where_parts, self.connector, self.negated)
101 if where_parts
102 else None
103 )
104 having_node = (
105 self.create(having_parts, self.connector, self.negated)
106 if having_parts
107 else None
108 )
109 qualify_node = (
110 self.create(qualify_parts, self.connector, self.negated)
111 if qualify_parts
112 else None
113 )
114 return where_node, having_node, qualify_node
115
116 def as_sql(self, compiler, connection):
117 """
118 Return the SQL version of the where clause and the value to be
119 substituted in. Return '', [] if this node matches everything,
120 None, [] if this node is empty, and raise EmptyResultSet if this
121 node can't match anything.
122 """
123 result = []
124 result_params = []
125 if self.connector == AND:
126 full_needed, empty_needed = len(self.children), 1
127 else:
128 full_needed, empty_needed = 1, len(self.children)
129
130 if self.connector == XOR and not connection.features.supports_logical_xor:
131 # Convert if the database doesn't support XOR:
132 # a XOR b XOR c XOR ...
133 # to:
134 # (a OR b OR c OR ...) AND MOD(a + b + c + ..., 2) == 1
135 # The result of an n-ary XOR is true when an odd number of operands
136 # are true.
137 lhs = self.__class__(self.children, OR)
138 rhs_sum = reduce(
139 operator.add,
140 (Case(When(c, then=1), default=0) for c in self.children),
141 )
142 if len(self.children) > 2:
143 rhs_sum = Mod(rhs_sum, 2)
144 rhs = Exact(1, rhs_sum)
145 return self.__class__([lhs, rhs], AND, self.negated).as_sql(
146 compiler, connection
147 )
148
149 for child in self.children:
150 try:
151 sql, params = compiler.compile(child)
152 except EmptyResultSet:
153 empty_needed -= 1
154 except FullResultSet:
155 full_needed -= 1
156 else:
157 if sql:
158 result.append(sql)
159 result_params.extend(params)
160 else:
161 full_needed -= 1
162 # Check if this node matches nothing or everything.
163 # First check the amount of full nodes and empty nodes
164 # to make this node empty/full.
165 # Now, check if this node is full/empty using the
166 # counts.
167 if empty_needed == 0:
168 if self.negated:
169 raise FullResultSet
170 else:
171 raise EmptyResultSet
172 if full_needed == 0:
173 if self.negated:
174 raise EmptyResultSet
175 else:
176 raise FullResultSet
177 conn = " %s " % self.connector
178 sql_string = conn.join(result)
179 if not sql_string:
180 raise FullResultSet
181 if self.negated:
182 # Some backends (Oracle at least) need parentheses around the inner
183 # SQL in the negated case, even if the inner SQL contains just a
184 # single expression.
185 sql_string = "NOT (%s)" % sql_string
186 elif len(result) > 1 or self.resolved:
187 sql_string = "(%s)" % sql_string
188 return sql_string, result_params
189
190 def get_group_by_cols(self):
191 cols = []
192 for child in self.children:
193 cols.extend(child.get_group_by_cols())
194 return cols
195
196 def get_source_expressions(self):
197 return self.children[:]
198
199 def set_source_expressions(self, children):
200 assert len(children) == len(self.children)
201 self.children = children
202
203 def relabel_aliases(self, change_map):
204 """
205 Relabel the alias values of any children. 'change_map' is a dictionary
206 mapping old (current) alias values to the new values.
207 """
208 if not change_map:
209 return self
210 for pos, child in enumerate(self.children):
211 if hasattr(child, "relabel_aliases"):
212 # For example another WhereNode
213 child.relabel_aliases(change_map)
214 elif hasattr(child, "relabeled_clone"):
215 self.children[pos] = child.relabeled_clone(change_map)
216
217 def clone(self):
218 clone = self.create(connector=self.connector, negated=self.negated)
219 for child in self.children:
220 if hasattr(child, "clone"):
221 child = child.clone()
222 clone.children.append(child)
223 return clone
224
225 def relabeled_clone(self, change_map):
226 clone = self.clone()
227 clone.relabel_aliases(change_map)
228 return clone
229
230 def replace_expressions(self, replacements):
231 if not replacements:
232 return self
233 if replacement := replacements.get(self):
234 return replacement
235 clone = self.create(connector=self.connector, negated=self.negated)
236 for child in self.children:
237 clone.children.append(child.replace_expressions(replacements))
238 return clone
239
240 def get_refs(self):
241 refs = set()
242 for child in self.children:
243 refs |= child.get_refs()
244 return refs
245
246 @classmethod
247 def _contains_aggregate(cls, obj):
248 if isinstance(obj, tree.Node):
249 return any(cls._contains_aggregate(c) for c in obj.children)
250 return obj.contains_aggregate
251
252 @cached_property
253 def contains_aggregate(self):
254 return self._contains_aggregate(self)
255
256 @classmethod
257 def _contains_over_clause(cls, obj):
258 if isinstance(obj, tree.Node):
259 return any(cls._contains_over_clause(c) for c in obj.children)
260 return obj.contains_over_clause
261
262 @cached_property
263 def contains_over_clause(self):
264 return self._contains_over_clause(self)
265
266 @property
267 def is_summary(self):
268 return any(child.is_summary for child in self.children)
269
270 @staticmethod
271 def _resolve_leaf(expr, query, *args, **kwargs):
272 if hasattr(expr, "resolve_expression"):
273 expr = expr.resolve_expression(query, *args, **kwargs)
274 return expr
275
276 @classmethod
277 def _resolve_node(cls, node, query, *args, **kwargs):
278 if hasattr(node, "children"):
279 for child in node.children:
280 cls._resolve_node(child, query, *args, **kwargs)
281 if hasattr(node, "lhs"):
282 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
283 if hasattr(node, "rhs"):
284 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
285
286 def resolve_expression(self, *args, **kwargs):
287 clone = self.clone()
288 clone._resolve_node(clone, *args, **kwargs)
289 clone.resolved = True
290 return clone
291
292 @cached_property
293 def output_field(self):
294 from django.db.models import BooleanField
295
296 return BooleanField()
297
298 @property
299 def _output_field_or_none(self):
300 return self.output_field
301
302 def select_format(self, compiler, sql, params):
303 # Wrap filters with a CASE WHEN expression if a database backend
304 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
305 # BY list.
306 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
307 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
308 return sql, params
309
310 def get_db_converters(self, connection):
311 return self.output_field.get_db_converters(connection)
312
313 def get_lookup(self, lookup):
314 return self.output_field.get_lookup(lookup)
315
316 def leaves(self):
317 for child in self.children:
318 if isinstance(child, WhereNode):
319 yield from child.leaves()
320 else:
321 yield child
322
323
324class NothingNode:
325 """A node that matches nothing."""
326
327 contains_aggregate = False
328 contains_over_clause = False
329
330 def as_sql(self, compiler=None, connection=None):
331 raise EmptyResultSet
332
333
334class ExtraWhere:
335 # The contents are a black box - assume no aggregates or windows are used.
336 contains_aggregate = False
337 contains_over_clause = False
338
339 def __init__(self, sqls, params):
340 self.sqls = sqls
341 self.params = params
342
343 def as_sql(self, compiler=None, connection=None):
344 sqls = ["(%s)" % sql for sql in self.sqls]
345 return " AND ".join(sqls), list(self.params or ())
346
347
348class SubqueryConstraint:
349 # Even if aggregates or windows would be used in a subquery,
350 # the outer query isn't interested about those.
351 contains_aggregate = False
352 contains_over_clause = False
353
354 def __init__(self, alias, columns, targets, query_object):
355 self.alias = alias
356 self.columns = columns
357 self.targets = targets
358 query_object.clear_ordering(clear_default=True)
359 self.query_object = query_object
360
361 def as_sql(self, compiler, connection):
362 query = self.query_object
363 query.set_values(self.targets)
364 query_compiler = query.get_compiler(connection=connection)
365 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)