Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/constraints.py: 24%

216 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1from enum import Enum 

2 

3from django.core.exceptions import FieldError, ValidationError 

4from django.db import connections 

5from django.db.models.expressions import Exists, ExpressionList, F 

6from django.db.models.indexes import IndexExpression 

7from django.db.models.lookups import Exact 

8from django.db.models.query_utils import Q 

9from django.db.models.sql.query import Query 

10from django.db.utils import DEFAULT_DB_ALIAS 

11from django.utils.translation import gettext_lazy as _ 

12 

13__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"] 

14 

15 

16class BaseConstraint: 

17 default_violation_error_message = _("Constraint “%(name)s” is violated.") 

18 violation_error_message = None 

19 

20 def __init__(self, name, violation_error_message=None): 

21 self.name = name 

22 if violation_error_message is not None: 

23 self.violation_error_message = violation_error_message 

24 else: 

25 self.violation_error_message = self.default_violation_error_message 

26 

27 @property 

28 def contains_expressions(self): 

29 return False 

30 

31 def constraint_sql(self, model, schema_editor): 

32 raise NotImplementedError("This method must be implemented by a subclass.") 

33 

34 def create_sql(self, model, schema_editor): 

35 raise NotImplementedError("This method must be implemented by a subclass.") 

36 

37 def remove_sql(self, model, schema_editor): 

38 raise NotImplementedError("This method must be implemented by a subclass.") 

39 

40 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

41 raise NotImplementedError("This method must be implemented by a subclass.") 

42 

43 def get_violation_error_message(self): 

44 return self.violation_error_message % {"name": self.name} 

45 

46 def deconstruct(self): 

47 path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) 

48 path = path.replace("django.db.models.constraints", "django.db.models") 

49 kwargs = {"name": self.name} 

50 if ( 

51 self.violation_error_message is not None 

52 and self.violation_error_message != self.default_violation_error_message 

53 ): 

54 kwargs["violation_error_message"] = self.violation_error_message 

55 return (path, (), kwargs) 

56 

57 def clone(self): 

58 _, args, kwargs = self.deconstruct() 

59 return self.__class__(*args, **kwargs) 

60 

61 

62class CheckConstraint(BaseConstraint): 

63 def __init__(self, *, check, name, violation_error_message=None): 

64 self.check = check 

65 if not getattr(check, "conditional", False): 

66 raise TypeError( 

67 "CheckConstraint.check must be a Q instance or boolean expression." 

68 ) 

69 super().__init__(name, violation_error_message=violation_error_message) 

70 

71 def _get_check_sql(self, model, schema_editor): 

72 query = Query(model=model, alias_cols=False) 

73 where = query.build_where(self.check) 

74 compiler = query.get_compiler(connection=schema_editor.connection) 

75 sql, params = where.as_sql(compiler, schema_editor.connection) 

76 return sql % tuple(schema_editor.quote_value(p) for p in params) 

77 

78 def constraint_sql(self, model, schema_editor): 

79 check = self._get_check_sql(model, schema_editor) 

80 return schema_editor._check_sql(self.name, check) 

81 

82 def create_sql(self, model, schema_editor): 

83 check = self._get_check_sql(model, schema_editor) 

84 return schema_editor._create_check_sql(model, self.name, check) 

85 

86 def remove_sql(self, model, schema_editor): 

87 return schema_editor._delete_check_sql(model, self.name) 

88 

89 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

90 against = instance._get_field_value_map(meta=model._meta, exclude=exclude) 

91 try: 

92 if not Q(self.check).check(against, using=using): 

93 raise ValidationError(self.get_violation_error_message()) 

94 except FieldError: 

95 pass 

96 

97 def __repr__(self): 

98 return "<%s: check=%s name=%s>" % ( 

99 self.__class__.__qualname__, 

100 self.check, 

101 repr(self.name), 

102 ) 

103 

104 def __eq__(self, other): 

105 if isinstance(other, CheckConstraint): 

106 return ( 

107 self.name == other.name 

108 and self.check == other.check 

109 and self.violation_error_message == other.violation_error_message 

110 ) 

111 return super().__eq__(other) 

112 

113 def deconstruct(self): 

114 path, args, kwargs = super().deconstruct() 

115 kwargs["check"] = self.check 

116 return path, args, kwargs 

117 

118 

119class Deferrable(Enum): 

120 DEFERRED = "deferred" 

121 IMMEDIATE = "immediate" 

122 

123 # A similar format was proposed for Python 3.10. 

124 def __repr__(self): 

125 return f"{self.__class__.__qualname__}.{self._name_}" 

126 

127 

128class UniqueConstraint(BaseConstraint): 

