Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/django/db/models/sql/where.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

218 statements  

1""" 

2Code to manage the creation and SQL rendering of 'where' constraints. 

3""" 

4 

5import operator 

6from functools import reduce 

7 

8from django.core.exceptions import EmptyResultSet, FullResultSet 

9from django.db.models.expressions import Case, When 

10from django.db.models.functions import Mod 

11from django.db.models.lookups import Exact 

12from django.utils import tree 

13from django.utils.functional import cached_property 

14 

15# Connection types 

16AND = "AND" 

17OR = "OR" 

18XOR = "XOR" 

19 

20 

21class WhereNode(tree.Node): 

22 """ 

23 An SQL WHERE clause. 

24 

25 The class is tied to the Query class that created it (in order to create 

26 the correct SQL). 

27 

28 A child is usually an expression producing boolean values. Most likely the 

29 expression is a Lookup instance. 

30 

31 However, a child could also be any class with as_sql() and either 

32 relabeled_clone() method or relabel_aliases() and clone() methods and 

33 contains_aggregate attribute. 

34 """ 

35 

36 default = AND 

37 resolved = False 

38 conditional = True 

39 

40 def split_having_qualify(self, negated=False, must_group_by=False): 

41 """ 

42 Return three possibly None nodes: one for those parts of self that 

43 should be included in the WHERE clause, one for those parts of self 

44 that must be included in the HAVING clause, and one for those parts 

45 that refer to window functions. 

46 """ 

47 if not self.contains_aggregate and not self.contains_over_clause: 

48 return self, None, None 

49 in_negated = negated ^ self.negated 

50 # Whether or not children must be connected in the same filtering 

51 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic. 

52 must_remain_connected = ( 

53 (in_negated and self.connector == AND) 

54 or (not in_negated and self.connector == OR) 

55 or self.connector == XOR 

56 ) 

57 if ( 

58 must_remain_connected 

59 and self.contains_aggregate 

60 and not self.contains_over_clause 

61 ): 

62 # It's must cheaper to short-circuit and stash everything in the 

63 # HAVING clause than split children if possible. 

64 return None, self, None 

65 where_parts = [] 

66 having_parts = [] 

67 qualify_parts = [] 

68 for c in self.children: 

69 if hasattr(c, "split_having_qualify"): 

70 where_part, having_part, qualify_part = c.split_having_qualify( 

71 in_negated, must_group_by 

72 ) 

73 if where_part is not None: 

74 where_parts.append(where_part) 

75 if having_part is not None: 

76 having_parts.append(having_part) 

77 if qualify_part is not None: 

78 qualify_parts.append(qualify_part) 

79 elif c.contains_over_clause: 

80 qualify_parts.append(c) 

81 elif c.contains_aggregate: 

82 having_parts.append(c) 

83 else: 

84 where_parts.append(c) 

85 if must_remain_connected and qualify_parts: 

86 # Disjunctive heterogeneous predicates can be pushed down to 

87 # qualify as long as no conditional aggregation is involved. 

88 if not where_parts or (where_parts and not must_group_by): 

89 return None, None, self 

90 elif where_parts: 

91 # In theory this should only be enforced when dealing with 

92 # where_parts containing predicates against multi-valued 

93 # relationships that could affect aggregation results but this 

94 # is complex to infer properly. 

95 raise NotImplementedError( 

96 "Heterogeneous disjunctive predicates against window functions are " 

97 "not implemented when performing conditional aggregation." 

98 ) 

99 where_node = ( 

100 self.create(where_parts, self.connector, self.negated) 

101 if where_parts 

102 else None 

103 ) 

104 having_node = ( 

105 self.create(having_parts, self.connector, self.negated) 

106 if having_parts 

107 else None 

108 ) 

109 qualify_node = ( 

110 self.create(qualify_parts, self.connector, self.negated) 

111 if qualify_parts 

112 else None 

113 ) 

114 return where_node, having_node, qualify_node 

115 

116 def as_sql(self, compiler, connection): 

