Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jedi-0.18.2-py3.8.egg/jedi/api/refactoring/extract.py: 12%
239 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
1from textwrap import dedent
3from parso import split_lines
5from jedi import debug
6from jedi.api.exceptions import RefactoringError
7from jedi.api.refactoring import Refactoring, EXPRESSION_PARTS
8from jedi.common import indent_block
9from jedi.parser_utils import function_is_classmethod, function_is_staticmethod
12_DEFINITION_SCOPES = ('suite', 'file_input')
13_VARIABLE_EXCTRACTABLE = EXPRESSION_PARTS + \
14 ('atom testlist_star_expr testlist test lambdef lambdef_nocond '
15 'keyword name number string fstring').split()
18def extract_variable(inference_state, path, module_node, name, pos, until_pos):
19 nodes = _find_nodes(module_node, pos, until_pos)
20 debug.dbg('Extracting nodes: %s', nodes)
22 is_expression, message = _is_expression_with_error(nodes)
23 if not is_expression:
24 raise RefactoringError(message)
26 generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
27 file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)}
28 return Refactoring(inference_state, file_to_node_changes)
31def _is_expression_with_error(nodes):
32 """
33 Returns a tuple (is_expression, error_string).
34 """
35 if any(node.type == 'name' and node.is_definition() for node in nodes):
36 return False, 'Cannot extract a name that defines something'
38 if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
39 return False, 'Cannot extract a "%s"' % nodes[0].type
40 return True, ''
43def _find_nodes(module_node, pos, until_pos):
44 """
45 Looks up a module and tries to find the appropriate amount of nodes that
46 are in there.
47 """
48 start_node = module_node.get_leaf_for_position(pos, include_prefixes=True)
50 if until_pos is None:
51 if start_node.type == 'operator':
52 next_leaf = start_node.get_next_leaf()
53 if next_leaf is not None and next_leaf.start_pos == pos:
54 start_node = next_leaf
56 if _is_not_extractable_syntax(start_node):
57 start_node = start_node.parent
59 if start_node.parent.type == 'trailer':
60 start_node = start_node.parent.parent
61 while start_node.parent.type in EXPRESSION_PARTS:
62 start_node = start_node.parent
64 nodes = [start_node]
65 else:
66 # Get the next leaf if we are at the end of a leaf
67 if start_node.end_pos == pos:
68 next_leaf = start_node.get_next_leaf()
69 if next_leaf is not None:
70 start_node = next_leaf
72 # Some syntax is not exactable, just use its parent
73 if _is_not_extractable_syntax(start_node):
74 start_node = start_node.parent
76 # Find the end
77 end_leaf = module_node.get_leaf_for_position(until_pos, include_prefixes=True)
78 if end_leaf.start_pos > until_pos:
79 end_leaf = end_leaf.get_previous_leaf()
80 if end_leaf is None:
81 raise RefactoringError('Cannot extract anything from that')
83 parent_node = start_node
84 while parent_node.end_pos < end_leaf.end_pos:
85 parent_node = parent_node.parent
87 nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos)
89 # If the user marks just a return statement, we return the expression
90 # instead of the whole statement, because the user obviously wants to
91 # extract that part.
92 if len(nodes) == 1 and start_node.type in ('return_stmt', 'yield_expr'):
93 return [nodes[0].children[1]]
94 return nodes
97def _replace(nodes, expression_replacement, extracted, pos,
98 insert_before_leaf=None, remaining_prefix=None):
99 # Now try to replace the nodes found with a variable and move the code
100 # before the current statement.
101 definition = _get_parent_definition(nodes[0])
102 if insert_before_leaf is None:
103 insert_before_leaf = definition.get_first_leaf()
104 first_node_leaf = nodes[0].get_first_leaf()
106 lines = split_lines(insert_before_leaf.prefix, keepends=True)
107 if first_node_leaf is insert_before_leaf:
108 if remaining_prefix is not None:
109 # The remaining prefix has already been calculated.
110 lines[:-1] = remaining_prefix
111 lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n']
112 extracted_prefix = ''.join(lines)
114 replacement_dct = {}
115 if first_node_leaf is insert_before_leaf:
116 replacement_dct[nodes[0]] = extracted_prefix + expression_replacement
117 else:
118 if remaining_prefix is None:
119 p = first_node_leaf.prefix
120 else:
121 p = remaining_prefix + _get_indentation(nodes[0])
122 replacement_dct[nodes[0]] = p + expression_replacement
123 replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value
125 for node in nodes[1:]:
126 replacement_dct[node] = ''
127 return replacement_dct
130def _expression_nodes_to_string(nodes):
131 return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes))
134def _suite_nodes_to_string(nodes, pos):
135 n = nodes[0]
136 prefix, part_of_code = _split_prefix_at(n.get_first_leaf(), pos[0] - 1)
137 code = part_of_code + n.get_code(include_prefix=False) \
138 + ''.join(n.get_code() for n in nodes[1:])
139 return prefix, code
142def _split_prefix_at(leaf, until_line):
143 """
144 Returns a tuple of the leaf's prefix, split at the until_line
145 position.
146 """
147 # second means the second returned part
148 second_line_count = leaf.start_pos[0] - until_line
149 lines = split_lines(leaf.prefix, keepends=True)
150 return ''.join(lines[:-second_line_count]), ''.join(lines[-second_line_count:])
153def _get_indentation(node):
154 return split_lines(node.get_first_leaf().prefix)[-1]
157def _get_parent_definition(node):
158 """
159 Returns the statement where a node is defined.
160 """
161 while node is not None:
162 if node.parent.type in _DEFINITION_SCOPES:
163 return node
164 node = node.parent
165 raise NotImplementedError('We should never even get here')
168def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
169 """
170 This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even
171 though it is not part of the expression.
172 """
173 typ = parent_node.type
174 is_suite_part = typ in ('suite', 'file_input')
175 if typ in EXPRESSION_PARTS or is_suite_part:
176 nodes = parent_node.children
177 for i, n in enumerate(nodes):
178 if n.end_pos > pos:
179 start_index = i
180 if n.type == 'operator':
181 start_index -= 1
182 break
183 for i, n in reversed(list(enumerate(nodes))):
184 if n.start_pos < until_pos:
185 end_index = i
186 if n.type == 'operator':
187 end_index += 1
189 # Something like `not foo or bar` should not be cut after not
190 for n2 in nodes[i:]:
191 if _is_not_extractable_syntax(n2):
192 end_index += 1
193 else:
194 break
195 break
196 nodes = nodes[start_index:end_index + 1]
197 if not is_suite_part:
198 nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
199 nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
200 return nodes
201 return [parent_node]
204def _is_not_extractable_syntax(node):
205 return node.type == 'operator' \
206 or node.type == 'keyword' and node.value not in ('None', 'True', 'False')
209def extract_function(inference_state, path, module_context, name, pos, until_pos):
210 nodes = _find_nodes(module_context.tree_node, pos, until_pos)
211 assert len(nodes)
213 is_expression, _ = _is_expression_with_error(nodes)
214 context = module_context.create_context(nodes[0])
215 is_bound_method = context.is_bound_method()
216 params, return_variables = list(_find_inputs_and_outputs(module_context, context, nodes))
218 # Find variables
219 # Is a class method / method
220 if context.is_module():
221 insert_before_leaf = None # Leaf will be determined later
222 else:
223 node = _get_code_insertion_node(context.tree_node, is_bound_method)
224 insert_before_leaf = node.get_first_leaf()
225 if is_expression:
226 code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
227 remaining_prefix = None
228 has_ending_return_stmt = False
229 else:
230 has_ending_return_stmt = _is_node_ending_return_stmt(nodes[-1])
231 if not has_ending_return_stmt:
232 # Find the actually used variables (of the defined ones). If none are
233 # used (e.g. if the range covers the whole function), return the last
234 # defined variable.
235 return_variables = list(_find_needed_output_variables(
236 context,
237 nodes[0].parent,
238 nodes[-1].end_pos,
239 return_variables
240 )) or [return_variables[-1]] if return_variables else []
242 remaining_prefix, code_block = _suite_nodes_to_string(nodes, pos)
243 after_leaf = nodes[-1].get_next_leaf()
244 first, second = _split_prefix_at(after_leaf, until_pos[0])
245 code_block += first
247 code_block = dedent(code_block)
248 if not has_ending_return_stmt:
249 output_var_str = ', '.join(return_variables)
250 code_block += 'return ' + output_var_str + '\n'
252 # Check if we have to raise RefactoringError
253 _check_for_non_extractables(nodes[:-1] if has_ending_return_stmt else nodes)
255 decorator = ''
256 self_param = None
257 if is_bound_method:
258 if not function_is_staticmethod(context.tree_node):
259 function_param_names = context.get_value().get_param_names()
260 if len(function_param_names):
261 self_param = function_param_names[0].string_name
262 params = [p for p in params if p != self_param]
264 if function_is_classmethod(context.tree_node):
265 decorator = '@classmethod\n'
266 else:
267 code_block += '\n'
269 function_code = '%sdef %s(%s):\n%s' % (
270 decorator,
271 name,
272 ', '.join(params if self_param is None else [self_param] + params),
273 indent_block(code_block)
274 )
276 function_call = '%s(%s)' % (
277 ('' if self_param is None else self_param + '.') + name,
278 ', '.join(params)
279 )
280 if is_expression:
281 replacement = function_call
282 else:
283 if has_ending_return_stmt:
284 replacement = 'return ' + function_call + '\n'
285 else:
286 replacement = output_var_str + ' = ' + function_call + '\n'
288 replacement_dct = _replace(nodes, replacement, function_code, pos,
289 insert_before_leaf, remaining_prefix)
290 if not is_expression:
291 replacement_dct[after_leaf] = second + after_leaf.value
292 file_to_node_changes = {path: replacement_dct}
293 return Refactoring(inference_state, file_to_node_changes)
296def _check_for_non_extractables(nodes):
297 for n in nodes:
298 try:
299 children = n.children
300 except AttributeError:
301 if n.value == 'return':
302 raise RefactoringError(
303 'Can only extract return statements if they are at the end.')
304 if n.value == 'yield':
305 raise RefactoringError('Cannot extract yield statements.')
306 else:
307 _check_for_non_extractables(children)
310def _is_name_input(module_context, names, first, last):
311 for name in names:
312 if name.api_type == 'param' or not name.parent_context.is_module():
313 if name.get_root_context() is not module_context:
314 return True
315 if name.start_pos is None or not (first <= name.start_pos < last):
316 return True
317 return False
320def _find_inputs_and_outputs(module_context, context, nodes):
321 first = nodes[0].start_pos
322 last = nodes[-1].end_pos
324 inputs = []
325 outputs = []
326 for name in _find_non_global_names(nodes):
327 if name.is_definition():
328 if name not in outputs:
329 outputs.append(name.value)
330 else:
331 if name.value not in inputs:
332 name_definitions = context.goto(name, name.start_pos)
333 if not name_definitions \
334 or _is_name_input(module_context, name_definitions, first, last):
335 inputs.append(name.value)
337 # Check if outputs are really needed:
338 return inputs, outputs
341def _find_non_global_names(nodes):
342 for node in nodes:
343 try:
344 children = node.children
345 except AttributeError:
346 if node.type == 'name':
347 yield node
348 else:
349 # We only want to check foo in foo.bar
350 if node.type == 'trailer' and node.children[0] == '.':
351 continue
353 yield from _find_non_global_names(children)
356def _get_code_insertion_node(node, is_bound_method):
357 if not is_bound_method or function_is_staticmethod(node):
358 while node.parent.type != 'file_input':
359 node = node.parent
361 while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'):
362 node = node.parent
363 return node
366def _find_needed_output_variables(context, search_node, at_least_pos, return_variables):
367 """
368 Searches everything after at_least_pos in a node and checks if any of the
369 return_variables are used in there and returns those.
370 """
371 for node in search_node.children:
372 if node.start_pos < at_least_pos:
373 continue
375 return_variables = set(return_variables)
376 for name in _find_non_global_names([node]):
377 if not name.is_definition() and name.value in return_variables:
378 return_variables.remove(name.value)
379 yield name.value
382def _is_node_ending_return_stmt(node):
383 t = node.type
384 if t == 'simple_stmt':
385 return _is_node_ending_return_stmt(node.children[0])
386 return t == 'return_stmt'