1from typing import Any
2
3from jedi import debug
4from jedi.inference.cache import inference_state_method_cache, CachedMetaClass
5from jedi.inference import compiled
6from jedi.inference import recursion
7from jedi.inference import docstrings
8from jedi.inference import flow_analysis
9from jedi.inference.signature import TreeSignature
10from jedi.inference.filters import ParserTreeFilter, FunctionExecutionFilter, \
11 AnonymousFunctionExecutionFilter
12from jedi.inference.names import ValueName, AbstractNameDefinition, \
13 AnonymousParamName, ParamName, NameWrapper
14from jedi.inference.base_value import ContextualizedNode, NO_VALUES, \
15 ValueSet, TreeValue, ValueWrapper
16from jedi.inference.lazy_value import LazyKnownValues, LazyKnownValue, \
17 LazyTreeValue
18from jedi.inference.context import ValueContext, TreeContextMixin
19from jedi.inference.value import iterable
20from jedi import parser_utils
21from jedi.inference.parser_cache import get_yield_exprs
22from jedi.inference.helpers import values_from_qualified_names
23from jedi.inference.gradual.generics import TupleGenericManager
24
25
26class LambdaName(AbstractNameDefinition):
27 string_name = '<lambda>'
28 api_type = 'function'
29
30 def __init__(self, lambda_value):
31 self._lambda_value = lambda_value
32 self.parent_context = lambda_value.parent_context
33
34 @property
35 def start_pos(self):
36 return self._lambda_value.tree_node.start_pos
37
38 def infer(self):
39 return ValueSet([self._lambda_value])
40
41
42class FunctionAndClassBase(TreeValue):
43 def get_qualified_names(self):
44 if self.parent_context.is_class():
45 n = self.parent_context.get_qualified_names()
46 if n is None:
47 # This means that the parent class lives within a function.
48 return None
49 return n + (self.py__name__(),)
50 elif self.parent_context.is_module():
51 return (self.py__name__(),)
52 else:
53 return None
54
55
56class FunctionMixin:
57 api_type = 'function'
58 tree_node: Any
59 py__class__: Any
60 as_context: Any
61 get_signature_functions: Any
62
63 def get_filters(self, origin_scope=None):
64 cls = self.py__class__()
65 for instance in cls.execute_with_values():
66 yield from instance.get_filters(origin_scope=origin_scope)
67
68 def py__get__(self, instance, class_value):
69 from jedi.inference.value.instance import BoundMethod
70 if instance is None:
71 # Calling the Foo.bar results in the original bar function.
72 return ValueSet([self])
73 return ValueSet([BoundMethod(instance, class_value.as_context(), self)])
74
75 def get_param_names(self):
76 return [AnonymousParamName(self, param.name)
77 for param in self.tree_node.get_params()]
78
79 @property
80 def name(self):
81 if self.tree_node.type == 'lambdef':
82 return LambdaName(self)
83 return ValueName(self, self.tree_node.name)
84
85 def is_function(self):
86 return True
87
88 def py__name__(self):
89 return self.name.string_name
90
91 def get_type_hint(self, add_class_info=True):
92 return_annotation = self.tree_node.annotation
93 if return_annotation is None:
94 def param_name_to_str(n):
95 s = n.string_name
96 annotation = n.infer().get_type_hint()
97 if annotation is not None:
98 s += ': ' + annotation
99 if n.default_node is not None:
100 s += '=' + n.default_node.get_code(include_prefix=False)
101 return s
102
103 function_execution = self.as_context()
104 result = function_execution.infer()
105 return_hint = result.get_type_hint()
106 body = self.py__name__() + '(%s)' % ', '.join([
107 param_name_to_str(n)
108 for n in function_execution.get_param_names()
109 ])
110 if return_hint is None:
111 return body
112 else:
113 return_hint = return_annotation.get_code(include_prefix=False)
114 body = self.py__name__() + self.tree_node.children[2].get_code(include_prefix=False)
115
116 return body + ' -> ' + return_hint
117
118 def py__call__(self, arguments):
119 function_execution = self.as_context(arguments)
120 return function_execution.infer()
121
122 def _as_context(self, arguments=None):
123 if arguments is None:
124 return AnonymousFunctionExecution(self)
125 return FunctionExecutionContext(self, arguments)
126
127 def get_signatures(self):
128 return [TreeSignature(f) for f in self.get_signature_functions()]
129
130
131class FunctionValue(FunctionMixin, FunctionAndClassBase, metaclass=CachedMetaClass):
132 @classmethod
133 def from_context(cls, context, tree_node):
134 def create(tree_node):
135 if context.is_class():
136 return MethodValue(
137 context.inference_state,
138 context,
139 parent_context=parent_context,
140 tree_node=tree_node
141 )
142 else:
143 return cls(
144 context.inference_state,
145 parent_context=parent_context,
146 tree_node=tree_node
147 )
148
149 overloaded_funcs = list(_find_overload_functions(context, tree_node))
150
151 parent_context = context
152 while parent_context.is_class() or parent_context.is_instance():
153 parent_context = parent_context.parent_context
154
155 function = create(tree_node)
156
157 if overloaded_funcs:
158 return OverloadedFunctionValue(
159 function,
160 # Get them into the correct order: lower line first.
161 list(reversed([create(f) for f in overloaded_funcs]))
162 )
163 return function
164
165 def py__class__(self):
166 c, = values_from_qualified_names(self.inference_state, 'types', 'FunctionType')
167 return c
168
169 def get_default_param_context(self):
170 return self.parent_context
171
172 def get_signature_functions(self):
173 return [self]
174
175
176class FunctionNameInClass(NameWrapper):
177 def __init__(self, class_context, name):
178 super().__init__(name)
179 self._class_context = class_context
180
181 def get_defining_qualified_value(self):
182 return self._class_context.get_value() # Might be None.
183
184
185class MethodValue(FunctionValue):
186 def __init__(self, inference_state, class_context, *args, **kwargs):
187 super().__init__(inference_state, *args, **kwargs)
188 self.class_context = class_context
189
190 def get_default_param_context(self):
191 return self.class_context
192
193 def get_qualified_names(self):
194 # Need to implement this, because the parent value of a method
195 # value is not the class value but the module.
196 names = self.class_context.get_qualified_names()
197 if names is None:
198 return None
199 return names + (self.py__name__(),)
200
201 @property
202 def name(self):
203 return FunctionNameInClass(self.class_context, super().name)
204
205
206class BaseFunctionExecutionContext(ValueContext, TreeContextMixin):
207 def infer_annotations(self):
208 raise NotImplementedError
209
210 @inference_state_method_cache(default=NO_VALUES)
211 @recursion.execution_recursion_decorator()
212 def get_return_values(self, check_yields=False):
213 funcdef = self.tree_node
214 if funcdef.type == 'lambdef':
215 return self.infer_node(funcdef.children[-1])
216
217 if check_yields:
218 value_set = NO_VALUES
219 returns = get_yield_exprs(self.inference_state, funcdef)
220 else:
221 value_set = self.infer_annotations()
222 if value_set:
223 # If there are annotations, prefer them over anything else.
224 # This will make it faster.
225 return value_set
226 value_set |= docstrings.infer_return_types(self._value)
227 returns = funcdef.iter_return_stmts()
228
229 for r in returns:
230 if check_yields:
231 value_set |= ValueSet.from_sets(
232 lazy_value.infer()
233 for lazy_value in self._get_yield_lazy_value(r)
234 )
235 else:
236 check = flow_analysis.reachability_check(self, funcdef, r)
237 if check is flow_analysis.UNREACHABLE:
238 debug.dbg('Return unreachable: %s', r)
239 else:
240 try:
241 children = r.children
242 except AttributeError:
243 ctx = compiled.builtin_from_name(self.inference_state, 'None')
244 value_set |= ValueSet([ctx])
245 else:
246 value_set |= self.infer_node(children[1])
247 if check is flow_analysis.REACHABLE:
248 debug.dbg('Return reachable: %s', r)
249 break
250 return value_set
251
252 def _get_yield_lazy_value(self, yield_expr):
253 if yield_expr.type == 'keyword':
254 # `yield` just yields None.
255 ctx = compiled.builtin_from_name(self.inference_state, 'None')
256 yield LazyKnownValue(ctx)
257 return
258
259 node = yield_expr.children[1]
260 if node.type == 'yield_arg': # It must be a yield from.
261 cn = ContextualizedNode(self, node.children[1])
262 yield from cn.infer().iterate(cn)
263 else:
264 yield LazyTreeValue(self, node)
265
266 @recursion.execution_recursion_decorator(default=iter([]))
267 def get_yield_lazy_values(self, is_async=False):
268 # TODO: if is_async, wrap yield statements in Awaitable/async_generator_asend
269 for_parents = [(y, y.search_ancestor('for_stmt', 'funcdef',
270 'while_stmt', 'if_stmt'))
271 for y in get_yield_exprs(self.inference_state, self.tree_node)]
272
273 # Calculate if the yields are placed within the same for loop.
274 yields_order = []
275 last_for_stmt = None
276 for yield_, for_stmt in for_parents:
277 # For really simple for loops we can predict the order. Otherwise
278 # we just ignore it.
279 parent = for_stmt.parent
280 if parent.type == 'suite':
281 parent = parent.parent
282 if for_stmt.type == 'for_stmt' and parent == self.tree_node \
283 and parser_utils.for_stmt_defines_one_name(for_stmt): # Simplicity for now.
284 if for_stmt == last_for_stmt:
285 yields_order[-1][1].append(yield_)
286 else:
287 yields_order.append((for_stmt, [yield_]))
288 elif for_stmt == self.tree_node:
289 yields_order.append((None, [yield_]))
290 else:
291 types = self.get_return_values(check_yields=True)
292 if types:
293 yield LazyKnownValues(types, min=0, max=float('inf'))
294 return
295 last_for_stmt = for_stmt
296
297 for for_stmt, yields in yields_order:
298 if for_stmt is None:
299 # No for_stmt, just normal yields.
300 for yield_ in yields:
301 yield from self._get_yield_lazy_value(yield_)
302 else:
303 input_node = for_stmt.get_testlist()
304 cn = ContextualizedNode(self, input_node)
305 ordered = cn.infer().iterate(cn)
306 ordered = list(ordered)
307 for lazy_value in ordered:
308 dct = {str(for_stmt.children[1].value): lazy_value.infer()}
309 with self.predefine_names(for_stmt, dct):
310 for yield_in_same_for_stmt in yields:
311 yield from self._get_yield_lazy_value(yield_in_same_for_stmt)
312
313 def merge_yield_values(self, is_async=False):
314 return ValueSet.from_sets(
315 lazy_value.infer()
316 for lazy_value in self.get_yield_lazy_values()
317 )
318
319 def is_generator(self):
320 return bool(get_yield_exprs(self.inference_state, self.tree_node))
321
322 def infer(self):
323 """
324 Created to be used by inheritance.
325 """
326 inference_state = self.inference_state
327 is_coroutine = self.tree_node.parent.type in ('async_stmt', 'async_funcdef')
328 from jedi.inference.gradual.base import GenericClass
329
330 if is_coroutine:
331 if self.is_generator():
332 async_generator_classes = inference_state.typing_module \
333 .py__getattribute__('AsyncGenerator')
334
335 yield_values = self.merge_yield_values(is_async=True)
336 # The contravariant doesn't seem to be defined.
337 generics = (yield_values.py__class__(), NO_VALUES)
338 return ValueSet(
339 GenericClass(c, TupleGenericManager(generics))
340 for c in async_generator_classes
341 ).execute_annotation(None)
342 else:
343 async_classes = inference_state.types_module.py__getattribute__('CoroutineType')
344 return_values = self.get_return_values()
345 # Only the first generic is relevant.
346 generics = (NO_VALUES, NO_VALUES, return_values.py__class__())
347 return ValueSet(
348 GenericClass(c, TupleGenericManager(generics)) for c in async_classes
349 ).execute_annotation(None)
350 else:
351 # If there are annotations, prefer them over anything else.
352 if self.is_generator() and not self.infer_annotations():
353 return ValueSet([iterable.Generator(inference_state, self)])
354 else:
355 return self.get_return_values()
356
357
358class FunctionExecutionContext(BaseFunctionExecutionContext):
359 def __init__(self, function_value, arguments):
360 super().__init__(function_value)
361 self._arguments = arguments
362
363 def get_filters(self, until_position=None, origin_scope=None):
364 yield FunctionExecutionFilter(
365 self, self._value,
366 until_position=until_position,
367 origin_scope=origin_scope,
368 arguments=self._arguments
369 )
370
371 def infer_annotations(self):
372 from jedi.inference.gradual.annotation import infer_return_types
373 return infer_return_types(self._value, self._arguments)
374
375 def get_param_names(self):
376 return [
377 ParamName(self._value, param.name, self._arguments)
378 for param in self._value.tree_node.get_params()
379 ]
380
381
382class AnonymousFunctionExecution(BaseFunctionExecutionContext):
383 def infer_annotations(self):
384 # I don't think inferring anonymous executions is a big thing.
385 # Anonymous contexts are mostly there for the user to work in. ~ dave
386 return NO_VALUES
387
388 def get_filters(self, until_position=None, origin_scope=None):
389 yield AnonymousFunctionExecutionFilter(
390 self, self._value,
391 until_position=until_position,
392 origin_scope=origin_scope,
393 )
394
395 def get_param_names(self):
396 return self._value.get_param_names()
397
398
399class OverloadedFunctionValue(FunctionMixin, ValueWrapper):
400 def __init__(self, function, overloaded_functions):
401 super().__init__(function)
402 self._overloaded_functions = overloaded_functions
403
404 def py__call__(self, arguments):
405 debug.dbg("Execute overloaded function %s", self._wrapped_value, color='BLUE')
406 function_executions = []
407 for signature in self.get_signatures():
408 function_execution = signature.value.as_context(arguments)
409 function_executions.append(function_execution)
410 if signature.matches_signature(arguments):
411 return function_execution.infer()
412
413 if self.inference_state.is_analysis:
414 # In this case we want precision.
415 return NO_VALUES
416 return ValueSet.from_sets(fe.infer() for fe in function_executions)
417
418 def get_signature_functions(self):
419 return self._overloaded_functions
420
421 def get_type_hint(self, add_class_info=True):
422 return 'Union[%s]' % ', '.join(f.get_type_hint() for f in self._overloaded_functions)
423
424
425def _find_overload_functions(context, tree_node):
426 def _is_overload_decorated(funcdef):
427 if funcdef.parent.type == 'decorated':
428 decorators = funcdef.parent.children[0]
429 if decorators.type == 'decorator':
430 decorators = [decorators]
431 else:
432 decorators = decorators.children
433 for decorator in decorators:
434 dotted_name = decorator.children[1]
435 if dotted_name.type == 'name' and dotted_name.value == 'overload':
436 # TODO check with values if it's the right overload
437 return True
438 return False
439
440 if tree_node.type == 'lambdef':
441 return
442
443 if _is_overload_decorated(tree_node):
444 yield tree_node
445
446 while True:
447 filter = ParserTreeFilter(
448 context,
449 until_position=tree_node.start_pos
450 )
451 names = filter.get(tree_node.name.value)
452 assert isinstance(names, list)
453 if not names:
454 break
455
456 found = False
457 for name in names:
458 funcdef = name.tree_name.parent
459 if funcdef.type == 'funcdef' and _is_overload_decorated(funcdef):
460 tree_node = funcdef
461 found = True
462 yield funcdef
463
464 if not found:
465 break