1from __future__ import annotations as _annotations
2
3from collections.abc import Hashable, Sequence
4from typing import TYPE_CHECKING, Any, cast
5
6from pydantic_core import CoreSchema, core_schema
7
8from ..errors import PydanticUserError
9from . import _core_utils
10from ._core_utils import (
11 CoreSchemaField,
12)
13
14if TYPE_CHECKING:
15 from ..types import Discriminator
16 from ._core_metadata import CoreMetadata
17
18
19class MissingDefinitionForUnionRef(Exception):
20 """Raised when applying a discriminated union discriminator to a schema
21 requires a definition that is not yet defined
22 """
23
24 def __init__(self, ref: str) -> None:
25 self.ref = ref
26 super().__init__(f'Missing definition for ref {self.ref!r}')
27
28
29def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
30 metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
31 metadata['pydantic_internal_union_discriminator'] = discriminator
32
33
34def apply_discriminator(
35 schema: core_schema.CoreSchema,
36 discriminator: str | Discriminator,
37 definitions: dict[str, core_schema.CoreSchema] | None = None,
38) -> core_schema.CoreSchema:
39 """Applies the discriminator and returns a new core schema.
40
41 Args:
42 schema: The input schema.
43 discriminator: The name of the field which will serve as the discriminator.
44 definitions: A mapping of schema ref to schema.
45
46 Returns:
47 The new core schema.
48
49 Raises:
50 TypeError:
51 - If `discriminator` is used with invalid union variant.
52 - If `discriminator` is used with `Union` type with one variant.
53 - If `discriminator` value mapped to multiple choices.
54 MissingDefinitionForUnionRef:
55 If the definition for ref is missing.
56 PydanticUserError:
57 - If a model in union doesn't have a discriminator field.
58 - If discriminator field has a non-string alias.
59 - If discriminator fields have different aliases.
60 - If discriminator field not of type `Literal`.
61 """
62 from ..types import Discriminator
63
64 if isinstance(discriminator, Discriminator):
65 if isinstance(discriminator.discriminator, str):
66 discriminator = discriminator.discriminator
67 else:
68 return discriminator._convert_schema(schema)
69
70 return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
71
72
73class _ApplyInferredDiscriminator:
74 """This class is used to convert an input schema containing a union schema into one where that union is
75 replaced with a tagged-union, with all the associated debugging and performance benefits.
76
77 This is done by:
78 * Validating that the input schema is compatible with the provided discriminator
79 * Introspecting the schema to determine which discriminator values should map to which union choices
80 * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
81
82 I have chosen to implement the conversion algorithm in this class, rather than a function,
83 to make it easier to maintain state while recursively walking the provided CoreSchema.
84 """
85
86 def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
87 # `discriminator` should be the name of the field which will serve as the discriminator.
88 # It must be the python name of the field, and *not* the field's alias. Note that as of now,
89 # all members of a discriminated union _must_ use a field with the same name as the discriminator.
90 # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
91 self.discriminator = discriminator
92
93 # `definitions` should contain a mapping of schema ref to schema for all schemas which might
94 # be referenced by some choice
95 self.definitions = definitions
96
97 # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
98 #
99 # Note: following the v1 implementation, we currently disallow the use of different aliases
100 # for different choices. This is not a limitation of pydantic_core, but if we try to handle
101 # this, the inference logic gets complicated very quickly, and could result in confusing
102 # debugging challenges for users making subtle mistakes.
103 #
104 # Rather than trying to do the most powerful inference possible, I think we should eventually
105 # expose a way to more-manually control the way the TaggedUnionSchema is constructed through
106 # the use of a new type which would be placed as an Annotation on the Union type. This would
107 # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
108 # more complex cases, without over-complicating the inference logic for the common cases.
109 self._discriminator_alias: str | None = None
110
111 # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
112 # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
113 # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
114 # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
115 # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
116 # python side, but resolves the issue that `None` cannot correspond to any discriminator values.
117 self._should_be_nullable = False
118
119 # `_is_nullable` is used to track if the final produced schema will definitely be nullable;
120 # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
121 # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
122 # the final value in another nullable schema.
123 #
124 # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
125 # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
126 self._is_nullable = False
127
128 # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
129 # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
130 # algorithm may add more choices to this stack as (nested) unions are encountered.
131 self._choices_to_handle: list[core_schema.CoreSchema] = []
132
133 # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
134 # in the output TaggedUnionSchema that will replace the union from the input schema
135 self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
136
137 # `_used` is changed to True after applying the discriminator to prevent accidental reuse
138 self._used = False
139
140 def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
141 """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
142 to this class.
143
144 Args:
145 schema: The input schema.
146
147 Returns:
148 The new core schema.
149
150 Raises:
151 TypeError:
152 - If `discriminator` is used with invalid union variant.
153 - If `discriminator` is used with `Union` type with one variant.
154 - If `discriminator` value mapped to multiple choices.
155 ValueError:
156 If the definition for ref is missing.
157 PydanticUserError:
158 - If a model in union doesn't have a discriminator field.
159 - If discriminator field has a non-string alias.
160 - If discriminator fields have different aliases.
161 - If discriminator field not of type `Literal`.
162 """
163 assert not self._used
164 schema = self._apply_to_root(schema)
165 if self._should_be_nullable and not self._is_nullable:
166 schema = core_schema.nullable_schema(schema)
167 self._used = True
168 return schema
169
170 def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
171 """This method handles the outer-most stage of recursion over the input schema:
172 unwrapping nullable or definitions schemas, and calling the `_handle_choice`
173 method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
174 """
175 if schema['type'] == 'nullable':
176 self._is_nullable = True
177 wrapped = self._apply_to_root(schema['schema'])
178 nullable_wrapper = schema.copy()
179 nullable_wrapper['schema'] = wrapped
180 return nullable_wrapper
181
182 if schema['type'] == 'definitions':
183 wrapped = self._apply_to_root(schema['schema'])
184 definitions_wrapper = schema.copy()
185 definitions_wrapper['schema'] = wrapped
186 return definitions_wrapper
187
188 if schema['type'] != 'union':
189 # If the schema is not a union, it probably means it just had a single member and
190 # was flattened by pydantic_core.
191 # However, it still may make sense to apply the discriminator to this schema,
192 # as a way to get discriminated-union-style error messages, so we allow this here.
193 schema = core_schema.union_schema([schema])
194
195 # Reverse the choices list before extending the stack so that they get handled in the order they occur
196 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
197 self._choices_to_handle.extend(choices_schemas)
198 while self._choices_to_handle:
199 choice = self._choices_to_handle.pop()
200 self._handle_choice(choice)
201
202 if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
203 # * We need to annotate `discriminator` as a union here to handle both branches of this conditional
204 # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
205 # invariance of list, and because list[list[str | int]] is the type of the discriminator argument
206 # to tagged_union_schema below
207 # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
208 # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
209 # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
210 discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
211 else:
212 discriminator = self.discriminator
213 return core_schema.tagged_union_schema(
214 choices=self._tagged_union_choices,
215 discriminator=discriminator,
216 custom_error_type=schema.get('custom_error_type'),
217 custom_error_message=schema.get('custom_error_message'),
218 custom_error_context=schema.get('custom_error_context'),
219 strict=False,
220 from_attributes=True,
221 ref=schema.get('ref'),
222 metadata=schema.get('metadata'),
223 serialization=schema.get('serialization'),
224 )
225
226 def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
227 """This method handles the "middle" stage of recursion over the input schema.
228 Specifically, it is responsible for handling each choice of the outermost union
229 (and any "coalesced" choices obtained from inner unions).
230
231 Here, "handling" entails:
232 * Coalescing nested unions and compatible tagged-unions
233 * Tracking the presence of 'none' and 'nullable' schemas occurring as choices
234 * Validating that each allowed discriminator value maps to a unique choice
235 * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
236 """
237 if choice['type'] == 'definition-ref':
238 if choice['schema_ref'] not in self.definitions:
239 raise MissingDefinitionForUnionRef(choice['schema_ref'])
240
241 if choice['type'] == 'none':
242 self._should_be_nullable = True
243 elif choice['type'] == 'definitions':
244 self._handle_choice(choice['schema'])
245 elif choice['type'] == 'nullable':
246 self._should_be_nullable = True
247 self._handle_choice(choice['schema']) # unwrap the nullable schema
248 elif choice['type'] == 'union':
249 # Reverse the choices list before extending the stack so that they get handled in the order they occur
250 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
251 self._choices_to_handle.extend(choices_schemas)
252 elif choice['type'] not in {
253 'model',
254 'typed-dict',
255 'tagged-union',
256 'lax-or-strict',
257 'dataclass',
258 'dataclass-args',
259 'definition-ref',
260 } and not _core_utils.is_function_with_inner_schema(choice):
261 # We should eventually handle 'definition-ref' as well
262 err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
263 if choice['type'] == 'list':
264 err_str += (
265 ' If you are making use of a list of union types, make sure the discriminator is applied to the '
266 'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
267 )
268 raise TypeError(err_str)
269 else:
270 if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
271 # In this case, this inner tagged-union is compatible with the outer tagged-union,
272 # and its choices can be coalesced into the outer TaggedUnionSchema.
273 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
274 # Reverse the choices list before extending the stack so that they get handled in the order they occur
275 self._choices_to_handle.extend(subchoices[::-1])
276 return
277
278 inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
279 self._set_unique_choice_for_values(choice, inferred_discriminator_values)
280
281 def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
282 """This method returns a boolean indicating whether the discriminator for the `choice`
283 is the same as that being used for the outermost tagged union. This is used to
284 determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
285 or whether it should be treated as a separate (nested) choice.
286 """
287 inner_discriminator = choice['discriminator']
288 return inner_discriminator == self.discriminator or (
289 isinstance(inner_discriminator, list)
290 and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
291 )
292
293 def _infer_discriminator_values_for_choice( # noqa C901
294 self, choice: core_schema.CoreSchema, source_name: str | None
295 ) -> list[str | int]:
296 """This function recurses over `choice`, extracting all discriminator values that should map to this choice.
297
298 `model_name` is accepted for the purpose of producing useful error messages.
299 """
300 if choice['type'] == 'definitions':
301 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
302
303 elif _core_utils.is_function_with_inner_schema(choice):
304 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
305
306 elif choice['type'] == 'lax-or-strict':
307 return sorted(
308 set(
309 self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
310 + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
311 )
312 )
313
314 elif choice['type'] == 'tagged-union':
315 values: list[str | int] = []
316 # Ignore str/int "choices" since these are just references to other choices
317 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
318 for subchoice in subchoices:
319 subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
320 values.extend(subchoice_values)
321 return values
322
323 elif choice['type'] == 'union':
324 values = []
325 for subchoice in choice['choices']:
326 subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
327 subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
328 values.extend(subchoice_values)
329 return values
330
331 elif choice['type'] == 'nullable':
332 self._should_be_nullable = True
333 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
334
335 elif choice['type'] == 'model':
336 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
337
338 elif choice['type'] == 'dataclass':
339 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
340
341 elif choice['type'] == 'model-fields':
342 return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
343
344 elif choice['type'] == 'dataclass-args':
345 return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
346
347 elif choice['type'] == 'typed-dict':
348 return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
349
350 elif choice['type'] == 'definition-ref':
351 schema_ref = choice['schema_ref']
352 if schema_ref not in self.definitions:
353 raise MissingDefinitionForUnionRef(schema_ref)
354 return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
355 else:
356 err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
357 if choice['type'] == 'list':
358 err_str += (
359 ' If you are making use of a list of union types, make sure the discriminator is applied to the '
360 'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
361 )
362 raise TypeError(err_str)
363
364 def _infer_discriminator_values_for_typed_dict_choice(
365 self, choice: core_schema.TypedDictSchema, source_name: str | None = None
366 ) -> list[str | int]:
367 """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
368 for the sake of readability.
369 """
370 source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
371 field = choice['fields'].get(self.discriminator)
372 if field is None:
373 raise PydanticUserError(
374 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
375 )
376 return self._infer_discriminator_values_for_field(field, source)
377
378 def _infer_discriminator_values_for_model_choice(
379 self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
380 ) -> list[str | int]:
381 source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
382 field = choice['fields'].get(self.discriminator)
383 if field is None:
384 raise PydanticUserError(
385 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
386 )
387 return self._infer_discriminator_values_for_field(field, source)
388
389 def _infer_discriminator_values_for_dataclass_choice(
390 self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
391 ) -> list[str | int]:
392 source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
393 for field in choice['fields']:
394 if field['name'] == self.discriminator:
395 break
396 else:
397 raise PydanticUserError(
398 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
399 )
400 return self._infer_discriminator_values_for_field(field, source)
401
402 def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
403 if field['type'] == 'computed-field':
404 # This should never occur as a discriminator, as it is only relevant to serialization
405 return []
406 alias = field.get('validation_alias', self.discriminator)
407 if not isinstance(alias, str):
408 raise PydanticUserError(
409 f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
410 )
411 if self._discriminator_alias is None:
412 self._discriminator_alias = alias
413 elif self._discriminator_alias != alias:
414 raise PydanticUserError(
415 f'Aliases for discriminator {self.discriminator!r} must be the same '
416 f'(got {alias}, {self._discriminator_alias})',
417 code='discriminator-alias',
418 )
419 return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
420
421 def _infer_discriminator_values_for_inner_schema(
422 self, schema: core_schema.CoreSchema, source: str
423 ) -> list[str | int]:
424 """When inferring discriminator values for a field, we typically extract the expected values from a literal
425 schema. This function does that, but also handles nested unions and defaults.
426 """
427 if schema['type'] == 'literal':
428 return schema['expected']
429
430 elif schema['type'] == 'union':
431 # Generally when multiple values are allowed they should be placed in a single `Literal`, but
432 # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
433 # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
434 values: list[Any] = []
435 for choice in schema['choices']:
436 choice_schema = choice[0] if isinstance(choice, tuple) else choice
437 choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
438 values.extend(choice_values)
439 return values
440
441 elif schema['type'] == 'default':
442 # This will happen if the field has a default value; we ignore it while extracting the discriminator values
443 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
444
445 elif schema['type'] == 'function-after':
446 # After validators don't affect the discriminator values
447 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
448
449 elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
450 validator_type = repr(schema['type'].split('-')[1])
451 raise PydanticUserError(
452 f'Cannot use a mode={validator_type} validator in the'
453 f' discriminator field {self.discriminator!r} of {source}',
454 code='discriminator-validator',
455 )
456
457 else:
458 raise PydanticUserError(
459 f'{source} needs field {self.discriminator!r} to be of type `Literal`',
460 code='discriminator-needs-literal',
461 )
462
463 def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
464 """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
465 provided `choice`, validating that none of these values already map to another (different) choice.
466 """
467 for discriminator_value in values:
468 if discriminator_value in self._tagged_union_choices:
469 # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
470 # Because tagged_union_choices may map values to other values, we need to walk the choices dict
471 # until we get to a "real" choice, and confirm that is equal to the one assigned.
472 existing_choice = self._tagged_union_choices[discriminator_value]
473 if existing_choice != choice:
474 raise TypeError(
475 f'Value {discriminator_value!r} for discriminator '
476 f'{self.discriminator!r} mapped to multiple choices'
477 )
478 else:
479 self._tagged_union_choices[discriminator_value] = choice