1# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
2from __future__ import annotations
3
4from dataclasses import dataclass, field
5from typing import TypedDict
6
7from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
8from typing_extensions import TypeAlias
9
10AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
11
12
13class GatherResult(TypedDict):
14 """Schema traversing result."""
15
16 collected_references: dict[str, DefinitionReferenceSchema | None]
17 """The collected definition references.
18
19 If a definition reference schema can be inlined, it means that there is
20 only one in the whole core schema. As such, it is stored as the value.
21 Otherwise, the value is set to `None`.
22 """
23
24 deferred_discriminator_schemas: list[CoreSchema]
25 """The list of core schemas having the discriminator application deferred."""
26
27
28class MissingDefinitionError(LookupError):
29 """A reference was pointing to a non-existing core schema."""
30
31 def __init__(self, schema_reference: str, /) -> None:
32 self.schema_reference = schema_reference
33
34
35@dataclass
36class GatherContext:
37 """The current context used during core schema traversing.
38
39 Context instances should only be used during schema traversing.
40 """
41
42 definitions: dict[str, CoreSchema]
43 """The available definitions."""
44
45 deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
46 """The list of core schemas having the discriminator application deferred.
47
48 Internally, these core schemas have a specific key set in the core metadata dict.
49 """
50
51 collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
52 """The collected definition references.
53
54 If a definition reference schema can be inlined, it means that there is
55 only one in the whole core schema. As such, it is stored as the value.
56 Otherwise, the value is set to `None`.
57
58 During schema traversing, definition reference schemas can be added as candidates, or removed
59 (by setting the value to `None`).
60 """
61
62
63def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
64 meta = schema.get('metadata')
65 if meta is not None and 'pydantic_internal_union_discriminator' in meta:
66 ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
67
68
69def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
70 schema_ref = def_ref_schema['schema_ref']
71
72 if schema_ref not in ctx.collected_references:
73 definition = ctx.definitions.get(schema_ref)
74 if definition is None:
75 raise MissingDefinitionError(schema_ref)
76
77 # The `'definition-ref'` schema was only encountered once, make it
78 # a candidate to be inlined:
79 ctx.collected_references[schema_ref] = def_ref_schema
80 traverse_schema(definition, ctx)
81 if 'serialization' in def_ref_schema:
82 traverse_schema(def_ref_schema['serialization'], ctx)
83 traverse_metadata(def_ref_schema, ctx)
84 else:
85 # The `'definition-ref'` schema was already encountered, meaning
86 # the previously encountered schema (and this one) can't be inlined:
87 ctx.collected_references[schema_ref] = None
88
89
90def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
91 # TODO When we drop 3.9, use a match statement to get better type checking and remove
92 # file-level type ignore.
93 # (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
94 schema_type = schema['type']
95
96 if schema_type == 'definition-ref':
97 traverse_definition_ref(schema, context)
98 # `traverse_definition_ref` handles the possible serialization and metadata schemas:
99 return
100 elif schema_type == 'definitions':
101 traverse_schema(schema['schema'], context)
102 for definition in schema['definitions']:
103 traverse_schema(definition, context)
104 elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
105 if 'items_schema' in schema:
106 traverse_schema(schema['items_schema'], context)
107 elif schema_type == 'tuple':
108 if 'items_schema' in schema:
109 for s in schema['items_schema']:
110 traverse_schema(s, context)
111 elif schema_type == 'dict':
112 if 'keys_schema' in schema:
113 traverse_schema(schema['keys_schema'], context)
114 if 'values_schema' in schema:
115 traverse_schema(schema['values_schema'], context)
116 elif schema_type == 'union':
117 for choice in schema['choices']:
118 if isinstance(choice, tuple):
119 traverse_schema(choice[0], context)
120 else:
121 traverse_schema(choice, context)
122 elif schema_type == 'tagged-union':
123 for v in schema['choices'].values():
124 traverse_schema(v, context)
125 elif schema_type == 'chain':
126 for step in schema['steps']:
127 traverse_schema(step, context)
128 elif schema_type == 'lax-or-strict':
129 traverse_schema(schema['lax_schema'], context)
130 traverse_schema(schema['strict_schema'], context)
131 elif schema_type == 'json-or-python':
132 traverse_schema(schema['json_schema'], context)
133 traverse_schema(schema['python_schema'], context)
134 elif schema_type in {'model-fields', 'typed-dict'}:
135 if 'extras_schema' in schema:
136 traverse_schema(schema['extras_schema'], context)
137 if 'computed_fields' in schema:
138 for s in schema['computed_fields']:
139 traverse_schema(s, context)
140 for s in schema['fields'].values():
141 traverse_schema(s, context)
142 elif schema_type == 'dataclass-args':
143 if 'computed_fields' in schema:
144 for s in schema['computed_fields']:
145 traverse_schema(s, context)
146 for s in schema['fields']:
147 traverse_schema(s, context)
148 elif schema_type == 'arguments':
149 for s in schema['arguments_schema']:
150 traverse_schema(s['schema'], context)
151 if 'var_args_schema' in schema:
152 traverse_schema(schema['var_args_schema'], context)
153 if 'var_kwargs_schema' in schema:
154 traverse_schema(schema['var_kwargs_schema'], context)
155 elif schema_type == 'arguments-v3':
156 for s in schema['arguments_schema']:
157 traverse_schema(s['schema'], context)
158 elif schema_type == 'call':
159 traverse_schema(schema['arguments_schema'], context)
160 if 'return_schema' in schema:
161 traverse_schema(schema['return_schema'], context)
162 elif schema_type == 'computed-field':
163 traverse_schema(schema['return_schema'], context)
164 elif schema_type == 'function-before':
165 if 'schema' in schema:
166 traverse_schema(schema['schema'], context)
167 if 'json_schema_input_schema' in schema:
168 traverse_schema(schema['json_schema_input_schema'], context)
169 elif schema_type == 'function-plain':
170 # TODO duplicate schema types for serializers and validators, needs to be deduplicated.
171 if 'return_schema' in schema:
172 traverse_schema(schema['return_schema'], context)
173 if 'json_schema_input_schema' in schema:
174 traverse_schema(schema['json_schema_input_schema'], context)
175 elif schema_type == 'function-wrap':
176 # TODO duplicate schema types for serializers and validators, needs to be deduplicated.
177 if 'return_schema' in schema:
178 traverse_schema(schema['return_schema'], context)
179 if 'schema' in schema:
180 traverse_schema(schema['schema'], context)
181 if 'json_schema_input_schema' in schema:
182 traverse_schema(schema['json_schema_input_schema'], context)
183 else:
184 if 'schema' in schema:
185 traverse_schema(schema['schema'], context)
186
187 if 'serialization' in schema:
188 traverse_schema(schema['serialization'], context)
189 traverse_metadata(schema, context)
190
191
192def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
193 """Traverse the core schema and definitions and return the necessary information for schema cleaning.
194
195 During the core schema traversing, any `'definition-ref'` schema is:
196
197 - Validated: the reference must point to an existing definition. If this is not the case, a
198 `MissingDefinitionError` exception is raised.
199 - Stored in the context: the actual reference is stored in the context. Depending on whether
200 the `'definition-ref'` schema is encountered more that once, the schema itself is also
201 saved in the context to be inlined (i.e. replaced by the definition it points to).
202 """
203 context = GatherContext(definitions)
204 traverse_schema(schema, context)
205
206 return {
207 'collected_references': context.collected_references,
208 'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
209 }