Coverage for /pythoncovmergedfiles/medio/medio/src/pydantic/pydantic/_internal/_discriminated_union.py: 12%

168 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-27 07:38 +0000

1from __future__ import annotations as _annotations 

2 

3from enum import Enum 

4from typing import Sequence 

5 

6from pydantic_core import core_schema 

7 

8from ..errors import PydanticUserError 

9from . import _core_utils 

10from ._core_utils import collect_definitions 

11 

12 

13def apply_discriminator( 

14 schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None 

15) -> core_schema.CoreSchema: 

16 return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema) 

17 

18 

19class _ApplyInferredDiscriminator: 

20 """ 

21 This class is used to convert an input schema containing a union schema into one where that union is 

22 replaced with a tagged-union, with all the associated debugging and performance benefits. 

23 

24 This is done by: 

25 * Validating that the input schema is compatible with the provided discriminator 

26 * Introspecting the schema to determine which discriminator values should map to which union choices 

27 * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more 

28 

29 I have chosen to implement the conversion algorithm in this class, rather than a function, 

30 to make it easier to maintain state while recursively walking the provided CoreSchema. 

31 """ 

32 

33 def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]): 

34 # `discriminator` should be the name of the field which will serve as the discriminator. 

35 # It must be the python name of the field, and *not* the field's alias. Note that as of now, 

36 # all members of a discriminated union _must_ use a field with the same name as the discriminator. 

37 # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices. 

38 self.discriminator = discriminator 

39 

40 # `definitions` should contain a mapping of schema ref to schema for all schemas which might 

41 # be referenced by some choice 

42 self.definitions = definitions 

43 

44 # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator 

45 # 

46 # Note: following the v1 implementation, we currently disallow the use of different aliases 

47 # for different choices. This is not a limitation of pydantic_core, but if we try to handle 

48 # this, the inference logic gets complicated very quickly, and could result in confusing 

49 # debugging challenges for users making subtle mistakes. 

50 # 

51 # Rather than trying to do the most powerful inference possible, I think we should eventually 

52 # expose a way to more-manually control the way the TaggedUnionSchema is constructed through 

53 # the use of a new type which would be placed as an Annotation on the Union type. This would 

54 # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for 

55 # more complex cases, without over-complicating the inference logic for the common cases. 

56 self._discriminator_alias: str | None = None 

57 

58 # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value. 

59 # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while 

60 # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True. 

61 # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure 

62 # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the 

63 # python side, but resolves the issue that `None` cannot correspond to any discriminator values. 

64 self._should_be_nullable = False 

65 

66 # `_is_nullable` is used to track if the final produced schema will definitely be nullable; 

67 # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved 

68 # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap 

69 # the final value in another nullable schema. 

70 # 

71 # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks 

72 # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc. 

73 self._is_nullable = False 

74 

75 # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices 

76 # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling 

77 # algorithm may add more choices to this stack as (nested) unions are encountered. 

78 self._choices_to_handle: list[core_schema.CoreSchema] = [] 

79 

80 # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included 

81 # in the output TaggedUnionSchema that will replace the union from the input schema 

82 self._tagged_union_choices: dict[str | int, str | int | core_schema.CoreSchema] = {} 

83 

84 # `_used` is changed to True after applying the discriminator to prevent accidental re-use 

85 self._used = False 

86 

87 def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

88 """ 

89 Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided 

90 to this class. 

91 """ 

92 old_definitions = collect_definitions(schema) 

93 assert not self._used 

94 schema = self._apply_to_root(schema) 

95 if self._should_be_nullable and not self._is_nullable: 

96 schema = core_schema.nullable_schema(schema) 

97 self._used = True 

98 new_definitions = collect_definitions(schema) 

99 

100 missing_definitions = [v for k, v in old_definitions.items() if k not in new_definitions] 

101 if missing_definitions: 

102 schema = core_schema.definitions_schema(schema, missing_definitions) 

103 

104 return schema 

105 

106 def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

107 """ 

108 This method handles the outer-most stage of recursion over the input schema: 

109 unwrapping nullable or definitions schemas, and calling the `_handle_choice` 

110 method iteratively on the choices extracted (recursively) from the possibly-wrapped union. 

111 """ 

112 if schema['type'] == 'nullable': 

113 self._is_nullable = True 

114 wrapped = self._apply_to_root(schema['schema']) 

115 nullable_wrapper = schema.copy() 

116 nullable_wrapper['schema'] = wrapped 

117 return nullable_wrapper 

118 

119 if schema['type'] == 'definitions': 

120 wrapped = self._apply_to_root(schema['schema']) 

121 definitions_wrapper = schema.copy() 

122 definitions_wrapper['schema'] = wrapped 

123 return definitions_wrapper 

124 

125 if schema['type'] != 'union': 

126 raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') 

127 

128 if len(schema['choices']) < 2: 

129 raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') 

130 

131 # Reverse the choices list before extending the stack so that they get handled in the order they occur 

132 self._choices_to_handle.extend(schema['choices'][::-1]) 

