1"""
2Expressions
3-----------
4
5Offer fast expression evaluation through numexpr
6
7"""
8from __future__ import annotations
9
10import operator
11from typing import TYPE_CHECKING
12import warnings
13
14import numpy as np
15
16from pandas._config import get_option
17
18from pandas.util._exceptions import find_stack_level
19
20from pandas.core import roperator
21from pandas.core.computation.check import NUMEXPR_INSTALLED
22
23if NUMEXPR_INSTALLED:
24 import numexpr as ne
25
26if TYPE_CHECKING:
27 from pandas._typing import FuncType
28
29_TEST_MODE: bool | None = None
30_TEST_RESULT: list[bool] = []
31USE_NUMEXPR = NUMEXPR_INSTALLED
32_evaluate: FuncType | None = None
33_where: FuncType | None = None
34
35# the set of dtypes that we will allow pass to numexpr
36_ALLOWED_DTYPES = {
37 "evaluate": {"int64", "int32", "float64", "float32", "bool"},
38 "where": {"int64", "float64", "bool"},
39}
40
41# the minimum prod shape that we will use numexpr
42_MIN_ELEMENTS = 1_000_000
43
44
45def set_use_numexpr(v: bool = True) -> None:
46 # set/unset to use numexpr
47 global USE_NUMEXPR
48 if NUMEXPR_INSTALLED:
49 USE_NUMEXPR = v
50
51 # choose what we are going to do
52 global _evaluate, _where
53
54 _evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard
55 _where = _where_numexpr if USE_NUMEXPR else _where_standard
56
57
58def set_numexpr_threads(n=None) -> None:
59 # if we are using numexpr, set the threads to n
60 # otherwise reset
61 if NUMEXPR_INSTALLED and USE_NUMEXPR:
62 if n is None:
63 n = ne.detect_number_of_cores()
64 ne.set_num_threads(n)
65
66
67def _evaluate_standard(op, op_str, a, b):
68 """
69 Standard evaluation.
70 """
71 if _TEST_MODE:
72 _store_test_result(False)
73 return op(a, b)
74
75
76def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
77 """return a boolean if we WILL be using numexpr"""
78 if op_str is not None:
79 # required min elements (otherwise we are adding overhead)
80 if a.size > _MIN_ELEMENTS:
81 # check for dtype compatibility
82 dtypes: set[str] = set()
83 for o in [a, b]:
84 # ndarray and Series Case
85 if hasattr(o, "dtype"):
86 dtypes |= {o.dtype.name}
87
88 # allowed are a superset
89 if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
90 return True
91
92 return False
93
94
95def _evaluate_numexpr(op, op_str, a, b):
96 result = None
97
98 if _can_use_numexpr(op, op_str, a, b, "evaluate"):
99 is_reversed = op.__name__.strip("_").startswith("r")
100 if is_reversed:
101 # we were originally called by a reversed op method
102 a, b = b, a
103
104 a_value = a
105 b_value = b
106
107 try:
108 result = ne.evaluate(
109 f"a_value {op_str} b_value",
110 local_dict={"a_value": a_value, "b_value": b_value},
111 casting="safe",
112 )
113 except TypeError:
114 # numexpr raises eg for array ** array with integers
115 # (https://github.com/pydata/numexpr/issues/379)
116 pass
117 except NotImplementedError:
118 if _bool_arith_fallback(op_str, a, b):
119 pass
120 else:
121 raise
122
123 if is_reversed:
124 # reverse order to original for fallback
125 a, b = b, a
126
127 if _TEST_MODE:
128 _store_test_result(result is not None)
129
130 if result is None:
131 result = _evaluate_standard(op, op_str, a, b)
132
133 return result
134
135
136_op_str_mapping = {
137 operator.add: "+",
138 roperator.radd: "+",
139 operator.mul: "*",
140 roperator.rmul: "*",
141 operator.sub: "-",
142 roperator.rsub: "-",
143 operator.truediv: "/",
144 roperator.rtruediv: "/",
145 # floordiv not supported by numexpr 2.x
146 operator.floordiv: None,
147 roperator.rfloordiv: None,
148 # we require Python semantics for mod of negative for backwards compatibility
149 # see https://github.com/pydata/numexpr/issues/365
150 # so sticking with unaccelerated for now GH#36552
151 operator.mod: None,
152 roperator.rmod: None,
153 operator.pow: "**",
154 roperator.rpow: "**",
155 operator.eq: "==",
156 operator.ne: "!=",
157 operator.le: "<=",
158 operator.lt: "<",
159 operator.ge: ">=",
160 operator.gt: ">",
161 operator.and_: "&",
162 roperator.rand_: "&",
163 operator.or_: "|",
164 roperator.ror_: "|",
165 operator.xor: "^",
166 roperator.rxor: "^",
167 divmod: None,
168 roperator.rdivmod: None,
169}
170
171
172def _where_standard(cond, a, b):
173 # Caller is responsible for extracting ndarray if necessary
174 return np.where(cond, a, b)
175
176
177def _where_numexpr(cond, a, b):
178 # Caller is responsible for extracting ndarray if necessary
179 result = None
180
181 if _can_use_numexpr(None, "where", a, b, "where"):
182 result = ne.evaluate(
183 "where(cond_value, a_value, b_value)",
184 local_dict={"cond_value": cond, "a_value": a, "b_value": b},
185 casting="safe",
186 )
187
188 if result is None:
189 result = _where_standard(cond, a, b)
190
191 return result
192
193
194# turn myself on
195set_use_numexpr(get_option("compute.use_numexpr"))
196
197
198def _has_bool_dtype(x):
199 try:
200 return x.dtype == bool
201 except AttributeError:
202 return isinstance(x, (bool, np.bool_))
203
204
205_BOOL_OP_UNSUPPORTED = {"+": "|", "*": "&", "-": "^"}
206
207
208def _bool_arith_fallback(op_str, a, b) -> bool:
209 """
210 Check if we should fallback to the python `_evaluate_standard` in case
211 of an unsupported operation by numexpr, which is the case for some
212 boolean ops.
213 """
214 if _has_bool_dtype(a) and _has_bool_dtype(b):
215 if op_str in _BOOL_OP_UNSUPPORTED:
216 warnings.warn(
217 f"evaluating in Python space because the {repr(op_str)} "
218 "operator is not supported by numexpr for the bool dtype, "
219 f"use {repr(_BOOL_OP_UNSUPPORTED[op_str])} instead.",
220 stacklevel=find_stack_level(),
221 )
222 return True
223 return False
224
225
226def evaluate(op, a, b, use_numexpr: bool = True):
227 """
228 Evaluate and return the expression of the op on a and b.
229
230 Parameters
231 ----------
232 op : the actual operand
233 a : left operand
234 b : right operand
235 use_numexpr : bool, default True
236 Whether to try to use numexpr.
237 """
238 op_str = _op_str_mapping[op]
239 if op_str is not None:
240 if use_numexpr:
241 # error: "None" not callable
242 return _evaluate(op, op_str, a, b) # type: ignore[misc]
243 return _evaluate_standard(op, op_str, a, b)
244
245
246def where(cond, a, b, use_numexpr: bool = True):
247 """
248 Evaluate the where condition cond on a and b.
249
250 Parameters
251 ----------
252 cond : np.ndarray[bool]
253 a : return if cond is True
254 b : return if cond is False
255 use_numexpr : bool, default True
256 Whether to try to use numexpr.
257 """
258 assert _where is not None
259 return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
260
261
262def set_test_mode(v: bool = True) -> None:
263 """
264 Keeps track of whether numexpr was used.
265
266 Stores an additional ``True`` for every successful use of evaluate with
267 numexpr since the last ``get_test_result``.
268 """
269 global _TEST_MODE, _TEST_RESULT
270 _TEST_MODE = v
271 _TEST_RESULT = []
272
273
274def _store_test_result(used_numexpr: bool) -> None:
275 if used_numexpr:
276 _TEST_RESULT.append(used_numexpr)
277
278
279def get_test_result() -> list[bool]:
280 """
281 Get test result and reset test_results.
282 """
283 global _TEST_RESULT
284 res = _TEST_RESULT
285 _TEST_RESULT = []
286 return res