1"""
2PEP 0484 ( https://www.python.org/dev/peps/pep-0484/ ) describes type hints
3through function annotations. There is a strong suggestion in this document
4that only the type of type hinting defined in PEP0484 should be allowed
5as annotations in future python versions.
6"""
7
8import re
9from inspect import Parameter
10
11from parso import ParserSyntaxError, parse
12
13from jedi.inference.cache import inference_state_method_cache
14from jedi.inference.base_value import ValueSet, NO_VALUES
15from jedi.inference.gradual.base import DefineGenericBaseClass, GenericClass
16from jedi.inference.gradual.generics import TupleGenericManager
17from jedi.inference.gradual.type_var import TypeVar
18from jedi.inference.helpers import is_string
19from jedi.inference.compiled import builtin_from_name
20from jedi.inference.param import get_executed_param_names
21from jedi import debug
22from jedi import parser_utils
23
24
25def infer_annotation(context, annotation):
26 """
27 Inferes an annotation node. This means that it inferes the part of
28 `int` here:
29
30 foo: int = 3
31
32 Also checks for forward references (strings)
33 """
34 value_set = context.infer_node(annotation)
35 if len(value_set) == 0:
36 debug.warning(
37 "Inferred typing index %s should lead to 1 object, not %s" % (annotation, value_set))
38 return value_set
39
40 strings_removed = NO_VALUES
41 for part in value_set:
42 if is_string(part):
43 result = _get_forward_reference_node(context, part.get_safe_value())
44 if result is not None:
45 strings_removed |= context.infer_node(result)
46 continue
47 strings_removed |= ValueSet([part])
48 return strings_removed
49
50
51def _infer_annotation_string(context, string, index=None):
52 node = _get_forward_reference_node(context, string)
53 if node is None:
54 return NO_VALUES
55
56 value_set = context.infer_node(node)
57 if index is not None:
58 value_set = value_set.filter(
59 lambda value: (
60 value.array_type == 'tuple'
61 and len(list(value.py__iter__())) >= index
62 )
63 ).py__simple_getitem__(index)
64 return value_set
65
66
67def _get_forward_reference_node(context, string):
68 try:
69 new_node = context.inference_state.grammar.parse(
70 string,
71 start_symbol='eval_input',
72 error_recovery=False
73 )
74 except ParserSyntaxError:
75 debug.warning('Annotation not parsed: %s' % string)
76 return None
77 else:
78 module = context.tree_node.get_root_node()
79 parser_utils.move(new_node, module.end_pos[0])
80 new_node.parent = context.tree_node
81 return new_node
82
83
84def _split_comment_param_declaration(decl_text):
85 """
86 Split decl_text on commas, but group generic expressions
87 together.
88
89 For example, given "foo, Bar[baz, biz]" we return
90 ['foo', 'Bar[baz, biz]'].
91
92 """
93 try:
94 node = parse(decl_text, error_recovery=False).children[0]
95 except ParserSyntaxError:
96 debug.warning('Comment annotation is not valid Python: %s' % decl_text)
97 return []
98
99 if node.type in ['name', 'atom_expr', 'power']:
100 return [node.get_code().strip()]
101
102 params = []
103 try:
104 children = node.children
105 except AttributeError:
106 return []
107 else:
108 for child in children:
109 if child.type in ['name', 'atom_expr', 'power']:
110 params.append(child.get_code().strip())
111
112 return params
113
114
115@inference_state_method_cache()
116def infer_param(function_value, param, ignore_stars=False):
117 values = _infer_param(function_value, param)
118 if ignore_stars or not values:
119 return values
120 inference_state = function_value.inference_state
121 if param.star_count == 1:
122 tuple_ = builtin_from_name(inference_state, 'tuple')
123 return ValueSet([GenericClass(
124 tuple_,
125 TupleGenericManager((values,)),
126 )])
127 elif param.star_count == 2:
128 dct = builtin_from_name(inference_state, 'dict')
129 generics = (
130 ValueSet([builtin_from_name(inference_state, 'str')]),
131 values
132 )
133 return ValueSet([GenericClass(
134 dct,
135 TupleGenericManager(generics),
136 )])
137 return values
138
139
140def _infer_param(function_value, param):
141 """
142 Infers the type of a function parameter, using type annotations.
143 """
144 annotation = param.annotation
145 if annotation is None:
146 # If no Python 3-style annotation, look for a comment annotation.
147 # Identify parameters to function in the same sequence as they would
148 # appear in a type comment.
149 all_params = [child for child in param.parent.children
150 if child.type == 'param']
151
152 node = param.parent.parent
153 comment = parser_utils.get_following_comment_same_line(node)
154 if comment is None:
155 return NO_VALUES
156
157 match = re.match(r"^#\s*type:\s*\(([^#]*)\)\s*->", comment)
158 if not match:
159 return NO_VALUES
160 params_comments = _split_comment_param_declaration(match.group(1))
161
162 # Find the specific param being investigated
163 index = all_params.index(param)
164 # If the number of parameters doesn't match length of type comment,
165 # ignore first parameter (assume it's self).
166 if len(params_comments) != len(all_params):
167 debug.warning(
168 "Comments length != Params length %s %s",
169 params_comments, all_params
170 )
171 if function_value.is_bound_method():
172 if index == 0:
173 # Assume it's self, which is already handled
174 return NO_VALUES
175 index -= 1
176 if index >= len(params_comments):
177 return NO_VALUES
178
179 param_comment = params_comments[index]
180 return _infer_annotation_string(
181 function_value.get_default_param_context(),
182 param_comment
183 )
184 # Annotations are like default params and resolve in the same way.
185 context = function_value.get_default_param_context()
186 return infer_annotation(context, annotation)
187
188
189def py__annotations__(funcdef):
190 dct = {}
191 for function_param in funcdef.get_params():
192 param_annotation = function_param.annotation
193 if param_annotation is not None:
194 dct[function_param.name.value] = param_annotation
195
196 return_annotation = funcdef.annotation
197 if return_annotation:
198 dct['return'] = return_annotation
199 return dct
200
201
202def resolve_forward_references(context, all_annotations):
203 def resolve(node):
204 if node is None or node.type != 'string':
205 return node
206
207 node = _get_forward_reference_node(
208 context,
209 context.inference_state.compiled_subprocess.safe_literal_eval(
210 node.value,
211 ),
212 )
213
214 if node is None:
215 # There was a string, but it's not a valid annotation
216 return None
217
218 # The forward reference tree has an additional root node ('eval_input')
219 # that we don't want. Extract the node we do want, that is equivalent to
220 # the nodes returned by `py__annotations__` for a non-quoted node.
221 node = node.children[0]
222
223 return node
224
225 return {name: resolve(node) for name, node in all_annotations.items()}
226
227
228@inference_state_method_cache()
229def infer_return_types(function, arguments):
230 """
231 Infers the type of a function's return value,
232 according to type annotations.
233 """
234 context = function.get_default_param_context()
235 all_annotations = resolve_forward_references(
236 context,
237 py__annotations__(function.tree_node),
238 )
239 annotation = all_annotations.get("return", None)
240 if annotation is None:
241 # If there is no Python 3-type annotation, look for an annotation
242 # comment.
243 node = function.tree_node
244 comment = parser_utils.get_following_comment_same_line(node)
245 if comment is None:
246 return NO_VALUES
247
248 match = re.match(r"^#\s*type:\s*\([^#]*\)\s*->\s*([^#]*)", comment)
249 if not match:
250 return NO_VALUES
251
252 return _infer_annotation_string(
253 context,
254 match.group(1).strip()
255 ).execute_annotation(context)
256
257 unknown_type_vars = find_unknown_type_vars(context, annotation)
258 annotation_values = infer_annotation(context, annotation)
259 if not unknown_type_vars:
260 return annotation_values.execute_annotation(context)
261
262 type_var_dict = infer_type_vars_for_execution(function, arguments, all_annotations)
263
264 return ValueSet.from_sets(
265 ann.define_generics(type_var_dict)
266 if isinstance(ann, (DefineGenericBaseClass, TypeVar)) else ValueSet({ann})
267 for ann in annotation_values
268 ).execute_annotation(context)
269
270
271def infer_type_vars_for_execution(function, arguments, annotation_dict):
272 """
273 Some functions use type vars that are not defined by the class, but rather
274 only defined in the function. See for example `iter`. In those cases we
275 want to:
276
277 1. Search for undefined type vars.
278 2. Infer type vars with the execution state we have.
279 3. Return the union of all type vars that have been found.
280 """
281 context = function.get_default_param_context()
282
283 annotation_variable_results = {}
284 executed_param_names = get_executed_param_names(function, arguments)
285 for executed_param_name in executed_param_names:
286 try:
287 annotation_node = annotation_dict[executed_param_name.string_name]
288 except KeyError:
289 continue
290
291 annotation_variables = find_unknown_type_vars(context, annotation_node)
292 if annotation_variables:
293 # Infer unknown type var
294 annotation_value_set = context.infer_node(annotation_node)
295 kind = executed_param_name.get_kind()
296 actual_value_set = executed_param_name.infer()
297 if kind is Parameter.VAR_POSITIONAL:
298 actual_value_set = actual_value_set.merge_types_of_iterate()
299 elif kind is Parameter.VAR_KEYWORD:
300 # TODO _dict_values is not public.
301 actual_value_set = actual_value_set.try_merge('_dict_values')
302 merge_type_var_dicts(
303 annotation_variable_results,
304 annotation_value_set.infer_type_vars(actual_value_set),
305 )
306 return annotation_variable_results
307
308
309def infer_return_for_callable(arguments, param_values, result_values):
310 all_type_vars = {}
311 for pv in param_values:
312 if pv.array_type == 'list':
313 type_var_dict = _infer_type_vars_for_callable(arguments, pv.py__iter__())
314 all_type_vars.update(type_var_dict)
315
316 return ValueSet.from_sets(
317 v.define_generics(all_type_vars)
318 if isinstance(v, (DefineGenericBaseClass, TypeVar))
319 else ValueSet({v})
320 for v in result_values
321 ).execute_annotation(arguments.context)
322
323
324def _infer_type_vars_for_callable(arguments, lazy_params):
325 """
326 Infers type vars for the Calllable class:
327
328 def x() -> Callable[[Callable[..., _T]], _T]: ...
329 """
330 annotation_variable_results = {}
331 for (_, lazy_value), lazy_callable_param in zip(arguments.unpack(), lazy_params):
332 callable_param_values = lazy_callable_param.infer()
333 # Infer unknown type var
334 actual_value_set = lazy_value.infer()
335 merge_type_var_dicts(
336 annotation_variable_results,
337 callable_param_values.infer_type_vars(actual_value_set),
338 )
339 return annotation_variable_results
340
341
342def merge_type_var_dicts(base_dict, new_dict):
343 for type_var_name, values in new_dict.items():
344 if values:
345 try:
346 base_dict[type_var_name] |= values
347 except KeyError:
348 base_dict[type_var_name] = values
349
350
351def merge_pairwise_generics(annotation_value, annotated_argument_class):
352 """
353 Match up the generic parameters from the given argument class to the
354 target annotation.
355
356 This walks the generic parameters immediately within the annotation and
357 argument's type, in order to determine the concrete values of the
358 annotation's parameters for the current case.
359
360 For example, given the following code:
361
362 def values(mapping: Mapping[K, V]) -> List[V]: ...
363
364 for val in values({1: 'a'}):
365 val
366
367 Then this function should be given representations of `Mapping[K, V]`
368 and `Mapping[int, str]`, so that it can determine that `K` is `int and
369 `V` is `str`.
370
371 Note that it is responsibility of the caller to traverse the MRO of the
372 argument type as needed in order to find the type matching the
373 annotation (in this case finding `Mapping[int, str]` as a parent of
374 `Dict[int, str]`).
375
376 Parameters
377 ----------
378
379 `annotation_value`: represents the annotation to infer the concrete
380 parameter types of.
381
382 `annotated_argument_class`: represents the annotated class of the
383 argument being passed to the object annotated by `annotation_value`.
384 """
385
386 type_var_dict = {}
387
388 if not isinstance(annotated_argument_class, DefineGenericBaseClass):
389 return type_var_dict
390
391 annotation_generics = annotation_value.get_generics()
392 actual_generics = annotated_argument_class.get_generics()
393
394 for annotation_generics_set, actual_generic_set in zip(annotation_generics, actual_generics):
395 merge_type_var_dicts(
396 type_var_dict,
397 annotation_generics_set.infer_type_vars(actual_generic_set.execute_annotation(None)),
398 )
399
400 return type_var_dict
401
402
403def find_type_from_comment_hint_for(context, node, name):
404 return _find_type_from_comment_hint(context, node, node.children[1], name)
405
406
407def find_type_from_comment_hint_with(context, node, name):
408 if len(node.children) > 4:
409 # In case there are multiple with_items, we do not want a type hint for
410 # now.
411 return []
412 assert len(node.children[1].children) == 3, \
413 "Can only be here when children[1] is 'foo() as f'"
414 varlist = node.children[1].children[2]
415 return _find_type_from_comment_hint(context, node, varlist, name)
416
417
418def find_type_from_comment_hint_assign(context, node, name):
419 return _find_type_from_comment_hint(context, node, node.children[0], name)
420
421
422def _find_type_from_comment_hint(context, node, varlist, name):
423 index = None
424 if varlist.type in ("testlist_star_expr", "exprlist", "testlist"):
425 # something like "a, b = 1, 2"
426 index = 0
427 for child in varlist.children:
428 if child == name:
429 break
430 if child.type == "operator":
431 continue
432 index += 1
433 else:
434 return []
435
436 comment = parser_utils.get_following_comment_same_line(node)
437 if comment is None:
438 return []
439 match = re.match(r"^#\s*type:\s*([^#]*)", comment)
440 if match is None:
441 return []
442 return _infer_annotation_string(
443 context, match.group(1).strip(), index
444 ).execute_annotation(context)
445
446
447def find_unknown_type_vars(context, node):
448 def check_node(node):
449 if node.type in ('atom_expr', 'power'):
450 trailer = node.children[-1]
451 if trailer.type == 'trailer' and trailer.children[0] == '[':
452 for subscript_node in _unpack_subscriptlist(trailer.children[1]):
453 check_node(subscript_node)
454 else:
455 found[:] = _filter_type_vars(context.infer_node(node), found)
456
457 found = [] # We're not using a set, because the order matters.
458 check_node(node)
459 return found
460
461
462def _filter_type_vars(value_set, found=()):
463 new_found = list(found)
464 for type_var in value_set:
465 if isinstance(type_var, TypeVar) and type_var not in found:
466 new_found.append(type_var)
467 return new_found
468
469
470def _unpack_subscriptlist(subscriptlist):
471 if subscriptlist.type == 'subscriptlist':
472 for subscript in subscriptlist.children[::2]:
473 if subscript.type != 'subscript':
474 yield subscript
475 else:
476 if subscriptlist.type != 'subscript':
477 yield subscriptlist