133 while self._choices_to_handle: 

134 choice = self._choices_to_handle.pop() 

135 self._handle_choice(choice) 

136 

137 if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator: 

138 # * We need to annotate `discriminator` as a union here to handle both branches of this conditional 

139 # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the 

140 # invariance of list, and because list[list[str | int]] is the type of the discriminator argument 

141 # to tagged_union_schema below 

142 # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to 

143 # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here 

144 # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.) 

145 discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]] 

146 else: 

147 discriminator = self.discriminator 

148 return core_schema.tagged_union_schema( 

149 choices=self._tagged_union_choices, 

150 discriminator=discriminator, 

151 custom_error_type=schema.get('custom_error_type'), 

152 custom_error_message=schema.get('custom_error_message'), 

153 custom_error_context=schema.get('custom_error_context'), 

154 strict=False, 

155 from_attributes=True, 

156 ref=schema.get('ref'), 

157 metadata=schema.get('metadata'), 

158 serialization=schema.get('serialization'), 

159 ) 

160 

161 def _handle_choice(self, choice: core_schema.CoreSchema) -> None: 

162 """ 

163 This method handles the "middle" stage of recursion over the input schema. 

164 Specifically, it is responsible for handling each choice of the outermost union 

165 (and any "coalesced" choices obtained from inner unions). 

166 

167 Here, "handling" entails: 

168 * Coalescing nested unions and compatible tagged-unions 

169 * Tracking the presence of 'none' and 'nullable' schemas occurring as choices 

170 * Validating that each allowed discriminator value maps to a unique choice 

171 * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. 

172 """ 

173 if choice['type'] == 'none': 

174 self._should_be_nullable = True 

175 elif choice['type'] == 'definitions': 

176 self._handle_choice(choice['schema']) 

177 elif choice['type'] == 'nullable': 

178 self._should_be_nullable = True 

179 self._handle_choice(choice['schema']) # unwrap the nullable schema 

180 elif choice['type'] == 'union': 

181 # Reverse the choices list before extending the stack so that they get handled in the order they occur 

182 self._choices_to_handle.extend(choice['choices'][::-1]) 

183 elif choice['type'] == 'definition-ref': 

184 if choice['schema_ref'] not in self.definitions: 

185 raise ValueError(f"Missing definition for ref {choice['schema_ref']!r}") 

186 self._handle_choice(self.definitions[choice['schema_ref']]) 

187 elif choice['type'] not in { 

188 'model', 

189 'typed-dict', 

190 'tagged-union', 

191 'lax-or-strict', 

192 'dataclass', 

193 'dataclass-args', 

194 } and not _core_utils.is_function_with_inner_schema(choice): 

195 # We should eventually handle 'definition-ref' as well 

196 raise TypeError( 

197 f'{choice["type"]!r} is not a valid discriminated union variant;' 

198 ' should be a `BaseModel` or `dataclass`' 

199 ) 

200 else: 

201 if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): 

202 # In this case, this inner tagged-union is compatible with the outer tagged-union, 

203 # and its choices can be coalesced into the outer TaggedUnionSchema. 

204 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 

205 # Reverse the choices list before extending the stack so that they get handled in the order they occur 

206 self._choices_to_handle.extend(subchoices[::-1]) 

207 return 

208 

209 inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None) 

210 self._set_unique_choice_for_values(choice, inferred_discriminator_values) 

211 

212 def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool: 

213 """ 

214 This method returns a boolean indicating whether the discriminator for the `choice` 

215 is the same as that being used for the outermost tagged union. This is used to 

216 determine whether this TaggedUnionSchema choice should be "coalesced" into the top level, 

217 or whether it should be treated as a separate (nested) choice. 

218 """ 

219 inner_discriminator = choice['discriminator'] 

220 return inner_discriminator == self.discriminator or ( 

221 isinstance(inner_discriminator, list) 

222 and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator) 

223 ) 

224 

225 def _infer_discriminator_values_for_choice( 

226 self, choice: core_schema.CoreSchema, source_name: str | None 

227 ) -> list[str | int]: 

228 """ 

229 This function recurses over `choice`, extracting all discriminator values that should map to this choice. 

230 

231 `model_name` is accepted for the purpose of producing useful error messages. 

232 """ 

233 if choice['type'] == 'definitions': 

234 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 

235 elif choice['type'] == 'function-plain': 

236 raise TypeError( 

237 f'{choice["type"]!r} is not a valid discriminated union variant;' 

238 ' should be a `BaseModel` or `dataclass`' 

239 ) 

240 elif _core_utils.is_function_with_inner_schema(choice): 

241 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 

242 elif choice['type'] == 'lax-or-strict': 

243 return sorted( 

244 set( 

245 self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None) 

246 + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None) 

247 ) 

248 ) 

249 

250 elif choice['type'] == 'tagged-union': 

251 values: list[str | int] = [] 

252 # Ignore str/int "choices" since these are just references to other choices 

253 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 

254 for subchoice in subchoices: 

255 subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) 

