Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/expressions.py: 33%
988 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
1import copy
2import datetime
3import functools
4import inspect
5import warnings
6from collections import defaultdict
7from decimal import Decimal
8from uuid import UUID
10from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
11from django.db import DatabaseError, NotSupportedError, connection
12from django.db.models import fields
13from django.db.models.constants import LOOKUP_SEP
14from django.db.models.query_utils import Q
15from django.utils.deconstruct import deconstructible
16from django.utils.deprecation import RemovedInDjango50Warning
17from django.utils.functional import cached_property
18from django.utils.hashable import make_hashable
21class SQLiteNumericMixin:
22 """
23 Some expressions with output_field=DecimalField() must be cast to
24 numeric to be properly filtered.
25 """
27 def as_sqlite(self, compiler, connection, **extra_context):
28 sql, params = self.as_sql(compiler, connection, **extra_context)
29 try:
30 if self.output_field.get_internal_type() == "DecimalField":
31 sql = "CAST(%s AS NUMERIC)" % sql
32 except FieldError:
33 pass
34 return sql, params
37class Combinable:
38 """
39 Provide the ability to combine one or two objects with
40 some connector. For example F('foo') + F('bar').
41 """
43 # Arithmetic connectors
44 ADD = "+"
45 SUB = "-"
46 MUL = "*"
47 DIV = "/"
48 POW = "^"
49 # The following is a quoted % operator - it is quoted because it can be
50 # used in strings that also have parameter substitution.
51 MOD = "%%"
53 # Bitwise operators - note that these are generated by .bitand()
54 # and .bitor(), the '&' and '|' are reserved for boolean operator
55 # usage.
56 BITAND = "&"
57 BITOR = "|"
58 BITLEFTSHIFT = "<<"
59 BITRIGHTSHIFT = ">>"
60 BITXOR = "#"
62 def _combine(self, other, connector, reversed):
63 if not hasattr(other, "resolve_expression"):
64 # everything must be resolvable to an expression
65 other = Value(other)
67 if reversed:
68 return CombinedExpression(other, connector, self)
69 return CombinedExpression(self, connector, other)
71 #############
72 # OPERATORS #
73 #############
75 def __neg__(self):
76 return self._combine(-1, self.MUL, False)
78 def __add__(self, other):
79 return self._combine(other, self.ADD, False)
81 def __sub__(self, other):
82 return self._combine(other, self.SUB, False)
84 def __mul__(self, other):
85 return self._combine(other, self.MUL, False)
87 def __truediv__(self, other):
88 return self._combine(other, self.DIV, False)
90 def __mod__(self, other):
91 return self._combine(other, self.MOD, False)
93 def __pow__(self, other):
94 return self._combine(other, self.POW, False)
96 def __and__(self, other):
97 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
98 return Q(self) & Q(other)
99 raise NotImplementedError(
100 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
101 )
103 def bitand(self, other):
104 return self._combine(other, self.BITAND, False)
106 def bitleftshift(self, other):
107 return self._combine(other, self.BITLEFTSHIFT, False)
109 def bitrightshift(self, other):
110 return self._combine(other, self.BITRIGHTSHIFT, False)
112 def __xor__(self, other):
113 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
114 return Q(self) ^ Q(other)
115 raise NotImplementedError(
116 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
117 )
119 def bitxor(self, other):
120 return self._combine(other, self.BITXOR, False)
122 def __or__(self, other):
123 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
124 return Q(self) | Q(other)
125 raise NotImplementedError(
126 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
127 )
129 def bitor(self, other):
130 return self._combine(other, self.BITOR, False)
132 def __radd__(self, other):
133 return self._combine(other, self.ADD, True)
135 def __rsub__(self, other):
136 return self._combine(other, self.SUB, True)
138 def __rmul__(self, other):
139 return self._combine(other, self.MUL, True)
141 def __rtruediv__(self, other):
142 return self._combine(other, self.DIV, True)
144 def __rmod__(self, other):
145 return self._combine(other, self.MOD, True)
147 def __rpow__(self, other):
148 return self._combine(other, self.POW, True)
150 def __rand__(self, other):
151 raise NotImplementedError(
152 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
153 )
155 def __ror__(self, other):
156 raise NotImplementedError(
157 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
158 )
160 def __rxor__(self, other):
161 raise NotImplementedError(
162 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
163 )
165 def __invert__(self):
166 return NegatedExpression(self)
169class BaseExpression:
170 """Base class for all query expressions."""
172 empty_result_set_value = NotImplemented
173 # aggregate specific fields
174 is_summary = False
175 _output_field_resolved_to_none = False
176 # Can the expression be used in a WHERE clause?
177 filterable = True
178 # Can the expression can be used as a source expression in Window?
179 window_compatible = False
181 def __init__(self, output_field=None):
182 if output_field is not None:
183 self.output_field = output_field
185 def __getstate__(self):
186 state = self.__dict__.copy()
187 state.pop("convert_value", None)
188 return state
190 def get_db_converters(self, connection):
191 return (
192 []
193 if self.convert_value is self._convert_value_noop
194 else [self.convert_value]
195 ) + self.output_field.get_db_converters(connection)
197 def get_source_expressions(self):
198 return []
200 def set_source_expressions(self, exprs):
201 assert not exprs
203 def _parse_expressions(self, *expressions):
204 return [
205 arg
206 if hasattr(arg, "resolve_expression")
207 else (F(arg) if isinstance(arg, str) else Value(arg))
208 for arg in expressions
209 ]
211 def as_sql(self, compiler, connection):
212 """
213 Responsible for returning a (sql, [params]) tuple to be included
214 in the current query.
216 Different backends can provide their own implementation, by
217 providing an `as_{vendor}` method and patching the Expression:
219 ```
220 def override_as_sql(self, compiler, connection):
221 # custom logic
222 return super().as_sql(compiler, connection)
223 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
224 ```
226 Arguments:
227 * compiler: the query compiler responsible for generating the query.
228 Must have a compile method, returning a (sql, [params]) tuple.
229 Calling compiler(value) will return a quoted `value`.
231 * connection: the database connection used for the current query.
233 Return: (sql, params)
234 Where `sql` is a string containing ordered sql parameters to be
235 replaced with the elements of the list `params`.
236 """
237 raise NotImplementedError("Subclasses must implement as_sql()")
239 @cached_property
240 def contains_aggregate(self):
241 return any(
242 expr and expr.contains_aggregate for expr in self.get_source_expressions()
243 )
245 @cached_property
246 def contains_over_clause(self):
247 return any(
248 expr and expr.contains_over_clause for expr in self.get_source_expressions()
249 )
251 @cached_property
252 def contains_column_references(self):
253 return any(
254 expr and expr.contains_column_references
255 for expr in self.get_source_expressions()
256 )
258 def resolve_expression(
259 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
260 ):
261 """
262 Provide the chance to do any preprocessing or validation before being
263 added to the query.
265 Arguments:
266 * query: the backend query implementation
267 * allow_joins: boolean allowing or denying use of joins
268 in this query
269 * reuse: a set of reusable joins for multijoins
270 * summarize: a terminal aggregate clause
271 * for_save: whether this expression about to be used in a save or update
273 Return: an Expression to be added to the query.
274 """
275 c = self.copy()
276 c.is_summary = summarize
277 c.set_source_expressions(
278 [
279 expr.resolve_expression(query, allow_joins, reuse, summarize)
280 if expr
281 else None
282 for expr in c.get_source_expressions()
283 ]
284 )
285 return c
287 @property
288 def conditional(self):
289 return isinstance(self.output_field, fields.BooleanField)
291 @property
292 def field(self):
293 return self.output_field
295 @cached_property
296 def output_field(self):
297 """Return the output type of this expressions."""
298 output_field = self._resolve_output_field()
299 if output_field is None:
300 self._output_field_resolved_to_none = True
301 raise FieldError("Cannot resolve expression type, unknown output_field")
302 return output_field
304 @cached_property
305 def _output_field_or_none(self):
306 """
307 Return the output field of this expression, or None if
308 _resolve_output_field() didn't return an output type.
309 """
310 try:
311 return self.output_field
312 except FieldError:
313 if not self._output_field_resolved_to_none:
314 raise
316 def _resolve_output_field(self):
317 """
318 Attempt to infer the output type of the expression.
320 As a guess, if the output fields of all source fields match then simply
321 infer the same type here.
323 If a source's output field resolves to None, exclude it from this check.
324 If all sources are None, then an error is raised higher up the stack in
325 the output_field property.
326 """
327 # This guess is mostly a bad idea, but there is quite a lot of code
328 # (especially 3rd party Func subclasses) that depend on it, we'd need a
329 # deprecation path to fix it.
330 sources_iter = (
331 source for source in self.get_source_fields() if source is not None
332 )
333 for output_field in sources_iter:
334 for source in sources_iter:
335 if not isinstance(output_field, source.__class__):
336 raise FieldError(
337 "Expression contains mixed types: %s, %s. You must "
338 "set output_field."
339 % (
340 output_field.__class__.__name__,
341 source.__class__.__name__,
342 )
343 )
344 return output_field
346 @staticmethod
347 def _convert_value_noop(value, expression, connection):
348 return value
350 @cached_property
351 def convert_value(self):
352 """
353 Expressions provide their own converters because users have the option
354 of manually specifying the output_field which may be a different type
355 from the one the database returns.
356 """
357 field = self.output_field
358 internal_type = field.get_internal_type()
359 if internal_type == "FloatField":
360 return (
361 lambda value, expression, connection: None
362 if value is None
363 else float(value)
364 )
365 elif internal_type.endswith("IntegerField"):
366 return (
367 lambda value, expression, connection: None
368 if value is None
369 else int(value)
370 )
371 elif internal_type == "DecimalField":
372 return (
373 lambda value, expression, connection: None
374 if value is None
375 else Decimal(value)
376 )
377 return self._convert_value_noop
379 def get_lookup(self, lookup):
380 return self.output_field.get_lookup(lookup)
382 def get_transform(self, name):
383 return self.output_field.get_transform(name)
385 def relabeled_clone(self, change_map):
386 clone = self.copy()
387 clone.set_source_expressions(
388 [
389 e.relabeled_clone(change_map) if e is not None else None
390 for e in self.get_source_expressions()
391 ]
392 )
393 return clone
395 def replace_expressions(self, replacements):
396 if replacement := replacements.get(self):
397 return replacement
398 clone = self.copy()
399 source_expressions = clone.get_source_expressions()
400 clone.set_source_expressions(
401 [
402 expr.replace_expressions(replacements) if expr else None
403 for expr in source_expressions
404 ]
405 )
406 return clone
408 def get_refs(self):
409 refs = set()
410 for expr in self.get_source_expressions():
411 refs |= expr.get_refs()
412 return refs
414 def copy(self):
415 return copy.copy(self)
417 def prefix_references(self, prefix):
418 clone = self.copy()
419 clone.set_source_expressions(
420 [
421 F(f"{prefix}{expr.name}")
422 if isinstance(expr, F)
423 else expr.prefix_references(prefix)
424 for expr in self.get_source_expressions()
425 ]
426 )
427 return clone
429 def get_group_by_cols(self):
430 if not self.contains_aggregate:
431 return [self]
432 cols = []
433 for source in self.get_source_expressions():
434 cols.extend(source.get_group_by_cols())
435 return cols
437 def get_source_fields(self):
438 """Return the underlying field types used by this aggregate."""
439 return [e._output_field_or_none for e in self.get_source_expressions()]
441 def asc(self, **kwargs):
442 return OrderBy(self, **kwargs)
444 def desc(self, **kwargs):
445 return OrderBy(self, descending=True, **kwargs)
447 def reverse_ordering(self):
448 return self
450 def flatten(self):
451 """
452 Recursively yield this expression and all subexpressions, in
453 depth-first order.
454 """
455 yield self
456 for expr in self.get_source_expressions():
457 if expr:
458 if hasattr(expr, "flatten"):
459 yield from expr.flatten()
460 else:
461 yield expr
463 def select_format(self, compiler, sql, params):
464 """
465 Custom format for select clauses. For example, EXISTS expressions need
466 to be wrapped in CASE WHEN on Oracle.
467 """
468 if hasattr(self.output_field, "select_format"):
469 return self.output_field.select_format(compiler, sql, params)
470 return sql, params
473@deconstructible
474class Expression(BaseExpression, Combinable):
475 """An expression that can be combined with other expressions."""
477 @cached_property
478 def identity(self):
479 constructor_signature = inspect.signature(self.__init__)
480 args, kwargs = self._constructor_args
481 signature = constructor_signature.bind_partial(*args, **kwargs)
482 signature.apply_defaults()
483 arguments = signature.arguments.items()
484 identity = [self.__class__]
485 for arg, value in arguments:
486 if isinstance(value, fields.Field):
487 if value.name and value.model:
488 value = (value.model._meta.label, value.name)
489 else:
490 value = type(value)
491 else:
492 value = make_hashable(value)
493 identity.append((arg, value))
494 return tuple(identity)
496 def __eq__(self, other):
497 if not isinstance(other, Expression):
498 return NotImplemented
499 return other.identity == self.identity
501 def __hash__(self):
502 return hash(self.identity)
505# Type inference for CombinedExpression.output_field.
506# Missing items will result in FieldError, by design.
507#
508# The current approach for NULL is based on lowest common denominator behavior
509# i.e. if one of the supported databases is raising an error (rather than
510# return NULL) for `val <op> NULL`, then Django raises FieldError.
511NoneType = type(None)
513_connector_combinations = [
514 # Numeric operations - operands of same type.
515 {
516 connector: [
517 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
518 (fields.FloatField, fields.FloatField, fields.FloatField),
519 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
520 ]
521 for connector in (
522 Combinable.ADD,
523 Combinable.SUB,
524 Combinable.MUL,
525 # Behavior for DIV with integer arguments follows Postgres/SQLite,
526 # not MySQL/Oracle.
527 Combinable.DIV,
528 Combinable.MOD,
529 Combinable.POW,
530 )
531 },
532 # Numeric operations - operands of different type.
533 {
534 connector: [
535 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
536 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
537 (fields.IntegerField, fields.FloatField, fields.FloatField),
538 (fields.FloatField, fields.IntegerField, fields.FloatField),
539 ]
540 for connector in (
541 Combinable.ADD,
542 Combinable.SUB,
543 Combinable.MUL,
544 Combinable.DIV,
545 Combinable.MOD,
546 )
547 },
548 # Bitwise operators.
549 {
550 connector: [
551 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
552 ]
553 for connector in (
554 Combinable.BITAND,
555 Combinable.BITOR,
556 Combinable.BITLEFTSHIFT,
557 Combinable.BITRIGHTSHIFT,
558 Combinable.BITXOR,
559 )
560 },
561 # Numeric with NULL.
562 {
563 connector: [
564 (field_type, NoneType, field_type),
565 (NoneType, field_type, field_type),
566 ]
567 for connector in (
568 Combinable.ADD,
569 Combinable.SUB,
570 Combinable.MUL,
571 Combinable.DIV,
572 Combinable.MOD,
573 Combinable.POW,
574 )
575 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
576 },
577 # Date/DateTimeField/DurationField/TimeField.
578 {
579 Combinable.ADD: [
580 # Date/DateTimeField.
581 (fields.DateField, fields.DurationField, fields.DateTimeField),
582 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
583 (fields.DurationField, fields.DateField, fields.DateTimeField),
584 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
585 # DurationField.
586 (fields.DurationField, fields.DurationField, fields.DurationField),
587 # TimeField.
588 (fields.TimeField, fields.DurationField, fields.TimeField),
589 (fields.DurationField, fields.TimeField, fields.TimeField),
590 ],
591 },
592 {
593 Combinable.SUB: [
594 # Date/DateTimeField.
595 (fields.DateField, fields.DurationField, fields.DateTimeField),
596 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
597 (fields.DateField, fields.DateField, fields.DurationField),
598 (fields.DateField, fields.DateTimeField, fields.DurationField),
599 (fields.DateTimeField, fields.DateField, fields.DurationField),
600 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
601 # DurationField.
602 (fields.DurationField, fields.DurationField, fields.DurationField),
603 # TimeField.
604 (fields.TimeField, fields.DurationField, fields.TimeField),
605 (fields.TimeField, fields.TimeField, fields.DurationField),
606 ],
607 },
608]
610_connector_combinators = defaultdict(list)
613def register_combinable_fields(lhs, connector, rhs, result):
614 """
615 Register combinable types:
616 lhs <connector> rhs -> result
617 e.g.
618 register_combinable_fields(
619 IntegerField, Combinable.ADD, FloatField, FloatField
620 )
621 """
622 _connector_combinators[connector].append((lhs, rhs, result))
625for d in _connector_combinations:
626 for connector, field_types in d.items():
627 for lhs, rhs, result in field_types:
628 register_combinable_fields(lhs, connector, rhs, result)
631@functools.lru_cache(maxsize=128)
632def _resolve_combined_type(connector, lhs_type, rhs_type):
633 combinators = _connector_combinators.get(connector, ())
634 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
635 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
636 rhs_type, combinator_rhs_type
637 ):
638 return combined_type
641class CombinedExpression(SQLiteNumericMixin, Expression):
642 def __init__(self, lhs, connector, rhs, output_field=None):
643 super().__init__(output_field=output_field)
644 self.connector = connector
645 self.lhs = lhs
646 self.rhs = rhs
648 def __repr__(self):
649 return "<{}: {}>".format(self.__class__.__name__, self)
651 def __str__(self):
652 return "{} {} {}".format(self.lhs, self.connector, self.rhs)
654 def get_source_expressions(self):
655 return [self.lhs, self.rhs]
657 def set_source_expressions(self, exprs):
658 self.lhs, self.rhs = exprs
660 def _resolve_output_field(self):
661 # We avoid using super() here for reasons given in
662 # Expression._resolve_output_field()
663 combined_type = _resolve_combined_type(
664 self.connector,
665 type(self.lhs._output_field_or_none),
666 type(self.rhs._output_field_or_none),
667 )
668 if combined_type is None:
669 raise FieldError(
670 f"Cannot infer type of {self.connector!r} expression involving these "
671 f"types: {self.lhs.output_field.__class__.__name__}, "
672 f"{self.rhs.output_field.__class__.__name__}. You must set "
673 f"output_field."
674 )
675 return combined_type()
677 def as_sql(self, compiler, connection):
678 expressions = []
679 expression_params = []
680 sql, params = compiler.compile(self.lhs)
681 expressions.append(sql)
682 expression_params.extend(params)
683 sql, params = compiler.compile(self.rhs)
684 expressions.append(sql)
685 expression_params.extend(params)
686 # order of precedence
687 expression_wrapper = "(%s)"
688 sql = connection.ops.combine_expression(self.connector, expressions)
689 return expression_wrapper % sql, expression_params
691 def resolve_expression(
692 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
693 ):
694 lhs = self.lhs.resolve_expression(
695 query, allow_joins, reuse, summarize, for_save
696 )
697 rhs = self.rhs.resolve_expression(
698 query, allow_joins, reuse, summarize, for_save
699 )
700 if not isinstance(self, (DurationExpression, TemporalSubtraction)):
701 try:
702 lhs_type = lhs.output_field.get_internal_type()
703 except (AttributeError, FieldError):
704 lhs_type = None
705 try:
706 rhs_type = rhs.output_field.get_internal_type()
707 except (AttributeError, FieldError):
708 rhs_type = None
709 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
710 return DurationExpression(
711 self.lhs, self.connector, self.rhs
712 ).resolve_expression(
713 query,
714 allow_joins,
715 reuse,
716 summarize,
717 for_save,
718 )
719 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
720 if (
721 self.connector == self.SUB
722 and lhs_type in datetime_fields
723 and lhs_type == rhs_type
724 ):
725 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
726 query,
727 allow_joins,
728 reuse,
729 summarize,
730 for_save,
731 )
732 c = self.copy()
733 c.is_summary = summarize
734 c.lhs = lhs
735 c.rhs = rhs
736 return c
739class DurationExpression(CombinedExpression):
740 def compile(self, side, compiler, connection):
741 try:
742 output = side.output_field
743 except FieldError:
744 pass
745 else:
746 if output.get_internal_type() == "DurationField":
747 sql, params = compiler.compile(side)
748 return connection.ops.format_for_duration_arithmetic(sql), params
749 return compiler.compile(side)
751 def as_sql(self, compiler, connection):
752 if connection.features.has_native_duration_field:
753 return super().as_sql(compiler, connection)
754 connection.ops.check_expression_support(self)
755 expressions = []
756 expression_params = []
757 sql, params = self.compile(self.lhs, compiler, connection)
758 expressions.append(sql)
759 expression_params.extend(params)
760 sql, params = self.compile(self.rhs, compiler, connection)
761 expressions.append(sql)
762 expression_params.extend(params)
763 # order of precedence
764 expression_wrapper = "(%s)"
765 sql = connection.ops.combine_duration_expression(self.connector, expressions)
766 return expression_wrapper % sql, expression_params
768 def as_sqlite(self, compiler, connection, **extra_context):
769 sql, params = self.as_sql(compiler, connection, **extra_context)
770 if self.connector in {Combinable.MUL, Combinable.DIV}:
771 try:
772 lhs_type = self.lhs.output_field.get_internal_type()
773 rhs_type = self.rhs.output_field.get_internal_type()
774 except (AttributeError, FieldError):
775 pass
776 else:
777 allowed_fields = {
778 "DecimalField",
779 "DurationField",
780 "FloatField",
781 "IntegerField",
782 }
783 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
784 raise DatabaseError(
785 f"Invalid arguments for operator {self.connector}."
786 )
787 return sql, params
790class TemporalSubtraction(CombinedExpression):
791 output_field = fields.DurationField()
793 def __init__(self, lhs, rhs):
794 super().__init__(lhs, self.SUB, rhs)
796 def as_sql(self, compiler, connection):
797 connection.ops.check_expression_support(self)
798 lhs = compiler.compile(self.lhs)
799 rhs = compiler.compile(self.rhs)
800 return connection.ops.subtract_temporals(
801 self.lhs.output_field.get_internal_type(), lhs, rhs
802 )
805@deconstructible(path="django.db.models.F")
806class F(Combinable):
807 """An object capable of resolving references to existing query objects."""
809 def __init__(self, name):
810 """
811 Arguments:
812 * name: the name of the field this expression references
813 """
814 self.name = name
816 def __repr__(self):
817 return "{}({})".format(self.__class__.__name__, self.name)
819 def resolve_expression(
820 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
821 ):
822 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
824 def replace_expressions(self, replacements):
825 return replacements.get(self, self)
827 def asc(self, **kwargs):
828 return OrderBy(self, **kwargs)
830 def desc(self, **kwargs):
831 return OrderBy(self, descending=True, **kwargs)
833 def __eq__(self, other):
834 return self.__class__ == other.__class__ and self.name == other.name
836 def __hash__(self):
837 return hash(self.name)
839 def copy(self):
840 return copy.copy(self)
843class ResolvedOuterRef(F):
844 """
845 An object that contains a reference to an outer query.
847 In this case, the reference to the outer query has been resolved because
848 the inner query has been used as a subquery.
849 """
851 contains_aggregate = False
852 contains_over_clause = False
854 def as_sql(self, *args, **kwargs):
855 raise ValueError(
856 "This queryset contains a reference to an outer query and may "
857 "only be used in a subquery."
858 )
860 def resolve_expression(self, *args, **kwargs):
861 col = super().resolve_expression(*args, **kwargs)
862 # FIXME: Rename possibly_multivalued to multivalued and fix detection
863 # for non-multivalued JOINs (e.g. foreign key fields). This should take
864 # into account only many-to-many and one-to-many relationships.
865 col.possibly_multivalued = LOOKUP_SEP in self.name
866 return col
868 def relabeled_clone(self, relabels):
869 return self
871 def get_group_by_cols(self):
872 return []
875class OuterRef(F):
876 contains_aggregate = False
878 def resolve_expression(self, *args, **kwargs):
879 if isinstance(self.name, self.__class__):
880 return self.name
881 return ResolvedOuterRef(self.name)
883 def relabeled_clone(self, relabels):
884 return self
887@deconstructible(path="django.db.models.Func")
888class Func(SQLiteNumericMixin, Expression):
889 """An SQL function call."""
891 function = None
892 template = "%(function)s(%(expressions)s)"
893 arg_joiner = ", "
894 arity = None # The number of arguments the function accepts.
896 def __init__(self, *expressions, output_field=None, **extra):
897 if self.arity is not None and len(expressions) != self.arity:
898 raise TypeError(
899 "'%s' takes exactly %s %s (%s given)"
900 % (
901 self.__class__.__name__,
902 self.arity,
903 "argument" if self.arity == 1 else "arguments",
904 len(expressions),
905 )
906 )
907 super().__init__(output_field=output_field)
908 self.source_expressions = self._parse_expressions(*expressions)
909 self.extra = extra
911 def __repr__(self):
912 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
913 extra = {**self.extra, **self._get_repr_options()}
914 if extra:
915 extra = ", ".join(
916 str(key) + "=" + str(val) for key, val in sorted(extra.items())
917 )
918 return "{}({}, {})".format(self.__class__.__name__, args, extra)
919 return "{}({})".format(self.__class__.__name__, args)
921 def _get_repr_options(self):
922 """Return a dict of extra __init__() options to include in the repr."""
923 return {}
925 def get_source_expressions(self):
926 return self.source_expressions
928 def set_source_expressions(self, exprs):
929 self.source_expressions = exprs
931 def resolve_expression(
932 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
933 ):
934 c = self.copy()
935 c.is_summary = summarize
936 for pos, arg in enumerate(c.source_expressions):
937 c.source_expressions[pos] = arg.resolve_expression(
938 query, allow_joins, reuse, summarize, for_save
939 )
940 return c
942 def as_sql(
943 self,
944 compiler,
945 connection,
946 function=None,
947 template=None,
948 arg_joiner=None,
949 **extra_context,
950 ):
951 connection.ops.check_expression_support(self)
952 sql_parts = []
953 params = []
954 for arg in self.source_expressions:
955 try:
956 arg_sql, arg_params = compiler.compile(arg)
957 except EmptyResultSet:
958 empty_result_set_value = getattr(
959 arg, "empty_result_set_value", NotImplemented
960 )
961 if empty_result_set_value is NotImplemented:
962 raise
963 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
964 except FullResultSet:
965 arg_sql, arg_params = compiler.compile(Value(True))
966 sql_parts.append(arg_sql)
967 params.extend(arg_params)
968 data = {**self.extra, **extra_context}
969 # Use the first supplied value in this order: the parameter to this
970 # method, a value supplied in __init__()'s **extra (the value in
971 # `data`), or the value defined on the class.
972 if function is not None:
973 data["function"] = function
974 else:
975 data.setdefault("function", self.function)
976 template = template or data.get("template", self.template)
977 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
978 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
979 return template % data, params
981 def copy(self):
982 copy = super().copy()
983 copy.source_expressions = self.source_expressions[:]
984 copy.extra = self.extra.copy()
985 return copy
988@deconstructible(path="django.db.models.Value")
989class Value(SQLiteNumericMixin, Expression):
990 """Represent a wrapped value as a node within an expression."""
992 # Provide a default value for `for_save` in order to allow unresolved
993 # instances to be compiled until a decision is taken in #25425.
994 for_save = False
996 def __init__(self, value, output_field=None):
997 """
998 Arguments:
999 * value: the value this expression represents. The value will be
1000 added into the sql parameter list and properly quoted.
1002 * output_field: an instance of the model field type that this
1003 expression will return, such as IntegerField() or CharField().
1004 """
1005 super().__init__(output_field=output_field)
1006 self.value = value
1008 def __repr__(self):
1009 return f"{self.__class__.__name__}({self.value!r})"
1011 def as_sql(self, compiler, connection):
1012 connection.ops.check_expression_support(self)
1013 val = self.value
1014 output_field = self._output_field_or_none
1015 if output_field is not None:
1016 if self.for_save:
1017 val = output_field.get_db_prep_save(val, connection=connection)
1018 else:
1019 val = output_field.get_db_prep_value(val, connection=connection)
1020 if hasattr(output_field, "get_placeholder"):
1021 return output_field.get_placeholder(val, compiler, connection), [val]
1022 if val is None:
1023 # cx_Oracle does not always convert None to the appropriate
1024 # NULL type (like in case expressions using numbers), so we
1025 # use a literal SQL NULL
1026 return "NULL", []
1027 return "%s", [val]
1029 def resolve_expression(
1030 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1031 ):
1032 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1033 c.for_save = for_save
1034 return c
1036 def get_group_by_cols(self):
1037 return []
1039 def _resolve_output_field(self):
1040 if isinstance(self.value, str):
1041 return fields.CharField()
1042 if isinstance(self.value, bool):
1043 return fields.BooleanField()
1044 if isinstance(self.value, int):
1045 return fields.IntegerField()
1046 if isinstance(self.value, float):
1047 return fields.FloatField()
1048 if isinstance(self.value, datetime.datetime):
1049 return fields.DateTimeField()
1050 if isinstance(self.value, datetime.date):
1051 return fields.DateField()
1052 if isinstance(self.value, datetime.time):
1053 return fields.TimeField()
1054 if isinstance(self.value, datetime.timedelta):
1055 return fields.DurationField()
1056 if isinstance(self.value, Decimal):
1057 return fields.DecimalField()
1058 if isinstance(self.value, bytes):
1059 return fields.BinaryField()
1060 if isinstance(self.value, UUID):
1061 return fields.UUIDField()
1063 @property
1064 def empty_result_set_value(self):
1065 return self.value
1068class RawSQL(Expression):
1069 def __init__(self, sql, params, output_field=None):
1070 if output_field is None:
1071 output_field = fields.Field()
1072 self.sql, self.params = sql, params
1073 super().__init__(output_field=output_field)
1075 def __repr__(self):
1076 return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
1078 def as_sql(self, compiler, connection):
1079 return "(%s)" % self.sql, self.params
1081 def get_group_by_cols(self):
1082 return [self]
1084 def resolve_expression(
1085 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1086 ):
1087 # Resolve parents fields used in raw SQL.
1088 if query.model:
1089 for parent in query.model._meta.get_parent_list():
1090 for parent_field in parent._meta.local_fields:
1091 _, column_name = parent_field.get_attname_column()
1092 if column_name.lower() in self.sql.lower():
1093 query.resolve_ref(
1094 parent_field.name, allow_joins, reuse, summarize
1095 )
1096 break
1097 return super().resolve_expression(
1098 query, allow_joins, reuse, summarize, for_save
1099 )
1102class Star(Expression):
1103 def __repr__(self):
1104 return "'*'"
1106 def as_sql(self, compiler, connection):
1107 return "*", []
1110class Col(Expression):
1112 contains_column_references = True
1113 possibly_multivalued = False
1115 def __init__(self, alias, target, output_field=None):
1116 if output_field is None:
1117 output_field = target
1118 super().__init__(output_field=output_field)
1119 self.alias, self.target = alias, target
1121 def __repr__(self):
1122 alias, target = self.alias, self.target
1123 identifiers = (alias, str(target)) if alias else (str(target),)
1124 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1126 def as_sql(self, compiler, connection):
1127 alias, column = self.alias, self.target.column
1128 identifiers = (alias, column) if alias else (column,)
1129 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1130 return sql, []
1132 def relabeled_clone(self, relabels):
1133 if self.alias is None:
1134 return self
1135 return self.__class__(
1136 relabels.get(self.alias, self.alias), self.target, self.output_field
1137 )
1139 def get_group_by_cols(self):
1140 return [self]
1142 def get_db_converters(self, connection):
1143 if self.target == self.output_field:
1144 return self.output_field.get_db_converters(connection)
1145 return self.output_field.get_db_converters(
1146 connection
1147 ) + self.target.get_db_converters(connection)
1150class Ref(Expression):
1151 """
1152 Reference to column alias of the query. For example, Ref('sum_cost') in
1153 qs.annotate(sum_cost=Sum('cost')) query.
1154 """
1156 def __init__(self, refs, source):
1157 super().__init__()
1158 self.refs, self.source = refs, source
1160 def __repr__(self):
1161 return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
1163 def get_source_expressions(self):
1164 return [self.source]
1166 def set_source_expressions(self, exprs):
1167 (self.source,) = exprs
1169 def resolve_expression(
1170 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1171 ):
1172 # The sub-expression `source` has already been resolved, as this is
1173 # just a reference to the name of `source`.
1174 return self
1176 def get_refs(self):
1177 return {self.refs}
1179 def relabeled_clone(self, relabels):
1180 return self
1182 def as_sql(self, compiler, connection):
1183 return connection.ops.quote_name(self.refs), []
1185 def get_group_by_cols(self):
1186 return [self]
1189class ExpressionList(Func):
1190 """
1191 An expression containing multiple expressions. Can be used to provide a
1192 list of expressions as an argument to another expression, like a partition
1193 clause.
1194 """
1196 template = "%(expressions)s"
1198 def __init__(self, *expressions, **extra):
1199 if not expressions:
1200 raise ValueError(
1201 "%s requires at least one expression." % self.__class__.__name__
1202 )
1203 super().__init__(*expressions, **extra)
1205 def __str__(self):
1206 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1208 def as_sqlite(self, compiler, connection, **extra_context):
1209 # Casting to numeric is unnecessary.
1210 return self.as_sql(compiler, connection, **extra_context)
1213class OrderByList(Func):
1214 template = "ORDER BY %(expressions)s"
1216 def __init__(self, *expressions, **extra):
1217 expressions = (
1218 (
1219 OrderBy(F(expr[1:]), descending=True)
1220 if isinstance(expr, str) and expr[0] == "-"
1221 else expr
1222 )
1223 for expr in expressions
1224 )
1225 super().__init__(*expressions, **extra)
1227 def as_sql(self, *args, **kwargs):
1228 if not self.source_expressions:
1229 return "", ()
1230 return super().as_sql(*args, **kwargs)
1232 def get_group_by_cols(self):
1233 group_by_cols = []
1234 for order_by in self.get_source_expressions():
1235 group_by_cols.extend(order_by.get_group_by_cols())
1236 return group_by_cols
1239@deconstructible(path="django.db.models.ExpressionWrapper")
1240class ExpressionWrapper(SQLiteNumericMixin, Expression):
1241 """
1242 An expression that can wrap another expression so that it can provide
1243 extra context to the inner expression, such as the output_field.
1244 """
1246 def __init__(self, expression, output_field):
1247 super().__init__(output_field=output_field)
1248 self.expression = expression
1250 def set_source_expressions(self, exprs):
1251 self.expression = exprs[0]
1253 def get_source_expressions(self):
1254 return [self.expression]
1256 def get_group_by_cols(self):
1257 if isinstance(self.expression, Expression):
1258 expression = self.expression.copy()
1259 expression.output_field = self.output_field
1260 return expression.get_group_by_cols()
1261 # For non-expressions e.g. an SQL WHERE clause, the entire
1262 # `expression` must be included in the GROUP BY clause.
1263 return super().get_group_by_cols()
1265 def as_sql(self, compiler, connection):
1266 return compiler.compile(self.expression)
1268 def __repr__(self):
1269 return "{}({})".format(self.__class__.__name__, self.expression)
1272class NegatedExpression(ExpressionWrapper):
1273 """The logical negation of a conditional expression."""
1275 def __init__(self, expression):
1276 super().__init__(expression, output_field=fields.BooleanField())
1278 def __invert__(self):
1279 return self.expression.copy()
1281 def as_sql(self, compiler, connection):
1282 try:
1283 sql, params = super().as_sql(compiler, connection)
1284 except EmptyResultSet:
1285 features = compiler.connection.features
1286 if not features.supports_boolean_expr_in_select_clause:
1287 return "1=1", ()
1288 return compiler.compile(Value(True))
1289 ops = compiler.connection.ops
1290 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1291 # to be compared to another expression unless they're wrapped in a CASE
1292 # WHEN.
1293 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1294 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1295 return f"NOT {sql}", params
1297 def resolve_expression(
1298 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1299 ):
1300 resolved = super().resolve_expression(
1301 query, allow_joins, reuse, summarize, for_save
1302 )
1303 if not getattr(resolved.expression, "conditional", False):
1304 raise TypeError("Cannot negate non-conditional expressions.")
1305 return resolved
1307 def select_format(self, compiler, sql, params):
1308 # Wrap boolean expressions with a CASE WHEN expression if a database
1309 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1310 # GROUP BY list.
1311 expression_supported_in_where_clause = (
1312 compiler.connection.ops.conditional_expression_supported_in_where_clause
1313 )
1314 if (
1315 not compiler.connection.features.supports_boolean_expr_in_select_clause
1316 # Avoid double wrapping.
1317 and expression_supported_in_where_clause(self.expression)
1318 ):
1319 sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
1320 return sql, params
1323@deconstructible(path="django.db.models.When")
1324class When(Expression):
1325 template = "WHEN %(condition)s THEN %(result)s"
1326 # This isn't a complete conditional expression, must be used in Case().
1327 conditional = False
1329 def __init__(self, condition=None, then=None, **lookups):
1330 if lookups:
1331 if condition is None:
1332 condition, lookups = Q(**lookups), None
1333 elif getattr(condition, "conditional", False):
1334 condition, lookups = Q(condition, **lookups), None
1335 if condition is None or not getattr(condition, "conditional", False) or lookups:
1336 raise TypeError(
1337 "When() supports a Q object, a boolean expression, or lookups "
1338 "as a condition."
1339 )
1340 if isinstance(condition, Q) and not condition:
1341 raise ValueError("An empty Q() can't be used as a When() condition.")
1342 super().__init__(output_field=None)
1343 self.condition = condition
1344 self.result = self._parse_expressions(then)[0]
1346 def __str__(self):
1347 return "WHEN %r THEN %r" % (self.condition, self.result)
1349 def __repr__(self):
1350 return "<%s: %s>" % (self.__class__.__name__, self)
1352 def get_source_expressions(self):
1353 return [self.condition, self.result]
1355 def set_source_expressions(self, exprs):
1356 self.condition, self.result = exprs
1358 def get_source_fields(self):
1359 # We're only interested in the fields of the result expressions.
1360 return [self.result._output_field_or_none]
1362 def resolve_expression(
1363 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1364 ):
1365 c = self.copy()
1366 c.is_summary = summarize
1367 if hasattr(c.condition, "resolve_expression"):
1368 c.condition = c.condition.resolve_expression(
1369 query, allow_joins, reuse, summarize, False
1370 )
1371 c.result = c.result.resolve_expression(
1372 query, allow_joins, reuse, summarize, for_save
1373 )
1374 return c
1376 def as_sql(self, compiler, connection, template=None, **extra_context):
1377 connection.ops.check_expression_support(self)
1378 template_params = extra_context
1379 sql_params = []
1380 condition_sql, condition_params = compiler.compile(self.condition)
1381 template_params["condition"] = condition_sql
1382 result_sql, result_params = compiler.compile(self.result)
1383 template_params["result"] = result_sql
1384 template = template or self.template
1385 return template % template_params, (
1386 *sql_params,
1387 *condition_params,
1388 *result_params,
1389 )
1391 def get_group_by_cols(self):
1392 # This is not a complete expression and cannot be used in GROUP BY.
1393 cols = []
1394 for source in self.get_source_expressions():
1395 cols.extend(source.get_group_by_cols())
1396 return cols
1399@deconstructible(path="django.db.models.Case")
1400class Case(SQLiteNumericMixin, Expression):
1401 """
1402 An SQL searched CASE expression:
1404 CASE
1405 WHEN n > 0
1406 THEN 'positive'
1407 WHEN n < 0
1408 THEN 'negative'
1409 ELSE 'zero'
1410 END
1411 """
1413 template = "CASE %(cases)s ELSE %(default)s END"
1414 case_joiner = " "
1416 def __init__(self, *cases, default=None, output_field=None, **extra):
1417 if not all(isinstance(case, When) for case in cases):
1418 raise TypeError("Positional arguments must all be When objects.")
1419 super().__init__(output_field)
1420 self.cases = list(cases)
1421 self.default = self._parse_expressions(default)[0]
1422 self.extra = extra
1424 def __str__(self):
1425 return "CASE %s, ELSE %r" % (
1426 ", ".join(str(c) for c in self.cases),
1427 self.default,
1428 )
1430 def __repr__(self):
1431 return "<%s: %s>" % (self.__class__.__name__, self)
1433 def get_source_expressions(self):
1434 return self.cases + [self.default]
1436 def set_source_expressions(self, exprs):
1437 *self.cases, self.default = exprs
1439 def resolve_expression(
1440 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1441 ):
1442 c = self.copy()
1443 c.is_summary = summarize
1444 for pos, case in enumerate(c.cases):
1445 c.cases[pos] = case.resolve_expression(
1446 query, allow_joins, reuse, summarize, for_save
1447 )
1448 c.default = c.default.resolve_expression(
1449 query, allow_joins, reuse, summarize, for_save
1450 )
1451 return c
1453 def copy(self):
1454 c = super().copy()
1455 c.cases = c.cases[:]
1456 return c
1458 def as_sql(
1459 self, compiler, connection, template=None, case_joiner=None, **extra_context
1460 ):
1461 connection.ops.check_expression_support(self)
1462 if not self.cases:
1463 return compiler.compile(self.default)
1464 template_params = {**self.extra, **extra_context}
1465 case_parts = []
1466 sql_params = []
1467 default_sql, default_params = compiler.compile(self.default)
1468 for case in self.cases:
1469 try:
1470 case_sql, case_params = compiler.compile(case)
1471 except EmptyResultSet:
1472 continue
1473 except FullResultSet:
1474 default_sql, default_params = compiler.compile(case.result)
1475 break
1476 case_parts.append(case_sql)
1477 sql_params.extend(case_params)
1478 if not case_parts:
1479 return default_sql, default_params
1480 case_joiner = case_joiner or self.case_joiner
1481 template_params["cases"] = case_joiner.join(case_parts)
1482 template_params["default"] = default_sql
1483 sql_params.extend(default_params)
1484 template = template or template_params.get("template", self.template)
1485 sql = template % template_params
1486 if self._output_field_or_none is not None:
1487 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1488 return sql, sql_params
1490 def get_group_by_cols(self):
1491 if not self.cases:
1492 return self.default.get_group_by_cols()
1493 return super().get_group_by_cols()
1496class Subquery(BaseExpression, Combinable):
1497 """
1498 An explicit subquery. It may contain OuterRef() references to the outer
1499 query which will be resolved when it is applied to that query.
1500 """
1502 template = "(%(subquery)s)"
1503 contains_aggregate = False
1504 empty_result_set_value = None
1506 def __init__(self, queryset, output_field=None, **extra):
1507 # Allow the usage of both QuerySet and sql.Query objects.
1508 self.query = getattr(queryset, "query", queryset).clone()
1509 self.query.subquery = True
1510 self.extra = extra
1511 super().__init__(output_field)
1513 def get_source_expressions(self):
1514 return [self.query]
1516 def set_source_expressions(self, exprs):
1517 self.query = exprs[0]
1519 def _resolve_output_field(self):
1520 return self.query.output_field
1522 def copy(self):
1523 clone = super().copy()
1524 clone.query = clone.query.clone()
1525 return clone
1527 @property
1528 def external_aliases(self):
1529 return self.query.external_aliases
1531 def get_external_cols(self):
1532 return self.query.get_external_cols()
1534 def as_sql(self, compiler, connection, template=None, **extra_context):
1535 connection.ops.check_expression_support(self)
1536 template_params = {**self.extra, **extra_context}
1537 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1538 template_params["subquery"] = subquery_sql[1:-1]
1540 template = template or template_params.get("template", self.template)
1541 sql = template % template_params
1542 return sql, sql_params
1544 def get_group_by_cols(self):
1545 return self.query.get_group_by_cols(wrapper=self)
1548class Exists(Subquery):
1549 template = "EXISTS(%(subquery)s)"
1550 output_field = fields.BooleanField()
1552 def __init__(self, queryset, **kwargs):
1553 super().__init__(queryset, **kwargs)
1554 self.query = self.query.exists()
1556 def select_format(self, compiler, sql, params):
1557 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1558 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1559 # BY list.
1560 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1561 sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
1562 return sql, params
1565@deconstructible(path="django.db.models.OrderBy")
1566class OrderBy(Expression):
1567 template = "%(expression)s %(ordering)s"
1568 conditional = False
1570 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1571 if nulls_first and nulls_last:
1572 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1573 if nulls_first is False or nulls_last is False:
1574 # When the deprecation ends, replace with:
1575 # raise ValueError(
1576 # "nulls_first and nulls_last values must be True or None."
1577 # )
1578 warnings.warn(
1579 "Passing nulls_first=False or nulls_last=False is deprecated, use None "
1580 "instead.",
1581 RemovedInDjango50Warning,
1582 stacklevel=2,
1583 )
1584 self.nulls_first = nulls_first
1585 self.nulls_last = nulls_last
1586 self.descending = descending
1587 if not hasattr(expression, "resolve_expression"):
1588 raise ValueError("expression must be an expression type")
1589 self.expression = expression
1591 def __repr__(self):
1592 return "{}({}, descending={})".format(
1593 self.__class__.__name__, self.expression, self.descending
1594 )
1596 def set_source_expressions(self, exprs):
1597 self.expression = exprs[0]
1599 def get_source_expressions(self):
1600 return [self.expression]
1602 def as_sql(self, compiler, connection, template=None, **extra_context):
1603 template = template or self.template
1604 if connection.features.supports_order_by_nulls_modifier:
1605 if self.nulls_last:
1606 template = "%s NULLS LAST" % template
1607 elif self.nulls_first:
1608 template = "%s NULLS FIRST" % template
1609 else:
1610 if self.nulls_last and not (
1611 self.descending and connection.features.order_by_nulls_first
1612 ):
1613 template = "%%(expression)s IS NULL, %s" % template
1614 elif self.nulls_first and not (
1615 not self.descending and connection.features.order_by_nulls_first
1616 ):
1617 template = "%%(expression)s IS NOT NULL, %s" % template
1618 connection.ops.check_expression_support(self)
1619 expression_sql, params = compiler.compile(self.expression)
1620 placeholders = {
1621 "expression": expression_sql,
1622 "ordering": "DESC" if self.descending else "ASC",
1623 **extra_context,
1624 }
1625 params *= template.count("%(expression)s")
1626 return (template % placeholders).rstrip(), params
1628 def as_oracle(self, compiler, connection):
1629 # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped
1630 # in a CASE WHEN.
1631 if connection.ops.conditional_expression_supported_in_where_clause(
1632 self.expression
1633 ):
1634 copy = self.copy()
1635 copy.expression = Case(
1636 When(self.expression, then=True),
1637 default=False,
1638 )
1639 return copy.as_sql(compiler, connection)
1640 return self.as_sql(compiler, connection)
1642 def get_group_by_cols(self):
1643 cols = []
1644 for source in self.get_source_expressions():
1645 cols.extend(source.get_group_by_cols())
1646 return cols
1648 def reverse_ordering(self):
1649 self.descending = not self.descending
1650 if self.nulls_first:
1651 self.nulls_last = True
1652 self.nulls_first = None
1653 elif self.nulls_last:
1654 self.nulls_first = True
1655 self.nulls_last = None
1656 return self
1658 def asc(self):
1659 self.descending = False
1661 def desc(self):
1662 self.descending = True
1665class Window(SQLiteNumericMixin, Expression):
1666 template = "%(expression)s OVER (%(window)s)"
1667 # Although the main expression may either be an aggregate or an
1668 # expression with an aggregate function, the GROUP BY that will
1669 # be introduced in the query as a result is not desired.
1670 contains_aggregate = False
1671 contains_over_clause = True
1673 def __init__(
1674 self,
1675 expression,
1676 partition_by=None,
1677 order_by=None,
1678 frame=None,
1679 output_field=None,
1680 ):
1681 self.partition_by = partition_by
1682 self.order_by = order_by
1683 self.frame = frame
1685 if not getattr(expression, "window_compatible", False):
1686 raise ValueError(
1687 "Expression '%s' isn't compatible with OVER clauses."
1688 % expression.__class__.__name__
1689 )
1691 if self.partition_by is not None:
1692 if not isinstance(self.partition_by, (tuple, list)):
1693 self.partition_by = (self.partition_by,)
1694 self.partition_by = ExpressionList(*self.partition_by)
1696 if self.order_by is not None:
1697 if isinstance(self.order_by, (list, tuple)):
1698 self.order_by = OrderByList(*self.order_by)
1699 elif isinstance(self.order_by, (BaseExpression, str)):
1700 self.order_by = OrderByList(self.order_by)
1701 else:
1702 raise ValueError(
1703 "Window.order_by must be either a string reference to a "
1704 "field, an expression, or a list or tuple of them."
1705 )
1706 super().__init__(output_field=output_field)
1707 self.source_expression = self._parse_expressions(expression)[0]
1709 def _resolve_output_field(self):
1710 return self.source_expression.output_field
1712 def get_source_expressions(self):
1713 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1715 def set_source_expressions(self, exprs):
1716 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1718 def as_sql(self, compiler, connection, template=None):
1719 connection.ops.check_expression_support(self)
1720 if not connection.features.supports_over_clause:
1721 raise NotSupportedError("This backend does not support window expressions.")
1722 expr_sql, params = compiler.compile(self.source_expression)
1723 window_sql, window_params = [], ()
1725 if self.partition_by is not None:
1726 sql_expr, sql_params = self.partition_by.as_sql(
1727 compiler=compiler,
1728 connection=connection,
1729 template="PARTITION BY %(expressions)s",
1730 )
1731 window_sql.append(sql_expr)
1732 window_params += tuple(sql_params)
1734 if self.order_by is not None:
1735 order_sql, order_params = compiler.compile(self.order_by)
1736 window_sql.append(order_sql)
1737 window_params += tuple(order_params)
1739 if self.frame:
1740 frame_sql, frame_params = compiler.compile(self.frame)
1741 window_sql.append(frame_sql)
1742 window_params += tuple(frame_params)
1744 template = template or self.template
1746 return (
1747 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1748 (*params, *window_params),
1749 )
1751 def as_sqlite(self, compiler, connection):
1752 if isinstance(self.output_field, fields.DecimalField):
1753 # Casting to numeric must be outside of the window expression.
1754 copy = self.copy()
1755 source_expressions = copy.get_source_expressions()
1756 source_expressions[0].output_field = fields.FloatField()
1757 copy.set_source_expressions(source_expressions)
1758 return super(Window, copy).as_sqlite(compiler, connection)
1759 return self.as_sql(compiler, connection)
1761 def __str__(self):
1762 return "{} OVER ({}{}{})".format(
1763 str(self.source_expression),
1764 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1765 str(self.order_by or ""),
1766 str(self.frame or ""),
1767 )
1769 def __repr__(self):
1770 return "<%s: %s>" % (self.__class__.__name__, self)
1772 def get_group_by_cols(self):
1773 group_by_cols = []
1774 if self.partition_by:
1775 group_by_cols.extend(self.partition_by.get_group_by_cols())
1776 if self.order_by is not None:
1777 group_by_cols.extend(self.order_by.get_group_by_cols())
1778 return group_by_cols
1781class WindowFrame(Expression):
1782 """
1783 Model the frame clause in window expressions. There are two types of frame
1784 clauses which are subclasses, however, all processing and validation (by no
1785 means intended to be complete) is done here. Thus, providing an end for a
1786 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1787 row in the frame).
1788 """
1790 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1792 def __init__(self, start=None, end=None):
1793 self.start = Value(start)
1794 self.end = Value(end)
1796 def set_source_expressions(self, exprs):
1797 self.start, self.end = exprs
1799 def get_source_expressions(self):
1800 return [self.start, self.end]
1802 def as_sql(self, compiler, connection):
1803 connection.ops.check_expression_support(self)
1804 start, end = self.window_frame_start_end(
1805 connection, self.start.value, self.end.value
1806 )
1807 return (
1808 self.template
1809 % {
1810 "frame_type": self.frame_type,
1811 "start": start,
1812 "end": end,
1813 },
1814 [],
1815 )
1817 def __repr__(self):
1818 return "<%s: %s>" % (self.__class__.__name__, self)
1820 def get_group_by_cols(self):
1821 return []
1823 def __str__(self):
1824 if self.start.value is not None and self.start.value < 0:
1825 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
1826 elif self.start.value is not None and self.start.value == 0:
1827 start = connection.ops.CURRENT_ROW
1828 else:
1829 start = connection.ops.UNBOUNDED_PRECEDING
1831 if self.end.value is not None and self.end.value > 0:
1832 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
1833 elif self.end.value is not None and self.end.value == 0:
1834 end = connection.ops.CURRENT_ROW
1835 else:
1836 end = connection.ops.UNBOUNDED_FOLLOWING
1837 return self.template % {
1838 "frame_type": self.frame_type,
1839 "start": start,
1840 "end": end,
1841 }
1843 def window_frame_start_end(self, connection, start, end):
1844 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1847class RowRange(WindowFrame):
1848 frame_type = "ROWS"
1850 def window_frame_start_end(self, connection, start, end):
1851 return connection.ops.window_frame_rows_start_end(start, end)
1854class ValueRange(WindowFrame):
1855 frame_type = "RANGE"
1857 def window_frame_start_end(self, connection, start, end):
1858 return connection.ops.window_frame_range_start_end(start, end)