Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jedi-0.18.2-py3.8.egg/jedi/api/refactoring/extract.py: 12%

239 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:56 +0000

1from textwrap import dedent 

2 

3from parso import split_lines 

4 

5from jedi import debug 

6from jedi.api.exceptions import RefactoringError 

7from jedi.api.refactoring import Refactoring, EXPRESSION_PARTS 

8from jedi.common import indent_block 

9from jedi.parser_utils import function_is_classmethod, function_is_staticmethod 

10 

11 

12_DEFINITION_SCOPES = ('suite', 'file_input') 

13_VARIABLE_EXCTRACTABLE = EXPRESSION_PARTS + \ 

14 ('atom testlist_star_expr testlist test lambdef lambdef_nocond ' 

15 'keyword name number string fstring').split() 

16 

17 

18def extract_variable(inference_state, path, module_node, name, pos, until_pos): 

19 nodes = _find_nodes(module_node, pos, until_pos) 

20 debug.dbg('Extracting nodes: %s', nodes) 

21 

22 is_expression, message = _is_expression_with_error(nodes) 

23 if not is_expression: 

24 raise RefactoringError(message) 

25 

26 generated_code = name + ' = ' + _expression_nodes_to_string(nodes) 

27 file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)} 

28 return Refactoring(inference_state, file_to_node_changes) 

29 

30 

31def _is_expression_with_error(nodes): 

32 """ 

33 Returns a tuple (is_expression, error_string). 

34 """ 

35 if any(node.type == 'name' and node.is_definition() for node in nodes): 

36 return False, 'Cannot extract a name that defines something' 

37 

38 if nodes[0].type not in _VARIABLE_EXCTRACTABLE: 

39 return False, 'Cannot extract a "%s"' % nodes[0].type 

40 return True, '' 

41 

42 

43def _find_nodes(module_node, pos, until_pos): 

44 """ 

45 Looks up a module and tries to find the appropriate amount of nodes that 

46 are in there. 

47 """ 

48 start_node = module_node.get_leaf_for_position(pos, include_prefixes=True) 

49 

50 if until_pos is None: 

51 if start_node.type == 'operator': 

52 next_leaf = start_node.get_next_leaf() 

53 if next_leaf is not None and next_leaf.start_pos == pos: 

54 start_node = next_leaf 

55 

56 if _is_not_extractable_syntax(start_node): 

57 start_node = start_node.parent 

58 

59 if start_node.parent.type == 'trailer': 

60 start_node = start_node.parent.parent 

61 while start_node.parent.type in EXPRESSION_PARTS: 

62 start_node = start_node.parent 

63 

64 nodes = [start_node] 

65 else: 

66 # Get the next leaf if we are at the end of a leaf 

67 if start_node.end_pos == pos: 

68 next_leaf = start_node.get_next_leaf() 

69 if next_leaf is not None: 

70 start_node = next_leaf 

71 

72 # Some syntax is not exactable, just use its parent 

73 if _is_not_extractable_syntax(start_node): 

74 start_node = start_node.parent 

75 

76 # Find the end 

77 end_leaf = module_node.get_leaf_for_position(until_pos, include_prefixes=True) 

78 if end_leaf.start_pos > until_pos: 

79 end_leaf = end_leaf.get_previous_leaf() 

80 if end_leaf is None: 

81 raise RefactoringError('Cannot extract anything from that') 

82 

83 parent_node = start_node 

84 while parent_node.end_pos < end_leaf.end_pos: 

85 parent_node = parent_node.parent 

86 

87 nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos) 

88 

89 # If the user marks just a return statement, we return the expression 

90 # instead of the whole statement, because the user obviously wants to 

91 # extract that part. 

92 if len(nodes) == 1 and start_node.type in ('return_stmt', 'yield_expr'): 

93 return [nodes[0].children[1]] 

94 return nodes 

95 

96 

97def _replace(nodes, expression_replacement, extracted, pos, 

98 insert_before_leaf=None, remaining_prefix=None): 

99 # Now try to replace the nodes found with a variable and move the code 

100 # before the current statement. 

101 definition = _get_parent_definition(nodes[0]) 

102 if insert_before_leaf is None: 

103 insert_before_leaf = definition.get_first_leaf() 

104 first_node_leaf = nodes[0].get_first_leaf() 

105 

106 lines = split_lines(insert_before_leaf.prefix, keepends=True) 

107 if first_node_leaf is insert_before_leaf: 

108 if remaining_prefix is not None: 

109 # The remaining prefix has already been calculated. 

110 lines[:-1] = remaining_prefix 

111 lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n'] 

112 extracted_prefix = ''.join(lines) 

113 

114 replacement_dct = {} 

115 if first_node_leaf is insert_before_leaf: 