129 def __init__( 

130 self, 

131 *expressions, 

132 fields=(), 

133 name=None, 

134 condition=None, 

135 deferrable=None, 

136 include=None, 

137 opclasses=(), 

138 violation_error_message=None, 

139 ): 

140 if not name: 

141 raise ValueError("A unique constraint must be named.") 

142 if not expressions and not fields: 

143 raise ValueError( 

144 "At least one field or expression is required to define a " 

145 "unique constraint." 

146 ) 

147 if expressions and fields: 

148 raise ValueError( 

149 "UniqueConstraint.fields and expressions are mutually exclusive." 

150 ) 

151 if not isinstance(condition, (type(None), Q)): 

152 raise ValueError("UniqueConstraint.condition must be a Q instance.") 

153 if condition and deferrable: 

154 raise ValueError("UniqueConstraint with conditions cannot be deferred.") 

155 if include and deferrable: 

156 raise ValueError("UniqueConstraint with include fields cannot be deferred.") 

157 if opclasses and deferrable: 

158 raise ValueError("UniqueConstraint with opclasses cannot be deferred.") 

159 if expressions and deferrable: 

160 raise ValueError("UniqueConstraint with expressions cannot be deferred.") 

161 if expressions and opclasses: 

162 raise ValueError( 

163 "UniqueConstraint.opclasses cannot be used with expressions. " 

164 "Use django.contrib.postgres.indexes.OpClass() instead." 

165 ) 

166 if not isinstance(deferrable, (type(None), Deferrable)): 

167 raise ValueError( 

168 "UniqueConstraint.deferrable must be a Deferrable instance." 

169 ) 

170 if not isinstance(include, (type(None), list, tuple)): 

171 raise ValueError("UniqueConstraint.include must be a list or tuple.") 

172 if not isinstance(opclasses, (list, tuple)): 

173 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.") 

174 if opclasses and len(fields) != len(opclasses): 

175 raise ValueError( 

176 "UniqueConstraint.fields and UniqueConstraint.opclasses must " 

177 "have the same number of elements." 

178 ) 

179 self.fields = tuple(fields) 

180 self.condition = condition 

181 self.deferrable = deferrable 

182 self.include = tuple(include) if include else () 

183 self.opclasses = opclasses 

184 self.expressions = tuple( 

185 F(expression) if isinstance(expression, str) else expression 

186 for expression in expressions 

187 ) 

188 super().__init__(name, violation_error_message=violation_error_message) 

189 

190 @property 

191 def contains_expressions(self): 

192 return bool(self.expressions) 

193 

194 def _get_condition_sql(self, model, schema_editor): 

195 if self.condition is None: 

196 return None 

197 query = Query(model=model, alias_cols=False) 

198 where = query.build_where(self.condition) 

199 compiler = query.get_compiler(connection=schema_editor.connection) 

200 sql, params = where.as_sql(compiler, schema_editor.connection) 

201 return sql % tuple(schema_editor.quote_value(p) for p in params) 

202 

203 def _get_index_expressions(self, model, schema_editor): 

204 if not self.expressions: 

205 return None 

206 index_expressions = [] 

207 for expression in self.expressions: 

208 index_expression = IndexExpression(expression) 

209 index_expression.set_wrapper_classes(schema_editor.connection) 

210 index_expressions.append(index_expression) 

211 return ExpressionList(*index_expressions).resolve_expression( 

212 Query(model, alias_cols=False), 

213 ) 

214 

215 def constraint_sql(self, model, schema_editor): 

216 fields = [model._meta.get_field(field_name) for field_name in self.fields] 

217 include = [ 

218 model._meta.get_field(field_name).column for field_name in self.include 

219 ] 

220 condition = self._get_condition_sql(model, schema_editor) 

221 expressions = self._get_index_expressions(model, schema_editor) 

222 return schema_editor._unique_sql( 

223 model, 

224 fields, 

225 self.name, 

226 condition=condition, 

227 deferrable=self.deferrable, 

228 include=include, 

229 opclasses=self.opclasses, 

230 expressions=expressions, 

231 ) 

232 

233 def create_sql(self, model, schema_editor): 

234 fields = [model._meta.get_field(field_name) for field_name in self.fields] 

235 include = [ 

236 model._meta.get_field(field_name).column for field_name in self.include 

237 ] 

238 condition = self._get_condition_sql(model, schema_editor) 

239 expressions = self._get_index_expressions(model, schema_editor) 

240 return schema_editor._create_unique_sql( 

241 model, 

242 fields, 

243 self.name, 

244 condition=condition, 

245 deferrable=self.deferrable, 

246 include=include, 

247 opclasses=self.opclasses, 

248 expressions=expressions, 

249 ) 