117 """ 

118 Return the SQL version of the where clause and the value to be 

119 substituted in. Return '', [] if this node matches everything, 

120 None, [] if this node is empty, and raise EmptyResultSet if this 

121 node can't match anything. 

122 """ 

123 result = [] 

124 result_params = [] 

125 if self.connector == AND: 

126 full_needed, empty_needed = len(self.children), 1 

127 else: 

128 full_needed, empty_needed = 1, len(self.children) 

129 

130 if self.connector == XOR and not connection.features.supports_logical_xor: 

131 # Convert if the database doesn't support XOR: 

132 # a XOR b XOR c XOR ... 

133 # to: 

134 # (a OR b OR c OR ...) AND MOD(a + b + c + ..., 2) == 1 

135 # The result of an n-ary XOR is true when an odd number of operands 

136 # are true. 

137 lhs = self.__class__(self.children, OR) 

138 rhs_sum = reduce( 

139 operator.add, 

140 (Case(When(c, then=1), default=0) for c in self.children), 

141 ) 

142 if len(self.children) > 2: 

143 rhs_sum = Mod(rhs_sum, 2) 

144 rhs = Exact(1, rhs_sum) 

145 return self.__class__([lhs, rhs], AND, self.negated).as_sql( 

146 compiler, connection 

147 ) 

148 

149 for child in self.children: 

150 try: 

151 sql, params = compiler.compile(child) 

152 except EmptyResultSet: 

153 empty_needed -= 1 

154 except FullResultSet: 

155 full_needed -= 1 

156 else: 

157 if sql: 

158 result.append(sql) 

159 result_params.extend(params) 

160 else: 

161 full_needed -= 1 

162 # Check if this node matches nothing or everything. 

163 # First check the amount of full nodes and empty nodes 

164 # to make this node empty/full. 

165 # Now, check if this node is full/empty using the 

166 # counts. 

167 if empty_needed == 0: 

168 if self.negated: 

169 raise FullResultSet 

170 else: 

171 raise EmptyResultSet 

172 if full_needed == 0: 

173 if self.negated: 

174 raise EmptyResultSet 

175 else: 

176 raise FullResultSet 

177 conn = " %s " % self.connector 

178 sql_string = conn.join(result) 

179 if not sql_string: 

180 raise FullResultSet 

181 if self.negated: 

182 # Some backends (Oracle at least) need parentheses around the inner 

183 # SQL in the negated case, even if the inner SQL contains just a 

184 # single expression. 

185 sql_string = "NOT (%s)" % sql_string 

186 elif len(result) > 1 or self.resolved: 

187 sql_string = "(%s)" % sql_string 

188 return sql_string, result_params 

189 

190 def get_group_by_cols(self): 

191 cols = [] 

192 for child in self.children: 

193 cols.extend(child.get_group_by_cols()) 

194 return cols 

195 

196 def get_source_expressions(self): 

197 return self.children[:] 

198 

199 def set_source_expressions(self, children): 

200 assert len(children) == len(self.children) 

201 self.children = children 

202 

203 def relabel_aliases(self, change_map): 

204 """ 

205 Relabel the alias values of any children. 'change_map' is a dictionary 

206 mapping old (current) alias values to the new values. 

207 """ 

208 if not change_map: 

209 return self 

210 for pos, child in enumerate(self.children): 

211 if hasattr(child, "relabel_aliases"): 

212 # For example another WhereNode 

213 child.relabel_aliases(change_map) 

214 elif hasattr(child, "relabeled_clone"): 

215 self.children[pos] = child.relabeled_clone(change_map) 

216 

217 def clone(self): 

218 clone = self.create(connector=self.connector, negated=self.negated) 

219 for child in self.children: 

220 if hasattr(child, "clone"): 

221 child = child.clone() 

222 clone.children.append(child) 

223 return clone 

224 

225 def relabeled_clone(self, change_map): 

226 clone = self.clone() 

227 clone.relabel_aliases(change_map) 

228 return clone 

229 

230 def replace_expressions(self, replacements): 

231 if not replacements: 

232 return self 

233 if replacement := replacements.get(self): 

234 return replacement 

235 clone = self.create(connector=self.connector, negated=self.negated) 

236 for child in self.children: 