116 replacement_dct[nodes[0]] = extracted_prefix + expression_replacement 

117 else: 

118 if remaining_prefix is None: 

119 p = first_node_leaf.prefix 

120 else: 

121 p = remaining_prefix + _get_indentation(nodes[0]) 

122 replacement_dct[nodes[0]] = p + expression_replacement 

123 replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value 

124 

125 for node in nodes[1:]: 

126 replacement_dct[node] = '' 

127 return replacement_dct 

128 

129 

130def _expression_nodes_to_string(nodes): 

131 return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes)) 

132 

133 

134def _suite_nodes_to_string(nodes, pos): 

135 n = nodes[0] 

136 prefix, part_of_code = _split_prefix_at(n.get_first_leaf(), pos[0] - 1) 

137 code = part_of_code + n.get_code(include_prefix=False) \ 

138 + ''.join(n.get_code() for n in nodes[1:]) 

139 return prefix, code 

140 

141 

142def _split_prefix_at(leaf, until_line): 

143 """ 

144 Returns a tuple of the leaf's prefix, split at the until_line 

145 position. 

146 """ 

147 # second means the second returned part 

148 second_line_count = leaf.start_pos[0] - until_line 

149 lines = split_lines(leaf.prefix, keepends=True) 

150 return ''.join(lines[:-second_line_count]), ''.join(lines[-second_line_count:]) 

151 

152 

153def _get_indentation(node): 

154 return split_lines(node.get_first_leaf().prefix)[-1] 

155 

156 

157def _get_parent_definition(node): 

158 """ 

159 Returns the statement where a node is defined. 

160 """ 

161 while node is not None: 

162 if node.parent.type in _DEFINITION_SCOPES: 

163 return node 

164 node = node.parent 

165 raise NotImplementedError('We should never even get here') 

166 

167 

168def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): 

169 """ 

170 This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even 

171 though it is not part of the expression. 

172 """ 

173 typ = parent_node.type 

174 is_suite_part = typ in ('suite', 'file_input') 

175 if typ in EXPRESSION_PARTS or is_suite_part: 

176 nodes = parent_node.children 

177 for i, n in enumerate(nodes): 

178 if n.end_pos > pos: 

179 start_index = i 

180 if n.type == 'operator': 

181 start_index -= 1 

182 break 

183 for i, n in reversed(list(enumerate(nodes))): 

184 if n.start_pos < until_pos: 

185 end_index = i 

186 if n.type == 'operator': 

187 end_index += 1 

188 

189 # Something like `not foo or bar` should not be cut after not 

190 for n2 in nodes[i:]: 

191 if _is_not_extractable_syntax(n2): 

192 end_index += 1 

193 else: 

194 break 

195 break 

196 nodes = nodes[start_index:end_index + 1] 

197 if not is_suite_part: 

198 nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos) 

199 nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos) 

200 return nodes 

201 return [parent_node] 

202 

203 

204def _is_not_extractable_syntax(node): 

205 return node.type == 'operator' \ 

206 or node.type == 'keyword' and node.value not in ('None', 'True', 'False') 

207 

208 

209def extract_function(inference_state, path, module_context, name, pos, until_pos): 

210 nodes = _find_nodes(module_context.tree_node, pos, until_pos) 

211 assert len(nodes) 

212 

213 is_expression, _ = _is_expression_with_error(nodes) 

214 context = module_context.create_context(nodes[0]) 

215 is_bound_method = context.is_bound_method() 

216 params, return_variables = list(_find_inputs_and_outputs(module_context, context, nodes)) 

217 

218 # Find variables 

219 # Is a class method / method 

220 if context.is_module(): 

221 insert_before_leaf = None # Leaf will be determined later 

222 else: 

223 node = _get_code_insertion_node(context.tree_node, is_bound_method) 

224 insert_before_leaf = node.get_first_leaf() 

225 if is_expression: 

226 code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' 

227 remaining_prefix = None 

228 has_ending_return_stmt = False 

229 else: 

230 has_ending_return_stmt = _is_node_ending_return_stmt(nodes[-1]) 

231 if not has_ending_return_stmt: 

232 # Find the actually used variables (of the defined ones). If none are 

233 # used (e.g. if the range covers the whole function), return the last 

234 # defined variable. 

235 return_variables = list(_find_needed_output_variables( 

236 context, 

237 nodes[0].parent, 

238 nodes[-1].end_pos, 

239 return_variables 

240 )) or [return_variables[-1]] if return_variables else [] 

241 

242 remaining_prefix, code_block = _suite_nodes_to_string(nodes, pos) 

243 after_leaf = nodes[-1].get_next_leaf() 

244 first, second = _split_prefix_at(after_leaf, until_pos[0]) 

