Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/sql/where.py: 29%

211 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1""" 

2Code to manage the creation and SQL rendering of 'where' constraints. 

3""" 

4import operator 

5from functools import reduce 

6 

7from django.core.exceptions import EmptyResultSet, FullResultSet 

8from django.db.models.expressions import Case, When 

9from django.db.models.lookups import Exact 

10from django.utils import tree 

11from django.utils.functional import cached_property 

12 

13# Connection types 

14AND = "AND" 

15OR = "OR" 

16XOR = "XOR" 

17 

18 

19class WhereNode(tree.Node): 

20 """ 

21 An SQL WHERE clause. 

22 

23 The class is tied to the Query class that created it (in order to create 

24 the correct SQL). 

25 

26 A child is usually an expression producing boolean values. Most likely the 

27 expression is a Lookup instance. 

28 

29 However, a child could also be any class with as_sql() and either 

30 relabeled_clone() method or relabel_aliases() and clone() methods and 

31 contains_aggregate attribute. 

32 """ 

33 

34 default = AND 

35 resolved = False 

36 conditional = True 

37 

38 def split_having_qualify(self, negated=False, must_group_by=False): 

39 """ 

40 Return three possibly None nodes: one for those parts of self that 

41 should be included in the WHERE clause, one for those parts of self 

42 that must be included in the HAVING clause, and one for those parts 

43 that refer to window functions. 

44 """ 

45 if not self.contains_aggregate and not self.contains_over_clause: 

46 return self, None, None 

47 in_negated = negated ^ self.negated 

48 # Whether or not children must be connected in the same filtering 

49 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic. 

50 must_remain_connected = ( 

51 (in_negated and self.connector == AND) 

52 or (not in_negated and self.connector == OR) 

53 or self.connector == XOR 

54 ) 

55 if ( 

56 must_remain_connected 

57 and self.contains_aggregate 

58 and not self.contains_over_clause 

59 ): 

60 # It's must cheaper to short-circuit and stash everything in the 

61 # HAVING clause than split children if possible. 

62 return None, self, None 

63 where_parts = [] 

64 having_parts = [] 

65 qualify_parts = [] 

66 for c in self.children: 

67 if hasattr(c, "split_having_qualify"): 

68 where_part, having_part, qualify_part = c.split_having_qualify( 

69 in_negated, must_group_by 

70 ) 

71 if where_part is not None: 

72 where_parts.append(where_part) 

73 if having_part is not None: 

74 having_parts.append(having_part) 

75 if qualify_part is not None: 

76 qualify_parts.append(qualify_part) 

77 elif c.contains_over_clause: 

78 qualify_parts.append(c) 

79 elif c.contains_aggregate: 

80 having_parts.append(c) 

81 else: 

82 where_parts.append(c) 

83 if must_remain_connected and qualify_parts: 

84 # Disjunctive heterogeneous predicates can be pushed down to 

85 # qualify as long as no conditional aggregation is involved. 

86 if not where_parts or (where_parts and not must_group_by): 

87 return None, None, self 

88 elif where_parts: 

89 # In theory this should only be enforced when dealing with 

90 # where_parts containing predicates against multi-valued 

91 # relationships that could affect aggregation results but this 

92 # is complex to infer properly. 

93 raise NotImplementedError( 

94 "Heterogeneous disjunctive predicates against window functions are " 

95 "not implemented when performing conditional aggregation." 

96 ) 

97 where_node = ( 

98 self.create(where_parts, self.connector, self.negated) 

99 if where_parts 

100 else None 

101 ) 

102 having_node = ( 

103 self.create(having_parts, self.connector, self.negated) 

104 if having_parts 

105 else None 

106 ) 

107 qualify_node = ( 

108 self.create(qualify_parts, self.connector, self.negated) 

109 if qualify_parts 

110 else None 

111 ) 

112 return where_node, having_node, qualify_node 

113 

114 def as_sql(self, compiler, connection): 

