Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/django/db/models/aggregates.py: 42%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

146 statements  

1""" 

2Classes to represent the definitions of aggregate functions. 

3""" 

4 

5from django.core.exceptions import FieldError, FullResultSet 

6from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When 

7from django.db.models.fields import IntegerField 

8from django.db.models.functions import Coalesce 

9from django.db.models.functions.mixins import ( 

10 FixDurationInputMixin, 

11 NumericOutputFieldMixin, 

12) 

13 

14__all__ = [ 

15 "Aggregate", 

16 "Avg", 

17 "Count", 

18 "Max", 

19 "Min", 

20 "StdDev", 

21 "Sum", 

22 "Variance", 

23] 

24 

25 

26class Aggregate(Func): 

27 template = "%(function)s(%(distinct)s%(expressions)s)" 

28 contains_aggregate = True 

29 name = None 

30 filter_template = "%s FILTER (WHERE %%(filter)s)" 

31 window_compatible = True 

32 allow_distinct = False 

33 empty_result_set_value = None 

34 

35 def __init__( 

36 self, *expressions, distinct=False, filter=None, default=None, **extra 

37 ): 

38 if distinct and not self.allow_distinct: 

39 raise TypeError("%s does not allow distinct." % self.__class__.__name__) 

40 if default is not None and self.empty_result_set_value is not None: 

41 raise TypeError(f"{self.__class__.__name__} does not allow default.") 

42 self.distinct = distinct 

43 self.filter = filter 

44 self.default = default 

45 super().__init__(*expressions, **extra) 

46 

47 def get_source_fields(self): 

48 # Don't return the filter expression since it's not a source field. 

49 return [e._output_field_or_none for e in super().get_source_expressions()] 

50 

51 def get_source_expressions(self): 

52 source_expressions = super().get_source_expressions() 

53 return source_expressions + [self.filter] 

54 

55 def set_source_expressions(self, exprs): 

56 *exprs, self.filter = exprs 

57 return super().set_source_expressions(exprs) 

58 

59 def resolve_expression( 

60 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

61 ): 

62 # Aggregates are not allowed in UPDATE queries, so ignore for_save 

63 c = super().resolve_expression(query, allow_joins, reuse, summarize) 

64 c.filter = ( 

65 c.filter.resolve_expression(query, allow_joins, reuse, summarize) 

66 if c.filter 

67 else None 

68 ) 

69 if summarize: 

70 # Summarized aggregates cannot refer to summarized aggregates. 

71 for ref in c.get_refs(): 

72 if query.annotations[ref].is_summary: 

73 raise FieldError( 

74 f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate" 

75 ) 

76 elif not self.is_summary: 

77 # Call Aggregate.get_source_expressions() to avoid 

78 # returning self.filter and including that in this loop. 

79 expressions = super(Aggregate, c).get_source_expressions() 

80 for index, expr in enumerate(expressions): 

81 if expr.contains_aggregate: 

82 before_resolved = self.get_source_expressions()[index] 

83 name = ( 

84 before_resolved.name 

85 if hasattr(before_resolved, "name") 

86 else repr(before_resolved) 

87 ) 

88 raise FieldError( 

89 "Cannot compute %s('%s'): '%s' is an aggregate" 

90 % (c.name, name, name) 

91 ) 

92 if (default := c.default) is None: 

93 return c 

94 if hasattr(default, "resolve_expression"): 

95 default = default.resolve_expression(query, allow_joins, reuse, summarize) 

96 if default._output_field_or_none is None: 

97 default.output_field = c._output_field_or_none 

98 else: 

99 default = Value(default, c._output_field_or_none) 

100 c.default = None # Reset the default argument before wrapping. 

101 coalesce = Coalesce(c, default, output_field=c._output_field_or_none) 

102 coalesce.is_summary = c.is_summary 

103 return coalesce 

104 

105 @property 

106 def default_alias(self): 

107 expressions = [ 

108 expr for expr in self.get_source_expressions() if expr is not None 

109 ] 

