Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/asteval/asteval.py: 68%
625 statements
« prev ^ index » next coverage.py v7.2.6, created at 2023-05-28 06:19 +0000
« prev ^ index » next coverage.py v7.2.6, created at 2023-05-28 06:19 +0000
1#!/usr/bin/env python
2"""
3Safe(ish) evaluation of minimal Python code using Python's ast module.
5This module provides an Interpreter class that compiles a restricted set of
6Python expressions and statements to Python's AST representation, and then
7executes that representation using values held in a symbol table.
9The symbol table is a simple dictionary, giving a flat namespace. This comes
10pre-loaded with many functions from Python's builtin and math module. If numpy
11is installed, many numpy functions are also included. Additional symbols can
12be added when an Interpreter is created, but the user of that interpreter will
13not be able to import additional modules.
15Expressions, including loops, conditionals, and function definitions can be
16compiled into ast node and then evaluated later, using the current values
17in the symbol table.
19The result is a restricted, simplified version of Python meant for
20numerical calculations that is somewhat safer than 'eval' because many
21unsafe operations (such as 'import' and 'eval') are simply not allowed.
23Many parts of Python syntax are supported, including:
24 for loops, while loops, if-then-elif-else conditionals, with,
25 try-except-finally
26 function definitions with def
27 advanced slicing: a[::-1], array[-3:, :, ::2]
28 if-expressions: out = one_thing if TEST else other
29 list, dict, and set comprehension
31The following Python syntax elements are not supported:
32 Import, Exec, Lambda, Class, Global, Generators,
33 Yield, Decorators
35In addition, while many builtin functions are supported, several builtin
36functions that are considered unsafe are missing ('eval', 'exec', and
37'getattr' for example)
38"""
39import ast
40import sys
41import copy
42import inspect
43import time
44from sys import exc_info, stderr, stdout
46from .astutils import (HAS_NUMPY, UNSAFE_ATTRS, ExceptionHolder, ReturnedNone,
47 make_symbol_table, numpy, op2func, valid_symbol_name,
48 Procedure)
50ALL_NODES = ['arg', 'assert', 'assign', 'attribute', 'augassign', 'binop',
51 'boolop', 'break', 'bytes', 'call', 'compare', 'constant',
52 'continue', 'delete', 'dict', 'dictcomp', 'ellipsis',
53 'excepthandler', 'expr', 'extslice', 'for', 'functiondef', 'if',
54 'ifexp', 'import', 'importfrom', 'index', 'interrupt', 'list',
55 'listcomp', 'module', 'name', 'nameconstant', 'num', 'pass',
56 'print', 'raise', 'repr', 'return', 'set', 'setcomp', 'slice',
57 'str', 'subscript', 'try', 'tuple', 'unaryop', 'while', 'with',
58 'formattedvalue', 'joinedstr']
60DEF_DISABLED = ('import', 'importfrom')
61MIN_DISABLED = ('import', 'importfrom', 'if', 'for', 'while', 'try', 'with',
62 'functiondef', 'ifexp', 'listcomp', 'dictcomp', 'setcomp',
63 'augassign', 'assert', 'delete', 'raise', 'print')
65class Interpreter:
66 """create an asteval Interpreter: a restricted, simplified interpreter
67 of mathematical expressions using Python syntax.
69 Parameters
70 ----------
71 symtable : dict or `None`
72 dictionary to use as symbol table (if `None`, one will be created).
73 usersyms : dict or `None`
74 dictionary of user-defined symbols to add to symbol table.
75 writer : file-like or `None`
76 callable file-like object where standard output will be sent.
77 err_writer : file-like or `None`
78 callable file-like object where standard error will be sent.
79 use_numpy : bool
80 whether to use functions from numpy.
81 max_statement_length : int
82 maximum length of expression allowed [50,000 characters]
83 readonly_symbols : iterable or `None`
84 symbols that the user can not assign to
85 builtins_readonly : bool
86 whether to blacklist all symbols that are in the initial symtable
87 minimal : bool
88 create a minimal interpreter: disable many nodes (see Note 1).
89 config : dict
90 dictionay listing which nodes to support (see note 2))
92 Notes
93 -----
94 1. setting `minimal=True` is equivalent to setting a config with the following
95 nodes disabled: ('import', 'importfrom', 'if', 'for', 'while', 'try', 'with',
96 'functiondef', 'ifexp', 'listcomp', 'dictcomp', 'setcomp', 'augassign',
97 'assert', 'delete', 'raise', 'print')
98 2. by default 'import' and 'importfrom' are disabled, though they can be enabled.
99 """
100 def __init__(self, symtable=None, usersyms=None, writer=None,
101 err_writer=None, use_numpy=True, max_statement_length=50000,
102 minimal=False, readonly_symbols=None, builtins_readonly=False,
103 config=None, **kws):
105 self.config = {}
106 disabled = MIN_DISABLED if minimal else DEF_DISABLED
107 for node in ALL_NODES:
108 self.config[node] = node not in disabled
110 if config is not None:
111 self.config.update(config)
113 if len(kws) > 0:
114 for key, val in kws.items():
115 if key.startswith('no_'):
116 node = key[3:]
117 if node in self.config:
118 self.config[node] = not val
119 elif key.startswith('with_'):
120 node = key[5:]
121 if node in self.config:
122 self.config[node] = val
124 self.writer = writer or stdout
125 self.err_writer = err_writer or stderr
126 self.max_statement_length = max(1, min(1.e8, max_statement_length))
128 if symtable is None:
129 if usersyms is None:
130 usersyms = {}
131 symtable = make_symbol_table(use_numpy=use_numpy, **usersyms)
133 symtable['print'] = self._printer
134 self.symtable = symtable
135 self._interrupt = None
136 self.error = []
137 self.error_msg = None
138 self.expr = None
139 self.retval = None
140 self._calldepth = 0
141 self.lineno = 0
142 self.start_time = time.time()
143 self.use_numpy = HAS_NUMPY and use_numpy
145 self.node_handlers = {}
146 for node, use in self.config.items():
147 handler = getattr(self, f"on_{node}", self.unimplemented)
148 if not use:
149 handler = self.unimplemented
150 self.node_handlers[node] = handler
152 # to rationalize try/except try/finally
153 if 'try' in self.node_handlers:
154 self.node_handlers['tryexcept'] = self.node_handlers['try']
155 self.node_handlers['tryfinally'] = self.node_handlers['try']
157 if readonly_symbols is None:
158 self.readonly_symbols = set()
159 else:
160 self.readonly_symbols = set(readonly_symbols)
162 if builtins_readonly:
163 self.readonly_symbols |= set(self.symtable)
165 self.no_deepcopy = [key for key, val in symtable.items()
166 if (callable(val)
167 or inspect.ismodule(val)
168 or 'numpy.lib.index_tricks' in repr(type(val)))]
170 def remove_nodehandler(self, node):
171 """remove support for a node
172 returns current node handler, so that it
173 might be re-added with add_nodehandler()
174 """
175 out = None
176 if node in self.node_handlers:
177 out = self.node_handlers.pop(node)
178 return out
180 def set_nodehandler(self, node, handler=None):
181 """set node handler or use current built-in default"""
182 if handler is None:
183 handler = getattr(self, f"on_{node}", self.unimplemented)
184 self.node_handlers[node] = handler
185 return handler
187 def user_defined_symbols(self):
188 """Return a set of symbols that have been added to symtable after
189 construction.
191 I.e., the symbols from self.symtable that are not in
192 self.no_deepcopy.
194 Returns
195 -------
196 unique_symbols : set
197 symbols in symtable that are not in self.no_deepcopy
199 """
200 sym_in_current = set(self.symtable.keys())
201 sym_from_construction = set(self.no_deepcopy)
202 unique_symbols = sym_in_current.difference(sym_from_construction)
203 return unique_symbols
205 def unimplemented(self, node):
206 """Unimplemented nodes."""
207 msg = f"{node.__class__.__name__} not supported"
208 self.raise_exception(node, exc=NotImplementedError, msg=msg)
210 def raise_exception(self, node, exc=None, msg='', expr=None,
211 lineno=None):
212 """Add an exception."""
213 if self.error is None:
214 self.error = []
215 if expr is None:
216 expr = self.expr
217 if len(self.error) > 0 and not isinstance(node, ast.Module):
218 msg = f'{msg!s}'
219 err = ExceptionHolder(node, exc=exc, msg=msg, expr=expr, lineno=lineno)
220 self._interrupt = ast.Raise()
221 self.error.append(err)
222 if self.error_msg is None:
223 self.error_msg = (' '.join([msg, f"at expr='{self.expr}'"])).strip()
224 elif len(msg) > 0:
225 self.error_msg = msg
226 if exc is None:
227 try:
228 exc = self.error[0].exc
229 except:
230 exc = RuntimeError
231 raise exc(self.error_msg)
233 # main entry point for Ast node evaluation
234 # parse: text of statements -> ast
235 # run: ast -> result
236 # eval: string statement -> result = run(parse(statement))
237 def parse(self, text):
238 """Parse statement/expression to Ast representation."""
239 if len(text) > self.max_statement_length:
240 msg = f'length of text exceeds {self.max_statement_length:d} characters'
241 self.raise_exception(None, exc=RuntimeError, expr=msg)
242 self.expr = text
243 try:
244 out = ast.parse(text)
245 except SyntaxError:
246 self.raise_exception(None, exc=SyntaxError, expr=text)
247 except:
248 self.raise_exception(None, exc=RuntimeError, expr=text)
250 return out
252 def run(self, node, expr=None, lineno=None, with_raise=True):
253 """Execute parsed Ast representation for an expression."""
254 # Note: keep the 'node is None' test: internal code here may run
255 # run(None) and expect a None in return.
256 out = None
257 if len(self.error) > 0:
258 return out
259 if self.retval is not None:
260 return self.retval
261 if isinstance(self._interrupt, (ast.Break, ast.Continue)):
262 return self._interrupt
263 if node is None:
264 return out
265 if isinstance(node, str):
266 node = self.parse(node)
267 if lineno is not None:
268 self.lineno = lineno
269 if expr is not None:
270 self.expr = expr
272 # get handler for this node:
273 # on_xxx with handle nodes of type 'xxx', etc
274 try:
275 handler = self.node_handlers[node.__class__.__name__.lower()]
276 except KeyError:
277 self.raise_exception(None, exc=NotImplementedError, expr=expr)
279 # run the handler: this will likely generate
280 # recursive calls into this run method.
281 try:
282 ret = handler(node)
283 if isinstance(ret, enumerate):
284 ret = list(ret)
285 return ret
286 except:
287 if with_raise:
288 if len(self.error) == 0:
289 # Unhandled exception that didn't use raise_exception
290 self.raise_exception(node, expr=expr)
291 raise
292 return None
294 def __call__(self, expr, **kw):
295 """Call class instance as function."""
296 return self.eval(expr, **kw)
298 def eval(self, expr, lineno=0, show_errors=True, raise_errors=False):
299 """Evaluate a single statement."""
300 self.lineno = lineno
301 self.error = []
302 self.start_time = time.time()
303 if isinstance(expr, str):
304 try:
305 node = self.parse(expr)
306 except Exception:
307 errmsg = exc_info()[1]
308 if len(self.error) > 0:
309 errmsg = "\n".join(self.error[0].get_error())
310 if raise_errors:
311 try:
312 exc = self.error[0].exc
313 except Exception:
314 exc = RuntimeError
315 raise exc(errmsg)
316 if show_errors:
317 print(errmsg, file=self.err_writer)
318 return None
319 else:
320 node = expr
321 try:
322 return self.run(node, expr=expr, lineno=lineno)
323 except:
324 errmsg = exc_info()[1]
325 if len(self.error) > 0:
326 errmsg = "\n".join(self.error[0].get_error())
327 if raise_errors:
328 try:
329 exc = self.error[0].exc
330 except Exception:
331 exc = RuntimeError
332 raise exc(errmsg)
333 if show_errors:
334 print(errmsg, file=self.err_writer)
337 @staticmethod
338 def dump(node, **kw):
339 """Simple ast dumper."""
340 return ast.dump(node, **kw)
342 # handlers for ast components
343 def on_expr(self, node):
344 """Expression."""
345 return self.run(node.value) # ('value',)
347 # imports
348 def on_import(self, node): # ('names',)
349 "simple import"
350 for tnode in node.names:
351 self.import_module(tnode.name, tnode.asname)
353 def on_importfrom(self, node): # ('module', 'names', 'level')
354 "import/from"
355 fromlist, asname = [], []
356 for tnode in node.names:
357 fromlist.append(tnode.name)
358 asname.append(tnode.asname)
359 self.import_module(node.module, asname, fromlist=fromlist)
361 def import_module(self, name, asname, fromlist=None):
362 """import a python module, installing it into the symbol table.
363 options:
364 name name of module to import 'foo' in 'import foo'
365 asname alias for imported name(s)
366 'bar' in 'import foo as bar'
367 or
368 ['s','t'] in 'from foo import x as s, y as t'
369 fromlist list of symbols to import with 'from-import'
370 ['x','y'] in 'from foo import x, y'
371 """
372 # find module in sys.modules or import to it
373 if name in sys.modules:
374 thismod = sys.modules[name]
375 else:
376 try:
377 __import__(name)
378 thismod = sys.modules[name]
379 except:
380 self.raise_exception(None, exc=ImportError, msg='Import Error')
382 if fromlist is None:
383 if asname is not None:
384 self.symtable[asname] = sys.modules[name]
385 else:
386 mparts = []
387 parts = name.split('.')
388 while len(parts) > 0:
389 mparts.append(parts.pop(0))
390 modname = '.'.join(mparts)
391 inname = name if (len(parts) == 0) else modname
392 self.symtable[inname] = sys.modules[modname]
393 else: # import-from construct
394 if asname is None:
395 asname = [None]*len(fromlist)
396 for sym, alias in zip(fromlist, asname):
397 if alias is None:
398 alias = sym
399 self.symtable[alias] = getattr(thismod, sym)
401 def on_index(self, node):
402 """Index."""
403 return self.run(node.value) # ('value',)
405 def on_return(self, node): # ('value',)
406 """Return statement: look for None, return special sentinel."""
407 if self._calldepth == 0:
408 raise SyntaxError('cannot return at top level')
409 self.retval = self.run(node.value)
410 if self.retval is None:
411 self.retval = ReturnedNone
413 def on_repr(self, node):
414 """Repr."""
415 return repr(self.run(node.value)) # ('value',)
417 def on_module(self, node): # ():('body',)
418 """Module def."""
419 out = None
420 for tnode in node.body:
421 out = self.run(tnode)
422 return out
424 def on_expression(self, node):
425 "basic expression"
426 return self.on_module(node) # ():('body',)
428 def on_pass(self, node):
429 """Pass statement."""
430 return None # ()
432 def on_ellipsis(self, node):
433 """Ellipses. deprecated in 3.8"""
434 return Ellipsis
436 # for break and continue: set the instance variable _interrupt
437 def on_interrupt(self, node): # ()
438 """Interrupt handler."""
439 self._interrupt = node
440 return node
442 def on_break(self, node):
443 """Break."""
444 return self.on_interrupt(node)
446 def on_continue(self, node):
447 """Continue."""
448 return self.on_interrupt(node)
450 def on_assert(self, node): # ('test', 'msg')
451 """Assert statement."""
452 if not self.run(node.test):
453 msg = node.msg.s if node.msg else ""
454 self.raise_exception(node, exc=AssertionError, msg=msg)
455 return True
457 def on_list(self, node): # ('elt', 'ctx')
458 """List."""
459 return [self.run(e) for e in node.elts]
461 def on_tuple(self, node): # ('elts', 'ctx')
462 """Tuple."""
463 return tuple(self.on_list(node))
465 def on_set(self, node): # ('elts')
466 """Set."""
467 return set([self.run(k) for k in node.elts])
469 def on_dict(self, node): # ('keys', 'values')
470 """Dictionary."""
471 return {self.run(k): self.run(v) for k, v in
472 zip(node.keys, node.values)}
474 def on_constant(self, node): # ('value', 'kind')
475 """Return constant value."""
476 return node.value
478 def on_num(self, node): # ('n',)
479 """Return number. deprecated in 3.8"""
480 return node.n
482 def on_str(self, node): # ('s',)
483 """Return string. deprecated in 3.8"""
484 return node.s
486 def on_bytes(self, node):
487 """return bytes. deprecated in 3.8"""
488 return node.s # ('s',)
490 def on_joinedstr(self, node): # ('values',)
491 "join strings, used in f-strings"
492 return ''.join([self.run(k) for k in node.values])
494 def on_formattedvalue(self, node): # ('value', 'conversion', 'format_spec')
495 "formatting used in f-strings"
496 val = self.run(node.value)
497 fstring_converters = {115: str, 114: repr, 97: ascii}
498 if node.conversion in fstring_converters:
499 val = fstring_converters[node.conversion](val)
500 fmt = '{0}'
501 if node.format_spec is not None:
502 fmt = f'{{0:{self.run(node.format_spec)}}}'
503 return fmt.format(val)
505 def on_name(self, node): # ('id', 'ctx')
506 """Name node."""
507 ctx = node.ctx.__class__
508 if ctx in (ast.Param, ast.Del):
509 return str(node.id)
510 if node.id in self.symtable:
511 return self.symtable[node.id]
512 msg = f"name '{node.id}' is not defined"
513 self.raise_exception(node, exc=NameError, msg=msg)
515 def on_nameconstant(self, node):
516 """True, False, or None deprecated in 3.8"""
517 return node.value
519 def node_assign(self, node, val):
520 """Assign a value (not the node.value object) to a node.
522 This is used by on_assign, but also by for, list comprehension,
523 etc.
525 """
526 if node.__class__ == ast.Name:
527 if (not valid_symbol_name(node.id) or
528 node.id in self.readonly_symbols):
529 errmsg = f"invalid symbol name (reserved word?) {node.id}"
530 self.raise_exception(node, exc=NameError, msg=errmsg)
531 self.symtable[node.id] = val
532 if node.id in self.no_deepcopy:
533 self.no_deepcopy.remove(node.id)
535 elif node.__class__ == ast.Attribute:
536 if node.ctx.__class__ == ast.Load:
537 msg = f"cannot assign to attribute {node.attr}"
538 self.raise_exception(node, exc=AttributeError, msg=msg)
540 setattr(self.run(node.value), node.attr, val)
542 elif node.__class__ == ast.Subscript:
543 self.run(node.value)[self.run(node.slice)] = val
545 elif node.__class__ in (ast.Tuple, ast.List):
546 if len(val) == len(node.elts):
547 for telem, tval in zip(node.elts, val):
548 self.node_assign(telem, tval)
549 else:
550 raise ValueError('too many values to unpack')
552 def on_attribute(self, node): # ('value', 'attr', 'ctx')
553 """Extract attribute."""
554 ctx = node.ctx.__class__
555 if ctx == ast.Store:
556 msg = "attribute for storage: shouldn't be here!"
557 self.raise_exception(node, exc=RuntimeError, msg=msg)
559 sym = self.run(node.value)
560 if ctx == ast.Del:
561 return delattr(sym, node.attr)
563 # ctx is ast.Load
564 if not (node.attr in UNSAFE_ATTRS or
565 (node.attr.startswith('__') and
566 node.attr.endswith('__'))):
567 try:
568 return getattr(sym, node.attr)
569 except AttributeError:
570 pass
572 # AttributeError or accessed unsafe attribute
573 msg = f"no attribute '{node.attr}' for {self.run(node.value)}"
574 self.raise_exception(node, exc=AttributeError, msg=msg)
576 def on_assign(self, node): # ('targets', 'value')
577 """Simple assignment."""
578 val = self.run(node.value)
579 for tnode in node.targets:
580 self.node_assign(tnode, val)
582 def on_augassign(self, node): # ('target', 'op', 'value')
583 """Augmented assign."""
584 return self.on_assign(ast.Assign(targets=[node.target],
585 value=ast.BinOp(left=node.target,
586 op=node.op,
587 right=node.value)))
589 def on_slice(self, node): # ():('lower', 'upper', 'step')
590 """Simple slice."""
591 return slice(self.run(node.lower),
592 self.run(node.upper),
593 self.run(node.step))
595 def on_extslice(self, node): # ():('dims',)
596 """Extended slice."""
597 return tuple([self.run(tnode) for tnode in node.dims])
599 def on_subscript(self, node): # ('value', 'slice', 'ctx')
600 """Subscript handling -- one of the tricky parts."""
601 val = self.run(node.value)
602 nslice = self.run(node.slice)
603 ctx = node.ctx.__class__
604 if ctx in (ast.Load, ast.Store):
605 return val[nslice]
606 msg = "subscript with unknown context"
607 self.raise_exception(node, msg=msg)
609 def on_delete(self, node): # ('targets',)
610 """Delete statement."""
611 for tnode in node.targets:
612 if tnode.ctx.__class__ != ast.Del:
613 break
614 children = []
615 while tnode.__class__ == ast.Attribute:
616 children.append(tnode.attr)
617 tnode = tnode.value
618 if (tnode.__class__ == ast.Name and
619 tnode.id not in self.readonly_symbols):
620 children.append(tnode.id)
621 children.reverse()
622 self.symtable.pop('.'.join(children))
623 else:
624 msg = "could not delete symbol"
625 self.raise_exception(node, msg=msg)
627 def on_unaryop(self, node): # ('op', 'operand')
628 """Unary operator."""
629 return op2func(node.op)(self.run(node.operand))
631 def on_binop(self, node): # ('left', 'op', 'right')
632 """Binary operator."""
633 return op2func(node.op)(self.run(node.left),
634 self.run(node.right))
636 def on_boolop(self, node): # ('op', 'values')
637 """Boolean operator."""
638 val = self.run(node.values[0])
639 is_and = ast.And == node.op.__class__
640 if (is_and and val) or (not is_and and not val):
641 for nodeval in node.values[1:]:
642 val = op2func(node.op)(val, self.run(nodeval))
643 if (is_and and not val) or (not is_and and val):
644 break
645 return val
647 def on_compare(self, node): # ('left', 'ops', 'comparators')
648 """comparison operators, including chained comparisons (a<b<c)"""
649 lval = self.run(node.left)
650 results = []
651 for oper, rnode in zip(node.ops, node.comparators):
652 rval = self.run(rnode)
653 ret = op2func(oper)(lval, rval)
654 results.append(ret)
655 if ((self.use_numpy and not isinstance(ret, numpy.ndarray)) and
656 not ret):
657 break
658 lval = rval
659 if len(results) == 1:
660 return results[0]
661 out = True
662 for ret in results:
663 out = out and ret
664 return out
666 def _printer(self, *out, **kws):
667 """Generic print function."""
668 if self.config['print']:
669 flush = kws.pop('flush', True)
670 fileh = kws.pop('file', self.writer)
671 sep = kws.pop('sep', ' ')
672 end = kws.pop('sep', '\n')
673 print(*out, file=fileh, sep=sep, end=end)
674 if flush:
675 fileh.flush()
677 def on_if(self, node): # ('test', 'body', 'orelse')
678 """Regular if-then-else statement."""
679 block = node.body
680 if not self.run(node.test):
681 block = node.orelse
682 for tnode in block:
683 self.run(tnode)
685 def on_ifexp(self, node): # ('test', 'body', 'orelse')
686 """If expressions."""
687 expr = node.orelse
688 if self.run(node.test):
689 expr = node.body
690 return self.run(expr)
692 def on_while(self, node): # ('test', 'body', 'orelse')
693 """While blocks."""
694 while self.run(node.test):
695 self._interrupt = None
696 for tnode in node.body:
697 self.run(tnode)
698 if self._interrupt is not None:
699 break
700 if isinstance(self._interrupt, ast.Break):
701 break
702 else:
703 for tnode in node.orelse:
704 self.run(tnode)
705 self._interrupt = None
707 def on_for(self, node): # ('target', 'iter', 'body', 'orelse')
708 """For blocks."""
709 for val in self.run(node.iter):
710 self.node_assign(node.target, val)
711 self._interrupt = None
712 for tnode in node.body:
713 self.run(tnode)
714 if self._interrupt is not None:
715 break
716 if isinstance(self._interrupt, ast.Break):
717 break
718 else:
719 for tnode in node.orelse:
720 self.run(tnode)
721 self._interrupt = None
723 def on_with(self, node): # ('items', 'body', 'type_comment')
724 """with blocks."""
725 contexts = []
726 for item in node.items:
727 ctx = self.run(item.context_expr)
728 contexts.append(ctx)
729 if hasattr(ctx, '__enter__'):
730 result = ctx.__enter__()
731 if item.optional_vars is not None:
732 self.node_assign(item.optional_vars, result)
733 else:
734 msg = "object does not support the context manager protocol"
735 raise TypeError(f"'{type(ctx)}' {msg}")
736 for bnode in node.body:
737 self.run(bnode)
738 if self._interrupt is not None:
739 break
741 for ctx in contexts:
742 if hasattr(ctx, '__exit__'):
743 ctx.__exit__()
746 def comprehension_data(self, node): # ('elt', 'generators')
747 "data for comprehensions"
748 mylocals = {}
749 saved_syms = {}
751 for tnode in node.generators:
752 if tnode.__class__ == ast.comprehension:
753 if tnode.target.__class__ == ast.Name:
754 if (not valid_symbol_name(tnode.target.id) or
755 tnode.target.id in self.readonly_symbols):
756 errmsg = f"invalid symbol name (reserved word?) {tnode.target.id}"
757 self.raise_exception(tnode.target, exc=NameError, msg=errmsg)
758 mylocals[tnode.target.id] = []
759 if tnode.target.id in self.symtable:
760 saved_syms[tnode.target.id] = copy.deepcopy(self.symtable[tnode.target.id])
762 elif tnode.target.__class__ == ast.Tuple:
763 target = []
764 for tval in tnode.target.elts:
765 mylocals[tval.id] = []
766 if tval.id in self.symtable:
767 saved_syms[tval.id] = copy.deepcopy(self.symtable[tval.id])
769 for tnode in node.generators:
770 if tnode.__class__ == ast.comprehension:
771 ttype = 'name'
772 if tnode.target.__class__ == ast.Name:
773 if (not valid_symbol_name(tnode.target.id) or
774 tnode.target.id in self.readonly_symbols):
775 errmsg = f"invalid symbol name (reserved word?) {tnode.target.id}"
776 self.raise_exception(tnode.target, exc=NameError, msg=errmsg)
777 ttype, target = 'name', tnode.target.id
778 elif tnode.target.__class__ == ast.Tuple:
779 ttype = 'tuple'
780 target =tuple([tval.id for tval in tnode.target.elts])
782 for val in self.run(tnode.iter):
783 if ttype == 'name':
784 self.symtable[target] = val
785 else:
786 for telem, tval in zip(target, val):
787 self.symtable[target] = val
789 add = True
790 for cond in tnode.ifs:
791 add = add and self.run(cond)
792 if add:
793 if ttype == 'name':
794 mylocals[target].append(val)
795 else:
796 for telem, tval in zip(target, val):
797 mylocals[telem].append(tval)
798 return mylocals, saved_syms
800 def on_listcomp(self, node):
801 """List comprehension"""
802 mylocals, saved_syms = self.comprehension_data(node)
804 names = list(mylocals.keys())
805 data = list(mylocals.values())
806 def listcomp_recurse(out, i, names, data):
807 if i == len(names):
808 out.append(self.run(node.elt))
809 return
811 for val in data[i]:
812 self.symtable[names[i]] = val
813 listcomp_recurse(out, i+1, names, data)
815 out = []
816 listcomp_recurse(out, 0, names, data)
817 for name, val in saved_syms.items():
818 self.symtable[name] = val
819 return out
821 def on_setcomp(self, node):
822 """Set comprehension"""
823 return set(self.on_listcomp(node))
825 def on_dictcomp(self, node):
826 """Dictionary comprehension"""
827 mylocals, saved_syms = self.comprehension_data(node)
829 names = list(mylocals.keys())
830 data = list(mylocals.values())
832 def dictcomp_recurse(out, i, names, data):
833 if i == len(names):
834 out[self.run(node.key)] = self.run(node.value)
835 return
837 for val in data[i]:
838 self.symtable[names[i]] = val
839 dictcomp_recurse(out, i+1, names, data)
841 out = {}
842 dictcomp_recurse(out, 0, names, data)
843 for name, val in saved_syms.items():
844 self.symtable[name] = val
845 return out
848 def on_excepthandler(self, node): # ('type', 'name', 'body')
849 """Exception handler..."""
850 return (self.run(node.type), node.name, node.body)
852 def on_try(self, node): # ('body', 'handlers', 'orelse', 'finalbody')
853 """Try/except/else/finally blocks."""
854 no_errors = True
855 for tnode in node.body:
856 self.run(tnode, with_raise=False)
857 no_errors = no_errors and len(self.error) == 0
858 if len(self.error) > 0:
859 e_type, e_value, _ = self.error[-1].exc_info
860 for hnd in node.handlers:
861 htype = None
862 if hnd.type is not None:
863 htype = __builtins__.get(hnd.type.id, None)
864 if htype is None or isinstance(e_type(), htype):
865 self.error = []
866 if hnd.name is not None:
867 self.node_assign(hnd.name, e_value)
868 for tline in hnd.body:
869 self.run(tline)
870 break
871 break
872 if no_errors and hasattr(node, 'orelse'):
873 for tnode in node.orelse:
874 self.run(tnode)
876 if hasattr(node, 'finalbody'):
877 for tnode in node.finalbody:
878 self.run(tnode)
880 def on_raise(self, node): # ('type', 'inst', 'tback')
881 """Raise statement: note difference for python 2 and 3."""
882 excnode = node.exc
883 msgnode = node.cause
884 out = self.run(excnode)
885 msg = ' '.join(out.args)
886 msg2 = self.run(msgnode)
887 if msg2 not in (None, 'None'):
888 msg = "%s: %s" % (msg, msg2)
889 self.raise_exception(None, exc=out.__class__, msg=msg, expr='')
891 def on_call(self, node):
892 """Function execution."""
893 # ('func', 'args', 'keywords'. Py<3.5 has 'starargs' and 'kwargs' too)
894 func = self.run(node.func)
895 if not hasattr(func, '__call__') and not isinstance(func, type):
896 msg = f"'{func}' is not callable!!"
897 self.raise_exception(node, exc=TypeError, msg=msg)
899 args = [self.run(targ) for targ in node.args]
900 starargs = getattr(node, 'starargs', None)
901 if starargs is not None:
902 args = args + self.run(starargs)
904 keywords = {}
905 if func == print:
906 keywords['file'] = self.writer
907 for key in node.keywords:
908 if not isinstance(key, ast.keyword):
909 msg = f"keyword error in function call '{func}'"
910 self.raise_exception(node, msg=msg)
911 if key.arg is None:
912 keywords.update(self.run(key.value))
913 elif key.arg in keywords:
914 self.raise_exception(node, exc=SyntaxError,
915 msg=f"keyword argument repeated: {key.arg}")
916 else:
917 keywords[key.arg] = self.run(key.value)
919 kwargs = getattr(node, 'kwargs', None)
920 if kwargs is not None:
921 keywords.update(self.run(kwargs))
923 if isinstance(func, Procedure):
924 self._calldepth += 1
925 try:
926 out = func(*args, **keywords)
927 except Exception as ex:
928 out = None
929 func_name = getattr(func, '__name__', str(func))
930 msg = f"Error running function '{func_name}' with args '{args}'"
931 msg = f"{msg} and kwargs {keywords}: {ex}"
932 self.raise_exception(node, msg=msg)
933 finally:
934 if isinstance(func, Procedure):
935 self._calldepth -= 1
936 return out
938 def on_arg(self, node): # ('test', 'msg')
939 """Arg for function definitions."""
940 return node.arg
942 def on_functiondef(self, node):
943 """Define procedures."""
944 # ('name', 'args', 'body', 'decorator_list')
945 if node.decorator_list:
946 raise Warning("decorated procedures not supported!")
947 kwargs = []
949 if (not valid_symbol_name(node.name) or
950 node.name in self.readonly_symbols):
951 errmsg = f"invalid function name (reserved word?) {node.name}"
952 self.raise_exception(node, exc=NameError, msg=errmsg)
954 offset = len(node.args.args) - len(node.args.defaults)
955 for idef, defnode in enumerate(node.args.defaults):
956 defval = self.run(defnode)
957 keyval = self.run(node.args.args[idef+offset])
958 kwargs.append((keyval, defval))
960 args = [tnode.arg for tnode in node.args.args[:offset]]
961 doc = None
962 nb0 = node.body[0]
963 if isinstance(nb0, ast.Expr) and isinstance(nb0.value, ast.Str):
964 doc = nb0.value.s
966 varkws = node.args.kwarg
967 vararg = node.args.vararg
968 if isinstance(vararg, ast.arg):
969 vararg = vararg.arg
970 if isinstance(varkws, ast.arg):
971 varkws = varkws.arg
973 self.symtable[node.name] = Procedure(node.name, self, doc=doc,
974 lineno=self.lineno,
975 body=node.body,
976 args=args, kwargs=kwargs,
977 vararg=vararg, varkws=varkws)
978 if node.name in self.no_deepcopy:
979 self.no_deepcopy.remove(node.name)