Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/aggregates.py: 42%
127 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
1"""
2Classes to represent the definitions of aggregate functions.
3"""
4from django.core.exceptions import FieldError, FullResultSet
5from django.db.models.expressions import Case, Func, Star, Value, When
6from django.db.models.fields import IntegerField
7from django.db.models.functions.comparison import Coalesce
8from django.db.models.functions.mixins import (
9 FixDurationInputMixin,
10 NumericOutputFieldMixin,
11)
13__all__ = [
14 "Aggregate",
15 "Avg",
16 "Count",
17 "Max",
18 "Min",
19 "StdDev",
20 "Sum",
21 "Variance",
22]
25class Aggregate(Func):
26 template = "%(function)s(%(distinct)s%(expressions)s)"
27 contains_aggregate = True
28 name = None
29 filter_template = "%s FILTER (WHERE %%(filter)s)"
30 window_compatible = True
31 allow_distinct = False
32 empty_result_set_value = None
34 def __init__(
35 self, *expressions, distinct=False, filter=None, default=None, **extra
36 ):
37 if distinct and not self.allow_distinct:
38 raise TypeError("%s does not allow distinct." % self.__class__.__name__)
39 if default is not None and self.empty_result_set_value is not None:
40 raise TypeError(f"{self.__class__.__name__} does not allow default.")
41 self.distinct = distinct
42 self.filter = filter
43 self.default = default
44 super().__init__(*expressions, **extra)
46 def get_source_fields(self):
47 # Don't return the filter expression since it's not a source field.
48 return [e._output_field_or_none for e in super().get_source_expressions()]
50 def get_source_expressions(self):
51 source_expressions = super().get_source_expressions()
52 if self.filter:
53 return source_expressions + [self.filter]
54 return source_expressions
56 def set_source_expressions(self, exprs):
57 self.filter = self.filter and exprs.pop()
58 return super().set_source_expressions(exprs)
60 def resolve_expression(
61 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
62 ):
63 # Aggregates are not allowed in UPDATE queries, so ignore for_save
64 c = super().resolve_expression(query, allow_joins, reuse, summarize)
65 c.filter = c.filter and c.filter.resolve_expression(
66 query, allow_joins, reuse, summarize
67 )
68 if not summarize:
69 # Call Aggregate.get_source_expressions() to avoid
70 # returning self.filter and including that in this loop.
71 expressions = super(Aggregate, c).get_source_expressions()
72 for index, expr in enumerate(expressions):
73 if expr.contains_aggregate:
74 before_resolved = self.get_source_expressions()[index]
75 name = (
76 before_resolved.name
77 if hasattr(before_resolved, "name")
78 else repr(before_resolved)
79 )
80 raise FieldError(
81 "Cannot compute %s('%s'): '%s' is an aggregate"
82 % (c.name, name, name)
83 )
84 if (default := c.default) is None:
85 return c
86 if hasattr(default, "resolve_expression"):
87 default = default.resolve_expression(query, allow_joins, reuse, summarize)
88 if default._output_field_or_none is None:
89 default.output_field = c._output_field_or_none
90 else:
91 default = Value(default, c._output_field_or_none)
92 c.default = None # Reset the default argument before wrapping.
93 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
94 coalesce.is_summary = c.is_summary
95 return coalesce
97 @property
98 def default_alias(self):
99 expressions = self.get_source_expressions()
100 if len(expressions) == 1 and hasattr(expressions[0], "name"):
101 return "%s__%s" % (expressions[0].name, self.name.lower())
102 raise TypeError("Complex expressions require an alias")
104 def get_group_by_cols(self):
105 return []
107 def as_sql(self, compiler, connection, **extra_context):
108 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
109 if self.filter:
110 if connection.features.supports_aggregate_filter_clause:
111 try:
112 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
113 except FullResultSet:
114 pass
115 else:
116 template = self.filter_template % extra_context.get(
117 "template", self.template
118 )
119 sql, params = super().as_sql(
120 compiler,
121 connection,
122 template=template,
123 filter=filter_sql,
124 **extra_context,
125 )
126 return sql, (*params, *filter_params)
127 else:
128 copy = self.copy()
129 copy.filter = None
130 source_expressions = copy.get_source_expressions()
131 condition = When(self.filter, then=source_expressions[0])
132 copy.set_source_expressions([Case(condition)] + source_expressions[1:])
133 return super(Aggregate, copy).as_sql(
134 compiler, connection, **extra_context
135 )
136 return super().as_sql(compiler, connection, **extra_context)
138 def _get_repr_options(self):
139 options = super()._get_repr_options()
140 if self.distinct:
141 options["distinct"] = self.distinct
142 if self.filter:
143 options["filter"] = self.filter
144 return options
147class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
148 function = "AVG"
149 name = "Avg"
150 allow_distinct = True
153class Count(Aggregate):
154 function = "COUNT"
155 name = "Count"
156 output_field = IntegerField()
157 allow_distinct = True
158 empty_result_set_value = 0
160 def __init__(self, expression, filter=None, **extra):
161 if expression == "*":
162 expression = Star()
163 if isinstance(expression, Star) and filter is not None:
164 raise ValueError("Star cannot be used with filter. Please specify a field.")
165 super().__init__(expression, filter=filter, **extra)
168class Max(Aggregate):
169 function = "MAX"
170 name = "Max"
173class Min(Aggregate):
174 function = "MIN"
175 name = "Min"
178class StdDev(NumericOutputFieldMixin, Aggregate):
179 name = "StdDev"
181 def __init__(self, expression, sample=False, **extra):
182 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
183 super().__init__(expression, **extra)
185 def _get_repr_options(self):
186 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
189class Sum(FixDurationInputMixin, Aggregate):
190 function = "SUM"
191 name = "Sum"
192 allow_distinct = True
195class Variance(NumericOutputFieldMixin, Aggregate):
196 name = "Variance"
198 def __init__(self, expression, sample=False, **extra):
199 self.function = "VAR_SAMP" if sample else "VAR_POP"
200 super().__init__(expression, **extra)
202 def _get_repr_options(self):
203 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}