1# coding=utf-8
2"""Annotate python syntax trees with formatting from the source file."""
3# Copyright 2021 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import ast
23import contextlib
24import functools
25import itertools
26import numbers
27import six
28from six.moves import zip
29import sys
30import token
31
32from pasta.base import ast_constants
33from pasta.base import ast_utils
34from pasta.base import formatting as fmt
35from pasta.base import token_generator
36
37
38# ==============================================================================
39# == Helper functions for decorating nodes with prefix + suffix ==
40# ==============================================================================
41
42def _gen_wrapper(f, scope=True, prefix=True, suffix=True, max_suffix_lines=None,
43 semicolon=False, comment=False, statement=False):
44 @contextlib.wraps(f)
45 def wrapped(self, node, *args, **kwargs):
46 with (self.scope(node, trailing_comma=False) if scope else _noop_context()):
47 if prefix:
48 self.prefix(node, default=self._indent if statement else '')
49 f(self, node, *args, **kwargs)
50 if suffix:
51 self.suffix(node, max_lines=max_suffix_lines, semicolon=semicolon,
52 comment=comment, default='\n' if statement else '')
53 return wrapped
54
55
56@contextlib.contextmanager
57def _noop_context():
58 yield
59
60
61def expression(f):
62 """Decorates a function where the node is an expression."""
63 return _gen_wrapper(f, max_suffix_lines=0)
64
65
66def fstring_expression(f):
67 """Decorates a function where the node is a FormattedValue in an fstring."""
68 return _gen_wrapper(f, scope=False)
69
70
71def space_around(f):
72 """Decorates a function where the node has whitespace prefix and suffix."""
73 return _gen_wrapper(f, scope=False)
74
75
76def space_left(f):
77 """Decorates a function where the node has whitespace prefix."""
78 return _gen_wrapper(f, scope=False, suffix=False)
79
80
81def statement(f):
82 """Decorates a function where the node is a statement."""
83 return _gen_wrapper(f, scope=False, max_suffix_lines=1, semicolon=True,
84 comment=True, statement=True)
85
86
87def module(f):
88 """Special decorator for the module node."""
89 return _gen_wrapper(f, scope=False, comment=True)
90
91
92def block_statement(f):
93 """Decorates a function where the node is a statement with children."""
94 @contextlib.wraps(f)
95 def wrapped(self, node, *args, **kwargs):
96 self.prefix(node, default=self._indent)
97 f(self, node, *args, **kwargs)
98 if hasattr(self, 'block_suffix'):
99 last_child = ast_utils.get_last_child(node)
100 # Workaround for ast.Module which does not have a lineno
101 if last_child and last_child.lineno != getattr(node, 'lineno', 0):
102 indent = (fmt.get(last_child, 'prefix') or '\n').splitlines()[-1]
103 self.block_suffix(node, indent)
104 else:
105 self.suffix(node, comment=True)
106 return wrapped
107
108
109# ==============================================================================
110# == NodeVisitors for annotating an AST ==
111# ==============================================================================
112
113class BaseVisitor(ast.NodeVisitor):
114 """Walks a syntax tree in the order it appears in code.
115
116 This class has a dual-purpose. It is implemented (in this file) for annotating
117 an AST with formatting information needed to reconstruct the source code, but
118 it also is implemented in pasta.base.codegen to reconstruct the source code.
119
120 Each visit method in this class specifies the order in which both child nodes
121 and syntax tokens appear, plus where to account for whitespace, commas,
122 parentheses, etc.
123 """
124
125 __metaclass__ = abc.ABCMeta
126
127 def __init__(self):
128 self._stack = []
129 self._indent = ''
130 self._indent_diff = ''
131 self._default_indent_diff = ' '
132
133 def visit(self, node):
134 self._stack.append(node)
135 super(BaseVisitor, self).visit(node)
136 assert node is self._stack.pop()
137
138 def prefix(self, node, default=''):
139 """Account for some amount of whitespace as the prefix to a node."""
140 self.attr(node, 'prefix', [lambda: self.ws(comment=True)], default=default)
141
142 def suffix(self, node, max_lines=None, semicolon=False, comment=False,
143 default=''):
144 """Account for some amount of whitespace as the suffix to a node."""
145 def _ws():
146 return self.ws(max_lines=max_lines, semicolon=semicolon, comment=comment)
147 self.attr(node, 'suffix', [_ws], default=default)
148
149 def indented(self, node, children_attr):
150 children = getattr(node, children_attr)
151 prev_indent = self._indent
152 prev_indent_diff = self._indent_diff
153 new_diff = fmt.get(children[0], 'indent_diff')
154 if new_diff is None:
155 new_diff = self._default_indent_diff
156 self._indent_diff = new_diff
157 self._indent = prev_indent + self._indent_diff
158 for child in children:
159 yield child
160 self.attr(node, 'block_suffix_%s' % children_attr, [])
161 self._indent = prev_indent
162 self._indent_diff = prev_indent_diff
163
164 def set_default_indent_diff(self, indent):
165 self._default_indent_diff = indent
166
167 @contextlib.contextmanager
168 def scope(self, node, attr=None, trailing_comma=False, default_parens=False):
169 """Context manager to handle a parenthesized scope.
170
171 Arguments:
172 node: (ast.AST) Node to store the scope prefix and suffix on.
173 attr: (string, optional) Attribute of the node contained in the scope, if
174 any. For example, as `None`, the scope would wrap the entire node, but
175 as 'bases', the scope might wrap only the bases of a class.
176 trailing_comma: (boolean) If True, allow a trailing comma at the end.
177 default_parens: (boolean) If True and no formatting information is
178 present, the scope would be assumed to be parenthesized.
179 """
180 if attr:
181 self.attr(node, attr + '_prefix', [],
182 default='(' if default_parens else '')
183 yield
184 if attr:
185 self.attr(node, attr + '_suffix', [],
186 default=')' if default_parens else '')
187
188 def token(self, token_val):
189 """Account for a specific token."""
190
191 def attr(self, node, attr_name, attr_vals, deps=None, default=None):
192 """Handles an attribute on the given node."""
193
194 def ws(self, max_lines=None, semicolon=False, comment=True):
195 """Account for some amount of whitespace.
196
197 Arguments:
198 max_lines: (int) Maximum number of newlines to consider.
199 semicolon: (boolean) If True, parse up to the next semicolon (if present).
200 comment: (boolean) If True, look for a trailing comment even when not in
201 a parenthesized scope.
202 """
203 return ''
204
205 def dots(self, num_dots):
206 """Account for a number of dots."""
207 return '.' * num_dots
208
209 def ws_oneline(self):
210 """Account for up to one line of whitespace."""
211 return self.ws(max_lines=1)
212
213 def optional_token(self, node, attr_name, token_val, default=False):
214 """Account for a suffix that may or may not occur."""
215
216 def one_of_symbols(self, *symbols):
217 """Account for one of the given symbols."""
218 return symbols[0]
219
220 # ============================================================================
221 # == BLOCK STATEMENTS: Statements that contain a list of statements ==
222 # ============================================================================
223
224 # Keeps the entire suffix, so @block_statement is not useful here.
225 @module
226 def visit_Module(self, node):
227 try:
228 self.attr(
229 node, 'bom',
230 [lambda: self.tokens.eat_tokens(lambda t: t.type == token.ERRORTOKEN)],
231 default='')
232 except:
233 pass
234 self.generic_visit(node)
235
236 @block_statement
237 def visit_If(self, node):
238 tok = 'elif' if fmt.get(node, 'is_elif') else 'if'
239 self.attr(node, 'open_if', [tok, self.ws], default=tok + ' ')
240 self.visit(node.test)
241 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
242 default=':\n')
243
244 for stmt in self.indented(node, 'body'):
245 self.visit(stmt)
246
247 if node.orelse:
248 if (len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If) and
249 self.check_is_elif(node.orelse[0])):
250 fmt.set(node.orelse[0], 'is_elif', True)
251 self.visit(node.orelse[0])
252 else:
253 self.attr(node, 'elseprefix', [self.ws])
254 self.token('else')
255 self.attr(node, 'open_else', [self.ws, ':', self.ws_oneline],
256 default=':\n')
257 for stmt in self.indented(node, 'orelse'):
258 self.visit(stmt)
259
260 @abc.abstractmethod
261 def check_is_elif(self, node):
262 """Return True if the node continues a previous `if` statement as `elif`.
263
264 In python 2.x, `elif` statements get parsed as If nodes. E.g, the following
265 two syntax forms are indistinguishable in the ast in python 2.
266
267 if a:
268 do_something()
269 elif b:
270 do_something_else()
271
272 if a:
273 do_something()
274 else:
275 if b:
276 do_something_else()
277
278 This method should return True for the 'if b' node if it has the first form.
279 """
280
281 @block_statement
282 def visit_While(self, node):
283 self.attr(node, 'while_keyword', ['while', self.ws], default='while ')
284 self.visit(node.test)
285 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
286 default=':\n')
287 for stmt in self.indented(node, 'body'):
288 self.visit(stmt)
289
290 if node.orelse:
291 self.attr(node, 'else', [self.ws, 'else', self.ws, ':', self.ws_oneline],
292 default=self._indent + 'else:\n')
293 for stmt in self.indented(node, 'orelse'):
294 self.visit(stmt)
295
296 @block_statement
297 def visit_For(self, node, is_async=False):
298 if is_async:
299 self.attr(node, 'for_keyword', ['async', self.ws, 'for', self.ws],
300 default='async for ')
301 else:
302 self.attr(node, 'for_keyword', ['for', self.ws], default='for ')
303 self.visit(node.target)
304 self.attr(node, 'for_in', [self.ws, 'in', self.ws], default=' in ')
305 self.visit(node.iter)
306 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
307 default=':\n')
308 for stmt in self.indented(node, 'body'):
309 self.visit(stmt)
310
311 if node.orelse:
312 self.attr(node, 'else', [self.ws, 'else', self.ws, ':', self.ws_oneline],
313 default=self._indent + 'else:\n')
314
315 for stmt in self.indented(node, 'orelse'):
316 self.visit(stmt)
317
318 def visit_AsyncFor(self, node):
319 return self.visit_For(node, is_async=True)
320
321 @block_statement
322 def visit_With(self, node, is_async=False):
323 if hasattr(node, 'items'):
324 return self.visit_With_3(node, is_async)
325 if not getattr(node, 'is_continued', False):
326 self.attr(node, 'with', ['with', self.ws], default='with ')
327 self.visit(node.context_expr)
328 if node.optional_vars:
329 self.attr(node, 'with_as', [self.ws, 'as', self.ws], default=' as ')
330 self.visit(node.optional_vars)
331
332 if len(node.body) == 1 and self.check_is_continued_with(node.body[0]):
333 node.body[0].is_continued = True
334 self.attr(node, 'with_comma', [self.ws, ',', self.ws], default=', ')
335 else:
336 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
337 default=':\n')
338 for stmt in self.indented(node, 'body'):
339 self.visit(stmt)
340
341 def visit_AsyncWith(self, node):
342 return self.visit_With(node, is_async=True)
343
344 @abc.abstractmethod
345 def check_is_continued_try(self, node):
346 pass
347
348 @abc.abstractmethod
349 def check_is_continued_with(self, node):
350 """Return True if the node continues a previous `with` statement.
351
352 In python 2.x, `with` statements with many context expressions get parsed as
353 a tree of With nodes. E.g, the following two syntax forms are
354 indistinguishable in the ast in python 2.
355
356 with a, b, c:
357 do_something()
358
359 with a:
360 with b:
361 with c:
362 do_something()
363
364 This method should return True for the `with b` and `with c` nodes.
365 """
366
367 def visit_With_3(self, node, is_async=False):
368 if is_async:
369 self.attr(node, 'with', ['async', self.ws, 'with', self.ws],
370 default='async with ')
371 else:
372 self.attr(node, 'with', ['with', self.ws], default='with ')
373
374 for i, withitem in enumerate(node.items):
375 self.visit(withitem)
376 if i != len(node.items) - 1:
377 self.token(',')
378
379 self.attr(node, 'with_body_open', [':', self.ws_oneline], default=':\n')
380 for stmt in self.indented(node, 'body'):
381 self.visit(stmt)
382
383 @space_around
384 def visit_withitem(self, node):
385 self.visit(node.context_expr)
386 if node.optional_vars:
387 self.attr(node, 'as', [self.ws, 'as', self.ws], default=' as ')
388 self.visit(node.optional_vars)
389
390 @block_statement
391 def visit_ClassDef(self, node):
392 for i, decorator in enumerate(node.decorator_list):
393 self.attr(node, 'decorator_prefix_%d' % i, [self.ws, '@'], default='@')
394 self.visit(decorator)
395 self.attr(node, 'decorator_suffix_%d' % i, [self.ws],
396 default='\n' + self._indent)
397 self.attr(node, 'class_def', ['class', self.ws, node.name, self.ws],
398 default='class %s' % node.name, deps=('name',))
399 class_args = getattr(node, 'bases', []) + getattr(node, 'keywords', [])
400 with self.scope(node, 'bases', trailing_comma=bool(class_args),
401 default_parens=True):
402 for i, base in enumerate(node.bases):
403 self.visit(base)
404 self.attr(node, 'base_suffix_%d' % i, [self.ws])
405 if base != class_args[-1]:
406 self.attr(node, 'base_sep_%d' % i, [',', self.ws], default=', ')
407 if hasattr(node, 'keywords'):
408 for i, keyword in enumerate(node.keywords):
409 self.visit(keyword)
410 self.attr(node, 'keyword_suffix_%d' % i, [self.ws])
411 if keyword != node.keywords[-1]:
412 self.attr(node, 'keyword_sep_%d' % i, [',', self.ws], default=', ')
413 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
414 default=':\n')
415 for stmt in self.indented(node, 'body'):
416 self.visit(stmt)
417
418 @block_statement
419 def visit_FunctionDef(self, node, is_async=False):
420 for i, decorator in enumerate(node.decorator_list):
421 self.attr(node, 'decorator_symbol_%d' % i, [self.ws, '@', self.ws],
422 default='@')
423 self.visit(decorator)
424 self.attr(node, 'decorator_suffix_%d' % i, [self.ws_oneline],
425 default='\n' + self._indent)
426 if is_async:
427 self.attr(node, 'function_def',
428 [self.ws, 'async', self.ws, 'def', self.ws, node.name, self.ws],
429 deps=('name',), default='async def %s' % node.name)
430 else:
431 self.attr(node, 'function_def',
432 [self.ws, 'def', self.ws, node.name, self.ws],
433 deps=('name',), default='def %s' % node.name)
434 # In Python 3, there can be extra args in kwonlyargs
435 kwonlyargs = getattr(node.args, 'kwonlyargs', [])
436 args_count = sum((len(node.args.args + kwonlyargs),
437 1 if node.args.vararg else 0,
438 1 if node.args.kwarg else 0))
439 with self.scope(node, 'args', trailing_comma=args_count > 0,
440 default_parens=True):
441 self.visit(node.args)
442
443 if getattr(node, 'returns', None):
444 self.attr(node, 'returns_prefix', [self.ws, '->', self.ws],
445 deps=('returns',), default=' -> ')
446 self.visit(node.returns)
447
448 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
449 default=':\n')
450 for stmt in self.indented(node, 'body'):
451 self.visit(stmt)
452
453 def visit_AsyncFunctionDef(self, node):
454 return self.visit_FunctionDef(node, is_async=True)
455
456 @block_statement
457 def visit_TryFinally(self, node):
458 # Try with except and finally is a TryFinally with the first statement as a
459 # TryExcept in Python2
460 self.attr(node, 'open_try', ['try', self.ws, ':', self.ws_oneline],
461 default='try:\n')
462 # TODO(soupytwist): Find a cleaner solution for differentiating this.
463 if len(node.body) == 1 and self.check_is_continued_try(node.body[0]):
464 node.body[0].is_continued = True
465 self.visit(node.body[0])
466 else:
467 for stmt in self.indented(node, 'body'):
468 self.visit(stmt)
469 self.attr(node, 'open_finally',
470 [self.ws, 'finally', self.ws, ':', self.ws_oneline],
471 default='finally:\n')
472 for stmt in self.indented(node, 'finalbody'):
473 self.visit(stmt)
474
475 @block_statement
476 def visit_TryExcept(self, node):
477 if not getattr(node, 'is_continued', False):
478 self.attr(node, 'open_try', ['try', self.ws, ':', self.ws_oneline],
479 default='try:\n')
480 for stmt in self.indented(node, 'body'):
481 self.visit(stmt)
482 for handler in node.handlers:
483 self.visit(handler)
484 if node.orelse:
485 self.attr(node, 'open_else',
486 [self.ws, 'else', self.ws, ':', self.ws_oneline],
487 default='else:\n')
488 for stmt in self.indented(node, 'orelse'):
489 self.visit(stmt)
490
491 @block_statement
492 def visit_Try(self, node):
493 # Python 3
494 self.attr(node, 'open_try', [self.ws, 'try', self.ws, ':', self.ws_oneline],
495 default='try:\n')
496 for stmt in self.indented(node, 'body'):
497 self.visit(stmt)
498 for handler in node.handlers:
499 self.visit(handler)
500 if node.orelse:
501 self.attr(node, 'open_else',
502 [self.ws, 'else', self.ws, ':', self.ws_oneline],
503 default='else:\n')
504 for stmt in self.indented(node, 'orelse'):
505 self.visit(stmt)
506 if node.finalbody:
507 self.attr(node, 'open_finally',
508 [self.ws, 'finally', self.ws, ':', self.ws_oneline],
509 default='finally:\n')
510 for stmt in self.indented(node, 'finalbody'):
511 self.visit(stmt)
512
513 @block_statement
514 def visit_ExceptHandler(self, node):
515 self.token('except')
516 if node.type:
517 self.visit(node.type)
518 if node.type and node.name:
519 self.attr(node, 'as', [self.ws, self.one_of_symbols("as", ","), self.ws],
520 default=' as ')
521 if node.name:
522 if isinstance(node.name, ast.AST):
523 self.visit(node.name)
524 else:
525 self.token(node.name)
526 self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
527 default=':\n')
528 for stmt in self.indented(node, 'body'):
529 self.visit(stmt)
530
531 @statement
532 def visit_Raise(self, node):
533 if hasattr(node, 'cause'):
534 return self.visit_Raise_3(node)
535
536 self.token('raise')
537 if node.type:
538 self.attr(node, 'type_prefix', [self.ws], default=' ')
539 self.visit(node.type)
540 if node.inst:
541 self.attr(node, 'inst_prefix', [self.ws, ',', self.ws], default=', ')
542 self.visit(node.inst)
543 if node.tback:
544 self.attr(node, 'tback_prefix', [self.ws, ',', self.ws], default=', ')
545 self.visit(node.tback)
546
547 def visit_Raise_3(self, node):
548 if node.exc:
549 self.attr(node, 'open_raise', ['raise', self.ws], default='raise ')
550 self.visit(node.exc)
551 if node.cause:
552 self.attr(node, 'cause_prefix', [self.ws, 'from', self.ws],
553 default=' from ')
554 self.visit(node.cause)
555 else:
556 self.token('raise')
557
558 # ============================================================================
559 # == STATEMENTS: Instructions without a return value ==
560 # ============================================================================
561
562 @statement
563 def visit_Assert(self, node):
564 self.attr(node, 'assert_open', ['assert', self.ws], default='assert ')
565 self.visit(node.test)
566 if node.msg:
567 self.attr(node, 'msg_prefix', [',', self.ws], default=', ')
568 self.visit(node.msg)
569
570 @statement
571 def visit_Assign(self, node):
572 for i, target in enumerate(node.targets):
573 self.visit(target)
574 self.attr(node, 'equal_%d' % i, [self.ws, '=', self.ws], default=' = ')
575 self.visit(node.value)
576
577 @statement
578 def visit_AugAssign(self, node):
579 self.visit(node.target)
580 op_token = '%s=' % ast_constants.NODE_TYPE_TO_TOKENS[type(node.op)][0]
581 self.attr(node, 'operator', [self.ws, op_token, self.ws],
582 default=' %s ' % op_token)
583 self.visit(node.value)
584
585 @statement
586 def visit_AnnAssign(self, node):
587 # TODO: Check default formatting for different values of "simple"
588 self.visit(node.target)
589 self.attr(node, 'colon', [self.ws, ':', self.ws], default=': ')
590 self.visit(node.annotation)
591 if node.value:
592 self.attr(node, 'equal', [self.ws, '=', self.ws], default=' = ')
593 self.visit(node.value)
594
595 @expression
596 def visit_Await(self, node):
597 self.attr(node, 'await', ['await', self.ws], default='await ')
598 self.visit(node.value)
599
600 @statement
601 def visit_Break(self, node):
602 self.token('break')
603
604 @statement
605 def visit_Continue(self, node):
606 self.token('continue')
607
608 @statement
609 def visit_Delete(self, node):
610 self.attr(node, 'del', ['del', self.ws], default='del ')
611 for i, target in enumerate(node.targets):
612 self.visit(target)
613 if target is not node.targets[-1]:
614 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
615
616 @statement
617 def visit_Exec(self, node):
618 # If no formatting info is present, will use parenthesized style
619 self.attr(node, 'exec', ['exec', self.ws], default='exec')
620 with self.scope(node, 'body', trailing_comma=False, default_parens=True):
621 self.visit(node.body)
622 if node.globals:
623 self.attr(node, 'in_globals',
624 [self.ws, self.one_of_symbols('in', ','), self.ws],
625 default=', ')
626 self.visit(node.globals)
627 if node.locals:
628 self.attr(node, 'in_locals', [self.ws, ',', self.ws], default=', ')
629 self.visit(node.locals)
630
631 @statement
632 def visit_Expr(self, node):
633 self.visit(node.value)
634
635 @statement
636 def visit_Global(self, node):
637 self.token('global')
638 identifiers = []
639 for ident in node.names:
640 if ident != node.names[0]:
641 identifiers.extend([self.ws, ','])
642 identifiers.extend([self.ws, ident])
643 self.attr(node, 'names', identifiers)
644
645 @statement
646 def visit_Import(self, node):
647 self.token('import')
648 for i, alias in enumerate(node.names):
649 self.attr(node, 'alias_prefix_%d' % i, [self.ws], default=' ')
650 self.visit(alias)
651 if alias != node.names[-1]:
652 self.attr(node, 'alias_sep_%d' % i, [self.ws, ','], default=',')
653
654 @statement
655 def visit_ImportFrom(self, node):
656 self.token('from')
657 self.attr(node, 'module_prefix', [self.ws], default=' ')
658
659 module_pattern = []
660 if node.level > 0:
661 module_pattern.extend([self.dots(node.level), self.ws])
662 if node.module:
663 parts = node.module.split('.')
664 for part in parts[:-1]:
665 module_pattern += [self.ws, part, self.ws, '.']
666 module_pattern += [self.ws, parts[-1]]
667
668 self.attr(node, 'module', module_pattern,
669 deps=('level', 'module'),
670 default='.' * node.level + (node.module or ''))
671 self.attr(node, 'module_suffix', [self.ws], default=' ')
672
673 self.token('import')
674 with self.scope(node, 'names', trailing_comma=True):
675 for i, alias in enumerate(node.names):
676 self.attr(node, 'alias_prefix_%d' % i, [self.ws], default=' ')
677 self.visit(alias)
678 if alias is not node.names[-1]:
679 self.attr(node, 'alias_sep_%d' % i, [self.ws, ','], default=',')
680
681 @expression
682 def visit_NamedExpr(self, node):
683 self.visit(node.target)
684 self.attr(node, 'equal', [self.ws, ':=', self.ws], default=' := ')
685 self.visit(node.value)
686
687 @statement
688 def visit_Nonlocal(self, node):
689 self.token('nonlocal')
690 identifiers = []
691 for ident in node.names:
692 if ident != node.names[0]:
693 identifiers.extend([self.ws, ','])
694 identifiers.extend([self.ws, ident])
695 self.attr(node, 'names', identifiers)
696
697 @statement
698 def visit_Pass(self, node):
699 self.token('pass')
700
701 @statement
702 def visit_Print(self, node):
703 self.attr(node, 'print_open', ['print', self.ws], default='print ')
704 if node.dest:
705 self.attr(node, 'redirection', ['>>', self.ws], default='>>')
706 self.visit(node.dest)
707 if node.values:
708 self.attr(node, 'values_prefix', [self.ws, ',', self.ws], default=', ')
709 elif not node.nl:
710 self.attr(node, 'trailing_comma', [self.ws, ','], default=',')
711
712 for i, value in enumerate(node.values):
713 self.visit(value)
714 if value is not node.values[-1]:
715 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
716 elif not node.nl:
717 self.attr(node, 'trailing_comma', [self.ws, ','], default=',')
718
719 @statement
720 def visit_Return(self, node):
721 self.token('return')
722 if node.value:
723 self.attr(node, 'return_value_prefix', [self.ws], default=' ')
724 self.visit(node.value)
725
726 @expression
727 def visit_Yield(self, node):
728 self.token('yield')
729 if node.value:
730 self.attr(node, 'yield_value_prefix', [self.ws], default=' ')
731 self.visit(node.value)
732
733 @expression
734 def visit_YieldFrom(self, node):
735 self.attr(node, 'yield_from', ['yield', self.ws, 'from', self.ws],
736 default='yield from ')
737 self.visit(node.value)
738
739 # ============================================================================
740 # == EXPRESSIONS: Anything that evaluates and can be in parens ==
741 # ============================================================================
742
743 @expression
744 def visit_Attribute(self, node):
745 self.visit(node.value)
746 self.attr(node, 'dot', [self.ws, '.', self.ws], default='.')
747 self.token(node.attr)
748
749 @expression
750 def visit_BinOp(self, node):
751 op_symbol = ast_constants.NODE_TYPE_TO_TOKENS[type(node.op)][0]
752 self.visit(node.left)
753 self.attr(node, 'op', [self.ws, op_symbol, self.ws],
754 default=' %s ' % op_symbol, deps=('op',))
755 self.visit(node.right)
756
757 @expression
758 def visit_BoolOp(self, node):
759 op_symbol = ast_constants.NODE_TYPE_TO_TOKENS[type(node.op)][0]
760 for i, value in enumerate(node.values):
761 self.visit(value)
762 if value is not node.values[-1]:
763 self.attr(node, 'op_%d' % i, [self.ws, op_symbol, self.ws],
764 default=' %s ' % op_symbol, deps=('op',))
765
766 @expression
767 def visit_Call(self, node):
768 self.visit(node.func)
769
770 with self.scope(node, 'arguments', default_parens=True):
771 # python <3.5: starargs and kwargs are in separate fields
772 # python 3.5+: starargs args included as a Starred nodes in the arguments
773 # and kwargs are included as keywords with no argument name.
774 if sys.version_info[:2] >= (3, 5):
775 any_args = self.visit_Call_arguments35(node)
776 else:
777 any_args = self.visit_Call_arguments(node)
778 if any_args:
779 self.optional_token(node, 'trailing_comma', ',')
780
781 def visit_Call_arguments(self, node):
782 def arg_location(tup):
783 arg = tup[1]
784 if isinstance(arg, ast.keyword):
785 arg = arg.value
786 return (getattr(arg, "lineno", 0), getattr(arg, "col_offset", 0))
787
788 if node.starargs:
789 sorted_keywords = sorted(
790 [(None, kw) for kw in node.keywords] + [('*', node.starargs)],
791 key=arg_location)
792 else:
793 sorted_keywords = [(None, kw) for kw in node.keywords]
794 all_args = [(None, n) for n in node.args] + sorted_keywords
795 if node.kwargs:
796 all_args.append(('**', node.kwargs))
797
798 for i, (prefix, arg) in enumerate(all_args):
799 if prefix is not None:
800 self.attr(node, '%s_prefix' % prefix, [self.ws, prefix], default=prefix)
801 self.visit(arg)
802 if arg is not all_args[-1][1]:
803 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
804 return bool(all_args)
805
806 def visit_Call_arguments35(self, node):
807 def arg_compare(a1, a2):
808 """Old-style comparator for sorting args."""
809 def is_arg(a):
810 return not isinstance(a, (ast.keyword, ast.Starred))
811
812 # No kwarg can come before a regular arg (but Starred can be wherever)
813 if is_arg(a1) and isinstance(a2, ast.keyword):
814 return -1
815 elif is_arg(a2) and isinstance(a1, ast.keyword):
816 return 1
817
818 # If no lineno or col_offset on one of the args, they compare as equal
819 # (since sorting is stable, this should leave them mostly where they
820 # were in the initial list).
821 def get_pos(a):
822 if isinstance(a, ast.keyword):
823 a = a.value
824 return (getattr(a, 'lineno', None), getattr(a, 'col_offset', None))
825
826 pos1 = get_pos(a1)
827 pos2 = get_pos(a2)
828
829 if None in pos1 or None in pos2:
830 return 0
831
832 # If both have lineno/col_offset set, use that to sort them
833 return -1 if pos1 < pos2 else 0 if pos1 == pos2 else 1
834
835 # Note that this always sorts keywords identically to just sorting by
836 # lineno/col_offset, except in cases where that ordering would have been
837 # a syntax error (named arg before unnamed arg).
838 all_args = sorted(node.args + node.keywords,
839 key=functools.cmp_to_key(arg_compare))
840
841 for i, arg in enumerate(all_args):
842 self.visit(arg)
843 if arg is not all_args[-1]:
844 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
845 return bool(all_args)
846
847 def visit_Starred(self, node):
848 self.attr(node, 'star', ['*', self.ws], default='*')
849 self.visit(node.value)
850
851 @expression
852 def visit_Compare(self, node):
853 self.visit(node.left)
854 for i, (op, comparator) in enumerate(zip(node.ops, node.comparators)):
855 self.attr(node, 'op_prefix_%d' % i, [self.ws], default=' ')
856 self.visit(op)
857 self.attr(node, 'op_suffix_%d' % i, [self.ws], default=' ')
858 self.visit(comparator)
859
860 @expression
861 def visit_Constant(self, node):
862 if hasattr(node, 'kind') and node.kind:
863 self.attr(node, 'content', [self.tokens.str],
864 default='%s"%s"' % (node.kind, node.value), deps=('value',))
865 elif isinstance(node.value, bool):
866 self.attr(node, 'content', [str(node.value)], default=str(node.value),
867 deps=('value',))
868 elif node.value is Ellipsis:
869 self.token('...')
870 elif isinstance(node.value, numbers.Number):
871 token_number_type = token_generator.TOKENS.NUMBER
872 self.attr(node, 'content',
873 [lambda: self.tokens.next_of_type(token_number_type).src],
874 deps=('value',), default=str(node.value))
875 elif isinstance(node.value, six.text_type) or isinstance(node.value, bytes):
876 self.attr(node, 'content', [self.tokens.str], deps=('value',),
877 default=node.value)
878 else:
879 self.token(str(node.value))
880
881 @expression
882 def visit_Dict(self, node):
883 self.token('{')
884 for i, key, value in zip(range(len(node.keys)), node.keys, node.values):
885 if key is None:
886 # Handle Python 3.5+ dict unpacking syntax (PEP-448)
887 self.attr(node, 'starstar_%d' % i, [self.ws, '**'], default='**')
888 else:
889 self.visit(key)
890 self.attr(node, 'key_val_sep_%d' % i, [self.ws, ':', self.ws],
891 default=': ')
892 self.visit(value)
893 if value is not node.values[-1]:
894 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
895 self.optional_token(node, 'extracomma', ',', allow_whitespace_prefix=True)
896 self.attr(node, 'close_prefix', [self.ws, '}'], default='}')
897
898 @expression
899 def visit_DictComp(self, node):
900 self.attr(node, 'open_dict', ['{', self.ws], default='{')
901 self.visit(node.key)
902 self.attr(node, 'key_val_sep', [self.ws, ':', self.ws], default=': ')
903 self.visit(node.value)
904 for comp in node.generators:
905 self.visit(comp)
906 self.attr(node, 'close_dict', [self.ws, '}'], default='}')
907
908 @expression
909 def visit_GeneratorExp(self, node):
910 self._comp_exp(node)
911
912 @expression
913 def visit_IfExp(self, node):
914 self.visit(node.body)
915 self.attr(node, 'if', [self.ws, 'if', self.ws], default=' if ')
916 self.visit(node.test)
917 self.attr(node, 'else', [self.ws, 'else', self.ws], default=' else ')
918 self.visit(node.orelse)
919
920 @expression
921 def visit_Lambda(self, node):
922 self.attr(node, 'lambda_def', ['lambda', self.ws], default='lambda ')
923 self.visit(node.args)
924 self.attr(node, 'open_lambda', [self.ws, ':', self.ws], default=': ')
925 self.visit(node.body)
926
927 @expression
928 def visit_List(self, node):
929 self.attr(node, 'list_open', ['[', self.ws], default='[')
930
931 for i, elt in enumerate(node.elts):
932 self.visit(elt)
933 if elt is not node.elts[-1]:
934 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
935 if node.elts:
936 self.optional_token(node, 'extracomma', ',', allow_whitespace_prefix=True)
937
938 self.attr(node, 'list_close', [self.ws, ']'], default=']')
939
940 @expression
941 def visit_ListComp(self, node):
942 self._comp_exp(node, open_brace='[', close_brace=']')
943
944 def _comp_exp(self, node, open_brace=None, close_brace=None):
945 if open_brace:
946 self.attr(node, 'compexp_open', [open_brace, self.ws], default=open_brace)
947 self.visit(node.elt)
948 for i, comp in enumerate(node.generators):
949 self.visit(comp)
950 if close_brace:
951 self.attr(node, 'compexp_close', [self.ws, close_brace],
952 default=close_brace)
953
954 @expression
955 def visit_Name(self, node):
956 self.token(node.id)
957
958 @expression
959 def visit_NameConstant(self, node):
960 self.token(str(node.value))
961
962 @expression
963 def visit_Repr(self, node):
964 self.attr(node, 'repr_open', ['`', self.ws], default='`')
965 self.visit(node.value)
966 self.attr(node, 'repr_close', [self.ws, '`'], default='`')
967
968 @expression
969 def visit_Set(self, node):
970 self.attr(node, 'set_open', ['{', self.ws], default='{')
971
972 for i, elt in enumerate(node.elts):
973 self.visit(elt)
974 if elt is not node.elts[-1]:
975 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws], default=', ')
976 else:
977 self.optional_token(node, 'extracomma', ',',
978 allow_whitespace_prefix=True)
979
980 self.attr(node, 'set_close', [self.ws, '}'], default='}')
981
982 @expression
983 def visit_SetComp(self, node):
984 self._comp_exp(node, open_brace='{', close_brace='}')
985
986 @expression
987 def visit_Subscript(self, node):
988 self.visit(node.value)
989 self.attr(node, 'slice_open', [self.ws, '[', self.ws], default='[')
990 self.visit(node.slice)
991 self.attr(node, 'slice_close', [self.ws, ']'], default=']')
992
993 @expression
994 def visit_Tuple(self, node):
995 with self.scope(node, 'elts', default_parens=True):
996 for i, elt in enumerate(node.elts):
997 self.visit(elt)
998 if elt is not node.elts[-1]:
999 self.attr(node, 'comma_%d' % i, [self.ws, ',', self.ws],
1000 default=', ')
1001 else:
1002 self.optional_token(node, 'extracomma', ',',
1003 allow_whitespace_prefix=True,
1004 default=len(node.elts) == 1)
1005
1006 @expression
1007 def visit_UnaryOp(self, node):
1008 op_symbol = ast_constants.NODE_TYPE_TO_TOKENS[type(node.op)][0]
1009 self.attr(node, 'op', [op_symbol, self.ws], default=op_symbol, deps=('op',))
1010 self.visit(node.operand)
1011
1012 # ============================================================================
1013 # == OPERATORS AND TOKENS: Anything that's just whitespace and tokens ==
1014 # ============================================================================
1015
1016 @space_around
1017 def visit_Ellipsis(self, node):
1018 self.token('...')
1019
1020 def visit_And(self, node):
1021 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1022
1023 def visit_Or(self, node):
1024 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1025
1026 def visit_Add(self, node):
1027 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1028
1029 def visit_Sub(self, node):
1030 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1031
1032 def visit_Mult(self, node):
1033 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1034
1035 def visit_Div(self, node):
1036 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1037
1038 def visit_MatMult(self, node):
1039 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1040
1041 def visit_Mod(self, node):
1042 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1043
1044 def visit_Pow(self, node):
1045 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1046
1047 def visit_LShift(self, node):
1048 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1049
1050 def visit_RShift(self, node):
1051 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1052
1053 def visit_BitAnd(self, node):
1054 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1055
1056 def visit_BitOr(self, node):
1057 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1058
1059 def visit_BitXor(self, node):
1060 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1061
1062 def visit_FloorDiv(self, node):
1063 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1064
1065 def visit_Invert(self, node):
1066 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1067
1068 def visit_Not(self, node):
1069 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1070
1071 def visit_UAdd(self, node):
1072 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1073
1074 def visit_USub(self, node):
1075 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1076
1077 def visit_Eq(self, node):
1078 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1079
1080 def visit_NotEq(self, node):
1081 self.attr(node, 'operator', [self.one_of_symbols('!=', '<>')])
1082
1083 def visit_Lt(self, node):
1084 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1085
1086 def visit_LtE(self, node):
1087 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1088
1089 def visit_Gt(self, node):
1090 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1091
1092 def visit_GtE(self, node):
1093 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1094
1095 def visit_Is(self, node):
1096 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1097
1098 def visit_IsNot(self, node):
1099 self.attr(node, 'content', ['is', self.ws, 'not'], default='is not')
1100
1101 def visit_In(self, node):
1102 self.token(ast_constants.NODE_TYPE_TO_TOKENS[type(node)][0])
1103
1104 def visit_NotIn(self, node):
1105 self.attr(node, 'content', ['not', self.ws, 'in'], default='not in')
1106
1107 # ============================================================================
1108 # == MISC NODES: Nodes which are neither statements nor expressions ==
1109 # ============================================================================
1110
1111 def visit_alias(self, node):
1112 name_pattern = []
1113 parts = node.name.split('.')
1114 for part in parts[:-1]:
1115 name_pattern += [self.ws, part, self.ws, '.']
1116 name_pattern += [self.ws, parts[-1]]
1117 self.attr(node, 'name', name_pattern,
1118 deps=('name',),
1119 default=node.name)
1120 if node.asname is not None:
1121 self.attr(node, 'asname', [self.ws, 'as', self.ws], default=' as ')
1122 self.token(node.asname)
1123
1124 @space_around
1125 def visit_arg(self, node):
1126 self.token(node.arg)
1127 if node.annotation is not None:
1128 self.attr(node, 'annotation_prefix', [self.ws, ':', self.ws],
1129 default=': ')
1130 self.visit(node.annotation)
1131
1132 @space_around
1133 def visit_arguments(self, node):
1134 # In Python 3, args appearing after *args must be kwargs
1135 kwonlyargs = getattr(node, 'kwonlyargs', [])
1136 kw_defaults = getattr(node, 'kw_defaults', [])
1137 assert len(kwonlyargs) == len(kw_defaults)
1138
1139 total_args = sum((len(node.args + kwonlyargs),
1140 len(getattr(node, 'posonlyargs', [])),
1141 1 if node.vararg else 0,
1142 1 if node.kwarg else 0))
1143 arg_i = 0
1144
1145 pos_args = getattr(node, 'posonlyargs', []) + node.args
1146 positional = pos_args[:-len(node.defaults)] if node.defaults else pos_args
1147 keyword = node.args[-len(node.defaults):] if node.defaults else node.args
1148
1149 for arg in positional:
1150 self.visit(arg)
1151 arg_i += 1
1152 if arg_i < total_args:
1153 self.attr(node, 'comma_%d' % arg_i, [self.ws, ',', self.ws],
1154 default=', ')
1155 if arg_i == len(getattr(node, 'posonlyargs', [])):
1156 self.attr(node, 'posonly_sep', [self.ws, '/', self.ws, ',', self.ws],
1157 default='/, ')
1158
1159 for i, (arg, default) in enumerate(zip(keyword, node.defaults)):
1160 self.visit(arg)
1161 self.attr(node, 'default_%d' % i, [self.ws, '=', self.ws],
1162 default='=')
1163 self.visit(default)
1164 arg_i += 1
1165 if arg_i < total_args:
1166 self.attr(node, 'comma_%d' % arg_i, [self.ws, ',', self.ws],
1167 default=', ')
1168
1169 if node.vararg:
1170 self.attr(node, 'vararg_prefix', [self.ws, '*', self.ws], default='*')
1171 if isinstance(node.vararg, ast.AST):
1172 self.visit(node.vararg)
1173 else:
1174 self.token(node.vararg)
1175 self.attr(node, 'vararg_suffix', [self.ws])
1176 arg_i += 1
1177 if arg_i < total_args:
1178 self.token(',')
1179 elif kwonlyargs:
1180 # If no vararg, but we have kwonlyargs, insert a naked *, which will
1181 # definitely not be the last arg.
1182 self.attr(node, 'kwonly_sep', [self.ws, '*', self.ws, ',', self.ws]);
1183
1184 for i, (arg, default) in enumerate(zip(kwonlyargs, kw_defaults)):
1185 self.visit(arg)
1186 if default is not None:
1187 self.attr(node, 'kw_default_%d' % i, [self.ws, '=', self.ws],
1188 default='=')
1189 self.visit(default)
1190 arg_i += 1
1191 if arg_i < total_args:
1192 self.attr(node, 'comma_%d' % arg_i, [self.ws, ',', self.ws],
1193 default=', ')
1194
1195 if node.kwarg:
1196 self.attr(node, 'kwarg_prefix', [self.ws, '**', self.ws], default='**')
1197 if isinstance(node.kwarg, ast.AST):
1198 self.visit(node.kwarg)
1199 else:
1200 self.token(node.kwarg)
1201 self.attr(node, 'kwarg_suffix', [self.ws])
1202
1203 @space_around
1204 def visit_comprehension(self, node):
1205 if getattr(node, 'is_async', False):
1206 self.attr(node, 'for', [self.ws, 'async', self.ws, 'for', self.ws],
1207 default=' async for ')
1208 else:
1209 self.attr(node, 'for', [self.ws, 'for', self.ws], default=' for ')
1210 self.visit(node.target)
1211 self.attr(node, 'in', [self.ws, 'in', self.ws], default=' in ')
1212 self.visit(node.iter)
1213 for i, if_expr in enumerate(node.ifs):
1214 self.attr(node, 'if_%d' % i, [self.ws, 'if', self.ws], default=' if ')
1215 self.visit(if_expr)
1216
1217 @space_around
1218 def visit_keyword(self, node):
1219 if node.arg is None:
1220 self.attr(node, 'stars', ['**', self.ws], default='**')
1221 else:
1222 self.token(node.arg)
1223 self.attr(node, 'eq', [self.ws, '='], default='=')
1224 self.visit(node.value)
1225
1226 @space_left
1227 def visit_Index(self, node):
1228 self.visit(node.value)
1229
1230 @space_left
1231 def visit_ExtSlice(self, node):
1232 for i, dim in enumerate(node.dims):
1233 self.visit(dim)
1234 if dim is not node.dims[-1]:
1235 self.attr(node, 'dim_sep_%d' % i, [self.ws, ',', self.ws], default=', ')
1236 self.optional_token(node, 'trailing_comma', ',', default=False)
1237
1238 @space_left
1239 def visit_Slice(self, node):
1240 if node.lower:
1241 self.visit(node.lower)
1242 self.attr(node, 'lowerspace', [self.ws, ':', self.ws], default=':')
1243 if node.upper:
1244 self.visit(node.upper)
1245
1246 self.attr(node, 'stepspace1', [self.ws])
1247 self.optional_token(node, 'step_colon', ':')
1248 self.attr(node, 'stepspace2', [self.ws])
1249 if node.step and self.check_slice_includes_step(node):
1250 self.optional_token(node, 'step_colon_2', ':', default=True)
1251 node.step.is_explicit_step = True
1252 self.visit(node.step)
1253
1254 def check_slice_includes_step(self, node):
1255 """Helper function for Slice node to determine whether to visit its step."""
1256 # This is needed because of a bug in the 2.7 parser which treats
1257 # a[::] as Slice(lower=None, upper=None, step=Name(id='None'))
1258 # but also treats a[::None] exactly the same.
1259 if not node.step:
1260 return False
1261 if getattr(node.step, 'is_explicit_step', False):
1262 return True
1263 return not (isinstance(node.step, ast.Name) and node.step.id == 'None')
1264
1265 @fstring_expression
1266 def visit_FormattedValue(self, node):
1267 self.visit(node.value)
1268 if node.conversion != -1:
1269 self.attr(node, 'conversion',
1270 [self.ws, '!', chr(node.conversion)], deps=('conversion',),
1271 default='!%c' % node.conversion)
1272 if node.format_spec:
1273 self.attr(node, 'format_spec_prefix', [self.ws, ':', self.ws],
1274 default=':')
1275 self.visit(node.format_spec)
1276
1277
1278class AnnotationError(Exception):
1279 """An exception for when we failed to annotate the tree."""
1280
1281
1282class AstAnnotator(BaseVisitor):
1283
1284 def __init__(self, source):
1285 super(AstAnnotator, self).__init__()
1286 self.tokens = token_generator.TokenGenerator(source)
1287
1288 def visit(self, node):
1289 try:
1290 fmt.set(node, 'indent', self._indent)
1291 fmt.set(node, 'indent_diff', self._indent_diff)
1292 fmt.set(node, 'start_line', self.tokens.peek().start[0])
1293 fmt.set(node, 'start_col', self.tokens.peek().start[1])
1294 super(AstAnnotator, self).visit(node)
1295 fmt.set(node, 'end_line', self.tokens.peek().end[0])
1296 fmt.set(node, 'end_col', self.tokens.peek().end[1])
1297 except (TypeError, ValueError, IndexError, KeyError) as e:
1298 raise AnnotationError(e)
1299
1300 def indented(self, node, children_attr):
1301 """Generator which annotates child nodes with their indentation level."""
1302 children = getattr(node, children_attr)
1303 cur_loc = self.tokens._loc
1304 next_loc = self.tokens.peek_non_whitespace().start
1305 # Special case: if the children are on the same line, then there is no
1306 # indentation level to track.
1307 if cur_loc[0] == next_loc[0]:
1308 indent_diff = self._indent_diff
1309 self._indent_diff = None
1310 for child in children:
1311 yield child
1312 self._indent_diff = indent_diff
1313 return
1314
1315 prev_indent = self._indent
1316 prev_indent_diff = self._indent_diff
1317
1318 # Find the indent level of the first child
1319 indent_token = self.tokens.peek_conditional(
1320 lambda t: t.type == token_generator.TOKENS.INDENT)
1321 new_indent = indent_token.src
1322 new_diff = _get_indent_diff(prev_indent, new_indent)
1323 if not new_diff:
1324 new_diff = ' ' * 4 # Sensible default
1325 print('Indent detection failed (line %d); inner indentation level is not '
1326 'more than the outer indentation.' % cur_loc[0], file=sys.stderr)
1327
1328 # Set the indent level to the child's indent and iterate over the children
1329 self._indent = new_indent
1330 self._indent_diff = new_diff
1331 for child in children:
1332 yield child
1333 # Store the suffix at this indentation level, which could be many lines
1334 fmt.set(node, 'block_suffix_%s' % children_attr,
1335 self.tokens.block_whitespace(self._indent))
1336
1337 # Dedent back to the previous level
1338 self._indent = prev_indent
1339 self._indent_diff = prev_indent_diff
1340
1341 @expression
1342 def visit_Num(self, node):
1343 """Annotate a Num node with the exact number format."""
1344 token_number_type = token_generator.TOKENS.NUMBER
1345 contentargs = [lambda: self.tokens.next_of_type(token_number_type).src]
1346 if self.tokens.peek().src == '-':
1347 contentargs.insert(0, '-')
1348 self.attr(node, 'content', contentargs, deps=('n',), default=str(node.n))
1349
1350 @expression
1351 def visit_Str(self, node):
1352 """Annotate a Str node with the exact string format."""
1353 self.attr(node, 'content', [self.tokens.str], deps=('s',), default=node.s)
1354
1355 @expression
1356 def visit_JoinedStr(self, node):
1357 """Annotate a JoinedStr node with the fstr formatting metadata."""
1358 fstr_iter = self.tokens.fstr()()
1359 res = ''
1360 values = (v for v in node.values if isinstance(v, ast.FormattedValue))
1361 while True:
1362 res_part, tg = next(fstr_iter)
1363 res += res_part
1364 if tg is None:
1365 break
1366 prev_tokens = self.tokens
1367 self.tokens = tg
1368 self.visit(next(values))
1369 self.tokens = prev_tokens
1370
1371 self.attr(node, 'content', [lambda: res], default=res)
1372
1373 @expression
1374 def visit_Bytes(self, node):
1375 """Annotate a Bytes node with the exact string format."""
1376 self.attr(node, 'content', [self.tokens.str], deps=('s',), default=node.s)
1377
1378 @space_around
1379 def visit_Ellipsis(self, node):
1380 # Ellipsis is sometimes split into 3 tokens and other times a single token
1381 # Account for both forms when parsing the input.
1382 if self.tokens.peek().src == '...':
1383 self.token('...')
1384 else:
1385 for i in range(3):
1386 self.token('.')
1387
1388 def check_is_elif(self, node):
1389 """Return True iff the If node is an `elif` in the source."""
1390 next_tok = self.tokens.next_name()
1391 return isinstance(node, ast.If) and next_tok.src == 'elif'
1392
1393 def check_is_continued_try(self, node):
1394 """Return True iff the TryExcept node is a continued `try` in the source."""
1395 return (isinstance(node, ast.TryExcept) and
1396 self.tokens.peek_non_whitespace().src != 'try')
1397
1398 def check_is_continued_with(self, node):
1399 """Return True iff the With node is a continued `with` in the source."""
1400 return isinstance(node, ast.With) and self.tokens.peek().src == ','
1401
1402 def check_slice_includes_step(self, node):
1403 """Helper function for Slice node to determine whether to visit its step."""
1404 # This is needed because of a bug in the 2.7 parser which treats
1405 # a[::] as Slice(lower=None, upper=None, step=Name(id='None'))
1406 # but also treats a[::None] exactly the same.
1407 return self.tokens.peek_non_whitespace().src not in '],'
1408
1409 def ws(self, max_lines=None, semicolon=False, comment=True):
1410 """Parse some whitespace from the source tokens and return it."""
1411 next_token = self.tokens.peek()
1412 if semicolon and next_token and next_token.src == ';':
1413 result = self.tokens.whitespace() + self.token(';')
1414 next_token = self.tokens.peek()
1415 if next_token.type in (token_generator.TOKENS.NL,
1416 token_generator.TOKENS.NEWLINE):
1417 result += self.tokens.whitespace(max_lines=1)
1418 return result
1419 return self.tokens.whitespace(max_lines=max_lines, comment=comment)
1420
1421 def dots(self, num_dots):
1422 """Parse a number of dots."""
1423 def _parse_dots():
1424 return self.tokens.dots(num_dots)
1425 return _parse_dots
1426
1427 def block_suffix(self, node, indent_level):
1428 fmt.set(node, 'suffix', self.tokens.block_whitespace(indent_level))
1429
1430 def token(self, token_val):
1431 """Parse a single token with exactly the given value."""
1432 token = self.tokens.next()
1433 if token.src != token_val:
1434 print(type(token.src), type(token_val))
1435 raise AnnotationError("Expected %r but found %r\nline %d: %s" % (
1436 token_val, token.src, token.start[0], token.line))
1437
1438 # If the token opens or closes a parentheses scope, keep track of it
1439 if token.src in '({[':
1440 self.tokens.hint_open()
1441 elif token.src in ')}]':
1442 self.tokens.hint_closed()
1443
1444 return token.src
1445
1446 def optional_token(self, node, attr_name, token_val,
1447 allow_whitespace_prefix=False, default=False):
1448 """Try to parse a token and attach it to the node."""
1449 del default
1450 fmt.append(node, attr_name, '')
1451 token = (self.tokens.peek_non_whitespace()
1452 if allow_whitespace_prefix else self.tokens.peek())
1453 if token and token.src == token_val:
1454 parsed = ''
1455 if allow_whitespace_prefix:
1456 parsed += self.ws()
1457 fmt.append(node, attr_name,
1458 parsed + self.tokens.next().src + self.ws())
1459
1460 def one_of_symbols(self, *symbols):
1461 """Account for one of the given symbols."""
1462 def _one_of_symbols():
1463 next_token = self.tokens.next()
1464 found = next((s for s in symbols if s == next_token.src), None)
1465 if found is None:
1466 raise AnnotationError(
1467 'Expected one of: %r, but found: %r' % (symbols, next_token.src))
1468 return found
1469 return _one_of_symbols
1470
1471 def attr(self, node, attr_name, attr_vals, deps=None, default=None):
1472 """Parses some source and sets an attribute on the given node.
1473
1474 Stores some arbitrary formatting information on the node. This takes a list
1475 attr_vals which tell what parts of the source to parse. The result of each
1476 function is concatenated onto the formatting data, and strings in this list
1477 are a shorthand to look for an exactly matching token.
1478
1479 For example:
1480 self.attr(node, 'foo', ['(', self.ws, 'Hello, world!', self.ws, ')'],
1481 deps=('s',), default=node.s)
1482
1483 is a rudimentary way to parse a parenthesized string. After running this,
1484 the matching source code for this node will be stored in its formatting
1485 dict under the key 'foo'. The result might be `(\n 'Hello, world!'\n)`.
1486
1487 This also keeps track of the current value of each of the dependencies.
1488 In the above example, we would have looked for the string 'Hello, world!'
1489 because that's the value of node.s, however, when we print this back, we
1490 want to know if the value of node.s has changed since this time. If any of
1491 the dependent values has changed, the default would be used instead.
1492
1493 Arguments:
1494 node: (ast.AST) An AST node to attach formatting information to.
1495 attr_name: (string) Name to store the formatting information under.
1496 attr_vals: (list of functions/strings) Each item is either a function
1497 that parses some source and return a string OR a string to match
1498 exactly (as a token).
1499 deps: (optional, set of strings) Attributes of the node which attr_vals
1500 depends on.
1501 default: (string) Unused here.
1502 """
1503 del default # unused
1504 if deps:
1505 for dep in deps:
1506 fmt.set(node, dep + '__src', getattr(node, dep, None))
1507 attr_parts = []
1508 for attr_val in attr_vals:
1509 if isinstance(attr_val, six.string_types):
1510 attr_parts.append(self.token(attr_val))
1511 else:
1512 attr_parts.append(attr_val())
1513 fmt.set(node, attr_name, ''.join(attr_parts))
1514
1515 def scope(self, node, attr=None, trailing_comma=False, default_parens=False):
1516 """Return a context manager to handle a parenthesized scope.
1517
1518 Arguments:
1519 node: (ast.AST) Node to store the scope prefix and suffix on.
1520 attr: (string, optional) Attribute of the node contained in the scope, if
1521 any. For example, as `None`, the scope would wrap the entire node, but
1522 as 'bases', the scope might wrap only the bases of a class.
1523 trailing_comma: (boolean) If True, allow a trailing comma at the end.
1524 default_parens: (boolean) If True and no formatting information is
1525 present, the scope would be assumed to be parenthesized.
1526 """
1527 del default_parens
1528 return self.tokens.scope(node, attr=attr, trailing_comma=trailing_comma)
1529
1530 def _optional_token(self, token_type, token_val):
1531 token = self.tokens.peek()
1532 if not token or token.type != token_type or token.src != token_val:
1533 return ''
1534 else:
1535 self.tokens.next()
1536 return token.src + self.ws()
1537
1538
1539def _get_indent_width(indent):
1540 width = 0
1541 for c in indent:
1542 if c == ' ':
1543 width += 1
1544 elif c == '\t':
1545 width += 8 - (width % 8)
1546 return width
1547
1548
1549def _ltrim_indent(indent, remove_width):
1550 width = 0
1551 for i, c in enumerate(indent):
1552 if width == remove_width:
1553 break
1554 if c == ' ':
1555 width += 1
1556 elif c == '\t':
1557 if width + 8 - (width % 8) <= remove_width:
1558 width += 8 - (width % 8)
1559 else:
1560 return ' ' * (width + 8 - remove_width) + indent[i + 1:]
1561 return indent[i:]
1562
1563
1564def _get_indent_diff(outer, inner):
1565 """Computes the whitespace added to an indented block.
1566
1567 Finds the portion of an indent prefix that is added onto the outer indent. In
1568 most cases, the inner indent starts with the outer indent, but this is not
1569 necessarily true. For example, the outer block could be indented to four
1570 spaces and its body indented with one tab (effectively 8 spaces).
1571
1572 Arguments:
1573 outer: (string) Indentation of the outer block.
1574 inner: (string) Indentation of the inner block.
1575 Returns:
1576 The string whitespace which is added to the indentation level when moving
1577 from outer to inner.
1578 """
1579 outer_w = _get_indent_width(outer)
1580 inner_w = _get_indent_width(inner)
1581 diff_w = inner_w - outer_w
1582
1583 if diff_w <= 0:
1584 return None
1585
1586 return _ltrim_indent(inner, inner_w - diff_w)