1import itertools
2
3from django.core.exceptions import EmptyResultSet
4from django.db.models import Field
5from django.db.models.expressions import ColPairs, Func, ResolvedOuterRef, Value
6from django.db.models.lookups import (
7 Exact,
8 GreaterThan,
9 GreaterThanOrEqual,
10 In,
11 IsNull,
12 LessThan,
13 LessThanOrEqual,
14)
15from django.db.models.sql import Query
16from django.db.models.sql.where import AND, OR, WhereNode
17
18
19class Tuple(Func):
20 allows_composite_expressions = True
21 function = ""
22 output_field = Field()
23
24 def __len__(self):
25 return len(self.source_expressions)
26
27 def __iter__(self):
28 return iter(self.source_expressions)
29
30
31class TupleLookupMixin:
32 allows_composite_expressions = True
33
34 def get_prep_lookup(self):
35 if self.rhs_is_direct_value():
36 self.check_rhs_is_tuple_or_list()
37 self.check_rhs_length_equals_lhs_length()
38 else:
39 self.check_rhs_is_outer_ref()
40 return self.rhs
41
42 def check_rhs_is_tuple_or_list(self):
43 if not isinstance(self.rhs, (tuple, list)):
44 lhs_str = self.get_lhs_str()
45 raise ValueError(
46 f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
47 )
48
49 def check_rhs_length_equals_lhs_length(self):
50 len_lhs = len(self.lhs)
51 if len_lhs != len(self.rhs):
52 lhs_str = self.get_lhs_str()
53 raise ValueError(
54 f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
55 )
56
57 def check_rhs_is_outer_ref(self):
58 if not isinstance(self.rhs, ResolvedOuterRef):
59 lhs_str = self.get_lhs_str()
60 rhs_cls = self.rhs.__class__.__name__
61 raise ValueError(
62 f"{self.lookup_name!r} subquery lookup of {lhs_str} "
63 f"only supports OuterRef objects (received {rhs_cls!r})"
64 )
65
66 def get_lhs_str(self):
67 if isinstance(self.lhs, ColPairs):
68 return repr(self.lhs.field.name)
69 else:
70 names = ", ".join(repr(f.name) for f in self.lhs)
71 return f"({names})"
72
73 def get_prep_lhs(self):
74 if isinstance(self.lhs, (tuple, list)):
75 return Tuple(*self.lhs)
76 return super().get_prep_lhs()
77
78 def process_lhs(self, compiler, connection, lhs=None):
79 sql, params = super().process_lhs(compiler, connection, lhs)
80 if not isinstance(self.lhs, Tuple):
81 sql = f"({sql})"
82 return sql, params
83
84 def process_rhs(self, compiler, connection):
85 if self.rhs_is_direct_value():
86 args = [
87 Value(val, output_field=col.output_field)
88 for col, val in zip(self.lhs, self.rhs)
89 ]
90 return Tuple(*args).as_sql(compiler, connection)
91 else:
92 sql, params = compiler.compile(self.rhs)
93 if not isinstance(self.rhs, ColPairs):
94 raise ValueError(
95 "Composite field lookups only work with composite expressions."
96 )
97 return "(%s)" % sql, params
98
99
100class TupleExact(TupleLookupMixin, Exact):
101 def as_oracle(self, compiler, connection):
102 # Process right-hand-side to trigger sanitization.
103 self.process_rhs(compiler, connection)
104 # e.g.: (a, b, c) == (x, y, z) as SQL:
105 # WHERE a = x AND b = y AND c = z
106 lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
107 root = WhereNode(lookups, connector=AND)
108
109 return root.as_sql(compiler, connection)
110
111
112class TupleIsNull(TupleLookupMixin, IsNull):
113 def get_prep_lookup(self):
114 rhs = self.rhs
115 if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
116 rhs = rhs[0]
117 if isinstance(rhs, bool):
118 return rhs
119 raise ValueError(
120 "The QuerySet value for an isnull lookup must be True or False."
121 )
122
123 def as_sql(self, compiler, connection):
124 # e.g.: (a, b, c) is None as SQL:
125 # WHERE a IS NULL OR b IS NULL OR c IS NULL
126 # e.g.: (a, b, c) is not None as SQL:
127 # WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
128 rhs = self.rhs
129 lookups = [IsNull(col, rhs) for col in self.lhs]
130 root = WhereNode(lookups, connector=OR if rhs else AND)
131 return root.as_sql(compiler, connection)
132
133
134class TupleGreaterThan(TupleLookupMixin, GreaterThan):
135 def as_oracle(self, compiler, connection):
136 # Process right-hand-side to trigger sanitization.
137 self.process_rhs(compiler, connection)
138 # e.g.: (a, b, c) > (x, y, z) as SQL:
139 # WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
140 lookups = itertools.cycle([GreaterThan, Exact])
141 connectors = itertools.cycle([OR, AND])
142 cols_list = [col for col in self.lhs for _ in range(2)]
143 vals_list = [val for val in self.rhs for _ in range(2)]
144 cols_iter = iter(cols_list[:-1])
145 vals_iter = iter(vals_list[:-1])
146 col = next(cols_iter)
147 val = next(vals_iter)
148 lookup = next(lookups)
149 connector = next(connectors)
150 root = node = WhereNode([lookup(col, val)], connector=connector)
151
152 for col, val in zip(cols_iter, vals_iter):
153 lookup = next(lookups)
154 connector = next(connectors)
155 child = WhereNode([lookup(col, val)], connector=connector)
156 node.children.append(child)
157 node = child
158
159 return root.as_sql(compiler, connection)
160
161
162class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
163 def as_oracle(self, compiler, connection):
164 # Process right-hand-side to trigger sanitization.
165 self.process_rhs(compiler, connection)
166 # e.g.: (a, b, c) >= (x, y, z) as SQL:
167 # WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
168 lookups = itertools.cycle([GreaterThan, Exact])
169 connectors = itertools.cycle([OR, AND])
170 cols_list = [col for col in self.lhs for _ in range(2)]
171 vals_list = [val for val in self.rhs for _ in range(2)]
172 cols_iter = iter(cols_list)
173 vals_iter = iter(vals_list)
174 col = next(cols_iter)
175 val = next(vals_iter)
176 lookup = next(lookups)
177 connector = next(connectors)
178 root = node = WhereNode([lookup(col, val)], connector=connector)
179
180 for col, val in zip(cols_iter, vals_iter):
181 lookup = next(lookups)
182 connector = next(connectors)
183 child = WhereNode([lookup(col, val)], connector=connector)
184 node.children.append(child)
185 node = child
186
187 return root.as_sql(compiler, connection)
188
189
190class TupleLessThan(TupleLookupMixin, LessThan):
191 def as_oracle(self, compiler, connection):
192 # Process right-hand-side to trigger sanitization.
193 self.process_rhs(compiler, connection)
194 # e.g.: (a, b, c) < (x, y, z) as SQL:
195 # WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
196 lookups = itertools.cycle([LessThan, Exact])
197 connectors = itertools.cycle([OR, AND])
198 cols_list = [col for col in self.lhs for _ in range(2)]
199 vals_list = [val for val in self.rhs for _ in range(2)]
200 cols_iter = iter(cols_list[:-1])
201 vals_iter = iter(vals_list[:-1])
202 col = next(cols_iter)
203 val = next(vals_iter)
204 lookup = next(lookups)
205 connector = next(connectors)
206 root = node = WhereNode([lookup(col, val)], connector=connector)
207
208 for col, val in zip(cols_iter, vals_iter):
209 lookup = next(lookups)
210 connector = next(connectors)
211 child = WhereNode([lookup(col, val)], connector=connector)
212 node.children.append(child)
213 node = child
214
215 return root.as_sql(compiler, connection)
216
217
218class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
219 def as_oracle(self, compiler, connection):
220 # Process right-hand-side to trigger sanitization.
221 self.process_rhs(compiler, connection)
222 # e.g.: (a, b, c) <= (x, y, z) as SQL:
223 # WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
224 lookups = itertools.cycle([LessThan, Exact])
225 connectors = itertools.cycle([OR, AND])
226 cols_list = [col for col in self.lhs for _ in range(2)]
227 vals_list = [val for val in self.rhs for _ in range(2)]
228 cols_iter = iter(cols_list)
229 vals_iter = iter(vals_list)
230 col = next(cols_iter)
231 val = next(vals_iter)
232 lookup = next(lookups)
233 connector = next(connectors)
234 root = node = WhereNode([lookup(col, val)], connector=connector)
235
236 for col, val in zip(cols_iter, vals_iter):
237 lookup = next(lookups)
238 connector = next(connectors)
239 child = WhereNode([lookup(col, val)], connector=connector)
240 node.children.append(child)
241 node = child
242
243 return root.as_sql(compiler, connection)
244
245
246class TupleIn(TupleLookupMixin, In):
247 def get_prep_lookup(self):
248 if self.rhs_is_direct_value():
249 self.check_rhs_is_tuple_or_list()
250 self.check_rhs_is_collection_of_tuples_or_lists()
251 self.check_rhs_elements_length_equals_lhs_length()
252 else:
253 self.check_rhs_is_query()
254 self.check_rhs_select_length_equals_lhs_length()
255
256 return self.rhs # skip checks from mixin
257
258 def check_rhs_is_collection_of_tuples_or_lists(self):
259 if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
260 lhs_str = self.get_lhs_str()
261 raise ValueError(
262 f"{self.lookup_name!r} lookup of {lhs_str} "
263 "must be a collection of tuples or lists"
264 )
265
266 def check_rhs_elements_length_equals_lhs_length(self):
267 len_lhs = len(self.lhs)
268 if not all(len_lhs == len(vals) for vals in self.rhs):
269 lhs_str = self.get_lhs_str()
270 raise ValueError(
271 f"{self.lookup_name!r} lookup of {lhs_str} "
272 f"must have {len_lhs} elements each"
273 )
274
275 def check_rhs_is_query(self):
276 if not isinstance(self.rhs, Query):
277 lhs_str = self.get_lhs_str()
278 rhs_cls = self.rhs.__class__.__name__
279 raise ValueError(
280 f"{self.lookup_name!r} subquery lookup of {lhs_str} "
281 f"must be a Query object (received {rhs_cls!r})"
282 )
283
284 def check_rhs_select_length_equals_lhs_length(self):
285 len_rhs = len(self.rhs.select)
286 if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
287 len_rhs = len(self.rhs.select[0])
288 len_lhs = len(self.lhs)
289 if len_rhs != len_lhs:
290 lhs_str = self.get_lhs_str()
291 raise ValueError(
292 f"{self.lookup_name!r} subquery lookup of {lhs_str} "
293 f"must have {len_lhs} fields (received {len_rhs})"
294 )
295
296 def process_rhs(self, compiler, connection):
297 rhs = self.rhs
298 if not rhs:
299 raise EmptyResultSet
300
301 # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
302 # WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
303 result = []
304 lhs = self.lhs
305
306 for vals in rhs:
307 result.append(
308 Tuple(
309 *[
310 Value(val, output_field=col.output_field)
311 for col, val in zip(lhs, vals)
312 ]
313 )
314 )
315
316 return Tuple(*result).as_sql(compiler, connection)
317
318 def as_sql(self, compiler, connection):
319 if not self.rhs_is_direct_value():
320 return self.as_subquery(compiler, connection)
321 return super().as_sql(compiler, connection)
322
323 def as_sqlite(self, compiler, connection):
324 rhs = self.rhs
325 if not rhs:
326 raise EmptyResultSet
327 if not self.rhs_is_direct_value():
328 return self.as_subquery(compiler, connection)
329
330 # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
331 # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
332 root = WhereNode([], connector=OR)
333 lhs = self.lhs
334
335 for vals in rhs:
336 lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
337 root.children.append(WhereNode(lookups, connector=AND))
338
339 return root.as_sql(compiler, connection)
340
341 def as_subquery(self, compiler, connection):
342 lhs = self.lhs
343 rhs = self.rhs
344 if isinstance(lhs, ColPairs):
345 rhs = rhs.clone()
346 rhs.set_values([source.name for source in lhs.sources])
347 lhs = Tuple(lhs)
348 return compiler.compile(In(lhs, rhs))
349
350
351tuple_lookups = {
352 "exact": TupleExact,
353 "gt": TupleGreaterThan,
354 "gte": TupleGreaterThanOrEqual,
355 "lt": TupleLessThan,
356 "lte": TupleLessThanOrEqual,
357 "in": TupleIn,
358 "isnull": TupleIsNull,
359}