1"""
2Classes to represent the definitions of aggregate functions.
3"""
4
5from django.core.exceptions import FieldError, FullResultSet
6from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
7from django.db.models.fields import IntegerField
8from django.db.models.functions import Coalesce
9from django.db.models.functions.mixins import (
10 FixDurationInputMixin,
11 NumericOutputFieldMixin,
12)
13
14__all__ = [
15 "Aggregate",
16 "Avg",
17 "Count",
18 "Max",
19 "Min",
20 "StdDev",
21 "Sum",
22 "Variance",
23]
24
25
26class Aggregate(Func):
27 template = "%(function)s(%(distinct)s%(expressions)s)"
28 contains_aggregate = True
29 name = None
30 filter_template = "%s FILTER (WHERE %%(filter)s)"
31 window_compatible = True
32 allow_distinct = False
33 empty_result_set_value = None
34
35 def __init__(
36 self, *expressions, distinct=False, filter=None, default=None, **extra
37 ):
38 if distinct and not self.allow_distinct:
39 raise TypeError("%s does not allow distinct." % self.__class__.__name__)
40 if default is not None and self.empty_result_set_value is not None:
41 raise TypeError(f"{self.__class__.__name__} does not allow default.")
42 self.distinct = distinct
43 self.filter = filter
44 self.default = default
45 super().__init__(*expressions, **extra)
46
47 def get_source_fields(self):
48 # Don't return the filter expression since it's not a source field.
49 return [e._output_field_or_none for e in super().get_source_expressions()]
50
51 def get_source_expressions(self):
52 source_expressions = super().get_source_expressions()
53 return source_expressions + [self.filter]
54
55 def set_source_expressions(self, exprs):
56 *exprs, self.filter = exprs
57 return super().set_source_expressions(exprs)
58
59 def resolve_expression(
60 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
61 ):
62 # Aggregates are not allowed in UPDATE queries, so ignore for_save
63 c = super().resolve_expression(query, allow_joins, reuse, summarize)
64 c.filter = (
65 c.filter.resolve_expression(query, allow_joins, reuse, summarize)
66 if c.filter
67 else None
68 )
69 if summarize:
70 # Summarized aggregates cannot refer to summarized aggregates.
71 for ref in c.get_refs():
72 if query.annotations[ref].is_summary:
73 raise FieldError(
74 f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
75 )
76 elif not self.is_summary:
77 # Call Aggregate.get_source_expressions() to avoid
78 # returning self.filter and including that in this loop.
79 expressions = super(Aggregate, c).get_source_expressions()
80 for index, expr in enumerate(expressions):
81 if expr.contains_aggregate:
82 before_resolved = self.get_source_expressions()[index]
83 name = (
84 before_resolved.name
85 if hasattr(before_resolved, "name")
86 else repr(before_resolved)
87 )
88 raise FieldError(
89 "Cannot compute %s('%s'): '%s' is an aggregate"
90 % (c.name, name, name)
91 )
92 if (default := c.default) is None:
93 return c
94 if hasattr(default, "resolve_expression"):
95 default = default.resolve_expression(query, allow_joins, reuse, summarize)
96 if default._output_field_or_none is None:
97 default.output_field = c._output_field_or_none
98 else:
99 default = Value(default, c._output_field_or_none)
100 c.default = None # Reset the default argument before wrapping.
101 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
102 coalesce.is_summary = c.is_summary
103 return coalesce
104
105 @property
106 def default_alias(self):
107 expressions = [
108 expr for expr in self.get_source_expressions() if expr is not None
109 ]
110 if len(expressions) == 1 and hasattr(expressions[0], "name"):
111 return "%s__%s" % (expressions[0].name, self.name.lower())
112 raise TypeError("Complex expressions require an alias")
113
114 def get_group_by_cols(self):
115 return []
116
117 def as_sql(self, compiler, connection, **extra_context):
118 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
119 if self.filter:
120 if connection.features.supports_aggregate_filter_clause:
121 try:
122 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
123 except FullResultSet:
124 pass
125 else:
126 template = self.filter_template % extra_context.get(
127 "template", self.template
128 )
129 sql, params = super().as_sql(
130 compiler,
131 connection,
132 template=template,
133 filter=filter_sql,
134 **extra_context,
135 )
136 return sql, (*params, *filter_params)
137 else:
138 copy = self.copy()
139 copy.filter = None
140 source_expressions = copy.get_source_expressions()
141 condition = When(self.filter, then=source_expressions[0])
142 copy.set_source_expressions([Case(condition)] + source_expressions[1:])
143 return super(Aggregate, copy).as_sql(
144 compiler, connection, **extra_context
145 )
146 return super().as_sql(compiler, connection, **extra_context)
147
148 def _get_repr_options(self):
149 options = super()._get_repr_options()
150 if self.distinct:
151 options["distinct"] = self.distinct
152 if self.filter:
153 options["filter"] = self.filter
154 return options
155
156
157class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
158 function = "AVG"
159 name = "Avg"
160 allow_distinct = True
161 arity = 1
162
163
164class Count(Aggregate):
165 function = "COUNT"
166 name = "Count"
167 output_field = IntegerField()
168 allow_distinct = True
169 empty_result_set_value = 0
170 arity = 1
171 allows_composite_expressions = True
172
173 def __init__(self, expression, filter=None, **extra):
174 if expression == "*":
175 expression = Star()
176 if isinstance(expression, Star) and filter is not None:
177 raise ValueError("Star cannot be used with filter. Please specify a field.")
178 super().__init__(expression, filter=filter, **extra)
179
180 def resolve_expression(self, *args, **kwargs):
181 result = super().resolve_expression(*args, **kwargs)
182 expr = result.source_expressions[0]
183
184 # In case of composite primary keys, count the first column.
185 if isinstance(expr, ColPairs):
186 if self.distinct:
187 raise ValueError(
188 "COUNT(DISTINCT) doesn't support composite primary keys"
189 )
190
191 cols = expr.get_cols()
192 return Count(cols[0], filter=result.filter)
193
194 return result
195
196
197class Max(Aggregate):
198 function = "MAX"
199 name = "Max"
200 arity = 1
201
202
203class Min(Aggregate):
204 function = "MIN"
205 name = "Min"
206 arity = 1
207
208
209class StdDev(NumericOutputFieldMixin, Aggregate):
210 name = "StdDev"
211 arity = 1
212
213 def __init__(self, expression, sample=False, **extra):
214 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
215 super().__init__(expression, **extra)
216
217 def _get_repr_options(self):
218 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
219
220
221class Sum(FixDurationInputMixin, Aggregate):
222 function = "SUM"
223 name = "Sum"
224 allow_distinct = True
225 arity = 1
226
227
228class Variance(NumericOutputFieldMixin, Aggregate):
229 name = "Variance"
230 arity = 1
231
232 def __init__(self, expression, sample=False, **extra):
233 self.function = "VAR_SAMP" if sample else "VAR_POP"
234 super().__init__(expression, **extra)
235
236 def _get_repr_options(self):
237 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}