Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pydantic/_internal/_discriminated_union.py: 13%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

198 statements  

1from __future__ import annotations as _annotations 

2 

3from collections.abc import Hashable, Sequence 

4from typing import TYPE_CHECKING, Any, cast 

5 

6from pydantic_core import CoreSchema, core_schema 

7 

8from ..errors import PydanticUserError 

9from . import _core_utils 

10from ._core_utils import ( 

11 CoreSchemaField, 

12) 

13 

14if TYPE_CHECKING: 

15 from ..types import Discriminator 

16 from ._core_metadata import CoreMetadata 

17 

18 

19class MissingDefinitionForUnionRef(Exception): 

20 """Raised when applying a discriminated union discriminator to a schema 

21 requires a definition that is not yet defined 

22 """ 

23 

24 def __init__(self, ref: str) -> None: 

25 self.ref = ref 

26 super().__init__(f'Missing definition for ref {self.ref!r}') 

27 

28 

29def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: 

30 metadata = cast('CoreMetadata', schema.setdefault('metadata', {})) 

31 metadata['pydantic_internal_union_discriminator'] = discriminator 

32 

33 

34def apply_discriminator( 

35 schema: core_schema.CoreSchema, 

36 discriminator: str | Discriminator, 

37 definitions: dict[str, core_schema.CoreSchema] | None = None, 

38) -> core_schema.CoreSchema: 

39 """Applies the discriminator and returns a new core schema. 

40 

41 Args: 

42 schema: The input schema. 

43 discriminator: The name of the field which will serve as the discriminator. 

44 definitions: A mapping of schema ref to schema. 

45 

46 Returns: 

47 The new core schema. 

48 

49 Raises: 

50 TypeError: 

51 - If `discriminator` is used with invalid union variant. 

52 - If `discriminator` is used with `Union` type with one variant. 

53 - If `discriminator` value mapped to multiple choices. 

54 MissingDefinitionForUnionRef: 

55 If the definition for ref is missing. 

56 PydanticUserError: 

57 - If a model in union doesn't have a discriminator field. 

58 - If discriminator field has a non-string alias. 

59 - If discriminator fields have different aliases. 

60 - If discriminator field not of type `Literal`. 

61 """ 

62 from ..types import Discriminator 

63 

64 if isinstance(discriminator, Discriminator): 

65 if isinstance(discriminator.discriminator, str): 

66 discriminator = discriminator.discriminator 

67 else: 

68 return discriminator._convert_schema(schema) 

69 

70 return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema) 

71 

72 

73class _ApplyInferredDiscriminator: 

74 """This class is used to convert an input schema containing a union schema into one where that union is 

75 replaced with a tagged-union, with all the associated debugging and performance benefits. 

76 

77 This is done by: 

78 * Validating that the input schema is compatible with the provided discriminator 

79 * Introspecting the schema to determine which discriminator values should map to which union choices 

80 * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more 

81 

82 I have chosen to implement the conversion algorithm in this class, rather than a function, 

83 to make it easier to maintain state while recursively walking the provided CoreSchema. 

84 """ 

85 

86 def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]): 

87 # `discriminator` should be the name of the field which will serve as the discriminator. 

88 # It must be the python name of the field, and *not* the field's alias. Note that as of now, 

89 # all members of a discriminated union _must_ use a field with the same name as the discriminator. 

90 # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices. 

91 self.discriminator = discriminator 

92 

93 # `definitions` should contain a mapping of schema ref to schema for all schemas which might 

94 # be referenced by some choice 

95 self.definitions = definitions 

96 

97 # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator 

98 # 

99 # Note: following the v1 implementation, we currently disallow the use of different aliases 

100 # for different choices. This is not a limitation of pydantic_core, but if we try to handle 

101 # this, the inference logic gets complicated very quickly, and could result in confusing 

102 # debugging challenges for users making subtle mistakes. 

103 # 

104 # Rather than trying to do the most powerful inference possible, I think we should eventually 

105 # expose a way to more-manually control the way the TaggedUnionSchema is constructed through 

106 # the use of a new type which would be placed as an Annotation on the Union type. This would 

107 # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for 

108 # more complex cases, without over-complicating the inference logic for the common cases. 

109 self._discriminator_alias: str | None = None 

110 

111 # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value. 

112 # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while 

113 # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True. 

114 # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure 

115 # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the 

116 # python side, but resolves the issue that `None` cannot correspond to any discriminator values. 

117 self._should_be_nullable = False 

118 

119 # `_is_nullable` is used to track if the final produced schema will definitely be nullable; 

120 # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved 

121 # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap 

122 # the final value in another nullable schema. 

