1from copy import copy
2from inspect import isclass, signature, Signature, getmodule
3from typing import (
4 Annotated,
5 AnyStr,
6 Callable,
7 Literal,
8 NamedTuple,
9 NewType,
10 Optional,
11 Protocol,
12 Sequence,
13 TypeGuard,
14 Union,
15 get_args,
16 get_origin,
17 is_typeddict,
18)
19import ast
20import builtins
21import collections
22import dataclasses
23import operator
24import sys
25import typing
26import warnings
27from functools import cached_property
28from dataclasses import dataclass, field
29from types import MethodDescriptorType, ModuleType, MethodType
30
31from IPython.utils.decorators import undoc
32
33
34from typing import Self, LiteralString
35
36if sys.version_info < (3, 12):
37 from typing_extensions import TypeAliasType
38else:
39 from typing import TypeAliasType
40
41
42@undoc
43class HasGetItem(Protocol):
44 def __getitem__(self, key) -> None: ...
45
46
47@undoc
48class InstancesHaveGetItem(Protocol):
49 def __call__(self, *args, **kwargs) -> HasGetItem: ...
50
51
52@undoc
53class HasGetAttr(Protocol):
54 def __getattr__(self, key) -> None: ...
55
56
57@undoc
58class DoesNotHaveGetAttr(Protocol):
59 pass
60
61
62# By default `__getattr__` is not explicitly implemented on most objects
63MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
64
65
66def _unbind_method(func: Callable) -> Union[Callable, None]:
67 """Get unbound method for given bound method.
68
69 Returns None if cannot get unbound method, or method is already unbound.
70 """
71 owner = getattr(func, "__self__", None)
72 owner_class = type(owner)
73 name = getattr(func, "__name__", None)
74 instance_dict_overrides = getattr(owner, "__dict__", None)
75 if (
76 owner is not None
77 and name
78 and (
79 not instance_dict_overrides
80 or (instance_dict_overrides and name not in instance_dict_overrides)
81 )
82 ):
83 return getattr(owner_class, name)
84 return None
85
86
87@undoc
88@dataclass
89class EvaluationPolicy:
90 """Definition of evaluation policy."""
91
92 allow_locals_access: bool = False
93 allow_globals_access: bool = False
94 allow_item_access: bool = False
95 allow_attr_access: bool = False
96 allow_builtins_access: bool = False
97 allow_all_operations: bool = False
98 allow_any_calls: bool = False
99 allow_auto_import: bool = False
100 allowed_calls: set[Callable] = field(default_factory=set)
101
102 def can_get_item(self, value, item):
103 return self.allow_item_access
104
105 def can_get_attr(self, value, attr):
106 return self.allow_attr_access
107
108 def can_operate(self, dunders: tuple[str, ...], a, b=None):
109 if self.allow_all_operations:
110 return True
111
112 def can_call(self, func):
113 if self.allow_any_calls:
114 return True
115
116 if func in self.allowed_calls:
117 return True
118
119 owner_method = _unbind_method(func)
120
121 if owner_method and owner_method in self.allowed_calls:
122 return True
123
124
125def _get_external(module_name: str, access_path: Sequence[str]):
126 """Get value from external module given a dotted access path.
127
128 Only gets value if the module is already imported.
129
130 Raises:
131 * `KeyError` if module is removed not found, and
132 * `AttributeError` if access path does not match an exported object
133 """
134 try:
135 member_type = sys.modules[module_name]
136 # standard module
137 for attr in access_path:
138 member_type = getattr(member_type, attr)
139 return member_type
140 except (KeyError, AttributeError):
141 # handle modules in namespace packages
142 module_path = ".".join([module_name, *access_path])
143 if module_path in sys.modules:
144 return sys.modules[module_path]
145 raise
146
147
148def _has_original_dunder_external(
149 value,
150 module_name: str,
151 access_path: Sequence[str],
152 method_name: str,
153):
154 if module_name not in sys.modules:
155 full_module_path = ".".join([module_name, *access_path])
156 if full_module_path not in sys.modules:
157 # LBYLB as it is faster
158 return False
159 try:
160 member_type = _get_external(module_name, access_path)
161 value_type = type(value)
162 if type(value) == member_type:
163 return True
164 if isinstance(member_type, ModuleType):
165 value_module = getmodule(value_type)
166 if not value_module or not value_module.__name__:
167 return False
168 if value_module.__name__.startswith(member_type.__name__):
169 return True
170 if method_name == "__getattribute__":
171 # we have to short-circuit here due to an unresolved issue in
172 # `isinstance` implementation: https://bugs.python.org/issue32683
173 return False
174 if not isinstance(member_type, ModuleType) and isinstance(value, member_type):
175 method = getattr(value_type, method_name, None)
176 member_method = getattr(member_type, method_name, None)
177 if member_method == method:
178 return True
179 except (AttributeError, KeyError):
180 return False
181
182
183def _has_original_dunder(
184 value, allowed_types, allowed_methods, allowed_external, method_name
185):
186 # note: Python ignores `__getattr__`/`__getitem__` on instances,
187 # we only need to check at class level
188 value_type = type(value)
189
190 # strict type check passes → no need to check method
191 if value_type in allowed_types:
192 return True
193
194 method = getattr(value_type, method_name, None)
195
196 if method is None:
197 return None
198
199 if method in allowed_methods:
200 return True
201
202 for module_name, *access_path in allowed_external:
203 if _has_original_dunder_external(value, module_name, access_path, method_name):
204 return True
205
206 return False
207
208
209def _coerce_path_to_tuples(
210 allow_list: set[tuple[str, ...] | str],
211) -> set[tuple[str, ...]]:
212 """Replace dotted paths on the provided allow-list with tuples."""
213 return {
214 path if isinstance(path, tuple) else tuple(path.split("."))
215 for path in allow_list
216 }
217
218
219@undoc
220@dataclass
221class SelectivePolicy(EvaluationPolicy):
222 allowed_getitem: set[InstancesHaveGetItem] = field(default_factory=set)
223 allowed_getitem_external: set[tuple[str, ...] | str] = field(default_factory=set)
224
225 allowed_getattr: set[MayHaveGetattr] = field(default_factory=set)
226 allowed_getattr_external: set[tuple[str, ...] | str] = field(default_factory=set)
227
228 allowed_operations: set = field(default_factory=set)
229 allowed_operations_external: set[tuple[str, ...] | str] = field(default_factory=set)
230
231 allow_getitem_on_types: bool = field(default_factory=bool)
232
233 _operation_methods_cache: dict[str, set[Callable]] = field(
234 default_factory=dict, init=False
235 )
236
237 def can_get_attr(self, value, attr):
238 allowed_getattr_external = _coerce_path_to_tuples(self.allowed_getattr_external)
239
240 has_original_attribute = _has_original_dunder(
241 value,
242 allowed_types=self.allowed_getattr,
243 allowed_methods=self._getattribute_methods,
244 allowed_external=allowed_getattr_external,
245 method_name="__getattribute__",
246 )
247 has_original_attr = _has_original_dunder(
248 value,
249 allowed_types=self.allowed_getattr,
250 allowed_methods=self._getattr_methods,
251 allowed_external=allowed_getattr_external,
252 method_name="__getattr__",
253 )
254
255 accept = False
256
257 # Many objects do not have `__getattr__`, this is fine.
258 if has_original_attr is None and has_original_attribute:
259 accept = True
260 else:
261 # Accept objects without modifications to `__getattr__` and `__getattribute__`
262 accept = has_original_attr and has_original_attribute
263
264 if accept:
265 # We still need to check for overridden properties.
266
267 value_class = type(value)
268 if not hasattr(value_class, attr):
269 return True
270
271 class_attr_val = getattr(value_class, attr)
272 is_property = isinstance(class_attr_val, property)
273
274 if not is_property:
275 return True
276
277 # Properties in allowed types are ok (although we do not include any
278 # properties in our default allow list currently).
279 if type(value) in self.allowed_getattr:
280 return True # pragma: no cover
281
282 # Properties in subclasses of allowed types may be ok if not changed
283 for module_name, *access_path in allowed_getattr_external:
284 try:
285 external_class = _get_external(module_name, access_path)
286 external_class_attr_val = getattr(external_class, attr)
287 except (KeyError, AttributeError):
288 return False # pragma: no cover
289 return class_attr_val == external_class_attr_val
290
291 return False
292
293 def can_get_item(self, value, item):
294 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
295 allowed_getitem_external = _coerce_path_to_tuples(self.allowed_getitem_external)
296 if self.allow_getitem_on_types:
297 # e.g. Union[str, int] or Literal[True, 1]
298 if isinstance(value, (typing._SpecialForm, typing._BaseGenericAlias)):
299 return True
300 # PEP 560 e.g. list[str]
301 if isinstance(value, type) and hasattr(value, "__class_getitem__"):
302 return True
303 return _has_original_dunder(
304 value,
305 allowed_types=self.allowed_getitem,
306 allowed_methods=self._getitem_methods,
307 allowed_external=allowed_getitem_external,
308 method_name="__getitem__",
309 )
310
311 def can_operate(self, dunders: tuple[str, ...], a, b=None):
312 allowed_operations_external = _coerce_path_to_tuples(
313 self.allowed_operations_external
314 )
315 objects = [a]
316 if b is not None:
317 objects.append(b)
318 return all(
319 [
320 _has_original_dunder(
321 obj,
322 allowed_types=self.allowed_operations,
323 allowed_methods=self._operator_dunder_methods(dunder),
324 allowed_external=allowed_operations_external,
325 method_name=dunder,
326 )
327 for dunder in dunders
328 for obj in objects
329 ]
330 )
331
332 def _operator_dunder_methods(self, dunder: str) -> set[Callable]:
333 if dunder not in self._operation_methods_cache:
334 self._operation_methods_cache[dunder] = self._safe_get_methods(
335 self.allowed_operations, dunder
336 )
337 return self._operation_methods_cache[dunder]
338
339 @cached_property
340 def _getitem_methods(self) -> set[Callable]:
341 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
342
343 @cached_property
344 def _getattr_methods(self) -> set[Callable]:
345 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
346
347 @cached_property
348 def _getattribute_methods(self) -> set[Callable]:
349 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
350
351 def _safe_get_methods(self, classes, name) -> set[Callable]:
352 return {
353 method
354 for class_ in classes
355 for method in [getattr(class_, name, None)]
356 if method
357 }
358
359
360class _DummyNamedTuple(NamedTuple):
361 """Used internally to retrieve methods of named tuple instance."""
362
363
364EvaluationPolicyName = Literal["forbidden", "minimal", "limited", "unsafe", "dangerous"]
365
366
367@dataclass
368class EvaluationContext:
369 #: Local namespace
370 locals: dict
371 #: Global namespace
372 globals: dict
373 #: Evaluation policy identifier
374 evaluation: EvaluationPolicyName = "forbidden"
375 #: Whether the evaluation of code takes place inside of a subscript.
376 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
377 in_subscript: bool = False
378 #: Auto import method
379 auto_import: Callable[list[str], ModuleType] | None = None
380 #: Overrides for evaluation policy
381 policy_overrides: dict = field(default_factory=dict)
382 #: Transient local namespace used to store mocks
383 transient_locals: dict = field(default_factory=dict)
384
385 def replace(self, /, **changes):
386 """Return a new copy of the context, with specified changes"""
387 return dataclasses.replace(self, **changes)
388
389
390class _IdentitySubscript:
391 """Returns the key itself when item is requested via subscript."""
392
393 def __getitem__(self, key):
394 return key
395
396
397IDENTITY_SUBSCRIPT = _IdentitySubscript()
398SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
399UNKNOWN_SIGNATURE = Signature()
400NOT_EVALUATED = object()
401
402
403class GuardRejection(Exception):
404 """Exception raised when guard rejects evaluation attempt."""
405
406 pass
407
408
409def guarded_eval(code: str, context: EvaluationContext):
410 """Evaluate provided code in the evaluation context.
411
412 If evaluation policy given by context is set to ``forbidden``
413 no evaluation will be performed; if it is set to ``dangerous``
414 standard :func:`eval` will be used; finally, for any other,
415 policy :func:`eval_node` will be called on parsed AST.
416 """
417 locals_ = context.locals
418
419 if context.evaluation == "forbidden":
420 raise GuardRejection("Forbidden mode")
421
422 # note: not using `ast.literal_eval` as it does not implement
423 # getitem at all, for example it fails on simple `[0][1]`
424
425 if context.in_subscript:
426 # syntactic sugar for ellipsis (:) is only available in subscripts
427 # so we need to trick the ast parser into thinking that we have
428 # a subscript, but we need to be able to later recognise that we did
429 # it so we can ignore the actual __getitem__ operation
430 if not code:
431 return tuple()
432 locals_ = locals_.copy()
433 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
434 code = SUBSCRIPT_MARKER + "[" + code + "]"
435 context = context.replace(locals=locals_)
436
437 if context.evaluation == "dangerous":
438 return eval(code, context.globals, context.locals)
439
440 node = ast.parse(code, mode="exec")
441
442 return eval_node(node, context)
443
444
445BINARY_OP_DUNDERS: dict[type[ast.operator], tuple[str]] = {
446 ast.Add: ("__add__",),
447 ast.Sub: ("__sub__",),
448 ast.Mult: ("__mul__",),
449 ast.Div: ("__truediv__",),
450 ast.FloorDiv: ("__floordiv__",),
451 ast.Mod: ("__mod__",),
452 ast.Pow: ("__pow__",),
453 ast.LShift: ("__lshift__",),
454 ast.RShift: ("__rshift__",),
455 ast.BitOr: ("__or__",),
456 ast.BitXor: ("__xor__",),
457 ast.BitAnd: ("__and__",),
458 ast.MatMult: ("__matmul__",),
459}
460
461COMP_OP_DUNDERS: dict[type[ast.cmpop], tuple[str, ...]] = {
462 ast.Eq: ("__eq__",),
463 ast.NotEq: ("__ne__", "__eq__"),
464 ast.Lt: ("__lt__", "__gt__"),
465 ast.LtE: ("__le__", "__ge__"),
466 ast.Gt: ("__gt__", "__lt__"),
467 ast.GtE: ("__ge__", "__le__"),
468 ast.In: ("__contains__",),
469 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
470}
471
472UNARY_OP_DUNDERS: dict[type[ast.unaryop], tuple[str, ...]] = {
473 ast.USub: ("__neg__",),
474 ast.UAdd: ("__pos__",),
475 # we have to check both __inv__ and __invert__!
476 ast.Invert: ("__invert__", "__inv__"),
477 ast.Not: ("__not__",),
478}
479
480
481class ImpersonatingDuck:
482 """A dummy class used to create objects of other classes without calling their ``__init__``"""
483
484 # no-op: override __class__ to impersonate
485
486
487class _Duck:
488 """A dummy class used to create objects pretending to have given attributes"""
489
490 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
491 self.attributes = attributes if attributes is not None else {}
492 self.items = items if items is not None else {}
493
494 def __getattr__(self, attr: str):
495 return self.attributes[attr]
496
497 def __hasattr__(self, attr: str):
498 return attr in self.attributes
499
500 def __dir__(self):
501 return [*dir(super), *self.attributes]
502
503 def __getitem__(self, key: str):
504 return self.items[key]
505
506 def __hasitem__(self, key: str):
507 return self.items[key]
508
509 def _ipython_key_completions_(self):
510 return self.items.keys()
511
512
513def _find_dunder(node_op, dunders) -> Union[tuple[str, ...], None]:
514 dunder = None
515 for op, candidate_dunder in dunders.items():
516 if isinstance(node_op, op):
517 dunder = candidate_dunder
518 return dunder
519
520
521def get_policy(context: EvaluationContext) -> EvaluationPolicy:
522 policy = copy(EVALUATION_POLICIES[context.evaluation])
523
524 for key, value in context.policy_overrides.items():
525 if hasattr(policy, key):
526 setattr(policy, key, value)
527 return policy
528
529
530def _validate_policy_overrides(
531 policy_name: EvaluationPolicyName, policy_overrides: dict
532) -> bool:
533 policy = EVALUATION_POLICIES[policy_name]
534
535 all_good = True
536 for key, value in policy_overrides.items():
537 if not hasattr(policy, key):
538 warnings.warn(
539 f"Override {key!r} is not valid with {policy_name!r} evaluation policy"
540 )
541 all_good = False
542 return all_good
543
544
545def _handle_assign(node: ast.Assign, context: EvaluationContext):
546 value = eval_node(node.value, context)
547 transient_locals = context.transient_locals
548 for target in node.targets:
549 if isinstance(target, (ast.Tuple, ast.List)):
550 # Handle unpacking assignment
551 values = list(value)
552 targets = target.elts
553 starred = [i for i, t in enumerate(targets) if isinstance(t, ast.Starred)]
554
555 # Unified handling: treat no starred as starred at end
556 star_or_last_idx = starred[0] if starred else len(targets)
557
558 # Before starred
559 for i in range(star_or_last_idx):
560 transient_locals[targets[i].id] = values[i]
561
562 # Starred if exists
563 if starred:
564 end = len(values) - (len(targets) - star_or_last_idx - 1)
565 transient_locals[targets[star_or_last_idx].value.id] = values[
566 star_or_last_idx:end
567 ]
568
569 # After starred
570 for i in range(star_or_last_idx + 1, len(targets)):
571 transient_locals[targets[i].id] = values[
572 len(values) - (len(targets) - i)
573 ]
574 else:
575 transient_locals[target.id] = value
576 return None
577
578
579def _extract_args_and_kwargs(node: ast.Call, context: EvaluationContext):
580 args = [eval_node(arg, context) for arg in node.args]
581 kwargs = {
582 k: v
583 for kw in node.keywords
584 for k, v in (
585 {kw.arg: eval_node(kw.value, context)}
586 if kw.arg
587 else eval_node(kw.value, context)
588 ).items()
589 }
590 return args, kwargs
591
592
593def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
594 """Evaluate AST node in provided context.
595
596 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
597
598 Does not evaluate actions that always have side effects:
599
600 - class definitions (``class sth: ...``)
601 - function definitions (``def sth: ...``)
602 - variable assignments (``x = 1``)
603 - augmented assignments (``x += 1``)
604 - deletions (``del x``)
605
606 Does not evaluate operations which do not return values:
607
608 - assertions (``assert x``)
609 - pass (``pass``)
610 - imports (``import x``)
611 - control flow:
612
613 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
614 - loops (``for`` and ``while``)
615 - exception handling
616
617 The purpose of this function is to guard against unwanted side-effects;
618 it does not give guarantees on protection from malicious code execution.
619 """
620 policy = get_policy(context)
621
622 if node is None:
623 return None
624 if isinstance(node, (ast.Interactive, ast.Module)):
625 result = None
626 for child_node in node.body:
627 result = eval_node(child_node, context)
628 return result
629 if isinstance(node, ast.FunctionDef):
630 # we ignore body and only extract the return type
631 is_property = False
632
633 for decorator_node in node.decorator_list:
634 try:
635 decorator = eval_node(decorator_node, context)
636 except NameError:
637 # if the decorator is not yet defined this is fine
638 # especialy because we don't handle imports yet
639 continue
640 if decorator is property:
641 is_property = True
642
643 return_type = eval_node(node.returns, context=context)
644
645 if is_property:
646 context.transient_locals[node.name] = _resolve_annotation(
647 return_type, context
648 )
649 return None
650
651 def dummy_function(*args, **kwargs):
652 pass
653
654 dummy_function.__annotations__["return"] = return_type
655 dummy_function.__name__ = node.name
656 dummy_function.__node__ = node
657 context.transient_locals[node.name] = dummy_function
658 return None
659 if isinstance(node, ast.ClassDef):
660 # TODO support class decorators?
661 class_locals = {}
662 class_context = context.replace(transient_locals=class_locals)
663 for child_node in node.body:
664 eval_node(child_node, class_context)
665 bases = tuple([eval_node(base, context) for base in node.bases])
666 dummy_class = type(node.name, bases, class_locals)
667 context.transient_locals[node.name] = dummy_class
668 return None
669 if isinstance(node, ast.Assign):
670 return _handle_assign(node, context)
671 if isinstance(node, ast.AnnAssign):
672 if not node.simple:
673 # for now only handle simple annotations
674 return None
675 context.transient_locals[node.target.id] = _resolve_annotation(
676 eval_node(node.annotation, context), context
677 )
678 return None
679 if isinstance(node, ast.Expression):
680 return eval_node(node.body, context)
681 if isinstance(node, ast.Expr):
682 return eval_node(node.value, context)
683 if isinstance(node, ast.Pass):
684 return None
685 if isinstance(node, ast.Import):
686 # TODO: populate transient_locals
687 return None
688 if isinstance(node, (ast.AugAssign, ast.Delete)):
689 return None
690 if isinstance(node, (ast.Global, ast.Nonlocal)):
691 return None
692 if isinstance(node, ast.BinOp):
693 left = eval_node(node.left, context)
694 right = eval_node(node.right, context)
695 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
696 if dunders:
697 if policy.can_operate(dunders, left, right):
698 return getattr(left, dunders[0])(right)
699 else:
700 raise GuardRejection(
701 f"Operation (`{dunders}`) for",
702 type(left),
703 f"not allowed in {context.evaluation} mode",
704 )
705 if isinstance(node, ast.Compare):
706 left = eval_node(node.left, context)
707 all_true = True
708 negate = False
709 for op, right in zip(node.ops, node.comparators):
710 right = eval_node(right, context)
711 dunder = None
712 dunders = _find_dunder(op, COMP_OP_DUNDERS)
713 if not dunders:
714 if isinstance(op, ast.NotIn):
715 dunders = COMP_OP_DUNDERS[ast.In]
716 negate = True
717 if isinstance(op, ast.Is):
718 dunder = "is_"
719 if isinstance(op, ast.IsNot):
720 dunder = "is_"
721 negate = True
722 if not dunder and dunders:
723 dunder = dunders[0]
724 if dunder:
725 a, b = (right, left) if dunder == "__contains__" else (left, right)
726 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
727 result = getattr(operator, dunder)(a, b)
728 if negate:
729 result = not result
730 if not result:
731 all_true = False
732 left = right
733 else:
734 raise GuardRejection(
735 f"Comparison (`{dunder}`) for",
736 type(left),
737 f"not allowed in {context.evaluation} mode",
738 )
739 else:
740 raise ValueError(
741 f"Comparison `{dunder}` not supported"
742 ) # pragma: no cover
743 return all_true
744 if isinstance(node, ast.Constant):
745 return node.value
746 if isinstance(node, ast.Tuple):
747 return tuple(eval_node(e, context) for e in node.elts)
748 if isinstance(node, ast.List):
749 return [eval_node(e, context) for e in node.elts]
750 if isinstance(node, ast.Set):
751 return {eval_node(e, context) for e in node.elts}
752 if isinstance(node, ast.Dict):
753 return dict(
754 zip(
755 [eval_node(k, context) for k in node.keys],
756 [eval_node(v, context) for v in node.values],
757 )
758 )
759 if isinstance(node, ast.Slice):
760 return slice(
761 eval_node(node.lower, context),
762 eval_node(node.upper, context),
763 eval_node(node.step, context),
764 )
765 if isinstance(node, ast.UnaryOp):
766 value = eval_node(node.operand, context)
767 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
768 if dunders:
769 if policy.can_operate(dunders, value):
770 try:
771 return getattr(value, dunders[0])()
772 except AttributeError:
773 raise TypeError(
774 f"bad operand type for unary {node.op}: {type(value)}"
775 )
776 else:
777 raise GuardRejection(
778 f"Operation (`{dunders}`) for",
779 type(value),
780 f"not allowed in {context.evaluation} mode",
781 )
782 if isinstance(node, ast.Subscript):
783 value = eval_node(node.value, context)
784 slice_ = eval_node(node.slice, context)
785 if policy.can_get_item(value, slice_):
786 return value[slice_]
787 raise GuardRejection(
788 "Subscript access (`__getitem__`) for",
789 type(value), # not joined to avoid calling `repr`
790 f" not allowed in {context.evaluation} mode",
791 )
792 if isinstance(node, ast.Name):
793 return _eval_node_name(node.id, context)
794 if isinstance(node, ast.Attribute):
795 value = eval_node(node.value, context)
796 if policy.can_get_attr(value, node.attr):
797 return getattr(value, node.attr)
798 raise GuardRejection(
799 "Attribute access (`__getattr__`) for",
800 type(value), # not joined to avoid calling `repr`
801 f"not allowed in {context.evaluation} mode",
802 )
803 if isinstance(node, ast.IfExp):
804 test = eval_node(node.test, context)
805 if test:
806 return eval_node(node.body, context)
807 else:
808 return eval_node(node.orelse, context)
809 if isinstance(node, ast.Call):
810 func = eval_node(node.func, context)
811 if policy.can_call(func):
812 args, kwargs = _extract_args_and_kwargs(node, context)
813 return func(*args, **kwargs)
814 if isclass(func):
815 # this code path gets entered when calling class e.g. `MyClass()`
816 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
817 # Should return `MyClass` if `__new__` is not overridden,
818 # otherwise whatever `__new__` return type is.
819 overridden_return_type = _eval_return_type(func.__new__, node, context)
820 if overridden_return_type is not NOT_EVALUATED:
821 return overridden_return_type
822 return _create_duck_for_heap_type(func)
823 else:
824 return_type = _eval_return_type(func, node, context)
825 if return_type is not NOT_EVALUATED:
826 return return_type
827 raise GuardRejection(
828 "Call for",
829 func, # not joined to avoid calling `repr`
830 f"not allowed in {context.evaluation} mode",
831 )
832 if isinstance(node, ast.Assert):
833 # message is always the second item, so if it is defined user would be completing
834 # on the message, not on the assertion test
835 if node.msg:
836 return eval_node(node.msg, context)
837 return eval_node(node.test, context)
838 return None
839
840
841def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
842 """Evaluate return type of a given callable function.
843
844 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
845 """
846 try:
847 sig = signature(func)
848 except ValueError:
849 sig = UNKNOWN_SIGNATURE
850 # if annotation was not stringized, or it was stringized
851 # but resolved by signature call we know the return type
852 not_empty = sig.return_annotation is not Signature.empty
853 if not_empty:
854 return _resolve_annotation(sig.return_annotation, context, sig, func, node)
855 return NOT_EVALUATED
856
857
858def _eval_annotation(
859 annotation: str,
860 context: EvaluationContext,
861):
862 return (
863 _eval_node_name(annotation, context)
864 if isinstance(annotation, str)
865 else annotation
866 )
867
868
869class _GetItemDuck(dict):
870 """A dict subclass that always returns the factory instance and claims to have any item."""
871
872 def __init__(self, factory, *args, **kwargs):
873 super().__init__(*args, **kwargs)
874 self._factory = factory
875
876 def __getitem__(self, key):
877 return self._factory()
878
879 def __contains__(self, key):
880 return True
881
882
883def _resolve_annotation(
884 annotation: object | str,
885 context: EvaluationContext,
886 sig: Signature | None = None,
887 func: Callable | None = None,
888 node: ast.Call | None = None,
889):
890 """Resolve annotation created by user with `typing` module and custom objects."""
891 if annotation is None:
892 return None
893 annotation = _eval_annotation(annotation, context)
894 origin = get_origin(annotation)
895 if annotation is Self and func and hasattr(func, "__self__"):
896 return func.__self__
897 elif origin is Literal:
898 type_args = get_args(annotation)
899 if len(type_args) == 1:
900 return type_args[0]
901 elif annotation is LiteralString:
902 return ""
903 elif annotation is AnyStr:
904 index = None
905 if func and hasattr(func, "__node__"):
906 def_node = func.__node__
907 for i, arg in enumerate(def_node.args.args):
908 if not arg.annotation:
909 continue
910 annotation = _eval_annotation(arg.annotation.id, context)
911 if annotation is AnyStr:
912 index = i
913 break
914 is_bound_method = (
915 isinstance(func, MethodType) and getattr(func, "__self__") is not None
916 )
917 if index and is_bound_method:
918 index -= 1
919 elif sig:
920 for i, (key, value) in enumerate(sig.parameters.items()):
921 if value.annotation is AnyStr:
922 index = i
923 break
924 if index is None:
925 return None
926 if index < 0 or index >= len(node.args):
927 return None
928 return eval_node(node.args[index], context)
929 elif origin is TypeGuard:
930 return False
931 elif origin is set or origin is list:
932 # only one type argument allowed
933 attributes = [
934 attr
935 for attr in dir(
936 _resolve_annotation(get_args(annotation)[0], context, sig, func, node)
937 )
938 ]
939 duck = _Duck(attributes=dict.fromkeys(attributes))
940 return _Duck(
941 attributes=dict.fromkeys(dir(origin())),
942 # items are not strrictly needed for set
943 items=_GetItemDuck(lambda: duck),
944 )
945 elif origin is tuple:
946 # multiple type arguments
947 return tuple(
948 _resolve_annotation(arg, context, sig, func, node)
949 for arg in get_args(annotation)
950 )
951 elif origin is Union:
952 # multiple type arguments
953 attributes = [
954 attr
955 for type_arg in get_args(annotation)
956 for attr in dir(_resolve_annotation(type_arg, context, sig, func, node))
957 ]
958 return _Duck(attributes=dict.fromkeys(attributes))
959 elif is_typeddict(annotation):
960 return _Duck(
961 attributes=dict.fromkeys(dir(dict())),
962 items={
963 k: _resolve_annotation(v, context, sig, func, node)
964 for k, v in annotation.__annotations__.items()
965 },
966 )
967 elif hasattr(annotation, "_is_protocol"):
968 return _Duck(attributes=dict.fromkeys(dir(annotation)))
969 elif origin is Annotated:
970 type_arg = get_args(annotation)[0]
971 return _resolve_annotation(type_arg, context, sig, func, node)
972 elif isinstance(annotation, NewType):
973 return _eval_or_create_duck(annotation.__supertype__, context)
974 elif isinstance(annotation, TypeAliasType):
975 return _eval_or_create_duck(annotation.__value__, context)
976 else:
977 return _eval_or_create_duck(annotation, context)
978
979
980def _eval_node_name(node_id: str, context: EvaluationContext):
981 policy = get_policy(context)
982 if node_id in context.transient_locals:
983 return context.transient_locals[node_id]
984 if policy.allow_locals_access and node_id in context.locals:
985 return context.locals[node_id]
986 if policy.allow_globals_access and node_id in context.globals:
987 return context.globals[node_id]
988 if policy.allow_builtins_access and hasattr(builtins, node_id):
989 # note: do not use __builtins__, it is implementation detail of cPython
990 return getattr(builtins, node_id)
991 if policy.allow_auto_import and context.auto_import:
992 return context.auto_import(node_id)
993 if not policy.allow_globals_access and not policy.allow_locals_access:
994 raise GuardRejection(
995 f"Namespace access not allowed in {context.evaluation} mode"
996 )
997 else:
998 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
999
1000
1001def _eval_or_create_duck(duck_type, context: EvaluationContext):
1002 policy = get_policy(context)
1003 # if allow-listed builtin is on type annotation, instantiate it
1004 if policy.can_call(duck_type):
1005 return duck_type()
1006 # if custom class is in type annotation, mock it
1007 return _create_duck_for_heap_type(duck_type)
1008
1009
1010def _create_duck_for_heap_type(duck_type):
1011 """Create an imitation of an object of a given type (a duck).
1012
1013 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
1014 """
1015 duck = ImpersonatingDuck()
1016 try:
1017 # this only works for heap types, not builtins
1018 duck.__class__ = duck_type
1019 return duck
1020 except TypeError:
1021 pass
1022 return NOT_EVALUATED
1023
1024
1025SUPPORTED_EXTERNAL_GETITEM = {
1026 ("pandas", "core", "indexing", "_iLocIndexer"),
1027 ("pandas", "core", "indexing", "_LocIndexer"),
1028 ("pandas", "DataFrame"),
1029 ("pandas", "Series"),
1030 ("numpy", "ndarray"),
1031 ("numpy", "void"),
1032}
1033
1034
1035BUILTIN_GETITEM: set[InstancesHaveGetItem] = {
1036 dict,
1037 str, # type: ignore[arg-type]
1038 bytes, # type: ignore[arg-type]
1039 list,
1040 tuple,
1041 type, # for type annotations like list[str]
1042 _Duck,
1043 collections.defaultdict,
1044 collections.deque,
1045 collections.OrderedDict,
1046 collections.ChainMap,
1047 collections.UserDict,
1048 collections.UserList,
1049 collections.UserString, # type: ignore[arg-type]
1050 _DummyNamedTuple,
1051 _IdentitySubscript,
1052}
1053
1054
1055def _list_methods(cls, source=None):
1056 """For use on immutable objects or with methods returning a copy"""
1057 return [getattr(cls, k) for k in (source if source else dir(cls))]
1058
1059
1060dict_non_mutating_methods = ("copy", "keys", "values", "items")
1061list_non_mutating_methods = ("copy", "index", "count")
1062set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
1063
1064
1065dict_keys: type[collections.abc.KeysView] = type({}.keys())
1066
1067NUMERICS = {int, float, complex}
1068
1069ALLOWED_CALLS = {
1070 bytes,
1071 *_list_methods(bytes),
1072 dict,
1073 *_list_methods(dict, dict_non_mutating_methods),
1074 dict_keys.isdisjoint,
1075 list,
1076 *_list_methods(list, list_non_mutating_methods),
1077 set,
1078 *_list_methods(set, set_non_mutating_methods),
1079 frozenset,
1080 *_list_methods(frozenset),
1081 range,
1082 str,
1083 *_list_methods(str),
1084 tuple,
1085 *_list_methods(tuple),
1086 bool,
1087 *_list_methods(bool),
1088 *NUMERICS,
1089 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
1090 collections.deque,
1091 *_list_methods(collections.deque, list_non_mutating_methods),
1092 collections.defaultdict,
1093 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
1094 collections.OrderedDict,
1095 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
1096 collections.UserDict,
1097 *_list_methods(collections.UserDict, dict_non_mutating_methods),
1098 collections.UserList,
1099 *_list_methods(collections.UserList, list_non_mutating_methods),
1100 collections.UserString,
1101 *_list_methods(collections.UserString, dir(str)),
1102 collections.Counter,
1103 *_list_methods(collections.Counter, dict_non_mutating_methods),
1104 collections.Counter.elements,
1105 collections.Counter.most_common,
1106}
1107
1108BUILTIN_GETATTR: set[MayHaveGetattr] = {
1109 *BUILTIN_GETITEM,
1110 set,
1111 frozenset,
1112 object,
1113 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
1114 *NUMERICS,
1115 dict_keys,
1116 MethodDescriptorType,
1117 ModuleType,
1118}
1119
1120
1121BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
1122
1123EVALUATION_POLICIES = {
1124 "minimal": EvaluationPolicy(
1125 allow_builtins_access=True,
1126 allow_locals_access=False,
1127 allow_globals_access=False,
1128 allow_item_access=False,
1129 allow_attr_access=False,
1130 allowed_calls=set(),
1131 allow_any_calls=False,
1132 allow_all_operations=False,
1133 ),
1134 "limited": SelectivePolicy(
1135 allowed_getitem=BUILTIN_GETITEM,
1136 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
1137 allowed_getattr=BUILTIN_GETATTR,
1138 allowed_getattr_external={
1139 # pandas Series/Frame implements custom `__getattr__`
1140 ("pandas", "DataFrame"),
1141 ("pandas", "Series"),
1142 },
1143 allowed_operations=BUILTIN_OPERATIONS,
1144 allow_builtins_access=True,
1145 allow_locals_access=True,
1146 allow_globals_access=True,
1147 allow_getitem_on_types=True,
1148 allowed_calls=ALLOWED_CALLS,
1149 ),
1150 "unsafe": EvaluationPolicy(
1151 allow_builtins_access=True,
1152 allow_locals_access=True,
1153 allow_globals_access=True,
1154 allow_attr_access=True,
1155 allow_item_access=True,
1156 allow_any_calls=True,
1157 allow_all_operations=True,
1158 ),
1159}
1160
1161
1162__all__ = [
1163 "guarded_eval",
1164 "eval_node",
1165 "GuardRejection",
1166 "EvaluationContext",
1167 "_unbind_method",
1168]