1from __future__ import annotations
2
3from functools import wraps
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Literal,
8 Sequence,
9 TypeVar,
10 cast,
11 overload,
12)
13
14import numpy as np
15
16from pandas._libs import lib
17from pandas._libs.arrays import NDArrayBacked
18from pandas._typing import (
19 ArrayLike,
20 AxisInt,
21 Dtype,
22 F,
23 PositionalIndexer2D,
24 PositionalIndexerTuple,
25 ScalarIndexer,
26 SequenceIndexer,
27 Shape,
28 TakeIndexer,
29 npt,
30 type_t,
31)
32from pandas.errors import AbstractMethodError
33from pandas.util._decorators import doc
34from pandas.util._validators import (
35 validate_bool_kwarg,
36 validate_fillna_kwargs,
37 validate_insert_loc,
38)
39
40from pandas.core.dtypes.common import (
41 is_dtype_equal,
42 pandas_dtype,
43)
44from pandas.core.dtypes.dtypes import (
45 DatetimeTZDtype,
46 ExtensionDtype,
47 PeriodDtype,
48)
49from pandas.core.dtypes.missing import array_equivalent
50
51from pandas.core import missing
52from pandas.core.algorithms import (
53 take,
54 unique,
55 value_counts,
56)
57from pandas.core.array_algos.quantile import quantile_with_mask
58from pandas.core.array_algos.transforms import shift
59from pandas.core.arrays.base import ExtensionArray
60from pandas.core.construction import extract_array
61from pandas.core.indexers import check_array_indexer
62from pandas.core.sorting import nargminmax
63
64NDArrayBackedExtensionArrayT = TypeVar(
65 "NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
66)
67
68if TYPE_CHECKING:
69 from pandas._typing import (
70 NumpySorter,
71 NumpyValueArrayLike,
72 )
73
74 from pandas import Series
75
76
77def ravel_compat(meth: F) -> F:
78 """
79 Decorator to ravel a 2D array before passing it to a cython operation,
80 then reshape the result to our own shape.
81 """
82
83 @wraps(meth)
84 def method(self, *args, **kwargs):
85 if self.ndim == 1:
86 return meth(self, *args, **kwargs)
87
88 flags = self._ndarray.flags
89 flat = self.ravel("K")
90 result = meth(flat, *args, **kwargs)
91 order = "F" if flags.f_contiguous else "C"
92 return result.reshape(self.shape, order=order)
93
94 return cast(F, method)
95
96
97class NDArrayBackedExtensionArray(NDArrayBacked, ExtensionArray):
98 """
99 ExtensionArray that is backed by a single NumPy ndarray.
100 """
101
102 _ndarray: np.ndarray
103
104 # scalar used to denote NA value inside our self._ndarray, e.g. -1
105 # for Categorical, iNaT for Period. Outside of object dtype,
106 # self.isna() should be exactly locations in self._ndarray with
107 # _internal_fill_value.
108 _internal_fill_value: Any
109
110 def _box_func(self, x):
111 """
112 Wrap numpy type in our dtype.type if necessary.
113 """
114 return x
115
116 def _validate_scalar(self, value):
117 # used by NDArrayBackedExtensionIndex.insert
118 raise AbstractMethodError(self)
119
120 # ------------------------------------------------------------------------
121
122 def view(self, dtype: Dtype | None = None) -> ArrayLike:
123 # We handle datetime64, datetime64tz, timedelta64, and period
124 # dtypes here. Everything else we pass through to the underlying
125 # ndarray.
126 if dtype is None or dtype is self.dtype:
127 return self._from_backing_data(self._ndarray)
128
129 if isinstance(dtype, type):
130 # we sometimes pass non-dtype objects, e.g np.ndarray;
131 # pass those through to the underlying ndarray
132 return self._ndarray.view(dtype)
133
134 dtype = pandas_dtype(dtype)
135 arr = self._ndarray
136
137 if isinstance(dtype, (PeriodDtype, DatetimeTZDtype)):
138 cls = dtype.construct_array_type()
139 return cls(arr.view("i8"), dtype=dtype)
140 elif dtype == "M8[ns]":
141 from pandas.core.arrays import DatetimeArray
142
143 return DatetimeArray(arr.view("i8"), dtype=dtype)
144 elif dtype == "m8[ns]":
145 from pandas.core.arrays import TimedeltaArray
146
147 return TimedeltaArray(arr.view("i8"), dtype=dtype)
148
149 # error: Argument "dtype" to "view" of "_ArrayOrScalarCommon" has incompatible
150 # type "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None,
151 # type, _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
152 # Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
153 return arr.view(dtype=dtype) # type: ignore[arg-type]
154
155 def take(
156 self: NDArrayBackedExtensionArrayT,
157 indices: TakeIndexer,
158 *,
159 allow_fill: bool = False,
160 fill_value: Any = None,
161 axis: AxisInt = 0,
162 ) -> NDArrayBackedExtensionArrayT:
163 if allow_fill:
164 fill_value = self._validate_scalar(fill_value)
165
166 new_data = take(
167 self._ndarray,
168 indices,
169 allow_fill=allow_fill,
170 fill_value=fill_value,
171 axis=axis,
172 )
173 return self._from_backing_data(new_data)
174
175 # ------------------------------------------------------------------------
176
177 def equals(self, other) -> bool:
178 if type(self) is not type(other):
179 return False
180 if not is_dtype_equal(self.dtype, other.dtype):
181 return False
182 return bool(array_equivalent(self._ndarray, other._ndarray))
183
184 @classmethod
185 def _from_factorized(cls, values, original):
186 assert values.dtype == original._ndarray.dtype
187 return original._from_backing_data(values)
188
189 def _values_for_argsort(self) -> np.ndarray:
190 return self._ndarray
191
192 def _values_for_factorize(self):
193 return self._ndarray, self._internal_fill_value
194
195 # Signature of "argmin" incompatible with supertype "ExtensionArray"
196 def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
197 # override base class by adding axis keyword
198 validate_bool_kwarg(skipna, "skipna")
199 if not skipna and self._hasna:
200 raise NotImplementedError
201 return nargminmax(self, "argmin", axis=axis)
202
203 # Signature of "argmax" incompatible with supertype "ExtensionArray"
204 def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
205 # override base class by adding axis keyword
206 validate_bool_kwarg(skipna, "skipna")
207 if not skipna and self._hasna:
208 raise NotImplementedError
209 return nargminmax(self, "argmax", axis=axis)
210
211 def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
212 new_data = unique(self._ndarray)
213 return self._from_backing_data(new_data)
214
215 @classmethod
216 @doc(ExtensionArray._concat_same_type)
217 def _concat_same_type(
218 cls: type[NDArrayBackedExtensionArrayT],
219 to_concat: Sequence[NDArrayBackedExtensionArrayT],
220 axis: AxisInt = 0,
221 ) -> NDArrayBackedExtensionArrayT:
222 dtypes = {str(x.dtype) for x in to_concat}
223 if len(dtypes) != 1:
224 raise ValueError("to_concat must have the same dtype (tz)", dtypes)
225
226 new_values = [x._ndarray for x in to_concat]
227 new_arr = np.concatenate(new_values, axis=axis)
228 return to_concat[0]._from_backing_data(new_arr)
229
230 @doc(ExtensionArray.searchsorted)
231 def searchsorted(
232 self,
233 value: NumpyValueArrayLike | ExtensionArray,
234 side: Literal["left", "right"] = "left",
235 sorter: NumpySorter = None,
236 ) -> npt.NDArray[np.intp] | np.intp:
237 npvalue = self._validate_setitem_value(value)
238 return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)
239
240 @doc(ExtensionArray.shift)
241 def shift(self, periods: int = 1, fill_value=None, axis: AxisInt = 0):
242 fill_value = self._validate_scalar(fill_value)
243 new_values = shift(self._ndarray, periods, axis, fill_value)
244
245 return self._from_backing_data(new_values)
246
247 def __setitem__(self, key, value) -> None:
248 key = check_array_indexer(self, key)
249 value = self._validate_setitem_value(value)
250 self._ndarray[key] = value
251
252 def _validate_setitem_value(self, value):
253 return value
254
255 @overload
256 def __getitem__(self, key: ScalarIndexer) -> Any:
257 ...
258
259 @overload
260 def __getitem__(
261 self: NDArrayBackedExtensionArrayT,
262 key: SequenceIndexer | PositionalIndexerTuple,
263 ) -> NDArrayBackedExtensionArrayT:
264 ...
265
266 def __getitem__(
267 self: NDArrayBackedExtensionArrayT,
268 key: PositionalIndexer2D,
269 ) -> NDArrayBackedExtensionArrayT | Any:
270 if lib.is_integer(key):
271 # fast-path
272 result = self._ndarray[key]
273 if self.ndim == 1:
274 return self._box_func(result)
275 return self._from_backing_data(result)
276
277 # error: Incompatible types in assignment (expression has type "ExtensionArray",
278 # variable has type "Union[int, slice, ndarray]")
279 key = extract_array(key, extract_numpy=True) # type: ignore[assignment]
280 key = check_array_indexer(self, key)
281 result = self._ndarray[key]
282 if lib.is_scalar(result):
283 return self._box_func(result)
284
285 result = self._from_backing_data(result)
286 return result
287
288 def _fill_mask_inplace(
289 self, method: str, limit, mask: npt.NDArray[np.bool_]
290 ) -> None:
291 # (for now) when self.ndim == 2, we assume axis=0
292 func = missing.get_fill_func(method, ndim=self.ndim)
293 func(self._ndarray.T, limit=limit, mask=mask.T)
294
295 @doc(ExtensionArray.fillna)
296 def fillna(
297 self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
298 ) -> NDArrayBackedExtensionArrayT:
299 value, method = validate_fillna_kwargs(
300 value, method, validate_scalar_dict_value=False
301 )
302
303 mask = self.isna()
304 # error: Argument 2 to "check_value_size" has incompatible type
305 # "ExtensionArray"; expected "ndarray"
306 value = missing.check_value_size(
307 value, mask, len(self) # type: ignore[arg-type]
308 )
309
310 if mask.any():
311 if method is not None:
312 # TODO: check value is None
313 # (for now) when self.ndim == 2, we assume axis=0
314 func = missing.get_fill_func(method, ndim=self.ndim)
315 npvalues = self._ndarray.T.copy()
316 func(npvalues, limit=limit, mask=mask.T)
317 npvalues = npvalues.T
318
319 # TODO: PandasArray didn't used to copy, need tests for this
320 new_values = self._from_backing_data(npvalues)
321 else:
322 # fill with value
323 new_values = self.copy()
324 new_values[mask] = value
325 else:
326 # We validate the fill_value even if there is nothing to fill
327 if value is not None:
328 self._validate_setitem_value(value)
329
330 new_values = self.copy()
331 return new_values
332
333 # ------------------------------------------------------------------------
334 # Reductions
335
336 def _wrap_reduction_result(self, axis: AxisInt | None, result):
337 if axis is None or self.ndim == 1:
338 return self._box_func(result)
339 return self._from_backing_data(result)
340
341 # ------------------------------------------------------------------------
342 # __array_function__ methods
343
344 def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
345 """
346 Analogue to np.putmask(self, mask, value)
347
348 Parameters
349 ----------
350 mask : np.ndarray[bool]
351 value : scalar or listlike
352
353 Raises
354 ------
355 TypeError
356 If value cannot be cast to self.dtype.
357 """
358 value = self._validate_setitem_value(value)
359
360 np.putmask(self._ndarray, mask, value)
361
362 def _where(
363 self: NDArrayBackedExtensionArrayT, mask: npt.NDArray[np.bool_], value
364 ) -> NDArrayBackedExtensionArrayT:
365 """
366 Analogue to np.where(mask, self, value)
367
368 Parameters
369 ----------
370 mask : np.ndarray[bool]
371 value : scalar or listlike
372
373 Raises
374 ------
375 TypeError
376 If value cannot be cast to self.dtype.
377 """
378 value = self._validate_setitem_value(value)
379
380 res_values = np.where(mask, self._ndarray, value)
381 return self._from_backing_data(res_values)
382
383 # ------------------------------------------------------------------------
384 # Index compat methods
385
386 def insert(
387 self: NDArrayBackedExtensionArrayT, loc: int, item
388 ) -> NDArrayBackedExtensionArrayT:
389 """
390 Make new ExtensionArray inserting new item at location. Follows
391 Python list.append semantics for negative values.
392
393 Parameters
394 ----------
395 loc : int
396 item : object
397
398 Returns
399 -------
400 type(self)
401 """
402 loc = validate_insert_loc(loc, len(self))
403
404 code = self._validate_scalar(item)
405
406 new_vals = np.concatenate(
407 (
408 self._ndarray[:loc],
409 np.asarray([code], dtype=self._ndarray.dtype),
410 self._ndarray[loc:],
411 )
412 )
413 return self._from_backing_data(new_vals)
414
415 # ------------------------------------------------------------------------
416 # Additional array methods
417 # These are not part of the EA API, but we implement them because
418 # pandas assumes they're there.
419
420 def value_counts(self, dropna: bool = True) -> Series:
421 """
422 Return a Series containing counts of unique values.
423
424 Parameters
425 ----------
426 dropna : bool, default True
427 Don't include counts of NA values.
428
429 Returns
430 -------
431 Series
432 """
433 if self.ndim != 1:
434 raise NotImplementedError
435
436 from pandas import (
437 Index,
438 Series,
439 )
440
441 if dropna:
442 # error: Unsupported operand type for ~ ("ExtensionArray")
443 values = self[~self.isna()]._ndarray # type: ignore[operator]
444 else:
445 values = self._ndarray
446
447 result = value_counts(values, sort=False, dropna=dropna)
448
449 index_arr = self._from_backing_data(np.asarray(result.index._data))
450 index = Index(index_arr, name=result.index.name)
451 return Series(result._values, index=index, name=result.name, copy=False)
452
453 def _quantile(
454 self: NDArrayBackedExtensionArrayT,
455 qs: npt.NDArray[np.float64],
456 interpolation: str,
457 ) -> NDArrayBackedExtensionArrayT:
458 # TODO: disable for Categorical if not ordered?
459
460 mask = np.asarray(self.isna())
461 arr = self._ndarray
462 fill_value = self._internal_fill_value
463
464 res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
465
466 res_values = self._cast_quantile_result(res_values)
467 return self._from_backing_data(res_values)
468
469 # TODO: see if we can share this with other dispatch-wrapping methods
470 def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
471 """
472 Cast the result of quantile_with_mask to an appropriate dtype
473 to pass to _from_backing_data in _quantile.
474 """
475 return res_values
476
477 # ------------------------------------------------------------------------
478 # numpy-like methods
479
480 @classmethod
481 def _empty(
482 cls: type_t[NDArrayBackedExtensionArrayT], shape: Shape, dtype: ExtensionDtype
483 ) -> NDArrayBackedExtensionArrayT:
484 """
485 Analogous to np.empty(shape, dtype=dtype)
486
487 Parameters
488 ----------
489 shape : tuple[int]
490 dtype : ExtensionDtype
491 """
492 # The base implementation uses a naive approach to find the dtype
493 # for the backing ndarray
494 arr = cls._from_sequence([], dtype=dtype)
495 backing = np.empty(shape, dtype=arr._ndarray.dtype)
496 return arr._from_backing_data(backing)