256 values.extend(subchoice_values) 

257 return values 

258 

259 elif choice['type'] == 'union': 

260 values = [] 

261 for subchoice in choice['choices']: 

262 subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) 

263 values.extend(subchoice_values) 

264 return values 

265 

266 elif choice['type'] == 'nullable': 

267 self._should_be_nullable = True 

268 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None) 

269 

270 elif choice['type'] == 'model': 

271 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 

272 

273 elif choice['type'] == 'dataclass': 

274 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 

275 

276 elif choice['type'] == 'dataclass-args': 

277 return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name) 

278 

279 elif choice['type'] == 'typed-dict': 

280 return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name) 

281 

282 else: 

283 raise TypeError( 

284 f'{choice["type"]!r} is not a valid discriminated union variant;' 

285 ' should be a `BaseModel` or `dataclass`' 

286 ) 

287 

288 def _infer_discriminator_values_for_typed_dict_choice( 

289 self, choice: core_schema.TypedDictSchema, source_name: str | None = None 

290 ) -> list[str | int]: 

291 """ 

292 This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema 

293 for the sake of readability. 

294 """ 

295 source = 'TypedDict' if source_name is None else f'Model {source_name!r}' 

296 field = choice['fields'].get(self.discriminator) 

297 if field is None: 

298 raise PydanticUserError( 

299 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' 

300 ) 

301 return self._infer_discriminator_values_for_field(field, source) 

302 

303 def _infer_discriminator_values_for_dataclass_choice( 

304 self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None 

305 ) -> list[str | int]: 

306 source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}' 

307 for field in choice['fields']: 

308 if field['name'] == self.discriminator: 

309 break 

310 else: 

311 raise PydanticUserError( 

312 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' 

313 ) 

314 return self._infer_discriminator_values_for_field(field, source) 

315 

316 def _infer_discriminator_values_for_field( 

317 self, field: core_schema.TypedDictField | core_schema.DataclassField, source: str 

318 ) -> list[str | int]: 

319 alias = field.get('validation_alias', self.discriminator) 

320 if not isinstance(alias, str): 

321 raise PydanticUserError( 

322 f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type' 

323 ) 

324 if self._discriminator_alias is None: 

325 self._discriminator_alias = alias 

326 elif self._discriminator_alias != alias: 

327 raise PydanticUserError( 

328 f'Aliases for discriminator {self.discriminator!r} must be the same ' 

329 f'(got {alias}, {self._discriminator_alias})', 

330 code='discriminator-alias', 

331 ) 

332 return self._infer_discriminator_values_for_inner_schema(field['schema'], source) 

333 

334 def _infer_discriminator_values_for_inner_schema( 

335 self, schema: core_schema.CoreSchema, source: str 

336 ) -> list[str | int]: 

337 """ 

338 When inferring discriminator values for a field, we typically extract the expected values from a literal schema. 

339 This function does that, but also handles nested unions and defaults. 

340 """ 

341 if schema['type'] == 'literal': 

342 values = [] 

343 for v in schema['expected']: 

344 if isinstance(v, Enum): 

345 v = v.value 

346 if not isinstance(v, (str, int)): 

347 raise TypeError(f'Unsupported value for discriminator field: {v!r}') 

348 values.append(v) 

349 return values 

350 

351 elif schema['type'] == 'union': 

352 # Generally when multiple values are allowed they should be placed in a single `Literal`, but 

353 # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s. 

354 # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]` 

355 values = [] 

356 for choice in schema['choices']: 

357 choice_values = self._infer_discriminator_values_for_inner_schema(choice, source) 

358 values.extend(choice_values) 

359 return values 

360 

361 elif schema['type'] == 'default': 

362 # This will happen if the field has a default value; we ignore it while extracting the discriminator values 

363 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 

364 

365 else: 

366 raise PydanticUserError( 

367 f'{source} needs field {self.discriminator!r} to be of type `Literal`', 

368 code='discriminator-needs-literal', 

369 ) 

370 

371 def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None: 

372 """ 

373 This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the 

374 provided `choice`, validating that none of these values already map to another (different) choice. 

375 """ 

376 primary_value: str | int | None = None 

377 for discriminator_value in values: 

378 if discriminator_value in self._tagged_union_choices: 

379 # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value. 

380 # Because tagged_union_choices may map values to other values, we need to walk the choices dict 

381 # until we get to a "real" choice, and confirm that is equal to the one assigned. 

382 existing_choice = self._tagged_union_choices[discriminator_value] 

383 while isinstance(existing_choice, (str, int)): 

384 existing_choice = self._tagged_union_choices[existing_choice] 

385 if existing_choice != choice: 

386 raise TypeError( 

387 f'Value {discriminator_value!r} for discriminator ' 

388 f'{self.discriminator!r} mapped to multiple choices' 

389 ) 

390 elif primary_value is None: 

391 self._tagged_union_choices[discriminator_value] = choice 

392 primary_value = discriminator_value 

393 else: 

394 self._tagged_union_choices[discriminator_value] = primary_value