1"""
2This module contains utility function and classes to inject simple ast
3transformations based on code strings into IPython. While it is already possible
4with ast-transformers it is not easy to directly manipulate ast.
5
6
7IPython has pre-code and post-code hooks, but are ran from within the IPython
8machinery so may be inappropriate, for example for performance measurement.
9
10This module give you tools to simplify this, and expose 2 classes:
11
12- `ReplaceCodeTransformer` which is a simple ast transformer based on code
13 template,
14
15and for advance case:
16
17- `Mangler` which is a simple ast transformer that mangle names in the ast.
18
19
20Example, let's try to make a simple version of the ``timeit`` magic, that run a
21code snippet 10 times and print the average time taken.
22
23Basically we want to run :
24
25.. code-block:: python
26
27 from time import perf_counter
28 now = perf_counter()
29 for i in range(10):
30 __code__ # our code
31 print(f"Time taken: {(perf_counter() - now)/10}")
32 __ret__ # the result of the last statement
33
34Where ``__code__`` is the code snippet we want to run, and ``__ret__`` is the
35result, so that if we for example run `dataframe.head()` IPython still display
36the head of dataframe instead of nothing.
37
38Here is a complete example of a file `timit2.py` that define such a magic:
39
40.. code-block:: python
41
42 from IPython.core.magic import (
43 Magics,
44 magics_class,
45 line_cell_magic,
46 )
47 from IPython.core.magics.ast_mod import ReplaceCodeTransformer
48 from textwrap import dedent
49 import ast
50
51 template = template = dedent('''
52 from time import perf_counter
53 now = perf_counter()
54 for i in range(10):
55 __code__
56 print(f"Time taken: {(perf_counter() - now)/10}")
57 __ret__
58 '''
59 )
60
61
62 @magics_class
63 class AstM(Magics):
64 @line_cell_magic
65 def t2(self, line, cell):
66 transformer = ReplaceCodeTransformer.from_string(template)
67 transformer.debug = True
68 transformer.mangler.debug = True
69 new_code = transformer.visit(ast.parse(cell))
70 return exec(compile(new_code, "<ast>", "exec"))
71
72
73 def load_ipython_extension(ip):
74 ip.register_magics(AstM)
75
76
77
78.. code-block:: python
79
80 In [1]: %load_ext timit2
81
82 In [2]: %%t2
83 ...: import time
84 ...: time.sleep(0.05)
85 ...:
86 ...:
87 Time taken: 0.05435649999999441
88
89
90If you wish to ran all the code enter in IPython in an ast transformer, you can
91do so as well:
92
93.. code-block:: python
94
95 In [1]: from IPython.core.magics.ast_mod import ReplaceCodeTransformer
96 ...:
97 ...: template = '''
98 ...: from time import perf_counter
99 ...: now = perf_counter()
100 ...: __code__
101 ...: print(f"Code ran in {perf_counter()-now}")
102 ...: __ret__'''
103 ...:
104 ...: get_ipython().ast_transformers.append(ReplaceCodeTransformer.from_string(template))
105
106 In [2]: 1+1
107 Code ran in 3.40410006174352e-05
108 Out[2]: 2
109
110
111
112Hygiene and Mangling
113--------------------
114
115The ast transformer above is not hygienic, it may not work if the user code use
116the same variable names as the ones used in the template. For example.
117
118To help with this by default the `ReplaceCodeTransformer` will mangle all names
119staring with 3 underscores. This is a simple heuristic that should work in most
120case, but can be cumbersome in some case. We provide a `Mangler` class that can
121be overridden to change the mangling heuristic, or simply use the `mangle_all`
122utility function. It will _try_ to mangle all names (except `__ret__` and
123`__code__`), but this include builtins (``print``, ``range``, ``type``) and
124replace those by invalid identifiers py prepending ``mangle-``:
125``mangle-print``, ``mangle-range``, ``mangle-type`` etc. This is not a problem
126as currently Python AST support invalid identifiers, but it may not be the case
127in the future.
128
129You can set `ReplaceCodeTransformer.debug=True` and
130`ReplaceCodeTransformer.mangler.debug=True` to see the code after mangling and
131transforming:
132
133.. code-block:: python
134
135
136 In [1]: from IPython.core.magics.ast_mod import ReplaceCodeTransformer, mangle_all
137 ...:
138 ...: template = '''
139 ...: from builtins import type, print
140 ...: from time import perf_counter
141 ...: now = perf_counter()
142 ...: __code__
143 ...: print(f"Code ran in {perf_counter()-now}")
144 ...: __ret__'''
145 ...:
146 ...: transformer = ReplaceCodeTransformer.from_string(template, mangling_predicate=mangle_all)
147
148
149 In [2]: transformer.debug = True
150 ...: transformer.mangler.debug = True
151 ...: get_ipython().ast_transformers.append(transformer)
152
153 In [3]: 1+1
154 Mangling Alias mangle-type
155 Mangling Alias mangle-print
156 Mangling Alias mangle-perf_counter
157 Mangling now
158 Mangling perf_counter
159 Not mangling __code__
160 Mangling print
161 Mangling perf_counter
162 Mangling now
163 Not mangling __ret__
164 ---- Transformed code ----
165 from builtins import type as mangle-type, print as mangle-print
166 from time import perf_counter as mangle-perf_counter
167 mangle-now = mangle-perf_counter()
168 ret-tmp = 1 + 1
169 mangle-print(f'Code ran in {mangle-perf_counter() - mangle-now}')
170 ret-tmp
171 ---- ---------------- ----
172 Code ran in 0.00013654199938173406
173 Out[3]: 2
174
175
176"""
177
178__skip_doctest__ = True
179
180
181from ast import (
182 NodeTransformer,
183 Store,
184 Load,
185 Name,
186 Expr,
187 Assign,
188 Module,
189 Import,
190 ImportFrom,
191)
192import ast
193import copy
194
195from typing import Dict, Optional, Union
196
197
198mangle_all = lambda name: False if name in ("__ret__", "__code__") else True
199
200
201class Mangler(NodeTransformer):
202 """
203 Mangle given names in and ast tree to make sure they do not conflict with
204 user code.
205 """
206
207 enabled: bool = True
208 debug: bool = False
209
210 def log(self, *args, **kwargs):
211 if self.debug:
212 print(*args, **kwargs)
213
214 def __init__(self, predicate=None):
215 if predicate is None:
216 predicate = lambda name: name.startswith("___")
217 self.predicate = predicate
218
219 def visit_Name(self, node):
220 if self.predicate(node.id):
221 self.log("Mangling", node.id)
222 # Once in the ast we do not need
223 # names to be valid identifiers.
224 node.id = "mangle-" + node.id
225 else:
226 self.log("Not mangling", node.id)
227 return node
228
229 def visit_FunctionDef(self, node):
230 if self.predicate(node.name):
231 self.log("Mangling", node.name)
232 node.name = "mangle-" + node.name
233 else:
234 self.log("Not mangling", node.name)
235
236 for arg in node.args.args:
237 if self.predicate(arg.arg):
238 self.log("Mangling function arg", arg.arg)
239 arg.arg = "mangle-" + arg.arg
240 else:
241 self.log("Not mangling function arg", arg.arg)
242 return self.generic_visit(node)
243
244 def visit_ImportFrom(self, node: ImportFrom):
245 return self._visit_Import_and_ImportFrom(node)
246
247 def visit_Import(self, node: Import):
248 return self._visit_Import_and_ImportFrom(node)
249
250 def _visit_Import_and_ImportFrom(self, node: Union[Import, ImportFrom]):
251 for alias in node.names:
252 asname = alias.name if alias.asname is None else alias.asname
253 if self.predicate(asname):
254 new_name: str = "mangle-" + asname
255 self.log("Mangling Alias", new_name)
256 alias.asname = new_name
257 else:
258 self.log("Not mangling Alias", alias.asname)
259 return node
260
261
262class ReplaceCodeTransformer(NodeTransformer):
263 enabled: bool = True
264 debug: bool = False
265 mangler: Mangler
266
267 def __init__(
268 self, template: Module, mapping: Optional[Dict] = None, mangling_predicate=None
269 ):
270 assert isinstance(mapping, (dict, type(None)))
271 assert isinstance(mangling_predicate, (type(None), type(lambda: None)))
272 assert isinstance(template, ast.Module)
273 self.template = template
274 self.mangler = Mangler(predicate=mangling_predicate)
275 if mapping is None:
276 mapping = {}
277 self.mapping = mapping
278
279 @classmethod
280 def from_string(
281 cls, template: str, mapping: Optional[Dict] = None, mangling_predicate=None
282 ):
283 return cls(
284 ast.parse(template), mapping=mapping, mangling_predicate=mangling_predicate
285 )
286
287 def visit_Module(self, code):
288 if not self.enabled:
289 return code
290 # if not isinstance(code, ast.Module):
291 # recursively called...
292 # return generic_visit(self, code)
293 last = code.body[-1]
294 if isinstance(last, Expr):
295 code.body.pop()
296 code.body.append(Assign([Name("ret-tmp", ctx=Store())], value=last.value))
297 ast.fix_missing_locations(code)
298 ret = Expr(value=Name("ret-tmp", ctx=Load()))
299 ret = ast.fix_missing_locations(ret)
300 self.mapping["__ret__"] = ret
301 else:
302 self.mapping["__ret__"] = ast.parse("None").body[0]
303 self.mapping["__code__"] = code.body
304 tpl = ast.fix_missing_locations(self.template)
305
306 tx = copy.deepcopy(tpl)
307 tx = self.mangler.visit(tx)
308 node = self.generic_visit(tx)
309 node_2 = ast.fix_missing_locations(node)
310 if self.debug:
311 print("---- Transformed code ----")
312 print(ast.unparse(node_2))
313 print("---- ---------------- ----")
314 return node_2
315
316 # this does not work as the name might be in a list and one might want to extend the list.
317 # def visit_Name(self, name):
318 # if name.id in self.mapping and name.id == "__ret__":
319 # print(name, "in mapping")
320 # if isinstance(name.ctx, ast.Store):
321 # return Name("tmp", ctx=Store())
322 # else:
323 # return copy.deepcopy(self.mapping[name.id])
324 # return name
325
326 def visit_Expr(self, expr):
327 if isinstance(expr.value, Name) and expr.value.id in self.mapping:
328 if self.mapping[expr.value.id] is not None:
329 return copy.deepcopy(self.mapping[expr.value.id])
330 return self.generic_visit(expr)