Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pure_eval/core.py: 13%
204 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 06:09 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 06:09 +0000
1import ast
2import builtins
3import operator
4from collections import ChainMap, OrderedDict, deque
5from contextlib import suppress
6from types import FrameType
7from typing import Any, Tuple, Iterable, List, Mapping, Dict, Union, Set
9from pure_eval.my_getattr_static import getattr_static
10from pure_eval.utils import (
11 CannotEval,
12 has_ast_name,
13 copy_ast_without_context,
14 is_standard_types,
15 of_standard_types,
16 is_any,
17 of_type,
18 ensure_dict,
19)
22class Evaluator:
23 def __init__(self, names: Mapping[str, Any]):
24 """
25 Construct a new evaluator with the given variable names.
26 This is a low level API, typically you will use `Evaluator.from_frame(frame)`.
28 :param names: a mapping from variable names to their values.
29 """
31 self.names = names
32 self._cache = {} # type: Dict[ast.expr, Any]
34 @classmethod
35 def from_frame(cls, frame: FrameType) -> 'Evaluator':
36 """
37 Construct an Evaluator that can look up variables from the given frame.
39 :param frame: a frame object, e.g. from a traceback or `inspect.currentframe().f_back`.
40 """
42 return cls(ChainMap(
43 ensure_dict(frame.f_locals),
44 ensure_dict(frame.f_globals),
45 ensure_dict(frame.f_builtins),
46 ))
48 def __getitem__(self, node: ast.expr) -> Any:
49 """
50 Find the value of the given node.
51 If it cannot be evaluated safely, this raises `CannotEval`.
52 The result is cached either way.
54 :param node: an AST expression to evaluate
55 :return: the value of the node
56 """
58 if not isinstance(node, ast.expr):
59 raise TypeError("node should be an ast.expr, not {!r}".format(type(node).__name__))
61 with suppress(KeyError):
62 result = self._cache[node]
63 if result is CannotEval:
64 raise CannotEval
65 else:
66 return result
68 try:
69 self._cache[node] = result = self._handle(node)
70 return result
71 except CannotEval:
72 self._cache[node] = CannotEval
73 raise
75 def _handle(self, node: ast.expr) -> Any:
76 """
77 This is where the evaluation happens.
78 Users should use `__getitem__`, i.e. `evaluator[node]`,
79 as it provides caching.
81 :param node: an AST expression to evaluate
82 :return: the value of the node
83 """
85 with suppress(Exception):
86 return ast.literal_eval(node)
88 if isinstance(node, ast.Name):
89 try:
90 return self.names[node.id]
91 except KeyError:
92 raise CannotEval
93 elif isinstance(node, ast.Attribute):
94 value = self[node.value]
95 attr = node.attr
96 return getattr_static(value, attr)
97 elif isinstance(node, ast.Subscript):
98 return self._handle_subscript(node)
99 elif isinstance(node, (ast.List, ast.Tuple, ast.Set, ast.Dict)):
100 return self._handle_container(node)
101 elif isinstance(node, ast.UnaryOp):
102 return self._handle_unary(node)
103 elif isinstance(node, ast.BinOp):
104 return self._handle_binop(node)
105 elif isinstance(node, ast.BoolOp):
106 return self._handle_boolop(node)
107 elif isinstance(node, ast.Compare):
108 return self._handle_compare(node)
109 elif isinstance(node, ast.Call):
110 return self._handle_call(node)
111 raise CannotEval
113 def _handle_call(self, node):
114 if node.keywords:
115 raise CannotEval
116 func = self[node.func]
117 args = [self[arg] for arg in node.args]
119 if (
120 is_any(
121 func,
122 slice,
123 int,
124 range,
125 round,
126 complex,
127 list,
128 tuple,
129 abs,
130 hex,
131 bin,
132 oct,
133 bool,
134 ord,
135 float,
136 len,
137 chr,
138 )
139 or len(args) == 0
140 and is_any(func, set, dict, str, frozenset, bytes, bytearray, object)
141 or len(args) >= 2
142 and is_any(func, str, divmod, bytes, bytearray, pow)
143 ):
144 args = [
145 of_standard_types(arg, check_dict_values=False, deep=False)
146 for arg in args
147 ]
148 try:
149 return func(*args)
150 except Exception as e:
151 raise CannotEval from e
153 if len(args) == 1:
154 arg = args[0]
155 if is_any(func, id, type):
156 try:
157 return func(arg)
158 except Exception as e:
159 raise CannotEval from e
160 if is_any(func, all, any, sum):
161 of_type(arg, tuple, frozenset, list, set, dict, OrderedDict, deque)
162 for x in arg:
163 of_standard_types(x, check_dict_values=False, deep=False)
164 try:
165 return func(arg)
166 except Exception as e:
167 raise CannotEval from e
169 if is_any(
170 func, sorted, min, max, hash, set, dict, ascii, str, repr, frozenset
171 ):
172 of_standard_types(arg, check_dict_values=True, deep=True)
173 try:
174 return func(arg)
175 except Exception as e:
176 raise CannotEval from e
177 raise CannotEval
179 def _handle_compare(self, node):
180 left = self[node.left]
181 result = True
183 for op, right in zip(node.ops, node.comparators):
184 right = self[right]
186 op_type = type(op)
187 op_func = {
188 ast.Eq: operator.eq,
189 ast.NotEq: operator.ne,
190 ast.Lt: operator.lt,
191 ast.LtE: operator.le,
192 ast.Gt: operator.gt,
193 ast.GtE: operator.ge,
194 ast.Is: operator.is_,
195 ast.IsNot: operator.is_not,
196 ast.In: (lambda a, b: a in b),
197 ast.NotIn: (lambda a, b: a not in b),
198 }[op_type]
200 if op_type not in (ast.Is, ast.IsNot):
201 of_standard_types(left, check_dict_values=False, deep=True)
202 of_standard_types(right, check_dict_values=False, deep=True)
204 try:
205 result = op_func(left, right)
206 except Exception as e:
207 raise CannotEval from e
208 if not result:
209 return result
210 left = right
212 return result
214 def _handle_boolop(self, node):
215 left = of_standard_types(
216 self[node.values[0]], check_dict_values=False, deep=False
217 )
219 for right in node.values[1:]:
220 # We need short circuiting so that the whole operation can be evaluated
221 # even if the right operand can't
222 if isinstance(node.op, ast.Or):
223 left = left or of_standard_types(
224 self[right], check_dict_values=False, deep=False
225 )
226 else:
227 assert isinstance(node.op, ast.And)
228 left = left and of_standard_types(
229 self[right], check_dict_values=False, deep=False
230 )
231 return left
233 def _handle_binop(self, node):
234 op_type = type(node.op)
235 op = {
236 ast.Add: operator.add,
237 ast.Sub: operator.sub,
238 ast.Mult: operator.mul,
239 ast.Div: operator.truediv,
240 ast.FloorDiv: operator.floordiv,
241 ast.Mod: operator.mod,
242 ast.Pow: operator.pow,
243 ast.LShift: operator.lshift,
244 ast.RShift: operator.rshift,
245 ast.BitOr: operator.or_,
246 ast.BitXor: operator.xor,
247 ast.BitAnd: operator.and_,
248 }.get(op_type)
249 if not op:
250 raise CannotEval
251 left = self[node.left]
252 hash_type = is_any(type(left), set, frozenset, dict, OrderedDict)
253 left = of_standard_types(left, check_dict_values=False, deep=hash_type)
254 formatting = type(left) in (str, bytes) and op_type == ast.Mod
256 right = of_standard_types(
257 self[node.right],
258 check_dict_values=formatting,
259 deep=formatting or hash_type,
260 )
261 try:
262 return op(left, right)
263 except Exception as e:
264 raise CannotEval from e
266 def _handle_unary(self, node: ast.UnaryOp):
267 value = of_standard_types(
268 self[node.operand], check_dict_values=False, deep=False
269 )
270 op_type = type(node.op)
271 op = {
272 ast.USub: operator.neg,
273 ast.UAdd: operator.pos,
274 ast.Not: operator.not_,
275 ast.Invert: operator.invert,
276 }[op_type]
277 try:
278 return op(value)
279 except Exception as e:
280 raise CannotEval from e
282 def _handle_subscript(self, node):
283 value = self[node.value]
284 of_standard_types(
285 value, check_dict_values=False, deep=is_any(type(value), dict, OrderedDict)
286 )
287 index = node.slice
288 if isinstance(index, ast.Slice):
289 index = slice(
290 *[
291 None if p is None else self[p]
292 for p in [index.lower, index.upper, index.step]
293 ]
294 )
295 elif isinstance(index, ast.ExtSlice):
296 raise CannotEval
297 else:
298 if isinstance(index, ast.Index):
299 index = index.value
300 index = self[index]
301 of_standard_types(index, check_dict_values=False, deep=True)
303 try:
304 return value[index]
305 except Exception:
306 raise CannotEval
308 def _handle_container(
309 self,
310 node: Union[ast.List, ast.Tuple, ast.Set, ast.Dict]
311 ) -> Union[List, Tuple, Set, Dict]:
312 """Handle container nodes, including List, Set, Tuple and Dict"""
313 if isinstance(node, ast.Dict):
314 elts = node.keys
315 if None in elts: # ** unpacking inside {}, not yet supported
316 raise CannotEval
317 else:
318 elts = node.elts
319 elts = [self[elt] for elt in elts]
320 if isinstance(node, ast.List):
321 return elts
322 if isinstance(node, ast.Tuple):
323 return tuple(elts)
325 # Set and Dict
326 if not all(
327 is_standard_types(elt, check_dict_values=False, deep=True) for elt in elts
328 ):
329 raise CannotEval
331 if isinstance(node, ast.Set):
332 try:
333 return set(elts)
334 except TypeError:
335 raise CannotEval
337 assert isinstance(node, ast.Dict)
339 pairs = [(elt, self[val]) for elt, val in zip(elts, node.values)]
340 try:
341 return dict(pairs)
342 except TypeError:
343 raise CannotEval
345 def find_expressions(self, root: ast.AST) -> Iterable[Tuple[ast.expr, Any]]:
346 """
347 Find all expressions in the given tree that can be safely evaluated.
348 This is a low level API, typically you will use `interesting_expressions_grouped`.
350 :param root: any AST node
351 :return: generator of pairs (tuples) of expression nodes and their corresponding values.
352 """
354 for node in ast.walk(root):
355 if not isinstance(node, ast.expr):
356 continue
358 try:
359 value = self[node]
360 except CannotEval:
361 continue
363 yield node, value
365 def interesting_expressions_grouped(self, root: ast.AST) -> List[Tuple[List[ast.expr], Any]]:
366 """
367 Find all interesting expressions in the given tree that can be safely evaluated,
368 grouping equivalent nodes together.
370 For more control and details, see:
371 - Evaluator.find_expressions
372 - is_expression_interesting
373 - group_expressions
375 :param root: any AST node
376 :return: A list of pairs (tuples) containing:
377 - A list of equivalent AST expressions
378 - The value of the first expression node
379 (which should be the same for all nodes, unless threads are involved)
380 """
382 return group_expressions(
383 pair
384 for pair in self.find_expressions(root)
385 if is_expression_interesting(*pair)
386 )
389def is_expression_interesting(node: ast.expr, value: Any) -> bool:
390 """
391 Determines if an expression is potentially interesting, at least in my opinion.
392 Returns False for the following expressions whose value is generally obvious:
393 - Literals (e.g. 123, 'abc', [1, 2, 3], {'a': (), 'b': ([1, 2], [3])})
394 - Variables or attributes whose name is equal to the value's __name__.
395 For example, a function `def foo(): ...` is not interesting when referred to
396 as `foo` as it usually would, but `bar` can be interesting if `bar is foo`.
397 Similarly the method `self.foo` is not interesting.
398 - Builtins (e.g. `len`) referred to by their usual name.
400 This is a low level API, typically you will use `interesting_expressions_grouped`.
402 :param node: an AST expression
403 :param value: the value of the node
404 :return: a boolean: True if the expression is interesting, False otherwise
405 """
407 with suppress(ValueError):
408 ast.literal_eval(node)
409 return False
411 # TODO exclude inner modules, e.g. numpy.random.__name__ == 'numpy.random' != 'random'
412 # TODO exclude common module abbreviations, e.g. numpy as np, pandas as pd
413 if has_ast_name(value, node):
414 return False
416 if (
417 isinstance(node, ast.Name)
418 and getattr(builtins, node.id, object()) is value
419 ):
420 return False
422 return True
425def group_expressions(expressions: Iterable[Tuple[ast.expr, Any]]) -> List[Tuple[List[ast.expr], Any]]:
426 """
427 Organise expression nodes and their values such that equivalent nodes are together.
428 Two nodes are considered equivalent if they have the same structure,
429 ignoring context (Load, Store, or Delete) and location (lineno, col_offset).
430 For example, this will group together the same variable name mentioned multiple times in an expression.
432 This will not check the values of the nodes. Equivalent nodes should have the same values,
433 unless threads are involved.
435 This is a low level API, typically you will use `interesting_expressions_grouped`.
437 :param expressions: pairs of AST expressions and their values, as obtained from
438 `Evaluator.find_expressions`, or `(node, evaluator[node])`.
439 :return: A list of pairs (tuples) containing:
440 - A list of equivalent AST expressions
441 - The value of the first expression node
442 (which should be the same for all nodes, unless threads are involved)
443 """
445 result = {}
446 for node, value in expressions:
447 dump = ast.dump(copy_ast_without_context(node))
448 result.setdefault(dump, ([], value))[0].append(node)
449 return list(result.values())