1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Astroid hooks for understanding functools library module."""
6
7from __future__ import annotations
8
9from collections.abc import Iterator
10from functools import partial
11from itertools import chain
12
13from astroid import BoundMethod, arguments, nodes, objects
14from astroid.builder import extract_node
15from astroid.context import InferenceContext
16from astroid.exceptions import InferenceError, UseInferenceDefault
17from astroid.inference_tip import inference_tip
18from astroid.interpreter import objectmodel
19from astroid.manager import AstroidManager
20from astroid.typing import InferenceResult, SuccessfulInferenceResult
21from astroid.util import UninferableBase, safe_infer
22
23LRU_CACHE = "functools.lru_cache"
24
25
26class LruWrappedModel(objectmodel.FunctionModel):
27 """Special attribute model for functions decorated with functools.lru_cache.
28
29 The said decorators patches at decoration time some functions onto
30 the decorated function.
31 """
32
33 @property
34 def attr___wrapped__(self):
35 return self._instance
36
37 @property
38 def attr_cache_info(self):
39 cache_info = extract_node("""
40 from functools import _CacheInfo
41 _CacheInfo(0, 0, 0, 0)
42 """)
43
44 class CacheInfoBoundMethod(BoundMethod):
45 def infer_call_result(
46 self,
47 caller: SuccessfulInferenceResult | None,
48 context: InferenceContext | None = None,
49 ) -> Iterator[InferenceResult]:
50 res = safe_infer(cache_info)
51 assert res is not None
52 yield res
53
54 return CacheInfoBoundMethod(proxy=self._instance, bound=self._instance)
55
56 @property
57 def attr_cache_clear(self):
58 node = extract_node("""def cache_clear(self): pass""")
59 return BoundMethod(proxy=node, bound=self._instance.parent.scope())
60
61
62def _transform_lru_cache(node, context: InferenceContext | None = None) -> None:
63 # TODO: this is not ideal, since the node should be immutable,
64 # but due to https://github.com/pylint-dev/astroid/issues/354,
65 # there's not much we can do now.
66 # Replacing the node would work partially, because,
67 # in pylint, the old node would still be available, leading
68 # to spurious false positives.
69 node.special_attributes = LruWrappedModel()(node)
70
71
72def _functools_partial_inference(
73 node: nodes.Call, context: InferenceContext | None = None
74) -> Iterator[objects.PartialFunction]:
75 call = arguments.CallSite.from_call(node, context=context)
76 number_of_positional = len(call.positional_arguments)
77 if number_of_positional < 1:
78 raise UseInferenceDefault("functools.partial takes at least one argument")
79 if number_of_positional == 1 and not call.keyword_arguments:
80 raise UseInferenceDefault(
81 "functools.partial needs at least to have some filled arguments"
82 )
83
84 partial_function = call.positional_arguments[0]
85 try:
86 inferred_wrapped_function = next(partial_function.infer(context=context))
87 except (InferenceError, StopIteration) as exc:
88 raise UseInferenceDefault from exc
89 if isinstance(inferred_wrapped_function, UninferableBase):
90 raise UseInferenceDefault("Cannot infer the wrapped function")
91 if not isinstance(inferred_wrapped_function, nodes.FunctionDef):
92 raise UseInferenceDefault("The wrapped function is not a function")
93
94 # Determine if the passed keywords into the callsite are supported
95 # by the wrapped function.
96 if not inferred_wrapped_function.args:
97 function_parameters = []
98 else:
99 function_parameters = chain(
100 inferred_wrapped_function.args.args or (),
101 inferred_wrapped_function.args.posonlyargs or (),
102 inferred_wrapped_function.args.kwonlyargs or (),
103 )
104 parameter_names = {
105 param.name
106 for param in function_parameters
107 if isinstance(param, nodes.AssignName)
108 }
109 if set(call.keyword_arguments) - parameter_names:
110 raise UseInferenceDefault("wrapped function received unknown parameters")
111
112 partial_function = objects.PartialFunction(
113 call,
114 name=inferred_wrapped_function.name,
115 lineno=inferred_wrapped_function.lineno,
116 col_offset=inferred_wrapped_function.col_offset,
117 end_lineno=inferred_wrapped_function.end_lineno,
118 end_col_offset=inferred_wrapped_function.end_col_offset,
119 parent=node.parent,
120 )
121 partial_function.postinit(
122 args=inferred_wrapped_function.args,
123 body=inferred_wrapped_function.body,
124 decorators=inferred_wrapped_function.decorators,
125 returns=inferred_wrapped_function.returns,
126 type_comment_returns=inferred_wrapped_function.type_comment_returns,
127 type_comment_args=inferred_wrapped_function.type_comment_args,
128 doc_node=inferred_wrapped_function.doc_node,
129 )
130 return iter((partial_function,))
131
132
133def _looks_like_lru_cache(node) -> bool:
134 """Check if the given function node is decorated with lru_cache."""
135 if not node.decorators:
136 return False
137 for decorator in node.decorators.nodes:
138 if not isinstance(decorator, (nodes.Attribute, nodes.Call)):
139 continue
140 if _looks_like_functools_member(decorator, "lru_cache"):
141 return True
142 return False
143
144
145def _looks_like_functools_member(
146 node: nodes.Attribute | nodes.Call, member: str
147) -> bool:
148 """Check if the given Call node is the wanted member of functools."""
149 if isinstance(node, nodes.Attribute):
150 return node.attrname == member
151 if isinstance(node.func, nodes.Name):
152 return node.func.name == member
153 if isinstance(node.func, nodes.Attribute):
154 return (
155 node.func.attrname == member
156 and isinstance(node.func.expr, nodes.Name)
157 and node.func.expr.name == "functools"
158 )
159 return False
160
161
162_looks_like_partial = partial(_looks_like_functools_member, member="partial")
163
164
165def register(manager: AstroidManager) -> None:
166 manager.register_transform(
167 nodes.FunctionDef, _transform_lru_cache, _looks_like_lru_cache
168 )
169
170 manager.register_transform(
171 nodes.Call,
172 inference_tip(_functools_partial_inference),
173 _looks_like_partial,
174 )