1"""
2Functions for arithmetic and comparison operations on NumPy arrays and
3ExtensionArrays.
4"""
5from __future__ import annotations
6
7import datetime
8from functools import partial
9import operator
10from typing import Any
11
12import numpy as np
13
14from pandas._libs import (
15 NaT,
16 Timedelta,
17 Timestamp,
18 lib,
19 ops as libops,
20)
21from pandas._libs.tslibs import (
22 BaseOffset,
23 get_supported_reso,
24 get_unit_from_dtype,
25 is_supported_unit,
26 is_unitless,
27 npy_unit_to_abbrev,
28)
29from pandas._typing import (
30 ArrayLike,
31 Shape,
32)
33
34from pandas.core.dtypes.cast import (
35 construct_1d_object_array_from_listlike,
36 find_common_type,
37)
38from pandas.core.dtypes.common import (
39 ensure_object,
40 is_bool_dtype,
41 is_integer_dtype,
42 is_list_like,
43 is_numeric_v_string_like,
44 is_object_dtype,
45 is_scalar,
46)
47from pandas.core.dtypes.generic import (
48 ABCExtensionArray,
49 ABCIndex,
50 ABCSeries,
51)
52from pandas.core.dtypes.missing import (
53 isna,
54 notna,
55)
56
57from pandas.core.computation import expressions
58from pandas.core.construction import ensure_wrapped_if_datetimelike
59from pandas.core.ops import (
60 missing,
61 roperator,
62)
63from pandas.core.ops.dispatch import should_extension_dispatch
64from pandas.core.ops.invalid import invalid_comparison
65
66
67def comp_method_OBJECT_ARRAY(op, x, y):
68 if isinstance(y, list):
69 y = construct_1d_object_array_from_listlike(y)
70
71 if isinstance(y, (np.ndarray, ABCSeries, ABCIndex)):
72 if not is_object_dtype(y.dtype):
73 y = y.astype(np.object_)
74
75 if isinstance(y, (ABCSeries, ABCIndex)):
76 y = y._values
77
78 if x.shape != y.shape:
79 raise ValueError("Shapes must match", x.shape, y.shape)
80 result = libops.vec_compare(x.ravel(), y.ravel(), op)
81 else:
82 result = libops.scalar_compare(x.ravel(), y, op)
83 return result.reshape(x.shape)
84
85
86def _masked_arith_op(x: np.ndarray, y, op):
87 """
88 If the given arithmetic operation fails, attempt it again on
89 only the non-null elements of the input array(s).
90
91 Parameters
92 ----------
93 x : np.ndarray
94 y : np.ndarray, Series, Index
95 op : binary operator
96 """
97 # For Series `x` is 1D so ravel() is a no-op; calling it anyway makes
98 # the logic valid for both Series and DataFrame ops.
99 xrav = x.ravel()
100 assert isinstance(x, np.ndarray), type(x)
101 if isinstance(y, np.ndarray):
102 dtype = find_common_type([x.dtype, y.dtype])
103 result = np.empty(x.size, dtype=dtype)
104
105 if len(x) != len(y):
106 raise ValueError(x.shape, y.shape)
107 ymask = notna(y)
108
109 # NB: ravel() is only safe since y is ndarray; for e.g. PeriodIndex
110 # we would get int64 dtype, see GH#19956
111 yrav = y.ravel()
112 mask = notna(xrav) & ymask.ravel()
113
114 # See GH#5284, GH#5035, GH#19448 for historical reference
115 if mask.any():
116 result[mask] = op(xrav[mask], yrav[mask])
117
118 else:
119 if not is_scalar(y):
120 raise TypeError(
121 f"Cannot broadcast np.ndarray with operand of type { type(y) }"
122 )
123
124 # mask is only meaningful for x
125 result = np.empty(x.size, dtype=x.dtype)
126 mask = notna(xrav)
127
128 # 1 ** np.nan is 1. So we have to unmask those.
129 if op is pow:
130 mask = np.where(x == 1, False, mask)
131 elif op is roperator.rpow:
132 mask = np.where(y == 1, False, mask)
133
134 if mask.any():
135 result[mask] = op(xrav[mask], y)
136
137 np.putmask(result, ~mask, np.nan)
138 result = result.reshape(x.shape) # 2D compat
139 return result
140
141
142def _na_arithmetic_op(left: np.ndarray, right, op, is_cmp: bool = False):
143 """
144 Return the result of evaluating op on the passed in values.
145
146 If native types are not compatible, try coercion to object dtype.
147
148 Parameters
149 ----------
150 left : np.ndarray
151 right : np.ndarray or scalar
152 Excludes DataFrame, Series, Index, ExtensionArray.
153 is_cmp : bool, default False
154 If this a comparison operation.
155
156 Returns
157 -------
158 array-like
159
160 Raises
161 ------
162 TypeError : invalid operation
163 """
164 if isinstance(right, str):
165 # can never use numexpr
166 func = op
167 else:
168 func = partial(expressions.evaluate, op)
169
170 try:
171 result = func(left, right)
172 except TypeError:
173 if not is_cmp and (is_object_dtype(left.dtype) or is_object_dtype(right)):
174 # For object dtype, fallback to a masked operation (only operating
175 # on the non-missing values)
176 # Don't do this for comparisons, as that will handle complex numbers
177 # incorrectly, see GH#32047
178 result = _masked_arith_op(left, right, op)
179 else:
180 raise
181
182 if is_cmp and (is_scalar(result) or result is NotImplemented):
183 # numpy returned a scalar instead of operating element-wise
184 # e.g. numeric array vs str
185 # TODO: can remove this after dropping some future numpy version?
186 return invalid_comparison(left, right, op)
187
188 return missing.dispatch_fill_zeros(op, left, right, result)
189
190
191def arithmetic_op(left: ArrayLike, right: Any, op):
192 """
193 Evaluate an arithmetic operation `+`, `-`, `*`, `/`, `//`, `%`, `**`, ...
194
195 Note: the caller is responsible for ensuring that numpy warnings are
196 suppressed (with np.errstate(all="ignore")) if needed.
197
198 Parameters
199 ----------
200 left : np.ndarray or ExtensionArray
201 right : object
202 Cannot be a DataFrame or Index. Series is *not* excluded.
203 op : {operator.add, operator.sub, ...}
204 Or one of the reversed variants from roperator.
205
206 Returns
207 -------
208 ndarray or ExtensionArray
209 Or a 2-tuple of these in the case of divmod or rdivmod.
210 """
211 # NB: We assume that extract_array and ensure_wrapped_if_datetimelike
212 # have already been called on `left` and `right`,
213 # and `maybe_prepare_scalar_for_op` has already been called on `right`
214 # We need to special-case datetime64/timedelta64 dtypes (e.g. because numpy
215 # casts integer dtypes to timedelta64 when operating with timedelta64 - GH#22390)
216
217 if (
218 should_extension_dispatch(left, right)
219 or isinstance(right, (Timedelta, BaseOffset, Timestamp))
220 or right is NaT
221 ):
222 # Timedelta/Timestamp and other custom scalars are included in the check
223 # because numexpr will fail on it, see GH#31457
224 res_values = op(left, right)
225 else:
226 # TODO we should handle EAs consistently and move this check before the if/else
227 # (https://github.com/pandas-dev/pandas/issues/41165)
228 _bool_arith_check(op, left, right)
229
230 # error: Argument 1 to "_na_arithmetic_op" has incompatible type
231 # "Union[ExtensionArray, ndarray[Any, Any]]"; expected "ndarray[Any, Any]"
232 res_values = _na_arithmetic_op(left, right, op) # type: ignore[arg-type]
233
234 return res_values
235
236
237def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
238 """
239 Evaluate a comparison operation `=`, `!=`, `>=`, `>`, `<=`, or `<`.
240
241 Note: the caller is responsible for ensuring that numpy warnings are
242 suppressed (with np.errstate(all="ignore")) if needed.
243
244 Parameters
245 ----------
246 left : np.ndarray or ExtensionArray
247 right : object
248 Cannot be a DataFrame, Series, or Index.
249 op : {operator.eq, operator.ne, operator.gt, operator.ge, operator.lt, operator.le}
250
251 Returns
252 -------
253 ndarray or ExtensionArray
254 """
255 # NB: We assume extract_array has already been called on left and right
256 lvalues = ensure_wrapped_if_datetimelike(left)
257 rvalues = ensure_wrapped_if_datetimelike(right)
258
259 rvalues = lib.item_from_zerodim(rvalues)
260 if isinstance(rvalues, list):
261 # We don't catch tuple here bc we may be comparing e.g. MultiIndex
262 # to a tuple that represents a single entry, see test_compare_tuple_strs
263 rvalues = np.asarray(rvalues)
264
265 if isinstance(rvalues, (np.ndarray, ABCExtensionArray)):
266 # TODO: make this treatment consistent across ops and classes.
267 # We are not catching all listlikes here (e.g. frozenset, tuple)
268 # The ambiguous case is object-dtype. See GH#27803
269 if len(lvalues) != len(rvalues):
270 raise ValueError(
271 "Lengths must match to compare", lvalues.shape, rvalues.shape
272 )
273
274 if should_extension_dispatch(lvalues, rvalues) or (
275 (isinstance(rvalues, (Timedelta, BaseOffset, Timestamp)) or right is NaT)
276 and not is_object_dtype(lvalues.dtype)
277 ):
278 # Call the method on lvalues
279 res_values = op(lvalues, rvalues)
280
281 elif is_scalar(rvalues) and isna(rvalues): # TODO: but not pd.NA?
282 # numpy does not like comparisons vs None
283 if op is operator.ne:
284 res_values = np.ones(lvalues.shape, dtype=bool)
285 else:
286 res_values = np.zeros(lvalues.shape, dtype=bool)
287
288 elif is_numeric_v_string_like(lvalues, rvalues):
289 # GH#36377 going through the numexpr path would incorrectly raise
290 return invalid_comparison(lvalues, rvalues, op)
291
292 elif is_object_dtype(lvalues.dtype) or isinstance(rvalues, str):
293 res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues)
294
295 else:
296 res_values = _na_arithmetic_op(lvalues, rvalues, op, is_cmp=True)
297
298 return res_values
299
300
301def na_logical_op(x: np.ndarray, y, op):
302 try:
303 # For exposition, write:
304 # yarr = isinstance(y, np.ndarray)
305 # yint = is_integer(y) or (yarr and y.dtype.kind == "i")
306 # ybool = is_bool(y) or (yarr and y.dtype.kind == "b")
307 # xint = x.dtype.kind == "i"
308 # xbool = x.dtype.kind == "b"
309 # Then Cases where this goes through without raising include:
310 # (xint or xbool) and (yint or bool)
311 result = op(x, y)
312 except TypeError:
313 if isinstance(y, np.ndarray):
314 # bool-bool dtype operations should be OK, should not get here
315 assert not (is_bool_dtype(x.dtype) and is_bool_dtype(y.dtype))
316 x = ensure_object(x)
317 y = ensure_object(y)
318 result = libops.vec_binop(x.ravel(), y.ravel(), op)
319 else:
320 # let null fall thru
321 assert lib.is_scalar(y)
322 if not isna(y):
323 y = bool(y)
324 try:
325 result = libops.scalar_binop(x, y, op)
326 except (
327 TypeError,
328 ValueError,
329 AttributeError,
330 OverflowError,
331 NotImplementedError,
332 ) as err:
333 typ = type(y).__name__
334 raise TypeError(
335 f"Cannot perform '{op.__name__}' with a dtyped [{x.dtype}] array "
336 f"and scalar of type [{typ}]"
337 ) from err
338
339 return result.reshape(x.shape)
340
341
342def logical_op(left: ArrayLike, right: Any, op) -> ArrayLike:
343 """
344 Evaluate a logical operation `|`, `&`, or `^`.
345
346 Parameters
347 ----------
348 left : np.ndarray or ExtensionArray
349 right : object
350 Cannot be a DataFrame, Series, or Index.
351 op : {operator.and_, operator.or_, operator.xor}
352 Or one of the reversed variants from roperator.
353
354 Returns
355 -------
356 ndarray or ExtensionArray
357 """
358 fill_int = lambda x: x
359
360 def fill_bool(x, left=None):
361 # if `left` is specifically not-boolean, we do not cast to bool
362 if x.dtype.kind in ["c", "f", "O"]:
363 # dtypes that can hold NA
364 mask = isna(x)
365 if mask.any():
366 x = x.astype(object)
367 x[mask] = False
368
369 if left is None or is_bool_dtype(left.dtype):
370 x = x.astype(bool)
371 return x
372
373 is_self_int_dtype = is_integer_dtype(left.dtype)
374
375 right = lib.item_from_zerodim(right)
376 if is_list_like(right) and not hasattr(right, "dtype"):
377 # e.g. list, tuple
378 right = construct_1d_object_array_from_listlike(right)
379
380 # NB: We assume extract_array has already been called on left and right
381 lvalues = ensure_wrapped_if_datetimelike(left)
382 rvalues = right
383
384 if should_extension_dispatch(lvalues, rvalues):
385 # Call the method on lvalues
386 res_values = op(lvalues, rvalues)
387
388 else:
389 if isinstance(rvalues, np.ndarray):
390 is_other_int_dtype = is_integer_dtype(rvalues.dtype)
391 rvalues = rvalues if is_other_int_dtype else fill_bool(rvalues, lvalues)
392
393 else:
394 # i.e. scalar
395 is_other_int_dtype = lib.is_integer(rvalues)
396
397 # For int vs int `^`, `|`, `&` are bitwise operators and return
398 # integer dtypes. Otherwise these are boolean ops
399 filler = fill_int if is_self_int_dtype and is_other_int_dtype else fill_bool
400
401 res_values = na_logical_op(lvalues, rvalues, op)
402 # error: Cannot call function of unknown type
403 res_values = filler(res_values) # type: ignore[operator]
404
405 return res_values
406
407
408def get_array_op(op):
409 """
410 Return a binary array operation corresponding to the given operator op.
411
412 Parameters
413 ----------
414 op : function
415 Binary operator from operator or roperator module.
416
417 Returns
418 -------
419 functools.partial
420 """
421 if isinstance(op, partial):
422 # We get here via dispatch_to_series in DataFrame case
423 # e.g. test_rolling_consistency_var_debiasing_factors
424 return op
425
426 op_name = op.__name__.strip("_").lstrip("r")
427 if op_name == "arith_op":
428 # Reached via DataFrame._combine_frame i.e. flex methods
429 # e.g. test_df_add_flex_filled_mixed_dtypes
430 return op
431
432 if op_name in {"eq", "ne", "lt", "le", "gt", "ge"}:
433 return partial(comparison_op, op=op)
434 elif op_name in {"and", "or", "xor", "rand", "ror", "rxor"}:
435 return partial(logical_op, op=op)
436 elif op_name in {
437 "add",
438 "sub",
439 "mul",
440 "truediv",
441 "floordiv",
442 "mod",
443 "divmod",
444 "pow",
445 }:
446 return partial(arithmetic_op, op=op)
447 else:
448 raise NotImplementedError(op_name)
449
450
451def maybe_prepare_scalar_for_op(obj, shape: Shape):
452 """
453 Cast non-pandas objects to pandas types to unify behavior of arithmetic
454 and comparison operations.
455
456 Parameters
457 ----------
458 obj: object
459 shape : tuple[int]
460
461 Returns
462 -------
463 out : object
464
465 Notes
466 -----
467 Be careful to call this *after* determining the `name` attribute to be
468 attached to the result of the arithmetic operation.
469 """
470 if type(obj) is datetime.timedelta:
471 # GH#22390 cast up to Timedelta to rely on Timedelta
472 # implementation; otherwise operation against numeric-dtype
473 # raises TypeError
474 return Timedelta(obj)
475 elif type(obj) is datetime.datetime:
476 # cast up to Timestamp to rely on Timestamp implementation, see Timedelta above
477 return Timestamp(obj)
478 elif isinstance(obj, np.datetime64):
479 # GH#28080 numpy casts integer-dtype to datetime64 when doing
480 # array[int] + datetime64, which we do not allow
481 if isna(obj):
482 from pandas.core.arrays import DatetimeArray
483
484 # Avoid possible ambiguities with pd.NaT
485 # GH 52295
486 if is_unitless(obj.dtype):
487 obj = obj.astype("datetime64[ns]")
488 elif not is_supported_unit(get_unit_from_dtype(obj.dtype)):
489 unit = get_unit_from_dtype(obj.dtype)
490 closest_unit = npy_unit_to_abbrev(get_supported_reso(unit))
491 obj = obj.astype(f"datetime64[{closest_unit}]")
492 right = np.broadcast_to(obj, shape)
493 return DatetimeArray(right)
494
495 return Timestamp(obj)
496
497 elif isinstance(obj, np.timedelta64):
498 if isna(obj):
499 from pandas.core.arrays import TimedeltaArray
500
501 # wrapping timedelta64("NaT") in Timedelta returns NaT,
502 # which would incorrectly be treated as a datetime-NaT, so
503 # we broadcast and wrap in a TimedeltaArray
504 # GH 52295
505 if is_unitless(obj.dtype):
506 obj = obj.astype("timedelta64[ns]")
507 elif not is_supported_unit(get_unit_from_dtype(obj.dtype)):
508 unit = get_unit_from_dtype(obj.dtype)
509 closest_unit = npy_unit_to_abbrev(get_supported_reso(unit))
510 obj = obj.astype(f"timedelta64[{closest_unit}]")
511 right = np.broadcast_to(obj, shape)
512 return TimedeltaArray(right)
513
514 # In particular non-nanosecond timedelta64 needs to be cast to
515 # nanoseconds, or else we get undesired behavior like
516 # np.timedelta64(3, 'D') / 2 == np.timedelta64(1, 'D')
517 return Timedelta(obj)
518
519 return obj
520
521
522_BOOL_OP_NOT_ALLOWED = {
523 operator.truediv,
524 roperator.rtruediv,
525 operator.floordiv,
526 roperator.rfloordiv,
527 operator.pow,
528 roperator.rpow,
529}
530
531
532def _bool_arith_check(op, a, b):
533 """
534 In contrast to numpy, pandas raises an error for certain operations
535 with booleans.
536 """
537 if op in _BOOL_OP_NOT_ALLOWED:
538 if is_bool_dtype(a.dtype) and (
539 is_bool_dtype(b) or isinstance(b, (bool, np.bool_))
540 ):
541 op_name = op.__name__.strip("_").lstrip("r")
542 raise NotImplementedError(
543 f"operator '{op_name}' not implemented for bool dtypes"
544 )