1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Astroid hooks for typing.py support."""
6
7from __future__ import annotations
8
9import textwrap
10import typing
11from collections.abc import Iterator
12from functools import partial
13from typing import Final
14
15from astroid import context
16from astroid.brain.helpers import register_module_extender
17from astroid.builder import AstroidBuilder, _extract_single_node, extract_node
18from astroid.const import PY312_PLUS, PY313_PLUS
19from astroid.exceptions import (
20 AstroidSyntaxError,
21 AttributeInferenceError,
22 InferenceError,
23 UseInferenceDefault,
24)
25from astroid.inference_tip import inference_tip
26from astroid.manager import AstroidManager
27from astroid.nodes.node_classes import (
28 Assign,
29 AssignName,
30 Attribute,
31 Call,
32 Const,
33 JoinedStr,
34 Name,
35 NodeNG,
36 Subscript,
37)
38from astroid.nodes.scoped_nodes import ClassDef, FunctionDef
39
40TYPING_TYPEVARS = {"TypeVar", "NewType"}
41TYPING_TYPEVARS_QUALIFIED: Final = {
42 "typing.TypeVar",
43 "typing.NewType",
44 "typing_extensions.TypeVar",
45}
46TYPING_TYPEDDICT_QUALIFIED: Final = {"typing.TypedDict", "typing_extensions.TypedDict"}
47TYPING_TYPE_TEMPLATE = """
48class Meta(type):
49 def __getitem__(self, item):
50 return self
51
52 @property
53 def __args__(self):
54 return ()
55
56class {0}(metaclass=Meta):
57 pass
58"""
59TYPING_MEMBERS = set(getattr(typing, "__all__", []))
60
61TYPING_ALIAS = frozenset(
62 (
63 "typing.Hashable",
64 "typing.Awaitable",
65 "typing.Coroutine",
66 "typing.AsyncIterable",
67 "typing.AsyncIterator",
68 "typing.Iterable",
69 "typing.Iterator",
70 "typing.Reversible",
71 "typing.Sized",
72 "typing.Container",
73 "typing.Collection",
74 "typing.Callable",
75 "typing.AbstractSet",
76 "typing.MutableSet",
77 "typing.Mapping",
78 "typing.MutableMapping",
79 "typing.Sequence",
80 "typing.MutableSequence",
81 "typing.ByteString",
82 "typing.Tuple",
83 "typing.List",
84 "typing.Deque",
85 "typing.Set",
86 "typing.FrozenSet",
87 "typing.MappingView",
88 "typing.KeysView",
89 "typing.ItemsView",
90 "typing.ValuesView",
91 "typing.ContextManager",
92 "typing.AsyncContextManager",
93 "typing.Dict",
94 "typing.DefaultDict",
95 "typing.OrderedDict",
96 "typing.Counter",
97 "typing.ChainMap",
98 "typing.Generator",
99 "typing.AsyncGenerator",
100 "typing.Type",
101 "typing.Pattern",
102 "typing.Match",
103 )
104)
105
106CLASS_GETITEM_TEMPLATE = """
107@classmethod
108def __class_getitem__(cls, item):
109 return cls
110"""
111
112
113def looks_like_typing_typevar_or_newtype(node) -> bool:
114 func = node.func
115 if isinstance(func, Attribute):
116 return func.attrname in TYPING_TYPEVARS
117 if isinstance(func, Name):
118 return func.name in TYPING_TYPEVARS
119 return False
120
121
122def infer_typing_typevar_or_newtype(
123 node: Call, context_itton: context.InferenceContext | None = None
124) -> Iterator[ClassDef]:
125 """Infer a typing.TypeVar(...) or typing.NewType(...) call."""
126 try:
127 func = next(node.func.infer(context=context_itton))
128 except (InferenceError, StopIteration) as exc:
129 raise UseInferenceDefault from exc
130
131 if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
132 raise UseInferenceDefault
133 if not node.args:
134 raise UseInferenceDefault
135 # Cannot infer from a dynamic class name (f-string)
136 if isinstance(node.args[0], JoinedStr):
137 raise UseInferenceDefault
138
139 typename = node.args[0].as_string().strip("'")
140 try:
141 node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
142 except AstroidSyntaxError as exc:
143 raise InferenceError from exc
144 return node.infer(context=context_itton)
145
146
147def _looks_like_typing_subscript(node) -> bool:
148 """Try to figure out if a Subscript node *might* be a typing-related subscript."""
149 if isinstance(node, Name):
150 return node.name in TYPING_MEMBERS
151 if isinstance(node, Attribute):
152 return node.attrname in TYPING_MEMBERS
153 if isinstance(node, Subscript):
154 return _looks_like_typing_subscript(node.value)
155 return False
156
157
158def infer_typing_attr(
159 node: Subscript, ctx: context.InferenceContext | None = None
160) -> Iterator[ClassDef]:
161 """Infer a typing.X[...] subscript."""
162 try:
163 value = next(node.value.infer()) # type: ignore[union-attr] # value shouldn't be None for Subscript.
164 except (InferenceError, StopIteration) as exc:
165 raise UseInferenceDefault from exc
166
167 if not value.qname().startswith("typing.") or value.qname() in TYPING_ALIAS:
168 # If typing subscript belongs to an alias handle it separately.
169 raise UseInferenceDefault
170
171 if (
172 PY313_PLUS
173 and isinstance(value, FunctionDef)
174 and value.qname() == "typing.Annotated"
175 ):
176 # typing.Annotated is a FunctionDef on 3.13+
177 node._explicit_inference = lambda node, context: iter([value])
178 return iter([value])
179
180 if isinstance(value, ClassDef) and value.qname() in {
181 "typing.Generic",
182 "typing.Annotated",
183 "typing_extensions.Annotated",
184 }:
185 # typing.Generic and typing.Annotated (PY39) are subscriptable
186 # through __class_getitem__. Since astroid can't easily
187 # infer the native methods, replace them for an easy inference tip
188 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
189 value.locals["__class_getitem__"] = [func_to_add]
190 if (
191 isinstance(node.parent, ClassDef)
192 and node in node.parent.bases
193 and getattr(node.parent, "__cache", None)
194 ):
195 # node.parent.slots is evaluated and cached before the inference tip
196 # is first applied. Remove the last result to allow a recalculation of slots
197 cache = node.parent.__cache # type: ignore[attr-defined] # Unrecognized getattr
198 if cache.get(node.parent.slots) is not None:
199 del cache[node.parent.slots]
200 # Avoid re-instantiating this class every time it's seen
201 node._explicit_inference = lambda node, context: iter([value])
202 return iter([value])
203
204 node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
205 return node.infer(context=ctx)
206
207
208def _looks_like_generic_class_pep695(node: ClassDef) -> bool:
209 """Check if class is using type parameter. Python 3.12+."""
210 return len(node.type_params) > 0
211
212
213def infer_typing_generic_class_pep695(
214 node: ClassDef, ctx: context.InferenceContext | None = None
215) -> Iterator[ClassDef]:
216 """Add __class_getitem__ for generic classes. Python 3.12+."""
217 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
218 node.locals["__class_getitem__"] = [func_to_add]
219 return iter([node])
220
221
222def _looks_like_typedDict( # pylint: disable=invalid-name
223 node: FunctionDef | ClassDef,
224) -> bool:
225 """Check if node is TypedDict FunctionDef."""
226 return node.qname() in TYPING_TYPEDDICT_QUALIFIED
227
228
229def infer_typedDict( # pylint: disable=invalid-name
230 node: FunctionDef, ctx: context.InferenceContext | None = None
231) -> Iterator[ClassDef]:
232 """Replace TypedDict FunctionDef with ClassDef."""
233 class_def = ClassDef(
234 name="TypedDict",
235 lineno=node.lineno,
236 col_offset=node.col_offset,
237 parent=node.parent,
238 end_lineno=node.end_lineno,
239 end_col_offset=node.end_col_offset,
240 )
241 class_def.postinit(bases=[extract_node("dict")], body=[], decorators=None)
242 func_to_add = _extract_single_node("dict")
243 class_def.locals["__call__"] = [func_to_add]
244 return iter([class_def])
245
246
247def _looks_like_typing_alias(node: Call) -> bool:
248 """
249 Returns True if the node corresponds to a call to _alias function.
250
251 For example :
252
253 MutableSet = _alias(collections.abc.MutableSet, T)
254
255 :param node: call node
256 """
257 return (
258 isinstance(node.func, Name)
259 # TODO: remove _DeprecatedGenericAlias when Py3.14 min
260 and node.func.name in {"_alias", "_DeprecatedGenericAlias"}
261 and len(node.args) == 2
262 and (
263 # _alias function works also for builtins object such as list and dict
264 isinstance(node.args[0], (Attribute, Name))
265 )
266 )
267
268
269def _forbid_class_getitem_access(node: ClassDef) -> None:
270 """Disable the access to __class_getitem__ method for the node in parameters."""
271
272 def full_raiser(origin_func, attr, *args, **kwargs):
273 """
274 Raises an AttributeInferenceError in case of access to __class_getitem__ method.
275 Otherwise, just call origin_func.
276 """
277 if attr == "__class_getitem__":
278 raise AttributeInferenceError("__class_getitem__ access is not allowed")
279 return origin_func(attr, *args, **kwargs)
280
281 try:
282 node.getattr("__class_getitem__")
283 # If we are here, then we are sure to modify an object that does have
284 # __class_getitem__ method (which origin is the protocol defined in
285 # collections module) whereas the typing module considers it should not.
286 # We do not want __class_getitem__ to be found in the classdef
287 partial_raiser = partial(full_raiser, node.getattr)
288 node.getattr = partial_raiser
289 except AttributeInferenceError:
290 pass
291
292
293def infer_typing_alias(
294 node: Call, ctx: context.InferenceContext | None = None
295) -> Iterator[ClassDef]:
296 """
297 Infers the call to _alias function
298 Insert ClassDef, with same name as aliased class,
299 in mro to simulate _GenericAlias.
300
301 :param node: call node
302 :param context: inference context
303
304 # TODO: evaluate if still necessary when Py3.12 is minimum
305 """
306 if (
307 not isinstance(node.parent, Assign)
308 or not len(node.parent.targets) == 1
309 or not isinstance(node.parent.targets[0], AssignName)
310 ):
311 raise UseInferenceDefault
312 try:
313 res = next(node.args[0].infer(context=ctx))
314 except StopIteration as e:
315 raise InferenceError(node=node.args[0], context=ctx) from e
316
317 assign_name = node.parent.targets[0]
318
319 class_def = ClassDef(
320 name=assign_name.name,
321 lineno=assign_name.lineno,
322 col_offset=assign_name.col_offset,
323 parent=node.parent,
324 end_lineno=assign_name.end_lineno,
325 end_col_offset=assign_name.end_col_offset,
326 )
327 if isinstance(res, ClassDef):
328 # Only add `res` as base if it's a `ClassDef`
329 # This isn't the case for `typing.Pattern` and `typing.Match`
330 class_def.postinit(bases=[res], body=[], decorators=None)
331
332 maybe_type_var = node.args[1]
333 if isinstance(maybe_type_var, Const) and maybe_type_var.value > 0:
334 # If typing alias is subscriptable, add `__class_getitem__` to ClassDef
335 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
336 class_def.locals["__class_getitem__"] = [func_to_add]
337 else:
338 # If not, make sure that `__class_getitem__` access is forbidden.
339 # This is an issue in cases where the aliased class implements it,
340 # but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
341 _forbid_class_getitem_access(class_def)
342
343 # Avoid re-instantiating this class every time it's seen
344 node._explicit_inference = lambda node, context: iter([class_def])
345 return iter([class_def])
346
347
348def _looks_like_special_alias(node: Call) -> bool:
349 """Return True if call is for Tuple or Callable alias.
350
351 In PY37 and PY38 the call is to '_VariadicGenericAlias' with 'tuple' as
352 first argument. In PY39+ it is replaced by a call to '_TupleType'.
353
354 PY37: Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
355 PY39: Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
356
357 PY37: Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
358 PY39: Callable = _CallableType(collections.abc.Callable, 2)
359 """
360 return isinstance(node.func, Name) and (
361 (
362 node.func.name == "_TupleType"
363 and isinstance(node.args[0], Name)
364 and node.args[0].name == "tuple"
365 )
366 or (
367 node.func.name == "_CallableType"
368 and isinstance(node.args[0], Attribute)
369 and node.args[0].as_string() == "collections.abc.Callable"
370 )
371 )
372
373
374def infer_special_alias(
375 node: Call, ctx: context.InferenceContext | None = None
376) -> Iterator[ClassDef]:
377 """Infer call to tuple alias as new subscriptable class typing.Tuple."""
378 if not (
379 isinstance(node.parent, Assign)
380 and len(node.parent.targets) == 1
381 and isinstance(node.parent.targets[0], AssignName)
382 ):
383 raise UseInferenceDefault
384 try:
385 res = next(node.args[0].infer(context=ctx))
386 except StopIteration as e:
387 raise InferenceError(node=node.args[0], context=ctx) from e
388
389 assign_name = node.parent.targets[0]
390 class_def = ClassDef(
391 name=assign_name.name,
392 parent=node.parent,
393 lineno=assign_name.lineno,
394 col_offset=assign_name.col_offset,
395 end_lineno=assign_name.end_lineno,
396 end_col_offset=assign_name.end_col_offset,
397 )
398 class_def.postinit(bases=[res], body=[], decorators=None)
399 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
400 class_def.locals["__class_getitem__"] = [func_to_add]
401 # Avoid re-instantiating this class every time it's seen
402 node._explicit_inference = lambda node, context: iter([class_def])
403 return iter([class_def])
404
405
406def _looks_like_typing_cast(node: Call) -> bool:
407 return (isinstance(node.func, Name) and node.func.name == "cast") or (
408 isinstance(node.func, Attribute) and node.func.attrname == "cast"
409 )
410
411
412def infer_typing_cast(
413 node: Call, ctx: context.InferenceContext | None = None
414) -> Iterator[NodeNG]:
415 """Infer call to cast() returning same type as casted-from var."""
416 if not isinstance(node.func, (Name, Attribute)):
417 raise UseInferenceDefault
418
419 try:
420 func = next(node.func.infer(context=ctx))
421 except (InferenceError, StopIteration) as exc:
422 raise UseInferenceDefault from exc
423 if (
424 not isinstance(func, FunctionDef)
425 or func.qname() != "typing.cast"
426 or len(node.args) != 2
427 ):
428 raise UseInferenceDefault
429
430 return node.args[1].infer(context=ctx)
431
432
433def _typing_transform():
434 return AstroidBuilder(AstroidManager()).string_build(
435 textwrap.dedent(
436 """
437 class Generic:
438 @classmethod
439 def __class_getitem__(cls, item): return cls
440 class ParamSpec:
441 @property
442 def args(self):
443 return ParamSpecArgs(self)
444 @property
445 def kwargs(self):
446 return ParamSpecKwargs(self)
447 class ParamSpecArgs: ...
448 class ParamSpecKwargs: ...
449 class TypeAlias: ...
450 class Type:
451 @classmethod
452 def __class_getitem__(cls, item): return cls
453 class TypeVar:
454 @classmethod
455 def __class_getitem__(cls, item): return cls
456 class TypeVarTuple: ...
457 class ContextManager:
458 @classmethod
459 def __class_getitem__(cls, item): return cls
460 class AsyncContextManager:
461 @classmethod
462 def __class_getitem__(cls, item): return cls
463 class Pattern:
464 @classmethod
465 def __class_getitem__(cls, item): return cls
466 class Match:
467 @classmethod
468 def __class_getitem__(cls, item): return cls
469 """
470 )
471 )
472
473
474def register(manager: AstroidManager) -> None:
475 manager.register_transform(
476 Call,
477 inference_tip(infer_typing_typevar_or_newtype),
478 looks_like_typing_typevar_or_newtype,
479 )
480 manager.register_transform(
481 Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
482 )
483 manager.register_transform(
484 Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
485 )
486
487 manager.register_transform(
488 FunctionDef, inference_tip(infer_typedDict), _looks_like_typedDict
489 )
490
491 manager.register_transform(
492 Call, inference_tip(infer_typing_alias), _looks_like_typing_alias
493 )
494 manager.register_transform(
495 Call, inference_tip(infer_special_alias), _looks_like_special_alias
496 )
497
498 if PY312_PLUS:
499 register_module_extender(manager, "typing", _typing_transform)
500 manager.register_transform(
501 ClassDef,
502 inference_tip(infer_typing_generic_class_pep695),
503 _looks_like_generic_class_pep695,
504 )