Coverage for /pythoncovmergedfiles/medio/medio/src/pydantic/pydantic/_internal/_core_utils.py: 49%
226 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-27 07:38 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-27 07:38 +0000
1# TODO: Should we move WalkAndApply into pydantic_core proper?
3from __future__ import annotations
5from typing import Any, Callable, Union, cast
7from pydantic_core import CoreSchema, CoreSchemaType, core_schema
8from typing_extensions import TypeGuard, get_args
10from . import _repr
12AnyFunctionSchema = Union[
13 core_schema.AfterValidatorFunctionSchema,
14 core_schema.BeforeValidatorFunctionSchema,
15 core_schema.WrapValidatorFunctionSchema,
16 core_schema.PlainValidatorFunctionSchema,
17]
20FunctionSchemaWithInnerSchema = Union[
21 core_schema.AfterValidatorFunctionSchema,
22 core_schema.BeforeValidatorFunctionSchema,
23 core_schema.WrapValidatorFunctionSchema,
24]
26_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field'}
29def is_typed_dict_field(
30 schema: CoreSchema | core_schema.TypedDictField | core_schema.DataclassField,
31) -> TypeGuard[core_schema.TypedDictField]:
32 return schema['type'] == 'typed-dict-field'
35def is_dataclass_field(
36 schema: CoreSchema | core_schema.TypedDictField | core_schema.DataclassField,
37) -> TypeGuard[core_schema.DataclassField]:
38 return schema['type'] == 'dataclass-field'
41def is_core_schema(
42 schema: CoreSchema | core_schema.TypedDictField | core_schema.DataclassField,
43) -> TypeGuard[CoreSchema]:
44 return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
47def is_function_with_inner_schema(
48 schema: CoreSchema | core_schema.TypedDictField,
49) -> TypeGuard[FunctionSchemaWithInnerSchema]:
50 return is_core_schema(schema) and schema['type'] in ('function-before', 'function-after', 'function-wrap')
53def is_list_like_schema_with_items_schema(
54 schema: CoreSchema,
55) -> TypeGuard[
56 core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema
57]:
58 return schema['type'] in ('list', 'tuple-variable', 'set', 'frozenset')
61def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
62 """
63 Produces the ref to be used for this type by pydantic_core's core schemas.
65 This `args_override` argument was added for the purpose of creating valid recursive references
66 when creating generic models without needing to create a concrete class.
67 """
68 origin = type_
69 args = args_override or ()
70 generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
71 if generic_metadata:
72 origin = generic_metadata['origin'] or origin
73 args = generic_metadata['args'] or args
75 module_name = getattr(origin, '__module__', '<No __module__>')
76 qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
77 type_ref = f'{module_name}.{qualname}:{id(origin)}'
79 arg_refs: list[str] = []
80 for arg in args:
81 if isinstance(arg, str):
82 # Handle string literals as a special case; we may be able to remove this special handling if we
83 # wrap them in a ForwardRef at some point.
84 arg_ref = f'{arg}:str-{id(arg)}'
85 else:
86 arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
87 arg_refs.append(arg_ref)
88 if arg_refs:
89 type_ref = f'{type_ref}[{",".join(arg_refs)}]'
90 return type_ref
93def consolidate_refs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
94 """
95 This function walks a schema recursively, replacing all but the first occurrence of each ref with
96 a definition-ref schema referencing that ref.
98 This makes the fundamental assumption that any time two schemas have the same ref, occurrences
99 after the first can safely be replaced.
101 In most cases, schemas with the same ref should not actually be produced. However, when building recursive
102 models with multiple references to themselves at some level in the field hierarchy, it is difficult to avoid
103 getting multiple (identical) copies of the same schema with the same ref. This function removes the copied refs,
104 but is safe because the "duplicate" refs refer to the same schema.
106 There is one case where we purposely emit multiple (different) schemas with the same ref: when building
107 recursive generic models. In this case, as an implementation detail, recursive generic models will emit
108 a _non_-identical schema deeper in the tree with a re-used ref, with the intent that _that_ schema will
109 be replaced with a recursive reference once the specific generic parametrization to use can be determined.
110 """
111 refs: set[str] = set()
113 def _replace_refs(s: core_schema.CoreSchema) -> core_schema.CoreSchema:
114 ref: str | None = s.get('ref') # type: ignore[assignment]
115 if ref:
116 if ref in refs:
117 return {'type': 'definition-ref', 'schema_ref': ref}
118 refs.add(ref)
119 return s
121 schema = WalkCoreSchema(_replace_refs, apply_before_recurse=True).walk(schema)
122 return schema
125def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
126 # Only collect valid definitions. This is equivalent to collecting all definitions for "valid" schemas,
127 # but allows us to reuse this logic while removing "invalid" definitions
128 valid_definitions = dict()
130 def _record_valid_refs(s: core_schema.CoreSchema) -> core_schema.CoreSchema:
131 ref: str | None = s.get('ref') # type: ignore[assignment]
132 if ref:
133 metadata = s.get('metadata')
134 definition_is_invalid = isinstance(metadata, dict) and 'invalid' in metadata
135 if not definition_is_invalid:
136 valid_definitions[ref] = s
137 return s
139 WalkCoreSchema(_record_valid_refs).walk(schema)
141 return valid_definitions
144def remove_unnecessary_invalid_definitions(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
145 valid_refs = collect_definitions(schema).keys()
147 def _remove_invalid_defs(s: core_schema.CoreSchema) -> core_schema.CoreSchema:
148 if s['type'] != 'definitions':
149 return s
151 new_schema = s.copy()
153 new_definitions: list[CoreSchema] = []
154 for definition in s['definitions']:
155 metadata = definition.get('metadata')
156 # fmt: off
157 if (
158 isinstance(metadata, dict)
159 and 'invalid' in metadata
160 and definition['ref'] in valid_refs # type: ignore
161 ):
162 continue
163 # fmt: on
164 new_definitions.append(definition)
166 new_schema['definitions'] = new_definitions
167 return new_schema
169 return WalkCoreSchema(_remove_invalid_defs).walk(schema)
172def define_expected_missing_refs(
173 schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
174) -> core_schema.CoreSchema:
175 if not allowed_missing_refs:
176 # in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
177 # this is a common case (will be hit for all non-generic models), so it's worth optimizing for
178 return schema
179 refs = set()
181 def _record_refs(s: core_schema.CoreSchema) -> core_schema.CoreSchema:
182 ref: str | None = s.get('ref') # type: ignore[assignment]
183 if ref:
184 refs.add(ref)
185 return s
187 WalkCoreSchema(_record_refs).walk(schema)
189 expected_missing_refs = allowed_missing_refs.difference(refs)
190 if expected_missing_refs:
191 definitions: list[core_schema.CoreSchema] = [
192 # TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
193 core_schema.none_schema(ref=ref, metadata={'pydantic_debug_missing_ref': True, 'invalid': True})
194 for ref in expected_missing_refs
195 ]
196 return core_schema.definitions_schema(schema, definitions)
197 return schema
200def collect_invalid_schemas(schema: core_schema.CoreSchema) -> list[core_schema.CoreSchema]:
201 invalid_schemas: list[core_schema.CoreSchema] = []
203 def _is_schema_valid(s: core_schema.CoreSchema) -> core_schema.CoreSchema:
204 if s.get('metadata', {}).get('invalid'):
205 invalid_schemas.append(s)
206 return s
208 WalkCoreSchema(_is_schema_valid).walk(schema)
209 return invalid_schemas
212class WalkCoreSchema:
213 """
214 Transforms a CoreSchema by recursively calling the provided function on all (nested) fields of type CoreSchema
216 The provided function need not actually modify the schema in any way, but will still be called on all nested
217 fields with type CoreSchema. (This can be useful for collecting information about refs, etc.)
218 """
220 def __init__(
221 self, f: Callable[[core_schema.CoreSchema], core_schema.CoreSchema], apply_before_recurse: bool = True
222 ):
223 self.f = f
225 self.apply_before_recurse = apply_before_recurse
227 self._schema_type_to_method = self._build_schema_type_to_method()
229 def _build_schema_type_to_method(self) -> dict[CoreSchemaType, Callable[[CoreSchema], CoreSchema]]:
230 mapping: dict[CoreSchemaType, Callable[[CoreSchema], CoreSchema]] = {}
231 for key in get_args(CoreSchemaType):
232 method_name = f"handle_{key.replace('-', '_')}_schema"
233 mapping[key] = getattr(self, method_name, self._handle_other_schemas)
234 return mapping
236 def walk(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
237 return self._walk(schema)
239 def _walk(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
240 schema = schema.copy()
241 if self.apply_before_recurse:
242 schema = self.f(schema)
243 method = self._schema_type_to_method[schema['type']]
244 schema = method(schema)
245 if not self.apply_before_recurse:
246 schema = self.f(schema)
247 return schema
249 def _handle_other_schemas(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
250 if 'schema' in schema:
251 schema['schema'] = self._walk(schema['schema']) # type: ignore
252 return schema
254 def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema) -> CoreSchema:
255 new_definitions = []
256 for definition in schema['definitions']:
257 updated_definition = self._walk(definition)
258 if 'ref' in updated_definition:
259 # If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
260 # This is most likely to happen due to replacing something with a definition reference, in
261 # which case it should certainly not go in the definitions list
262 new_definitions.append(updated_definition)
263 new_inner_schema = self._walk(schema['schema'])
265 if not new_definitions and len(schema) == 3:
266 # This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
267 return new_inner_schema
269 new_schema = schema.copy()
270 new_schema['schema'] = new_inner_schema
271 new_schema['definitions'] = new_definitions
272 return new_schema
274 def handle_list_schema(self, schema: core_schema.ListSchema) -> CoreSchema:
275 if 'items_schema' in schema:
276 schema['items_schema'] = self._walk(schema['items_schema'])
277 return schema
279 def handle_set_schema(self, schema: core_schema.SetSchema) -> CoreSchema:
280 if 'items_schema' in schema:
281 schema['items_schema'] = self._walk(schema['items_schema'])
282 return schema
284 def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema) -> CoreSchema:
285 if 'items_schema' in schema:
286 schema['items_schema'] = self._walk(schema['items_schema'])
287 return schema
289 def handle_generator_schema(self, schema: core_schema.GeneratorSchema) -> CoreSchema:
290 if 'items_schema' in schema:
291 schema['items_schema'] = self._walk(schema['items_schema'])
292 return schema
294 def handle_tuple_variable_schema(
295 self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema
296 ) -> CoreSchema:
297 schema = cast(core_schema.TupleVariableSchema, schema)
298 if 'items_schema' in schema:
299 # Could drop the # type: ignore on the next line if we made 'mode' required in TupleVariableSchema
300 schema['items_schema'] = self._walk(schema['items_schema'])
301 return schema
303 def handle_tuple_positional_schema(
304 self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema
305 ) -> CoreSchema:
306 schema = cast(core_schema.TuplePositionalSchema, schema)
307 schema['items_schema'] = [self._walk(v) for v in schema['items_schema']]
308 if 'extra_schema' in schema:
309 schema['extra_schema'] = self._walk(schema['extra_schema'])
310 return schema
312 def handle_dict_schema(self, schema: core_schema.DictSchema) -> CoreSchema:
313 if 'keys_schema' in schema:
314 schema['keys_schema'] = self._walk(schema['keys_schema'])
315 if 'values_schema' in schema:
316 schema['values_schema'] = self._walk(schema['values_schema'])
317 return schema
319 def handle_function_schema(
320 self,
321 schema: AnyFunctionSchema,
322 ) -> CoreSchema:
323 if not is_function_with_inner_schema(schema):
324 return schema
325 schema['schema'] = self._walk(schema['schema'])
326 return schema
328 def handle_union_schema(self, schema: core_schema.UnionSchema) -> CoreSchema:
329 schema['choices'] = [self._walk(v) for v in schema['choices']]
330 return schema
332 def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema) -> CoreSchema:
333 new_choices: dict[str | int, str | int | CoreSchema] = {}
334 for k, v in schema['choices'].items():
335 new_choices[k] = v if isinstance(v, (str, int)) else self._walk(v)
336 schema['choices'] = new_choices
337 return schema
339 def handle_chain_schema(self, schema: core_schema.ChainSchema) -> CoreSchema:
340 schema['steps'] = [self._walk(v) for v in schema['steps']]
341 return schema
343 def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema) -> CoreSchema:
344 schema['lax_schema'] = self._walk(schema['lax_schema'])
345 schema['strict_schema'] = self._walk(schema['strict_schema'])
346 return schema
348 def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> CoreSchema:
349 if 'extra_validator' in schema:
350 schema['extra_validator'] = self._walk(schema['extra_validator'])
351 replaced_fields: dict[str, core_schema.TypedDictField] = {}
352 for k, v in schema['fields'].items():
353 replaced_field = v.copy()
354 replaced_field['schema'] = self._walk(v['schema'])
355 replaced_fields[k] = replaced_field
356 schema['fields'] = replaced_fields
357 return schema
359 def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema) -> CoreSchema:
360 replaced_fields: list[core_schema.DataclassField] = []
361 for field in schema['fields']:
362 replaced_field = field.copy()
363 replaced_field['schema'] = self._walk(field['schema'])
364 replaced_fields.append(replaced_field)
365 schema['fields'] = replaced_fields
366 return schema
368 def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema) -> CoreSchema:
369 replaced_arguments_schema = []
370 for param in schema['arguments_schema']:
371 replaced_param = param.copy()
372 replaced_param['schema'] = self._walk(param['schema'])
373 replaced_arguments_schema.append(replaced_param)
374 schema['arguments_schema'] = replaced_arguments_schema
375 if 'var_args_schema' in schema:
376 schema['var_args_schema'] = self._walk(schema['var_args_schema'])
377 if 'var_kwargs_schema' in schema:
378 schema['var_kwargs_schema'] = self._walk(schema['var_kwargs_schema'])
379 return schema
381 def handle_call_schema(self, schema: core_schema.CallSchema) -> CoreSchema:
382 schema['arguments_schema'] = self._walk(schema['arguments_schema'])
383 if 'return_schema' in schema:
384 schema['return_schema'] = self._walk(schema['return_schema'])
385 return schema