123 # 

124 # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks 

125 # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc. 

126 self._is_nullable = False 

127 

128 # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices 

129 # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling 

130 # algorithm may add more choices to this stack as (nested) unions are encountered. 

131 self._choices_to_handle: list[core_schema.CoreSchema] = [] 

132 

133 # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included 

134 # in the output TaggedUnionSchema that will replace the union from the input schema 

135 self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {} 

136 

137 # `_used` is changed to True after applying the discriminator to prevent accidental reuse 

138 self._used = False 

139 

140 def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

141 """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided 

142 to this class. 

143 

144 Args: 

145 schema: The input schema. 

146 

147 Returns: 

148 The new core schema. 

149 

150 Raises: 

151 TypeError: 

152 - If `discriminator` is used with invalid union variant. 

153 - If `discriminator` is used with `Union` type with one variant. 

154 - If `discriminator` value mapped to multiple choices. 

155 ValueError: 

156 If the definition for ref is missing. 

157 PydanticUserError: 

158 - If a model in union doesn't have a discriminator field. 

159 - If discriminator field has a non-string alias. 

160 - If discriminator fields have different aliases. 

161 - If discriminator field not of type `Literal`. 

162 """ 

163 assert not self._used 

164 schema = self._apply_to_root(schema) 

165 if self._should_be_nullable and not self._is_nullable: 

166 schema = core_schema.nullable_schema(schema) 

167 self._used = True 

168 return schema 

169 

170 def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

171 """This method handles the outer-most stage of recursion over the input schema: 

172 unwrapping nullable or definitions schemas, and calling the `_handle_choice` 

173 method iteratively on the choices extracted (recursively) from the possibly-wrapped union. 

174 """ 

175 if schema['type'] == 'nullable': 

176 self._is_nullable = True 

177 wrapped = self._apply_to_root(schema['schema']) 

178 nullable_wrapper = schema.copy() 

179 nullable_wrapper['schema'] = wrapped 

180 return nullable_wrapper 

181 

182 if schema['type'] == 'definitions': 

183 wrapped = self._apply_to_root(schema['schema']) 

184 definitions_wrapper = schema.copy() 

185 definitions_wrapper['schema'] = wrapped 

186 return definitions_wrapper 

187 

188 if schema['type'] == 'definition-ref': 

189 schema_ref = schema['schema_ref'] 

190 if schema_ref not in self.definitions: # pragma: no cover 

191 raise MissingDefinitionForUnionRef(schema_ref) 

192 

193 def_schema = self.definitions[schema_ref] 

194 # If using a referenceable union as discriminated (e.g. `type Pet = Cat | Dog; field: Pet = Field(discriminator=...)`): 

195 if def_schema['type'] == 'union': 

196 schema = def_schema.copy() 

197 schema.pop('ref') 

198 

199 if schema['type'] != 'union': 

200 # If the schema is not a union, it probably means it just had a single member and 

201 # was flattened by pydantic_core. 

202 # However, it still may make sense to apply the discriminator to this schema, 

203 # as a way to get discriminated-union-style error messages, so we allow this here. 

204 schema = core_schema.union_schema([schema]) 

205 

206 # Reverse the choices list before extending the stack so that they get handled in the order they occur 

207 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]] 

208 self._choices_to_handle.extend(choices_schemas) 

209 while self._choices_to_handle: 

210 choice = self._choices_to_handle.pop() 

211 self._handle_choice(choice) 

212 

213 if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator: 

214 # * We need to annotate `discriminator` as a union here to handle both branches of this conditional 

215 # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the 

216 # invariance of list, and because list[list[str | int]] is the type of the discriminator argument 

217 # to tagged_union_schema below 

218 # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to 

219 # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here 

220 # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.) 

221 discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]] 

222 else: 

223 discriminator = self.discriminator 

224 return core_schema.tagged_union_schema( 

225 choices=self._tagged_union_choices, 

226 discriminator=discriminator, 

227 custom_error_type=schema.get('custom_error_type'), 

228 custom_error_message=schema.get('custom_error_message'), 

229 custom_error_context=schema.get('custom_error_context'), 

230 strict=False, 

231 from_attributes=True, 

232 ref=schema.get('ref'), 

233 metadata=schema.get('metadata'), 

234 serialization=schema.get('serialization'), 

235 ) 

236 

237 def _handle_choice(self, choice: core_schema.CoreSchema) -> None: 

