1"""
2EA-compatible analogue to np.putmask
3"""
4from __future__ import annotations
5
6from typing import (
7 TYPE_CHECKING,
8 Any,
9)
10
11import numpy as np
12
13from pandas._libs import lib
14from pandas._typing import (
15 ArrayLike,
16 npt,
17)
18from pandas.compat import np_version_under1p21
19
20from pandas.core.dtypes.cast import infer_dtype_from
21from pandas.core.dtypes.common import is_list_like
22
23from pandas.core.arrays import ExtensionArray
24
25if TYPE_CHECKING:
26 from pandas import MultiIndex
27
28
29def putmask_inplace(values: ArrayLike, mask: npt.NDArray[np.bool_], value: Any) -> None:
30 """
31 ExtensionArray-compatible implementation of np.putmask. The main
32 difference is we do not handle repeating or truncating like numpy.
33
34 Parameters
35 ----------
36 values: np.ndarray or ExtensionArray
37 mask : np.ndarray[bool]
38 We assume extract_bool_array has already been called.
39 value : Any
40 """
41
42 if (
43 not isinstance(values, np.ndarray)
44 or (values.dtype == object and not lib.is_scalar(value))
45 # GH#43424: np.putmask raises TypeError if we cannot cast between types with
46 # rule = "safe", a stricter guarantee we may not have here
47 or (
48 isinstance(value, np.ndarray) and not np.can_cast(value.dtype, values.dtype)
49 )
50 ):
51 # GH#19266 using np.putmask gives unexpected results with listlike value
52 # along with object dtype
53 if is_list_like(value) and len(value) == len(values):
54 values[mask] = value[mask]
55 else:
56 values[mask] = value
57 else:
58 # GH#37833 np.putmask is more performant than __setitem__
59 np.putmask(values, mask, value)
60
61
62def putmask_without_repeat(
63 values: np.ndarray, mask: npt.NDArray[np.bool_], new: Any
64) -> None:
65 """
66 np.putmask will truncate or repeat if `new` is a listlike with
67 len(new) != len(values). We require an exact match.
68
69 Parameters
70 ----------
71 values : np.ndarray
72 mask : np.ndarray[bool]
73 new : Any
74 """
75 if np_version_under1p21:
76 new = setitem_datetimelike_compat(values, mask.sum(), new)
77
78 if getattr(new, "ndim", 0) >= 1:
79 new = new.astype(values.dtype, copy=False)
80
81 # TODO: this prob needs some better checking for 2D cases
82 nlocs = mask.sum()
83 if nlocs > 0 and is_list_like(new) and getattr(new, "ndim", 1) == 1:
84 shape = np.shape(new)
85 # np.shape compat for if setitem_datetimelike_compat
86 # changed arraylike to list e.g. test_where_dt64_2d
87 if nlocs == shape[-1]:
88 # GH#30567
89 # If length of ``new`` is less than the length of ``values``,
90 # `np.putmask` would first repeat the ``new`` array and then
91 # assign the masked values hence produces incorrect result.
92 # `np.place` on the other hand uses the ``new`` values at it is
93 # to place in the masked locations of ``values``
94 np.place(values, mask, new)
95 # i.e. values[mask] = new
96 elif mask.shape[-1] == shape[-1] or shape[-1] == 1:
97 np.putmask(values, mask, new)
98 else:
99 raise ValueError("cannot assign mismatch length to masked array")
100 else:
101 np.putmask(values, mask, new)
102
103
104def validate_putmask(
105 values: ArrayLike | MultiIndex, mask: np.ndarray
106) -> tuple[npt.NDArray[np.bool_], bool]:
107 """
108 Validate mask and check if this putmask operation is a no-op.
109 """
110 mask = extract_bool_array(mask)
111 if mask.shape != values.shape:
112 raise ValueError("putmask: mask and data must be the same size")
113
114 noop = not mask.any()
115 return mask, noop
116
117
118def extract_bool_array(mask: ArrayLike) -> npt.NDArray[np.bool_]:
119 """
120 If we have a SparseArray or BooleanArray, convert it to ndarray[bool].
121 """
122 if isinstance(mask, ExtensionArray):
123 # We could have BooleanArray, Sparse[bool], ...
124 # Except for BooleanArray, this is equivalent to just
125 # np.asarray(mask, dtype=bool)
126 mask = mask.to_numpy(dtype=bool, na_value=False)
127
128 mask = np.asarray(mask, dtype=bool)
129 return mask
130
131
132def setitem_datetimelike_compat(values: np.ndarray, num_set: int, other):
133 """
134 Parameters
135 ----------
136 values : np.ndarray
137 num_set : int
138 For putmask, this is mask.sum()
139 other : Any
140 """
141 if values.dtype == object:
142 dtype, _ = infer_dtype_from(other, pandas_dtype=True)
143
144 if isinstance(dtype, np.dtype) and dtype.kind in ["m", "M"]:
145 # https://github.com/numpy/numpy/issues/12550
146 # timedelta64 will incorrectly cast to int
147 if not is_list_like(other):
148 other = [other] * num_set
149 else:
150 other = list(other)
151
152 return other