110 if len(expressions) == 1 and hasattr(expressions[0], "name"): 

111 return "%s__%s" % (expressions[0].name, self.name.lower()) 

112 raise TypeError("Complex expressions require an alias") 

113 

114 def get_group_by_cols(self): 

115 return [] 

116 

117 def as_sql(self, compiler, connection, **extra_context): 

118 extra_context["distinct"] = "DISTINCT " if self.distinct else "" 

119 if self.filter: 

120 if connection.features.supports_aggregate_filter_clause: 

121 try: 

122 filter_sql, filter_params = self.filter.as_sql(compiler, connection) 

123 except FullResultSet: 

124 pass 

125 else: 

126 template = self.filter_template % extra_context.get( 

127 "template", self.template 

128 ) 

129 sql, params = super().as_sql( 

130 compiler, 

131 connection, 

132 template=template, 

133 filter=filter_sql, 

134 **extra_context, 

135 ) 

136 return sql, (*params, *filter_params) 

137 else: 

138 copy = self.copy() 

139 copy.filter = None 

140 source_expressions = copy.get_source_expressions() 

141 condition = When(self.filter, then=source_expressions[0]) 

142 copy.set_source_expressions([Case(condition)] + source_expressions[1:]) 

143 return super(Aggregate, copy).as_sql( 

144 compiler, connection, **extra_context 

145 ) 

146 return super().as_sql(compiler, connection, **extra_context) 

147 

148 def _get_repr_options(self): 

149 options = super()._get_repr_options() 

150 if self.distinct: 

151 options["distinct"] = self.distinct 

152 if self.filter: 

153 options["filter"] = self.filter 

154 return options 

155 

156 

157class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): 

158 function = "AVG" 

159 name = "Avg" 

160 allow_distinct = True 

161 arity = 1 

162 

163 

164class Count(Aggregate): 

165 function = "COUNT" 

166 name = "Count" 

167 output_field = IntegerField() 

168 allow_distinct = True 

169 empty_result_set_value = 0 

170 arity = 1 

171 allows_composite_expressions = True 

172 

173 def __init__(self, expression, filter=None, **extra): 

174 if expression == "*": 

175 expression = Star() 

176 if isinstance(expression, Star) and filter is not None: 

177 raise ValueError("Star cannot be used with filter. Please specify a field.") 

178 super().__init__(expression, filter=filter, **extra) 

179 

180 def resolve_expression(self, *args, **kwargs): 

181 result = super().resolve_expression(*args, **kwargs) 

182 expr = result.source_expressions[0] 

183 

184 # In case of composite primary keys, count the first column. 

185 if isinstance(expr, ColPairs): 

186 if self.distinct: 

187 raise ValueError( 

188 "COUNT(DISTINCT) doesn't support composite primary keys" 

189 ) 

190 

191 cols = expr.get_cols() 

192 return Count(cols[0], filter=result.filter) 

193 

194 return result 

195 

196 

197class Max(Aggregate): 

198 function = "MAX" 

199 name = "Max" 

200 arity = 1 

201 

202 

203class Min(Aggregate): 

204 function = "MIN" 

205 name = "Min" 

206 arity = 1 

207 

208 

209class StdDev(NumericOutputFieldMixin, Aggregate): 

210 name = "StdDev" 

211 arity = 1 

212 

213 def __init__(self, expression, sample=False, **extra): 

214 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" 

215 super().__init__(expression, **extra) 

216 

217 def _get_repr_options(self): 

218 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} 

219 

220 

221class Sum(FixDurationInputMixin, Aggregate): 

222 function = "SUM" 

223 name = "Sum" 

224 allow_distinct = True 

225 arity = 1 

226 

227 

228class Variance(NumericOutputFieldMixin, Aggregate): 

229 name = "Variance" 

230 arity = 1 

231 

232 def __init__(self, expression, sample=False, **extra): 

233 self.function = "VAR_SAMP" if sample else "VAR_POP" 

234 super().__init__(expression, **extra) 

235 

236 def _get_repr_options(self): 

237 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}