1from __future__ import annotations
2
3import dataclasses
4from inspect import Parameter, Signature, signature
5from typing import TYPE_CHECKING, Any, Callable
6
7from pydantic_core import PydanticUndefined
8
9from ._utils import is_valid_identifier
10
11if TYPE_CHECKING:
12 from ..config import ExtraValues
13 from ..fields import FieldInfo
14
15
16# Copied over from stdlib dataclasses
17class _HAS_DEFAULT_FACTORY_CLASS:
18 def __repr__(self):
19 return '<factory>'
20
21
22_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
23
24
25def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
26 """Extract the correct name to use for the field when generating a signature.
27
28 Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
29 First priority is given to the alias, then the validation_alias, then the field name.
30
31 Args:
32 field_name: The name of the field
33 field_info: The corresponding FieldInfo object.
34
35 Returns:
36 The correct name to use when generating a signature.
37 """
38 if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
39 return field_info.alias
40 if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
41 return field_info.validation_alias
42
43 return field_name
44
45
46def _process_param_defaults(param: Parameter) -> Parameter:
47 """Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
48
49 Args:
50 param (Parameter): The parameter
51
52 Returns:
53 Parameter: The custom processed parameter
54 """
55 from ..fields import FieldInfo
56
57 param_default = param.default
58 if isinstance(param_default, FieldInfo):
59 annotation = param.annotation
60 # Replace the annotation if appropriate
61 # inspect does "clever" things to show annotations as strings because we have
62 # `from __future__ import annotations` in main, we don't want that
63 if annotation == 'Any':
64 annotation = Any
65
66 # Replace the field default
67 default = param_default.default
68 if default is PydanticUndefined:
69 if param_default.default_factory is PydanticUndefined:
70 default = Signature.empty
71 else:
72 # this is used by dataclasses to indicate a factory exists:
73 default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
74 return param.replace(
75 annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
76 )
77 return param
78
79
80def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
81 init: Callable[..., None],
82 fields: dict[str, FieldInfo],
83 validate_by_name: bool,
84 extra: ExtraValues | None,
85) -> dict[str, Parameter]:
86 """Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
87 from itertools import islice
88
89 present_params = signature(init).parameters.values()
90 merged_params: dict[str, Parameter] = {}
91 var_kw = None
92 use_var_kw = False
93
94 for param in islice(present_params, 1, None): # skip self arg
95 # inspect does "clever" things to show annotations as strings because we have
96 # `from __future__ import annotations` in main, we don't want that
97 if fields.get(param.name):
98 # exclude params with init=False
99 if getattr(fields[param.name], 'init', True) is False:
100 continue
101 param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
102 if param.annotation == 'Any':
103 param = param.replace(annotation=Any)
104 if param.kind is param.VAR_KEYWORD:
105 var_kw = param
106 continue
107 merged_params[param.name] = param
108
109 if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
110 allow_names = validate_by_name
111 for field_name, field in fields.items():
112 # when alias is a str it should be used for signature generation
113 param_name = _field_name_for_signature(field_name, field)
114
115 if field_name in merged_params or param_name in merged_params:
116 continue
117
118 if not is_valid_identifier(param_name):
119 if allow_names:
120 param_name = field_name
121 else:
122 use_var_kw = True
123 continue
124
125 if field.is_required():
126 default = Parameter.empty
127 elif field.default_factory is not None:
128 # Mimics stdlib dataclasses:
129 default = _HAS_DEFAULT_FACTORY
130 else:
131 default = field.default
132 merged_params[param_name] = Parameter(
133 param_name,
134 Parameter.KEYWORD_ONLY,
135 annotation=field.rebuild_annotation(),
136 default=default,
137 )
138
139 if extra == 'allow':
140 use_var_kw = True
141
142 if var_kw and use_var_kw:
143 # Make sure the parameter for extra kwargs
144 # does not have the same name as a field
145 default_model_signature = [
146 ('self', Parameter.POSITIONAL_ONLY),
147 ('data', Parameter.VAR_KEYWORD),
148 ]
149 if [(p.name, p.kind) for p in present_params] == default_model_signature:
150 # if this is the standard model signature, use extra_data as the extra args name
151 var_kw_name = 'extra_data'
152 else:
153 # else start from var_kw
154 var_kw_name = var_kw.name
155
156 # generate a name that's definitely unique
157 while var_kw_name in fields:
158 var_kw_name += '_'
159 merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
160
161 return merged_params
162
163
164def generate_pydantic_signature(
165 init: Callable[..., None],
166 fields: dict[str, FieldInfo],
167 validate_by_name: bool,
168 extra: ExtraValues | None,
169 is_dataclass: bool = False,
170) -> Signature:
171 """Generate signature for a pydantic BaseModel or dataclass.
172
173 Args:
174 init: The class init.
175 fields: The model fields.
176 validate_by_name: The `validate_by_name` value of the config.
177 extra: The `extra` value of the config.
178 is_dataclass: Whether the model is a dataclass.
179
180 Returns:
181 The dataclass/BaseModel subclass signature.
182 """
183 merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
184
185 if is_dataclass:
186 merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
187
188 return Signature(parameters=list(merged_params.values()), return_annotation=None)