1from copy import copy
2from inspect import isclass, signature, Signature, getmodule
3from typing import (
4 Annotated,
5 AnyStr,
6 Callable,
7 Literal,
8 NamedTuple,
9 NewType,
10 Optional,
11 Protocol,
12 Sequence,
13 TypeGuard,
14 Union,
15 get_args,
16 get_origin,
17 is_typeddict,
18)
19import ast
20import builtins
21import collections
22import dataclasses
23import operator
24import sys
25import typing
26import warnings
27from functools import cached_property
28from dataclasses import dataclass, field
29from types import MethodDescriptorType, ModuleType, MethodType
30
31from IPython.utils.decorators import undoc
32
33
34from typing import Self, LiteralString
35
36if sys.version_info < (3, 12):
37 from typing_extensions import TypeAliasType
38else:
39 from typing import TypeAliasType
40
41
42@undoc
43class HasGetItem(Protocol):
44 def __getitem__(self, key) -> None: ...
45
46
47@undoc
48class InstancesHaveGetItem(Protocol):
49 def __call__(self, *args, **kwargs) -> HasGetItem: ...
50
51
52@undoc
53class HasGetAttr(Protocol):
54 def __getattr__(self, key) -> None: ...
55
56
57@undoc
58class DoesNotHaveGetAttr(Protocol):
59 pass
60
61
62# By default `__getattr__` is not explicitly implemented on most objects
63MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
64
65
66def _unbind_method(func: Callable) -> Union[Callable, None]:
67 """Get unbound method for given bound method.
68
69 Returns None if cannot get unbound method, or method is already unbound.
70 """
71 owner = getattr(func, "__self__", None)
72 owner_class = type(owner)
73 name = getattr(func, "__name__", None)
74 instance_dict_overrides = getattr(owner, "__dict__", None)
75 if (
76 owner is not None
77 and name
78 and (
79 not instance_dict_overrides
80 or (instance_dict_overrides and name not in instance_dict_overrides)
81 )
82 ):
83 return getattr(owner_class, name)
84 return None
85
86
87@undoc
88@dataclass
89class EvaluationPolicy:
90 """Definition of evaluation policy."""
91
92 allow_locals_access: bool = False
93 allow_globals_access: bool = False
94 allow_item_access: bool = False
95 allow_attr_access: bool = False
96 allow_builtins_access: bool = False
97 allow_all_operations: bool = False
98 allow_any_calls: bool = False
99 allow_auto_import: bool = False
100 allowed_calls: set[Callable] = field(default_factory=set)
101
102 def can_get_item(self, value, item):
103 return self.allow_item_access
104
105 def can_get_attr(self, value, attr):
106 return self.allow_attr_access
107
108 def can_operate(self, dunders: tuple[str, ...], a, b=None):
109 if self.allow_all_operations:
110 return True
111
112 def can_call(self, func):
113 if self.allow_any_calls:
114 return True
115
116 if func in self.allowed_calls:
117 return True
118
119 owner_method = _unbind_method(func)
120
121 if owner_method and owner_method in self.allowed_calls:
122 return True
123
124
125def _get_external(module_name: str, access_path: Sequence[str]):
126 """Get value from external module given a dotted access path.
127
128 Only gets value if the module is already imported.
129
130 Raises:
131 * `KeyError` if module is removed not found, and
132 * `AttributeError` if access path does not match an exported object
133 """
134 try:
135 member_type = sys.modules[module_name]
136 # standard module
137 for attr in access_path:
138 member_type = getattr(member_type, attr)
139 return member_type
140 except (KeyError, AttributeError):
141 # handle modules in namespace packages
142 module_path = ".".join([module_name, *access_path])
143 if module_path in sys.modules:
144 return sys.modules[module_path]
145 raise
146
147
148def _has_original_dunder_external(
149 value,
150 module_name: str,
151 access_path: Sequence[str],
152 method_name: str,
153):
154 if module_name not in sys.modules:
155 full_module_path = ".".join([module_name, *access_path])
156 if full_module_path not in sys.modules:
157 # LBYLB as it is faster
158 return False
159 try:
160 member_type = _get_external(module_name, access_path)
161 value_type = type(value)
162 if type(value) == member_type:
163 return True
164 if isinstance(member_type, ModuleType):
165 value_module = getmodule(value_type)
166 if not value_module or not value_module.__name__:
167 return False
168 if (
169 value_module.__name__ == member_type.__name__
170 or value_module.__name__.startswith(member_type.__name__ + ".")
171 ):
172 return True
173 if method_name == "__getattribute__":
174 # we have to short-circuit here due to an unresolved issue in
175 # `isinstance` implementation: https://bugs.python.org/issue32683
176 return False
177 if not isinstance(member_type, ModuleType) and isinstance(value, member_type):
178 method = getattr(value_type, method_name, None)
179 member_method = getattr(member_type, method_name, None)
180 if member_method == method:
181 return True
182 if isinstance(member_type, ModuleType):
183 method = getattr(value_type, method_name, None)
184 for base_class in value_type.__mro__[1:]:
185 base_module = getmodule(base_class)
186 if base_module and (
187 base_module.__name__ == member_type.__name__
188 or base_module.__name__.startswith(member_type.__name__ + ".")
189 ):
190 # Check if the method comes from this trusted base class
191 base_method = getattr(base_class, method_name, None)
192 if base_method is not None and base_method == method:
193 return True
194 except (AttributeError, KeyError):
195 return False
196
197
198def _has_original_dunder(
199 value, allowed_types, allowed_methods, allowed_external, method_name
200):
201 # note: Python ignores `__getattr__`/`__getitem__` on instances,
202 # we only need to check at class level
203 value_type = type(value)
204
205 # strict type check passes → no need to check method
206 if value_type in allowed_types:
207 return True
208
209 method = getattr(value_type, method_name, None)
210
211 if method is None:
212 return None
213
214 if method in allowed_methods:
215 return True
216
217 for module_name, *access_path in allowed_external:
218 if _has_original_dunder_external(value, module_name, access_path, method_name):
219 return True
220
221 return False
222
223
224def _coerce_path_to_tuples(
225 allow_list: set[tuple[str, ...] | str],
226) -> set[tuple[str, ...]]:
227 """Replace dotted paths on the provided allow-list with tuples."""
228 return {
229 path if isinstance(path, tuple) else tuple(path.split("."))
230 for path in allow_list
231 }
232
233
234@undoc
235@dataclass
236class SelectivePolicy(EvaluationPolicy):
237 allowed_getitem: set[InstancesHaveGetItem] = field(default_factory=set)
238 allowed_getitem_external: set[tuple[str, ...] | str] = field(default_factory=set)
239
240 allowed_getattr: set[MayHaveGetattr] = field(default_factory=set)
241 allowed_getattr_external: set[tuple[str, ...] | str] = field(default_factory=set)
242
243 allowed_operations: set = field(default_factory=set)
244 allowed_operations_external: set[tuple[str, ...] | str] = field(default_factory=set)
245
246 allow_getitem_on_types: bool = field(default_factory=bool)
247
248 _operation_methods_cache: dict[str, set[Callable]] = field(
249 default_factory=dict, init=False
250 )
251
252 def can_get_attr(self, value, attr):
253 allowed_getattr_external = _coerce_path_to_tuples(self.allowed_getattr_external)
254
255 has_original_attribute = _has_original_dunder(
256 value,
257 allowed_types=self.allowed_getattr,
258 allowed_methods=self._getattribute_methods,
259 allowed_external=allowed_getattr_external,
260 method_name="__getattribute__",
261 )
262 has_original_attr = _has_original_dunder(
263 value,
264 allowed_types=self.allowed_getattr,
265 allowed_methods=self._getattr_methods,
266 allowed_external=allowed_getattr_external,
267 method_name="__getattr__",
268 )
269
270 accept = False
271
272 # Many objects do not have `__getattr__`, this is fine.
273 if has_original_attr is None and has_original_attribute:
274 accept = True
275 else:
276 # Accept objects without modifications to `__getattr__` and `__getattribute__`
277 accept = has_original_attr and has_original_attribute
278
279 if accept:
280 # We still need to check for overridden properties.
281
282 value_class = type(value)
283 if not hasattr(value_class, attr):
284 return True
285
286 class_attr_val = getattr(value_class, attr)
287 is_property = isinstance(class_attr_val, property)
288
289 if not is_property:
290 return True
291
292 # Properties in allowed types are ok (although we do not include any
293 # properties in our default allow list currently).
294 if type(value) in self.allowed_getattr:
295 return True # pragma: no cover
296
297 # Properties in subclasses of allowed types may be ok if not changed
298 for module_name, *access_path in allowed_getattr_external:
299 try:
300 external_class = _get_external(module_name, access_path)
301 external_class_attr_val = getattr(external_class, attr)
302 except (KeyError, AttributeError):
303 return False # pragma: no cover
304 return class_attr_val == external_class_attr_val
305
306 return False
307
308 def can_get_item(self, value, item):
309 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
310 allowed_getitem_external = _coerce_path_to_tuples(self.allowed_getitem_external)
311 if self.allow_getitem_on_types:
312 # e.g. Union[str, int] or Literal[True, 1]
313 if isinstance(value, (typing._SpecialForm, typing._BaseGenericAlias)):
314 return True
315 # PEP 560 e.g. list[str]
316 if isinstance(value, type) and hasattr(value, "__class_getitem__"):
317 return True
318 return _has_original_dunder(
319 value,
320 allowed_types=self.allowed_getitem,
321 allowed_methods=self._getitem_methods,
322 allowed_external=allowed_getitem_external,
323 method_name="__getitem__",
324 )
325
326 def can_operate(self, dunders: tuple[str, ...], a, b=None):
327 allowed_operations_external = _coerce_path_to_tuples(
328 self.allowed_operations_external
329 )
330 objects = [a]
331 if b is not None:
332 objects.append(b)
333 return all(
334 [
335 _has_original_dunder(
336 obj,
337 allowed_types=self.allowed_operations,
338 allowed_methods=self._operator_dunder_methods(dunder),
339 allowed_external=allowed_operations_external,
340 method_name=dunder,
341 )
342 for dunder in dunders
343 for obj in objects
344 ]
345 )
346
347 def _operator_dunder_methods(self, dunder: str) -> set[Callable]:
348 if dunder not in self._operation_methods_cache:
349 self._operation_methods_cache[dunder] = self._safe_get_methods(
350 self.allowed_operations, dunder
351 )
352 return self._operation_methods_cache[dunder]
353
354 @cached_property
355 def _getitem_methods(self) -> set[Callable]:
356 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
357
358 @cached_property
359 def _getattr_methods(self) -> set[Callable]:
360 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
361
362 @cached_property
363 def _getattribute_methods(self) -> set[Callable]:
364 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
365
366 def _safe_get_methods(self, classes, name) -> set[Callable]:
367 return {
368 method
369 for class_ in classes
370 for method in [getattr(class_, name, None)]
371 if method
372 }
373
374
375class _DummyNamedTuple(NamedTuple):
376 """Used internally to retrieve methods of named tuple instance."""
377
378
379EvaluationPolicyName = Literal["forbidden", "minimal", "limited", "unsafe", "dangerous"]
380
381
382@dataclass
383class EvaluationContext:
384 #: Local namespace
385 locals: dict
386 #: Global namespace
387 globals: dict
388 #: Evaluation policy identifier
389 evaluation: EvaluationPolicyName = "forbidden"
390 #: Whether the evaluation of code takes place inside of a subscript.
391 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
392 in_subscript: bool = False
393 #: Auto import method
394 auto_import: Callable[list[str], ModuleType] | None = None
395 #: Overrides for evaluation policy
396 policy_overrides: dict = field(default_factory=dict)
397 #: Transient local namespace used to store mocks
398 transient_locals: dict = field(default_factory=dict)
399 #: Transients of class level
400 class_transients: dict | None = None
401 #: Instance variable name used in the method definition
402 instance_arg_name: str | None = None
403
404 def replace(self, /, **changes):
405 """Return a new copy of the context, with specified changes"""
406 return dataclasses.replace(self, **changes)
407
408
409class _IdentitySubscript:
410 """Returns the key itself when item is requested via subscript."""
411
412 def __getitem__(self, key):
413 return key
414
415
416IDENTITY_SUBSCRIPT = _IdentitySubscript()
417SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
418UNKNOWN_SIGNATURE = Signature()
419NOT_EVALUATED = object()
420
421
422class GuardRejection(Exception):
423 """Exception raised when guard rejects evaluation attempt."""
424
425 pass
426
427
428def guarded_eval(code: str, context: EvaluationContext):
429 """Evaluate provided code in the evaluation context.
430
431 If evaluation policy given by context is set to ``forbidden``
432 no evaluation will be performed; if it is set to ``dangerous``
433 standard :func:`eval` will be used; finally, for any other,
434 policy :func:`eval_node` will be called on parsed AST.
435 """
436 locals_ = context.locals
437
438 if context.evaluation == "forbidden":
439 raise GuardRejection("Forbidden mode")
440
441 # note: not using `ast.literal_eval` as it does not implement
442 # getitem at all, for example it fails on simple `[0][1]`
443
444 if context.in_subscript:
445 # syntactic sugar for ellipsis (:) is only available in subscripts
446 # so we need to trick the ast parser into thinking that we have
447 # a subscript, but we need to be able to later recognise that we did
448 # it so we can ignore the actual __getitem__ operation
449 if not code:
450 return tuple()
451 locals_ = locals_.copy()
452 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
453 code = SUBSCRIPT_MARKER + "[" + code + "]"
454 context = context.replace(locals=locals_)
455
456 if context.evaluation == "dangerous":
457 return eval(code, context.globals, context.locals)
458
459 node = ast.parse(code, mode="exec")
460
461 return eval_node(node, context)
462
463
464BINARY_OP_DUNDERS: dict[type[ast.operator], tuple[str]] = {
465 ast.Add: ("__add__",),
466 ast.Sub: ("__sub__",),
467 ast.Mult: ("__mul__",),
468 ast.Div: ("__truediv__",),
469 ast.FloorDiv: ("__floordiv__",),
470 ast.Mod: ("__mod__",),
471 ast.Pow: ("__pow__",),
472 ast.LShift: ("__lshift__",),
473 ast.RShift: ("__rshift__",),
474 ast.BitOr: ("__or__",),
475 ast.BitXor: ("__xor__",),
476 ast.BitAnd: ("__and__",),
477 ast.MatMult: ("__matmul__",),
478}
479
480COMP_OP_DUNDERS: dict[type[ast.cmpop], tuple[str, ...]] = {
481 ast.Eq: ("__eq__",),
482 ast.NotEq: ("__ne__", "__eq__"),
483 ast.Lt: ("__lt__", "__gt__"),
484 ast.LtE: ("__le__", "__ge__"),
485 ast.Gt: ("__gt__", "__lt__"),
486 ast.GtE: ("__ge__", "__le__"),
487 ast.In: ("__contains__",),
488 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
489}
490
491UNARY_OP_DUNDERS: dict[type[ast.unaryop], tuple[str, ...]] = {
492 ast.USub: ("__neg__",),
493 ast.UAdd: ("__pos__",),
494 # we have to check both __inv__ and __invert__!
495 ast.Invert: ("__invert__", "__inv__"),
496 ast.Not: ("__not__",),
497}
498
499
500class ImpersonatingDuck:
501 """A dummy class used to create objects of other classes without calling their ``__init__``"""
502
503 # no-op: override __class__ to impersonate
504
505
506class _Duck:
507 """A dummy class used to create objects pretending to have given attributes"""
508
509 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
510 self.attributes = attributes if attributes is not None else {}
511 self.items = items if items is not None else {}
512
513 def __getattr__(self, attr: str):
514 return self.attributes[attr]
515
516 def __hasattr__(self, attr: str):
517 return attr in self.attributes
518
519 def __dir__(self):
520 return [*dir(super), *self.attributes]
521
522 def __getitem__(self, key: str):
523 return self.items[key]
524
525 def __hasitem__(self, key: str):
526 return self.items[key]
527
528 def _ipython_key_completions_(self):
529 return self.items.keys()
530
531
532def _find_dunder(node_op, dunders) -> Union[tuple[str, ...], None]:
533 dunder = None
534 for op, candidate_dunder in dunders.items():
535 if isinstance(node_op, op):
536 dunder = candidate_dunder
537 return dunder
538
539
540def get_policy(context: EvaluationContext) -> EvaluationPolicy:
541 policy = copy(EVALUATION_POLICIES[context.evaluation])
542
543 for key, value in context.policy_overrides.items():
544 if hasattr(policy, key):
545 setattr(policy, key, value)
546 return policy
547
548
549def _validate_policy_overrides(
550 policy_name: EvaluationPolicyName, policy_overrides: dict
551) -> bool:
552 policy = EVALUATION_POLICIES[policy_name]
553
554 all_good = True
555 for key, value in policy_overrides.items():
556 if not hasattr(policy, key):
557 warnings.warn(
558 f"Override {key!r} is not valid with {policy_name!r} evaluation policy"
559 )
560 all_good = False
561 return all_good
562
563
564def _handle_assign(node: ast.Assign, context: EvaluationContext):
565 value = eval_node(node.value, context)
566 transient_locals = context.transient_locals
567 policy = get_policy(context)
568 class_transients = context.class_transients
569 for target in node.targets:
570 if isinstance(target, (ast.Tuple, ast.List)):
571 # Handle unpacking assignment
572 values = list(value)
573 targets = target.elts
574 starred = [i for i, t in enumerate(targets) if isinstance(t, ast.Starred)]
575
576 # Unified handling: treat no starred as starred at end
577 star_or_last_idx = starred[0] if starred else len(targets)
578
579 # Before starred
580 for i in range(star_or_last_idx):
581 # Check for self.x assignment
582 if _is_instance_attribute_assignment(targets[i], context):
583 class_transients[targets[i].attr] = values[i]
584 else:
585 transient_locals[targets[i].id] = values[i]
586
587 # Starred if exists
588 if starred:
589 end = len(values) - (len(targets) - star_or_last_idx - 1)
590 if _is_instance_attribute_assignment(
591 targets[star_or_last_idx], context
592 ):
593 class_transients[targets[star_or_last_idx].attr] = values[
594 star_or_last_idx:end
595 ]
596 else:
597 transient_locals[targets[star_or_last_idx].value.id] = values[
598 star_or_last_idx:end
599 ]
600
601 # After starred
602 for i in range(star_or_last_idx + 1, len(targets)):
603 if _is_instance_attribute_assignment(targets[i], context):
604 class_transients[targets[i].attr] = values[
605 len(values) - (len(targets) - i)
606 ]
607 else:
608 transient_locals[targets[i].id] = values[
609 len(values) - (len(targets) - i)
610 ]
611 elif isinstance(target, ast.Subscript):
612 if isinstance(target.value, ast.Name):
613 name = target.value.id
614 container = transient_locals.get(name)
615 if container is None:
616 container = context.locals.get(name)
617 if container is None:
618 container = context.globals.get(name)
619 if container is None:
620 raise NameError(
621 f"{name} not found in locals, globals, nor builtins"
622 )
623 storage_dict = transient_locals
624 storage_key = name
625 elif isinstance(
626 target.value, ast.Attribute
627 ) and _is_instance_attribute_assignment(target.value, context):
628 attr = target.value.attr
629 container = class_transients.get(attr, None)
630 if container is None:
631 raise NameError(f"{attr} not found in class transients")
632 storage_dict = class_transients
633 storage_key = attr
634 else:
635 return
636
637 key = eval_node(target.slice, context)
638 attributes = (
639 dict.fromkeys(dir(container))
640 if policy.can_call(container.__dir__)
641 else {}
642 )
643 items = {}
644
645 if policy.can_get_item(container, None):
646 try:
647 items = dict(container.items())
648 except Exception:
649 pass
650
651 items[key] = value
652 duck_container = _Duck(attributes=attributes, items=items)
653 storage_dict[storage_key] = duck_container
654 elif _is_instance_attribute_assignment(target, context):
655 class_transients[target.attr] = value
656 else:
657 transient_locals[target.id] = value
658 return None
659
660
661def _extract_args_and_kwargs(node: ast.Call, context: EvaluationContext):
662 args = [eval_node(arg, context) for arg in node.args]
663 kwargs = {
664 k: v
665 for kw in node.keywords
666 for k, v in (
667 {kw.arg: eval_node(kw.value, context)}
668 if kw.arg
669 else eval_node(kw.value, context)
670 ).items()
671 }
672 return args, kwargs
673
674
675def _is_instance_attribute_assignment(
676 target: ast.AST, context: EvaluationContext
677) -> bool:
678 """Return True if target is an attribute access on the instance argument."""
679 return (
680 context.class_transients is not None
681 and context.instance_arg_name is not None
682 and isinstance(target, ast.Attribute)
683 and isinstance(getattr(target, "value", None), ast.Name)
684 and getattr(target.value, "id", None) == context.instance_arg_name
685 )
686
687
688def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
689 """Evaluate AST node in provided context.
690
691 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
692
693 Does not evaluate actions that always have side effects:
694
695 - class definitions (``class sth: ...``)
696 - function definitions (``def sth: ...``)
697 - variable assignments (``x = 1``)
698 - augmented assignments (``x += 1``)
699 - deletions (``del x``)
700
701 Does not evaluate operations which do not return values:
702
703 - assertions (``assert x``)
704 - pass (``pass``)
705 - imports (``import x``)
706 - control flow:
707
708 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
709 - loops (``for`` and ``while``)
710 - exception handling
711
712 The purpose of this function is to guard against unwanted side-effects;
713 it does not give guarantees on protection from malicious code execution.
714 """
715 policy = get_policy(context)
716
717 if node is None:
718 return None
719 if isinstance(node, (ast.Interactive, ast.Module)):
720 result = None
721 for child_node in node.body:
722 result = eval_node(child_node, context)
723 return result
724 if isinstance(node, ast.FunctionDef):
725 func_locals = context.transient_locals.copy()
726 func_context = context.replace(transient_locals=func_locals)
727 is_property = False
728 is_static = False
729 is_classmethod = False
730 for decorator_node in node.decorator_list:
731 try:
732 decorator = eval_node(decorator_node, context)
733 except NameError:
734 # if the decorator is not yet defined this is fine
735 # especialy because we don't handle imports yet
736 continue
737 if decorator is property:
738 is_property = True
739 elif decorator is staticmethod:
740 is_static = True
741 elif decorator is classmethod:
742 is_classmethod = True
743
744 if func_context.class_transients is not None:
745 if not is_static and not is_classmethod:
746 func_context.instance_arg_name = (
747 node.args.args[0].arg if node.args.args else None
748 )
749
750 return_type = eval_node(node.returns, context=context)
751
752 for child_node in node.body:
753 eval_node(child_node, func_context)
754
755 if is_property:
756 if return_type is not None:
757 context.transient_locals[node.name] = _resolve_annotation(
758 return_type, context
759 )
760 else:
761 return_value = _infer_return_value(node, func_context)
762 context.transient_locals[node.name] = return_value
763
764 return None
765
766 def dummy_function(*args, **kwargs):
767 pass
768
769 if return_type is not None:
770 dummy_function.__annotations__["return"] = return_type
771 else:
772 inferred_return = _infer_return_value(node, func_context)
773 if inferred_return is not None:
774 dummy_function.__inferred_return__ = inferred_return
775
776 dummy_function.__name__ = node.name
777 dummy_function.__node__ = node
778 context.transient_locals[node.name] = dummy_function
779 return None
780 if isinstance(node, ast.ClassDef):
781 # TODO support class decorators?
782 class_locals = {}
783 outer_locals = context.locals.copy()
784 outer_locals.update(context.transient_locals)
785 class_context = context.replace(
786 transient_locals=class_locals, locals=outer_locals
787 )
788 class_context.class_transients = class_locals
789 for child_node in node.body:
790 eval_node(child_node, class_context)
791 bases = tuple([eval_node(base, context) for base in node.bases])
792 dummy_class = type(node.name, bases, class_locals)
793 context.transient_locals[node.name] = dummy_class
794 return None
795 if isinstance(node, ast.Assign):
796 return _handle_assign(node, context)
797 if isinstance(node, ast.AnnAssign):
798 if node.simple:
799 value = _resolve_annotation(eval_node(node.annotation, context), context)
800 context.transient_locals[node.target.id] = value
801 # Handle non-simple annotated assignments only for self.x: type = value
802 if _is_instance_attribute_assignment(node.target, context):
803 value = _resolve_annotation(eval_node(node.annotation, context), context)
804 context.class_transients[node.target.attr] = value
805 return None
806 if isinstance(node, ast.Expression):
807 return eval_node(node.body, context)
808 if isinstance(node, ast.Expr):
809 return eval_node(node.value, context)
810 if isinstance(node, ast.Pass):
811 return None
812 if isinstance(node, ast.Import):
813 # TODO: populate transient_locals
814 return None
815 if isinstance(node, (ast.AugAssign, ast.Delete)):
816 return None
817 if isinstance(node, (ast.Global, ast.Nonlocal)):
818 return None
819 if isinstance(node, ast.BinOp):
820 left = eval_node(node.left, context)
821 right = eval_node(node.right, context)
822 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
823 if dunders:
824 if policy.can_operate(dunders, left, right):
825 return getattr(left, dunders[0])(right)
826 else:
827 raise GuardRejection(
828 f"Operation (`{dunders}`) for",
829 type(left),
830 f"not allowed in {context.evaluation} mode",
831 )
832 if isinstance(node, ast.Compare):
833 left = eval_node(node.left, context)
834 all_true = True
835 negate = False
836 for op, right in zip(node.ops, node.comparators):
837 right = eval_node(right, context)
838 dunder = None
839 dunders = _find_dunder(op, COMP_OP_DUNDERS)
840 if not dunders:
841 if isinstance(op, ast.NotIn):
842 dunders = COMP_OP_DUNDERS[ast.In]
843 negate = True
844 if isinstance(op, ast.Is):
845 dunder = "is_"
846 if isinstance(op, ast.IsNot):
847 dunder = "is_"
848 negate = True
849 if not dunder and dunders:
850 dunder = dunders[0]
851 if dunder:
852 a, b = (right, left) if dunder == "__contains__" else (left, right)
853 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
854 result = getattr(operator, dunder)(a, b)
855 if negate:
856 result = not result
857 if not result:
858 all_true = False
859 left = right
860 else:
861 raise GuardRejection(
862 f"Comparison (`{dunder}`) for",
863 type(left),
864 f"not allowed in {context.evaluation} mode",
865 )
866 else:
867 raise ValueError(
868 f"Comparison `{dunder}` not supported"
869 ) # pragma: no cover
870 return all_true
871 if isinstance(node, ast.Constant):
872 return node.value
873 if isinstance(node, ast.Tuple):
874 return tuple(eval_node(e, context) for e in node.elts)
875 if isinstance(node, ast.List):
876 return [eval_node(e, context) for e in node.elts]
877 if isinstance(node, ast.Set):
878 return {eval_node(e, context) for e in node.elts}
879 if isinstance(node, ast.Dict):
880 return dict(
881 zip(
882 [eval_node(k, context) for k in node.keys],
883 [eval_node(v, context) for v in node.values],
884 )
885 )
886 if isinstance(node, ast.Slice):
887 return slice(
888 eval_node(node.lower, context),
889 eval_node(node.upper, context),
890 eval_node(node.step, context),
891 )
892 if isinstance(node, ast.UnaryOp):
893 value = eval_node(node.operand, context)
894 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
895 if dunders:
896 if policy.can_operate(dunders, value):
897 try:
898 return getattr(value, dunders[0])()
899 except AttributeError:
900 raise TypeError(
901 f"bad operand type for unary {node.op}: {type(value)}"
902 )
903 else:
904 raise GuardRejection(
905 f"Operation (`{dunders}`) for",
906 type(value),
907 f"not allowed in {context.evaluation} mode",
908 )
909 if isinstance(node, ast.Subscript):
910 value = eval_node(node.value, context)
911 slice_ = eval_node(node.slice, context)
912 if policy.can_get_item(value, slice_):
913 return value[slice_]
914 raise GuardRejection(
915 "Subscript access (`__getitem__`) for",
916 type(value), # not joined to avoid calling `repr`
917 f" not allowed in {context.evaluation} mode",
918 )
919 if isinstance(node, ast.Name):
920 return _eval_node_name(node.id, context)
921 if isinstance(node, ast.Attribute):
922 if (
923 context.class_transients is not None
924 and isinstance(node.value, ast.Name)
925 and node.value.id == context.instance_arg_name
926 ):
927 return context.class_transients.get(node.attr)
928 value = eval_node(node.value, context)
929 if policy.can_get_attr(value, node.attr):
930 return getattr(value, node.attr)
931 raise GuardRejection(
932 "Attribute access (`__getattr__`) for",
933 type(value), # not joined to avoid calling `repr`
934 f"not allowed in {context.evaluation} mode",
935 )
936 if isinstance(node, ast.IfExp):
937 test = eval_node(node.test, context)
938 if test:
939 return eval_node(node.body, context)
940 else:
941 return eval_node(node.orelse, context)
942 if isinstance(node, ast.Call):
943 func = eval_node(node.func, context)
944 if policy.can_call(func):
945 args, kwargs = _extract_args_and_kwargs(node, context)
946 return func(*args, **kwargs)
947 if isclass(func):
948 # this code path gets entered when calling class e.g. `MyClass()`
949 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
950 # Should return `MyClass` if `__new__` is not overridden,
951 # otherwise whatever `__new__` return type is.
952 overridden_return_type = _eval_return_type(func.__new__, node, context)
953 if overridden_return_type is not NOT_EVALUATED:
954 return overridden_return_type
955 return _create_duck_for_heap_type(func)
956 else:
957 if hasattr(func, "__inferred_return__"):
958 return func.__inferred_return__
959 return_type = _eval_return_type(func, node, context)
960 if return_type is not NOT_EVALUATED:
961 return return_type
962 raise GuardRejection(
963 "Call for",
964 func, # not joined to avoid calling `repr`
965 f"not allowed in {context.evaluation} mode",
966 )
967 if isinstance(node, ast.Assert):
968 # message is always the second item, so if it is defined user would be completing
969 # on the message, not on the assertion test
970 if node.msg:
971 return eval_node(node.msg, context)
972 return eval_node(node.test, context)
973 return None
974
975
976def _merge_values(values, policy: EvaluationPolicy):
977 """Recursively merge multiple values, combining attributes and dict items."""
978 if len(values) == 1:
979 return values[0]
980
981 types = {type(v) for v in values}
982 merged_items = None
983 key_values = {}
984 attributes = set()
985 for v in values:
986 if policy.can_call(v.__dir__):
987 attributes.update(dir(v))
988 try:
989 if policy.can_call(v.items):
990 try:
991 for k, val in v.items():
992 key_values.setdefault(k, []).append(val)
993 except Exception as e:
994 pass
995 elif policy.can_call(v.keys):
996 try:
997 for k in v.keys():
998 key_values.setdefault(k, []).append(None)
999 except Exception as e:
1000 pass
1001 except Exception as e:
1002 pass
1003
1004 if key_values:
1005 merged_items = {
1006 k: _merge_values(vals, policy) if vals[0] is not None else None
1007 for k, vals in key_values.items()
1008 }
1009
1010 if len(types) == 1:
1011 t = next(iter(types))
1012 if t not in (dict,) and not (
1013 hasattr(next(iter(values)), "__getitem__")
1014 and (
1015 hasattr(next(iter(values)), "items")
1016 or hasattr(next(iter(values)), "keys")
1017 )
1018 ):
1019 if t in (list, set, tuple):
1020 return t
1021 return values[0]
1022
1023 return _Duck(attributes=dict.fromkeys(attributes), items=merged_items)
1024
1025
1026def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
1027 """Infer the return value(s) of a function by evaluating all return statements."""
1028 return_values = _collect_return_values(node.body, context)
1029
1030 if not return_values:
1031 return None
1032 if len(return_values) == 1:
1033 return return_values[0]
1034
1035 policy = get_policy(context)
1036 return _merge_values(return_values, policy)
1037
1038
1039def _collect_return_values(body, context):
1040 """Recursively collect return values from a list of AST statements."""
1041 return_values = []
1042 for stmt in body:
1043 if isinstance(stmt, ast.Return):
1044 if stmt.value is None:
1045 continue
1046 try:
1047 value = eval_node(stmt.value, context)
1048 if value is not None and value is not NOT_EVALUATED:
1049 return_values.append(value)
1050 except Exception:
1051 pass
1052 if isinstance(
1053 stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)
1054 ):
1055 continue
1056 elif hasattr(stmt, "body") and isinstance(stmt.body, list):
1057 return_values.extend(_collect_return_values(stmt.body, context))
1058 if isinstance(stmt, ast.Try):
1059 for h in stmt.handlers:
1060 if hasattr(h, "body"):
1061 return_values.extend(_collect_return_values(h.body, context))
1062 if hasattr(stmt, "orelse"):
1063 return_values.extend(_collect_return_values(stmt.orelse, context))
1064 if hasattr(stmt, "finalbody"):
1065 return_values.extend(_collect_return_values(stmt.finalbody, context))
1066 if hasattr(stmt, "orelse") and isinstance(stmt.orelse, list):
1067 return_values.extend(_collect_return_values(stmt.orelse, context))
1068 return return_values
1069
1070
1071def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
1072 """Evaluate return type of a given callable function.
1073
1074 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
1075 """
1076 try:
1077 sig = signature(func)
1078 except ValueError:
1079 sig = UNKNOWN_SIGNATURE
1080 # if annotation was not stringized, or it was stringized
1081 # but resolved by signature call we know the return type
1082 not_empty = sig.return_annotation is not Signature.empty
1083 if not_empty:
1084 return _resolve_annotation(sig.return_annotation, context, sig, func, node)
1085 return NOT_EVALUATED
1086
1087
1088def _eval_annotation(
1089 annotation: str,
1090 context: EvaluationContext,
1091):
1092 return (
1093 _eval_node_name(annotation, context)
1094 if isinstance(annotation, str)
1095 else annotation
1096 )
1097
1098
1099class _GetItemDuck(dict):
1100 """A dict subclass that always returns the factory instance and claims to have any item."""
1101
1102 def __init__(self, factory, *args, **kwargs):
1103 super().__init__(*args, **kwargs)
1104 self._factory = factory
1105
1106 def __getitem__(self, key):
1107 return self._factory()
1108
1109 def __contains__(self, key):
1110 return True
1111
1112
1113def _resolve_annotation(
1114 annotation: object | str,
1115 context: EvaluationContext,
1116 sig: Signature | None = None,
1117 func: Callable | None = None,
1118 node: ast.Call | None = None,
1119):
1120 """Resolve annotation created by user with `typing` module and custom objects."""
1121 if annotation is None:
1122 return None
1123 annotation = _eval_annotation(annotation, context)
1124 origin = get_origin(annotation)
1125 if annotation is Self and func and hasattr(func, "__self__"):
1126 return func.__self__
1127 elif origin is Literal:
1128 type_args = get_args(annotation)
1129 if len(type_args) == 1:
1130 return type_args[0]
1131 elif annotation is LiteralString:
1132 return ""
1133 elif annotation is AnyStr:
1134 index = None
1135 if func and hasattr(func, "__node__"):
1136 def_node = func.__node__
1137 for i, arg in enumerate(def_node.args.args):
1138 if not arg.annotation:
1139 continue
1140 annotation = _eval_annotation(arg.annotation.id, context)
1141 if annotation is AnyStr:
1142 index = i
1143 break
1144 is_bound_method = (
1145 isinstance(func, MethodType) and getattr(func, "__self__") is not None
1146 )
1147 if index and is_bound_method:
1148 index -= 1
1149 elif sig:
1150 for i, (key, value) in enumerate(sig.parameters.items()):
1151 if value.annotation is AnyStr:
1152 index = i
1153 break
1154 if index is None:
1155 return None
1156 if index < 0 or index >= len(node.args):
1157 return None
1158 return eval_node(node.args[index], context)
1159 elif origin is TypeGuard:
1160 return False
1161 elif origin is set or origin is list:
1162 # only one type argument allowed
1163 attributes = [
1164 attr
1165 for attr in dir(
1166 _resolve_annotation(get_args(annotation)[0], context, sig, func, node)
1167 )
1168 ]
1169 duck = _Duck(attributes=dict.fromkeys(attributes))
1170 return _Duck(
1171 attributes=dict.fromkeys(dir(origin())),
1172 # items are not strrictly needed for set
1173 items=_GetItemDuck(lambda: duck),
1174 )
1175 elif origin is tuple:
1176 # multiple type arguments
1177 return tuple(
1178 _resolve_annotation(arg, context, sig, func, node)
1179 for arg in get_args(annotation)
1180 )
1181 elif origin is Union:
1182 # multiple type arguments
1183 attributes = [
1184 attr
1185 for type_arg in get_args(annotation)
1186 for attr in dir(_resolve_annotation(type_arg, context, sig, func, node))
1187 ]
1188 return _Duck(attributes=dict.fromkeys(attributes))
1189 elif is_typeddict(annotation):
1190 return _Duck(
1191 attributes=dict.fromkeys(dir(dict())),
1192 items={
1193 k: _resolve_annotation(v, context, sig, func, node)
1194 for k, v in annotation.__annotations__.items()
1195 },
1196 )
1197 elif hasattr(annotation, "_is_protocol"):
1198 return _Duck(attributes=dict.fromkeys(dir(annotation)))
1199 elif origin is Annotated:
1200 type_arg = get_args(annotation)[0]
1201 return _resolve_annotation(type_arg, context, sig, func, node)
1202 elif isinstance(annotation, NewType):
1203 return _eval_or_create_duck(annotation.__supertype__, context)
1204 elif isinstance(annotation, TypeAliasType):
1205 return _eval_or_create_duck(annotation.__value__, context)
1206 else:
1207 return _eval_or_create_duck(annotation, context)
1208
1209
1210def _eval_node_name(node_id: str, context: EvaluationContext):
1211 policy = get_policy(context)
1212 if node_id in context.transient_locals:
1213 return context.transient_locals[node_id]
1214 if policy.allow_locals_access and node_id in context.locals:
1215 return context.locals[node_id]
1216 if policy.allow_globals_access and node_id in context.globals:
1217 return context.globals[node_id]
1218 if policy.allow_builtins_access and hasattr(builtins, node_id):
1219 # note: do not use __builtins__, it is implementation detail of cPython
1220 return getattr(builtins, node_id)
1221 if policy.allow_auto_import and context.auto_import:
1222 return context.auto_import(node_id)
1223 if not policy.allow_globals_access and not policy.allow_locals_access:
1224 raise GuardRejection(
1225 f"Namespace access not allowed in {context.evaluation} mode"
1226 )
1227 else:
1228 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
1229
1230
1231def _eval_or_create_duck(duck_type, context: EvaluationContext):
1232 policy = get_policy(context)
1233 # if allow-listed builtin is on type annotation, instantiate it
1234 if policy.can_call(duck_type):
1235 return duck_type()
1236 # if custom class is in type annotation, mock it
1237 return _create_duck_for_heap_type(duck_type)
1238
1239
1240def _create_duck_for_heap_type(duck_type):
1241 """Create an imitation of an object of a given type (a duck).
1242
1243 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
1244 """
1245 duck = ImpersonatingDuck()
1246 try:
1247 # this only works for heap types, not builtins
1248 duck.__class__ = duck_type
1249 return duck
1250 except TypeError:
1251 pass
1252 return NOT_EVALUATED
1253
1254
1255SUPPORTED_EXTERNAL_GETITEM = {
1256 ("pandas", "core", "indexing", "_iLocIndexer"),
1257 ("pandas", "core", "indexing", "_LocIndexer"),
1258 ("pandas", "DataFrame"),
1259 ("pandas", "Series"),
1260 ("numpy", "ndarray"),
1261 ("numpy", "void"),
1262}
1263
1264
1265BUILTIN_GETITEM: set[InstancesHaveGetItem] = {
1266 dict,
1267 str, # type: ignore[arg-type]
1268 bytes, # type: ignore[arg-type]
1269 list,
1270 tuple,
1271 type, # for type annotations like list[str]
1272 _Duck,
1273 collections.defaultdict,
1274 collections.deque,
1275 collections.OrderedDict,
1276 collections.ChainMap,
1277 collections.UserDict,
1278 collections.UserList,
1279 collections.UserString, # type: ignore[arg-type]
1280 _DummyNamedTuple,
1281 _IdentitySubscript,
1282}
1283
1284
1285def _list_methods(cls, source=None):
1286 """For use on immutable objects or with methods returning a copy"""
1287 return [getattr(cls, k) for k in (source if source else dir(cls))]
1288
1289
1290dict_non_mutating_methods = ("copy", "keys", "values", "items")
1291list_non_mutating_methods = ("copy", "index", "count")
1292set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
1293
1294
1295dict_keys: type[collections.abc.KeysView] = type({}.keys())
1296
1297NUMERICS = {int, float, complex}
1298
1299ALLOWED_CALLS = {
1300 bytes,
1301 *_list_methods(bytes),
1302 dict,
1303 *_list_methods(dict, dict_non_mutating_methods),
1304 dict_keys.isdisjoint,
1305 list,
1306 *_list_methods(list, list_non_mutating_methods),
1307 set,
1308 *_list_methods(set, set_non_mutating_methods),
1309 frozenset,
1310 *_list_methods(frozenset),
1311 range,
1312 str,
1313 *_list_methods(str),
1314 tuple,
1315 *_list_methods(tuple),
1316 bool,
1317 *_list_methods(bool),
1318 *NUMERICS,
1319 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
1320 collections.deque,
1321 *_list_methods(collections.deque, list_non_mutating_methods),
1322 collections.defaultdict,
1323 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
1324 collections.OrderedDict,
1325 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
1326 collections.UserDict,
1327 *_list_methods(collections.UserDict, dict_non_mutating_methods),
1328 collections.UserList,
1329 *_list_methods(collections.UserList, list_non_mutating_methods),
1330 collections.UserString,
1331 *_list_methods(collections.UserString, dir(str)),
1332 collections.Counter,
1333 *_list_methods(collections.Counter, dict_non_mutating_methods),
1334 collections.Counter.elements,
1335 collections.Counter.most_common,
1336 object.__dir__,
1337 type.__dir__,
1338}
1339
1340BUILTIN_GETATTR: set[MayHaveGetattr] = {
1341 *BUILTIN_GETITEM,
1342 set,
1343 frozenset,
1344 object,
1345 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
1346 *NUMERICS,
1347 dict_keys,
1348 MethodDescriptorType,
1349 ModuleType,
1350}
1351
1352
1353BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
1354
1355EVALUATION_POLICIES = {
1356 "minimal": EvaluationPolicy(
1357 allow_builtins_access=True,
1358 allow_locals_access=False,
1359 allow_globals_access=False,
1360 allow_item_access=False,
1361 allow_attr_access=False,
1362 allowed_calls=set(),
1363 allow_any_calls=False,
1364 allow_all_operations=False,
1365 ),
1366 "limited": SelectivePolicy(
1367 allowed_getitem=BUILTIN_GETITEM,
1368 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
1369 allowed_getattr=BUILTIN_GETATTR,
1370 allowed_getattr_external={
1371 # pandas Series/Frame implements custom `__getattr__`
1372 ("pandas", "DataFrame"),
1373 ("pandas", "Series"),
1374 },
1375 allowed_operations=BUILTIN_OPERATIONS,
1376 allow_builtins_access=True,
1377 allow_locals_access=True,
1378 allow_globals_access=True,
1379 allow_getitem_on_types=True,
1380 allowed_calls=ALLOWED_CALLS,
1381 ),
1382 "unsafe": EvaluationPolicy(
1383 allow_builtins_access=True,
1384 allow_locals_access=True,
1385 allow_globals_access=True,
1386 allow_attr_access=True,
1387 allow_item_access=True,
1388 allow_any_calls=True,
1389 allow_all_operations=True,
1390 ),
1391}
1392
1393
1394__all__ = [
1395 "guarded_eval",
1396 "eval_node",
1397 "GuardRejection",
1398 "EvaluationContext",
1399 "_unbind_method",
1400]