1from __future__ import annotations
2
3import functools
4from typing import (
5 TYPE_CHECKING,
6 cast,
7 overload,
8)
9
10import numpy as np
11
12from pandas._libs import (
13 algos as libalgos,
14 lib,
15)
16from pandas._typing import (
17 ArrayLike,
18 AxisInt,
19 npt,
20)
21
22from pandas.core.dtypes.cast import maybe_promote
23from pandas.core.dtypes.common import (
24 ensure_platform_int,
25 is_1d_only_ea_obj,
26)
27from pandas.core.dtypes.missing import na_value_for_dtype
28
29from pandas.core.construction import ensure_wrapped_if_datetimelike
30
31if TYPE_CHECKING:
32 from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
33 from pandas.core.arrays.base import ExtensionArray
34
35
36@overload
37def take_nd(
38 arr: np.ndarray,
39 indexer,
40 axis: AxisInt = ...,
41 fill_value=...,
42 allow_fill: bool = ...,
43) -> np.ndarray:
44 ...
45
46
47@overload
48def take_nd(
49 arr: ExtensionArray,
50 indexer,
51 axis: AxisInt = ...,
52 fill_value=...,
53 allow_fill: bool = ...,
54) -> ArrayLike:
55 ...
56
57
58def take_nd(
59 arr: ArrayLike,
60 indexer,
61 axis: AxisInt = 0,
62 fill_value=lib.no_default,
63 allow_fill: bool = True,
64) -> ArrayLike:
65 """
66 Specialized Cython take which sets NaN values in one pass
67
68 This dispatches to ``take`` defined on ExtensionArrays. It does not
69 currently dispatch to ``SparseArray.take`` for sparse ``arr``.
70
71 Note: this function assumes that the indexer is a valid(ated) indexer with
72 no out of bound indices.
73
74 Parameters
75 ----------
76 arr : np.ndarray or ExtensionArray
77 Input array.
78 indexer : ndarray
79 1-D array of indices to take, subarrays corresponding to -1 value
80 indices are filed with fill_value
81 axis : int, default 0
82 Axis to take from
83 fill_value : any, default np.nan
84 Fill value to replace -1 values with
85 allow_fill : bool, default True
86 If False, indexer is assumed to contain no -1 values so no filling
87 will be done. This short-circuits computation of a mask. Result is
88 undefined if allow_fill == False and -1 is present in indexer.
89
90 Returns
91 -------
92 subarray : np.ndarray or ExtensionArray
93 May be the same type as the input, or cast to an ndarray.
94 """
95 if fill_value is lib.no_default:
96 fill_value = na_value_for_dtype(arr.dtype, compat=False)
97 elif isinstance(arr.dtype, np.dtype) and arr.dtype.kind in "mM":
98 dtype, fill_value = maybe_promote(arr.dtype, fill_value)
99 if arr.dtype != dtype:
100 # EA.take is strict about returning a new object of the same type
101 # so for that case cast upfront
102 arr = arr.astype(dtype)
103
104 if not isinstance(arr, np.ndarray):
105 # i.e. ExtensionArray,
106 # includes for EA to catch DatetimeArray, TimedeltaArray
107 if not is_1d_only_ea_obj(arr):
108 # i.e. DatetimeArray, TimedeltaArray
109 arr = cast("NDArrayBackedExtensionArray", arr)
110 return arr.take(
111 indexer, fill_value=fill_value, allow_fill=allow_fill, axis=axis
112 )
113
114 return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
115
116 arr = np.asarray(arr)
117 return _take_nd_ndarray(arr, indexer, axis, fill_value, allow_fill)
118
119
120def _take_nd_ndarray(
121 arr: np.ndarray,
122 indexer: npt.NDArray[np.intp] | None,
123 axis: AxisInt,
124 fill_value,
125 allow_fill: bool,
126) -> np.ndarray:
127 if indexer is None:
128 indexer = np.arange(arr.shape[axis], dtype=np.intp)
129 dtype, fill_value = arr.dtype, arr.dtype.type()
130 else:
131 indexer = ensure_platform_int(indexer)
132
133 dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value(
134 arr, indexer, fill_value, allow_fill
135 )
136
137 flip_order = False
138 if arr.ndim == 2 and arr.flags.f_contiguous:
139 flip_order = True
140
141 if flip_order:
142 arr = arr.T
143 axis = arr.ndim - axis - 1
144
145 # at this point, it's guaranteed that dtype can hold both the arr values
146 # and the fill_value
147 out_shape_ = list(arr.shape)
148 out_shape_[axis] = len(indexer)
149 out_shape = tuple(out_shape_)
150 if arr.flags.f_contiguous and axis == arr.ndim - 1:
151 # minor tweak that can make an order-of-magnitude difference
152 # for dataframes initialized directly from 2-d ndarrays
153 # (s.t. df.values is c-contiguous and df._mgr.blocks[0] is its
154 # f-contiguous transpose)
155 out = np.empty(out_shape, dtype=dtype, order="F")
156 else:
157 out = np.empty(out_shape, dtype=dtype)
158
159 func = _get_take_nd_function(
160 arr.ndim, arr.dtype, out.dtype, axis=axis, mask_info=mask_info
161 )
162 func(arr, indexer, out, fill_value)
163
164 if flip_order:
165 out = out.T
166 return out
167
168
169def take_1d(
170 arr: ArrayLike,
171 indexer: npt.NDArray[np.intp],
172 fill_value=None,
173 allow_fill: bool = True,
174 mask: npt.NDArray[np.bool_] | None = None,
175) -> ArrayLike:
176 """
177 Specialized version for 1D arrays. Differences compared to `take_nd`:
178
179 - Assumes input array has already been converted to numpy array / EA
180 - Assumes indexer is already guaranteed to be intp dtype ndarray
181 - Only works for 1D arrays
182
183 To ensure the lowest possible overhead.
184
185 Note: similarly to `take_nd`, this function assumes that the indexer is
186 a valid(ated) indexer with no out of bound indices.
187
188 Parameters
189 ----------
190 arr : np.ndarray or ExtensionArray
191 Input array.
192 indexer : ndarray
193 1-D array of indices to take (validated indices, intp dtype).
194 fill_value : any, default np.nan
195 Fill value to replace -1 values with
196 allow_fill : bool, default True
197 If False, indexer is assumed to contain no -1 values so no filling
198 will be done. This short-circuits computation of a mask. Result is
199 undefined if allow_fill == False and -1 is present in indexer.
200 mask : np.ndarray, optional, default None
201 If `allow_fill` is True, and the mask (where indexer == -1) is already
202 known, it can be passed to avoid recomputation.
203 """
204 if not isinstance(arr, np.ndarray):
205 # ExtensionArray -> dispatch to their method
206 return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
207
208 if not allow_fill:
209 return arr.take(indexer)
210
211 dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value(
212 arr, indexer, fill_value, True, mask
213 )
214
215 # at this point, it's guaranteed that dtype can hold both the arr values
216 # and the fill_value
217 out = np.empty(indexer.shape, dtype=dtype)
218
219 func = _get_take_nd_function(
220 arr.ndim, arr.dtype, out.dtype, axis=0, mask_info=mask_info
221 )
222 func(arr, indexer, out, fill_value)
223
224 return out
225
226
227def take_2d_multi(
228 arr: np.ndarray,
229 indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]],
230 fill_value=np.nan,
231) -> np.ndarray:
232 """
233 Specialized Cython take which sets NaN values in one pass.
234 """
235 # This is only called from one place in DataFrame._reindex_multi,
236 # so we know indexer is well-behaved.
237 assert indexer is not None
238 assert indexer[0] is not None
239 assert indexer[1] is not None
240
241 row_idx, col_idx = indexer
242
243 row_idx = ensure_platform_int(row_idx)
244 col_idx = ensure_platform_int(col_idx)
245 indexer = row_idx, col_idx
246 mask_info = None
247
248 # check for promotion based on types only (do this first because
249 # it's faster than computing a mask)
250 dtype, fill_value = maybe_promote(arr.dtype, fill_value)
251 if dtype != arr.dtype:
252 # check if promotion is actually required based on indexer
253 row_mask = row_idx == -1
254 col_mask = col_idx == -1
255 row_needs = row_mask.any()
256 col_needs = col_mask.any()
257 mask_info = (row_mask, col_mask), (row_needs, col_needs)
258
259 if not (row_needs or col_needs):
260 # if not, then depromote, set fill_value to dummy
261 # (it won't be used but we don't want the cython code
262 # to crash when trying to cast it to dtype)
263 dtype, fill_value = arr.dtype, arr.dtype.type()
264
265 # at this point, it's guaranteed that dtype can hold both the arr values
266 # and the fill_value
267 out_shape = len(row_idx), len(col_idx)
268 out = np.empty(out_shape, dtype=dtype)
269
270 func = _take_2d_multi_dict.get((arr.dtype.name, out.dtype.name), None)
271 if func is None and arr.dtype != out.dtype:
272 func = _take_2d_multi_dict.get((out.dtype.name, out.dtype.name), None)
273 if func is not None:
274 func = _convert_wrapper(func, out.dtype)
275
276 if func is not None:
277 func(arr, indexer, out=out, fill_value=fill_value)
278 else:
279 # test_reindex_multi
280 _take_2d_multi_object(
281 arr, indexer, out, fill_value=fill_value, mask_info=mask_info
282 )
283
284 return out
285
286
287@functools.lru_cache(maxsize=128)
288def _get_take_nd_function_cached(
289 ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: AxisInt
290):
291 """
292 Part of _get_take_nd_function below that doesn't need `mask_info` and thus
293 can be cached (mask_info potentially contains a numpy ndarray which is not
294 hashable and thus cannot be used as argument for cached function).
295 """
296 tup = (arr_dtype.name, out_dtype.name)
297 if ndim == 1:
298 func = _take_1d_dict.get(tup, None)
299 elif ndim == 2:
300 if axis == 0:
301 func = _take_2d_axis0_dict.get(tup, None)
302 else:
303 func = _take_2d_axis1_dict.get(tup, None)
304 if func is not None:
305 return func
306
307 # We get here with string, uint, float16, and complex dtypes that could
308 # potentially be handled in algos_take_helper.
309 # Also a couple with (M8[ns], object) and (m8[ns], object)
310 tup = (out_dtype.name, out_dtype.name)
311 if ndim == 1:
312 func = _take_1d_dict.get(tup, None)
313 elif ndim == 2:
314 if axis == 0:
315 func = _take_2d_axis0_dict.get(tup, None)
316 else:
317 func = _take_2d_axis1_dict.get(tup, None)
318 if func is not None:
319 func = _convert_wrapper(func, out_dtype)
320 return func
321
322 return None
323
324
325def _get_take_nd_function(
326 ndim: int,
327 arr_dtype: np.dtype,
328 out_dtype: np.dtype,
329 axis: AxisInt = 0,
330 mask_info=None,
331):
332 """
333 Get the appropriate "take" implementation for the given dimension, axis
334 and dtypes.
335 """
336 func = None
337 if ndim <= 2:
338 # for this part we don't need `mask_info` -> use the cached algo lookup
339 func = _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis)
340
341 if func is None:
342
343 def func(arr, indexer, out, fill_value=np.nan) -> None:
344 indexer = ensure_platform_int(indexer)
345 _take_nd_object(
346 arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
347 )
348
349 return func
350
351
352def _view_wrapper(f, arr_dtype=None, out_dtype=None, fill_wrap=None):
353 def wrapper(
354 arr: np.ndarray, indexer: np.ndarray, out: np.ndarray, fill_value=np.nan
355 ) -> None:
356 if arr_dtype is not None:
357 arr = arr.view(arr_dtype)
358 if out_dtype is not None:
359 out = out.view(out_dtype)
360 if fill_wrap is not None:
361 # FIXME: if we get here with dt64/td64 we need to be sure we have
362 # matching resos
363 if fill_value.dtype.kind == "m":
364 fill_value = fill_value.astype("m8[ns]")
365 else:
366 fill_value = fill_value.astype("M8[ns]")
367 fill_value = fill_wrap(fill_value)
368
369 f(arr, indexer, out, fill_value=fill_value)
370
371 return wrapper
372
373
374def _convert_wrapper(f, conv_dtype):
375 def wrapper(
376 arr: np.ndarray, indexer: np.ndarray, out: np.ndarray, fill_value=np.nan
377 ) -> None:
378 if conv_dtype == object:
379 # GH#39755 avoid casting dt64/td64 to integers
380 arr = ensure_wrapped_if_datetimelike(arr)
381 arr = arr.astype(conv_dtype)
382 f(arr, indexer, out, fill_value=fill_value)
383
384 return wrapper
385
386
387_take_1d_dict = {
388 ("int8", "int8"): libalgos.take_1d_int8_int8,
389 ("int8", "int32"): libalgos.take_1d_int8_int32,
390 ("int8", "int64"): libalgos.take_1d_int8_int64,
391 ("int8", "float64"): libalgos.take_1d_int8_float64,
392 ("int16", "int16"): libalgos.take_1d_int16_int16,
393 ("int16", "int32"): libalgos.take_1d_int16_int32,
394 ("int16", "int64"): libalgos.take_1d_int16_int64,
395 ("int16", "float64"): libalgos.take_1d_int16_float64,
396 ("int32", "int32"): libalgos.take_1d_int32_int32,
397 ("int32", "int64"): libalgos.take_1d_int32_int64,
398 ("int32", "float64"): libalgos.take_1d_int32_float64,
399 ("int64", "int64"): libalgos.take_1d_int64_int64,
400 ("int64", "float64"): libalgos.take_1d_int64_float64,
401 ("float32", "float32"): libalgos.take_1d_float32_float32,
402 ("float32", "float64"): libalgos.take_1d_float32_float64,
403 ("float64", "float64"): libalgos.take_1d_float64_float64,
404 ("object", "object"): libalgos.take_1d_object_object,
405 ("bool", "bool"): _view_wrapper(libalgos.take_1d_bool_bool, np.uint8, np.uint8),
406 ("bool", "object"): _view_wrapper(libalgos.take_1d_bool_object, np.uint8, None),
407 ("datetime64[ns]", "datetime64[ns]"): _view_wrapper(
408 libalgos.take_1d_int64_int64, np.int64, np.int64, np.int64
409 ),
410 ("timedelta64[ns]", "timedelta64[ns]"): _view_wrapper(
411 libalgos.take_1d_int64_int64, np.int64, np.int64, np.int64
412 ),
413}
414
415_take_2d_axis0_dict = {
416 ("int8", "int8"): libalgos.take_2d_axis0_int8_int8,
417 ("int8", "int32"): libalgos.take_2d_axis0_int8_int32,
418 ("int8", "int64"): libalgos.take_2d_axis0_int8_int64,
419 ("int8", "float64"): libalgos.take_2d_axis0_int8_float64,
420 ("int16", "int16"): libalgos.take_2d_axis0_int16_int16,
421 ("int16", "int32"): libalgos.take_2d_axis0_int16_int32,
422 ("int16", "int64"): libalgos.take_2d_axis0_int16_int64,
423 ("int16", "float64"): libalgos.take_2d_axis0_int16_float64,
424 ("int32", "int32"): libalgos.take_2d_axis0_int32_int32,
425 ("int32", "int64"): libalgos.take_2d_axis0_int32_int64,
426 ("int32", "float64"): libalgos.take_2d_axis0_int32_float64,
427 ("int64", "int64"): libalgos.take_2d_axis0_int64_int64,
428 ("int64", "float64"): libalgos.take_2d_axis0_int64_float64,
429 ("float32", "float32"): libalgos.take_2d_axis0_float32_float32,
430 ("float32", "float64"): libalgos.take_2d_axis0_float32_float64,
431 ("float64", "float64"): libalgos.take_2d_axis0_float64_float64,
432 ("object", "object"): libalgos.take_2d_axis0_object_object,
433 ("bool", "bool"): _view_wrapper(
434 libalgos.take_2d_axis0_bool_bool, np.uint8, np.uint8
435 ),
436 ("bool", "object"): _view_wrapper(
437 libalgos.take_2d_axis0_bool_object, np.uint8, None
438 ),
439 ("datetime64[ns]", "datetime64[ns]"): _view_wrapper(
440 libalgos.take_2d_axis0_int64_int64, np.int64, np.int64, fill_wrap=np.int64
441 ),
442 ("timedelta64[ns]", "timedelta64[ns]"): _view_wrapper(
443 libalgos.take_2d_axis0_int64_int64, np.int64, np.int64, fill_wrap=np.int64
444 ),
445}
446
447_take_2d_axis1_dict = {
448 ("int8", "int8"): libalgos.take_2d_axis1_int8_int8,
449 ("int8", "int32"): libalgos.take_2d_axis1_int8_int32,
450 ("int8", "int64"): libalgos.take_2d_axis1_int8_int64,
451 ("int8", "float64"): libalgos.take_2d_axis1_int8_float64,
452 ("int16", "int16"): libalgos.take_2d_axis1_int16_int16,
453 ("int16", "int32"): libalgos.take_2d_axis1_int16_int32,
454 ("int16", "int64"): libalgos.take_2d_axis1_int16_int64,
455 ("int16", "float64"): libalgos.take_2d_axis1_int16_float64,
456 ("int32", "int32"): libalgos.take_2d_axis1_int32_int32,
457 ("int32", "int64"): libalgos.take_2d_axis1_int32_int64,
458 ("int32", "float64"): libalgos.take_2d_axis1_int32_float64,
459 ("int64", "int64"): libalgos.take_2d_axis1_int64_int64,
460 ("int64", "float64"): libalgos.take_2d_axis1_int64_float64,
461 ("float32", "float32"): libalgos.take_2d_axis1_float32_float32,
462 ("float32", "float64"): libalgos.take_2d_axis1_float32_float64,
463 ("float64", "float64"): libalgos.take_2d_axis1_float64_float64,
464 ("object", "object"): libalgos.take_2d_axis1_object_object,
465 ("bool", "bool"): _view_wrapper(
466 libalgos.take_2d_axis1_bool_bool, np.uint8, np.uint8
467 ),
468 ("bool", "object"): _view_wrapper(
469 libalgos.take_2d_axis1_bool_object, np.uint8, None
470 ),
471 ("datetime64[ns]", "datetime64[ns]"): _view_wrapper(
472 libalgos.take_2d_axis1_int64_int64, np.int64, np.int64, fill_wrap=np.int64
473 ),
474 ("timedelta64[ns]", "timedelta64[ns]"): _view_wrapper(
475 libalgos.take_2d_axis1_int64_int64, np.int64, np.int64, fill_wrap=np.int64
476 ),
477}
478
479_take_2d_multi_dict = {
480 ("int8", "int8"): libalgos.take_2d_multi_int8_int8,
481 ("int8", "int32"): libalgos.take_2d_multi_int8_int32,
482 ("int8", "int64"): libalgos.take_2d_multi_int8_int64,
483 ("int8", "float64"): libalgos.take_2d_multi_int8_float64,
484 ("int16", "int16"): libalgos.take_2d_multi_int16_int16,
485 ("int16", "int32"): libalgos.take_2d_multi_int16_int32,
486 ("int16", "int64"): libalgos.take_2d_multi_int16_int64,
487 ("int16", "float64"): libalgos.take_2d_multi_int16_float64,
488 ("int32", "int32"): libalgos.take_2d_multi_int32_int32,
489 ("int32", "int64"): libalgos.take_2d_multi_int32_int64,
490 ("int32", "float64"): libalgos.take_2d_multi_int32_float64,
491 ("int64", "int64"): libalgos.take_2d_multi_int64_int64,
492 ("int64", "float64"): libalgos.take_2d_multi_int64_float64,
493 ("float32", "float32"): libalgos.take_2d_multi_float32_float32,
494 ("float32", "float64"): libalgos.take_2d_multi_float32_float64,
495 ("float64", "float64"): libalgos.take_2d_multi_float64_float64,
496 ("object", "object"): libalgos.take_2d_multi_object_object,
497 ("bool", "bool"): _view_wrapper(
498 libalgos.take_2d_multi_bool_bool, np.uint8, np.uint8
499 ),
500 ("bool", "object"): _view_wrapper(
501 libalgos.take_2d_multi_bool_object, np.uint8, None
502 ),
503 ("datetime64[ns]", "datetime64[ns]"): _view_wrapper(
504 libalgos.take_2d_multi_int64_int64, np.int64, np.int64, fill_wrap=np.int64
505 ),
506 ("timedelta64[ns]", "timedelta64[ns]"): _view_wrapper(
507 libalgos.take_2d_multi_int64_int64, np.int64, np.int64, fill_wrap=np.int64
508 ),
509}
510
511
512def _take_nd_object(
513 arr: np.ndarray,
514 indexer: npt.NDArray[np.intp],
515 out: np.ndarray,
516 axis: AxisInt,
517 fill_value,
518 mask_info,
519) -> None:
520 if mask_info is not None:
521 mask, needs_masking = mask_info
522 else:
523 mask = indexer == -1
524 needs_masking = mask.any()
525 if arr.dtype != out.dtype:
526 arr = arr.astype(out.dtype)
527 if arr.shape[axis] > 0:
528 arr.take(indexer, axis=axis, out=out)
529 if needs_masking:
530 outindexer = [slice(None)] * arr.ndim
531 outindexer[axis] = mask
532 out[tuple(outindexer)] = fill_value
533
534
535def _take_2d_multi_object(
536 arr: np.ndarray,
537 indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]],
538 out: np.ndarray,
539 fill_value,
540 mask_info,
541) -> None:
542 # this is not ideal, performance-wise, but it's better than raising
543 # an exception (best to optimize in Cython to avoid getting here)
544 row_idx, col_idx = indexer # both np.intp
545 if mask_info is not None:
546 (row_mask, col_mask), (row_needs, col_needs) = mask_info
547 else:
548 row_mask = row_idx == -1
549 col_mask = col_idx == -1
550 row_needs = row_mask.any()
551 col_needs = col_mask.any()
552 if fill_value is not None:
553 if row_needs:
554 out[row_mask, :] = fill_value
555 if col_needs:
556 out[:, col_mask] = fill_value
557 for i, u_ in enumerate(row_idx):
558 if u_ != -1:
559 for j, v in enumerate(col_idx):
560 if v != -1:
561 out[i, j] = arr[u_, v]
562
563
564def _take_preprocess_indexer_and_fill_value(
565 arr: np.ndarray,
566 indexer: npt.NDArray[np.intp],
567 fill_value,
568 allow_fill: bool,
569 mask: npt.NDArray[np.bool_] | None = None,
570):
571 mask_info: tuple[np.ndarray | None, bool] | None = None
572
573 if not allow_fill:
574 dtype, fill_value = arr.dtype, arr.dtype.type()
575 mask_info = None, False
576 else:
577 # check for promotion based on types only (do this first because
578 # it's faster than computing a mask)
579 dtype, fill_value = maybe_promote(arr.dtype, fill_value)
580 if dtype != arr.dtype:
581 # check if promotion is actually required based on indexer
582 if mask is not None:
583 needs_masking = True
584 else:
585 mask = indexer == -1
586 needs_masking = bool(mask.any())
587 mask_info = mask, needs_masking
588 if not needs_masking:
589 # if not, then depromote, set fill_value to dummy
590 # (it won't be used but we don't want the cython code
591 # to crash when trying to cast it to dtype)
592 dtype, fill_value = arr.dtype, arr.dtype.type()
593
594 return dtype, fill_value, mask_info