1"""
2Expressions
3-----------
4
5Offer fast expression evaluation through numexpr
6
7"""
8from __future__ import annotations
9
10import operator
11import warnings
12
13import numpy as np
14
15from pandas._config import get_option
16
17from pandas._typing import FuncType
18from pandas.util._exceptions import find_stack_level
19
20from pandas.core.computation.check import NUMEXPR_INSTALLED
21from pandas.core.ops import roperator
22
23if NUMEXPR_INSTALLED:
24 import numexpr as ne
25
26_TEST_MODE: bool | None = None
27_TEST_RESULT: list[bool] = []
28USE_NUMEXPR = NUMEXPR_INSTALLED
29_evaluate: FuncType | None = None
30_where: FuncType | None = None
31
32# the set of dtypes that we will allow pass to numexpr
33_ALLOWED_DTYPES = {
34 "evaluate": {"int64", "int32", "float64", "float32", "bool"},
35 "where": {"int64", "float64", "bool"},
36}
37
38# the minimum prod shape that we will use numexpr
39_MIN_ELEMENTS = 1_000_000
40
41
42def set_use_numexpr(v: bool = True) -> None:
43 # set/unset to use numexpr
44 global USE_NUMEXPR
45 if NUMEXPR_INSTALLED:
46 USE_NUMEXPR = v
47
48 # choose what we are going to do
49 global _evaluate, _where
50
51 _evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard
52 _where = _where_numexpr if USE_NUMEXPR else _where_standard
53
54
55def set_numexpr_threads(n=None) -> None:
56 # if we are using numexpr, set the threads to n
57 # otherwise reset
58 if NUMEXPR_INSTALLED and USE_NUMEXPR:
59 if n is None:
60 n = ne.detect_number_of_cores()
61 ne.set_num_threads(n)
62
63
64def _evaluate_standard(op, op_str, a, b):
65 """
66 Standard evaluation.
67 """
68 if _TEST_MODE:
69 _store_test_result(False)
70 return op(a, b)
71
72
73def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
74 """return a boolean if we WILL be using numexpr"""
75 if op_str is not None:
76 # required min elements (otherwise we are adding overhead)
77 if a.size > _MIN_ELEMENTS:
78 # check for dtype compatibility
79 dtypes: set[str] = set()
80 for o in [a, b]:
81 # ndarray and Series Case
82 if hasattr(o, "dtype"):
83 dtypes |= {o.dtype.name}
84
85 # allowed are a superset
86 if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
87 return True
88
89 return False
90
91
92def _evaluate_numexpr(op, op_str, a, b):
93 result = None
94
95 if _can_use_numexpr(op, op_str, a, b, "evaluate"):
96 is_reversed = op.__name__.strip("_").startswith("r")
97 if is_reversed:
98 # we were originally called by a reversed op method
99 a, b = b, a
100
101 a_value = a
102 b_value = b
103
104 try:
105 result = ne.evaluate(
106 f"a_value {op_str} b_value",
107 local_dict={"a_value": a_value, "b_value": b_value},
108 casting="safe",
109 )
110 except TypeError:
111 # numexpr raises eg for array ** array with integers
112 # (https://github.com/pydata/numexpr/issues/379)
113 pass
114 except NotImplementedError:
115 if _bool_arith_fallback(op_str, a, b):
116 pass
117 else:
118 raise
119
120 if is_reversed:
121 # reverse order to original for fallback
122 a, b = b, a
123
124 if _TEST_MODE:
125 _store_test_result(result is not None)
126
127 if result is None:
128 result = _evaluate_standard(op, op_str, a, b)
129
130 return result
131
132
133_op_str_mapping = {
134 operator.add: "+",
135 roperator.radd: "+",
136 operator.mul: "*",
137 roperator.rmul: "*",
138 operator.sub: "-",
139 roperator.rsub: "-",
140 operator.truediv: "/",
141 roperator.rtruediv: "/",
142 # floordiv not supported by numexpr 2.x
143 operator.floordiv: None,
144 roperator.rfloordiv: None,
145 # we require Python semantics for mod of negative for backwards compatibility
146 # see https://github.com/pydata/numexpr/issues/365
147 # so sticking with unaccelerated for now GH#36552
148 operator.mod: None,
149 roperator.rmod: None,
150 operator.pow: "**",
151 roperator.rpow: "**",
152 operator.eq: "==",
153 operator.ne: "!=",
154 operator.le: "<=",
155 operator.lt: "<",
156 operator.ge: ">=",
157 operator.gt: ">",
158 operator.and_: "&",
159 roperator.rand_: "&",
160 operator.or_: "|",
161 roperator.ror_: "|",
162 operator.xor: "^",
163 roperator.rxor: "^",
164 divmod: None,
165 roperator.rdivmod: None,
166}
167
168
169def _where_standard(cond, a, b):
170 # Caller is responsible for extracting ndarray if necessary
171 return np.where(cond, a, b)
172
173
174def _where_numexpr(cond, a, b):
175 # Caller is responsible for extracting ndarray if necessary
176 result = None
177
178 if _can_use_numexpr(None, "where", a, b, "where"):
179 result = ne.evaluate(
180 "where(cond_value, a_value, b_value)",
181 local_dict={"cond_value": cond, "a_value": a, "b_value": b},
182 casting="safe",
183 )
184
185 if result is None:
186 result = _where_standard(cond, a, b)
187
188 return result
189
190
191# turn myself on
192set_use_numexpr(get_option("compute.use_numexpr"))
193
194
195def _has_bool_dtype(x):
196 try:
197 return x.dtype == bool
198 except AttributeError:
199 return isinstance(x, (bool, np.bool_))
200
201
202_BOOL_OP_UNSUPPORTED = {"+": "|", "*": "&", "-": "^"}
203
204
205def _bool_arith_fallback(op_str, a, b) -> bool:
206 """
207 Check if we should fallback to the python `_evaluate_standard` in case
208 of an unsupported operation by numexpr, which is the case for some
209 boolean ops.
210 """
211 if _has_bool_dtype(a) and _has_bool_dtype(b):
212 if op_str in _BOOL_OP_UNSUPPORTED:
213 warnings.warn(
214 f"evaluating in Python space because the {repr(op_str)} "
215 "operator is not supported by numexpr for the bool dtype, "
216 f"use {repr(_BOOL_OP_UNSUPPORTED[op_str])} instead.",
217 stacklevel=find_stack_level(),
218 )
219 return True
220 return False
221
222
223def evaluate(op, a, b, use_numexpr: bool = True):
224 """
225 Evaluate and return the expression of the op on a and b.
226
227 Parameters
228 ----------
229 op : the actual operand
230 a : left operand
231 b : right operand
232 use_numexpr : bool, default True
233 Whether to try to use numexpr.
234 """
235 op_str = _op_str_mapping[op]
236 if op_str is not None:
237 if use_numexpr:
238 # error: "None" not callable
239 return _evaluate(op, op_str, a, b) # type: ignore[misc]
240 return _evaluate_standard(op, op_str, a, b)
241
242
243def where(cond, a, b, use_numexpr: bool = True):
244 """
245 Evaluate the where condition cond on a and b.
246
247 Parameters
248 ----------
249 cond : np.ndarray[bool]
250 a : return if cond is True
251 b : return if cond is False
252 use_numexpr : bool, default True
253 Whether to try to use numexpr.
254 """
255 assert _where is not None
256 return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
257
258
259def set_test_mode(v: bool = True) -> None:
260 """
261 Keeps track of whether numexpr was used.
262
263 Stores an additional ``True`` for every successful use of evaluate with
264 numexpr since the last ``get_test_result``.
265 """
266 global _TEST_MODE, _TEST_RESULT
267 _TEST_MODE = v
268 _TEST_RESULT = []
269
270
271def _store_test_result(used_numexpr: bool) -> None:
272 if used_numexpr:
273 _TEST_RESULT.append(used_numexpr)
274
275
276def get_test_result() -> list[bool]:
277 """
278 Get test result and reset test_results.
279 """
280 global _TEST_RESULT
281 res = _TEST_RESULT
282 _TEST_RESULT = []
283 return res