238 """This method handles the "middle" stage of recursion over the input schema. 

239 Specifically, it is responsible for handling each choice of the outermost union 

240 (and any "coalesced" choices obtained from inner unions). 

241 

242 Here, "handling" entails: 

243 * Coalescing nested unions and compatible tagged-unions 

244 * Tracking the presence of 'none' and 'nullable' schemas occurring as choices 

245 * Validating that each allowed discriminator value maps to a unique choice 

246 * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. 

247 """ 

248 if choice['type'] == 'definition-ref': 

249 if choice['schema_ref'] not in self.definitions: 

250 raise MissingDefinitionForUnionRef(choice['schema_ref']) 

251 

252 if choice['type'] == 'none': 

253 self._should_be_nullable = True 

254 elif choice['type'] == 'definitions': 

255 self._handle_choice(choice['schema']) 

256 elif choice['type'] == 'nullable': 

257 self._should_be_nullable = True 

258 self._handle_choice(choice['schema']) # unwrap the nullable schema 

259 elif choice['type'] == 'union': 

260 # Reverse the choices list before extending the stack so that they get handled in the order they occur 

261 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] 

262 self._choices_to_handle.extend(choices_schemas) 

263 elif choice['type'] not in { 

264 'model', 

265 'typed-dict', 

266 'tagged-union', 

267 'lax-or-strict', 

268 'dataclass', 

269 'dataclass-args', 

270 'definition-ref', 

271 } and not _core_utils.is_function_with_inner_schema(choice): 

272 # We should eventually handle 'definition-ref' as well 

273 err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.' 

274 if choice['type'] == 'list': 

275 err_str += ( 

276 ' If you are making use of a list of union types, make sure the discriminator is applied to the ' 

277 'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).' 

278 ) 

279 raise TypeError(err_str) 

280 else: 

281 if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): 

282 # In this case, this inner tagged-union is compatible with the outer tagged-union, 

283 # and its choices can be coalesced into the outer TaggedUnionSchema. 

284 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 

285 # Reverse the choices list before extending the stack so that they get handled in the order they occur 

286 self._choices_to_handle.extend(subchoices[::-1]) 

287 return 

288 

289 inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None) 

290 self._set_unique_choice_for_values(choice, inferred_discriminator_values) 

291 

292 def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool: 

293 """This method returns a boolean indicating whether the discriminator for the `choice` 

294 is the same as that being used for the outermost tagged union. This is used to 

295 determine whether this TaggedUnionSchema choice should be "coalesced" into the top level, 

296 or whether it should be treated as a separate (nested) choice. 

297 """ 

298 inner_discriminator = choice['discriminator'] 

299 return inner_discriminator == self.discriminator or ( 

300 isinstance(inner_discriminator, list) 

301 and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator) 

302 ) 

303 

304 def _infer_discriminator_values_for_choice( # noqa C901 

305 self, choice: core_schema.CoreSchema, source_name: str | None 

306 ) -> list[str | int]: 

307 """This function recurses over `choice`, extracting all discriminator values that should map to this choice. 

308 

309 `model_name` is accepted for the purpose of producing useful error messages. 

310 """ 

311 if choice['type'] == 'definitions': 

312 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 

313 

314 elif _core_utils.is_function_with_inner_schema(choice): 

315 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 

316 

317 elif choice['type'] == 'lax-or-strict': 

318 return sorted( 

319 set( 

320 self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None) 

321 + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None) 

322 ) 

323 ) 

324 

325 elif choice['type'] == 'tagged-union': 

326 values: list[str | int] = [] 

327 # Ignore str/int "choices" since these are just references to other choices 

328 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 

329 for subchoice in subchoices: 

330 subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) 

331 values.extend(subchoice_values) 

332 return values 

333 

334 elif choice['type'] == 'union': 

335 values = [] 

336 for subchoice in choice['choices']: 

337 subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice 

338 subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None) 

339 values.extend(subchoice_values) 

340 return values 

341 

342 elif choice['type'] == 'nullable': 

343 self._should_be_nullable = True 

344 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None) 

345 

346 elif choice['type'] == 'model': 

347 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 

348 

349 elif choice['type'] == 'dataclass': 

350 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 

351 

352 elif choice['type'] == 'model-fields': 

353 return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name) 

354 

355 elif choice['type'] == 'dataclass-args': 

356 return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name) 

357 

358 elif choice['type'] == 'typed-dict': 

359 return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name) 

360 

361 elif choice['type'] == 'definition-ref': 

362 schema_ref = choice['schema_ref'] 

363 if schema_ref not in self.definitions: 

364 raise MissingDefinitionForUnionRef(schema_ref) 

365 return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name) 

366 else: 

367 err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.' 

368 if choice['type'] == 'list': 

