1from jedi import debug
2from jedi.inference.cache import inference_state_method_cache, CachedMetaClass
3from jedi.inference import compiled
4from jedi.inference import recursion
5from jedi.inference import docstrings
6from jedi.inference import flow_analysis
7from jedi.inference.signature import TreeSignature
8from jedi.inference.filters import ParserTreeFilter, FunctionExecutionFilter, \
9 AnonymousFunctionExecutionFilter
10from jedi.inference.names import ValueName, AbstractNameDefinition, \
11 AnonymousParamName, ParamName, NameWrapper
12from jedi.inference.base_value import ContextualizedNode, NO_VALUES, \
13 ValueSet, TreeValue, ValueWrapper
14from jedi.inference.lazy_value import LazyKnownValues, LazyKnownValue, \
15 LazyTreeValue
16from jedi.inference.context import ValueContext, TreeContextMixin
17from jedi.inference.value import iterable
18from jedi import parser_utils
19from jedi.inference.parser_cache import get_yield_exprs
20from jedi.inference.helpers import values_from_qualified_names
21from jedi.inference.gradual.generics import TupleGenericManager
22
23
24class LambdaName(AbstractNameDefinition):
25 string_name = '<lambda>'
26 api_type = 'function'
27
28 def __init__(self, lambda_value):
29 self._lambda_value = lambda_value
30 self.parent_context = lambda_value.parent_context
31
32 @property
33 def start_pos(self):
34 return self._lambda_value.tree_node.start_pos
35
36 def infer(self):
37 return ValueSet([self._lambda_value])
38
39
40class FunctionAndClassBase(TreeValue):
41 def get_qualified_names(self):
42 if self.parent_context.is_class():
43 n = self.parent_context.get_qualified_names()
44 if n is None:
45 # This means that the parent class lives within a function.
46 return None
47 return n + (self.py__name__(),)
48 elif self.parent_context.is_module():
49 return (self.py__name__(),)
50 else:
51 return None
52
53
54class FunctionMixin:
55 api_type = 'function'
56
57 def get_filters(self, origin_scope=None):
58 cls = self.py__class__()
59 for instance in cls.execute_with_values():
60 yield from instance.get_filters(origin_scope=origin_scope)
61
62 def py__get__(self, instance, class_value):
63 from jedi.inference.value.instance import BoundMethod
64 if instance is None:
65 # Calling the Foo.bar results in the original bar function.
66 return ValueSet([self])
67 return ValueSet([BoundMethod(instance, class_value.as_context(), self)])
68
69 def get_param_names(self):
70 return [AnonymousParamName(self, param.name)
71 for param in self.tree_node.get_params()]
72
73 @property
74 def name(self):
75 if self.tree_node.type == 'lambdef':
76 return LambdaName(self)
77 return ValueName(self, self.tree_node.name)
78
79 def is_function(self):
80 return True
81
82 def py__name__(self):
83 return self.name.string_name
84
85 def get_type_hint(self, add_class_info=True):
86 return_annotation = self.tree_node.annotation
87 if return_annotation is None:
88 def param_name_to_str(n):
89 s = n.string_name
90 annotation = n.infer().get_type_hint()
91 if annotation is not None:
92 s += ': ' + annotation
93 if n.default_node is not None:
94 s += '=' + n.default_node.get_code(include_prefix=False)
95 return s
96
97 function_execution = self.as_context()
98 result = function_execution.infer()
99 return_hint = result.get_type_hint()
100 body = self.py__name__() + '(%s)' % ', '.join([
101 param_name_to_str(n)
102 for n in function_execution.get_param_names()
103 ])
104 if return_hint is None:
105 return body
106 else:
107 return_hint = return_annotation.get_code(include_prefix=False)
108 body = self.py__name__() + self.tree_node.children[2].get_code(include_prefix=False)
109
110 return body + ' -> ' + return_hint
111
112 def py__call__(self, arguments):
113 function_execution = self.as_context(arguments)
114 return function_execution.infer()
115
116 def _as_context(self, arguments=None):
117 if arguments is None:
118 return AnonymousFunctionExecution(self)
119 return FunctionExecutionContext(self, arguments)
120
121 def get_signatures(self):
122 return [TreeSignature(f) for f in self.get_signature_functions()]
123
124
125class FunctionValue(FunctionMixin, FunctionAndClassBase, metaclass=CachedMetaClass):
126 @classmethod
127 def from_context(cls, context, tree_node):
128 def create(tree_node):
129 if context.is_class():
130 return MethodValue(
131 context.inference_state,
132 context,
133 parent_context=parent_context,
134 tree_node=tree_node
135 )
136 else:
137 return cls(
138 context.inference_state,
139 parent_context=parent_context,
140 tree_node=tree_node
141 )
142
143 overloaded_funcs = list(_find_overload_functions(context, tree_node))
144
145 parent_context = context
146 while parent_context.is_class() or parent_context.is_instance():
147 parent_context = parent_context.parent_context
148
149 function = create(tree_node)
150
151 if overloaded_funcs:
152 return OverloadedFunctionValue(
153 function,
154 # Get them into the correct order: lower line first.
155 list(reversed([create(f) for f in overloaded_funcs]))
156 )
157 return function
158
159 def py__class__(self):
160 c, = values_from_qualified_names(self.inference_state, 'types', 'FunctionType')
161 return c
162
163 def get_default_param_context(self):
164 return self.parent_context
165
166 def get_signature_functions(self):
167 return [self]
168
169
170class FunctionNameInClass(NameWrapper):
171 def __init__(self, class_context, name):
172 super().__init__(name)
173 self._class_context = class_context
174
175 def get_defining_qualified_value(self):
176 return self._class_context.get_value() # Might be None.
177
178
179class MethodValue(FunctionValue):
180 def __init__(self, inference_state, class_context, *args, **kwargs):
181 super().__init__(inference_state, *args, **kwargs)
182 self.class_context = class_context
183
184 def get_default_param_context(self):
185 return self.class_context
186
187 def get_qualified_names(self):
188 # Need to implement this, because the parent value of a method
189 # value is not the class value but the module.
190 names = self.class_context.get_qualified_names()
191 if names is None:
192 return None
193 return names + (self.py__name__(),)
194
195 @property
196 def name(self):
197 return FunctionNameInClass(self.class_context, super().name)
198
199
200class BaseFunctionExecutionContext(ValueContext, TreeContextMixin):
201 def infer_annotations(self):
202 raise NotImplementedError
203
204 @inference_state_method_cache(default=NO_VALUES)
205 @recursion.execution_recursion_decorator()
206 def get_return_values(self, check_yields=False):
207 funcdef = self.tree_node
208 if funcdef.type == 'lambdef':
209 return self.infer_node(funcdef.children[-1])
210
211 if check_yields:
212 value_set = NO_VALUES
213 returns = get_yield_exprs(self.inference_state, funcdef)
214 else:
215 value_set = self.infer_annotations()
216 if value_set:
217 # If there are annotations, prefer them over anything else.
218 # This will make it faster.
219 return value_set
220 value_set |= docstrings.infer_return_types(self._value)
221 returns = funcdef.iter_return_stmts()
222
223 for r in returns:
224 if check_yields:
225 value_set |= ValueSet.from_sets(
226 lazy_value.infer()
227 for lazy_value in self._get_yield_lazy_value(r)
228 )
229 else:
230 check = flow_analysis.reachability_check(self, funcdef, r)
231 if check is flow_analysis.UNREACHABLE:
232 debug.dbg('Return unreachable: %s', r)
233 else:
234 try:
235 children = r.children
236 except AttributeError:
237 ctx = compiled.builtin_from_name(self.inference_state, 'None')
238 value_set |= ValueSet([ctx])
239 else:
240 value_set |= self.infer_node(children[1])
241 if check is flow_analysis.REACHABLE:
242 debug.dbg('Return reachable: %s', r)
243 break
244 return value_set
245
246 def _get_yield_lazy_value(self, yield_expr):
247 if yield_expr.type == 'keyword':
248 # `yield` just yields None.
249 ctx = compiled.builtin_from_name(self.inference_state, 'None')
250 yield LazyKnownValue(ctx)
251 return
252
253 node = yield_expr.children[1]
254 if node.type == 'yield_arg': # It must be a yield from.
255 cn = ContextualizedNode(self, node.children[1])
256 yield from cn.infer().iterate(cn)
257 else:
258 yield LazyTreeValue(self, node)
259
260 @recursion.execution_recursion_decorator(default=iter([]))
261 def get_yield_lazy_values(self, is_async=False):
262 # TODO: if is_async, wrap yield statements in Awaitable/async_generator_asend
263 for_parents = [(y, y.search_ancestor('for_stmt', 'funcdef',
264 'while_stmt', 'if_stmt'))
265 for y in get_yield_exprs(self.inference_state, self.tree_node)]
266
267 # Calculate if the yields are placed within the same for loop.
268 yields_order = []
269 last_for_stmt = None
270 for yield_, for_stmt in for_parents:
271 # For really simple for loops we can predict the order. Otherwise
272 # we just ignore it.
273 parent = for_stmt.parent
274 if parent.type == 'suite':
275 parent = parent.parent
276 if for_stmt.type == 'for_stmt' and parent == self.tree_node \
277 and parser_utils.for_stmt_defines_one_name(for_stmt): # Simplicity for now.
278 if for_stmt == last_for_stmt:
279 yields_order[-1][1].append(yield_)
280 else:
281 yields_order.append((for_stmt, [yield_]))
282 elif for_stmt == self.tree_node:
283 yields_order.append((None, [yield_]))
284 else:
285 types = self.get_return_values(check_yields=True)
286 if types:
287 yield LazyKnownValues(types, min=0, max=float('inf'))
288 return
289 last_for_stmt = for_stmt
290
291 for for_stmt, yields in yields_order:
292 if for_stmt is None:
293 # No for_stmt, just normal yields.
294 for yield_ in yields:
295 yield from self._get_yield_lazy_value(yield_)
296 else:
297 input_node = for_stmt.get_testlist()
298 cn = ContextualizedNode(self, input_node)
299 ordered = cn.infer().iterate(cn)
300 ordered = list(ordered)
301 for lazy_value in ordered:
302 dct = {str(for_stmt.children[1].value): lazy_value.infer()}
303 with self.predefine_names(for_stmt, dct):
304 for yield_in_same_for_stmt in yields:
305 yield from self._get_yield_lazy_value(yield_in_same_for_stmt)
306
307 def merge_yield_values(self, is_async=False):
308 return ValueSet.from_sets(
309 lazy_value.infer()
310 for lazy_value in self.get_yield_lazy_values()
311 )
312
313 def is_generator(self):
314 return bool(get_yield_exprs(self.inference_state, self.tree_node))
315
316 def infer(self):
317 """
318 Created to be used by inheritance.
319 """
320 inference_state = self.inference_state
321 is_coroutine = self.tree_node.parent.type in ('async_stmt', 'async_funcdef')
322 from jedi.inference.gradual.base import GenericClass
323
324 if is_coroutine:
325 if self.is_generator():
326 async_generator_classes = inference_state.typing_module \
327 .py__getattribute__('AsyncGenerator')
328
329 yield_values = self.merge_yield_values(is_async=True)
330 # The contravariant doesn't seem to be defined.
331 generics = (yield_values.py__class__(), NO_VALUES)
332 return ValueSet(
333 GenericClass(c, TupleGenericManager(generics))
334 for c in async_generator_classes
335 ).execute_annotation()
336 else:
337 async_classes = inference_state.typing_module.py__getattribute__('Coroutine')
338 return_values = self.get_return_values()
339 # Only the first generic is relevant.
340 generics = (return_values.py__class__(), NO_VALUES, NO_VALUES)
341 return ValueSet(
342 GenericClass(c, TupleGenericManager(generics)) for c in async_classes
343 ).execute_annotation()
344 else:
345 # If there are annotations, prefer them over anything else.
346 if self.is_generator() and not self.infer_annotations():
347 return ValueSet([iterable.Generator(inference_state, self)])
348 else:
349 return self.get_return_values()
350
351
352class FunctionExecutionContext(BaseFunctionExecutionContext):
353 def __init__(self, function_value, arguments):
354 super().__init__(function_value)
355 self._arguments = arguments
356
357 def get_filters(self, until_position=None, origin_scope=None):
358 yield FunctionExecutionFilter(
359 self, self._value,
360 until_position=until_position,
361 origin_scope=origin_scope,
362 arguments=self._arguments
363 )
364
365 def infer_annotations(self):
366 from jedi.inference.gradual.annotation import infer_return_types
367 return infer_return_types(self._value, self._arguments)
368
369 def get_param_names(self):
370 return [
371 ParamName(self._value, param.name, self._arguments)
372 for param in self._value.tree_node.get_params()
373 ]
374
375
376class AnonymousFunctionExecution(BaseFunctionExecutionContext):
377 def infer_annotations(self):
378 # I don't think inferring anonymous executions is a big thing.
379 # Anonymous contexts are mostly there for the user to work in. ~ dave
380 return NO_VALUES
381
382 def get_filters(self, until_position=None, origin_scope=None):
383 yield AnonymousFunctionExecutionFilter(
384 self, self._value,
385 until_position=until_position,
386 origin_scope=origin_scope,
387 )
388
389 def get_param_names(self):
390 return self._value.get_param_names()
391
392
393class OverloadedFunctionValue(FunctionMixin, ValueWrapper):
394 def __init__(self, function, overloaded_functions):
395 super().__init__(function)
396 self._overloaded_functions = overloaded_functions
397
398 def py__call__(self, arguments):
399 debug.dbg("Execute overloaded function %s", self._wrapped_value, color='BLUE')
400 function_executions = []
401 for signature in self.get_signatures():
402 function_execution = signature.value.as_context(arguments)
403 function_executions.append(function_execution)
404 if signature.matches_signature(arguments):
405 return function_execution.infer()
406
407 if self.inference_state.is_analysis:
408 # In this case we want precision.
409 return NO_VALUES
410 return ValueSet.from_sets(fe.infer() for fe in function_executions)
411
412 def get_signature_functions(self):
413 return self._overloaded_functions
414
415 def get_type_hint(self, add_class_info=True):
416 return 'Union[%s]' % ', '.join(f.get_type_hint() for f in self._overloaded_functions)
417
418
419def _find_overload_functions(context, tree_node):
420 def _is_overload_decorated(funcdef):
421 if funcdef.parent.type == 'decorated':
422 decorators = funcdef.parent.children[0]
423 if decorators.type == 'decorator':
424 decorators = [decorators]
425 else:
426 decorators = decorators.children
427 for decorator in decorators:
428 dotted_name = decorator.children[1]
429 if dotted_name.type == 'name' and dotted_name.value == 'overload':
430 # TODO check with values if it's the right overload
431 return True
432 return False
433
434 if tree_node.type == 'lambdef':
435 return
436
437 if _is_overload_decorated(tree_node):
438 yield tree_node
439
440 while True:
441 filter = ParserTreeFilter(
442 context,
443 until_position=tree_node.start_pos
444 )
445 names = filter.get(tree_node.name.value)
446 assert isinstance(names, list)
447 if not names:
448 break
449
450 found = False
451 for name in names:
452 funcdef = name.tree_name.parent
453 if funcdef.type == 'funcdef' and _is_overload_decorated(funcdef):
454 tree_node = funcdef
455 found = True
456 yield funcdef
457
458 if not found:
459 break