1"""
2Operator classes for eval.
3"""
4
5from __future__ import annotations
6
7from datetime import datetime
8from functools import partial
9import operator
10from typing import (
11 TYPE_CHECKING,
12 Callable,
13 Literal,
14)
15
16import numpy as np
17
18from pandas._libs.tslibs import Timestamp
19
20from pandas.core.dtypes.common import (
21 is_list_like,
22 is_scalar,
23)
24
25import pandas.core.common as com
26from pandas.core.computation.common import (
27 ensure_decoded,
28 result_type_many,
29)
30from pandas.core.computation.scope import DEFAULT_GLOBALS
31
32from pandas.io.formats.printing import (
33 pprint_thing,
34 pprint_thing_encoded,
35)
36
37if TYPE_CHECKING:
38 from collections.abc import (
39 Iterable,
40 Iterator,
41 )
42
43REDUCTIONS = ("sum", "prod", "min", "max")
44
45_unary_math_ops = (
46 "sin",
47 "cos",
48 "exp",
49 "log",
50 "expm1",
51 "log1p",
52 "sqrt",
53 "sinh",
54 "cosh",
55 "tanh",
56 "arcsin",
57 "arccos",
58 "arctan",
59 "arccosh",
60 "arcsinh",
61 "arctanh",
62 "abs",
63 "log10",
64 "floor",
65 "ceil",
66)
67_binary_math_ops = ("arctan2",)
68
69MATHOPS = _unary_math_ops + _binary_math_ops
70
71
72LOCAL_TAG = "__pd_eval_local_"
73
74
75class Term:
76 def __new__(cls, name, env, side=None, encoding=None):
77 klass = Constant if not isinstance(name, str) else cls
78 # error: Argument 2 for "super" not an instance of argument 1
79 supr_new = super(Term, klass).__new__ # type: ignore[misc]
80 return supr_new(klass)
81
82 is_local: bool
83
84 def __init__(self, name, env, side=None, encoding=None) -> None:
85 # name is a str for Term, but may be something else for subclasses
86 self._name = name
87 self.env = env
88 self.side = side
89 tname = str(name)
90 self.is_local = tname.startswith(LOCAL_TAG) or tname in DEFAULT_GLOBALS
91 self._value = self._resolve_name()
92 self.encoding = encoding
93
94 @property
95 def local_name(self) -> str:
96 return self.name.replace(LOCAL_TAG, "")
97
98 def __repr__(self) -> str:
99 return pprint_thing(self.name)
100
101 def __call__(self, *args, **kwargs):
102 return self.value
103
104 def evaluate(self, *args, **kwargs) -> Term:
105 return self
106
107 def _resolve_name(self):
108 local_name = str(self.local_name)
109 is_local = self.is_local
110 if local_name in self.env.scope and isinstance(
111 self.env.scope[local_name], type
112 ):
113 is_local = False
114
115 res = self.env.resolve(local_name, is_local=is_local)
116 self.update(res)
117
118 if hasattr(res, "ndim") and res.ndim > 2:
119 raise NotImplementedError(
120 "N-dimensional objects, where N > 2, are not supported with eval"
121 )
122 return res
123
124 def update(self, value) -> None:
125 """
126 search order for local (i.e., @variable) variables:
127
128 scope, key_variable
129 [('locals', 'local_name'),
130 ('globals', 'local_name'),
131 ('locals', 'key'),
132 ('globals', 'key')]
133 """
134 key = self.name
135
136 # if it's a variable name (otherwise a constant)
137 if isinstance(key, str):
138 self.env.swapkey(self.local_name, key, new_value=value)
139
140 self.value = value
141
142 @property
143 def is_scalar(self) -> bool:
144 return is_scalar(self._value)
145
146 @property
147 def type(self):
148 try:
149 # potentially very slow for large, mixed dtype frames
150 return self._value.values.dtype
151 except AttributeError:
152 try:
153 # ndarray
154 return self._value.dtype
155 except AttributeError:
156 # scalar
157 return type(self._value)
158
159 return_type = type
160
161 @property
162 def raw(self) -> str:
163 return f"{type(self).__name__}(name={repr(self.name)}, type={self.type})"
164
165 @property
166 def is_datetime(self) -> bool:
167 try:
168 t = self.type.type
169 except AttributeError:
170 t = self.type
171
172 return issubclass(t, (datetime, np.datetime64))
173
174 @property
175 def value(self):
176 return self._value
177
178 @value.setter
179 def value(self, new_value) -> None:
180 self._value = new_value
181
182 @property
183 def name(self):
184 return self._name
185
186 @property
187 def ndim(self) -> int:
188 return self._value.ndim
189
190
191class Constant(Term):
192 def _resolve_name(self):
193 return self._name
194
195 @property
196 def name(self):
197 return self.value
198
199 def __repr__(self) -> str:
200 # in python 2 str() of float
201 # can truncate shorter than repr()
202 return repr(self.name)
203
204
205_bool_op_map = {"not": "~", "and": "&", "or": "|"}
206
207
208class Op:
209 """
210 Hold an operator of arbitrary arity.
211 """
212
213 op: str
214
215 def __init__(self, op: str, operands: Iterable[Term | Op], encoding=None) -> None:
216 self.op = _bool_op_map.get(op, op)
217 self.operands = operands
218 self.encoding = encoding
219
220 def __iter__(self) -> Iterator:
221 return iter(self.operands)
222
223 def __repr__(self) -> str:
224 """
225 Print a generic n-ary operator and its operands using infix notation.
226 """
227 # recurse over the operands
228 parened = (f"({pprint_thing(opr)})" for opr in self.operands)
229 return pprint_thing(f" {self.op} ".join(parened))
230
231 @property
232 def return_type(self):
233 # clobber types to bool if the op is a boolean operator
234 if self.op in (CMP_OPS_SYMS + BOOL_OPS_SYMS):
235 return np.bool_
236 return result_type_many(*(term.type for term in com.flatten(self)))
237
238 @property
239 def has_invalid_return_type(self) -> bool:
240 types = self.operand_types
241 obj_dtype_set = frozenset([np.dtype("object")])
242 return self.return_type == object and types - obj_dtype_set
243
244 @property
245 def operand_types(self):
246 return frozenset(term.type for term in com.flatten(self))
247
248 @property
249 def is_scalar(self) -> bool:
250 return all(operand.is_scalar for operand in self.operands)
251
252 @property
253 def is_datetime(self) -> bool:
254 try:
255 t = self.return_type.type
256 except AttributeError:
257 t = self.return_type
258
259 return issubclass(t, (datetime, np.datetime64))
260
261
262def _in(x, y):
263 """
264 Compute the vectorized membership of ``x in y`` if possible, otherwise
265 use Python.
266 """
267 try:
268 return x.isin(y)
269 except AttributeError:
270 if is_list_like(x):
271 try:
272 return y.isin(x)
273 except AttributeError:
274 pass
275 return x in y
276
277
278def _not_in(x, y):
279 """
280 Compute the vectorized membership of ``x not in y`` if possible,
281 otherwise use Python.
282 """
283 try:
284 return ~x.isin(y)
285 except AttributeError:
286 if is_list_like(x):
287 try:
288 return ~y.isin(x)
289 except AttributeError:
290 pass
291 return x not in y
292
293
294CMP_OPS_SYMS = (">", "<", ">=", "<=", "==", "!=", "in", "not in")
295_cmp_ops_funcs = (
296 operator.gt,
297 operator.lt,
298 operator.ge,
299 operator.le,
300 operator.eq,
301 operator.ne,
302 _in,
303 _not_in,
304)
305_cmp_ops_dict = dict(zip(CMP_OPS_SYMS, _cmp_ops_funcs))
306
307BOOL_OPS_SYMS = ("&", "|", "and", "or")
308_bool_ops_funcs = (operator.and_, operator.or_, operator.and_, operator.or_)
309_bool_ops_dict = dict(zip(BOOL_OPS_SYMS, _bool_ops_funcs))
310
311ARITH_OPS_SYMS = ("+", "-", "*", "/", "**", "//", "%")
312_arith_ops_funcs = (
313 operator.add,
314 operator.sub,
315 operator.mul,
316 operator.truediv,
317 operator.pow,
318 operator.floordiv,
319 operator.mod,
320)
321_arith_ops_dict = dict(zip(ARITH_OPS_SYMS, _arith_ops_funcs))
322
323SPECIAL_CASE_ARITH_OPS_SYMS = ("**", "//", "%")
324_special_case_arith_ops_funcs = (operator.pow, operator.floordiv, operator.mod)
325_special_case_arith_ops_dict = dict(
326 zip(SPECIAL_CASE_ARITH_OPS_SYMS, _special_case_arith_ops_funcs)
327)
328
329_binary_ops_dict = {}
330
331for d in (_cmp_ops_dict, _bool_ops_dict, _arith_ops_dict):
332 _binary_ops_dict.update(d)
333
334
335def is_term(obj) -> bool:
336 return isinstance(obj, Term)
337
338
339class BinOp(Op):
340 """
341 Hold a binary operator and its operands.
342
343 Parameters
344 ----------
345 op : str
346 lhs : Term or Op
347 rhs : Term or Op
348 """
349
350 def __init__(self, op: str, lhs, rhs) -> None:
351 super().__init__(op, (lhs, rhs))
352 self.lhs = lhs
353 self.rhs = rhs
354
355 self._disallow_scalar_only_bool_ops()
356
357 self.convert_values()
358
359 try:
360 self.func = _binary_ops_dict[op]
361 except KeyError as err:
362 # has to be made a list for python3
363 keys = list(_binary_ops_dict.keys())
364 raise ValueError(
365 f"Invalid binary operator {repr(op)}, valid operators are {keys}"
366 ) from err
367
368 def __call__(self, env):
369 """
370 Recursively evaluate an expression in Python space.
371
372 Parameters
373 ----------
374 env : Scope
375
376 Returns
377 -------
378 object
379 The result of an evaluated expression.
380 """
381 # recurse over the left/right nodes
382 left = self.lhs(env)
383 right = self.rhs(env)
384
385 return self.func(left, right)
386
387 def evaluate(self, env, engine: str, parser, term_type, eval_in_python):
388 """
389 Evaluate a binary operation *before* being passed to the engine.
390
391 Parameters
392 ----------
393 env : Scope
394 engine : str
395 parser : str
396 term_type : type
397 eval_in_python : list
398
399 Returns
400 -------
401 term_type
402 The "pre-evaluated" expression as an instance of ``term_type``
403 """
404 if engine == "python":
405 res = self(env)
406 else:
407 # recurse over the left/right nodes
408
409 left = self.lhs.evaluate(
410 env,
411 engine=engine,
412 parser=parser,
413 term_type=term_type,
414 eval_in_python=eval_in_python,
415 )
416
417 right = self.rhs.evaluate(
418 env,
419 engine=engine,
420 parser=parser,
421 term_type=term_type,
422 eval_in_python=eval_in_python,
423 )
424
425 # base cases
426 if self.op in eval_in_python:
427 res = self.func(left.value, right.value)
428 else:
429 from pandas.core.computation.eval import eval
430
431 res = eval(self, local_dict=env, engine=engine, parser=parser)
432
433 name = env.add_tmp(res)
434 return term_type(name, env=env)
435
436 def convert_values(self) -> None:
437 """
438 Convert datetimes to a comparable value in an expression.
439 """
440
441 def stringify(value):
442 encoder: Callable
443 if self.encoding is not None:
444 encoder = partial(pprint_thing_encoded, encoding=self.encoding)
445 else:
446 encoder = pprint_thing
447 return encoder(value)
448
449 lhs, rhs = self.lhs, self.rhs
450
451 if is_term(lhs) and lhs.is_datetime and is_term(rhs) and rhs.is_scalar:
452 v = rhs.value
453 if isinstance(v, (int, float)):
454 v = stringify(v)
455 v = Timestamp(ensure_decoded(v))
456 if v.tz is not None:
457 v = v.tz_convert("UTC")
458 self.rhs.update(v)
459
460 if is_term(rhs) and rhs.is_datetime and is_term(lhs) and lhs.is_scalar:
461 v = lhs.value
462 if isinstance(v, (int, float)):
463 v = stringify(v)
464 v = Timestamp(ensure_decoded(v))
465 if v.tz is not None:
466 v = v.tz_convert("UTC")
467 self.lhs.update(v)
468
469 def _disallow_scalar_only_bool_ops(self):
470 rhs = self.rhs
471 lhs = self.lhs
472
473 # GH#24883 unwrap dtype if necessary to ensure we have a type object
474 rhs_rt = rhs.return_type
475 rhs_rt = getattr(rhs_rt, "type", rhs_rt)
476 lhs_rt = lhs.return_type
477 lhs_rt = getattr(lhs_rt, "type", lhs_rt)
478 if (
479 (lhs.is_scalar or rhs.is_scalar)
480 and self.op in _bool_ops_dict
481 and (
482 not (
483 issubclass(rhs_rt, (bool, np.bool_))
484 and issubclass(lhs_rt, (bool, np.bool_))
485 )
486 )
487 ):
488 raise NotImplementedError("cannot evaluate scalar only bool ops")
489
490
491def isnumeric(dtype) -> bool:
492 return issubclass(np.dtype(dtype).type, np.number)
493
494
495UNARY_OPS_SYMS = ("+", "-", "~", "not")
496_unary_ops_funcs = (operator.pos, operator.neg, operator.invert, operator.invert)
497_unary_ops_dict = dict(zip(UNARY_OPS_SYMS, _unary_ops_funcs))
498
499
500class UnaryOp(Op):
501 """
502 Hold a unary operator and its operands.
503
504 Parameters
505 ----------
506 op : str
507 The token used to represent the operator.
508 operand : Term or Op
509 The Term or Op operand to the operator.
510
511 Raises
512 ------
513 ValueError
514 * If no function associated with the passed operator token is found.
515 """
516
517 def __init__(self, op: Literal["+", "-", "~", "not"], operand) -> None:
518 super().__init__(op, (operand,))
519 self.operand = operand
520
521 try:
522 self.func = _unary_ops_dict[op]
523 except KeyError as err:
524 raise ValueError(
525 f"Invalid unary operator {repr(op)}, "
526 f"valid operators are {UNARY_OPS_SYMS}"
527 ) from err
528
529 def __call__(self, env) -> MathCall:
530 operand = self.operand(env)
531 # error: Cannot call function of unknown type
532 return self.func(operand) # type: ignore[operator]
533
534 def __repr__(self) -> str:
535 return pprint_thing(f"{self.op}({self.operand})")
536
537 @property
538 def return_type(self) -> np.dtype:
539 operand = self.operand
540 if operand.return_type == np.dtype("bool"):
541 return np.dtype("bool")
542 if isinstance(operand, Op) and (
543 operand.op in _cmp_ops_dict or operand.op in _bool_ops_dict
544 ):
545 return np.dtype("bool")
546 return np.dtype("int")
547
548
549class MathCall(Op):
550 def __init__(self, func, args) -> None:
551 super().__init__(func.name, args)
552 self.func = func
553
554 def __call__(self, env):
555 # error: "Op" not callable
556 operands = [op(env) for op in self.operands] # type: ignore[operator]
557 return self.func.func(*operands)
558
559 def __repr__(self) -> str:
560 operands = map(str, self.operands)
561 return pprint_thing(f"{self.op}({','.join(operands)})")
562
563
564class FuncNode:
565 def __init__(self, name: str) -> None:
566 if name not in MATHOPS:
567 raise ValueError(f'"{name}" is not a supported function')
568 self.name = name
569 self.func = getattr(np, name)
570
571 def __call__(self, *args) -> MathCall:
572 return MathCall(self, args)