369 err_str += ( 

370 ' If you are making use of a list of union types, make sure the discriminator is applied to the ' 

371 'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).' 

372 ) 

373 raise TypeError(err_str) 

374 

375 def _infer_discriminator_values_for_typed_dict_choice( 

376 self, choice: core_schema.TypedDictSchema, source_name: str | None = None 

377 ) -> list[str | int]: 

378 """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema 

379 for the sake of readability. 

380 """ 

381 source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}' 

382 field = choice['fields'].get(self.discriminator) 

383 if field is None: 

384 raise PydanticUserError( 

385 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' 

386 ) 

387 return self._infer_discriminator_values_for_field(field, source) 

388 

389 def _infer_discriminator_values_for_model_choice( 

390 self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None 

391 ) -> list[str | int]: 

392 source = 'ModelFields' if source_name is None else f'Model {source_name!r}' 

393 field = choice['fields'].get(self.discriminator) 

394 if field is None: 

395 raise PydanticUserError( 

396 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' 

397 ) 

398 return self._infer_discriminator_values_for_field(field, source) 

399 

400 def _infer_discriminator_values_for_dataclass_choice( 

401 self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None 

402 ) -> list[str | int]: 

403 source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}' 

404 for field in choice['fields']: 

405 if field['name'] == self.discriminator: 

406 break 

407 else: 

408 raise PydanticUserError( 

409 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' 

410 ) 

411 return self._infer_discriminator_values_for_field(field, source) 

412 

413 def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]: 

414 if field['type'] == 'computed-field': 

415 # This should never occur as a discriminator, as it is only relevant to serialization 

416 return [] 

417 alias = field.get('validation_alias', self.discriminator) 

418 if not isinstance(alias, str): 

419 raise PydanticUserError( 

420 f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type' 

421 ) 

422 if self._discriminator_alias is None: 

423 self._discriminator_alias = alias 

424 elif self._discriminator_alias != alias: 

425 raise PydanticUserError( 

426 f'Aliases for discriminator {self.discriminator!r} must be the same ' 

427 f'(got {alias}, {self._discriminator_alias})', 

428 code='discriminator-alias', 

429 ) 

430 return self._infer_discriminator_values_for_inner_schema(field['schema'], source) 

431 

432 def _infer_discriminator_values_for_inner_schema( 

433 self, schema: core_schema.CoreSchema, source: str 

434 ) -> list[str | int]: 

435 """When inferring discriminator values for a field, we typically extract the expected values from a literal 

436 schema. This function does that, but also handles nested unions and defaults. 

437 """ 

438 if schema['type'] == 'literal': 

439 return schema['expected'] 

440 

441 elif schema['type'] == 'union': 

442 # Generally when multiple values are allowed they should be placed in a single `Literal`, but 

443 # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s. 

444 # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]` 

445 values: list[Any] = [] 

446 for choice in schema['choices']: 

447 choice_schema = choice[0] if isinstance(choice, tuple) else choice 

448 choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source) 

449 values.extend(choice_values) 

450 return values 

451 

452 elif schema['type'] == 'default': 

453 # This will happen if the field has a default value; we ignore it while extracting the discriminator values 

454 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 

455 

456 elif schema['type'] == 'function-after': 

457 # After validators don't affect the discriminator values 

458 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 

459 

460 elif schema['type'] == 'model' and schema.get('root_model'): 

461 # Support RootModel[Literal[...]] as discriminator field type 

462 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 

463 

464 elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}: 

465 validator_type = repr(schema['type'].split('-')[1]) 

466 raise PydanticUserError( 

467 f'Cannot use a mode={validator_type} validator in the' 

468 f' discriminator field {self.discriminator!r} of {source}', 

469 code='discriminator-validator', 

470 ) 

471 

472 else: 

473 raise PydanticUserError( 

474 f'{source} needs field {self.discriminator!r} to be of type `Literal`', 

475 code='discriminator-needs-literal', 

476 ) 

477 

478 def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None: 

479 """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the 

480 provided `choice`, validating that none of these values already map to another (different) choice. 

481 """ 

482 for discriminator_value in values: 

483 if discriminator_value in self._tagged_union_choices: 

484 # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value. 

485 # Because tagged_union_choices may map values to other values, we need to walk the choices dict 

486 # until we get to a "real" choice, and confirm that is equal to the one assigned. 

487 existing_choice = self._tagged_union_choices[discriminator_value] 

488 if existing_choice != choice: 

489 raise TypeError( 

490 f'Value {discriminator_value!r} for discriminator ' 

491 f'{self.discriminator!r} mapped to multiple choices' 

492 ) 

493 else: 

494 self._tagged_union_choices[discriminator_value] = choice