Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/aggregates.py: 42%

127 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1""" 

2Classes to represent the definitions of aggregate functions. 

3""" 

4from django.core.exceptions import FieldError, FullResultSet 

5from django.db.models.expressions import Case, Func, Star, Value, When 

6from django.db.models.fields import IntegerField 

7from django.db.models.functions.comparison import Coalesce 

8from django.db.models.functions.mixins import ( 

9 FixDurationInputMixin, 

10 NumericOutputFieldMixin, 

11) 

12 

13__all__ = [ 

14 "Aggregate", 

15 "Avg", 

16 "Count", 

17 "Max", 

18 "Min", 

19 "StdDev", 

20 "Sum", 

21 "Variance", 

22] 

23 

24 

25class Aggregate(Func): 

26 template = "%(function)s(%(distinct)s%(expressions)s)" 

27 contains_aggregate = True 

28 name = None 

29 filter_template = "%s FILTER (WHERE %%(filter)s)" 

30 window_compatible = True 

31 allow_distinct = False 

32 empty_result_set_value = None 

33 

34 def __init__( 

35 self, *expressions, distinct=False, filter=None, default=None, **extra 

36 ): 

37 if distinct and not self.allow_distinct: 

38 raise TypeError("%s does not allow distinct." % self.__class__.__name__) 

39 if default is not None and self.empty_result_set_value is not None: 

40 raise TypeError(f"{self.__class__.__name__} does not allow default.") 

41 self.distinct = distinct 

42 self.filter = filter 

43 self.default = default 

44 super().__init__(*expressions, **extra) 

45 

46 def get_source_fields(self): 

47 # Don't return the filter expression since it's not a source field. 

48 return [e._output_field_or_none for e in super().get_source_expressions()] 

49 

50 def get_source_expressions(self): 

51 source_expressions = super().get_source_expressions() 

52 if self.filter: 

53 return source_expressions + [self.filter] 

54 return source_expressions 

55 

56 def set_source_expressions(self, exprs): 

57 self.filter = self.filter and exprs.pop() 

58 return super().set_source_expressions(exprs) 

59 

60 def resolve_expression( 

61 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

62 ): 

63 # Aggregates are not allowed in UPDATE queries, so ignore for_save 

64 c = super().resolve_expression(query, allow_joins, reuse, summarize) 

65 c.filter = c.filter and c.filter.resolve_expression( 

66 query, allow_joins, reuse, summarize 

67 ) 

68 if not summarize: 

69 # Call Aggregate.get_source_expressions() to avoid 

70 # returning self.filter and including that in this loop. 

71 expressions = super(Aggregate, c).get_source_expressions() 

72 for index, expr in enumerate(expressions): 

73 if expr.contains_aggregate: 

74 before_resolved = self.get_source_expressions()[index] 

75 name = ( 

76 before_resolved.name 

77 if hasattr(before_resolved, "name") 

78 else repr(before_resolved) 

79 ) 

80 raise FieldError( 

81 "Cannot compute %s('%s'): '%s' is an aggregate" 

82 % (c.name, name, name) 

83 ) 

84 if (default := c.default) is None: 

85 return c 

86 if hasattr(default, "resolve_expression"): 

87 default = default.resolve_expression(query, allow_joins, reuse, summarize) 

88 if default._output_field_or_none is None: 

89 default.output_field = c._output_field_or_none 

90 else: 

91 default = Value(default, c._output_field_or_none) 

92 c.default = None # Reset the default argument before wrapping. 

93 coalesce = Coalesce(c, default, output_field=c._output_field_or_none) 

94 coalesce.is_summary = c.is_summary 

95 return coalesce 

96 

97 @property 

98 def default_alias(self): 

99 expressions = self.get_source_expressions() 

100 if len(expressions) == 1 and hasattr(expressions[0], "name"): 

101 return "%s__%s" % (expressions[0].name, self.name.lower()) 

102 raise TypeError("Complex expressions require an alias") 

103 

104 def get_group_by_cols(self): 

105 return [] 

106 

107 def as_sql(self, compiler, connection, **extra_context): 

108 extra_context["distinct"] = "DISTINCT " if self.distinct else "" 

109 if self.filter: 

110 if connection.features.supports_aggregate_filter_clause: 

111 try: 

112 filter_sql, filter_params = self.filter.as_sql(compiler, connection) 

113 except FullResultSet: 

114 pass 

115 else: 

116 template = self.filter_template % extra_context.get( 

117 "template", self.template 

118 ) 

119 sql, params = super().as_sql( 

120 compiler, 

121 connection, 

122 template=template, 

123 filter=filter_sql, 

124 **extra_context, 

125 ) 

126 return sql, (*params, *filter_params) 

127 else: 

128 copy = self.copy() 

129 copy.filter = None 

130 source_expressions = copy.get_source_expressions() 

131 condition = When(self.filter, then=source_expressions[0]) 

132 copy.set_source_expressions([Case(condition)] + source_expressions[1:]) 

133 return super(Aggregate, copy).as_sql( 

134 compiler, connection, **extra_context 

135 ) 

136 return super().as_sql(compiler, connection, **extra_context) 

137 

138 def _get_repr_options(self): 

139 options = super()._get_repr_options() 

140 if self.distinct: 

141 options["distinct"] = self.distinct 

142 if self.filter: 

143 options["filter"] = self.filter 

144 return options 

145 

146 

147class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): 

148 function = "AVG" 

149 name = "Avg" 

150 allow_distinct = True 

151 

152 

153class Count(Aggregate): 

154 function = "COUNT" 

155 name = "Count" 

156 output_field = IntegerField() 

157 allow_distinct = True 

158 empty_result_set_value = 0 

159 

160 def __init__(self, expression, filter=None, **extra): 

161 if expression == "*": 

162 expression = Star() 

163 if isinstance(expression, Star) and filter is not None: 

164 raise ValueError("Star cannot be used with filter. Please specify a field.") 

165 super().__init__(expression, filter=filter, **extra) 

166 

167 

168class Max(Aggregate): 

169 function = "MAX" 

170 name = "Max" 

171 

172 

173class Min(Aggregate): 

174 function = "MIN" 

175 name = "Min" 

176 

177 

178class StdDev(NumericOutputFieldMixin, Aggregate): 

179 name = "StdDev" 

180 

181 def __init__(self, expression, sample=False, **extra): 

182 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" 

183 super().__init__(expression, **extra) 

184 

185 def _get_repr_options(self): 

186 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} 

187 

188 

189class Sum(FixDurationInputMixin, Aggregate): 

190 function = "SUM" 

191 name = "Sum" 

192 allow_distinct = True 

193 

194 

195class Variance(NumericOutputFieldMixin, Aggregate): 

196 name = "Variance" 

197 

198 def __init__(self, expression, sample=False, **extra): 

199 self.function = "VAR_SAMP" if sample else "VAR_POP" 

200 super().__init__(expression, **extra) 

201 

202 def _get_repr_options(self): 

203 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}