1"""Helpers for AST (Abstract Syntax Tree)."""
2
3from __future__ import annotations
4
5import ast
6from typing import NoReturn, overload
7
8OPERATORS: dict[type[ast.AST], str] = {
9 ast.Add: "+",
10 ast.And: "and",
11 ast.BitAnd: "&",
12 ast.BitOr: "|",
13 ast.BitXor: "^",
14 ast.Div: "/",
15 ast.FloorDiv: "//",
16 ast.Invert: "~",
17 ast.LShift: "<<",
18 ast.MatMult: "@",
19 ast.Mult: "*",
20 ast.Mod: "%",
21 ast.Not: "not",
22 ast.Pow: "**",
23 ast.Or: "or",
24 ast.RShift: ">>",
25 ast.Sub: "-",
26 ast.UAdd: "+",
27 ast.USub: "-",
28}
29
30
31@overload
32def unparse(node: None, code: str = '') -> None:
33 ...
34
35
36@overload
37def unparse(node: ast.AST, code: str = '') -> str:
38 ...
39
40
41def unparse(node: ast.AST | None, code: str = '') -> str | None:
42 """Unparse an AST to string."""
43 if node is None:
44 return None
45 elif isinstance(node, str):
46 return node
47 return _UnparseVisitor(code).visit(node)
48
49
50# a greatly cut-down version of `ast._Unparser`
51class _UnparseVisitor(ast.NodeVisitor):
52 def __init__(self, code: str = '') -> None:
53 self.code = code
54
55 def _visit_op(self, node: ast.AST) -> str:
56 return OPERATORS[node.__class__]
57 for _op in OPERATORS:
58 locals()[f'visit_{_op.__name__}'] = _visit_op
59
60 def visit_arg(self, node: ast.arg) -> str:
61 if node.annotation:
62 return f"{node.arg}: {self.visit(node.annotation)}"
63 else:
64 return node.arg
65
66 def _visit_arg_with_default(self, arg: ast.arg, default: ast.AST | None) -> str:
67 """Unparse a single argument to a string."""
68 name = self.visit(arg)
69 if default:
70 if arg.annotation:
71 name += " = %s" % self.visit(default)
72 else:
73 name += "=%s" % self.visit(default)
74 return name
75
76 def visit_arguments(self, node: ast.arguments) -> str:
77 defaults: list[ast.expr | None] = list(node.defaults)
78 positionals = len(node.args)
79 posonlyargs = len(node.posonlyargs)
80 positionals += posonlyargs
81 for _ in range(len(defaults), positionals):
82 defaults.insert(0, None)
83
84 kw_defaults: list[ast.expr | None] = list(node.kw_defaults)
85 for _ in range(len(kw_defaults), len(node.kwonlyargs)):
86 kw_defaults.insert(0, None)
87
88 args: list[str] = [self._visit_arg_with_default(arg, defaults[i])
89 for i, arg in enumerate(node.posonlyargs)]
90
91 if node.posonlyargs:
92 args.append('/')
93
94 for i, arg in enumerate(node.args):
95 args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))
96
97 if node.vararg:
98 args.append("*" + self.visit(node.vararg))
99
100 if node.kwonlyargs and not node.vararg:
101 args.append('*')
102 for i, arg in enumerate(node.kwonlyargs):
103 args.append(self._visit_arg_with_default(arg, kw_defaults[i]))
104
105 if node.kwarg:
106 args.append("**" + self.visit(node.kwarg))
107
108 return ", ".join(args)
109
110 def visit_Attribute(self, node: ast.Attribute) -> str:
111 return f"{self.visit(node.value)}.{node.attr}"
112
113 def visit_BinOp(self, node: ast.BinOp) -> str:
114 # Special case ``**`` to not have surrounding spaces.
115 if isinstance(node.op, ast.Pow):
116 return "".join(map(self.visit, (node.left, node.op, node.right)))
117 return " ".join(map(self.visit, (node.left, node.op, node.right)))
118
119 def visit_BoolOp(self, node: ast.BoolOp) -> str:
120 op = " %s " % self.visit(node.op)
121 return op.join(self.visit(e) for e in node.values)
122
123 def visit_Call(self, node: ast.Call) -> str:
124 args = ', '.join(
125 [self.visit(e) for e in node.args]
126 + [f"{k.arg}={self.visit(k.value)}" for k in node.keywords],
127 )
128 return f"{self.visit(node.func)}({args})"
129
130 def visit_Constant(self, node: ast.Constant) -> str:
131 if node.value is Ellipsis:
132 return "..."
133 elif isinstance(node.value, (int, float, complex)):
134 if self.code:
135 return ast.get_source_segment(self.code, node) or repr(node.value)
136 else:
137 return repr(node.value)
138 else:
139 return repr(node.value)
140
141 def visit_Dict(self, node: ast.Dict) -> str:
142 keys = (self.visit(k) for k in node.keys if k is not None)
143 values = (self.visit(v) for v in node.values)
144 items = (k + ": " + v for k, v in zip(keys, values))
145 return "{" + ", ".join(items) + "}"
146
147 def visit_Lambda(self, node: ast.Lambda) -> str:
148 return "lambda %s: ..." % self.visit(node.args)
149
150 def visit_List(self, node: ast.List) -> str:
151 return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"
152
153 def visit_Name(self, node: ast.Name) -> str:
154 return node.id
155
156 def visit_Set(self, node: ast.Set) -> str:
157 return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"
158
159 def visit_Slice(self, node: ast.Slice) -> str:
160 if not node.lower and not node.upper and not node.step:
161 # Empty slice with default values -> [:]
162 return ":"
163
164 start = self.visit(node.lower) if node.lower else ""
165 stop = self.visit(node.upper) if node.upper else ""
166 if not node.step:
167 # Default step size -> [start:stop]
168 return f"{start}:{stop}"
169
170 step = self.visit(node.step) if node.step else ""
171 return f"{start}:{stop}:{step}"
172
173 def visit_Subscript(self, node: ast.Subscript) -> str:
174 def is_simple_tuple(value: ast.expr) -> bool:
175 return (
176 isinstance(value, ast.Tuple)
177 and bool(value.elts)
178 and not any(isinstance(elt, ast.Starred) for elt in value.elts)
179 )
180
181 if is_simple_tuple(node.slice):
182 elts = ", ".join(self.visit(e)
183 for e in node.slice.elts) # type: ignore[attr-defined]
184 return f"{self.visit(node.value)}[{elts}]"
185 return f"{self.visit(node.value)}[{self.visit(node.slice)}]"
186
187 def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
188 # UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``,
189 # ``-x``, ``~x``, and ``not x``. Only Not needs a space.
190 if isinstance(node.op, ast.Not):
191 return f"{self.visit(node.op)} {self.visit(node.operand)}"
192 return f"{self.visit(node.op)}{self.visit(node.operand)}"
193
194 def visit_Tuple(self, node: ast.Tuple) -> str:
195 if len(node.elts) == 0:
196 return "()"
197 elif len(node.elts) == 1:
198 return "(%s,)" % self.visit(node.elts[0])
199 else:
200 return "(" + ", ".join(self.visit(e) for e in node.elts) + ")"
201
202 def generic_visit(self, node: ast.AST) -> NoReturn:
203 raise NotImplementedError('Unable to parse %s object' % type(node).__name__)