1# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
2from __future__ import annotations
3
4from dataclasses import dataclass, field
5from typing import TypedDict
6
7from pydantic_core.core_schema import (
8 ComputedField,
9 CoreSchema,
10 DefinitionReferenceSchema,
11 SerSchema,
12 iter_union_choices,
13)
14from typing_extensions import TypeAlias
15
16AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
17
18
19class GatherResult(TypedDict):
20 """Schema traversing result."""
21
22 collected_references: dict[str, DefinitionReferenceSchema | None]
23 """The collected definition references.
24
25 If a definition reference schema can be inlined, it means that there is
26 only one in the whole core schema. As such, it is stored as the value.
27 Otherwise, the value is set to `None`.
28 """
29
30 deferred_discriminator_schemas: list[CoreSchema]
31 """The list of core schemas having the discriminator application deferred."""
32
33
34class MissingDefinitionError(LookupError):
35 """A reference was pointing to a non-existing core schema."""
36
37 def __init__(self, schema_reference: str, /) -> None:
38 self.schema_reference = schema_reference
39
40
41@dataclass
42class GatherContext:
43 """The current context used during core schema traversing.
44
45 Context instances should only be used during schema traversing.
46 """
47
48 definitions: dict[str, CoreSchema]
49 """The available definitions."""
50
51 deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
52 """The list of core schemas having the discriminator application deferred.
53
54 Internally, these core schemas have a specific key set in the core metadata dict.
55 """
56
57 collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
58 """The collected definition references.
59
60 If a definition reference schema can be inlined, it means that there is
61 only one in the whole core schema. As such, it is stored as the value.
62 Otherwise, the value is set to `None`.
63
64 During schema traversing, definition reference schemas can be added as candidates, or removed
65 (by setting the value to `None`).
66 """
67
68
69def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
70 meta = schema.get('metadata')
71 if meta is not None and 'pydantic_internal_union_discriminator' in meta:
72 ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
73
74
75def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
76 schema_ref = def_ref_schema['schema_ref']
77
78 if schema_ref not in ctx.collected_references:
79 definition = ctx.definitions.get(schema_ref)
80 if definition is None:
81 raise MissingDefinitionError(schema_ref)
82
83 # The `'definition-ref'` schema was only encountered once, make it
84 # a candidate to be inlined:
85 ctx.collected_references[schema_ref] = def_ref_schema
86 traverse_schema(definition, ctx)
87 if 'serialization' in def_ref_schema:
88 traverse_schema(def_ref_schema['serialization'], ctx)
89 traverse_metadata(def_ref_schema, ctx)
90 else:
91 # The `'definition-ref'` schema was already encountered, meaning
92 # the previously encountered schema (and this one) can't be inlined:
93 ctx.collected_references[schema_ref] = None
94
95
96def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
97 # TODO When we drop 3.9, use a match statement to get better type checking and remove
98 # file-level type ignore.
99 # (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
100 schema_type = schema['type']
101
102 if schema_type == 'definition-ref':
103 traverse_definition_ref(schema, context)
104 # `traverse_definition_ref` handles the possible serialization and metadata schemas:
105 return
106 elif schema_type == 'definitions':
107 traverse_schema(schema['schema'], context)
108 for definition in schema['definitions']:
109 traverse_schema(definition, context)
110 elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
111 if 'items_schema' in schema:
112 traverse_schema(schema['items_schema'], context)
113 elif schema_type == 'tuple':
114 if 'items_schema' in schema:
115 for s in schema['items_schema']:
116 traverse_schema(s, context)
117 elif schema_type == 'dict':
118 if 'keys_schema' in schema:
119 traverse_schema(schema['keys_schema'], context)
120 if 'values_schema' in schema:
121 traverse_schema(schema['values_schema'], context)
122 elif schema_type == 'union':
123 for choice in iter_union_choices(schema):
124 traverse_schema(choice, context)
125 elif schema_type == 'tagged-union':
126 for v in schema['choices'].values():
127 traverse_schema(v, context)
128 elif schema_type == 'chain':
129 for step in schema['steps']:
130 traverse_schema(step, context)
131 elif schema_type == 'lax-or-strict':
132 traverse_schema(schema['lax_schema'], context)
133 traverse_schema(schema['strict_schema'], context)
134 elif schema_type == 'json-or-python':
135 traverse_schema(schema['json_schema'], context)
136 traverse_schema(schema['python_schema'], context)
137 elif schema_type in {'model-fields', 'typed-dict'}:
138 if 'extras_schema' in schema:
139 traverse_schema(schema['extras_schema'], context)
140 if 'computed_fields' in schema:
141 for s in schema['computed_fields']:
142 traverse_schema(s, context)
143 for s in schema['fields'].values():
144 traverse_schema(s, context)
145 elif schema_type == 'dataclass-args':
146 if 'computed_fields' in schema:
147 for s in schema['computed_fields']:
148 traverse_schema(s, context)
149 for s in schema['fields']:
150 traverse_schema(s, context)
151 elif schema_type == 'arguments':
152 for s in schema['arguments_schema']:
153 traverse_schema(s['schema'], context)
154 if 'var_args_schema' in schema:
155 traverse_schema(schema['var_args_schema'], context)
156 if 'var_kwargs_schema' in schema:
157 traverse_schema(schema['var_kwargs_schema'], context)
158 elif schema_type == 'arguments-v3':
159 for s in schema['arguments_schema']:
160 traverse_schema(s['schema'], context)
161 elif schema_type == 'call':
162 traverse_schema(schema['arguments_schema'], context)
163 if 'return_schema' in schema:
164 traverse_schema(schema['return_schema'], context)
165 elif schema_type == 'computed-field':
166 traverse_schema(schema['return_schema'], context)
167 elif schema_type == 'function-before':
168 if 'schema' in schema:
169 traverse_schema(schema['schema'], context)
170 if 'json_schema_input_schema' in schema:
171 traverse_schema(schema['json_schema_input_schema'], context)
172 elif schema_type == 'function-plain':
173 # TODO duplicate schema types for serializers and validators, needs to be deduplicated.
174 if 'return_schema' in schema:
175 traverse_schema(schema['return_schema'], context)
176 if 'json_schema_input_schema' in schema:
177 traverse_schema(schema['json_schema_input_schema'], context)
178 elif schema_type == 'function-wrap':
179 # TODO duplicate schema types for serializers and validators, needs to be deduplicated.
180 if 'return_schema' in schema:
181 traverse_schema(schema['return_schema'], context)
182 if 'schema' in schema:
183 traverse_schema(schema['schema'], context)
184 if 'json_schema_input_schema' in schema:
185 traverse_schema(schema['json_schema_input_schema'], context)
186 else:
187 if 'schema' in schema:
188 traverse_schema(schema['schema'], context)
189
190 if 'serialization' in schema:
191 traverse_schema(schema['serialization'], context)
192 traverse_metadata(schema, context)
193
194
195def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
196 """Traverse the core schema and definitions and return the necessary information for schema cleaning.
197
198 During the core schema traversing, any `'definition-ref'` schema is:
199
200 - Validated: the reference must point to an existing definition. If this is not the case, a
201 `MissingDefinitionError` exception is raised.
202 - Stored in the context: the actual reference is stored in the context. Depending on whether
203 the `'definition-ref'` schema is encountered more that once, the schema itself is also
204 saved in the context to be inlined (i.e. replaced by the definition it points to).
205 """
206 context = GatherContext(definitions)
207 traverse_schema(schema, context)
208
209 return {
210 'collected_references': context.collected_references,
211 'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
212 }