1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Astroid hooks for typing.py support."""
6
7from __future__ import annotations
8
9import textwrap
10import typing
11from collections.abc import Iterator
12from functools import partial
13from typing import Final
14
15from astroid import context, extract_node, inference_tip
16from astroid.brain.helpers import register_module_extender
17from astroid.builder import AstroidBuilder, _extract_single_node
18from astroid.const import PY39_PLUS, PY312_PLUS
19from astroid.exceptions import (
20 AstroidSyntaxError,
21 AttributeInferenceError,
22 InferenceError,
23 UseInferenceDefault,
24)
25from astroid.manager import AstroidManager
26from astroid.nodes.node_classes import (
27 Assign,
28 AssignName,
29 Attribute,
30 Call,
31 Const,
32 JoinedStr,
33 Name,
34 NodeNG,
35 Subscript,
36 Tuple,
37)
38from astroid.nodes.scoped_nodes import ClassDef, FunctionDef
39
40TYPING_TYPEVARS = {"TypeVar", "NewType"}
41TYPING_TYPEVARS_QUALIFIED: Final = {
42 "typing.TypeVar",
43 "typing.NewType",
44 "typing_extensions.TypeVar",
45}
46TYPING_TYPEDDICT_QUALIFIED: Final = {"typing.TypedDict", "typing_extensions.TypedDict"}
47TYPING_TYPE_TEMPLATE = """
48class Meta(type):
49 def __getitem__(self, item):
50 return self
51
52 @property
53 def __args__(self):
54 return ()
55
56class {0}(metaclass=Meta):
57 pass
58"""
59TYPING_MEMBERS = set(getattr(typing, "__all__", []))
60
61TYPING_ALIAS = frozenset(
62 (
63 "typing.Hashable",
64 "typing.Awaitable",
65 "typing.Coroutine",
66 "typing.AsyncIterable",
67 "typing.AsyncIterator",
68 "typing.Iterable",
69 "typing.Iterator",
70 "typing.Reversible",
71 "typing.Sized",
72 "typing.Container",
73 "typing.Collection",
74 "typing.Callable",
75 "typing.AbstractSet",
76 "typing.MutableSet",
77 "typing.Mapping",
78 "typing.MutableMapping",
79 "typing.Sequence",
80 "typing.MutableSequence",
81 "typing.ByteString",
82 "typing.Tuple",
83 "typing.List",
84 "typing.Deque",
85 "typing.Set",
86 "typing.FrozenSet",
87 "typing.MappingView",
88 "typing.KeysView",
89 "typing.ItemsView",
90 "typing.ValuesView",
91 "typing.ContextManager",
92 "typing.AsyncContextManager",
93 "typing.Dict",
94 "typing.DefaultDict",
95 "typing.OrderedDict",
96 "typing.Counter",
97 "typing.ChainMap",
98 "typing.Generator",
99 "typing.AsyncGenerator",
100 "typing.Type",
101 "typing.Pattern",
102 "typing.Match",
103 )
104)
105
106CLASS_GETITEM_TEMPLATE = """
107@classmethod
108def __class_getitem__(cls, item):
109 return cls
110"""
111
112
113def looks_like_typing_typevar_or_newtype(node) -> bool:
114 func = node.func
115 if isinstance(func, Attribute):
116 return func.attrname in TYPING_TYPEVARS
117 if isinstance(func, Name):
118 return func.name in TYPING_TYPEVARS
119 return False
120
121
122def infer_typing_typevar_or_newtype(
123 node: Call, context_itton: context.InferenceContext | None = None
124) -> Iterator[ClassDef]:
125 """Infer a typing.TypeVar(...) or typing.NewType(...) call."""
126 try:
127 func = next(node.func.infer(context=context_itton))
128 except (InferenceError, StopIteration) as exc:
129 raise UseInferenceDefault from exc
130
131 if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
132 raise UseInferenceDefault
133 if not node.args:
134 raise UseInferenceDefault
135 # Cannot infer from a dynamic class name (f-string)
136 if isinstance(node.args[0], JoinedStr):
137 raise UseInferenceDefault
138
139 typename = node.args[0].as_string().strip("'")
140 try:
141 node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
142 except AstroidSyntaxError as exc:
143 raise InferenceError from exc
144 return node.infer(context=context_itton)
145
146
147def _looks_like_typing_subscript(node) -> bool:
148 """Try to figure out if a Subscript node *might* be a typing-related subscript."""
149 if isinstance(node, Name):
150 return node.name in TYPING_MEMBERS
151 if isinstance(node, Attribute):
152 return node.attrname in TYPING_MEMBERS
153 if isinstance(node, Subscript):
154 return _looks_like_typing_subscript(node.value)
155 return False
156
157
158def infer_typing_attr(
159 node: Subscript, ctx: context.InferenceContext | None = None
160) -> Iterator[ClassDef]:
161 """Infer a typing.X[...] subscript."""
162 try:
163 value = next(node.value.infer()) # type: ignore[union-attr] # value shouldn't be None for Subscript.
164 except (InferenceError, StopIteration) as exc:
165 raise UseInferenceDefault from exc
166
167 if not value.qname().startswith("typing.") or value.qname() in TYPING_ALIAS:
168 # If typing subscript belongs to an alias handle it separately.
169 raise UseInferenceDefault
170
171 if isinstance(value, ClassDef) and value.qname() in {
172 "typing.Generic",
173 "typing.Annotated",
174 "typing_extensions.Annotated",
175 }:
176 # typing.Generic and typing.Annotated (PY39) are subscriptable
177 # through __class_getitem__. Since astroid can't easily
178 # infer the native methods, replace them for an easy inference tip
179 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
180 value.locals["__class_getitem__"] = [func_to_add]
181 if (
182 isinstance(node.parent, ClassDef)
183 and node in node.parent.bases
184 and getattr(node.parent, "__cache", None)
185 ):
186 # node.parent.slots is evaluated and cached before the inference tip
187 # is first applied. Remove the last result to allow a recalculation of slots
188 cache = node.parent.__cache # type: ignore[attr-defined] # Unrecognized getattr
189 if cache.get(node.parent.slots) is not None:
190 del cache[node.parent.slots]
191 # Avoid re-instantiating this class every time it's seen
192 node._explicit_inference = lambda node, context: iter([value])
193 return iter([value])
194
195 node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
196 return node.infer(context=ctx)
197
198
199def _looks_like_generic_class_pep695(node: ClassDef) -> bool:
200 """Check if class is using type parameter. Python 3.12+."""
201 return len(node.type_params) > 0
202
203
204def infer_typing_generic_class_pep695(
205 node: ClassDef, ctx: context.InferenceContext | None = None
206) -> Iterator[ClassDef]:
207 """Add __class_getitem__ for generic classes. Python 3.12+."""
208 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
209 node.locals["__class_getitem__"] = [func_to_add]
210 return iter([node])
211
212
213def _looks_like_typedDict( # pylint: disable=invalid-name
214 node: FunctionDef | ClassDef,
215) -> bool:
216 """Check if node is TypedDict FunctionDef."""
217 return node.qname() in TYPING_TYPEDDICT_QUALIFIED
218
219
220def infer_old_typedDict( # pylint: disable=invalid-name
221 node: ClassDef, ctx: context.InferenceContext | None = None
222) -> Iterator[ClassDef]:
223 func_to_add = _extract_single_node("dict")
224 node.locals["__call__"] = [func_to_add]
225 return iter([node])
226
227
228def infer_typedDict( # pylint: disable=invalid-name
229 node: FunctionDef, ctx: context.InferenceContext | None = None
230) -> Iterator[ClassDef]:
231 """Replace TypedDict FunctionDef with ClassDef."""
232 class_def = ClassDef(
233 name="TypedDict",
234 lineno=node.lineno,
235 col_offset=node.col_offset,
236 parent=node.parent,
237 end_lineno=node.end_lineno,
238 end_col_offset=node.end_col_offset,
239 )
240 class_def.postinit(bases=[extract_node("dict")], body=[], decorators=None)
241 func_to_add = _extract_single_node("dict")
242 class_def.locals["__call__"] = [func_to_add]
243 return iter([class_def])
244
245
246def _looks_like_typing_alias(node: Call) -> bool:
247 """
248 Returns True if the node corresponds to a call to _alias function.
249
250 For example :
251
252 MutableSet = _alias(collections.abc.MutableSet, T)
253
254 :param node: call node
255 """
256 return (
257 isinstance(node.func, Name)
258 # TODO: remove _DeprecatedGenericAlias when Py3.14 min
259 and node.func.name in {"_alias", "_DeprecatedGenericAlias"}
260 and (
261 # _alias function works also for builtins object such as list and dict
262 isinstance(node.args[0], (Attribute, Name))
263 )
264 )
265
266
267def _forbid_class_getitem_access(node: ClassDef) -> None:
268 """Disable the access to __class_getitem__ method for the node in parameters."""
269
270 def full_raiser(origin_func, attr, *args, **kwargs):
271 """
272 Raises an AttributeInferenceError in case of access to __class_getitem__ method.
273 Otherwise, just call origin_func.
274 """
275 if attr == "__class_getitem__":
276 raise AttributeInferenceError("__class_getitem__ access is not allowed")
277 return origin_func(attr, *args, **kwargs)
278
279 try:
280 node.getattr("__class_getitem__")
281 # If we are here, then we are sure to modify an object that does have
282 # __class_getitem__ method (which origin is the protocol defined in
283 # collections module) whereas the typing module considers it should not.
284 # We do not want __class_getitem__ to be found in the classdef
285 partial_raiser = partial(full_raiser, node.getattr)
286 node.getattr = partial_raiser
287 except AttributeInferenceError:
288 pass
289
290
291def infer_typing_alias(
292 node: Call, ctx: context.InferenceContext | None = None
293) -> Iterator[ClassDef]:
294 """
295 Infers the call to _alias function
296 Insert ClassDef, with same name as aliased class,
297 in mro to simulate _GenericAlias.
298
299 :param node: call node
300 :param context: inference context
301
302 # TODO: evaluate if still necessary when Py3.12 is minimum
303 """
304 if (
305 not isinstance(node.parent, Assign)
306 or not len(node.parent.targets) == 1
307 or not isinstance(node.parent.targets[0], AssignName)
308 ):
309 raise UseInferenceDefault
310 try:
311 res = next(node.args[0].infer(context=ctx))
312 except StopIteration as e:
313 raise InferenceError(node=node.args[0], context=ctx) from e
314
315 assign_name = node.parent.targets[0]
316
317 class_def = ClassDef(
318 name=assign_name.name,
319 lineno=assign_name.lineno,
320 col_offset=assign_name.col_offset,
321 parent=node.parent,
322 end_lineno=assign_name.end_lineno,
323 end_col_offset=assign_name.end_col_offset,
324 )
325 if isinstance(res, ClassDef):
326 # Only add `res` as base if it's a `ClassDef`
327 # This isn't the case for `typing.Pattern` and `typing.Match`
328 class_def.postinit(bases=[res], body=[], decorators=None)
329
330 maybe_type_var = node.args[1]
331 if (
332 not PY39_PLUS
333 and not (isinstance(maybe_type_var, Tuple) and not maybe_type_var.elts)
334 or PY39_PLUS
335 and isinstance(maybe_type_var, Const)
336 and maybe_type_var.value > 0
337 ):
338 # If typing alias is subscriptable, add `__class_getitem__` to ClassDef
339 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
340 class_def.locals["__class_getitem__"] = [func_to_add]
341 else:
342 # If not, make sure that `__class_getitem__` access is forbidden.
343 # This is an issue in cases where the aliased class implements it,
344 # but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
345 _forbid_class_getitem_access(class_def)
346
347 # Avoid re-instantiating this class every time it's seen
348 node._explicit_inference = lambda node, context: iter([class_def])
349 return iter([class_def])
350
351
352def _looks_like_special_alias(node: Call) -> bool:
353 """Return True if call is for Tuple or Callable alias.
354
355 In PY37 and PY38 the call is to '_VariadicGenericAlias' with 'tuple' as
356 first argument. In PY39+ it is replaced by a call to '_TupleType'.
357
358 PY37: Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
359 PY39: Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
360
361 PY37: Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
362 PY39: Callable = _CallableType(collections.abc.Callable, 2)
363 """
364 return isinstance(node.func, Name) and (
365 not PY39_PLUS
366 and node.func.name == "_VariadicGenericAlias"
367 and (
368 isinstance(node.args[0], Name)
369 and node.args[0].name == "tuple"
370 or isinstance(node.args[0], Attribute)
371 and node.args[0].as_string() == "collections.abc.Callable"
372 )
373 or PY39_PLUS
374 and (
375 node.func.name == "_TupleType"
376 and isinstance(node.args[0], Name)
377 and node.args[0].name == "tuple"
378 or node.func.name == "_CallableType"
379 and isinstance(node.args[0], Attribute)
380 and node.args[0].as_string() == "collections.abc.Callable"
381 )
382 )
383
384
385def infer_special_alias(
386 node: Call, ctx: context.InferenceContext | None = None
387) -> Iterator[ClassDef]:
388 """Infer call to tuple alias as new subscriptable class typing.Tuple."""
389 if not (
390 isinstance(node.parent, Assign)
391 and len(node.parent.targets) == 1
392 and isinstance(node.parent.targets[0], AssignName)
393 ):
394 raise UseInferenceDefault
395 try:
396 res = next(node.args[0].infer(context=ctx))
397 except StopIteration as e:
398 raise InferenceError(node=node.args[0], context=ctx) from e
399
400 assign_name = node.parent.targets[0]
401 class_def = ClassDef(
402 name=assign_name.name,
403 parent=node.parent,
404 lineno=assign_name.lineno,
405 col_offset=assign_name.col_offset,
406 end_lineno=assign_name.end_lineno,
407 end_col_offset=assign_name.end_col_offset,
408 )
409 class_def.postinit(bases=[res], body=[], decorators=None)
410 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
411 class_def.locals["__class_getitem__"] = [func_to_add]
412 # Avoid re-instantiating this class every time it's seen
413 node._explicit_inference = lambda node, context: iter([class_def])
414 return iter([class_def])
415
416
417def _looks_like_typing_cast(node: Call) -> bool:
418 return isinstance(node, Call) and (
419 isinstance(node.func, Name)
420 and node.func.name == "cast"
421 or isinstance(node.func, Attribute)
422 and node.func.attrname == "cast"
423 )
424
425
426def infer_typing_cast(
427 node: Call, ctx: context.InferenceContext | None = None
428) -> Iterator[NodeNG]:
429 """Infer call to cast() returning same type as casted-from var."""
430 if not isinstance(node.func, (Name, Attribute)):
431 raise UseInferenceDefault
432
433 try:
434 func = next(node.func.infer(context=ctx))
435 except (InferenceError, StopIteration) as exc:
436 raise UseInferenceDefault from exc
437 if (
438 not isinstance(func, FunctionDef)
439 or func.qname() != "typing.cast"
440 or len(node.args) != 2
441 ):
442 raise UseInferenceDefault
443
444 return node.args[1].infer(context=ctx)
445
446
447def _typing_transform():
448 return AstroidBuilder(AstroidManager()).string_build(
449 textwrap.dedent(
450 """
451 class Generic:
452 @classmethod
453 def __class_getitem__(cls, item): return cls
454 class ParamSpec:
455 @property
456 def args(self):
457 return ParamSpecArgs(self)
458 @property
459 def kwargs(self):
460 return ParamSpecKwargs(self)
461 class ParamSpecArgs: ...
462 class ParamSpecKwargs: ...
463 class TypeAlias: ...
464 class Type:
465 @classmethod
466 def __class_getitem__(cls, item): return cls
467 class TypeVar:
468 @classmethod
469 def __class_getitem__(cls, item): return cls
470 class TypeVarTuple: ...
471 """
472 )
473 )
474
475
476def register(manager: AstroidManager) -> None:
477 manager.register_transform(
478 Call,
479 inference_tip(infer_typing_typevar_or_newtype),
480 looks_like_typing_typevar_or_newtype,
481 )
482 manager.register_transform(
483 Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
484 )
485 manager.register_transform(
486 Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
487 )
488
489 if PY39_PLUS:
490 manager.register_transform(
491 FunctionDef, inference_tip(infer_typedDict), _looks_like_typedDict
492 )
493 else:
494 manager.register_transform(
495 ClassDef, inference_tip(infer_old_typedDict), _looks_like_typedDict
496 )
497
498 manager.register_transform(
499 Call, inference_tip(infer_typing_alias), _looks_like_typing_alias
500 )
501 manager.register_transform(
502 Call, inference_tip(infer_special_alias), _looks_like_special_alias
503 )
504
505 if PY312_PLUS:
506 register_module_extender(manager, "typing", _typing_transform)
507 manager.register_transform(
508 ClassDef,
509 inference_tip(infer_typing_generic_class_pep695),
510 _looks_like_generic_class_pep695,
511 )