237 clone.children.append(child.replace_expressions(replacements)) 

238 return clone 

239 

240 def get_refs(self): 

241 refs = set() 

242 for child in self.children: 

243 refs |= child.get_refs() 

244 return refs 

245 

246 @classmethod 

247 def _contains_aggregate(cls, obj): 

248 if isinstance(obj, tree.Node): 

249 return any(cls._contains_aggregate(c) for c in obj.children) 

250 return obj.contains_aggregate 

251 

252 @cached_property 

253 def contains_aggregate(self): 

254 return self._contains_aggregate(self) 

255 

256 @classmethod 

257 def _contains_over_clause(cls, obj): 

258 if isinstance(obj, tree.Node): 

259 return any(cls._contains_over_clause(c) for c in obj.children) 

260 return obj.contains_over_clause 

261 

262 @cached_property 

263 def contains_over_clause(self): 

264 return self._contains_over_clause(self) 

265 

266 @property 

267 def is_summary(self): 

268 return any(child.is_summary for child in self.children) 

269 

270 @staticmethod 

271 def _resolve_leaf(expr, query, *args, **kwargs): 

272 if hasattr(expr, "resolve_expression"): 

273 expr = expr.resolve_expression(query, *args, **kwargs) 

274 return expr 

275 

276 @classmethod 

277 def _resolve_node(cls, node, query, *args, **kwargs): 

278 if hasattr(node, "children"): 

279 for child in node.children: 

280 cls._resolve_node(child, query, *args, **kwargs) 

281 if hasattr(node, "lhs"): 

282 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs) 

283 if hasattr(node, "rhs"): 

284 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs) 

285 

286 def resolve_expression(self, *args, **kwargs): 

287 clone = self.clone() 

288 clone._resolve_node(clone, *args, **kwargs) 

289 clone.resolved = True 

290 return clone 

291 

292 @cached_property 

293 def output_field(self): 

294 from django.db.models import BooleanField 

295 

296 return BooleanField() 

297 

298 @property 

299 def _output_field_or_none(self): 

300 return self.output_field 

301 

302 def select_format(self, compiler, sql, params): 

303 # Wrap filters with a CASE WHEN expression if a database backend 

304 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP 

305 # BY list. 

306 if not compiler.connection.features.supports_boolean_expr_in_select_clause: 

307 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

308 return sql, params 

309 

310 def get_db_converters(self, connection): 

311 return self.output_field.get_db_converters(connection) 

312 

313 def get_lookup(self, lookup): 

314 return self.output_field.get_lookup(lookup) 

315 

316 def leaves(self): 

317 for child in self.children: 

318 if isinstance(child, WhereNode): 

319 yield from child.leaves() 

320 else: 

321 yield child 

322 

323 

324class NothingNode: 

325 """A node that matches nothing.""" 

326 

327 contains_aggregate = False 

328 contains_over_clause = False 

329 

330 def as_sql(self, compiler=None, connection=None): 

331 raise EmptyResultSet 

332 

333 

334class ExtraWhere: 

335 # The contents are a black box - assume no aggregates or windows are used. 

336 contains_aggregate = False 

337 contains_over_clause = False 

338 

339 def __init__(self, sqls, params): 

340 self.sqls = sqls 

341 self.params = params 

342 

343 def as_sql(self, compiler=None, connection=None): 

344 sqls = ["(%s)" % sql for sql in self.sqls] 

345 return " AND ".join(sqls), list(self.params or ()) 

346 

347 

348class SubqueryConstraint: 

349 # Even if aggregates or windows would be used in a subquery, 

350 # the outer query isn't interested about those. 

351 contains_aggregate = False 

352 contains_over_clause = False 

353 

354 def __init__(self, alias, columns, targets, query_object): 

355 self.alias = alias 

356 self.columns = columns 

357 self.targets = targets 

358 query_object.clear_ordering(clear_default=True) 

359 self.query_object = query_object 

360 

361 def as_sql(self, compiler, connection): 

362 query = self.query_object 

363 query.set_values(self.targets) 

364 query_compiler = query.get_compiler(connection=connection) 

365 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)