1from __future__ import annotations
2
3from functools import wraps
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Literal,
8 cast,
9 overload,
10)
11
12import numpy as np
13
14from pandas._libs import lib
15from pandas._libs.arrays import NDArrayBacked
16from pandas._libs.tslibs import is_supported_dtype
17from pandas._typing import (
18 ArrayLike,
19 AxisInt,
20 Dtype,
21 F,
22 FillnaOptions,
23 PositionalIndexer2D,
24 PositionalIndexerTuple,
25 ScalarIndexer,
26 Self,
27 SequenceIndexer,
28 Shape,
29 TakeIndexer,
30 npt,
31)
32from pandas.errors import AbstractMethodError
33from pandas.util._decorators import doc
34from pandas.util._validators import (
35 validate_bool_kwarg,
36 validate_fillna_kwargs,
37 validate_insert_loc,
38)
39
40from pandas.core.dtypes.common import pandas_dtype
41from pandas.core.dtypes.dtypes import (
42 DatetimeTZDtype,
43 ExtensionDtype,
44 PeriodDtype,
45)
46from pandas.core.dtypes.missing import array_equivalent
47
48from pandas.core import missing
49from pandas.core.algorithms import (
50 take,
51 unique,
52 value_counts_internal as value_counts,
53)
54from pandas.core.array_algos.quantile import quantile_with_mask
55from pandas.core.array_algos.transforms import shift
56from pandas.core.arrays.base import ExtensionArray
57from pandas.core.construction import extract_array
58from pandas.core.indexers import check_array_indexer
59from pandas.core.sorting import nargminmax
60
61if TYPE_CHECKING:
62 from collections.abc import Sequence
63
64 from pandas._typing import (
65 NumpySorter,
66 NumpyValueArrayLike,
67 )
68
69 from pandas import Series
70
71
72def ravel_compat(meth: F) -> F:
73 """
74 Decorator to ravel a 2D array before passing it to a cython operation,
75 then reshape the result to our own shape.
76 """
77
78 @wraps(meth)
79 def method(self, *args, **kwargs):
80 if self.ndim == 1:
81 return meth(self, *args, **kwargs)
82
83 flags = self._ndarray.flags
84 flat = self.ravel("K")
85 result = meth(flat, *args, **kwargs)
86 order = "F" if flags.f_contiguous else "C"
87 return result.reshape(self.shape, order=order)
88
89 return cast(F, method)
90
91
92class NDArrayBackedExtensionArray(NDArrayBacked, ExtensionArray):
93 """
94 ExtensionArray that is backed by a single NumPy ndarray.
95 """
96
97 _ndarray: np.ndarray
98
99 # scalar used to denote NA value inside our self._ndarray, e.g. -1
100 # for Categorical, iNaT for Period. Outside of object dtype,
101 # self.isna() should be exactly locations in self._ndarray with
102 # _internal_fill_value.
103 _internal_fill_value: Any
104
105 def _box_func(self, x):
106 """
107 Wrap numpy type in our dtype.type if necessary.
108 """
109 return x
110
111 def _validate_scalar(self, value):
112 # used by NDArrayBackedExtensionIndex.insert
113 raise AbstractMethodError(self)
114
115 # ------------------------------------------------------------------------
116
117 def view(self, dtype: Dtype | None = None) -> ArrayLike:
118 # We handle datetime64, datetime64tz, timedelta64, and period
119 # dtypes here. Everything else we pass through to the underlying
120 # ndarray.
121 if dtype is None or dtype is self.dtype:
122 return self._from_backing_data(self._ndarray)
123
124 if isinstance(dtype, type):
125 # we sometimes pass non-dtype objects, e.g np.ndarray;
126 # pass those through to the underlying ndarray
127 return self._ndarray.view(dtype)
128
129 dtype = pandas_dtype(dtype)
130 arr = self._ndarray
131
132 if isinstance(dtype, PeriodDtype):
133 cls = dtype.construct_array_type()
134 return cls(arr.view("i8"), dtype=dtype)
135 elif isinstance(dtype, DatetimeTZDtype):
136 dt_cls = dtype.construct_array_type()
137 dt64_values = arr.view(f"M8[{dtype.unit}]")
138 return dt_cls._simple_new(dt64_values, dtype=dtype)
139 elif lib.is_np_dtype(dtype, "M") and is_supported_dtype(dtype):
140 from pandas.core.arrays import DatetimeArray
141
142 dt64_values = arr.view(dtype)
143 return DatetimeArray._simple_new(dt64_values, dtype=dtype)
144
145 elif lib.is_np_dtype(dtype, "m") and is_supported_dtype(dtype):
146 from pandas.core.arrays import TimedeltaArray
147
148 td64_values = arr.view(dtype)
149 return TimedeltaArray._simple_new(td64_values, dtype=dtype)
150
151 # error: Argument "dtype" to "view" of "_ArrayOrScalarCommon" has incompatible
152 # type "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None,
153 # type, _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
154 # Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
155 return arr.view(dtype=dtype) # type: ignore[arg-type]
156
157 def take(
158 self,
159 indices: TakeIndexer,
160 *,
161 allow_fill: bool = False,
162 fill_value: Any = None,
163 axis: AxisInt = 0,
164 ) -> Self:
165 if allow_fill:
166 fill_value = self._validate_scalar(fill_value)
167
168 new_data = take(
169 self._ndarray,
170 indices,
171 allow_fill=allow_fill,
172 fill_value=fill_value,
173 axis=axis,
174 )
175 return self._from_backing_data(new_data)
176
177 # ------------------------------------------------------------------------
178
179 def equals(self, other) -> bool:
180 if type(self) is not type(other):
181 return False
182 if self.dtype != other.dtype:
183 return False
184 return bool(array_equivalent(self._ndarray, other._ndarray, dtype_equal=True))
185
186 @classmethod
187 def _from_factorized(cls, values, original):
188 assert values.dtype == original._ndarray.dtype
189 return original._from_backing_data(values)
190
191 def _values_for_argsort(self) -> np.ndarray:
192 return self._ndarray
193
194 def _values_for_factorize(self):
195 return self._ndarray, self._internal_fill_value
196
197 def _hash_pandas_object(
198 self, *, encoding: str, hash_key: str, categorize: bool
199 ) -> npt.NDArray[np.uint64]:
200 from pandas.core.util.hashing import hash_array
201
202 values = self._ndarray
203 return hash_array(
204 values, encoding=encoding, hash_key=hash_key, categorize=categorize
205 )
206
207 # Signature of "argmin" incompatible with supertype "ExtensionArray"
208 def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
209 # override base class by adding axis keyword
210 validate_bool_kwarg(skipna, "skipna")
211 if not skipna and self._hasna:
212 raise NotImplementedError
213 return nargminmax(self, "argmin", axis=axis)
214
215 # Signature of "argmax" incompatible with supertype "ExtensionArray"
216 def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
217 # override base class by adding axis keyword
218 validate_bool_kwarg(skipna, "skipna")
219 if not skipna and self._hasna:
220 raise NotImplementedError
221 return nargminmax(self, "argmax", axis=axis)
222
223 def unique(self) -> Self:
224 new_data = unique(self._ndarray)
225 return self._from_backing_data(new_data)
226
227 @classmethod
228 @doc(ExtensionArray._concat_same_type)
229 def _concat_same_type(
230 cls,
231 to_concat: Sequence[Self],
232 axis: AxisInt = 0,
233 ) -> Self:
234 if not lib.dtypes_all_equal([x.dtype for x in to_concat]):
235 dtypes = {str(x.dtype) for x in to_concat}
236 raise ValueError("to_concat must have the same dtype", dtypes)
237
238 return super()._concat_same_type(to_concat, axis=axis)
239
240 @doc(ExtensionArray.searchsorted)
241 def searchsorted(
242 self,
243 value: NumpyValueArrayLike | ExtensionArray,
244 side: Literal["left", "right"] = "left",
245 sorter: NumpySorter | None = None,
246 ) -> npt.NDArray[np.intp] | np.intp:
247 npvalue = self._validate_setitem_value(value)
248 return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)
249
250 @doc(ExtensionArray.shift)
251 def shift(self, periods: int = 1, fill_value=None):
252 # NB: shift is always along axis=0
253 axis = 0
254 fill_value = self._validate_scalar(fill_value)
255 new_values = shift(self._ndarray, periods, axis, fill_value)
256
257 return self._from_backing_data(new_values)
258
259 def __setitem__(self, key, value) -> None:
260 key = check_array_indexer(self, key)
261 value = self._validate_setitem_value(value)
262 self._ndarray[key] = value
263
264 def _validate_setitem_value(self, value):
265 return value
266
267 @overload
268 def __getitem__(self, key: ScalarIndexer) -> Any:
269 ...
270
271 @overload
272 def __getitem__(
273 self,
274 key: SequenceIndexer | PositionalIndexerTuple,
275 ) -> Self:
276 ...
277
278 def __getitem__(
279 self,
280 key: PositionalIndexer2D,
281 ) -> Self | Any:
282 if lib.is_integer(key):
283 # fast-path
284 result = self._ndarray[key]
285 if self.ndim == 1:
286 return self._box_func(result)
287 return self._from_backing_data(result)
288
289 # error: Incompatible types in assignment (expression has type "ExtensionArray",
290 # variable has type "Union[int, slice, ndarray]")
291 key = extract_array(key, extract_numpy=True) # type: ignore[assignment]
292 key = check_array_indexer(self, key)
293 result = self._ndarray[key]
294 if lib.is_scalar(result):
295 return self._box_func(result)
296
297 result = self._from_backing_data(result)
298 return result
299
300 def _fill_mask_inplace(
301 self, method: str, limit: int | None, mask: npt.NDArray[np.bool_]
302 ) -> None:
303 # (for now) when self.ndim == 2, we assume axis=0
304 func = missing.get_fill_func(method, ndim=self.ndim)
305 func(self._ndarray.T, limit=limit, mask=mask.T)
306
307 def _pad_or_backfill(
308 self,
309 *,
310 method: FillnaOptions,
311 limit: int | None = None,
312 limit_area: Literal["inside", "outside"] | None = None,
313 copy: bool = True,
314 ) -> Self:
315 mask = self.isna()
316 if mask.any():
317 # (for now) when self.ndim == 2, we assume axis=0
318 func = missing.get_fill_func(method, ndim=self.ndim)
319
320 npvalues = self._ndarray.T
321 if copy:
322 npvalues = npvalues.copy()
323 func(npvalues, limit=limit, limit_area=limit_area, mask=mask.T)
324 npvalues = npvalues.T
325
326 if copy:
327 new_values = self._from_backing_data(npvalues)
328 else:
329 new_values = self
330
331 else:
332 if copy:
333 new_values = self.copy()
334 else:
335 new_values = self
336 return new_values
337
338 @doc(ExtensionArray.fillna)
339 def fillna(
340 self, value=None, method=None, limit: int | None = None, copy: bool = True
341 ) -> Self:
342 value, method = validate_fillna_kwargs(
343 value, method, validate_scalar_dict_value=False
344 )
345
346 mask = self.isna()
347 # error: Argument 2 to "check_value_size" has incompatible type
348 # "ExtensionArray"; expected "ndarray"
349 value = missing.check_value_size(
350 value, mask, len(self) # type: ignore[arg-type]
351 )
352
353 if mask.any():
354 if method is not None:
355 # (for now) when self.ndim == 2, we assume axis=0
356 func = missing.get_fill_func(method, ndim=self.ndim)
357 npvalues = self._ndarray.T
358 if copy:
359 npvalues = npvalues.copy()
360 func(npvalues, limit=limit, mask=mask.T)
361 npvalues = npvalues.T
362
363 # TODO: NumpyExtensionArray didn't used to copy, need tests
364 # for this
365 new_values = self._from_backing_data(npvalues)
366 else:
367 # fill with value
368 if copy:
369 new_values = self.copy()
370 else:
371 new_values = self[:]
372 new_values[mask] = value
373 else:
374 # We validate the fill_value even if there is nothing to fill
375 if value is not None:
376 self._validate_setitem_value(value)
377
378 if not copy:
379 new_values = self[:]
380 else:
381 new_values = self.copy()
382 return new_values
383
384 # ------------------------------------------------------------------------
385 # Reductions
386
387 def _wrap_reduction_result(self, axis: AxisInt | None, result):
388 if axis is None or self.ndim == 1:
389 return self._box_func(result)
390 return self._from_backing_data(result)
391
392 # ------------------------------------------------------------------------
393 # __array_function__ methods
394
395 def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
396 """
397 Analogue to np.putmask(self, mask, value)
398
399 Parameters
400 ----------
401 mask : np.ndarray[bool]
402 value : scalar or listlike
403
404 Raises
405 ------
406 TypeError
407 If value cannot be cast to self.dtype.
408 """
409 value = self._validate_setitem_value(value)
410
411 np.putmask(self._ndarray, mask, value)
412
413 def _where(self: Self, mask: npt.NDArray[np.bool_], value) -> Self:
414 """
415 Analogue to np.where(mask, self, value)
416
417 Parameters
418 ----------
419 mask : np.ndarray[bool]
420 value : scalar or listlike
421
422 Raises
423 ------
424 TypeError
425 If value cannot be cast to self.dtype.
426 """
427 value = self._validate_setitem_value(value)
428
429 res_values = np.where(mask, self._ndarray, value)
430 if res_values.dtype != self._ndarray.dtype:
431 raise AssertionError(
432 # GH#56410
433 "Something has gone wrong, please report a bug at "
434 "github.com/pandas-dev/pandas/"
435 )
436 return self._from_backing_data(res_values)
437
438 # ------------------------------------------------------------------------
439 # Index compat methods
440
441 def insert(self, loc: int, item) -> Self:
442 """
443 Make new ExtensionArray inserting new item at location. Follows
444 Python list.append semantics for negative values.
445
446 Parameters
447 ----------
448 loc : int
449 item : object
450
451 Returns
452 -------
453 type(self)
454 """
455 loc = validate_insert_loc(loc, len(self))
456
457 code = self._validate_scalar(item)
458
459 new_vals = np.concatenate(
460 (
461 self._ndarray[:loc],
462 np.asarray([code], dtype=self._ndarray.dtype),
463 self._ndarray[loc:],
464 )
465 )
466 return self._from_backing_data(new_vals)
467
468 # ------------------------------------------------------------------------
469 # Additional array methods
470 # These are not part of the EA API, but we implement them because
471 # pandas assumes they're there.
472
473 def value_counts(self, dropna: bool = True) -> Series:
474 """
475 Return a Series containing counts of unique values.
476
477 Parameters
478 ----------
479 dropna : bool, default True
480 Don't include counts of NA values.
481
482 Returns
483 -------
484 Series
485 """
486 if self.ndim != 1:
487 raise NotImplementedError
488
489 from pandas import (
490 Index,
491 Series,
492 )
493
494 if dropna:
495 # error: Unsupported operand type for ~ ("ExtensionArray")
496 values = self[~self.isna()]._ndarray # type: ignore[operator]
497 else:
498 values = self._ndarray
499
500 result = value_counts(values, sort=False, dropna=dropna)
501
502 index_arr = self._from_backing_data(np.asarray(result.index._data))
503 index = Index(index_arr, name=result.index.name)
504 return Series(result._values, index=index, name=result.name, copy=False)
505
506 def _quantile(
507 self,
508 qs: npt.NDArray[np.float64],
509 interpolation: str,
510 ) -> Self:
511 # TODO: disable for Categorical if not ordered?
512
513 mask = np.asarray(self.isna())
514 arr = self._ndarray
515 fill_value = self._internal_fill_value
516
517 res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
518
519 res_values = self._cast_quantile_result(res_values)
520 return self._from_backing_data(res_values)
521
522 # TODO: see if we can share this with other dispatch-wrapping methods
523 def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
524 """
525 Cast the result of quantile_with_mask to an appropriate dtype
526 to pass to _from_backing_data in _quantile.
527 """
528 return res_values
529
530 # ------------------------------------------------------------------------
531 # numpy-like methods
532
533 @classmethod
534 def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:
535 """
536 Analogous to np.empty(shape, dtype=dtype)
537
538 Parameters
539 ----------
540 shape : tuple[int]
541 dtype : ExtensionDtype
542 """
543 # The base implementation uses a naive approach to find the dtype
544 # for the backing ndarray
545 arr = cls._from_sequence([], dtype=dtype)
546 backing = np.empty(shape, dtype=arr._ndarray.dtype)
547 return arr._from_backing_data(backing)