245 code_block += first 

246 

247 code_block = dedent(code_block) 

248 if not has_ending_return_stmt: 

249 output_var_str = ', '.join(return_variables) 

250 code_block += 'return ' + output_var_str + '\n' 

251 

252 # Check if we have to raise RefactoringError 

253 _check_for_non_extractables(nodes[:-1] if has_ending_return_stmt else nodes) 

254 

255 decorator = '' 

256 self_param = None 

257 if is_bound_method: 

258 if not function_is_staticmethod(context.tree_node): 

259 function_param_names = context.get_value().get_param_names() 

260 if len(function_param_names): 

261 self_param = function_param_names[0].string_name 

262 params = [p for p in params if p != self_param] 

263 

264 if function_is_classmethod(context.tree_node): 

265 decorator = '@classmethod\n' 

266 else: 

267 code_block += '\n' 

268 

269 function_code = '%sdef %s(%s):\n%s' % ( 

270 decorator, 

271 name, 

272 ', '.join(params if self_param is None else [self_param] + params), 

273 indent_block(code_block) 

274 ) 

275 

276 function_call = '%s(%s)' % ( 

277 ('' if self_param is None else self_param + '.') + name, 

278 ', '.join(params) 

279 ) 

280 if is_expression: 

281 replacement = function_call 

282 else: 

283 if has_ending_return_stmt: 

284 replacement = 'return ' + function_call + '\n' 

285 else: 

286 replacement = output_var_str + ' = ' + function_call + '\n' 

287 

288 replacement_dct = _replace(nodes, replacement, function_code, pos, 

289 insert_before_leaf, remaining_prefix) 

290 if not is_expression: 

291 replacement_dct[after_leaf] = second + after_leaf.value 

292 file_to_node_changes = {path: replacement_dct} 

293 return Refactoring(inference_state, file_to_node_changes) 

294 

295 

296def _check_for_non_extractables(nodes): 

297 for n in nodes: 

298 try: 

299 children = n.children 

300 except AttributeError: 

301 if n.value == 'return': 

302 raise RefactoringError( 

303 'Can only extract return statements if they are at the end.') 

304 if n.value == 'yield': 

305 raise RefactoringError('Cannot extract yield statements.') 

306 else: 

307 _check_for_non_extractables(children) 

308 

309 

310def _is_name_input(module_context, names, first, last): 

311 for name in names: 

312 if name.api_type == 'param' or not name.parent_context.is_module(): 

313 if name.get_root_context() is not module_context: 

314 return True 

315 if name.start_pos is None or not (first <= name.start_pos < last): 

316 return True 

317 return False 

318 

319 

320def _find_inputs_and_outputs(module_context, context, nodes): 

321 first = nodes[0].start_pos 

322 last = nodes[-1].end_pos 

323 

324 inputs = [] 

325 outputs = [] 

326 for name in _find_non_global_names(nodes): 

327 if name.is_definition(): 

328 if name not in outputs: 

329 outputs.append(name.value) 

330 else: 

331 if name.value not in inputs: 

332 name_definitions = context.goto(name, name.start_pos) 

333 if not name_definitions \ 

334 or _is_name_input(module_context, name_definitions, first, last): 

335 inputs.append(name.value) 

336 

337 # Check if outputs are really needed: 

338 return inputs, outputs 

339 

340 

341def _find_non_global_names(nodes): 

342 for node in nodes: 

343 try: 

344 children = node.children 

345 except AttributeError: 

346 if node.type == 'name': 

347 yield node 

348 else: 

349 # We only want to check foo in foo.bar 

350 if node.type == 'trailer' and node.children[0] == '.': 

351 continue 

352 

353 yield from _find_non_global_names(children) 

354 

355 

356def _get_code_insertion_node(node, is_bound_method): 

357 if not is_bound_method or function_is_staticmethod(node): 

358 while node.parent.type != 'file_input': 

359 node = node.parent 

360 

361 while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'): 

362 node = node.parent 

363 return node 

364 

365 

366def _find_needed_output_variables(context, search_node, at_least_pos, return_variables): 

367 """ 

368 Searches everything after at_least_pos in a node and checks if any of the 

369 return_variables are used in there and returns those. 

370 """ 

371 for node in search_node.children: 

372 if node.start_pos < at_least_pos: 

373 continue 

374 

375 return_variables = set(return_variables) 

376 for name in _find_non_global_names([node]): 

377 if not name.is_definition() and name.value in return_variables: 

378 return_variables.remove(name.value) 

379 yield name.value 

380 

381 

382def _is_node_ending_return_stmt(node): 

383 t = node.type 

384 if t == 'simple_stmt': 

385 return _is_node_ending_return_stmt(node.children[0]) 

386 return t == 'return_stmt'