1"""
2Operator classes for eval.
3"""
4
5from __future__ import annotations
6
7from datetime import datetime
8from functools import partial
9import operator
10from typing import (
11 Callable,
12 Iterable,
13 Iterator,
14 Literal,
15)
16
17import numpy as np
18
19from pandas._libs.tslibs import Timestamp
20
21from pandas.core.dtypes.common import (
22 is_list_like,
23 is_scalar,
24)
25
26import pandas.core.common as com
27from pandas.core.computation.common import (
28 ensure_decoded,
29 result_type_many,
30)
31from pandas.core.computation.scope import DEFAULT_GLOBALS
32
33from pandas.io.formats.printing import (
34 pprint_thing,
35 pprint_thing_encoded,
36)
37
38REDUCTIONS = ("sum", "prod", "min", "max")
39
40_unary_math_ops = (
41 "sin",
42 "cos",
43 "exp",
44 "log",
45 "expm1",
46 "log1p",
47 "sqrt",
48 "sinh",
49 "cosh",
50 "tanh",
51 "arcsin",
52 "arccos",
53 "arctan",
54 "arccosh",
55 "arcsinh",
56 "arctanh",
57 "abs",
58 "log10",
59 "floor",
60 "ceil",
61)
62_binary_math_ops = ("arctan2",)
63
64MATHOPS = _unary_math_ops + _binary_math_ops
65
66
67LOCAL_TAG = "__pd_eval_local_"
68
69
70class Term:
71 def __new__(cls, name, env, side=None, encoding=None):
72 klass = Constant if not isinstance(name, str) else cls
73 # error: Argument 2 for "super" not an instance of argument 1
74 supr_new = super(Term, klass).__new__ # type: ignore[misc]
75 return supr_new(klass)
76
77 is_local: bool
78
79 def __init__(self, name, env, side=None, encoding=None) -> None:
80 # name is a str for Term, but may be something else for subclasses
81 self._name = name
82 self.env = env
83 self.side = side
84 tname = str(name)
85 self.is_local = tname.startswith(LOCAL_TAG) or tname in DEFAULT_GLOBALS
86 self._value = self._resolve_name()
87 self.encoding = encoding
88
89 @property
90 def local_name(self) -> str:
91 return self.name.replace(LOCAL_TAG, "")
92
93 def __repr__(self) -> str:
94 return pprint_thing(self.name)
95
96 def __call__(self, *args, **kwargs):
97 return self.value
98
99 def evaluate(self, *args, **kwargs) -> Term:
100 return self
101
102 def _resolve_name(self):
103 local_name = str(self.local_name)
104 is_local = self.is_local
105 if local_name in self.env.scope and isinstance(
106 self.env.scope[local_name], type
107 ):
108 is_local = False
109
110 res = self.env.resolve(local_name, is_local=is_local)
111 self.update(res)
112
113 if hasattr(res, "ndim") and res.ndim > 2:
114 raise NotImplementedError(
115 "N-dimensional objects, where N > 2, are not supported with eval"
116 )
117 return res
118
119 def update(self, value) -> None:
120 """
121 search order for local (i.e., @variable) variables:
122
123 scope, key_variable
124 [('locals', 'local_name'),
125 ('globals', 'local_name'),
126 ('locals', 'key'),
127 ('globals', 'key')]
128 """
129 key = self.name
130
131 # if it's a variable name (otherwise a constant)
132 if isinstance(key, str):
133 self.env.swapkey(self.local_name, key, new_value=value)
134
135 self.value = value
136
137 @property
138 def is_scalar(self) -> bool:
139 return is_scalar(self._value)
140
141 @property
142 def type(self):
143 try:
144 # potentially very slow for large, mixed dtype frames
145 return self._value.values.dtype
146 except AttributeError:
147 try:
148 # ndarray
149 return self._value.dtype
150 except AttributeError:
151 # scalar
152 return type(self._value)
153
154 return_type = type
155
156 @property
157 def raw(self) -> str:
158 return f"{type(self).__name__}(name={repr(self.name)}, type={self.type})"
159
160 @property
161 def is_datetime(self) -> bool:
162 try:
163 t = self.type.type
164 except AttributeError:
165 t = self.type
166
167 return issubclass(t, (datetime, np.datetime64))
168
169 @property
170 def value(self):
171 return self._value
172
173 @value.setter
174 def value(self, new_value) -> None:
175 self._value = new_value
176
177 @property
178 def name(self):
179 return self._name
180
181 @property
182 def ndim(self) -> int:
183 return self._value.ndim
184
185
186class Constant(Term):
187 def __init__(self, value, env, side=None, encoding=None) -> None:
188 super().__init__(value, env, side=side, encoding=encoding)
189
190 def _resolve_name(self):
191 return self._name
192
193 @property
194 def name(self):
195 return self.value
196
197 def __repr__(self) -> str:
198 # in python 2 str() of float
199 # can truncate shorter than repr()
200 return repr(self.name)
201
202
203_bool_op_map = {"not": "~", "and": "&", "or": "|"}
204
205
206class Op:
207 """
208 Hold an operator of arbitrary arity.
209 """
210
211 op: str
212
213 def __init__(self, op: str, operands: Iterable[Term | Op], encoding=None) -> None:
214 self.op = _bool_op_map.get(op, op)
215 self.operands = operands
216 self.encoding = encoding
217
218 def __iter__(self) -> Iterator:
219 return iter(self.operands)
220
221 def __repr__(self) -> str:
222 """
223 Print a generic n-ary operator and its operands using infix notation.
224 """
225 # recurse over the operands
226 parened = (f"({pprint_thing(opr)})" for opr in self.operands)
227 return pprint_thing(f" {self.op} ".join(parened))
228
229 @property
230 def return_type(self):
231 # clobber types to bool if the op is a boolean operator
232 if self.op in (CMP_OPS_SYMS + BOOL_OPS_SYMS):
233 return np.bool_
234 return result_type_many(*(term.type for term in com.flatten(self)))
235
236 @property
237 def has_invalid_return_type(self) -> bool:
238 types = self.operand_types
239 obj_dtype_set = frozenset([np.dtype("object")])
240 return self.return_type == object and types - obj_dtype_set
241
242 @property
243 def operand_types(self):
244 return frozenset(term.type for term in com.flatten(self))
245
246 @property
247 def is_scalar(self) -> bool:
248 return all(operand.is_scalar for operand in self.operands)
249
250 @property
251 def is_datetime(self) -> bool:
252 try:
253 t = self.return_type.type
254 except AttributeError:
255 t = self.return_type
256
257 return issubclass(t, (datetime, np.datetime64))
258
259
260def _in(x, y):
261 """
262 Compute the vectorized membership of ``x in y`` if possible, otherwise
263 use Python.
264 """
265 try:
266 return x.isin(y)
267 except AttributeError:
268 if is_list_like(x):
269 try:
270 return y.isin(x)
271 except AttributeError:
272 pass
273 return x in y
274
275
276def _not_in(x, y):
277 """
278 Compute the vectorized membership of ``x not in y`` if possible,
279 otherwise use Python.
280 """
281 try:
282 return ~x.isin(y)
283 except AttributeError:
284 if is_list_like(x):
285 try:
286 return ~y.isin(x)
287 except AttributeError:
288 pass
289 return x not in y
290
291
292CMP_OPS_SYMS = (">", "<", ">=", "<=", "==", "!=", "in", "not in")
293_cmp_ops_funcs = (
294 operator.gt,
295 operator.lt,
296 operator.ge,
297 operator.le,
298 operator.eq,
299 operator.ne,
300 _in,
301 _not_in,
302)
303_cmp_ops_dict = dict(zip(CMP_OPS_SYMS, _cmp_ops_funcs))
304
305BOOL_OPS_SYMS = ("&", "|", "and", "or")
306_bool_ops_funcs = (operator.and_, operator.or_, operator.and_, operator.or_)
307_bool_ops_dict = dict(zip(BOOL_OPS_SYMS, _bool_ops_funcs))
308
309ARITH_OPS_SYMS = ("+", "-", "*", "/", "**", "//", "%")
310_arith_ops_funcs = (
311 operator.add,
312 operator.sub,
313 operator.mul,
314 operator.truediv,
315 operator.pow,
316 operator.floordiv,
317 operator.mod,
318)
319_arith_ops_dict = dict(zip(ARITH_OPS_SYMS, _arith_ops_funcs))
320
321SPECIAL_CASE_ARITH_OPS_SYMS = ("**", "//", "%")
322_special_case_arith_ops_funcs = (operator.pow, operator.floordiv, operator.mod)
323_special_case_arith_ops_dict = dict(
324 zip(SPECIAL_CASE_ARITH_OPS_SYMS, _special_case_arith_ops_funcs)
325)
326
327_binary_ops_dict = {}
328
329for d in (_cmp_ops_dict, _bool_ops_dict, _arith_ops_dict):
330 _binary_ops_dict.update(d)
331
332
333def _cast_inplace(terms, acceptable_dtypes, dtype) -> None:
334 """
335 Cast an expression inplace.
336
337 Parameters
338 ----------
339 terms : Op
340 The expression that should cast.
341 acceptable_dtypes : list of acceptable numpy.dtype
342 Will not cast if term's dtype in this list.
343 dtype : str or numpy.dtype
344 The dtype to cast to.
345 """
346 dt = np.dtype(dtype)
347 for term in terms:
348 if term.type in acceptable_dtypes:
349 continue
350
351 try:
352 new_value = term.value.astype(dt)
353 except AttributeError:
354 new_value = dt.type(term.value)
355 term.update(new_value)
356
357
358def is_term(obj) -> bool:
359 return isinstance(obj, Term)
360
361
362class BinOp(Op):
363 """
364 Hold a binary operator and its operands.
365
366 Parameters
367 ----------
368 op : str
369 lhs : Term or Op
370 rhs : Term or Op
371 """
372
373 def __init__(self, op: str, lhs, rhs) -> None:
374 super().__init__(op, (lhs, rhs))
375 self.lhs = lhs
376 self.rhs = rhs
377
378 self._disallow_scalar_only_bool_ops()
379
380 self.convert_values()
381
382 try:
383 self.func = _binary_ops_dict[op]
384 except KeyError as err:
385 # has to be made a list for python3
386 keys = list(_binary_ops_dict.keys())
387 raise ValueError(
388 f"Invalid binary operator {repr(op)}, valid operators are {keys}"
389 ) from err
390
391 def __call__(self, env):
392 """
393 Recursively evaluate an expression in Python space.
394
395 Parameters
396 ----------
397 env : Scope
398
399 Returns
400 -------
401 object
402 The result of an evaluated expression.
403 """
404 # recurse over the left/right nodes
405 left = self.lhs(env)
406 right = self.rhs(env)
407
408 return self.func(left, right)
409
410 def evaluate(self, env, engine: str, parser, term_type, eval_in_python):
411 """
412 Evaluate a binary operation *before* being passed to the engine.
413
414 Parameters
415 ----------
416 env : Scope
417 engine : str
418 parser : str
419 term_type : type
420 eval_in_python : list
421
422 Returns
423 -------
424 term_type
425 The "pre-evaluated" expression as an instance of ``term_type``
426 """
427 if engine == "python":
428 res = self(env)
429 else:
430 # recurse over the left/right nodes
431
432 left = self.lhs.evaluate(
433 env,
434 engine=engine,
435 parser=parser,
436 term_type=term_type,
437 eval_in_python=eval_in_python,
438 )
439
440 right = self.rhs.evaluate(
441 env,
442 engine=engine,
443 parser=parser,
444 term_type=term_type,
445 eval_in_python=eval_in_python,
446 )
447
448 # base cases
449 if self.op in eval_in_python:
450 res = self.func(left.value, right.value)
451 else:
452 from pandas.core.computation.eval import eval
453
454 res = eval(self, local_dict=env, engine=engine, parser=parser)
455
456 name = env.add_tmp(res)
457 return term_type(name, env=env)
458
459 def convert_values(self) -> None:
460 """
461 Convert datetimes to a comparable value in an expression.
462 """
463
464 def stringify(value):
465 encoder: Callable
466 if self.encoding is not None:
467 encoder = partial(pprint_thing_encoded, encoding=self.encoding)
468 else:
469 encoder = pprint_thing
470 return encoder(value)
471
472 lhs, rhs = self.lhs, self.rhs
473
474 if is_term(lhs) and lhs.is_datetime and is_term(rhs) and rhs.is_scalar:
475 v = rhs.value
476 if isinstance(v, (int, float)):
477 v = stringify(v)
478 v = Timestamp(ensure_decoded(v))
479 if v.tz is not None:
480 v = v.tz_convert("UTC")
481 self.rhs.update(v)
482
483 if is_term(rhs) and rhs.is_datetime and is_term(lhs) and lhs.is_scalar:
484 v = lhs.value
485 if isinstance(v, (int, float)):
486 v = stringify(v)
487 v = Timestamp(ensure_decoded(v))
488 if v.tz is not None:
489 v = v.tz_convert("UTC")
490 self.lhs.update(v)
491
492 def _disallow_scalar_only_bool_ops(self):
493 rhs = self.rhs
494 lhs = self.lhs
495
496 # GH#24883 unwrap dtype if necessary to ensure we have a type object
497 rhs_rt = rhs.return_type
498 rhs_rt = getattr(rhs_rt, "type", rhs_rt)
499 lhs_rt = lhs.return_type
500 lhs_rt = getattr(lhs_rt, "type", lhs_rt)
501 if (
502 (lhs.is_scalar or rhs.is_scalar)
503 and self.op in _bool_ops_dict
504 and (
505 not (
506 issubclass(rhs_rt, (bool, np.bool_))
507 and issubclass(lhs_rt, (bool, np.bool_))
508 )
509 )
510 ):
511 raise NotImplementedError("cannot evaluate scalar only bool ops")
512
513
514def isnumeric(dtype) -> bool:
515 return issubclass(np.dtype(dtype).type, np.number)
516
517
518class Div(BinOp):
519 """
520 Div operator to special case casting.
521
522 Parameters
523 ----------
524 lhs, rhs : Term or Op
525 The Terms or Ops in the ``/`` expression.
526 """
527
528 def __init__(self, lhs, rhs) -> None:
529 super().__init__("/", lhs, rhs)
530
531 if not isnumeric(lhs.return_type) or not isnumeric(rhs.return_type):
532 raise TypeError(
533 f"unsupported operand type(s) for {self.op}: "
534 f"'{lhs.return_type}' and '{rhs.return_type}'"
535 )
536
537 # do not upcast float32s to float64 un-necessarily
538 acceptable_dtypes = [np.float32, np.float_]
539 _cast_inplace(com.flatten(self), acceptable_dtypes, np.float_)
540
541
542UNARY_OPS_SYMS = ("+", "-", "~", "not")
543_unary_ops_funcs = (operator.pos, operator.neg, operator.invert, operator.invert)
544_unary_ops_dict = dict(zip(UNARY_OPS_SYMS, _unary_ops_funcs))
545
546
547class UnaryOp(Op):
548 """
549 Hold a unary operator and its operands.
550
551 Parameters
552 ----------
553 op : str
554 The token used to represent the operator.
555 operand : Term or Op
556 The Term or Op operand to the operator.
557
558 Raises
559 ------
560 ValueError
561 * If no function associated with the passed operator token is found.
562 """
563
564 def __init__(self, op: Literal["+", "-", "~", "not"], operand) -> None:
565 super().__init__(op, (operand,))
566 self.operand = operand
567
568 try:
569 self.func = _unary_ops_dict[op]
570 except KeyError as err:
571 raise ValueError(
572 f"Invalid unary operator {repr(op)}, "
573 f"valid operators are {UNARY_OPS_SYMS}"
574 ) from err
575
576 def __call__(self, env) -> MathCall:
577 operand = self.operand(env)
578 # error: Cannot call function of unknown type
579 return self.func(operand) # type: ignore[operator]
580
581 def __repr__(self) -> str:
582 return pprint_thing(f"{self.op}({self.operand})")
583
584 @property
585 def return_type(self) -> np.dtype:
586 operand = self.operand
587 if operand.return_type == np.dtype("bool"):
588 return np.dtype("bool")
589 if isinstance(operand, Op) and (
590 operand.op in _cmp_ops_dict or operand.op in _bool_ops_dict
591 ):
592 return np.dtype("bool")
593 return np.dtype("int")
594
595
596class MathCall(Op):
597 def __init__(self, func, args) -> None:
598 super().__init__(func.name, args)
599 self.func = func
600
601 def __call__(self, env):
602 # error: "Op" not callable
603 operands = [op(env) for op in self.operands] # type: ignore[operator]
604 with np.errstate(all="ignore"):
605 return self.func.func(*operands)
606
607 def __repr__(self) -> str:
608 operands = map(str, self.operands)
609 return pprint_thing(f"{self.op}({','.join(operands)})")
610
611
612class FuncNode:
613 def __init__(self, name: str) -> None:
614 if name not in MATHOPS:
615 raise ValueError(f'"{name}" is not a supported function')
616 self.name = name
617 self.func = getattr(np, name)
618
619 def __call__(self, *args):
620 return MathCall(self, args)