1"""Bucket of reusable internal utilities.
2
3This should be reduced as much as possible with functions only used in one place, moved to that place.
4"""
5
6from __future__ import annotations as _annotations
7
8import dataclasses
9import keyword
10import sys
11import typing
12import warnings
13import weakref
14from collections import OrderedDict, defaultdict, deque
15from collections.abc import Mapping
16from copy import deepcopy
17from functools import cached_property
18from inspect import Parameter
19from itertools import zip_longest
20from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
21from typing import Any, Callable, Generic, TypeVar, overload
22
23from typing_extensions import TypeAlias, TypeGuard, deprecated
24
25from pydantic import PydanticDeprecatedSince211
26
27from . import _repr, _typing_extra
28from ._import_utils import import_cached_base_model
29
30if typing.TYPE_CHECKING:
31 MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
32 AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
33 from ..main import BaseModel
34
35
36# these are types that are returned unchanged by deepcopy
37IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = {
38 int,
39 float,
40 complex,
41 str,
42 bool,
43 bytes,
44 type,
45 _typing_extra.NoneType,
46 FunctionType,
47 BuiltinFunctionType,
48 LambdaType,
49 weakref.ref,
50 CodeType,
51 # note: including ModuleType will differ from behaviour of deepcopy by not producing error.
52 # It might be not a good idea in general, but considering that this function used only internally
53 # against default values of fields, this will allow to actually have a field with module as default value
54 ModuleType,
55 NotImplemented.__class__,
56 Ellipsis.__class__,
57}
58
59# these are types that if empty, might be copied with simple copy() instead of deepcopy()
60BUILTIN_COLLECTIONS: set[type[Any]] = {
61 list,
62 set,
63 tuple,
64 frozenset,
65 dict,
66 OrderedDict,
67 defaultdict,
68 deque,
69}
70
71
72def can_be_positional(param: Parameter) -> bool:
73 """Return whether the parameter accepts a positional argument.
74
75 ```python {test="skip" lint="skip"}
76 def func(a, /, b, *, c):
77 pass
78
79 params = inspect.signature(func).parameters
80 can_be_positional(params['a'])
81 #> True
82 can_be_positional(params['b'])
83 #> True
84 can_be_positional(params['c'])
85 #> False
86 ```
87 """
88 return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
89
90
91def sequence_like(v: Any) -> bool:
92 return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
93
94
95def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover
96 try:
97 return isinstance(o, class_or_tuple) # type: ignore[arg-type]
98 except TypeError:
99 return False
100
101
102def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
103 try:
104 return isinstance(cls, type) and issubclass(cls, class_or_tuple)
105 except TypeError:
106 if isinstance(cls, _typing_extra.WithArgsTypes):
107 return False
108 raise # pragma: no cover
109
110
111def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
112 """Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
113 unlike raw calls to lenient_issubclass.
114 """
115 BaseModel = import_cached_base_model()
116
117 return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
118
119
120def is_valid_identifier(identifier: str) -> bool:
121 """Checks that a string is a valid identifier and not a Python keyword.
122 :param identifier: The identifier to test.
123 :return: True if the identifier is valid.
124 """
125 return identifier.isidentifier() and not keyword.iskeyword(identifier)
126
127
128KeyType = TypeVar('KeyType')
129
130
131def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
132 updated_mapping = mapping.copy()
133 for updating_mapping in updating_mappings:
134 for k, v in updating_mapping.items():
135 if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
136 updated_mapping[k] = deep_update(updated_mapping[k], v)
137 else:
138 updated_mapping[k] = v
139 return updated_mapping
140
141
142def update_not_none(mapping: dict[Any, Any], **update: Any) -> None:
143 mapping.update({k: v for k, v in update.items() if v is not None})
144
145
146T = TypeVar('T')
147
148
149def unique_list(
150 input_list: list[T] | tuple[T, ...],
151 *,
152 name_factory: typing.Callable[[T], str] = str,
153) -> list[T]:
154 """Make a list unique while maintaining order.
155 We update the list if another one with the same name is set
156 (e.g. model validator overridden in subclass).
157 """
158 result: list[T] = []
159 result_names: list[str] = []
160 for v in input_list:
161 v_name = name_factory(v)
162 if v_name not in result_names:
163 result_names.append(v_name)
164 result.append(v)
165 else:
166 result[result_names.index(v_name)] = v
167
168 return result
169
170
171class ValueItems(_repr.Representation):
172 """Class for more convenient calculation of excluded or included fields on values."""
173
174 __slots__ = ('_items', '_type')
175
176 def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None:
177 items = self._coerce_items(items)
178
179 if isinstance(value, (list, tuple)):
180 items = self._normalize_indexes(items, len(value)) # type: ignore
181
182 self._items: MappingIntStrAny = items # type: ignore
183
184 def is_excluded(self, item: Any) -> bool:
185 """Check if item is fully excluded.
186
187 :param item: key or index of a value
188 """
189 return self.is_true(self._items.get(item))
190
191 def is_included(self, item: Any) -> bool:
192 """Check if value is contained in self._items.
193
194 :param item: key or index of value
195 """
196 return item in self._items
197
198 def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None:
199 """:param e: key or index of element on value
200 :return: raw values for element if self._items is dict and contain needed element
201 """
202 item = self._items.get(e) # type: ignore
203 return item if not self.is_true(item) else None
204
205 def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]:
206 """:param items: dict or set of indexes which will be normalized
207 :param v_length: length of sequence indexes of which will be
208
209 >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
210 {0: True, 2: True, 3: True}
211 >>> self._normalize_indexes({'__all__': True}, 4)
212 {0: True, 1: True, 2: True, 3: True}
213 """
214 normalized_items: dict[int | str, Any] = {}
215 all_items = None
216 for i, v in items.items():
217 if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
218 raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
219 if i == '__all__':
220 all_items = self._coerce_value(v)
221 continue
222 if not isinstance(i, int):
223 raise TypeError(
224 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
225 'expected integer keys or keyword "__all__"'
226 )
227 normalized_i = v_length + i if i < 0 else i
228 normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
229
230 if not all_items:
231 return normalized_items
232 if self.is_true(all_items):
233 for i in range(v_length):
234 normalized_items.setdefault(i, ...)
235 return normalized_items
236 for i in range(v_length):
237 normalized_item = normalized_items.setdefault(i, {})
238 if not self.is_true(normalized_item):
239 normalized_items[i] = self.merge(all_items, normalized_item)
240 return normalized_items
241
242 @classmethod
243 def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
244 """Merge a `base` item with an `override` item.
245
246 Both `base` and `override` are converted to dictionaries if possible.
247 Sets are converted to dictionaries with the sets entries as keys and
248 Ellipsis as values.
249
250 Each key-value pair existing in `base` is merged with `override`,
251 while the rest of the key-value pairs are updated recursively with this function.
252
253 Merging takes place based on the "union" of keys if `intersect` is
254 set to `False` (default) and on the intersection of keys if
255 `intersect` is set to `True`.
256 """
257 override = cls._coerce_value(override)
258 base = cls._coerce_value(base)
259 if override is None:
260 return base
261 if cls.is_true(base) or base is None:
262 return override
263 if cls.is_true(override):
264 return base if intersect else override
265
266 # intersection or union of keys while preserving ordering:
267 if intersect:
268 merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
269 else:
270 merge_keys = list(base) + [k for k in override if k not in base]
271
272 merged: dict[int | str, Any] = {}
273 for k in merge_keys:
274 merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
275 if merged_item is not None:
276 merged[k] = merged_item
277
278 return merged
279
280 @staticmethod
281 def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
282 if isinstance(items, typing.Mapping):
283 pass
284 elif isinstance(items, typing.AbstractSet):
285 items = dict.fromkeys(items, ...) # type: ignore
286 else:
287 class_name = getattr(items, '__class__', '???')
288 raise TypeError(f'Unexpected type of exclude value {class_name}')
289 return items # type: ignore
290
291 @classmethod
292 def _coerce_value(cls, value: Any) -> Any:
293 if value is None or cls.is_true(value):
294 return value
295 return cls._coerce_items(value)
296
297 @staticmethod
298 def is_true(v: Any) -> bool:
299 return v is True or v is ...
300
301 def __repr_args__(self) -> _repr.ReprArgs:
302 return [(None, self._items)]
303
304
305if typing.TYPE_CHECKING:
306
307 def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
308
309else:
310
311 class LazyClassAttribute:
312 """A descriptor exposing an attribute only accessible on a class (hidden from instances).
313
314 The attribute is lazily computed and cached during the first access.
315 """
316
317 def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
318 self.name = name
319 self.get_value = get_value
320
321 @cached_property
322 def value(self) -> Any:
323 return self.get_value()
324
325 def __get__(self, instance: Any, owner: type[Any]) -> None:
326 if instance is None:
327 return self.value
328 raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
329
330
331Obj = TypeVar('Obj')
332
333
334def smart_deepcopy(obj: Obj) -> Obj:
335 """Return type as is for immutable built-in types
336 Use obj.copy() for built-in empty collections
337 Use copy.deepcopy() for non-empty collections and unknown objects.
338 """
339 obj_type = obj.__class__
340 if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
341 return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
342 try:
343 if not obj and obj_type in BUILTIN_COLLECTIONS:
344 # faster way for empty collections, no need to copy its members
345 return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
346 except (TypeError, ValueError, RuntimeError):
347 # do we really dare to catch ALL errors? Seems a bit risky
348 pass
349
350 return deepcopy(obj) # slowest way when we actually might need a deepcopy
351
352
353_SENTINEL = object()
354
355
356def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
357 """Check that the items of `left` are the same objects as those in `right`.
358
359 >>> a, b = object(), object()
360 >>> all_identical([a, b, a], [a, b, a])
361 True
362 >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
363 False
364 """
365 for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
366 if left_item is not right_item:
367 return False
368 return True
369
370
371@dataclasses.dataclass(frozen=True)
372class SafeGetItemProxy:
373 """Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
374
375 This makes is safe to use in `operator.itemgetter` when some keys may be missing
376 """
377
378 # Define __slots__manually for performances
379 # @dataclasses.dataclass() only support slots=True in python>=3.10
380 __slots__ = ('wrapped',)
381
382 wrapped: Mapping[str, Any]
383
384 def __getitem__(self, key: str, /) -> Any:
385 return self.wrapped.get(key, _SENTINEL)
386
387 # required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
388 # https://github.com/python/mypy/issues/13713
389 # https://github.com/python/typeshed/pull/8785
390 # Since this is typing-only, hide it in a typing.TYPE_CHECKING block
391 if typing.TYPE_CHECKING:
392
393 def __contains__(self, key: str, /) -> bool:
394 return self.wrapped.__contains__(key)
395
396
397_ModelT = TypeVar('_ModelT', bound='BaseModel')
398_RT = TypeVar('_RT')
399
400
401class deprecated_instance_property(Generic[_ModelT, _RT]):
402 """A decorator exposing the decorated class method as a property, with a warning on instance access.
403
404 This decorator takes a class method defined on the `BaseModel` class and transforms it into
405 an attribute. The attribute can be accessed on both the class and instances of the class. If accessed
406 via an instance, a deprecation warning is emitted stating that instance access will be removed in V3.
407 """
408
409 def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None:
410 # Note: fget should be a classmethod:
411 self.fget = fget
412
413 @overload
414 def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ...
415 @overload
416 @deprecated(
417 'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
418 'Instead, you should access this attribute from the model class.',
419 category=None,
420 )
421 def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ...
422 def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT:
423 if instance is not None:
424 attr_name = self.fget.__name__ if sys.version_info >= (3, 10) else self.fget.__func__.__name__
425 warnings.warn(
426 f'Accessing the {attr_name!r} attribute on the instance is deprecated. '
427 'Instead, you should access this attribute from the model class.',
428 category=PydanticDeprecatedSince211,
429 stacklevel=2,
430 )
431 return self.fget.__get__(instance, objtype)()