1import copy
2import datetime
3import functools
4import inspect
5from collections import defaultdict
6from decimal import Decimal
7from enum import Enum
8from itertools import chain
9from types import NoneType
10from uuid import UUID
11
12from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
13from django.db import DatabaseError, NotSupportedError, connection
14from django.db.models import fields
15from django.db.models.constants import LOOKUP_SEP
16from django.db.models.query_utils import Q
17from django.utils.deconstruct import deconstructible
18from django.utils.functional import cached_property, classproperty
19from django.utils.hashable import make_hashable
20
21
22class SQLiteNumericMixin:
23 """
24 Some expressions with output_field=DecimalField() must be cast to
25 numeric to be properly filtered.
26 """
27
28 def as_sqlite(self, compiler, connection, **extra_context):
29 sql, params = self.as_sql(compiler, connection, **extra_context)
30 try:
31 if self.output_field.get_internal_type() == "DecimalField":
32 sql = "(CAST(%s AS NUMERIC))" % sql
33 except FieldError:
34 pass
35 return sql, params
36
37
38class Combinable:
39 """
40 Provide the ability to combine one or two objects with
41 some connector. For example F('foo') + F('bar').
42 """
43
44 # Arithmetic connectors
45 ADD = "+"
46 SUB = "-"
47 MUL = "*"
48 DIV = "/"
49 POW = "^"
50 # The following is a quoted % operator - it is quoted because it can be
51 # used in strings that also have parameter substitution.
52 MOD = "%%"
53
54 # Bitwise operators - note that these are generated by .bitand()
55 # and .bitor(), the '&' and '|' are reserved for boolean operator
56 # usage.
57 BITAND = "&"
58 BITOR = "|"
59 BITLEFTSHIFT = "<<"
60 BITRIGHTSHIFT = ">>"
61 BITXOR = "#"
62
63 def _combine(self, other, connector, reversed):
64 if not hasattr(other, "resolve_expression"):
65 # everything must be resolvable to an expression
66 other = Value(other)
67
68 if reversed:
69 return CombinedExpression(other, connector, self)
70 return CombinedExpression(self, connector, other)
71
72 #############
73 # OPERATORS #
74 #############
75
76 def __neg__(self):
77 return self._combine(-1, self.MUL, False)
78
79 def __add__(self, other):
80 return self._combine(other, self.ADD, False)
81
82 def __sub__(self, other):
83 return self._combine(other, self.SUB, False)
84
85 def __mul__(self, other):
86 return self._combine(other, self.MUL, False)
87
88 def __truediv__(self, other):
89 return self._combine(other, self.DIV, False)
90
91 def __mod__(self, other):
92 return self._combine(other, self.MOD, False)
93
94 def __pow__(self, other):
95 return self._combine(other, self.POW, False)
96
97 def __and__(self, other):
98 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
99 return Q(self) & Q(other)
100 raise NotImplementedError(
101 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
102 )
103
104 def bitand(self, other):
105 return self._combine(other, self.BITAND, False)
106
107 def bitleftshift(self, other):
108 return self._combine(other, self.BITLEFTSHIFT, False)
109
110 def bitrightshift(self, other):
111 return self._combine(other, self.BITRIGHTSHIFT, False)
112
113 def __xor__(self, other):
114 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
115 return Q(self) ^ Q(other)
116 raise NotImplementedError(
117 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
118 )
119
120 def bitxor(self, other):
121 return self._combine(other, self.BITXOR, False)
122
123 def __or__(self, other):
124 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
125 return Q(self) | Q(other)
126 raise NotImplementedError(
127 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
128 )
129
130 def bitor(self, other):
131 return self._combine(other, self.BITOR, False)
132
133 def __radd__(self, other):
134 return self._combine(other, self.ADD, True)
135
136 def __rsub__(self, other):
137 return self._combine(other, self.SUB, True)
138
139 def __rmul__(self, other):
140 return self._combine(other, self.MUL, True)
141
142 def __rtruediv__(self, other):
143 return self._combine(other, self.DIV, True)
144
145 def __rmod__(self, other):
146 return self._combine(other, self.MOD, True)
147
148 def __rpow__(self, other):
149 return self._combine(other, self.POW, True)
150
151 def __rand__(self, other):
152 raise NotImplementedError(
153 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
154 )
155
156 def __ror__(self, other):
157 raise NotImplementedError(
158 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
159 )
160
161 def __rxor__(self, other):
162 raise NotImplementedError(
163 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
164 )
165
166 def __invert__(self):
167 return NegatedExpression(self)
168
169
170class BaseExpression:
171 """Base class for all query expressions."""
172
173 empty_result_set_value = NotImplemented
174 # aggregate specific fields
175 is_summary = False
176 _output_field_resolved_to_none = False
177 # Can the expression be used in a WHERE clause?
178 filterable = True
179 # Can the expression be used as a source expression in Window?
180 window_compatible = False
181 # Can the expression be used as a database default value?
182 allowed_default = False
183 # Can the expression be used during a constraint validation?
184 constraint_validation_compatible = True
185 # Does the expression possibly return more than one row?
186 set_returning = False
187 # Does the expression allow composite expressions?
188 allows_composite_expressions = False
189
190 def __init__(self, output_field=None):
191 if output_field is not None:
192 self.output_field = output_field
193
194 def __getstate__(self):
195 state = self.__dict__.copy()
196 state.pop("convert_value", None)
197 return state
198
199 def get_db_converters(self, connection):
200 return (
201 []
202 if self.convert_value is self._convert_value_noop
203 else [self.convert_value]
204 ) + self.output_field.get_db_converters(connection)
205
206 def get_source_expressions(self):
207 return []
208
209 def set_source_expressions(self, exprs):
210 assert not exprs
211
212 def _parse_expressions(self, *expressions):
213 return [
214 (
215 arg
216 if hasattr(arg, "resolve_expression")
217 else (F(arg) if isinstance(arg, str) else Value(arg))
218 )
219 for arg in expressions
220 ]
221
222 def as_sql(self, compiler, connection):
223 """
224 Responsible for returning a (sql, [params]) tuple to be included
225 in the current query.
226
227 Different backends can provide their own implementation, by
228 providing an `as_{vendor}` method and patching the Expression:
229
230 ```
231 def override_as_sql(self, compiler, connection):
232 # custom logic
233 return super().as_sql(compiler, connection)
234 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
235 ```
236
237 Arguments:
238 * compiler: the query compiler responsible for generating the query.
239 Must have a compile method, returning a (sql, [params]) tuple.
240 Calling compiler(value) will return a quoted `value`.
241
242 * connection: the database connection used for the current query.
243
244 Return: (sql, params)
245 Where `sql` is a string containing ordered sql parameters to be
246 replaced with the elements of the list `params`.
247 """
248 raise NotImplementedError("Subclasses must implement as_sql()")
249
250 @cached_property
251 def contains_aggregate(self):
252 return any(
253 expr and expr.contains_aggregate for expr in self.get_source_expressions()
254 )
255
256 @cached_property
257 def contains_over_clause(self):
258 return any(
259 expr and expr.contains_over_clause for expr in self.get_source_expressions()
260 )
261
262 @cached_property
263 def contains_column_references(self):
264 return any(
265 expr and expr.contains_column_references
266 for expr in self.get_source_expressions()
267 )
268
269 @cached_property
270 def contains_subquery(self):
271 return any(
272 expr and (getattr(expr, "subquery", False) or expr.contains_subquery)
273 for expr in self.get_source_expressions()
274 )
275
276 def resolve_expression(
277 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
278 ):
279 """
280 Provide the chance to do any preprocessing or validation before being
281 added to the query.
282
283 Arguments:
284 * query: the backend query implementation
285 * allow_joins: boolean allowing or denying use of joins
286 in this query
287 * reuse: a set of reusable joins for multijoins
288 * summarize: a terminal aggregate clause
289 * for_save: whether this expression about to be used in a save or update
290
291 Return: an Expression to be added to the query.
292 """
293 c = self.copy()
294 c.is_summary = summarize
295 c.set_source_expressions(
296 [
297 (
298 expr.resolve_expression(query, allow_joins, reuse, summarize)
299 if expr
300 else None
301 )
302 for expr in c.get_source_expressions()
303 ]
304 )
305 return c
306
307 @property
308 def conditional(self):
309 return isinstance(self.output_field, fields.BooleanField)
310
311 @property
312 def field(self):
313 return self.output_field
314
315 @cached_property
316 def output_field(self):
317 """Return the output type of this expressions."""
318 output_field = self._resolve_output_field()
319 if output_field is None:
320 self._output_field_resolved_to_none = True
321 raise FieldError("Cannot resolve expression type, unknown output_field")
322 return output_field
323
324 @cached_property
325 def _output_field_or_none(self):
326 """
327 Return the output field of this expression, or None if
328 _resolve_output_field() didn't return an output type.
329 """
330 try:
331 return self.output_field
332 except FieldError:
333 if not self._output_field_resolved_to_none:
334 raise
335
336 def _resolve_output_field(self):
337 """
338 Attempt to infer the output type of the expression.
339
340 As a guess, if the output fields of all source fields match then simply
341 infer the same type here.
342
343 If a source's output field resolves to None, exclude it from this check.
344 If all sources are None, then an error is raised higher up the stack in
345 the output_field property.
346 """
347 # This guess is mostly a bad idea, but there is quite a lot of code
348 # (especially 3rd party Func subclasses) that depend on it, we'd need a
349 # deprecation path to fix it.
350 sources_iter = (
351 source for source in self.get_source_fields() if source is not None
352 )
353 for output_field in sources_iter:
354 for source in sources_iter:
355 if not isinstance(output_field, source.__class__):
356 raise FieldError(
357 "Expression contains mixed types: %s, %s. You must "
358 "set output_field."
359 % (
360 output_field.__class__.__name__,
361 source.__class__.__name__,
362 )
363 )
364 return output_field
365
366 @staticmethod
367 def _convert_value_noop(value, expression, connection):
368 return value
369
370 @cached_property
371 def convert_value(self):
372 """
373 Expressions provide their own converters because users have the option
374 of manually specifying the output_field which may be a different type
375 from the one the database returns.
376 """
377 field = self.output_field
378 internal_type = field.get_internal_type()
379 if internal_type == "FloatField":
380 return lambda value, expression, connection: (
381 None if value is None else float(value)
382 )
383 elif internal_type.endswith("IntegerField"):
384 return lambda value, expression, connection: (
385 None if value is None else int(value)
386 )
387 elif internal_type == "DecimalField":
388 return lambda value, expression, connection: (
389 None if value is None else Decimal(value)
390 )
391 return self._convert_value_noop
392
393 def get_lookup(self, lookup):
394 return self.output_field.get_lookup(lookup)
395
396 def get_transform(self, name):
397 return self.output_field.get_transform(name)
398
399 def relabeled_clone(self, change_map):
400 clone = self.copy()
401 clone.set_source_expressions(
402 [
403 e.relabeled_clone(change_map) if e is not None else None
404 for e in self.get_source_expressions()
405 ]
406 )
407 return clone
408
409 def replace_expressions(self, replacements):
410 if not replacements:
411 return self
412 if replacement := replacements.get(self):
413 return replacement
414 if not (source_expressions := self.get_source_expressions()):
415 return self
416 clone = self.copy()
417 clone.set_source_expressions(
418 [
419 expr.replace_expressions(replacements) if expr else None
420 for expr in source_expressions
421 ]
422 )
423 return clone
424
425 def get_refs(self):
426 refs = set()
427 for expr in self.get_source_expressions():
428 if expr is None:
429 continue
430 refs |= expr.get_refs()
431 return refs
432
433 def copy(self):
434 return copy.copy(self)
435
436 def prefix_references(self, prefix):
437 clone = self.copy()
438 clone.set_source_expressions(
439 [
440 (
441 F(f"{prefix}{expr.name}")
442 if isinstance(expr, F)
443 else expr.prefix_references(prefix)
444 )
445 for expr in self.get_source_expressions()
446 ]
447 )
448 return clone
449
450 def get_group_by_cols(self):
451 if not self.contains_aggregate:
452 return [self]
453 cols = []
454 for source in self.get_source_expressions():
455 cols.extend(source.get_group_by_cols())
456 return cols
457
458 def get_source_fields(self):
459 """Return the underlying field types used by this aggregate."""
460 return [e._output_field_or_none for e in self.get_source_expressions()]
461
462 def asc(self, **kwargs):
463 return OrderBy(self, **kwargs)
464
465 def desc(self, **kwargs):
466 return OrderBy(self, descending=True, **kwargs)
467
468 def reverse_ordering(self):
469 return self
470
471 def flatten(self):
472 """
473 Recursively yield this expression and all subexpressions, in
474 depth-first order.
475 """
476 yield self
477 for expr in self.get_source_expressions():
478 if expr:
479 if hasattr(expr, "flatten"):
480 yield from expr.flatten()
481 else:
482 yield expr
483
484 def select_format(self, compiler, sql, params):
485 """
486 Custom format for select clauses. For example, EXISTS expressions need
487 to be wrapped in CASE WHEN on Oracle.
488 """
489 if hasattr(self.output_field, "select_format"):
490 return self.output_field.select_format(compiler, sql, params)
491 return sql, params
492
493 def get_expression_for_validation(self):
494 # Ignore expressions that cannot be used during a constraint validation.
495 if not getattr(self, "constraint_validation_compatible", True):
496 try:
497 (expression,) = self.get_source_expressions()
498 except ValueError as e:
499 raise ValueError(
500 "Expressions with constraint_validation_compatible set to False "
501 "must have only one source expression."
502 ) from e
503 else:
504 return expression
505 return self
506
507
508@deconstructible
509class Expression(BaseExpression, Combinable):
510 """An expression that can be combined with other expressions."""
511
512 @classproperty
513 @functools.lru_cache(maxsize=128)
514 def _constructor_signature(cls):
515 return inspect.signature(cls.__init__)
516
517 @cached_property
518 def identity(self):
519 args, kwargs = self._constructor_args
520 signature = self._constructor_signature.bind_partial(self, *args, **kwargs)
521 signature.apply_defaults()
522 arguments = iter(signature.arguments.items())
523 next(arguments)
524 identity = [self.__class__]
525 for arg, value in arguments:
526 if isinstance(value, fields.Field):
527 if value.name and value.model:
528 value = (value.model._meta.label, value.name)
529 else:
530 value = type(value)
531 else:
532 value = make_hashable(value)
533 identity.append((arg, value))
534 return tuple(identity)
535
536 def __eq__(self, other):
537 if not isinstance(other, Expression):
538 return NotImplemented
539 return other.identity == self.identity
540
541 def __hash__(self):
542 return hash(self.identity)
543
544
545# Type inference for CombinedExpression.output_field.
546# Missing items will result in FieldError, by design.
547#
548# The current approach for NULL is based on lowest common denominator behavior
549# i.e. if one of the supported databases is raising an error (rather than
550# return NULL) for `val <op> NULL`, then Django raises FieldError.
551
552_connector_combinations = [
553 # Numeric operations - operands of same type.
554 # PositiveIntegerField should take precedence over IntegerField (except
555 # subtraction).
556 {
557 connector: [
558 (
559 fields.PositiveIntegerField,
560 fields.PositiveIntegerField,
561 fields.PositiveIntegerField,
562 ),
563 ]
564 for connector in (
565 Combinable.ADD,
566 Combinable.MUL,
567 Combinable.DIV,
568 Combinable.MOD,
569 Combinable.POW,
570 )
571 },
572 # Other numeric operands.
573 {
574 connector: [
575 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
576 (fields.FloatField, fields.FloatField, fields.FloatField),
577 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
578 ]
579 for connector in (
580 Combinable.ADD,
581 Combinable.SUB,
582 Combinable.MUL,
583 # Behavior for DIV with integer arguments follows Postgres/SQLite,
584 # not MySQL/Oracle.
585 Combinable.DIV,
586 Combinable.MOD,
587 Combinable.POW,
588 )
589 },
590 # Numeric operations - operands of different type.
591 {
592 connector: [
593 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
594 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
595 (fields.IntegerField, fields.FloatField, fields.FloatField),
596 (fields.FloatField, fields.IntegerField, fields.FloatField),
597 ]
598 for connector in (
599 Combinable.ADD,
600 Combinable.SUB,
601 Combinable.MUL,
602 Combinable.DIV,
603 Combinable.MOD,
604 )
605 },
606 # Bitwise operators.
607 {
608 connector: [
609 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
610 ]
611 for connector in (
612 Combinable.BITAND,
613 Combinable.BITOR,
614 Combinable.BITLEFTSHIFT,
615 Combinable.BITRIGHTSHIFT,
616 Combinable.BITXOR,
617 )
618 },
619 # Numeric with NULL.
620 {
621 connector: list(
622 chain.from_iterable(
623 [(field_type, NoneType, field_type), (NoneType, field_type, field_type)]
624 for field_type in (
625 fields.IntegerField,
626 fields.DecimalField,
627 fields.FloatField,
628 )
629 )
630 )
631 for connector in (
632 Combinable.ADD,
633 Combinable.SUB,
634 Combinable.MUL,
635 Combinable.DIV,
636 Combinable.MOD,
637 Combinable.POW,
638 )
639 },
640 # Date/DateTimeField/DurationField/TimeField.
641 {
642 Combinable.ADD: [
643 # Date/DateTimeField.
644 (fields.DateField, fields.DurationField, fields.DateTimeField),
645 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
646 (fields.DurationField, fields.DateField, fields.DateTimeField),
647 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
648 # DurationField.
649 (fields.DurationField, fields.DurationField, fields.DurationField),
650 # TimeField.
651 (fields.TimeField, fields.DurationField, fields.TimeField),
652 (fields.DurationField, fields.TimeField, fields.TimeField),
653 ],
654 },
655 {
656 Combinable.SUB: [
657 # Date/DateTimeField.
658 (fields.DateField, fields.DurationField, fields.DateTimeField),
659 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
660 (fields.DateField, fields.DateField, fields.DurationField),
661 (fields.DateField, fields.DateTimeField, fields.DurationField),
662 (fields.DateTimeField, fields.DateField, fields.DurationField),
663 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
664 # DurationField.
665 (fields.DurationField, fields.DurationField, fields.DurationField),
666 # TimeField.
667 (fields.TimeField, fields.DurationField, fields.TimeField),
668 (fields.TimeField, fields.TimeField, fields.DurationField),
669 ],
670 },
671]
672
673_connector_combinators = defaultdict(list)
674
675
676def register_combinable_fields(lhs, connector, rhs, result):
677 """
678 Register combinable types:
679 lhs <connector> rhs -> result
680 e.g.
681 register_combinable_fields(
682 IntegerField, Combinable.ADD, FloatField, FloatField
683 )
684 """
685 _connector_combinators[connector].append((lhs, rhs, result))
686
687
688for d in _connector_combinations:
689 for connector, field_types in d.items():
690 for lhs, rhs, result in field_types:
691 register_combinable_fields(lhs, connector, rhs, result)
692
693
694@functools.lru_cache(maxsize=128)
695def _resolve_combined_type(connector, lhs_type, rhs_type):
696 combinators = _connector_combinators.get(connector, ())
697 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
698 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
699 rhs_type, combinator_rhs_type
700 ):
701 return combined_type
702
703
704class CombinedExpression(SQLiteNumericMixin, Expression):
705 def __init__(self, lhs, connector, rhs, output_field=None):
706 super().__init__(output_field=output_field)
707 self.connector = connector
708 self.lhs = lhs
709 self.rhs = rhs
710
711 def __repr__(self):
712 return "<{}: {}>".format(self.__class__.__name__, self)
713
714 def __str__(self):
715 return "{} {} {}".format(self.lhs, self.connector, self.rhs)
716
717 def get_source_expressions(self):
718 return [self.lhs, self.rhs]
719
720 def set_source_expressions(self, exprs):
721 self.lhs, self.rhs = exprs
722
723 def _resolve_output_field(self):
724 # We avoid using super() here for reasons given in
725 # Expression._resolve_output_field()
726 combined_type = _resolve_combined_type(
727 self.connector,
728 type(self.lhs._output_field_or_none),
729 type(self.rhs._output_field_or_none),
730 )
731 if combined_type is None:
732 raise FieldError(
733 f"Cannot infer type of {self.connector!r} expression involving these "
734 f"types: {self.lhs.output_field.__class__.__name__}, "
735 f"{self.rhs.output_field.__class__.__name__}. You must set "
736 f"output_field."
737 )
738 return combined_type()
739
740 def as_sql(self, compiler, connection):
741 expressions = []
742 expression_params = []
743 sql, params = compiler.compile(self.lhs)
744 expressions.append(sql)
745 expression_params.extend(params)
746 sql, params = compiler.compile(self.rhs)
747 expressions.append(sql)
748 expression_params.extend(params)
749 # order of precedence
750 expression_wrapper = "(%s)"
751 sql = connection.ops.combine_expression(self.connector, expressions)
752 return expression_wrapper % sql, expression_params
753
754 def resolve_expression(
755 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
756 ):
757 lhs = self.lhs.resolve_expression(
758 query, allow_joins, reuse, summarize, for_save
759 )
760 rhs = self.rhs.resolve_expression(
761 query, allow_joins, reuse, summarize, for_save
762 )
763 if isinstance(lhs, ColPairs) or isinstance(rhs, ColPairs):
764 raise ValueError("CompositePrimaryKey is not combinable.")
765 if not isinstance(self, (DurationExpression, TemporalSubtraction)):
766 try:
767 lhs_type = lhs.output_field.get_internal_type()
768 except (AttributeError, FieldError):
769 lhs_type = None
770 try:
771 rhs_type = rhs.output_field.get_internal_type()
772 except (AttributeError, FieldError):
773 rhs_type = None
774 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
775 return DurationExpression(
776 self.lhs, self.connector, self.rhs
777 ).resolve_expression(
778 query,
779 allow_joins,
780 reuse,
781 summarize,
782 for_save,
783 )
784 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
785 if (
786 self.connector == self.SUB
787 and lhs_type in datetime_fields
788 and lhs_type == rhs_type
789 ):
790 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
791 query,
792 allow_joins,
793 reuse,
794 summarize,
795 for_save,
796 )
797 c = self.copy()
798 c.is_summary = summarize
799 c.lhs = lhs
800 c.rhs = rhs
801 return c
802
803 @cached_property
804 def allowed_default(self):
805 return self.lhs.allowed_default and self.rhs.allowed_default
806
807
808class DurationExpression(CombinedExpression):
809 def compile(self, side, compiler, connection):
810 try:
811 output = side.output_field
812 except FieldError:
813 pass
814 else:
815 if output.get_internal_type() == "DurationField":
816 sql, params = compiler.compile(side)
817 return connection.ops.format_for_duration_arithmetic(sql), params
818 return compiler.compile(side)
819
820 def as_sql(self, compiler, connection):
821 if connection.features.has_native_duration_field:
822 return super().as_sql(compiler, connection)
823 connection.ops.check_expression_support(self)
824 expressions = []
825 expression_params = []
826 sql, params = self.compile(self.lhs, compiler, connection)
827 expressions.append(sql)
828 expression_params.extend(params)
829 sql, params = self.compile(self.rhs, compiler, connection)
830 expressions.append(sql)
831 expression_params.extend(params)
832 # order of precedence
833 expression_wrapper = "(%s)"
834 sql = connection.ops.combine_duration_expression(self.connector, expressions)
835 return expression_wrapper % sql, expression_params
836
837 def as_sqlite(self, compiler, connection, **extra_context):
838 sql, params = self.as_sql(compiler, connection, **extra_context)
839 if self.connector in {Combinable.MUL, Combinable.DIV}:
840 try:
841 lhs_type = self.lhs.output_field.get_internal_type()
842 rhs_type = self.rhs.output_field.get_internal_type()
843 except (AttributeError, FieldError):
844 pass
845 else:
846 allowed_fields = {
847 "DecimalField",
848 "DurationField",
849 "FloatField",
850 "IntegerField",
851 }
852 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
853 raise DatabaseError(
854 f"Invalid arguments for operator {self.connector}."
855 )
856 return sql, params
857
858
859class TemporalSubtraction(CombinedExpression):
860 output_field = fields.DurationField()
861
862 def __init__(self, lhs, rhs):
863 super().__init__(lhs, self.SUB, rhs)
864
865 def as_sql(self, compiler, connection):
866 connection.ops.check_expression_support(self)
867 lhs = compiler.compile(self.lhs)
868 rhs = compiler.compile(self.rhs)
869 return connection.ops.subtract_temporals(
870 self.lhs.output_field.get_internal_type(), lhs, rhs
871 )
872
873
874@deconstructible(path="django.db.models.F")
875class F(Combinable):
876 """An object capable of resolving references to existing query objects."""
877
878 allowed_default = False
879
880 def __init__(self, name):
881 """
882 Arguments:
883 * name: the name of the field this expression references
884 """
885 self.name = name
886
887 def __repr__(self):
888 return "{}({})".format(self.__class__.__name__, self.name)
889
890 def __getitem__(self, subscript):
891 return Sliced(self, subscript)
892
893 def __contains__(self, other):
894 # Disable old-style iteration protocol inherited from implementing
895 # __getitem__() to prevent this method from hanging.
896 raise TypeError(f"argument of type '{self.__class__.__name__}' is not iterable")
897
898 def resolve_expression(
899 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
900 ):
901 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
902
903 def replace_expressions(self, replacements):
904 return replacements.get(self, self)
905
906 def asc(self, **kwargs):
907 return OrderBy(self, **kwargs)
908
909 def desc(self, **kwargs):
910 return OrderBy(self, descending=True, **kwargs)
911
912 def __eq__(self, other):
913 return self.__class__ == other.__class__ and self.name == other.name
914
915 def __hash__(self):
916 return hash(self.name)
917
918 def copy(self):
919 return copy.copy(self)
920
921
922class ResolvedOuterRef(F):
923 """
924 An object that contains a reference to an outer query.
925
926 In this case, the reference to the outer query has been resolved because
927 the inner query has been used as a subquery.
928 """
929
930 contains_aggregate = False
931 contains_over_clause = False
932
933 def as_sql(self, *args, **kwargs):
934 raise ValueError(
935 "This queryset contains a reference to an outer query and may "
936 "only be used in a subquery."
937 )
938
939 def resolve_expression(self, *args, **kwargs):
940 col = super().resolve_expression(*args, **kwargs)
941 if col.contains_over_clause:
942 raise NotSupportedError(
943 f"Referencing outer query window expression is not supported: "
944 f"{self.name}."
945 )
946 # FIXME: Rename possibly_multivalued to multivalued and fix detection
947 # for non-multivalued JOINs (e.g. foreign key fields). This should take
948 # into account only many-to-many and one-to-many relationships.
949 col.possibly_multivalued = LOOKUP_SEP in self.name
950 return col
951
952 def relabeled_clone(self, relabels):
953 return self
954
955 def get_group_by_cols(self):
956 return []
957
958
959class OuterRef(F):
960 contains_aggregate = False
961 contains_over_clause = False
962
963 def resolve_expression(self, *args, **kwargs):
964 if isinstance(self.name, self.__class__):
965 return self.name
966 return ResolvedOuterRef(self.name)
967
968 def relabeled_clone(self, relabels):
969 return self
970
971
972class Sliced(F):
973 """
974 An object that contains a slice of an F expression.
975
976 Object resolves the column on which the slicing is applied, and then
977 applies the slicing if possible.
978 """
979
980 def __init__(self, obj, subscript):
981 super().__init__(obj.name)
982 self.obj = obj
983 if isinstance(subscript, int):
984 if subscript < 0:
985 raise ValueError("Negative indexing is not supported.")
986 self.start = subscript + 1
987 self.length = 1
988 elif isinstance(subscript, slice):
989 if (subscript.start is not None and subscript.start < 0) or (
990 subscript.stop is not None and subscript.stop < 0
991 ):
992 raise ValueError("Negative indexing is not supported.")
993 if subscript.step is not None:
994 raise ValueError("Step argument is not supported.")
995 if subscript.stop and subscript.start and subscript.stop < subscript.start:
996 raise ValueError("Slice stop must be greater than slice start.")
997 self.start = 1 if subscript.start is None else subscript.start + 1
998 if subscript.stop is None:
999 self.length = None
1000 else:
1001 self.length = subscript.stop - (subscript.start or 0)
1002 else:
1003 raise TypeError("Argument to slice must be either int or slice instance.")
1004
1005 def __repr__(self):
1006 start = self.start - 1
1007 stop = None if self.length is None else start + self.length
1008 subscript = slice(start, stop)
1009 return f"{self.__class__.__qualname__}({self.obj!r}, {subscript!r})"
1010
1011 def resolve_expression(
1012 self,
1013 query=None,
1014 allow_joins=True,
1015 reuse=None,
1016 summarize=False,
1017 for_save=False,
1018 ):
1019 resolved = query.resolve_ref(self.name, allow_joins, reuse, summarize)
1020 if isinstance(self.obj, (OuterRef, self.__class__)):
1021 expr = self.obj.resolve_expression(
1022 query, allow_joins, reuse, summarize, for_save
1023 )
1024 else:
1025 expr = resolved
1026 return resolved.output_field.slice_expression(expr, self.start, self.length)
1027
1028
1029@deconstructible(path="django.db.models.Func")
1030class Func(SQLiteNumericMixin, Expression):
1031 """An SQL function call."""
1032
1033 function = None
1034 template = "%(function)s(%(expressions)s)"
1035 arg_joiner = ", "
1036 arity = None # The number of arguments the function accepts.
1037
1038 def __init__(self, *expressions, output_field=None, **extra):
1039 if self.arity is not None and len(expressions) != self.arity:
1040 raise TypeError(
1041 "'%s' takes exactly %s %s (%s given)"
1042 % (
1043 self.__class__.__name__,
1044 self.arity,
1045 "argument" if self.arity == 1 else "arguments",
1046 len(expressions),
1047 )
1048 )
1049 super().__init__(output_field=output_field)
1050 self.source_expressions = self._parse_expressions(*expressions)
1051 self.extra = extra
1052
1053 def __repr__(self):
1054 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1055 extra = {**self.extra, **self._get_repr_options()}
1056 if extra:
1057 extra = ", ".join(
1058 str(key) + "=" + str(val) for key, val in sorted(extra.items())
1059 )
1060 return "{}({}, {})".format(self.__class__.__name__, args, extra)
1061 return "{}({})".format(self.__class__.__name__, args)
1062
1063 def _get_repr_options(self):
1064 """Return a dict of extra __init__() options to include in the repr."""
1065 return {}
1066
1067 def get_source_expressions(self):
1068 return self.source_expressions
1069
1070 def set_source_expressions(self, exprs):
1071 self.source_expressions = exprs
1072
1073 def resolve_expression(
1074 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1075 ):
1076 c = self.copy()
1077 c.is_summary = summarize
1078 for pos, arg in enumerate(c.source_expressions):
1079 c.source_expressions[pos] = arg.resolve_expression(
1080 query, allow_joins, reuse, summarize, for_save
1081 )
1082 if not self.allows_composite_expressions and any(
1083 isinstance(expr, ColPairs) for expr in c.get_source_expressions()
1084 ):
1085 raise ValueError(
1086 f"{self.__class__.__name__} does not support composite primary keys."
1087 )
1088 return c
1089
1090 def as_sql(
1091 self,
1092 compiler,
1093 connection,
1094 function=None,
1095 template=None,
1096 arg_joiner=None,
1097 **extra_context,
1098 ):
1099 connection.ops.check_expression_support(self)
1100 sql_parts = []
1101 params = []
1102 for arg in self.source_expressions:
1103 try:
1104 arg_sql, arg_params = compiler.compile(arg)
1105 except EmptyResultSet:
1106 empty_result_set_value = getattr(
1107 arg, "empty_result_set_value", NotImplemented
1108 )
1109 if empty_result_set_value is NotImplemented:
1110 raise
1111 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1112 except FullResultSet:
1113 arg_sql, arg_params = compiler.compile(Value(True))
1114 sql_parts.append(arg_sql)
1115 params.extend(arg_params)
1116 data = {**self.extra, **extra_context}
1117 # Use the first supplied value in this order: the parameter to this
1118 # method, a value supplied in __init__()'s **extra (the value in
1119 # `data`), or the value defined on the class.
1120 if function is not None:
1121 data["function"] = function
1122 else:
1123 data.setdefault("function", self.function)
1124 template = template or data.get("template", self.template)
1125 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1126 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1127 return template % data, params
1128
1129 def copy(self):
1130 copy = super().copy()
1131 copy.source_expressions = self.source_expressions[:]
1132 copy.extra = self.extra.copy()
1133 return copy
1134
1135 @cached_property
1136 def allowed_default(self):
1137 return all(expression.allowed_default for expression in self.source_expressions)
1138
1139
1140@deconstructible(path="django.db.models.Value")
1141class Value(SQLiteNumericMixin, Expression):
1142 """Represent a wrapped value as a node within an expression."""
1143
1144 # Provide a default value for `for_save` in order to allow unresolved
1145 # instances to be compiled until a decision is taken in #25425.
1146 for_save = False
1147 allowed_default = True
1148
1149 def __init__(self, value, output_field=None):
1150 """
1151 Arguments:
1152 * value: the value this expression represents. The value will be
1153 added into the sql parameter list and properly quoted.
1154
1155 * output_field: an instance of the model field type that this
1156 expression will return, such as IntegerField() or CharField().
1157 """
1158 super().__init__(output_field=output_field)
1159 self.value = value
1160
1161 def __repr__(self):
1162 return f"{self.__class__.__name__}({self.value!r})"
1163
1164 def as_sql(self, compiler, connection):
1165 connection.ops.check_expression_support(self)
1166 val = self.value
1167 output_field = self._output_field_or_none
1168 if output_field is not None:
1169 if self.for_save:
1170 val = output_field.get_db_prep_save(val, connection=connection)
1171 else:
1172 val = output_field.get_db_prep_value(val, connection=connection)
1173 if hasattr(output_field, "get_placeholder"):
1174 return output_field.get_placeholder(val, compiler, connection), [val]
1175 if val is None:
1176 # oracledb does not always convert None to the appropriate
1177 # NULL type (like in case expressions using numbers), so we
1178 # use a literal SQL NULL
1179 return "NULL", []
1180 return "%s", [val]
1181
1182 def resolve_expression(
1183 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1184 ):
1185 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1186 c.for_save = for_save
1187 return c
1188
1189 def get_group_by_cols(self):
1190 return []
1191
1192 def _resolve_output_field(self):
1193 if isinstance(self.value, str):
1194 return fields.CharField()
1195 if isinstance(self.value, bool):
1196 return fields.BooleanField()
1197 if isinstance(self.value, int):
1198 return fields.IntegerField()
1199 if isinstance(self.value, float):
1200 return fields.FloatField()
1201 if isinstance(self.value, datetime.datetime):
1202 return fields.DateTimeField()
1203 if isinstance(self.value, datetime.date):
1204 return fields.DateField()
1205 if isinstance(self.value, datetime.time):
1206 return fields.TimeField()
1207 if isinstance(self.value, datetime.timedelta):
1208 return fields.DurationField()
1209 if isinstance(self.value, Decimal):
1210 return fields.DecimalField()
1211 if isinstance(self.value, bytes):
1212 return fields.BinaryField()
1213 if isinstance(self.value, UUID):
1214 return fields.UUIDField()
1215
1216 @property
1217 def empty_result_set_value(self):
1218 return self.value
1219
1220
1221class RawSQL(Expression):
1222 allowed_default = True
1223
1224 def __init__(self, sql, params, output_field=None):
1225 if output_field is None:
1226 output_field = fields.Field()
1227 self.sql, self.params = sql, params
1228 super().__init__(output_field=output_field)
1229
1230 def __repr__(self):
1231 return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
1232
1233 def as_sql(self, compiler, connection):
1234 return "(%s)" % self.sql, self.params
1235
1236 def get_group_by_cols(self):
1237 return [self]
1238
1239 def resolve_expression(
1240 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1241 ):
1242 # Resolve parents fields used in raw SQL.
1243 if query.model:
1244 for parent in query.model._meta.all_parents:
1245 for parent_field in parent._meta.local_fields:
1246 if parent_field.column.lower() in self.sql.lower():
1247 query.resolve_ref(
1248 parent_field.name, allow_joins, reuse, summarize
1249 )
1250 break
1251 return super().resolve_expression(
1252 query, allow_joins, reuse, summarize, for_save
1253 )
1254
1255
1256class Star(Expression):
1257 def __repr__(self):
1258 return "'*'"
1259
1260 def as_sql(self, compiler, connection):
1261 return "*", []
1262
1263
1264class DatabaseDefault(Expression):
1265 """
1266 Expression to use DEFAULT keyword during insert otherwise the underlying expression.
1267 """
1268
1269 def __init__(self, expression, output_field=None):
1270 super().__init__(output_field)
1271 self.expression = expression
1272
1273 def get_source_expressions(self):
1274 return [self.expression]
1275
1276 def set_source_expressions(self, exprs):
1277 (self.expression,) = exprs
1278
1279 def resolve_expression(
1280 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1281 ):
1282 resolved_expression = self.expression.resolve_expression(
1283 query=query,
1284 allow_joins=allow_joins,
1285 reuse=reuse,
1286 summarize=summarize,
1287 for_save=for_save,
1288 )
1289 # Defaults used outside an INSERT context should resolve to their
1290 # underlying expression.
1291 if not for_save:
1292 return resolved_expression
1293 return DatabaseDefault(
1294 resolved_expression, output_field=self._output_field_or_none
1295 )
1296
1297 def as_sql(self, compiler, connection):
1298 if not connection.features.supports_default_keyword_in_insert:
1299 return compiler.compile(self.expression)
1300 return "DEFAULT", []
1301
1302
1303class Col(Expression):
1304 contains_column_references = True
1305 possibly_multivalued = False
1306
1307 def __init__(self, alias, target, output_field=None):
1308 if output_field is None:
1309 output_field = target
1310 super().__init__(output_field=output_field)
1311 self.alias, self.target = alias, target
1312
1313 def __repr__(self):
1314 alias, target = self.alias, self.target
1315 identifiers = (alias, str(target)) if alias else (str(target),)
1316 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1317
1318 def as_sql(self, compiler, connection):
1319 alias, column = self.alias, self.target.column
1320 identifiers = (alias, column) if alias else (column,)
1321 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1322 return sql, []
1323
1324 def relabeled_clone(self, relabels):
1325 if self.alias is None:
1326 return self
1327 return self.__class__(
1328 relabels.get(self.alias, self.alias), self.target, self.output_field
1329 )
1330
1331 def get_group_by_cols(self):
1332 return [self]
1333
1334 def get_db_converters(self, connection):
1335 if self.target == self.output_field:
1336 return self.output_field.get_db_converters(connection)
1337 return self.output_field.get_db_converters(
1338 connection
1339 ) + self.target.get_db_converters(connection)
1340
1341
1342class ColPairs(Expression):
1343 def __init__(self, alias, targets, sources, output_field):
1344 super().__init__(output_field=output_field)
1345 self.alias = alias
1346 self.targets = targets
1347 self.sources = sources
1348
1349 def __len__(self):
1350 return len(self.targets)
1351
1352 def __iter__(self):
1353 return iter(self.get_cols())
1354
1355 def __repr__(self):
1356 return (
1357 f"{self.__class__.__name__}({self.alias!r}, {self.targets!r}, "
1358 f"{self.sources!r}, {self.output_field!r})"
1359 )
1360
1361 def get_cols(self):
1362 return [
1363 Col(self.alias, target, source)
1364 for target, source in zip(self.targets, self.sources)
1365 ]
1366
1367 def get_source_expressions(self):
1368 return self.get_cols()
1369
1370 def set_source_expressions(self, exprs):
1371 assert all(isinstance(expr, Col) and expr.alias == self.alias for expr in exprs)
1372 self.targets = [col.target for col in exprs]
1373 self.sources = [col.field for col in exprs]
1374
1375 def as_sql(self, compiler, connection):
1376 cols_sql = []
1377 cols_params = []
1378 cols = self.get_cols()
1379
1380 for col in cols:
1381 sql, params = col.as_sql(compiler, connection)
1382 cols_sql.append(sql)
1383 cols_params.extend(params)
1384
1385 return ", ".join(cols_sql), cols_params
1386
1387 def relabeled_clone(self, relabels):
1388 return self.__class__(
1389 relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
1390 )
1391
1392 def resolve_expression(self, *args, **kwargs):
1393 return self
1394
1395
1396class Ref(Expression):
1397 """
1398 Reference to column alias of the query. For example, Ref('sum_cost') in
1399 qs.annotate(sum_cost=Sum('cost')) query.
1400 """
1401
1402 def __init__(self, refs, source):
1403 super().__init__()
1404 self.refs, self.source = refs, source
1405
1406 def __repr__(self):
1407 return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
1408
1409 def get_source_expressions(self):
1410 return [self.source]
1411
1412 def set_source_expressions(self, exprs):
1413 (self.source,) = exprs
1414
1415 def resolve_expression(
1416 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1417 ):
1418 # The sub-expression `source` has already been resolved, as this is
1419 # just a reference to the name of `source`.
1420 return self
1421
1422 def get_refs(self):
1423 return {self.refs}
1424
1425 def relabeled_clone(self, relabels):
1426 clone = self.copy()
1427 clone.source = self.source.relabeled_clone(relabels)
1428 return clone
1429
1430 def as_sql(self, compiler, connection):
1431 return connection.ops.quote_name(self.refs), []
1432
1433 def get_group_by_cols(self):
1434 return [self]
1435
1436
1437class ExpressionList(Func):
1438 """
1439 An expression containing multiple expressions. Can be used to provide a
1440 list of expressions as an argument to another expression, like a partition
1441 clause.
1442 """
1443
1444 template = "%(expressions)s"
1445
1446 def __str__(self):
1447 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1448
1449 def as_sql(self, *args, **kwargs):
1450 if not self.source_expressions:
1451 return "", ()
1452 return super().as_sql(*args, **kwargs)
1453
1454 def as_sqlite(self, compiler, connection, **extra_context):
1455 # Casting to numeric is unnecessary.
1456 return self.as_sql(compiler, connection, **extra_context)
1457
1458 def get_group_by_cols(self):
1459 group_by_cols = []
1460 for expr in self.get_source_expressions():
1461 group_by_cols.extend(expr.get_group_by_cols())
1462 return group_by_cols
1463
1464
1465class OrderByList(ExpressionList):
1466 allowed_default = False
1467 template = "ORDER BY %(expressions)s"
1468
1469 def __init__(self, *expressions, **extra):
1470 expressions = (
1471 (
1472 OrderBy(F(expr[1:]), descending=True)
1473 if isinstance(expr, str) and expr[0] == "-"
1474 else expr
1475 )
1476 for expr in expressions
1477 )
1478 super().__init__(*expressions, **extra)
1479
1480
1481@deconstructible(path="django.db.models.ExpressionWrapper")
1482class ExpressionWrapper(SQLiteNumericMixin, Expression):
1483 """
1484 An expression that can wrap another expression so that it can provide
1485 extra context to the inner expression, such as the output_field.
1486 """
1487
1488 def __init__(self, expression, output_field):
1489 super().__init__(output_field=output_field)
1490 self.expression = expression
1491
1492 def set_source_expressions(self, exprs):
1493 self.expression = exprs[0]
1494
1495 def get_source_expressions(self):
1496 return [self.expression]
1497
1498 def get_group_by_cols(self):
1499 if isinstance(self.expression, Expression):
1500 expression = self.expression.copy()
1501 expression.output_field = self.output_field
1502 return expression.get_group_by_cols()
1503 # For non-expressions e.g. an SQL WHERE clause, the entire
1504 # `expression` must be included in the GROUP BY clause.
1505 return super().get_group_by_cols()
1506
1507 def as_sql(self, compiler, connection):
1508 return compiler.compile(self.expression)
1509
1510 def __repr__(self):
1511 return "{}({})".format(self.__class__.__name__, self.expression)
1512
1513 @property
1514 def allowed_default(self):
1515 return self.expression.allowed_default
1516
1517
1518class NegatedExpression(ExpressionWrapper):
1519 """The logical negation of a conditional expression."""
1520
1521 def __init__(self, expression):
1522 super().__init__(expression, output_field=fields.BooleanField())
1523
1524 def __invert__(self):
1525 return self.expression.copy()
1526
1527 def as_sql(self, compiler, connection):
1528 try:
1529 sql, params = super().as_sql(compiler, connection)
1530 except EmptyResultSet:
1531 features = compiler.connection.features
1532 if not features.supports_boolean_expr_in_select_clause:
1533 return "1=1", ()
1534 return compiler.compile(Value(True))
1535 ops = compiler.connection.ops
1536 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1537 # to be compared to another expression unless they're wrapped in a CASE
1538 # WHEN.
1539 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1540 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1541 return f"NOT {sql}", params
1542
1543 def resolve_expression(
1544 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1545 ):
1546 resolved = super().resolve_expression(
1547 query, allow_joins, reuse, summarize, for_save
1548 )
1549 if not getattr(resolved.expression, "conditional", False):
1550 raise TypeError("Cannot negate non-conditional expressions.")
1551 return resolved
1552
1553 def select_format(self, compiler, sql, params):
1554 # Wrap boolean expressions with a CASE WHEN expression if a database
1555 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1556 # GROUP BY list.
1557 expression_supported_in_where_clause = (
1558 compiler.connection.ops.conditional_expression_supported_in_where_clause
1559 )
1560 if (
1561 not compiler.connection.features.supports_boolean_expr_in_select_clause
1562 # Avoid double wrapping.
1563 and expression_supported_in_where_clause(self.expression)
1564 ):
1565 sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
1566 return sql, params
1567
1568
1569@deconstructible(path="django.db.models.When")
1570class When(Expression):
1571 template = "WHEN %(condition)s THEN %(result)s"
1572 # This isn't a complete conditional expression, must be used in Case().
1573 conditional = False
1574
1575 def __init__(self, condition=None, then=None, **lookups):
1576 if lookups:
1577 if condition is None:
1578 condition, lookups = Q(**lookups), None
1579 elif getattr(condition, "conditional", False):
1580 condition, lookups = Q(condition, **lookups), None
1581 if condition is None or not getattr(condition, "conditional", False) or lookups:
1582 raise TypeError(
1583 "When() supports a Q object, a boolean expression, or lookups "
1584 "as a condition."
1585 )
1586 if isinstance(condition, Q) and not condition:
1587 raise ValueError("An empty Q() can't be used as a When() condition.")
1588 super().__init__(output_field=None)
1589 self.condition = condition
1590 self.result = self._parse_expressions(then)[0]
1591
1592 def __str__(self):
1593 return "WHEN %r THEN %r" % (self.condition, self.result)
1594
1595 def __repr__(self):
1596 return "<%s: %s>" % (self.__class__.__name__, self)
1597
1598 def get_source_expressions(self):
1599 return [self.condition, self.result]
1600
1601 def set_source_expressions(self, exprs):
1602 self.condition, self.result = exprs
1603
1604 def get_source_fields(self):
1605 # We're only interested in the fields of the result expressions.
1606 return [self.result._output_field_or_none]
1607
1608 def resolve_expression(
1609 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1610 ):
1611 c = self.copy()
1612 c.is_summary = summarize
1613 if hasattr(c.condition, "resolve_expression"):
1614 c.condition = c.condition.resolve_expression(
1615 query, allow_joins, reuse, summarize, False
1616 )
1617 c.result = c.result.resolve_expression(
1618 query, allow_joins, reuse, summarize, for_save
1619 )
1620 return c
1621
1622 def as_sql(self, compiler, connection, template=None, **extra_context):
1623 connection.ops.check_expression_support(self)
1624 template_params = extra_context
1625 sql_params = []
1626 condition_sql, condition_params = compiler.compile(self.condition)
1627 template_params["condition"] = condition_sql
1628 result_sql, result_params = compiler.compile(self.result)
1629 template_params["result"] = result_sql
1630 template = template or self.template
1631 return template % template_params, (
1632 *sql_params,
1633 *condition_params,
1634 *result_params,
1635 )
1636
1637 def get_group_by_cols(self):
1638 # This is not a complete expression and cannot be used in GROUP BY.
1639 cols = []
1640 for source in self.get_source_expressions():
1641 cols.extend(source.get_group_by_cols())
1642 return cols
1643
1644 @cached_property
1645 def allowed_default(self):
1646 return self.condition.allowed_default and self.result.allowed_default
1647
1648
1649@deconstructible(path="django.db.models.Case")
1650class Case(SQLiteNumericMixin, Expression):
1651 """
1652 An SQL searched CASE expression:
1653
1654 CASE
1655 WHEN n > 0
1656 THEN 'positive'
1657 WHEN n < 0
1658 THEN 'negative'
1659 ELSE 'zero'
1660 END
1661 """
1662
1663 template = "CASE %(cases)s ELSE %(default)s END"
1664 case_joiner = " "
1665
1666 def __init__(self, *cases, default=None, output_field=None, **extra):
1667 if not all(isinstance(case, When) for case in cases):
1668 raise TypeError("Positional arguments must all be When objects.")
1669 super().__init__(output_field)
1670 self.cases = list(cases)
1671 self.default = self._parse_expressions(default)[0]
1672 self.extra = extra
1673
1674 def __str__(self):
1675 return "CASE %s, ELSE %r" % (
1676 ", ".join(str(c) for c in self.cases),
1677 self.default,
1678 )
1679
1680 def __repr__(self):
1681 return "<%s: %s>" % (self.__class__.__name__, self)
1682
1683 def get_source_expressions(self):
1684 return self.cases + [self.default]
1685
1686 def set_source_expressions(self, exprs):
1687 *self.cases, self.default = exprs
1688
1689 def resolve_expression(
1690 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1691 ):
1692 c = self.copy()
1693 c.is_summary = summarize
1694 for pos, case in enumerate(c.cases):
1695 c.cases[pos] = case.resolve_expression(
1696 query, allow_joins, reuse, summarize, for_save
1697 )
1698 c.default = c.default.resolve_expression(
1699 query, allow_joins, reuse, summarize, for_save
1700 )
1701 return c
1702
1703 def copy(self):
1704 c = super().copy()
1705 c.cases = c.cases[:]
1706 return c
1707
1708 def as_sql(
1709 self, compiler, connection, template=None, case_joiner=None, **extra_context
1710 ):
1711 connection.ops.check_expression_support(self)
1712 if not self.cases:
1713 return compiler.compile(self.default)
1714 template_params = {**self.extra, **extra_context}
1715 case_parts = []
1716 sql_params = []
1717 for case in self.cases:
1718 try:
1719 case_sql, case_params = compiler.compile(case)
1720 except EmptyResultSet:
1721 continue
1722 except FullResultSet:
1723 default_sql, default_params = compiler.compile(case.result)
1724 break
1725 case_parts.append(case_sql)
1726 sql_params.extend(case_params)
1727 else:
1728 default_sql, default_params = compiler.compile(self.default)
1729 if not case_parts:
1730 return default_sql, default_params
1731 case_joiner = case_joiner or self.case_joiner
1732 template_params["cases"] = case_joiner.join(case_parts)
1733 template_params["default"] = default_sql
1734 sql_params.extend(default_params)
1735 template = template or template_params.get("template", self.template)
1736 sql = template % template_params
1737 if self._output_field_or_none is not None:
1738 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1739 return sql, sql_params
1740
1741 def get_group_by_cols(self):
1742 if not self.cases:
1743 return self.default.get_group_by_cols()
1744 return super().get_group_by_cols()
1745
1746 @cached_property
1747 def allowed_default(self):
1748 return self.default.allowed_default and all(
1749 case_.allowed_default for case_ in self.cases
1750 )
1751
1752
1753class Subquery(BaseExpression, Combinable):
1754 """
1755 An explicit subquery. It may contain OuterRef() references to the outer
1756 query which will be resolved when it is applied to that query.
1757 """
1758
1759 template = "(%(subquery)s)"
1760 contains_aggregate = False
1761 empty_result_set_value = None
1762 subquery = True
1763
1764 def __init__(self, queryset, output_field=None, **extra):
1765 # Allow the usage of both QuerySet and sql.Query objects.
1766 self.query = getattr(queryset, "query", queryset).clone()
1767 self.query.subquery = True
1768 self.extra = extra
1769 super().__init__(output_field)
1770
1771 def get_source_expressions(self):
1772 return [self.query]
1773
1774 def set_source_expressions(self, exprs):
1775 self.query = exprs[0]
1776
1777 def _resolve_output_field(self):
1778 return self.query.output_field
1779
1780 def copy(self):
1781 clone = super().copy()
1782 clone.query = clone.query.clone()
1783 return clone
1784
1785 @property
1786 def external_aliases(self):
1787 return self.query.external_aliases
1788
1789 def get_external_cols(self):
1790 return self.query.get_external_cols()
1791
1792 def as_sql(self, compiler, connection, template=None, **extra_context):
1793 connection.ops.check_expression_support(self)
1794 template_params = {**self.extra, **extra_context}
1795 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1796 template_params["subquery"] = subquery_sql[1:-1]
1797
1798 template = template or template_params.get("template", self.template)
1799 sql = template % template_params
1800 return sql, sql_params
1801
1802 def get_group_by_cols(self):
1803 return self.query.get_group_by_cols(wrapper=self)
1804
1805
1806class Exists(Subquery):
1807 template = "EXISTS(%(subquery)s)"
1808 output_field = fields.BooleanField()
1809 empty_result_set_value = False
1810
1811 def __init__(self, queryset, **kwargs):
1812 super().__init__(queryset, **kwargs)
1813 self.query = self.query.exists()
1814
1815 def select_format(self, compiler, sql, params):
1816 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1817 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1818 # BY list.
1819 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1820 sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
1821 return sql, params
1822
1823 def as_sql(self, compiler, *args, **kwargs):
1824 try:
1825 return super().as_sql(compiler, *args, **kwargs)
1826 except EmptyResultSet:
1827 features = compiler.connection.features
1828 if not features.supports_boolean_expr_in_select_clause:
1829 return "1=0", ()
1830 return compiler.compile(Value(False))
1831
1832
1833@deconstructible(path="django.db.models.OrderBy")
1834class OrderBy(Expression):
1835 template = "%(expression)s %(ordering)s"
1836 conditional = False
1837 constraint_validation_compatible = False
1838 allows_composite_expressions = True
1839
1840 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1841 if nulls_first and nulls_last:
1842 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1843 if nulls_first is False or nulls_last is False:
1844 raise ValueError("nulls_first and nulls_last values must be True or None.")
1845 self.nulls_first = nulls_first
1846 self.nulls_last = nulls_last
1847 self.descending = descending
1848 if not hasattr(expression, "resolve_expression"):
1849 raise ValueError("expression must be an expression type")
1850 self.expression = expression
1851
1852 def __repr__(self):
1853 return "{}({}, descending={})".format(
1854 self.__class__.__name__, self.expression, self.descending
1855 )
1856
1857 def set_source_expressions(self, exprs):
1858 self.expression = exprs[0]
1859
1860 def get_source_expressions(self):
1861 return [self.expression]
1862
1863 def as_sql(self, compiler, connection, template=None, **extra_context):
1864 if isinstance(self.expression, ColPairs):
1865 sql_parts = []
1866 params = []
1867 for col in self.expression.get_cols():
1868 copy = self.copy()
1869 copy.set_source_expressions([col])
1870 sql, col_params = compiler.compile(copy)
1871 sql_parts.append(sql)
1872 params.extend(col_params)
1873 return ", ".join(sql_parts), params
1874 template = template or self.template
1875 if connection.features.supports_order_by_nulls_modifier:
1876 if self.nulls_last:
1877 template = "%s NULLS LAST" % template
1878 elif self.nulls_first:
1879 template = "%s NULLS FIRST" % template
1880 else:
1881 if self.nulls_last and not (
1882 self.descending and connection.features.order_by_nulls_first
1883 ):
1884 template = "%%(expression)s IS NULL, %s" % template
1885 elif self.nulls_first and not (
1886 not self.descending and connection.features.order_by_nulls_first
1887 ):
1888 template = "%%(expression)s IS NOT NULL, %s" % template
1889 connection.ops.check_expression_support(self)
1890 expression_sql, params = compiler.compile(self.expression)
1891 placeholders = {
1892 "expression": expression_sql,
1893 "ordering": "DESC" if self.descending else "ASC",
1894 **extra_context,
1895 }
1896 params *= template.count("%(expression)s")
1897 return (template % placeholders).rstrip(), params
1898
1899 def as_oracle(self, compiler, connection):
1900 # Oracle < 23c doesn't allow ORDER BY EXISTS() or filters unless it's
1901 # wrapped in a CASE WHEN.
1902 if (
1903 not connection.features.supports_boolean_expr_in_select_clause
1904 and connection.ops.conditional_expression_supported_in_where_clause(
1905 self.expression
1906 )
1907 ):
1908 copy = self.copy()
1909 copy.expression = Case(
1910 When(self.expression, then=True),
1911 default=False,
1912 )
1913 return copy.as_sql(compiler, connection)
1914 return self.as_sql(compiler, connection)
1915
1916 def get_group_by_cols(self):
1917 cols = []
1918 for source in self.get_source_expressions():
1919 cols.extend(source.get_group_by_cols())
1920 return cols
1921
1922 def reverse_ordering(self):
1923 self.descending = not self.descending
1924 if self.nulls_first:
1925 self.nulls_last = True
1926 self.nulls_first = None
1927 elif self.nulls_last:
1928 self.nulls_first = True
1929 self.nulls_last = None
1930 return self
1931
1932 def asc(self):
1933 self.descending = False
1934
1935 def desc(self):
1936 self.descending = True
1937
1938
1939class Window(SQLiteNumericMixin, Expression):
1940 template = "%(expression)s OVER (%(window)s)"
1941 # Although the main expression may either be an aggregate or an
1942 # expression with an aggregate function, the GROUP BY that will
1943 # be introduced in the query as a result is not desired.
1944 contains_aggregate = False
1945 contains_over_clause = True
1946
1947 def __init__(
1948 self,
1949 expression,
1950 partition_by=None,
1951 order_by=None,
1952 frame=None,
1953 output_field=None,
1954 ):
1955 self.partition_by = partition_by
1956 self.order_by = order_by
1957 self.frame = frame
1958
1959 if not getattr(expression, "window_compatible", False):
1960 raise ValueError(
1961 "Expression '%s' isn't compatible with OVER clauses."
1962 % expression.__class__.__name__
1963 )
1964
1965 if self.partition_by is not None:
1966 if not isinstance(self.partition_by, (tuple, list)):
1967 self.partition_by = (self.partition_by,)
1968 self.partition_by = ExpressionList(*self.partition_by)
1969
1970 if self.order_by is not None:
1971 if isinstance(self.order_by, (list, tuple)):
1972 self.order_by = OrderByList(*self.order_by)
1973 elif isinstance(self.order_by, (BaseExpression, str)):
1974 self.order_by = OrderByList(self.order_by)
1975 else:
1976 raise ValueError(
1977 "Window.order_by must be either a string reference to a "
1978 "field, an expression, or a list or tuple of them."
1979 )
1980 super().__init__(output_field=output_field)
1981 self.source_expression = self._parse_expressions(expression)[0]
1982
1983 def _resolve_output_field(self):
1984 return self.source_expression.output_field
1985
1986 def get_source_expressions(self):
1987 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1988
1989 def set_source_expressions(self, exprs):
1990 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1991
1992 def as_sql(self, compiler, connection, template=None):
1993 connection.ops.check_expression_support(self)
1994 if not connection.features.supports_over_clause:
1995 raise NotSupportedError("This backend does not support window expressions.")
1996 expr_sql, params = compiler.compile(self.source_expression)
1997 window_sql, window_params = [], ()
1998
1999 if self.partition_by is not None:
2000 sql_expr, sql_params = self.partition_by.as_sql(
2001 compiler=compiler,
2002 connection=connection,
2003 template="PARTITION BY %(expressions)s",
2004 )
2005 window_sql.append(sql_expr)
2006 window_params += tuple(sql_params)
2007
2008 if self.order_by is not None:
2009 order_sql, order_params = compiler.compile(self.order_by)
2010 window_sql.append(order_sql)
2011 window_params += tuple(order_params)
2012
2013 if self.frame:
2014 frame_sql, frame_params = compiler.compile(self.frame)
2015 window_sql.append(frame_sql)
2016 window_params += tuple(frame_params)
2017
2018 template = template or self.template
2019
2020 return (
2021 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
2022 (*params, *window_params),
2023 )
2024
2025 def as_sqlite(self, compiler, connection):
2026 if isinstance(self.output_field, fields.DecimalField):
2027 # Casting to numeric must be outside of the window expression.
2028 copy = self.copy()
2029 source_expressions = copy.get_source_expressions()
2030 source_expressions[0].output_field = fields.FloatField()
2031 copy.set_source_expressions(source_expressions)
2032 return super(Window, copy).as_sqlite(compiler, connection)
2033 return self.as_sql(compiler, connection)
2034
2035 def __str__(self):
2036 return "{} OVER ({}{}{})".format(
2037 str(self.source_expression),
2038 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
2039 str(self.order_by or ""),
2040 str(self.frame or ""),
2041 )
2042
2043 def __repr__(self):
2044 return "<%s: %s>" % (self.__class__.__name__, self)
2045
2046 def get_group_by_cols(self):
2047 group_by_cols = []
2048 if self.partition_by:
2049 group_by_cols.extend(self.partition_by.get_group_by_cols())
2050 if self.order_by is not None:
2051 group_by_cols.extend(self.order_by.get_group_by_cols())
2052 return group_by_cols
2053
2054
2055class WindowFrameExclusion(Enum):
2056 CURRENT_ROW = "CURRENT ROW"
2057 GROUP = "GROUP"
2058 TIES = "TIES"
2059 NO_OTHERS = "NO OTHERS"
2060
2061 def __repr__(self):
2062 return f"{self.__class__.__qualname__}.{self._name_}"
2063
2064
2065class WindowFrame(Expression):
2066 """
2067 Model the frame clause in window expressions. There are two types of frame
2068 clauses which are subclasses, however, all processing and validation (by no
2069 means intended to be complete) is done here. Thus, providing an end for a
2070 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
2071 row in the frame).
2072 """
2073
2074 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s%(exclude)s"
2075
2076 def __init__(self, start=None, end=None, exclusion=None):
2077 self.start = Value(start)
2078 self.end = Value(end)
2079 if not isinstance(exclusion, (NoneType, WindowFrameExclusion)):
2080 raise TypeError(
2081 f"{self.__class__.__qualname__}.exclusion must be a "
2082 "WindowFrameExclusion instance."
2083 )
2084 self.exclusion = exclusion
2085
2086 def set_source_expressions(self, exprs):
2087 self.start, self.end = exprs
2088
2089 def get_source_expressions(self):
2090 return [self.start, self.end]
2091
2092 def get_exclusion(self):
2093 if self.exclusion is None:
2094 return ""
2095 return f" EXCLUDE {self.exclusion.value}"
2096
2097 def as_sql(self, compiler, connection):
2098 connection.ops.check_expression_support(self)
2099 start, end = self.window_frame_start_end(
2100 connection, self.start.value, self.end.value
2101 )
2102 if self.exclusion and not connection.features.supports_frame_exclusion:
2103 raise NotSupportedError(
2104 "This backend does not support window frame exclusions."
2105 )
2106 return (
2107 self.template
2108 % {
2109 "frame_type": self.frame_type,
2110 "start": start,
2111 "end": end,
2112 "exclude": self.get_exclusion(),
2113 },
2114 [],
2115 )
2116
2117 def __repr__(self):
2118 return "<%s: %s>" % (self.__class__.__name__, self)
2119
2120 def get_group_by_cols(self):
2121 return []
2122
2123 def __str__(self):
2124 if self.start.value is not None and self.start.value < 0:
2125 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
2126 elif self.start.value is not None and self.start.value == 0:
2127 start = connection.ops.CURRENT_ROW
2128 elif self.start.value is not None and self.start.value > 0:
2129 start = "%d %s" % (self.start.value, connection.ops.FOLLOWING)
2130 else:
2131 start = connection.ops.UNBOUNDED_PRECEDING
2132
2133 if self.end.value is not None and self.end.value > 0:
2134 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
2135 elif self.end.value is not None and self.end.value == 0:
2136 end = connection.ops.CURRENT_ROW
2137 elif self.end.value is not None and self.end.value < 0:
2138 end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING)
2139 else:
2140 end = connection.ops.UNBOUNDED_FOLLOWING
2141 return self.template % {
2142 "frame_type": self.frame_type,
2143 "start": start,
2144 "end": end,
2145 "exclude": self.get_exclusion(),
2146 }
2147
2148 def window_frame_start_end(self, connection, start, end):
2149 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
2150
2151
2152class RowRange(WindowFrame):
2153 frame_type = "ROWS"
2154
2155 def window_frame_start_end(self, connection, start, end):
2156 return connection.ops.window_frame_rows_start_end(start, end)
2157
2158
2159class ValueRange(WindowFrame):
2160 frame_type = "RANGE"
2161
2162 def window_frame_start_end(self, connection, start, end):
2163 return connection.ops.window_frame_range_start_end(start, end)