115 """ 

116 Return the SQL version of the where clause and the value to be 

117 substituted in. Return '', [] if this node matches everything, 

118 None, [] if this node is empty, and raise EmptyResultSet if this 

119 node can't match anything. 

120 """ 

121 result = [] 

122 result_params = [] 

123 if self.connector == AND: 

124 full_needed, empty_needed = len(self.children), 1 

125 else: 

126 full_needed, empty_needed = 1, len(self.children) 

127 

128 if self.connector == XOR and not connection.features.supports_logical_xor: 

129 # Convert if the database doesn't support XOR: 

130 # a XOR b XOR c XOR ... 

131 # to: 

132 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1 

133 lhs = self.__class__(self.children, OR) 

134 rhs_sum = reduce( 

135 operator.add, 

136 (Case(When(c, then=1), default=0) for c in self.children), 

137 ) 

138 rhs = Exact(1, rhs_sum) 

139 return self.__class__([lhs, rhs], AND, self.negated).as_sql( 

140 compiler, connection 

141 ) 

142 

143 for child in self.children: 

144 try: 

145 sql, params = compiler.compile(child) 

146 except EmptyResultSet: 

147 empty_needed -= 1 

148 except FullResultSet: 

149 full_needed -= 1 

150 else: 

151 if sql: 

152 result.append(sql) 

153 result_params.extend(params) 

154 else: 

155 full_needed -= 1 

156 # Check if this node matches nothing or everything. 

157 # First check the amount of full nodes and empty nodes 

158 # to make this node empty/full. 

159 # Now, check if this node is full/empty using the 

160 # counts. 

161 if empty_needed == 0: 

162 if self.negated: 

163 raise FullResultSet 

164 else: 

165 raise EmptyResultSet 

166 if full_needed == 0: 

167 if self.negated: 

168 raise EmptyResultSet 

169 else: 

170 raise FullResultSet 

171 conn = " %s " % self.connector 

172 sql_string = conn.join(result) 

173 if not sql_string: 

174 raise FullResultSet 

175 if self.negated: 

176 # Some backends (Oracle at least) need parentheses around the inner 

177 # SQL in the negated case, even if the inner SQL contains just a 

178 # single expression. 

179 sql_string = "NOT (%s)" % sql_string 

180 elif len(result) > 1 or self.resolved: 

181 sql_string = "(%s)" % sql_string 

182 return sql_string, result_params 

183 

184 def get_group_by_cols(self): 

185 cols = [] 

186 for child in self.children: 

187 cols.extend(child.get_group_by_cols()) 

188 return cols 

189 

190 def get_source_expressions(self): 

191 return self.children[:] 

192 

193 def set_source_expressions(self, children): 

194 assert len(children) == len(self.children) 

195 self.children = children 

196 

197 def relabel_aliases(self, change_map): 

198 """ 

199 Relabel the alias values of any children. 'change_map' is a dictionary 

200 mapping old (current) alias values to the new values. 

201 """ 

202 for pos, child in enumerate(self.children): 

203 if hasattr(child, "relabel_aliases"): 

204 # For example another WhereNode 

205 child.relabel_aliases(change_map) 

206 elif hasattr(child, "relabeled_clone"): 

207 self.children[pos] = child.relabeled_clone(change_map) 

208 

209 def clone(self): 

210 clone = self.create(connector=self.connector, negated=self.negated) 

211 for child in self.children: 

212 if hasattr(child, "clone"): 

213 child = child.clone() 

214 clone.children.append(child) 

215 return clone 

216 

217 def relabeled_clone(self, change_map): 

218 clone = self.clone() 

219 clone.relabel_aliases(change_map) 

220 return clone 

221 

222 def replace_expressions(self, replacements): 

223 if replacement := replacements.get(self): 

224 return replacement 

225 clone = self.create(connector=self.connector, negated=self.negated) 

226 for child in self.children: 

227 clone.children.append(child.replace_expressions(replacements)) 

228 return clone 

229 

230 def get_refs(self): 

231 refs = set() 

