1# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false 
    2from __future__ import annotations 
    3 
    4from dataclasses import dataclass, field 
    5from typing import TypedDict 
    6 
    7from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema 
    8from typing_extensions import TypeAlias 
    9 
    10AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField' 
    11 
    12 
    13class GatherResult(TypedDict): 
    14    """Schema traversing result.""" 
    15 
    16    collected_references: dict[str, DefinitionReferenceSchema | None] 
    17    """The collected definition references. 
    18 
    19    If a definition reference schema can be inlined, it means that there is 
    20    only one in the whole core schema. As such, it is stored as the value. 
    21    Otherwise, the value is set to `None`. 
    22    """ 
    23 
    24    deferred_discriminator_schemas: list[CoreSchema] 
    25    """The list of core schemas having the discriminator application deferred.""" 
    26 
    27 
    28class MissingDefinitionError(LookupError): 
    29    """A reference was pointing to a non-existing core schema.""" 
    30 
    31    def __init__(self, schema_reference: str, /) -> None: 
    32        self.schema_reference = schema_reference 
    33 
    34 
    35@dataclass 
    36class GatherContext: 
    37    """The current context used during core schema traversing. 
    38 
    39    Context instances should only be used during schema traversing. 
    40    """ 
    41 
    42    definitions: dict[str, CoreSchema] 
    43    """The available definitions.""" 
    44 
    45    deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list) 
    46    """The list of core schemas having the discriminator application deferred. 
    47 
    48    Internally, these core schemas have a specific key set in the core metadata dict. 
    49    """ 
    50 
    51    collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict) 
    52    """The collected definition references. 
    53 
    54    If a definition reference schema can be inlined, it means that there is 
    55    only one in the whole core schema. As such, it is stored as the value. 
    56    Otherwise, the value is set to `None`. 
    57 
    58    During schema traversing, definition reference schemas can be added as candidates, or removed 
    59    (by setting the value to `None`). 
    60    """ 
    61 
    62 
    63def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None: 
    64    meta = schema.get('metadata') 
    65    if meta is not None and 'pydantic_internal_union_discriminator' in meta: 
    66        ctx.deferred_discriminator_schemas.append(schema)  # pyright: ignore[reportArgumentType] 
    67 
    68 
    69def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None: 
    70    schema_ref = def_ref_schema['schema_ref'] 
    71 
    72    if schema_ref not in ctx.collected_references: 
    73        definition = ctx.definitions.get(schema_ref) 
    74        if definition is None: 
    75            raise MissingDefinitionError(schema_ref) 
    76 
    77        # The `'definition-ref'` schema was only encountered once, make it 
    78        # a candidate to be inlined: 
    79        ctx.collected_references[schema_ref] = def_ref_schema 
    80        traverse_schema(definition, ctx) 
    81        if 'serialization' in def_ref_schema: 
    82            traverse_schema(def_ref_schema['serialization'], ctx) 
    83        traverse_metadata(def_ref_schema, ctx) 
    84    else: 
    85        # The `'definition-ref'` schema was already encountered, meaning 
    86        # the previously encountered schema (and this one) can't be inlined: 
    87        ctx.collected_references[schema_ref] = None 
    88 
    89 
    90def traverse_schema(schema: AllSchemas, context: GatherContext) -> None: 
    91    # TODO When we drop 3.9, use a match statement to get better type checking and remove 
    92    # file-level type ignore. 
    93    # (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance). 
    94    schema_type = schema['type'] 
    95 
    96    if schema_type == 'definition-ref': 
    97        traverse_definition_ref(schema, context) 
    98        # `traverse_definition_ref` handles the possible serialization and metadata schemas: 
    99        return 
    100    elif schema_type == 'definitions': 
    101        traverse_schema(schema['schema'], context) 
    102        for definition in schema['definitions']: 
    103            traverse_schema(definition, context) 
    104    elif schema_type in {'list', 'set', 'frozenset', 'generator'}: 
    105        if 'items_schema' in schema: 
    106            traverse_schema(schema['items_schema'], context) 
    107    elif schema_type == 'tuple': 
    108        if 'items_schema' in schema: 
    109            for s in schema['items_schema']: 
    110                traverse_schema(s, context) 
    111    elif schema_type == 'dict': 
    112        if 'keys_schema' in schema: 
    113            traverse_schema(schema['keys_schema'], context) 
    114        if 'values_schema' in schema: 
    115            traverse_schema(schema['values_schema'], context) 
    116    elif schema_type == 'union': 
    117        for choice in schema['choices']: 
    118            if isinstance(choice, tuple): 
    119                traverse_schema(choice[0], context) 
    120            else: 
    121                traverse_schema(choice, context) 
    122    elif schema_type == 'tagged-union': 
    123        for v in schema['choices'].values(): 
    124            traverse_schema(v, context) 
    125    elif schema_type == 'chain': 
    126        for step in schema['steps']: 
    127            traverse_schema(step, context) 
    128    elif schema_type == 'lax-or-strict': 
    129        traverse_schema(schema['lax_schema'], context) 
    130        traverse_schema(schema['strict_schema'], context) 
    131    elif schema_type == 'json-or-python': 
    132        traverse_schema(schema['json_schema'], context) 
    133        traverse_schema(schema['python_schema'], context) 
    134    elif schema_type in {'model-fields', 'typed-dict'}: 
    135        if 'extras_schema' in schema: 
    136            traverse_schema(schema['extras_schema'], context) 
    137        if 'computed_fields' in schema: 
    138            for s in schema['computed_fields']: 
    139                traverse_schema(s, context) 
    140        for s in schema['fields'].values(): 
    141            traverse_schema(s, context) 
    142    elif schema_type == 'dataclass-args': 
    143        if 'computed_fields' in schema: 
    144            for s in schema['computed_fields']: 
    145                traverse_schema(s, context) 
    146        for s in schema['fields']: 
    147            traverse_schema(s, context) 
    148    elif schema_type == 'arguments': 
    149        for s in schema['arguments_schema']: 
    150            traverse_schema(s['schema'], context) 
    151        if 'var_args_schema' in schema: 
    152            traverse_schema(schema['var_args_schema'], context) 
    153        if 'var_kwargs_schema' in schema: 
    154            traverse_schema(schema['var_kwargs_schema'], context) 
    155    elif schema_type == 'arguments-v3': 
    156        for s in schema['arguments_schema']: 
    157            traverse_schema(s['schema'], context) 
    158    elif schema_type == 'call': 
    159        traverse_schema(schema['arguments_schema'], context) 
    160        if 'return_schema' in schema: 
    161            traverse_schema(schema['return_schema'], context) 
    162    elif schema_type == 'computed-field': 
    163        traverse_schema(schema['return_schema'], context) 
    164    elif schema_type == 'function-before': 
    165        if 'schema' in schema: 
    166            traverse_schema(schema['schema'], context) 
    167        if 'json_schema_input_schema' in schema: 
    168            traverse_schema(schema['json_schema_input_schema'], context) 
    169    elif schema_type == 'function-plain': 
    170        # TODO duplicate schema types for serializers and validators, needs to be deduplicated. 
    171        if 'return_schema' in schema: 
    172            traverse_schema(schema['return_schema'], context) 
    173        if 'json_schema_input_schema' in schema: 
    174            traverse_schema(schema['json_schema_input_schema'], context) 
    175    elif schema_type == 'function-wrap': 
    176        # TODO duplicate schema types for serializers and validators, needs to be deduplicated. 
    177        if 'return_schema' in schema: 
    178            traverse_schema(schema['return_schema'], context) 
    179        if 'schema' in schema: 
    180            traverse_schema(schema['schema'], context) 
    181        if 'json_schema_input_schema' in schema: 
    182            traverse_schema(schema['json_schema_input_schema'], context) 
    183    else: 
    184        if 'schema' in schema: 
    185            traverse_schema(schema['schema'], context) 
    186 
    187    if 'serialization' in schema: 
    188        traverse_schema(schema['serialization'], context) 
    189    traverse_metadata(schema, context) 
    190 
    191 
    192def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult: 
    193    """Traverse the core schema and definitions and return the necessary information for schema cleaning. 
    194 
    195    During the core schema traversing, any `'definition-ref'` schema is: 
    196 
    197    - Validated: the reference must point to an existing definition. If this is not the case, a 
    198      `MissingDefinitionError` exception is raised. 
    199    - Stored in the context: the actual reference is stored in the context. Depending on whether 
    200      the `'definition-ref'` schema is encountered more that once, the schema itself is also 
    201      saved in the context to be inlined (i.e. replaced by the definition it points to). 
    202    """ 
    203    context = GatherContext(definitions) 
    204    traverse_schema(schema, context) 
    205 
    206    return { 
    207        'collected_references': context.collected_references, 
    208        'deferred_discriminator_schemas': context.deferred_discriminator_schemas, 
    209    }