1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Astroid hooks for typing.py support."""
6
7from __future__ import annotations
8
9import textwrap
10import typing
11from collections.abc import Iterator
12from functools import partial
13from typing import Final
14
15from astroid import context, nodes
16from astroid.brain.helpers import register_module_extender
17from astroid.builder import AstroidBuilder, _extract_single_node, extract_node
18from astroid.const import PY312_PLUS, PY313_PLUS, PY314_PLUS
19from astroid.exceptions import (
20 AstroidSyntaxError,
21 AttributeInferenceError,
22 InferenceError,
23 UseInferenceDefault,
24)
25from astroid.inference_tip import inference_tip
26from astroid.manager import AstroidManager
27
28TYPING_TYPEVARS = {"TypeVar", "NewType"}
29TYPING_TYPEVARS_QUALIFIED: Final = {
30 "typing.TypeVar",
31 "typing.NewType",
32 "typing_extensions.TypeVar",
33}
34TYPING_TYPEDDICT_QUALIFIED: Final = {"typing.TypedDict", "typing_extensions.TypedDict"}
35TYPING_TYPE_TEMPLATE = """
36class Meta(type):
37 def __getitem__(self, item):
38 return self
39
40 @property
41 def __args__(self):
42 return ()
43
44class {0}(metaclass=Meta):
45 pass
46"""
47TYPING_MEMBERS = set(getattr(typing, "__all__", []))
48
49TYPING_ALIAS = frozenset(
50 (
51 "typing.Hashable",
52 "typing.Awaitable",
53 "typing.Coroutine",
54 "typing.AsyncIterable",
55 "typing.AsyncIterator",
56 "typing.Iterable",
57 "typing.Iterator",
58 "typing.Reversible",
59 "typing.Sized",
60 "typing.Container",
61 "typing.Collection",
62 "typing.Callable",
63 "typing.AbstractSet",
64 "typing.MutableSet",
65 "typing.Mapping",
66 "typing.MutableMapping",
67 "typing.Sequence",
68 "typing.MutableSequence",
69 "typing.ByteString", # scheduled for removal in 3.17
70 "typing.Tuple",
71 "typing.List",
72 "typing.Deque",
73 "typing.Set",
74 "typing.FrozenSet",
75 "typing.MappingView",
76 "typing.KeysView",
77 "typing.ItemsView",
78 "typing.ValuesView",
79 "typing.ContextManager",
80 "typing.AsyncContextManager",
81 "typing.Dict",
82 "typing.DefaultDict",
83 "typing.OrderedDict",
84 "typing.Counter",
85 "typing.ChainMap",
86 "typing.Generator",
87 "typing.AsyncGenerator",
88 "typing.Type",
89 "typing.Pattern",
90 "typing.Match",
91 )
92)
93
94CLASS_GETITEM_TEMPLATE = """
95@classmethod
96def __class_getitem__(cls, item):
97 return cls
98"""
99
100
101def looks_like_typing_typevar_or_newtype(node) -> bool:
102 func = node.func
103 if isinstance(func, nodes.Attribute):
104 return func.attrname in TYPING_TYPEVARS
105 if isinstance(func, nodes.Name):
106 return func.name in TYPING_TYPEVARS
107 return False
108
109
110def infer_typing_typevar_or_newtype(
111 node: nodes.Call, context_itton: context.InferenceContext | None = None
112) -> Iterator[nodes.ClassDef]:
113 """Infer a typing.TypeVar(...) or typing.NewType(...) call."""
114 try:
115 func = next(node.func.infer(context=context_itton))
116 except (InferenceError, StopIteration) as exc:
117 raise UseInferenceDefault from exc
118
119 if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
120 raise UseInferenceDefault
121 if not node.args:
122 raise UseInferenceDefault
123 # Cannot infer from a dynamic class name (f-string)
124 if isinstance(node.args[0], nodes.JoinedStr):
125 raise UseInferenceDefault
126
127 typename = node.args[0].as_string().strip("'")
128 try:
129 node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
130 except AstroidSyntaxError as exc:
131 raise InferenceError from exc
132 return node.infer(context=context_itton)
133
134
135def _looks_like_typing_subscript(node) -> bool:
136 """Try to figure out if a Subscript node *might* be a typing-related subscript."""
137 if isinstance(node, nodes.Name):
138 return node.name in TYPING_MEMBERS
139 if isinstance(node, nodes.Attribute):
140 return node.attrname in TYPING_MEMBERS
141 if isinstance(node, nodes.Subscript):
142 return _looks_like_typing_subscript(node.value)
143 return False
144
145
146def infer_typing_attr(
147 node: nodes.Subscript, ctx: context.InferenceContext | None = None
148) -> Iterator[nodes.ClassDef]:
149 """Infer a typing.X[...] subscript."""
150 try:
151 value = next(node.value.infer()) # type: ignore[union-attr] # value shouldn't be None for Subscript.
152 except (InferenceError, StopIteration) as exc:
153 raise UseInferenceDefault from exc
154
155 if not value.qname().startswith("typing.") or value.qname() in TYPING_ALIAS:
156 # If typing subscript belongs to an alias handle it separately.
157 raise UseInferenceDefault
158
159 if (
160 PY313_PLUS
161 and isinstance(value, nodes.FunctionDef)
162 and value.qname() == "typing.Annotated"
163 ):
164 # typing.Annotated is a FunctionDef on 3.13+
165 node._explicit_inference = lambda node, context: iter([value])
166 return iter([value])
167
168 if isinstance(value, nodes.ClassDef) and value.qname() in {
169 "typing.Generic",
170 "typing.Annotated",
171 "typing_extensions.Annotated",
172 }:
173 # typing.Generic and typing.Annotated (PY39) are subscriptable
174 # through __class_getitem__. Since astroid can't easily
175 # infer the native methods, replace them for an easy inference tip
176 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
177 value.locals["__class_getitem__"] = [func_to_add]
178 if (
179 isinstance(node.parent, nodes.ClassDef)
180 and node in node.parent.bases
181 and getattr(node.parent, "__cache", None)
182 ):
183 # node.parent.slots is evaluated and cached before the inference tip
184 # is first applied. Remove the last result to allow a recalculation of slots
185 cache = node.parent.__cache # type: ignore[attr-defined] # Unrecognized getattr
186 if cache.get(node.parent.slots) is not None:
187 del cache[node.parent.slots]
188 # Avoid re-instantiating this class every time it's seen
189 node._explicit_inference = lambda node, context: iter([value])
190 return iter([value])
191
192 node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
193 return node.infer(context=ctx)
194
195
196def _looks_like_generic_class_pep695(node: nodes.ClassDef) -> bool:
197 """Check if class is using type parameter. Python 3.12+."""
198 return len(node.type_params) > 0
199
200
201def infer_typing_generic_class_pep695(
202 node: nodes.ClassDef, ctx: context.InferenceContext | None = None
203) -> Iterator[nodes.ClassDef]:
204 """Add __class_getitem__ for generic classes. Python 3.12+."""
205 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
206 node.locals["__class_getitem__"] = [func_to_add]
207 return iter([node])
208
209
210def _looks_like_typedDict( # pylint: disable=invalid-name
211 node: nodes.FunctionDef | nodes.ClassDef,
212) -> bool:
213 """Check if node is TypedDict FunctionDef."""
214 return node.qname() in TYPING_TYPEDDICT_QUALIFIED
215
216
217def infer_typedDict( # pylint: disable=invalid-name
218 node: nodes.FunctionDef, ctx: context.InferenceContext | None = None
219) -> Iterator[nodes.ClassDef]:
220 """Replace TypedDict FunctionDef with ClassDef."""
221 class_def = nodes.ClassDef(
222 name="TypedDict",
223 lineno=node.lineno,
224 col_offset=node.col_offset,
225 parent=node.parent,
226 end_lineno=node.end_lineno,
227 end_col_offset=node.end_col_offset,
228 )
229 class_def.postinit(bases=[extract_node("dict")], body=[], decorators=None)
230 func_to_add = _extract_single_node("dict")
231 class_def.locals["__call__"] = [func_to_add]
232 return iter([class_def])
233
234
235def _looks_like_typing_alias(node: nodes.Call) -> bool:
236 """
237 Returns True if the node corresponds to a call to _alias function.
238
239 For example :
240
241 MutableSet = _alias(collections.abc.MutableSet, T)
242
243 :param node: call node
244 """
245 return (
246 isinstance(node.func, nodes.Name)
247 # TODO: remove _DeprecatedGenericAlias when Py3.14 min
248 and node.func.name in {"_alias", "_DeprecatedGenericAlias"}
249 and len(node.args) == 2
250 and (
251 # _alias function works also for builtins object such as list and dict
252 isinstance(node.args[0], (nodes.Attribute, nodes.Name))
253 )
254 )
255
256
257def _forbid_class_getitem_access(node: nodes.ClassDef) -> None:
258 """Disable the access to __class_getitem__ method for the node in parameters."""
259
260 def full_raiser(origin_func, attr, *args, **kwargs):
261 """
262 Raises an AttributeInferenceError in case of access to __class_getitem__ method.
263 Otherwise, just call origin_func.
264 """
265 if attr == "__class_getitem__":
266 raise AttributeInferenceError("__class_getitem__ access is not allowed")
267 return origin_func(attr, *args, **kwargs)
268
269 try:
270 node.getattr("__class_getitem__")
271 # If we are here, then we are sure to modify an object that does have
272 # __class_getitem__ method (which origin is the protocol defined in
273 # collections module) whereas the typing module considers it should not.
274 # We do not want __class_getitem__ to be found in the classdef
275 partial_raiser = partial(full_raiser, node.getattr)
276 node.getattr = partial_raiser
277 except AttributeInferenceError:
278 pass
279
280
281def infer_typing_alias(
282 node: nodes.Call, ctx: context.InferenceContext | None = None
283) -> Iterator[nodes.ClassDef]:
284 """
285 Infers the call to _alias function
286 Insert ClassDef, with same name as aliased class,
287 in mro to simulate _GenericAlias.
288
289 :param node: call node
290 :param context: inference context
291
292 # TODO: evaluate if still necessary when Py3.12 is minimum
293 """
294 if not (
295 isinstance(node.parent, nodes.Assign)
296 and len(node.parent.targets) == 1
297 and isinstance(node.parent.targets[0], nodes.AssignName)
298 ):
299 raise UseInferenceDefault
300 try:
301 res = next(node.args[0].infer(context=ctx))
302 except StopIteration as e:
303 raise InferenceError(node=node.args[0], context=ctx) from e
304
305 assign_name = node.parent.targets[0]
306
307 class_def = nodes.ClassDef(
308 name=assign_name.name,
309 lineno=assign_name.lineno,
310 col_offset=assign_name.col_offset,
311 parent=node.parent,
312 end_lineno=assign_name.end_lineno,
313 end_col_offset=assign_name.end_col_offset,
314 )
315 if isinstance(res, nodes.ClassDef):
316 # Only add `res` as base if it's a `ClassDef`
317 # This isn't the case for `typing.Pattern` and `typing.Match`
318 class_def.postinit(bases=[res], body=[], decorators=None)
319
320 maybe_type_var = node.args[1]
321 if isinstance(maybe_type_var, nodes.Const) and maybe_type_var.value > 0:
322 # If typing alias is subscriptable, add `__class_getitem__` to ClassDef
323 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
324 class_def.locals["__class_getitem__"] = [func_to_add]
325 else:
326 # If not, make sure that `__class_getitem__` access is forbidden.
327 # This is an issue in cases where the aliased class implements it,
328 # but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
329 _forbid_class_getitem_access(class_def)
330
331 # Avoid re-instantiating this class every time it's seen
332 node._explicit_inference = lambda node, context: iter([class_def])
333 return iter([class_def])
334
335
336def _looks_like_special_alias(node: nodes.Call) -> bool:
337 """Return True if call is for Tuple or Callable alias.
338
339 In PY37 and PY38 the call is to '_VariadicGenericAlias' with 'tuple' as
340 first argument. In PY39+ it is replaced by a call to '_TupleType'.
341
342 PY37: Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
343 PY39: Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
344
345 PY37: Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
346 PY39: Callable = _CallableType(collections.abc.Callable, 2)
347 """
348 return (
349 isinstance(node.func, nodes.Name)
350 and node.args
351 and (
352 (
353 node.func.name == "_TupleType"
354 and isinstance(node.args[0], nodes.Name)
355 and node.args[0].name == "tuple"
356 )
357 or (
358 node.func.name == "_CallableType"
359 and isinstance(node.args[0], nodes.Attribute)
360 and node.args[0].as_string() == "collections.abc.Callable"
361 )
362 )
363 )
364
365
366def infer_special_alias(
367 node: nodes.Call, ctx: context.InferenceContext | None = None
368) -> Iterator[nodes.ClassDef]:
369 """Infer call to tuple alias as new subscriptable class typing.Tuple."""
370 if not (
371 isinstance(node.parent, nodes.Assign)
372 and len(node.parent.targets) == 1
373 and isinstance(node.parent.targets[0], nodes.AssignName)
374 ):
375 raise UseInferenceDefault
376 try:
377 res = next(node.args[0].infer(context=ctx))
378 except StopIteration as e:
379 raise InferenceError(node=node.args[0], context=ctx) from e
380
381 assign_name = node.parent.targets[0]
382 class_def = nodes.ClassDef(
383 name=assign_name.name,
384 parent=node.parent,
385 lineno=assign_name.lineno,
386 col_offset=assign_name.col_offset,
387 end_lineno=assign_name.end_lineno,
388 end_col_offset=assign_name.end_col_offset,
389 )
390 class_def.postinit(bases=[res], body=[], decorators=None)
391 func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
392 class_def.locals["__class_getitem__"] = [func_to_add]
393 # Avoid re-instantiating this class every time it's seen
394 node._explicit_inference = lambda node, context: iter([class_def])
395 return iter([class_def])
396
397
398def _looks_like_typing_cast(node: nodes.Call) -> bool:
399 return (isinstance(node.func, nodes.Name) and node.func.name == "cast") or (
400 isinstance(node.func, nodes.Attribute) and node.func.attrname == "cast"
401 )
402
403
404def infer_typing_cast(
405 node: nodes.Call, ctx: context.InferenceContext | None = None
406) -> Iterator[nodes.NodeNG]:
407 """Infer call to cast() returning same type as casted-from var."""
408 if not isinstance(node.func, (nodes.Name, nodes.Attribute)):
409 raise UseInferenceDefault
410
411 try:
412 func = next(node.func.infer(context=ctx))
413 except (InferenceError, StopIteration) as exc:
414 raise UseInferenceDefault from exc
415 if not (
416 isinstance(func, nodes.FunctionDef)
417 and func.qname() == "typing.cast"
418 and len(node.args) == 2
419 ):
420 raise UseInferenceDefault
421
422 return node.args[1].infer(context=ctx)
423
424
425def _typing_transform():
426 code = textwrap.dedent(
427 """
428 class Generic:
429 @classmethod
430 def __class_getitem__(cls, item): return cls
431 class ParamSpec:
432 @property
433 def args(self):
434 return ParamSpecArgs(self)
435 @property
436 def kwargs(self):
437 return ParamSpecKwargs(self)
438 class ParamSpecArgs: ...
439 class ParamSpecKwargs: ...
440 class TypeAlias: ...
441 class Type:
442 @classmethod
443 def __class_getitem__(cls, item): return cls
444 class TypeVar:
445 @classmethod
446 def __class_getitem__(cls, item): return cls
447 class TypeVarTuple: ...
448 class ContextManager:
449 @classmethod
450 def __class_getitem__(cls, item): return cls
451 class AsyncContextManager:
452 @classmethod
453 def __class_getitem__(cls, item): return cls
454 class Pattern:
455 @classmethod
456 def __class_getitem__(cls, item): return cls
457 class Match:
458 @classmethod
459 def __class_getitem__(cls, item): return cls
460 """
461 )
462 if PY314_PLUS:
463 code += textwrap.dedent(
464 """
465 from annotationlib import ForwardRef
466 class Union:
467 @classmethod
468 def __class_getitem__(cls, item): return cls
469 """
470 )
471 return AstroidBuilder(AstroidManager()).string_build(code)
472
473
474def register(manager: AstroidManager) -> None:
475 manager.register_transform(
476 nodes.Call,
477 inference_tip(infer_typing_typevar_or_newtype),
478 looks_like_typing_typevar_or_newtype,
479 )
480 manager.register_transform(
481 nodes.Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
482 )
483 manager.register_transform(
484 nodes.Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
485 )
486
487 manager.register_transform(
488 nodes.FunctionDef, inference_tip(infer_typedDict), _looks_like_typedDict
489 )
490
491 manager.register_transform(
492 nodes.Call, inference_tip(infer_typing_alias), _looks_like_typing_alias
493 )
494 manager.register_transform(
495 nodes.Call, inference_tip(infer_special_alias), _looks_like_special_alias
496 )
497
498 if PY312_PLUS:
499 register_module_extender(manager, "typing", _typing_transform)
500 manager.register_transform(
501 nodes.ClassDef,
502 inference_tip(infer_typing_generic_class_pep695),
503 _looks_like_generic_class_pep695,
504 )