232 for child in self.children: 

233 refs |= child.get_refs() 

234 return refs 

235 

236 @classmethod 

237 def _contains_aggregate(cls, obj): 

238 if isinstance(obj, tree.Node): 

239 return any(cls._contains_aggregate(c) for c in obj.children) 

240 return obj.contains_aggregate 

241 

242 @cached_property 

243 def contains_aggregate(self): 

244 return self._contains_aggregate(self) 

245 

246 @classmethod 

247 def _contains_over_clause(cls, obj): 

248 if isinstance(obj, tree.Node): 

249 return any(cls._contains_over_clause(c) for c in obj.children) 

250 return obj.contains_over_clause 

251 

252 @cached_property 

253 def contains_over_clause(self): 

254 return self._contains_over_clause(self) 

255 

256 @property 

257 def is_summary(self): 

258 return any(child.is_summary for child in self.children) 

259 

260 @staticmethod 

261 def _resolve_leaf(expr, query, *args, **kwargs): 

262 if hasattr(expr, "resolve_expression"): 

263 expr = expr.resolve_expression(query, *args, **kwargs) 

264 return expr 

265 

266 @classmethod 

267 def _resolve_node(cls, node, query, *args, **kwargs): 

268 if hasattr(node, "children"): 

269 for child in node.children: 

270 cls._resolve_node(child, query, *args, **kwargs) 

271 if hasattr(node, "lhs"): 

272 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs) 

273 if hasattr(node, "rhs"): 

274 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs) 

275 

276 def resolve_expression(self, *args, **kwargs): 

277 clone = self.clone() 

278 clone._resolve_node(clone, *args, **kwargs) 

279 clone.resolved = True 

280 return clone 

281 

282 @cached_property 

283 def output_field(self): 

284 from django.db.models import BooleanField 

285 

286 return BooleanField() 

287 

288 @property 

289 def _output_field_or_none(self): 

290 return self.output_field 

291 

292 def select_format(self, compiler, sql, params): 

293 # Wrap filters with a CASE WHEN expression if a database backend 

294 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP 

295 # BY list. 

296 if not compiler.connection.features.supports_boolean_expr_in_select_clause: 

297 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

298 return sql, params 

299 

300 def get_db_converters(self, connection): 

301 return self.output_field.get_db_converters(connection) 

302 

303 def get_lookup(self, lookup): 

304 return self.output_field.get_lookup(lookup) 

305 

306 def leaves(self): 

307 for child in self.children: 

308 if isinstance(child, WhereNode): 

309 yield from child.leaves() 

310 else: 

311 yield child 

312 

313 

314class NothingNode: 

315 """A node that matches nothing.""" 

316 

317 contains_aggregate = False 

318 contains_over_clause = False 

319 

320 def as_sql(self, compiler=None, connection=None): 

321 raise EmptyResultSet 

322 

323 

324class ExtraWhere: 

325 # The contents are a black box - assume no aggregates or windows are used. 

326 contains_aggregate = False 

327 contains_over_clause = False 

328 

329 def __init__(self, sqls, params): 

330 self.sqls = sqls 

331 self.params = params 

332 

333 def as_sql(self, compiler=None, connection=None): 

334 sqls = ["(%s)" % sql for sql in self.sqls] 

335 return " AND ".join(sqls), list(self.params or ()) 

336 

337 

338class SubqueryConstraint: 

339 # Even if aggregates or windows would be used in a subquery, 

340 # the outer query isn't interested about those. 

341 contains_aggregate = False 

342 contains_over_clause = False 

343 

344 def __init__(self, alias, columns, targets, query_object): 

345 self.alias = alias 

346 self.columns = columns 

347 self.targets = targets 

348 query_object.clear_ordering(clear_default=True) 

349 self.query_object = query_object 

350 

351 def as_sql(self, compiler, connection): 

352 query = self.query_object 

353 query.set_values(self.targets) 

354 query_compiler = query.get_compiler(connection=connection) 

355 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)