250 

251 def remove_sql(self, model, schema_editor): 

252 condition = self._get_condition_sql(model, schema_editor) 

253 include = [ 

254 model._meta.get_field(field_name).column for field_name in self.include 

255 ] 

256 expressions = self._get_index_expressions(model, schema_editor) 

257 return schema_editor._delete_unique_sql( 

258 model, 

259 self.name, 

260 condition=condition, 

261 deferrable=self.deferrable, 

262 include=include, 

263 opclasses=self.opclasses, 

264 expressions=expressions, 

265 ) 

266 

267 def __repr__(self): 

268 return "<%s:%s%s%s%s%s%s%s>" % ( 

269 self.__class__.__qualname__, 

270 "" if not self.fields else " fields=%s" % repr(self.fields), 

271 "" if not self.expressions else " expressions=%s" % repr(self.expressions), 

272 " name=%s" % repr(self.name), 

273 "" if self.condition is None else " condition=%s" % self.condition, 

274 "" if self.deferrable is None else " deferrable=%r" % self.deferrable, 

275 "" if not self.include else " include=%s" % repr(self.include), 

276 "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), 

277 ) 

278 

279 def __eq__(self, other): 

280 if isinstance(other, UniqueConstraint): 

281 return ( 

282 self.name == other.name 

283 and self.fields == other.fields 

284 and self.condition == other.condition 

285 and self.deferrable == other.deferrable 

286 and self.include == other.include 

287 and self.opclasses == other.opclasses 

288 and self.expressions == other.expressions 

289 and self.violation_error_message == other.violation_error_message 

290 ) 

291 return super().__eq__(other) 

292 

293 def deconstruct(self): 

294 path, args, kwargs = super().deconstruct() 

295 if self.fields: 

296 kwargs["fields"] = self.fields 

297 if self.condition: 

298 kwargs["condition"] = self.condition 

299 if self.deferrable: 

300 kwargs["deferrable"] = self.deferrable 

301 if self.include: 

302 kwargs["include"] = self.include 

303 if self.opclasses: 

304 kwargs["opclasses"] = self.opclasses 

305 return path, self.expressions, kwargs 

306 

307 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

308 queryset = model._default_manager.using(using) 

309 if self.fields: 

310 lookup_kwargs = {} 

311 for field_name in self.fields: 

312 if exclude and field_name in exclude: 

313 return 

314 field = model._meta.get_field(field_name) 

315 lookup_value = getattr(instance, field.attname) 

316 if lookup_value is None or ( 

317 lookup_value == "" 

318 and connections[using].features.interprets_empty_strings_as_nulls 

319 ): 

320 # A composite constraint containing NULL value cannot cause 

321 # a violation since NULL != NULL in SQL. 

322 return 

323 lookup_kwargs[field.name] = lookup_value 

324 queryset = queryset.filter(**lookup_kwargs) 

325 else: 

326 # Ignore constraints with excluded fields. 

327 if exclude: 

328 for expression in self.expressions: 

329 if hasattr(expression, "flatten"): 

330 for expr in expression.flatten(): 

331 if isinstance(expr, F) and expr.name in exclude: 

332 return 

333 elif isinstance(expression, F) and expression.name in exclude: 

334 return 

335 replacements = { 

336 F(field): value 

337 for field, value in instance._get_field_value_map( 

338 meta=model._meta, exclude=exclude 

339 ).items() 

340 } 

341 expressions = [ 

342 Exact(expr, expr.replace_expressions(replacements)) 

343 for expr in self.expressions 

344 ] 

345 queryset = queryset.filter(*expressions) 

346 model_class_pk = instance._get_pk_val(model._meta) 

347 if not instance._state.adding and model_class_pk is not None: 

348 queryset = queryset.exclude(pk=model_class_pk) 

349 if not self.condition: 

350 if queryset.exists(): 

351 if self.expressions: 

352 raise ValidationError(self.get_violation_error_message()) 

353 # When fields are defined, use the unique_error_message() for 

354 # backward compatibility. 

355 for model, constraints in instance.get_constraints(): 

356 for constraint in constraints: 

357 if constraint is self: 

358 raise ValidationError( 

359 instance.unique_error_message(model, self.fields) 

360 ) 

361 else: 

362 against = instance._get_field_value_map(meta=model._meta, exclude=exclude) 

363 try: 

364 if (self.condition & Exists(queryset.filter(self.condition))).check( 

365 against, using=using 

366 ): 

367 raise ValidationError(self.get_violation_error_message()) 

368 except FieldError: 

369 pass