1from copy import copy
2from inspect import isclass, signature, Signature, getmodule
3from typing import (
4 Annotated,
5 AnyStr,
6 Literal,
7 NamedTuple,
8 NewType,
9 Optional,
10 Protocol,
11 TypeGuard,
12 Union,
13 get_args,
14 get_origin,
15 is_typeddict,
16)
17from collections.abc import Callable, Sequence
18import ast
19import builtins
20import collections
21import dataclasses
22import operator
23import sys
24import typing
25import warnings
26from functools import cached_property
27from dataclasses import dataclass, field
28from types import MethodDescriptorType, ModuleType, MethodType
29
30from IPython.utils.decorators import undoc
31
32import types
33from typing import Self, LiteralString, get_type_hints
34
35if sys.version_info < (3, 12):
36 from typing_extensions import TypeAliasType
37else:
38 from typing import TypeAliasType
39
40
41@undoc
42class HasGetItem(Protocol):
43 def __getitem__(self, key) -> None:
44 ...
45
46
47@undoc
48class InstancesHaveGetItem(Protocol):
49 def __call__(self, *args, **kwargs) -> HasGetItem:
50 ...
51
52
53@undoc
54class HasGetAttr(Protocol):
55 def __getattr__(self, key) -> None:
56 ...
57
58
59@undoc
60class DoesNotHaveGetAttr(Protocol):
61 pass
62
63
64# By default `__getattr__` is not explicitly implemented on most objects
65MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
66
67
68def _unbind_method(func: Callable) -> Union[Callable, None]:
69 """Get unbound method for given bound method.
70
71 Returns None if cannot get unbound method, or method is already unbound.
72 """
73 owner = getattr(func, "__self__", None)
74 owner_class = type(owner)
75 name = getattr(func, "__name__", None)
76 instance_dict_overrides = getattr(owner, "__dict__", None)
77 if (
78 owner is not None
79 and name
80 and (
81 not instance_dict_overrides
82 or (instance_dict_overrides and name not in instance_dict_overrides)
83 )
84 ):
85 return getattr(owner_class, name)
86 return None
87
88
89@undoc
90@dataclass
91class EvaluationPolicy:
92 """Definition of evaluation policy."""
93
94 allow_locals_access: bool = False
95 allow_globals_access: bool = False
96 allow_item_access: bool = False
97 allow_attr_access: bool = False
98 allow_builtins_access: bool = False
99 allow_all_operations: bool = False
100 allow_any_calls: bool = False
101 allow_auto_import: bool = False
102 allowed_calls: set[Callable] = field(default_factory=set)
103
104 def can_get_item(self, value, item):
105 return self.allow_item_access
106
107 def can_get_attr(self, value, attr):
108 return self.allow_attr_access
109
110 def can_operate(self, dunders: tuple[str, ...], a, b=None):
111 if self.allow_all_operations:
112 return True
113
114 def can_call(self, func):
115 if self.allow_any_calls:
116 return True
117
118 if func in self.allowed_calls:
119 return True
120
121 owner_method = _unbind_method(func)
122
123 if owner_method and owner_method in self.allowed_calls:
124 return True
125
126
127def _get_external(module_name: str, access_path: Sequence[str]):
128 """Get value from external module given a dotted access path.
129
130 Only gets value if the module is already imported.
131
132 Raises:
133 * `KeyError` if module is removed not found, and
134 * `AttributeError` if access path does not match an exported object
135 """
136 try:
137 member_type = sys.modules[module_name]
138 # standard module
139 for attr in access_path:
140 member_type = getattr(member_type, attr)
141 return member_type
142 except (KeyError, AttributeError):
143 # handle modules in namespace packages
144 module_path = ".".join([module_name, *access_path])
145 if module_path in sys.modules:
146 return sys.modules[module_path]
147 raise
148
149
150def _has_original_dunder_external(
151 value,
152 module_name: str,
153 access_path: Sequence[str],
154 method_name: str,
155):
156 if module_name not in sys.modules:
157 full_module_path = ".".join([module_name, *access_path])
158 if full_module_path not in sys.modules:
159 # LBYLB as it is faster
160 return False
161 try:
162 member_type = _get_external(module_name, access_path)
163 value_type = type(value)
164 if type(value) == member_type:
165 return True
166 if isinstance(member_type, ModuleType):
167 value_module = getmodule(value_type)
168 if not value_module or not value_module.__name__:
169 return False
170 if (
171 value_module.__name__ == member_type.__name__
172 or value_module.__name__.startswith(member_type.__name__ + ".")
173 ):
174 return True
175 if method_name == "__getattribute__":
176 # we have to short-circuit here due to an unresolved issue in
177 # `isinstance` implementation: https://bugs.python.org/issue32683
178 return False
179 if not isinstance(member_type, ModuleType) and isinstance(value, member_type):
180 method = getattr(value_type, method_name, None)
181 member_method = getattr(member_type, method_name, None)
182 if member_method == method:
183 return True
184 if isinstance(member_type, ModuleType):
185 method = getattr(value_type, method_name, None)
186 for base_class in value_type.__mro__[1:]:
187 base_module = getmodule(base_class)
188 if base_module and (
189 base_module.__name__ == member_type.__name__
190 or base_module.__name__.startswith(member_type.__name__ + ".")
191 ):
192 # Check if the method comes from this trusted base class
193 base_method = getattr(base_class, method_name, None)
194 if base_method is not None and base_method == method:
195 return True
196 except (AttributeError, KeyError):
197 return False
198
199
200def _has_original_dunder(
201 value, allowed_types, allowed_methods, allowed_external, method_name
202):
203 # note: Python ignores `__getattr__`/`__getitem__` on instances,
204 # we only need to check at class level
205 value_type = type(value)
206
207 # strict type check passes → no need to check method
208 if value_type in allowed_types:
209 return True
210
211 method = getattr(value_type, method_name, None)
212
213 if method is None:
214 return None
215
216 if method in allowed_methods:
217 return True
218
219 for module_name, *access_path in allowed_external:
220 if _has_original_dunder_external(value, module_name, access_path, method_name):
221 return True
222
223 return False
224
225
226def _coerce_path_to_tuples(
227 allow_list: set[tuple[str, ...] | str],
228) -> set[tuple[str, ...]]:
229 """Replace dotted paths on the provided allow-list with tuples."""
230 return {
231 path if isinstance(path, tuple) else tuple(path.split("."))
232 for path in allow_list
233 }
234
235
236@undoc
237@dataclass
238class SelectivePolicy(EvaluationPolicy):
239 allowed_getitem: set[InstancesHaveGetItem] = field(default_factory=set)
240 allowed_getitem_external: set[tuple[str, ...] | str] = field(default_factory=set)
241
242 allowed_getattr: set[MayHaveGetattr] = field(default_factory=set)
243 allowed_getattr_external: set[tuple[str, ...] | str] = field(default_factory=set)
244
245 allowed_operations: set = field(default_factory=set)
246 allowed_operations_external: set[tuple[str, ...] | str] = field(default_factory=set)
247
248 allow_getitem_on_types: bool = field(default_factory=bool)
249
250 _operation_methods_cache: dict[str, set[Callable]] = field(
251 default_factory=dict, init=False
252 )
253
254 def can_get_attr(self, value, attr):
255 allowed_getattr_external = _coerce_path_to_tuples(self.allowed_getattr_external)
256
257 has_original_attribute = _has_original_dunder(
258 value,
259 allowed_types=self.allowed_getattr,
260 allowed_methods=self._getattribute_methods,
261 allowed_external=allowed_getattr_external,
262 method_name="__getattribute__",
263 )
264 has_original_attr = _has_original_dunder(
265 value,
266 allowed_types=self.allowed_getattr,
267 allowed_methods=self._getattr_methods,
268 allowed_external=allowed_getattr_external,
269 method_name="__getattr__",
270 )
271
272 accept = False
273
274 # Many objects do not have `__getattr__`, this is fine.
275 if has_original_attr is None and has_original_attribute:
276 accept = True
277 else:
278 # Accept objects without modifications to `__getattr__` and `__getattribute__`
279 accept = has_original_attr and has_original_attribute
280
281 if accept:
282 # We still need to check for overridden properties.
283
284 value_class = type(value)
285 if not hasattr(value_class, attr):
286 return True
287
288 class_attr_val = getattr(value_class, attr)
289 is_property = isinstance(class_attr_val, property)
290
291 if not is_property:
292 return True
293
294 # Properties in allowed types are ok (although we do not include any
295 # properties in our default allow list currently).
296 if type(value) in self.allowed_getattr:
297 return True # pragma: no cover
298
299 # Properties in subclasses of allowed types may be ok if not changed
300 for module_name, *access_path in allowed_getattr_external:
301 try:
302 external_class = _get_external(module_name, access_path)
303 external_class_attr_val = getattr(external_class, attr)
304 except (KeyError, AttributeError):
305 return False # pragma: no cover
306 return class_attr_val == external_class_attr_val
307
308 return False
309
310 def can_get_item(self, value, item):
311 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
312 allowed_getitem_external = _coerce_path_to_tuples(self.allowed_getitem_external)
313 if self.allow_getitem_on_types:
314 # e.g. Union[str, int] or Literal[True, 1]
315 if isinstance(value, (typing._SpecialForm, typing._BaseGenericAlias)):
316 return True
317 # PEP 560 e.g. list[str]
318 if isinstance(value, type) and hasattr(value, "__class_getitem__"):
319 return True
320 return _has_original_dunder(
321 value,
322 allowed_types=self.allowed_getitem,
323 allowed_methods=self._getitem_methods,
324 allowed_external=allowed_getitem_external,
325 method_name="__getitem__",
326 )
327
328 def can_operate(self, dunders: tuple[str, ...], a, b=None):
329 allowed_operations_external = _coerce_path_to_tuples(
330 self.allowed_operations_external
331 )
332 objects = [a]
333 if b is not None:
334 objects.append(b)
335 return all(
336 [
337 _has_original_dunder(
338 obj,
339 allowed_types=self.allowed_operations,
340 allowed_methods=self._operator_dunder_methods(dunder),
341 allowed_external=allowed_operations_external,
342 method_name=dunder,
343 )
344 for dunder in dunders
345 for obj in objects
346 ]
347 )
348
349 def _operator_dunder_methods(self, dunder: str) -> set[Callable]:
350 if dunder not in self._operation_methods_cache:
351 self._operation_methods_cache[dunder] = self._safe_get_methods(
352 self.allowed_operations, dunder
353 )
354 return self._operation_methods_cache[dunder]
355
356 @cached_property
357 def _getitem_methods(self) -> set[Callable]:
358 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
359
360 @cached_property
361 def _getattr_methods(self) -> set[Callable]:
362 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
363
364 @cached_property
365 def _getattribute_methods(self) -> set[Callable]:
366 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
367
368 def _safe_get_methods(self, classes, name) -> set[Callable]:
369 return {
370 method
371 for class_ in classes
372 for method in [getattr(class_, name, None)]
373 if method
374 }
375
376
377class _DummyNamedTuple(NamedTuple):
378 """Used internally to retrieve methods of named tuple instance."""
379
380
381EvaluationPolicyName = Literal["forbidden", "minimal", "limited", "unsafe", "dangerous"]
382
383
384@dataclass
385class EvaluationContext:
386 #: Local namespace
387 locals: dict
388 #: Global namespace
389 globals: dict
390 #: Evaluation policy identifier
391 evaluation: EvaluationPolicyName = "forbidden"
392 #: Whether the evaluation of code takes place inside of a subscript.
393 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
394 in_subscript: bool = False
395 #: Auto import method
396 auto_import: Callable[[Sequence[str]], ModuleType] | None = None
397 #: Overrides for evaluation policy
398 policy_overrides: dict = field(default_factory=dict)
399 #: Transient local namespace used to store mocks
400 transient_locals: dict = field(default_factory=dict)
401 #: Transients of class level
402 class_transients: dict | None = None
403 #: Instance variable name used in the method definition
404 instance_arg_name: str | None = None
405 #: Currently associated value
406 #: Useful for adding items to _Duck on annotated assignment
407 current_value: ast.AST | None = None
408
409 def replace(self, /, **changes):
410 """Return a new copy of the context, with specified changes"""
411 return dataclasses.replace(self, **changes)
412
413
414class _IdentitySubscript:
415 """Returns the key itself when item is requested via subscript."""
416
417 def __getitem__(self, key):
418 return key
419
420
421IDENTITY_SUBSCRIPT = _IdentitySubscript()
422SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
423UNKNOWN_SIGNATURE = Signature()
424NOT_EVALUATED = object()
425
426
427class GuardRejection(Exception):
428 """Exception raised when guard rejects evaluation attempt."""
429
430 pass
431
432
433def guarded_eval(code: str, context: EvaluationContext):
434 """Evaluate provided code in the evaluation context.
435
436 If evaluation policy given by context is set to ``forbidden``
437 no evaluation will be performed; if it is set to ``dangerous``
438 standard :func:`eval` will be used; finally, for any other,
439 policy :func:`eval_node` will be called on parsed AST.
440 """
441 locals_ = context.locals
442
443 if context.evaluation == "forbidden":
444 raise GuardRejection("Forbidden mode")
445
446 # note: not using `ast.literal_eval` as it does not implement
447 # getitem at all, for example it fails on simple `[0][1]`
448
449 if context.in_subscript:
450 # syntactic sugar for ellipsis (:) is only available in subscripts
451 # so we need to trick the ast parser into thinking that we have
452 # a subscript, but we need to be able to later recognise that we did
453 # it so we can ignore the actual __getitem__ operation
454 if not code:
455 return tuple()
456 locals_ = locals_.copy()
457 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
458 code = SUBSCRIPT_MARKER + "[" + code + "]"
459 context = context.replace(locals=locals_)
460
461 if context.evaluation == "dangerous":
462 return eval(code, context.globals, context.locals)
463
464 node = ast.parse(code, mode="exec")
465
466 return eval_node(node, context)
467
468
469BINARY_OP_DUNDERS: dict[type[ast.operator], tuple[str]] = {
470 ast.Add: ("__add__",),
471 ast.Sub: ("__sub__",),
472 ast.Mult: ("__mul__",),
473 ast.Div: ("__truediv__",),
474 ast.FloorDiv: ("__floordiv__",),
475 ast.Mod: ("__mod__",),
476 ast.Pow: ("__pow__",),
477 ast.LShift: ("__lshift__",),
478 ast.RShift: ("__rshift__",),
479 ast.BitOr: ("__or__",),
480 ast.BitXor: ("__xor__",),
481 ast.BitAnd: ("__and__",),
482 ast.MatMult: ("__matmul__",),
483}
484
485COMP_OP_DUNDERS: dict[type[ast.cmpop], tuple[str, ...]] = {
486 ast.Eq: ("__eq__",),
487 ast.NotEq: ("__ne__", "__eq__"),
488 ast.Lt: ("__lt__", "__gt__"),
489 ast.LtE: ("__le__", "__ge__"),
490 ast.Gt: ("__gt__", "__lt__"),
491 ast.GtE: ("__ge__", "__le__"),
492 ast.In: ("__contains__",),
493 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
494}
495
496UNARY_OP_DUNDERS: dict[type[ast.unaryop], tuple[str, ...]] = {
497 ast.USub: ("__neg__",),
498 ast.UAdd: ("__pos__",),
499 # we have to check both __inv__ and __invert__!
500 ast.Invert: ("__invert__", "__inv__"),
501 ast.Not: ("__not__",),
502}
503
504GENERIC_CONTAINER_TYPES = (dict, list, set, tuple, frozenset)
505
506
507class ImpersonatingDuck:
508 """A dummy class used to create objects of other classes without calling their ``__init__``"""
509
510 # no-op: override __class__ to impersonate
511
512
513class _Duck:
514 """A dummy class used to create objects pretending to have given attributes"""
515
516 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
517 self.attributes = attributes if attributes is not None else {}
518 self.items = items if items is not None else {}
519
520 def __getattr__(self, attr: str):
521 return self.attributes[attr]
522
523 def __hasattr__(self, attr: str):
524 return attr in self.attributes
525
526 def __dir__(self):
527 return [*dir(super), *self.attributes]
528
529 def __getitem__(self, key: str):
530 return self.items[key]
531
532 def __hasitem__(self, key: str):
533 return self.items[key]
534
535 def _ipython_key_completions_(self):
536 return self.items.keys()
537
538
539def _find_dunder(node_op, dunders) -> Union[tuple[str, ...], None]:
540 dunder = None
541 for op, candidate_dunder in dunders.items():
542 if isinstance(node_op, op):
543 dunder = candidate_dunder
544 return dunder
545
546
547def get_policy(context: EvaluationContext) -> EvaluationPolicy:
548 policy = copy(EVALUATION_POLICIES[context.evaluation])
549
550 for key, value in context.policy_overrides.items():
551 if hasattr(policy, key):
552 setattr(policy, key, value)
553 return policy
554
555
556def _validate_policy_overrides(
557 policy_name: EvaluationPolicyName, policy_overrides: dict
558) -> bool:
559 policy = EVALUATION_POLICIES[policy_name]
560
561 all_good = True
562 for key, value in policy_overrides.items():
563 if not hasattr(policy, key):
564 warnings.warn(
565 f"Override {key!r} is not valid with {policy_name!r} evaluation policy"
566 )
567 all_good = False
568 return all_good
569
570
571def _is_type_annotation(obj) -> bool:
572 """
573 Returns True if obj is a type annotation, False otherwise.
574 """
575 if isinstance(obj, type):
576 return True
577 if isinstance(obj, types.GenericAlias):
578 return True
579 if hasattr(types, "UnionType") and isinstance(obj, types.UnionType):
580 return True
581 if isinstance(obj, (typing._SpecialForm, typing._BaseGenericAlias)):
582 return True
583 if isinstance(obj, typing.TypeVar):
584 return True
585 # Types that support __class_getitem__
586 if isinstance(obj, type) and hasattr(obj, "__class_getitem__"):
587 return True
588 # Fallback: check if get_origin returns something
589 if hasattr(typing, "get_origin") and get_origin(obj) is not None:
590 return True
591
592 return False
593
594
595def _handle_assign(node: ast.Assign, context: EvaluationContext):
596 value = eval_node(node.value, context)
597 transient_locals = context.transient_locals
598 policy = get_policy(context)
599 class_transients = context.class_transients
600 for target in node.targets:
601 if isinstance(target, (ast.Tuple, ast.List)):
602 # Handle unpacking assignment
603 values = list(value)
604 targets = target.elts
605 starred = [i for i, t in enumerate(targets) if isinstance(t, ast.Starred)]
606
607 # Unified handling: treat no starred as starred at end
608 star_or_last_idx = starred[0] if starred else len(targets)
609
610 # Before starred
611 for i in range(star_or_last_idx):
612 # Check for self.x assignment
613 if _is_instance_attribute_assignment(targets[i], context):
614 class_transients[targets[i].attr] = values[i]
615 else:
616 transient_locals[targets[i].id] = values[i]
617
618 # Starred if exists
619 if starred:
620 end = len(values) - (len(targets) - star_or_last_idx - 1)
621 if _is_instance_attribute_assignment(
622 targets[star_or_last_idx], context
623 ):
624 class_transients[targets[star_or_last_idx].attr] = values[
625 star_or_last_idx:end
626 ]
627 else:
628 transient_locals[targets[star_or_last_idx].value.id] = values[
629 star_or_last_idx:end
630 ]
631
632 # After starred
633 for i in range(star_or_last_idx + 1, len(targets)):
634 if _is_instance_attribute_assignment(targets[i], context):
635 class_transients[targets[i].attr] = values[
636 len(values) - (len(targets) - i)
637 ]
638 else:
639 transient_locals[targets[i].id] = values[
640 len(values) - (len(targets) - i)
641 ]
642 elif isinstance(target, ast.Subscript):
643 if isinstance(target.value, ast.Name):
644 name = target.value.id
645 container = transient_locals.get(name)
646 if container is None:
647 container = context.locals.get(name)
648 if container is None:
649 container = context.globals.get(name)
650 if container is None:
651 raise NameError(
652 f"{name} not found in locals, globals, nor builtins"
653 )
654 storage_dict = transient_locals
655 storage_key = name
656 elif isinstance(
657 target.value, ast.Attribute
658 ) and _is_instance_attribute_assignment(target.value, context):
659 attr = target.value.attr
660 container = class_transients.get(attr, None)
661 if container is None:
662 raise NameError(f"{attr} not found in class transients")
663 storage_dict = class_transients
664 storage_key = attr
665 else:
666 return
667
668 key = eval_node(target.slice, context)
669 attributes = (
670 dict.fromkeys(dir(container))
671 if policy.can_call(container.__dir__)
672 else {}
673 )
674 items = {}
675
676 if policy.can_get_item(container, None):
677 try:
678 items = dict(container.items())
679 except Exception:
680 pass
681
682 items[key] = value
683 duck_container = _Duck(attributes=attributes, items=items)
684 storage_dict[storage_key] = duck_container
685 elif _is_instance_attribute_assignment(target, context):
686 class_transients[target.attr] = value
687 else:
688 transient_locals[target.id] = value
689 return None
690
691
692def _handle_annassign(node, context):
693 context_with_value = context.replace(current_value=getattr(node, "value", None))
694 annotation_result = eval_node(node.annotation, context_with_value)
695 if _is_type_annotation(annotation_result):
696 annotation_value = _resolve_annotation(annotation_result, context)
697 # Use Value for generic types
698 use_value = (
699 isinstance(annotation_value, GENERIC_CONTAINER_TYPES) and node.value is not None
700 )
701 else:
702 annotation_value = annotation_result
703 use_value = False
704
705 # LOCAL VARIABLE
706 if getattr(node, "simple", False) and isinstance(node.target, ast.Name):
707 name = node.target.id
708 if use_value:
709 return _handle_assign(
710 ast.Assign(targets=[node.target], value=node.value), context
711 )
712 context.transient_locals[name] = annotation_value
713 return None
714
715 # INSTANCE ATTRIBUTE
716 if _is_instance_attribute_assignment(node.target, context):
717 attr = node.target.attr
718 if use_value:
719 return _handle_assign(
720 ast.Assign(targets=[node.target], value=node.value), context
721 )
722 context.class_transients[attr] = annotation_value
723 return None
724
725 return None
726
727def _extract_args_and_kwargs(node: ast.Call, context: EvaluationContext):
728 args = [eval_node(arg, context) for arg in node.args]
729 kwargs = {
730 k: v
731 for kw in node.keywords
732 for k, v in (
733 {kw.arg: eval_node(kw.value, context)}
734 if kw.arg
735 else eval_node(kw.value, context)
736 ).items()
737 }
738 return args, kwargs
739
740
741def _is_instance_attribute_assignment(
742 target: ast.AST, context: EvaluationContext
743) -> bool:
744 """Return True if target is an attribute access on the instance argument."""
745 return (
746 context.class_transients is not None
747 and context.instance_arg_name is not None
748 and isinstance(target, ast.Attribute)
749 and isinstance(getattr(target, "value", None), ast.Name)
750 and getattr(target.value, "id", None) == context.instance_arg_name
751 )
752
753
754def _get_coroutine_attributes() -> dict[str, Optional[object]]:
755 async def _dummy():
756 return None
757
758 coro = _dummy()
759 try:
760 return {attr: getattr(coro, attr, None) for attr in dir(coro)}
761 finally:
762 coro.close()
763
764
765def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
766 """Evaluate AST node in provided context.
767
768 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
769
770 Does not evaluate actions that always have side effects:
771
772 - class definitions (``class sth: ...``)
773 - function definitions (``def sth: ...``)
774 - variable assignments (``x = 1``)
775 - augmented assignments (``x += 1``)
776 - deletions (``del x``)
777
778 Does not evaluate operations which do not return values:
779
780 - assertions (``assert x``)
781 - pass (``pass``)
782 - imports (``import x``)
783 - control flow:
784
785 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
786 - loops (``for`` and ``while``)
787 - exception handling
788
789 The purpose of this function is to guard against unwanted side-effects;
790 it does not give guarantees on protection from malicious code execution.
791 """
792 policy = get_policy(context)
793
794 if node is None:
795 return None
796 if isinstance(node, (ast.Interactive, ast.Module)):
797 result = None
798 for child_node in node.body:
799 result = eval_node(child_node, context)
800 return result
801 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
802 is_async = isinstance(node, ast.AsyncFunctionDef)
803 func_locals = context.transient_locals.copy()
804 func_context = context.replace(transient_locals=func_locals)
805 is_property = False
806 is_static = False
807 is_classmethod = False
808 for decorator_node in node.decorator_list:
809 try:
810 decorator = eval_node(decorator_node, context)
811 except NameError:
812 # if the decorator is not yet defined this is fine
813 # especialy because we don't handle imports yet
814 continue
815 if decorator is property:
816 is_property = True
817 elif decorator is staticmethod:
818 is_static = True
819 elif decorator is classmethod:
820 is_classmethod = True
821
822 if func_context.class_transients is not None:
823 if not is_static and not is_classmethod:
824 func_context.instance_arg_name = (
825 node.args.args[0].arg if node.args.args else None
826 )
827
828 return_type = eval_node(node.returns, context=context)
829
830 for child_node in node.body:
831 eval_node(child_node, func_context)
832
833 if is_property:
834 if return_type is not None:
835 if _is_type_annotation(return_type):
836 context.transient_locals[node.name] = _resolve_annotation(
837 return_type, context
838 )
839 else:
840 context.transient_locals[node.name] = return_type
841 else:
842 return_value = _infer_return_value(node, func_context)
843 context.transient_locals[node.name] = return_value
844
845 return None
846
847 def dummy_function(*args, **kwargs):
848 pass
849
850 if return_type is not None:
851 if _is_type_annotation(return_type):
852 dummy_function.__annotations__["return"] = return_type
853 else:
854 dummy_function.__inferred_return__ = return_type
855 else:
856 inferred_return = _infer_return_value(node, func_context)
857 if inferred_return is not None:
858 dummy_function.__inferred_return__ = inferred_return
859
860 dummy_function.__name__ = node.name
861 dummy_function.__node__ = node
862 dummy_function.__is_async__ = is_async
863 context.transient_locals[node.name] = dummy_function
864 return None
865 if isinstance(node, ast.Lambda):
866
867 def dummy_function(*args, **kwargs):
868 pass
869
870 dummy_function.__inferred_return__ = eval_node(node.body, context)
871 return dummy_function
872 if isinstance(node, ast.ClassDef):
873 # TODO support class decorators?
874 class_locals = {}
875 outer_locals = context.locals.copy()
876 outer_locals.update(context.transient_locals)
877 class_context = context.replace(
878 transient_locals=class_locals, locals=outer_locals
879 )
880 class_context.class_transients = class_locals
881 for child_node in node.body:
882 eval_node(child_node, class_context)
883 bases = tuple([eval_node(base, context) for base in node.bases])
884 dummy_class = type(node.name, bases, class_locals)
885 context.transient_locals[node.name] = dummy_class
886 return None
887 if isinstance(node, ast.Await):
888 value = eval_node(node.value, context)
889 if hasattr(value, "__awaited_type__"):
890 return value.__awaited_type__
891 return value
892 if isinstance(node, ast.While):
893 loop_locals = context.transient_locals.copy()
894 loop_context = context.replace(transient_locals=loop_locals)
895
896 result = None
897 for stmt in node.body:
898 result = eval_node(stmt, loop_context)
899
900 policy = get_policy(context)
901 merged_locals = _merge_dicts_by_key(
902 [loop_locals, context.transient_locals.copy()], policy
903 )
904 context.transient_locals.update(merged_locals)
905
906 return result
907 if isinstance(node, ast.For):
908 try:
909 iterable = eval_node(node.iter, context)
910 except Exception:
911 iterable = None
912
913 sample = None
914 if iterable is not None:
915 try:
916 if policy.can_call(getattr(iterable, "__iter__", None)):
917 sample = next(iter(iterable))
918 except Exception:
919 sample = None
920
921 loop_locals = context.transient_locals.copy()
922 loop_context = context.replace(transient_locals=loop_locals)
923
924 if sample is not None:
925 try:
926 fake_assign = ast.Assign(
927 targets=[node.target], value=ast.Constant(value=sample)
928 )
929 _handle_assign(fake_assign, loop_context)
930 except Exception:
931 pass
932
933 result = None
934 for stmt in node.body:
935 result = eval_node(stmt, loop_context)
936
937 policy = get_policy(context)
938 merged_locals = _merge_dicts_by_key(
939 [loop_locals, context.transient_locals.copy()], policy
940 )
941 context.transient_locals.update(merged_locals)
942
943 return result
944 if isinstance(node, ast.If):
945 branches = []
946 current = node
947 result = None
948 while True:
949 branch_locals = context.transient_locals.copy()
950 branch_context = context.replace(transient_locals=branch_locals)
951 for stmt in current.body:
952 result = eval_node(stmt, branch_context)
953 branches.append(branch_locals)
954 if not current.orelse:
955 break
956 elif len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
957 # It's an elif - continue loop
958 current = current.orelse[0]
959 else:
960 # It's an else block - process and break
961 else_locals = context.transient_locals.copy()
962 else_context = context.replace(transient_locals=else_locals)
963 for stmt in current.orelse:
964 result = eval_node(stmt, else_context)
965 branches.append(else_locals)
966 break
967 branches.append(context.transient_locals.copy())
968 policy = get_policy(context)
969 merged_locals = _merge_dicts_by_key(branches, policy)
970 context.transient_locals.update(merged_locals)
971 return result
972 if isinstance(node, ast.Assign):
973 return _handle_assign(node, context)
974 if isinstance(node, ast.AnnAssign):
975 return _handle_annassign(node, context)
976 if isinstance(node, ast.Expression):
977 return eval_node(node.body, context)
978 if isinstance(node, ast.Expr):
979 return eval_node(node.value, context)
980 if isinstance(node, ast.Pass):
981 return None
982 if isinstance(node, ast.Import):
983 # TODO: populate transient_locals
984 return None
985 if isinstance(node, (ast.AugAssign, ast.Delete)):
986 return None
987 if isinstance(node, (ast.Global, ast.Nonlocal)):
988 return None
989 if isinstance(node, ast.BinOp):
990 left = eval_node(node.left, context)
991 right = eval_node(node.right, context)
992 if (
993 isinstance(node.op, ast.BitOr)
994 and _is_type_annotation(left)
995 and _is_type_annotation(right)
996 ):
997 left_duck = (
998 _Duck(dict.fromkeys(dir(left)))
999 if policy.can_call(left.__dir__)
1000 else _Duck()
1001 )
1002 right_duck = (
1003 _Duck(dict.fromkeys(dir(right)))
1004 if policy.can_call(right.__dir__)
1005 else _Duck()
1006 )
1007 value_node = context.current_value
1008 if value_node is not None and isinstance(value_node, ast.Dict):
1009 if dict in [left, right]:
1010 return _merge_values(
1011 [left_duck, right_duck, ast.literal_eval(value_node)],
1012 policy=get_policy(context),
1013 )
1014 return _merge_values([left_duck, right_duck], policy=get_policy(context))
1015 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
1016 if dunders:
1017 if policy.can_operate(dunders, left, right):
1018 return getattr(left, dunders[0])(right)
1019 else:
1020 raise GuardRejection(
1021 f"Operation (`{dunders}`) for",
1022 type(left),
1023 f"not allowed in {context.evaluation} mode",
1024 )
1025 if isinstance(node, ast.Compare):
1026 left = eval_node(node.left, context)
1027 all_true = True
1028 negate = False
1029 for op, right in zip(node.ops, node.comparators):
1030 right = eval_node(right, context)
1031 dunder = None
1032 dunders = _find_dunder(op, COMP_OP_DUNDERS)
1033 if not dunders:
1034 if isinstance(op, ast.NotIn):
1035 dunders = COMP_OP_DUNDERS[ast.In]
1036 negate = True
1037 if isinstance(op, ast.Is):
1038 dunder = "is_"
1039 if isinstance(op, ast.IsNot):
1040 dunder = "is_"
1041 negate = True
1042 if not dunder and dunders:
1043 dunder = dunders[0]
1044 if dunder:
1045 a, b = (right, left) if dunder == "__contains__" else (left, right)
1046 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
1047 result = getattr(operator, dunder)(a, b)
1048 if negate:
1049 result = not result
1050 if not result:
1051 all_true = False
1052 left = right
1053 else:
1054 raise GuardRejection(
1055 f"Comparison (`{dunder}`) for",
1056 type(left),
1057 f"not allowed in {context.evaluation} mode",
1058 )
1059 else:
1060 raise ValueError(
1061 f"Comparison `{dunder}` not supported"
1062 ) # pragma: no cover
1063 return all_true
1064 if isinstance(node, ast.Constant):
1065 return node.value
1066 if isinstance(node, ast.Tuple):
1067 return tuple(eval_node(e, context) for e in node.elts)
1068 if isinstance(node, ast.List):
1069 return [eval_node(e, context) for e in node.elts]
1070 if isinstance(node, ast.Set):
1071 return {eval_node(e, context) for e in node.elts}
1072 if isinstance(node, ast.Dict):
1073 return dict(
1074 zip(
1075 [eval_node(k, context) for k in node.keys],
1076 [eval_node(v, context) for v in node.values],
1077 )
1078 )
1079 if isinstance(node, ast.Slice):
1080 return slice(
1081 eval_node(node.lower, context),
1082 eval_node(node.upper, context),
1083 eval_node(node.step, context),
1084 )
1085 if isinstance(node, ast.UnaryOp):
1086 value = eval_node(node.operand, context)
1087 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
1088 if dunders:
1089 if policy.can_operate(dunders, value):
1090 try:
1091 return getattr(value, dunders[0])()
1092 except AttributeError:
1093 raise TypeError(
1094 f"bad operand type for unary {node.op}: {type(value)}"
1095 )
1096 else:
1097 raise GuardRejection(
1098 f"Operation (`{dunders}`) for",
1099 type(value),
1100 f"not allowed in {context.evaluation} mode",
1101 )
1102 if isinstance(node, ast.Subscript):
1103 value = eval_node(node.value, context)
1104 slice_ = eval_node(node.slice, context)
1105 if policy.can_get_item(value, slice_):
1106 return value[slice_]
1107 raise GuardRejection(
1108 "Subscript access (`__getitem__`) for",
1109 type(value), # not joined to avoid calling `repr`
1110 f" not allowed in {context.evaluation} mode",
1111 )
1112 if isinstance(node, ast.Name):
1113 return _eval_node_name(node.id, context)
1114 if isinstance(node, ast.Attribute):
1115 if (
1116 context.class_transients is not None
1117 and isinstance(node.value, ast.Name)
1118 and node.value.id == context.instance_arg_name
1119 ):
1120 return context.class_transients.get(node.attr)
1121 value = eval_node(node.value, context)
1122 if policy.can_get_attr(value, node.attr):
1123 return getattr(value, node.attr)
1124 try:
1125 cls = (
1126 value if isinstance(value, type) else getattr(value, "__class__", None)
1127 )
1128 if cls is not None:
1129 resolved_hints = get_type_hints(
1130 cls,
1131 globalns=(context.globals or {}),
1132 localns=(context.locals or {}),
1133 )
1134 if node.attr in resolved_hints:
1135 annotated = resolved_hints[node.attr]
1136 return _resolve_annotation(annotated, context)
1137 except Exception:
1138 # Fall through to the guard rejection
1139 pass
1140 raise GuardRejection(
1141 "Attribute access (`__getattr__`) for",
1142 type(value), # not joined to avoid calling `repr`
1143 f"not allowed in {context.evaluation} mode",
1144 )
1145 if isinstance(node, ast.IfExp):
1146 test = eval_node(node.test, context)
1147 if test:
1148 return eval_node(node.body, context)
1149 else:
1150 return eval_node(node.orelse, context)
1151 if isinstance(node, ast.Call):
1152 func = eval_node(node.func, context)
1153 if policy.can_call(func):
1154 args, kwargs = _extract_args_and_kwargs(node, context)
1155 return func(*args, **kwargs)
1156 if isclass(func):
1157 # this code path gets entered when calling class e.g. `MyClass()`
1158 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
1159 # Should return `MyClass` if `__new__` is not overridden,
1160 # otherwise whatever `__new__` return type is.
1161 overridden_return_type = _eval_return_type(func.__new__, node, context)
1162 if overridden_return_type is not NOT_EVALUATED:
1163 return overridden_return_type
1164 return _create_duck_for_heap_type(func)
1165 else:
1166 inferred_return = getattr(func, "__inferred_return__", NOT_EVALUATED)
1167 return_type = _eval_return_type(func, node, context)
1168 if getattr(func, "__is_async__", False):
1169 awaited_type = (
1170 inferred_return if inferred_return is not None else return_type
1171 )
1172 coroutine_duck = _Duck(attributes=_get_coroutine_attributes())
1173 coroutine_duck.__awaited_type__ = awaited_type
1174 return coroutine_duck
1175 if inferred_return is not NOT_EVALUATED:
1176 return inferred_return
1177 if return_type is not NOT_EVALUATED:
1178 return return_type
1179 raise GuardRejection(
1180 "Call for",
1181 func, # not joined to avoid calling `repr`
1182 f"not allowed in {context.evaluation} mode",
1183 )
1184 if isinstance(node, ast.Assert):
1185 # message is always the second item, so if it is defined user would be completing
1186 # on the message, not on the assertion test
1187 if node.msg:
1188 return eval_node(node.msg, context)
1189 return eval_node(node.test, context)
1190 return None
1191
1192
1193def _merge_dicts_by_key(dicts: list, policy: EvaluationPolicy):
1194 """Merge multiple dictionaries, combining values for each key."""
1195 if len(dicts) == 1:
1196 return dicts[0]
1197
1198 all_keys = set()
1199 for d in dicts:
1200 all_keys.update(d.keys())
1201
1202 merged = {}
1203 for key in all_keys:
1204 values = [d[key] for d in dicts if key in d]
1205 if values:
1206 merged[key] = _merge_values(values, policy)
1207
1208 return merged
1209
1210
1211def _merge_values(values, policy: EvaluationPolicy):
1212 """Recursively merge multiple values, combining attributes and dict items."""
1213 if len(values) == 1:
1214 return values[0]
1215
1216 types = {type(v) for v in values}
1217 merged_items = None
1218 key_values = {}
1219 attributes = set()
1220 for v in values:
1221 if policy.can_call(v.__dir__):
1222 attributes.update(dir(v))
1223 try:
1224 if policy.can_call(v.items):
1225 try:
1226 for k, val in v.items():
1227 key_values.setdefault(k, []).append(val)
1228 except Exception as e:
1229 pass
1230 elif policy.can_call(v.keys):
1231 try:
1232 for k in v.keys():
1233 key_values.setdefault(k, []).append(None)
1234 except Exception as e:
1235 pass
1236 except Exception as e:
1237 pass
1238
1239 if key_values:
1240 merged_items = {
1241 k: _merge_values(vals, policy) if vals[0] is not None else None
1242 for k, vals in key_values.items()
1243 }
1244
1245 if len(types) == 1:
1246 t = next(iter(types))
1247 if t not in (dict,) and not (
1248 hasattr(next(iter(values)), "__getitem__")
1249 and (
1250 hasattr(next(iter(values)), "items")
1251 or hasattr(next(iter(values)), "keys")
1252 )
1253 ):
1254 if t in (list, set, tuple):
1255 return t
1256 return values[0]
1257
1258 return _Duck(attributes=dict.fromkeys(attributes), items=merged_items)
1259
1260
1261def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
1262 """Infer the return value(s) of a function by evaluating all return statements."""
1263 return_values = _collect_return_values(node.body, context)
1264
1265 if not return_values:
1266 return None
1267 if len(return_values) == 1:
1268 return return_values[0]
1269
1270 policy = get_policy(context)
1271 return _merge_values(return_values, policy)
1272
1273
1274def _collect_return_values(body, context):
1275 """Recursively collect return values from a list of AST statements."""
1276 return_values = []
1277 for stmt in body:
1278 if isinstance(stmt, ast.Return):
1279 if stmt.value is None:
1280 continue
1281 try:
1282 value = eval_node(stmt.value, context)
1283 if value is not None and value is not NOT_EVALUATED:
1284 return_values.append(value)
1285 except Exception:
1286 pass
1287 if isinstance(
1288 stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)
1289 ):
1290 continue
1291 elif hasattr(stmt, "body") and isinstance(stmt.body, list):
1292 return_values.extend(_collect_return_values(stmt.body, context))
1293 if isinstance(stmt, ast.Try):
1294 for h in stmt.handlers:
1295 if hasattr(h, "body"):
1296 return_values.extend(_collect_return_values(h.body, context))
1297 if hasattr(stmt, "orelse"):
1298 return_values.extend(_collect_return_values(stmt.orelse, context))
1299 if hasattr(stmt, "finalbody"):
1300 return_values.extend(_collect_return_values(stmt.finalbody, context))
1301 if hasattr(stmt, "orelse") and isinstance(stmt.orelse, list):
1302 return_values.extend(_collect_return_values(stmt.orelse, context))
1303 return return_values
1304
1305
1306def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
1307 """Evaluate return type of a given callable function.
1308
1309 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
1310 """
1311 try:
1312 sig = signature(func)
1313 except ValueError:
1314 sig = UNKNOWN_SIGNATURE
1315 # if annotation was not stringized, or it was stringized
1316 # but resolved by signature call we know the return type
1317 not_empty = sig.return_annotation is not Signature.empty
1318 if not_empty:
1319 return _resolve_annotation(sig.return_annotation, context, sig, func, node)
1320 return NOT_EVALUATED
1321
1322
1323def _eval_annotation(
1324 annotation: str,
1325 context: EvaluationContext,
1326):
1327 return (
1328 _eval_node_name(annotation, context)
1329 if isinstance(annotation, str)
1330 else annotation
1331 )
1332
1333
1334class _GetItemDuck(dict):
1335 """A dict subclass that always returns the factory instance and claims to have any item."""
1336
1337 def __init__(self, factory, *args, **kwargs):
1338 super().__init__(*args, **kwargs)
1339 self._factory = factory
1340
1341 def __getitem__(self, key):
1342 return self._factory()
1343
1344 def __contains__(self, key):
1345 return True
1346
1347
1348def _resolve_annotation(
1349 annotation: object | str,
1350 context: EvaluationContext,
1351 sig: Signature | None = None,
1352 func: Callable | None = None,
1353 node: ast.Call | None = None,
1354):
1355 """Resolve annotation created by user with `typing` module and custom objects."""
1356 if annotation is None:
1357 return None
1358 annotation = _eval_annotation(annotation, context)
1359 origin = get_origin(annotation)
1360 if annotation is Self and func and hasattr(func, "__self__"):
1361 return func.__self__
1362 elif origin is Literal:
1363 type_args = get_args(annotation)
1364 if len(type_args) == 1:
1365 return type_args[0]
1366 elif annotation is LiteralString:
1367 return ""
1368 elif annotation is AnyStr:
1369 index = None
1370 if func and hasattr(func, "__node__"):
1371 def_node = func.__node__
1372 for i, arg in enumerate(def_node.args.args):
1373 if not arg.annotation:
1374 continue
1375 annotation = _eval_annotation(arg.annotation.id, context)
1376 if annotation is AnyStr:
1377 index = i
1378 break
1379 is_bound_method = (
1380 isinstance(func, MethodType) and getattr(func, "__self__") is not None
1381 )
1382 if index and is_bound_method:
1383 index -= 1
1384 elif sig:
1385 for i, (key, value) in enumerate(sig.parameters.items()):
1386 if value.annotation is AnyStr:
1387 index = i
1388 break
1389 if index is None:
1390 return None
1391 if index < 0 or index >= len(node.args):
1392 return None
1393 return eval_node(node.args[index], context)
1394 elif origin is TypeGuard:
1395 return False
1396 elif origin is set or origin is list:
1397 # only one type argument allowed
1398 attributes = [
1399 attr
1400 for attr in dir(
1401 _resolve_annotation(get_args(annotation)[0], context, sig, func, node)
1402 )
1403 ]
1404 duck = _Duck(attributes=dict.fromkeys(attributes))
1405 return _Duck(
1406 attributes=dict.fromkeys(dir(origin())),
1407 # items are not strrictly needed for set
1408 items=_GetItemDuck(lambda: duck),
1409 )
1410 elif origin is tuple:
1411 # multiple type arguments
1412 return tuple(
1413 _resolve_annotation(arg, context, sig, func, node)
1414 for arg in get_args(annotation)
1415 )
1416 elif origin is Union:
1417 # multiple type arguments
1418 attributes = [
1419 attr
1420 for type_arg in get_args(annotation)
1421 for attr in dir(_resolve_annotation(type_arg, context, sig, func, node))
1422 ]
1423 return _Duck(attributes=dict.fromkeys(attributes))
1424 elif is_typeddict(annotation):
1425 return _Duck(
1426 attributes=dict.fromkeys(dir(dict())),
1427 items={
1428 k: _resolve_annotation(v, context, sig, func, node)
1429 for k, v in annotation.__annotations__.items()
1430 },
1431 )
1432 elif hasattr(annotation, "_is_protocol"):
1433 return _Duck(attributes=dict.fromkeys(dir(annotation)))
1434 elif origin is Annotated:
1435 type_arg = get_args(annotation)[0]
1436 return _resolve_annotation(type_arg, context, sig, func, node)
1437 elif isinstance(annotation, NewType):
1438 return _eval_or_create_duck(annotation.__supertype__, context)
1439 elif isinstance(annotation, TypeAliasType):
1440 return _eval_or_create_duck(annotation.__value__, context)
1441 else:
1442 return _eval_or_create_duck(annotation, context)
1443
1444
1445def _eval_node_name(node_id: str, context: EvaluationContext):
1446 policy = get_policy(context)
1447 if node_id in context.transient_locals:
1448 return context.transient_locals[node_id]
1449 if policy.allow_locals_access and node_id in context.locals:
1450 return context.locals[node_id]
1451 if policy.allow_globals_access and node_id in context.globals:
1452 return context.globals[node_id]
1453 if policy.allow_builtins_access and hasattr(builtins, node_id):
1454 # note: do not use __builtins__, it is implementation detail of cPython
1455 return getattr(builtins, node_id)
1456 if policy.allow_auto_import and context.auto_import:
1457 return context.auto_import(node_id)
1458 if not policy.allow_globals_access and not policy.allow_locals_access:
1459 raise GuardRejection(
1460 f"Namespace access not allowed in {context.evaluation} mode"
1461 )
1462 else:
1463 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
1464
1465
1466def _eval_or_create_duck(duck_type, context: EvaluationContext):
1467 policy = get_policy(context)
1468 # if allow-listed builtin is on type annotation, instantiate it
1469 if policy.can_call(duck_type):
1470 return duck_type()
1471 # if custom class is in type annotation, mock it
1472 return _create_duck_for_heap_type(duck_type)
1473
1474
1475def _create_duck_for_heap_type(duck_type):
1476 """Create an imitation of an object of a given type (a duck).
1477
1478 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
1479 """
1480 duck = ImpersonatingDuck()
1481 try:
1482 # this only works for heap types, not builtins
1483 duck.__class__ = duck_type
1484 return duck
1485 except TypeError:
1486 pass
1487 return NOT_EVALUATED
1488
1489
1490SUPPORTED_EXTERNAL_GETITEM = {
1491 ("pandas", "core", "indexing", "_iLocIndexer"),
1492 ("pandas", "core", "indexing", "_LocIndexer"),
1493 ("pandas", "DataFrame"),
1494 ("pandas", "Series"),
1495 ("numpy", "ndarray"),
1496 ("numpy", "void"),
1497}
1498
1499
1500BUILTIN_GETITEM: set[InstancesHaveGetItem] = {
1501 dict,
1502 str, # type: ignore[arg-type]
1503 bytes, # type: ignore[arg-type]
1504 list,
1505 tuple,
1506 type, # for type annotations like list[str]
1507 _Duck,
1508 collections.defaultdict,
1509 collections.deque,
1510 collections.OrderedDict,
1511 collections.ChainMap,
1512 collections.UserDict,
1513 collections.UserList,
1514 collections.UserString, # type: ignore[arg-type]
1515 _DummyNamedTuple,
1516 _IdentitySubscript,
1517}
1518
1519
1520def _list_methods(cls, source=None):
1521 """For use on immutable objects or with methods returning a copy"""
1522 return [getattr(cls, k) for k in (source if source else dir(cls))]
1523
1524
1525dict_non_mutating_methods = ("copy", "keys", "values", "items")
1526list_non_mutating_methods = ("copy", "index", "count")
1527set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
1528
1529
1530dict_keys: type[collections.abc.KeysView] = type({}.keys())
1531dict_values: type = type({}.values())
1532dict_items: type = type({}.items())
1533
1534NUMERICS = {int, float, complex}
1535
1536ALLOWED_CALLS = {
1537 bytes,
1538 *_list_methods(bytes),
1539 bytes.__iter__,
1540 dict,
1541 *_list_methods(dict, dict_non_mutating_methods),
1542 dict.__iter__,
1543 dict_keys.__iter__,
1544 dict_values.__iter__,
1545 dict_items.__iter__,
1546 dict_keys.isdisjoint,
1547 list,
1548 *_list_methods(list, list_non_mutating_methods),
1549 list.__iter__,
1550 set,
1551 *_list_methods(set, set_non_mutating_methods),
1552 set.__iter__,
1553 frozenset,
1554 *_list_methods(frozenset),
1555 frozenset.__iter__,
1556 range,
1557 range.__iter__,
1558 str,
1559 *_list_methods(str),
1560 str.__iter__,
1561 tuple,
1562 *_list_methods(tuple),
1563 tuple.__iter__,
1564 bool,
1565 *_list_methods(bool),
1566 *NUMERICS,
1567 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
1568 collections.deque,
1569 *_list_methods(collections.deque, list_non_mutating_methods),
1570 collections.deque.__iter__,
1571 collections.defaultdict,
1572 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
1573 collections.defaultdict.__iter__,
1574 collections.OrderedDict,
1575 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
1576 collections.OrderedDict.__iter__,
1577 collections.UserDict,
1578 *_list_methods(collections.UserDict, dict_non_mutating_methods),
1579 collections.UserDict.__iter__,
1580 collections.UserList,
1581 *_list_methods(collections.UserList, list_non_mutating_methods),
1582 collections.UserList.__iter__,
1583 collections.UserString,
1584 *_list_methods(collections.UserString, dir(str)),
1585 collections.UserString.__iter__,
1586 collections.Counter,
1587 *_list_methods(collections.Counter, dict_non_mutating_methods),
1588 collections.Counter.__iter__,
1589 collections.Counter.elements,
1590 collections.Counter.most_common,
1591 object.__dir__,
1592 type.__dir__,
1593 _Duck.__dir__,
1594}
1595
1596BUILTIN_GETATTR: set[MayHaveGetattr] = {
1597 *BUILTIN_GETITEM,
1598 set,
1599 frozenset,
1600 object,
1601 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
1602 *NUMERICS,
1603 dict_keys,
1604 MethodDescriptorType,
1605 ModuleType,
1606}
1607
1608
1609BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
1610
1611EVALUATION_POLICIES = {
1612 "minimal": EvaluationPolicy(
1613 allow_builtins_access=True,
1614 allow_locals_access=False,
1615 allow_globals_access=False,
1616 allow_item_access=False,
1617 allow_attr_access=False,
1618 allowed_calls=set(),
1619 allow_any_calls=False,
1620 allow_all_operations=False,
1621 ),
1622 "limited": SelectivePolicy(
1623 allowed_getitem=BUILTIN_GETITEM,
1624 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
1625 allowed_getattr=BUILTIN_GETATTR,
1626 allowed_getattr_external={
1627 # pandas Series/Frame implements custom `__getattr__`
1628 ("pandas", "DataFrame"),
1629 ("pandas", "Series"),
1630 },
1631 allowed_operations=BUILTIN_OPERATIONS,
1632 allow_builtins_access=True,
1633 allow_locals_access=True,
1634 allow_globals_access=True,
1635 allow_getitem_on_types=True,
1636 allowed_calls=ALLOWED_CALLS,
1637 ),
1638 "unsafe": EvaluationPolicy(
1639 allow_builtins_access=True,
1640 allow_locals_access=True,
1641 allow_globals_access=True,
1642 allow_attr_access=True,
1643 allow_item_access=True,
1644 allow_any_calls=True,
1645 allow_all_operations=True,
1646 ),
1647}
1648
1649
1650__all__ = [
1651 "guarded_eval",
1652 "eval_node",
1653 "GuardRejection",
1654 "EvaluationContext",
1655 "_unbind_method",
1656]