1from __future__ import annotations
2
3import dataclasses
4from inspect import Parameter, Signature
5from typing import TYPE_CHECKING, Any, Callable
6
7from pydantic_core import PydanticUndefined
8
9from ._typing_extra import signature_no_eval
10from ._utils import is_valid_identifier
11
12if TYPE_CHECKING:
13 from ..config import ExtraValues
14 from ..fields import FieldInfo
15
16
17# Copied over from stdlib dataclasses
18class _HAS_DEFAULT_FACTORY_CLASS:
19 def __repr__(self):
20 return '<factory>'
21
22
23_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
24
25
26def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
27 """Extract the correct name to use for the field when generating a signature.
28
29 Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
30 First priority is given to the alias, then the validation_alias, then the field name.
31
32 Args:
33 field_name: The name of the field
34 field_info: The corresponding FieldInfo object.
35
36 Returns:
37 The correct name to use when generating a signature.
38 """
39 if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
40 return field_info.alias
41 if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
42 return field_info.validation_alias
43
44 return field_name
45
46
47def _process_param_defaults(param: Parameter) -> Parameter:
48 """Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
49
50 Args:
51 param (Parameter): The parameter
52
53 Returns:
54 Parameter: The custom processed parameter
55 """
56 from ..fields import FieldInfo
57
58 param_default = param.default
59 if isinstance(param_default, FieldInfo):
60 annotation = param.annotation
61 # Replace the annotation if appropriate
62 # inspect does "clever" things to show annotations as strings because we have
63 # `from __future__ import annotations` in main, we don't want that
64 if annotation == 'Any':
65 annotation = Any
66
67 # Replace the field default
68 default = param_default.default
69 if default is PydanticUndefined:
70 if param_default.default_factory is None:
71 default = Signature.empty
72 else:
73 # this is used by dataclasses to indicate a factory exists:
74 default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
75 return param.replace(
76 annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
77 )
78 return param
79
80
81def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
82 init: Callable[..., None],
83 fields: dict[str, FieldInfo],
84 validate_by_name: bool,
85 extra: ExtraValues | None,
86) -> dict[str, Parameter]:
87 """Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
88 from itertools import islice
89
90 present_params = signature_no_eval(init).parameters.values()
91 merged_params: dict[str, Parameter] = {}
92 var_kw = None
93 use_var_kw = False
94
95 for param in islice(present_params, 1, None): # skip self arg
96 # inspect does "clever" things to show annotations as strings because we have
97 # `from __future__ import annotations` in main, we don't want that
98 if fields.get(param.name):
99 # exclude params with init=False
100 if getattr(fields[param.name], 'init', True) is False:
101 continue
102 param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
103 if param.annotation == 'Any':
104 param = param.replace(annotation=Any)
105 if param.kind is param.VAR_KEYWORD:
106 var_kw = param
107 continue
108 merged_params[param.name] = param
109
110 if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
111 allow_names = validate_by_name
112 for field_name, field in fields.items():
113 # when alias is a str it should be used for signature generation
114 param_name = _field_name_for_signature(field_name, field)
115
116 if field_name in merged_params or param_name in merged_params:
117 continue
118
119 if not is_valid_identifier(param_name):
120 if allow_names:
121 param_name = field_name
122 else:
123 use_var_kw = True
124 continue
125
126 if field.is_required():
127 default = Parameter.empty
128 elif field.default_factory is not None:
129 # Mimics stdlib dataclasses:
130 default = _HAS_DEFAULT_FACTORY
131 else:
132 default = field.default
133 merged_params[param_name] = Parameter(
134 param_name,
135 Parameter.KEYWORD_ONLY,
136 annotation=field.rebuild_annotation(),
137 default=default,
138 )
139
140 if extra == 'allow':
141 use_var_kw = True
142
143 if var_kw and use_var_kw:
144 # Make sure the parameter for extra kwargs
145 # does not have the same name as a field
146 default_model_signature = [
147 ('self', Parameter.POSITIONAL_ONLY),
148 ('data', Parameter.VAR_KEYWORD),
149 ]
150 if [(p.name, p.kind) for p in present_params] == default_model_signature:
151 # if this is the standard model signature, use extra_data as the extra args name
152 var_kw_name = 'extra_data'
153 else:
154 # else start from var_kw
155 var_kw_name = var_kw.name
156
157 # generate a name that's definitely unique
158 while var_kw_name in fields:
159 var_kw_name += '_'
160 merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
161
162 return merged_params
163
164
165def generate_pydantic_signature(
166 init: Callable[..., None],
167 fields: dict[str, FieldInfo],
168 validate_by_name: bool,
169 extra: ExtraValues | None,
170 is_dataclass: bool = False,
171) -> Signature:
172 """Generate signature for a pydantic BaseModel or dataclass.
173
174 Args:
175 init: The class init.
176 fields: The model fields.
177 validate_by_name: The `validate_by_name` value of the config.
178 extra: The `extra` value of the config.
179 is_dataclass: Whether the model is a dataclass.
180
181 Returns:
182 The dataclass/BaseModel subclass signature.
183 """
184 merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
185
186 if is_dataclass:
187 merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
188
189 return Signature(parameters=list(merged_params.values()), return_annotation=None)