1# Copyright 2016 Grist Labs, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import ast
16import numbers
17import sys
18import token
19from ast import Module
20from typing import Callable, List, Union, cast, Optional, Tuple, TYPE_CHECKING
21
22from . import util
23from .asttokens import ASTTokens
24from .astroid_compat import astroid_node_classes as nc, BaseContainer as AstroidBaseContainer
25
26if TYPE_CHECKING:
27 from .util import AstNode
28
29
30# Mapping of matching braces. To find a token here, look up token[:2].
31_matching_pairs_left = {
32 (token.OP, '('): (token.OP, ')'),
33 (token.OP, '['): (token.OP, ']'),
34 (token.OP, '{'): (token.OP, '}'),
35}
36
37_matching_pairs_right = {
38 (token.OP, ')'): (token.OP, '('),
39 (token.OP, ']'): (token.OP, '['),
40 (token.OP, '}'): (token.OP, '{'),
41}
42
43
44class MarkTokens:
45 """
46 Helper that visits all nodes in the AST tree and assigns .first_token and .last_token attributes
47 to each of them. This is the heart of the token-marking logic.
48 """
49 def __init__(self, code):
50 # type: (ASTTokens) -> None
51 self._code = code
52 self._methods = util.NodeMethods()
53 self._iter_children = None # type: Optional[Callable]
54
55 def visit_tree(self, node):
56 # type: (Module) -> None
57 self._iter_children = util.iter_children_func(node)
58 util.visit_tree(node, self._visit_before_children, self._visit_after_children)
59
60 def _visit_before_children(self, node, parent_token):
61 # type: (AstNode, Optional[util.Token]) -> Tuple[Optional[util.Token], Optional[util.Token]]
62 col = getattr(node, 'col_offset', None)
63 token = self._code.get_token_from_utf8(node.lineno, col) if col is not None else None
64
65 if not token and util.is_module(node):
66 # We'll assume that a Module node starts at the start of the source code.
67 token = self._code.get_token(1, 0)
68
69 # Use our own token, or our parent's if we don't have one, to pass to child calls as
70 # parent_token argument. The second value becomes the token argument of _visit_after_children.
71 return (token or parent_token, token)
72
73 def _visit_after_children(self, node, parent_token, token):
74 # type: (AstNode, Optional[util.Token], Optional[util.Token]) -> None
75 # This processes the node generically first, after all children have been processed.
76
77 # Get the first and last tokens that belong to children. Note how this doesn't assume that we
78 # iterate through children in order that corresponds to occurrence in source code. This
79 # assumption can fail (e.g. with return annotations).
80 first = token
81 last = None
82 for child in cast(Callable, self._iter_children)(node):
83 # astroid slices have especially wrong positions, we don't want them to corrupt their parents.
84 if util.is_empty_astroid_slice(child):
85 continue
86 if not first or child.first_token.index < first.index:
87 first = child.first_token
88 if not last or child.last_token.index > last.index:
89 last = child.last_token
90
91 # If we don't have a first token from _visit_before_children, and there were no children, then
92 # use the parent's token as the first token.
93 first = first or parent_token
94
95 # If no children, set last token to the first one.
96 last = last or first
97
98 # Statements continue to before NEWLINE. This helps cover a few different cases at once.
99 if util.is_stmt(node):
100 last = self._find_last_in_stmt(cast(util.Token, last))
101
102 # Capture any unmatched brackets.
103 first, last = self._expand_to_matching_pairs(cast(util.Token, first), cast(util.Token, last), node)
104
105 # Give a chance to node-specific methods to adjust.
106 nfirst, nlast = self._methods.get(self, node.__class__)(node, first, last)
107
108 if (nfirst, nlast) != (first, last):
109 # If anything changed, expand again to capture any unmatched brackets.
110 nfirst, nlast = self._expand_to_matching_pairs(nfirst, nlast, node)
111
112 node.first_token = nfirst
113 node.last_token = nlast
114
115 def _find_last_in_stmt(self, start_token):
116 # type: (util.Token) -> util.Token
117 t = start_token
118 while (not util.match_token(t, token.NEWLINE) and
119 not util.match_token(t, token.OP, ';') and
120 not token.ISEOF(t.type)):
121 t = self._code.next_token(t, include_extra=True)
122 return self._code.prev_token(t)
123
124 def _expand_to_matching_pairs(self, first_token, last_token, node):
125 # type: (util.Token, util.Token, AstNode) -> Tuple[util.Token, util.Token]
126 """
127 Scan tokens in [first_token, last_token] range that are between node's children, and for any
128 unmatched brackets, adjust first/last tokens to include the closing pair.
129 """
130 # We look for opening parens/braces among non-child tokens (i.e. tokens between our actual
131 # child nodes). If we find any closing ones, we match them to the opens.
132 to_match_right = [] # type: List[Tuple[int, str]]
133 to_match_left = []
134 for tok in self._code.token_range(first_token, last_token):
135 tok_info = tok[:2]
136 if to_match_right and tok_info == to_match_right[-1]:
137 to_match_right.pop()
138 elif tok_info in _matching_pairs_left:
139 to_match_right.append(_matching_pairs_left[tok_info])
140 elif tok_info in _matching_pairs_right:
141 to_match_left.append(_matching_pairs_right[tok_info])
142
143 # Once done, extend `last_token` to match any unclosed parens/braces.
144 for match in reversed(to_match_right):
145 last = self._code.next_token(last_token)
146 # Allow for trailing commas or colons (allowed in subscripts) before the closing delimiter
147 while any(util.match_token(last, token.OP, x) for x in (',', ':')):
148 last = self._code.next_token(last)
149 # Now check for the actual closing delimiter.
150 if util.match_token(last, *match):
151 last_token = last
152
153 # And extend `first_token` to match any unclosed opening parens/braces.
154 for match in to_match_left:
155 first = self._code.prev_token(first_token)
156 if util.match_token(first, *match):
157 first_token = first
158
159 return (first_token, last_token)
160
161 #----------------------------------------------------------------------
162 # Node visitors. Each takes a preliminary first and last tokens, and returns the adjusted pair
163 # that will actually be assigned.
164
165 def visit_default(self, node, first_token, last_token):
166 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
167 # pylint: disable=no-self-use
168 # By default, we don't need to adjust the token we computed earlier.
169 return (first_token, last_token)
170
171 def handle_comp(self, open_brace, node, first_token, last_token):
172 # type: (str, AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
173 # For list/set/dict comprehensions, we only get the token of the first child, so adjust it to
174 # include the opening brace (the closing brace will be matched automatically).
175 before = self._code.prev_token(first_token)
176 util.expect_token(before, token.OP, open_brace)
177 return (before, last_token)
178
179 def visit_comprehension(self,
180 node, # type: AstNode
181 first_token, # type: util.Token
182 last_token, # type: util.Token
183 ):
184 # type: (...) -> Tuple[util.Token, util.Token]
185 # The 'comprehension' node starts with 'for' but we only get first child; we search backwards
186 # to find the 'for' keyword.
187 first = self._code.find_token(first_token, token.NAME, 'for', reverse=True)
188 return (first, last_token)
189
190 def visit_if(self, node, first_token, last_token):
191 # type: (util.Token, util.Token, util.Token) -> Tuple[util.Token, util.Token]
192 while first_token.string not in ('if', 'elif'):
193 first_token = self._code.prev_token(first_token)
194 return first_token, last_token
195
196 def handle_attr(self, node, first_token, last_token):
197 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
198 # Attribute node has ".attr" (2 tokens) after the last child.
199 dot = self._code.find_token(last_token, token.OP, '.')
200 name = self._code.next_token(dot)
201 util.expect_token(name, token.NAME)
202 return (first_token, name)
203
204 visit_attribute = handle_attr
205 visit_assignattr = handle_attr
206 visit_delattr = handle_attr
207
208 def handle_def(self, node, first_token, last_token):
209 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
210 # With astroid, nodes that start with a doc-string can have an empty body, in which case we
211 # need to adjust the last token to include the doc string.
212 if not node.body and (getattr(node, 'doc_node', None) or getattr(node, 'doc', None)): # type: ignore[union-attr]
213 last_token = self._code.find_token(last_token, token.STRING)
214
215 # Include @ from decorator
216 if first_token.index > 0:
217 prev = self._code.prev_token(first_token)
218 if util.match_token(prev, token.OP, '@'):
219 first_token = prev
220 return (first_token, last_token)
221
222 visit_classdef = handle_def
223 visit_functiondef = handle_def
224
225 def handle_following_brackets(self, node, last_token, opening_bracket):
226 # type: (AstNode, util.Token, str) -> util.Token
227 # This is for calls and subscripts, which have a pair of brackets
228 # at the end which may contain no nodes, e.g. foo() or bar[:].
229 # We look for the opening bracket and then let the matching pair be found automatically
230 # Remember that last_token is at the end of all children,
231 # so we are not worried about encountering a bracket that belongs to a child.
232 first_child = next(cast(Callable, self._iter_children)(node))
233 call_start = self._code.find_token(first_child.last_token, token.OP, opening_bracket)
234 if call_start.index > last_token.index:
235 last_token = call_start
236 return last_token
237
238 def visit_call(self, node, first_token, last_token):
239 # type: (util.Token, util.Token, util.Token) -> Tuple[util.Token, util.Token]
240 last_token = self.handle_following_brackets(node, last_token, '(')
241
242 # Handling a python bug with decorators with empty parens, e.g.
243 # @deco()
244 # def ...
245 if util.match_token(first_token, token.OP, '@'):
246 first_token = self._code.next_token(first_token)
247 return (first_token, last_token)
248
249 def visit_matchclass(self, node, first_token, last_token):
250 # type: (util.Token, util.Token, util.Token) -> Tuple[util.Token, util.Token]
251 last_token = self.handle_following_brackets(node, last_token, '(')
252 return (first_token, last_token)
253
254 def visit_subscript(self,
255 node, # type: AstNode
256 first_token, # type: util.Token
257 last_token, # type: util.Token
258 ):
259 # type: (...) -> Tuple[util.Token, util.Token]
260 last_token = self.handle_following_brackets(node, last_token, '[')
261 return (first_token, last_token)
262
263 def visit_slice(self, node, first_token, last_token):
264 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
265 # consume `:` tokens to the left and right. In Python 3.9, Slice nodes are
266 # given a col_offset, (and end_col_offset), so this will always start inside
267 # the slice, even if it is the empty slice. However, in 3.8 and below, this
268 # will only expand to the full slice if the slice contains a node with a
269 # col_offset. So x[:] will only get the correct tokens in 3.9, but x[1:] and
270 # x[:1] will even on earlier versions of Python.
271 while True:
272 prev = self._code.prev_token(first_token)
273 if prev.string != ':':
274 break
275 first_token = prev
276 while True:
277 next_ = self._code.next_token(last_token)
278 if next_.string != ':':
279 break
280 last_token = next_
281 return (first_token, last_token)
282
283 def handle_bare_tuple(self, node, first_token, last_token):
284 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
285 # A bare tuple doesn't include parens; if there is a trailing comma, make it part of the tuple.
286 maybe_comma = self._code.next_token(last_token)
287 if util.match_token(maybe_comma, token.OP, ','):
288 last_token = maybe_comma
289 return (first_token, last_token)
290
291 # In Python3.8 parsed tuples include parentheses when present.
292 def handle_tuple_nonempty(self, node, first_token, last_token):
293 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
294 assert isinstance(node, ast.Tuple) or isinstance(node, AstroidBaseContainer)
295 # It's a bare tuple if the first token belongs to the first child. The first child may
296 # include extraneous parentheses (which don't create new nodes), so account for those too.
297 child = node.elts[0]
298 if TYPE_CHECKING:
299 child = cast(AstNode, child)
300 child_first, child_last = self._gobble_parens(child.first_token, child.last_token, True)
301 if first_token == child_first:
302 return self.handle_bare_tuple(node, first_token, last_token)
303 return (first_token, last_token)
304
305 def visit_tuple(self, node, first_token, last_token):
306 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
307 assert isinstance(node, ast.Tuple) or isinstance(node, AstroidBaseContainer)
308 if not node.elts:
309 # An empty tuple is just "()", and we need no further info.
310 return (first_token, last_token)
311 return self.handle_tuple_nonempty(node, first_token, last_token)
312
313 def _gobble_parens(self, first_token, last_token, include_all=False):
314 # type: (util.Token, util.Token, bool) -> Tuple[util.Token, util.Token]
315 # Expands a range of tokens to include one or all pairs of surrounding parentheses, and
316 # returns (first, last) tokens that include these parens.
317 while first_token.index > 0:
318 prev = self._code.prev_token(first_token)
319 next = self._code.next_token(last_token)
320 if util.match_token(prev, token.OP, '(') and util.match_token(next, token.OP, ')'):
321 first_token, last_token = prev, next
322 if include_all:
323 continue
324 break
325 return (first_token, last_token)
326
327 def visit_str(self, node, first_token, last_token):
328 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
329 return self.handle_str(first_token, last_token)
330
331 def visit_joinedstr(self,
332 node, # type: AstNode
333 first_token, # type: util.Token
334 last_token, # type: util.Token
335 ):
336 # type: (...) -> Tuple[util.Token, util.Token]
337 if sys.version_info < (3, 12):
338 # Older versions don't tokenize the contents of f-strings
339 return self.handle_str(first_token, last_token)
340
341 last = first_token
342 while True:
343 if util.match_token(last, getattr(token, "FSTRING_START")):
344 # Python 3.12+ has tokens for the start (e.g. `f"`) and end (`"`)
345 # of the f-string. We can't just look for the next FSTRING_END
346 # because f-strings can be nested, e.g. f"{f'{x}'}", so we need
347 # to treat this like matching balanced parentheses.
348 count = 1
349 while count > 0:
350 last = self._code.next_token(last)
351 # mypy complains about token.FSTRING_START and token.FSTRING_END.
352 if util.match_token(last, getattr(token, "FSTRING_START")):
353 count += 1
354 elif util.match_token(last, getattr(token, "FSTRING_END")):
355 count -= 1
356 last_token = last
357 last = self._code.next_token(last_token)
358 elif util.match_token(last, token.STRING):
359 # Similar to handle_str, we also need to handle adjacent strings.
360 last_token = last
361 last = self._code.next_token(last_token)
362 else:
363 break
364 return (first_token, last_token)
365
366 def visit_bytes(self, node, first_token, last_token):
367 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
368 return self.handle_str(first_token, last_token)
369
370 def handle_str(self, first_token, last_token):
371 # type: (util.Token, util.Token) -> Tuple[util.Token, util.Token]
372 # Multiple adjacent STRING tokens form a single string.
373 last = self._code.next_token(last_token)
374 while util.match_token(last, token.STRING):
375 last_token = last
376 last = self._code.next_token(last_token)
377 return (first_token, last_token)
378
379 def handle_num(self,
380 node, # type: AstNode
381 value, # type: Union[complex, int, numbers.Number]
382 first_token, # type: util.Token
383 last_token, # type: util.Token
384 ):
385 # type: (...) -> Tuple[util.Token, util.Token]
386 # A constant like '-1' gets turned into two tokens; this will skip the '-'.
387 while util.match_token(last_token, token.OP):
388 last_token = self._code.next_token(last_token)
389
390 if isinstance(value, complex):
391 # A complex number like -2j cannot be compared directly to 0
392 # A complex number like 1-2j is expressed as a binary operation
393 # so we don't need to worry about it
394 value = value.imag
395
396 # This makes sure that the - is included
397 if value < 0 and first_token.type == token.NUMBER: # type: ignore[operator]
398 first_token = self._code.prev_token(first_token)
399 return (first_token, last_token)
400
401 def visit_num(self, node, first_token, last_token):
402 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
403 return self.handle_num(node, cast(ast.Num, node).n, first_token, last_token)
404
405 def visit_const(self, node, first_token, last_token):
406 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
407 assert isinstance(node, ast.Constant) or isinstance(node, nc.Const)
408 if isinstance(node.value, numbers.Number):
409 return self.handle_num(node, node.value, first_token, last_token)
410 elif isinstance(node.value, (str, bytes)):
411 return self.visit_str(node, first_token, last_token)
412 return (first_token, last_token)
413
414 visit_constant = visit_const
415
416 def visit_keyword(self, node, first_token, last_token):
417 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
418 # Until python 3.9 (https://bugs.python.org/issue40141),
419 # ast.keyword nodes didn't have line info. Astroid has lineno None.
420 assert isinstance(node, ast.keyword) or isinstance(node, nc.Keyword)
421 if node.arg is not None and getattr(node, 'lineno', None) is None:
422 equals = self._code.find_token(first_token, token.OP, '=', reverse=True)
423 name = self._code.prev_token(equals)
424 util.expect_token(name, token.NAME, node.arg)
425 first_token = name
426 return (first_token, last_token)
427
428 def visit_starred(self, node, first_token, last_token):
429 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
430 # Astroid has 'Starred' nodes (for "foo(*bar)" type args), but they need to be adjusted.
431 if not util.match_token(first_token, token.OP, '*'):
432 star = self._code.prev_token(first_token)
433 if util.match_token(star, token.OP, '*'):
434 first_token = star
435 return (first_token, last_token)
436
437 def visit_assignname(self, node, first_token, last_token):
438 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
439 # Astroid may turn 'except' clause into AssignName, but we need to adjust it.
440 if util.match_token(first_token, token.NAME, 'except'):
441 colon = self._code.find_token(last_token, token.OP, ':')
442 first_token = last_token = self._code.prev_token(colon)
443 return (first_token, last_token)
444
445 # Async nodes should typically start with the word 'async'
446 # but Python < 3.7 doesn't put the col_offset there
447 # AsyncFunctionDef is slightly different because it might have
448 # decorators before that, which visit_functiondef handles
449 def handle_async(self, node, first_token, last_token):
450 # type: (AstNode, util.Token, util.Token) -> Tuple[util.Token, util.Token]
451 if not first_token.string == 'async':
452 first_token = self._code.prev_token(first_token)
453 return (first_token, last_token)
454
455 visit_asyncfor = handle_async
456 visit_asyncwith = handle_async
457
458 def visit_asyncfunctiondef(self,
459 node, # type: AstNode
460 first_token, # type: util.Token
461 last_token, # type: util.Token
462 ):
463 # type: (...) -> Tuple[util.Token, util.Token]
464 if util.match_token(first_token, token.NAME, 'def'):
465 # Include the 'async' token
466 first_token = self._code.prev_token(first_token)
467 return self.visit_functiondef(node, first_token, last_token)