1"""
2Provide classes to perform the groupby aggregate operations.
3
4These are not exposed to the user and provide implementations of the grouping
5operations, primarily in cython. These classes (BaseGrouper and BinGrouper)
6are contained *in* the SeriesGroupBy and DataFrameGroupBy objects.
7"""
8from __future__ import annotations
9
10import collections
11import functools
12from typing import (
13 TYPE_CHECKING,
14 Callable,
15 Generic,
16 Hashable,
17 Iterator,
18 Sequence,
19 final,
20)
21
22import numpy as np
23
24from pandas._libs import (
25 NaT,
26 lib,
27)
28import pandas._libs.groupby as libgroupby
29import pandas._libs.reduction as libreduction
30from pandas._typing import (
31 ArrayLike,
32 AxisInt,
33 DtypeObj,
34 NDFrameT,
35 Shape,
36 npt,
37)
38from pandas.errors import AbstractMethodError
39from pandas.util._decorators import cache_readonly
40
41from pandas.core.dtypes.cast import (
42 maybe_cast_pointwise_result,
43 maybe_downcast_to_dtype,
44)
45from pandas.core.dtypes.common import (
46 ensure_float64,
47 ensure_int64,
48 ensure_platform_int,
49 ensure_uint64,
50 is_1d_only_ea_dtype,
51 is_bool_dtype,
52 is_complex_dtype,
53 is_datetime64_any_dtype,
54 is_float_dtype,
55 is_integer_dtype,
56 is_numeric_dtype,
57 is_period_dtype,
58 is_sparse,
59 is_timedelta64_dtype,
60 needs_i8_conversion,
61)
62from pandas.core.dtypes.dtypes import CategoricalDtype
63from pandas.core.dtypes.missing import (
64 isna,
65 maybe_fill,
66)
67
68from pandas.core.arrays import (
69 Categorical,
70 DatetimeArray,
71 ExtensionArray,
72 PeriodArray,
73 TimedeltaArray,
74)
75from pandas.core.arrays.masked import (
76 BaseMaskedArray,
77 BaseMaskedDtype,
78)
79from pandas.core.arrays.string_ import StringDtype
80from pandas.core.frame import DataFrame
81from pandas.core.groupby import grouper
82from pandas.core.indexes.api import (
83 CategoricalIndex,
84 Index,
85 MultiIndex,
86 ensure_index,
87)
88from pandas.core.series import Series
89from pandas.core.sorting import (
90 compress_group_index,
91 decons_obs_group_ids,
92 get_flattened_list,
93 get_group_index,
94 get_group_index_sorter,
95 get_indexer_dict,
96)
97
98if TYPE_CHECKING:
99 from pandas.core.generic import NDFrame
100
101
102class WrappedCythonOp:
103 """
104 Dispatch logic for functions defined in _libs.groupby
105
106 Parameters
107 ----------
108 kind: str
109 Whether the operation is an aggregate or transform.
110 how: str
111 Operation name, e.g. "mean".
112 has_dropped_na: bool
113 True precisely when dropna=True and the grouper contains a null value.
114 """
115
116 # Functions for which we do _not_ attempt to cast the cython result
117 # back to the original dtype.
118 cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
119
120 def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
121 self.kind = kind
122 self.how = how
123 self.has_dropped_na = has_dropped_na
124
125 _CYTHON_FUNCTIONS = {
126 "aggregate": {
127 "sum": "group_sum",
128 "prod": "group_prod",
129 "min": "group_min",
130 "max": "group_max",
131 "mean": "group_mean",
132 "median": "group_median_float64",
133 "var": "group_var",
134 "first": "group_nth",
135 "last": "group_last",
136 "ohlc": "group_ohlc",
137 },
138 "transform": {
139 "cumprod": "group_cumprod",
140 "cumsum": "group_cumsum",
141 "cummin": "group_cummin",
142 "cummax": "group_cummax",
143 "rank": "group_rank",
144 },
145 }
146
147 _cython_arity = {"ohlc": 4} # OHLC
148
149 # Note: we make this a classmethod and pass kind+how so that caching
150 # works at the class level and not the instance level
151 @classmethod
152 @functools.lru_cache(maxsize=None)
153 def _get_cython_function(
154 cls, kind: str, how: str, dtype: np.dtype, is_numeric: bool
155 ):
156 dtype_str = dtype.name
157 ftype = cls._CYTHON_FUNCTIONS[kind][how]
158
159 # see if there is a fused-type version of function
160 # only valid for numeric
161 f = getattr(libgroupby, ftype)
162 if is_numeric:
163 return f
164 elif dtype == np.dtype(object):
165 if how in ["median", "cumprod"]:
166 # no fused types -> no __signatures__
167 raise NotImplementedError(
168 f"function is not implemented for this dtype: "
169 f"[how->{how},dtype->{dtype_str}]"
170 )
171 if "object" not in f.__signatures__:
172 # raise NotImplementedError here rather than TypeError later
173 raise NotImplementedError(
174 f"function is not implemented for this dtype: "
175 f"[how->{how},dtype->{dtype_str}]"
176 )
177 return f
178 else:
179 raise NotImplementedError(
180 "This should not be reached. Please report a bug at "
181 "github.com/pandas-dev/pandas/",
182 dtype,
183 )
184
185 def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
186 """
187 Cast numeric dtypes to float64 for functions that only support that.
188
189 Parameters
190 ----------
191 values : np.ndarray
192
193 Returns
194 -------
195 values : np.ndarray
196 """
197 how = self.how
198
199 if how == "median":
200 # median only has a float64 implementation
201 # We should only get here with is_numeric, as non-numeric cases
202 # should raise in _get_cython_function
203 values = ensure_float64(values)
204
205 elif values.dtype.kind in ["i", "u"]:
206 if how in ["var", "mean"] or (
207 self.kind == "transform" and self.has_dropped_na
208 ):
209 # has_dropped_na check need for test_null_group_str_transformer
210 # result may still include NaN, so we have to cast
211 values = ensure_float64(values)
212
213 elif how in ["sum", "ohlc", "prod", "cumsum", "cumprod"]:
214 # Avoid overflow during group op
215 if values.dtype.kind == "i":
216 values = ensure_int64(values)
217 else:
218 values = ensure_uint64(values)
219
220 return values
221
222 # TODO: general case implementation overridable by EAs.
223 def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
224 """
225 Check if we can do this operation with our cython functions.
226
227 Raises
228 ------
229 TypeError
230 This is not a valid operation for this dtype.
231 NotImplementedError
232 This may be a valid operation, but does not have a cython implementation.
233 """
234 how = self.how
235
236 if is_numeric:
237 # never an invalid op for those dtypes, so return early as fastpath
238 return
239
240 if isinstance(dtype, CategoricalDtype):
241 if how in ["sum", "prod", "cumsum", "cumprod"]:
242 raise TypeError(f"{dtype} type does not support {how} operations")
243 if how in ["min", "max", "rank"] and not dtype.ordered:
244 # raise TypeError instead of NotImplementedError to ensure we
245 # don't go down a group-by-group path, since in the empty-groups
246 # case that would fail to raise
247 raise TypeError(f"Cannot perform {how} with non-ordered Categorical")
248 if how not in ["rank"]:
249 # only "rank" is implemented in cython
250 raise NotImplementedError(f"{dtype} dtype not supported")
251
252 elif is_sparse(dtype):
253 raise NotImplementedError(f"{dtype} dtype not supported")
254 elif is_datetime64_any_dtype(dtype):
255 # Adding/multiplying datetimes is not valid
256 if how in ["sum", "prod", "cumsum", "cumprod"]:
257 raise TypeError(f"datetime64 type does not support {how} operations")
258 elif is_period_dtype(dtype):
259 # Adding/multiplying Periods is not valid
260 if how in ["sum", "prod", "cumsum", "cumprod"]:
261 raise TypeError(f"Period type does not support {how} operations")
262 elif is_timedelta64_dtype(dtype):
263 # timedeltas we can add but not multiply
264 if how in ["prod", "cumprod"]:
265 raise TypeError(f"timedelta64 type does not support {how} operations")
266
267 def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
268 how = self.how
269 kind = self.kind
270
271 arity = self._cython_arity.get(how, 1)
272
273 out_shape: Shape
274 if how == "ohlc":
275 out_shape = (ngroups, arity)
276 elif arity > 1:
277 raise NotImplementedError(
278 "arity of more than 1 is not supported for the 'how' argument"
279 )
280 elif kind == "transform":
281 out_shape = values.shape
282 else:
283 out_shape = (ngroups,) + values.shape[1:]
284 return out_shape
285
286 def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
287 how = self.how
288
289 if how == "rank":
290 out_dtype = "float64"
291 else:
292 if is_numeric_dtype(dtype):
293 out_dtype = f"{dtype.kind}{dtype.itemsize}"
294 else:
295 out_dtype = "object"
296 return np.dtype(out_dtype)
297
298 def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
299 """
300 Get the desired dtype of a result based on the
301 input dtype and how it was computed.
302
303 Parameters
304 ----------
305 dtype : np.dtype
306
307 Returns
308 -------
309 np.dtype
310 The desired dtype of the result.
311 """
312 how = self.how
313
314 if how in ["sum", "cumsum", "sum", "prod", "cumprod"]:
315 if dtype == np.dtype(bool):
316 return np.dtype(np.int64)
317 elif how in ["mean", "median", "var"]:
318 if is_float_dtype(dtype) or is_complex_dtype(dtype):
319 return dtype
320 elif is_numeric_dtype(dtype):
321 return np.dtype(np.float64)
322 return dtype
323
324 @final
325 def _ea_wrap_cython_operation(
326 self,
327 values: ExtensionArray,
328 min_count: int,
329 ngroups: int,
330 comp_ids: np.ndarray,
331 **kwargs,
332 ) -> ArrayLike:
333 """
334 If we have an ExtensionArray, unwrap, call _cython_operation, and
335 re-wrap if appropriate.
336 """
337 if isinstance(values, BaseMaskedArray):
338 return self._masked_ea_wrap_cython_operation(
339 values,
340 min_count=min_count,
341 ngroups=ngroups,
342 comp_ids=comp_ids,
343 **kwargs,
344 )
345
346 elif isinstance(values, Categorical):
347 assert self.how == "rank" # the only one implemented ATM
348 assert values.ordered # checked earlier
349 mask = values.isna()
350 npvalues = values._ndarray
351
352 res_values = self._cython_op_ndim_compat(
353 npvalues,
354 min_count=min_count,
355 ngroups=ngroups,
356 comp_ids=comp_ids,
357 mask=mask,
358 **kwargs,
359 )
360
361 # If we ever have more than just "rank" here, we'll need to do
362 # `if self.how in self.cast_blocklist` like we do for other dtypes.
363 return res_values
364
365 npvalues = self._ea_to_cython_values(values)
366
367 res_values = self._cython_op_ndim_compat(
368 npvalues,
369 min_count=min_count,
370 ngroups=ngroups,
371 comp_ids=comp_ids,
372 mask=None,
373 **kwargs,
374 )
375
376 if self.how in self.cast_blocklist:
377 # i.e. how in ["rank"], since other cast_blocklist methods don't go
378 # through cython_operation
379 return res_values
380
381 return self._reconstruct_ea_result(values, res_values)
382
383 # TODO: general case implementation overridable by EAs.
384 def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
385 # GH#43682
386 if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
387 # All of the functions implemented here are ordinal, so we can
388 # operate on the tz-naive equivalents
389 npvalues = values._ndarray.view("M8[ns]")
390 elif isinstance(values.dtype, StringDtype):
391 # StringArray
392 npvalues = values.to_numpy(object, na_value=np.nan)
393 else:
394 raise NotImplementedError(
395 f"function is not implemented for this dtype: {values.dtype}"
396 )
397 return npvalues
398
399 # TODO: general case implementation overridable by EAs.
400 def _reconstruct_ea_result(
401 self, values: ExtensionArray, res_values: np.ndarray
402 ) -> ExtensionArray:
403 """
404 Construct an ExtensionArray result from an ndarray result.
405 """
406 dtype: BaseMaskedDtype | StringDtype
407
408 if isinstance(values.dtype, StringDtype):
409 dtype = values.dtype
410 string_array_cls = dtype.construct_array_type()
411 return string_array_cls._from_sequence(res_values, dtype=dtype)
412
413 elif isinstance(values, (DatetimeArray, TimedeltaArray, PeriodArray)):
414 # In to_cython_values we took a view as M8[ns]
415 assert res_values.dtype == "M8[ns]"
416 res_values = res_values.view(values._ndarray.dtype)
417 return values._from_backing_data(res_values)
418
419 raise NotImplementedError
420
421 @final
422 def _masked_ea_wrap_cython_operation(
423 self,
424 values: BaseMaskedArray,
425 min_count: int,
426 ngroups: int,
427 comp_ids: np.ndarray,
428 **kwargs,
429 ) -> BaseMaskedArray:
430 """
431 Equivalent of `_ea_wrap_cython_operation`, but optimized for masked EA's
432 and cython algorithms which accept a mask.
433 """
434 orig_values = values
435
436 # libgroupby functions are responsible for NOT altering mask
437 mask = values._mask
438 if self.kind != "aggregate":
439 result_mask = mask.copy()
440 else:
441 result_mask = np.zeros(ngroups, dtype=bool)
442
443 arr = values._data
444
445 res_values = self._cython_op_ndim_compat(
446 arr,
447 min_count=min_count,
448 ngroups=ngroups,
449 comp_ids=comp_ids,
450 mask=mask,
451 result_mask=result_mask,
452 **kwargs,
453 )
454
455 if self.how == "ohlc":
456 arity = self._cython_arity.get(self.how, 1)
457 result_mask = np.tile(result_mask, (arity, 1)).T
458
459 # res_values should already have the correct dtype, we just need to
460 # wrap in a MaskedArray
461 return orig_values._maybe_mask_result(res_values, result_mask)
462
463 @final
464 def _cython_op_ndim_compat(
465 self,
466 values: np.ndarray,
467 *,
468 min_count: int,
469 ngroups: int,
470 comp_ids: np.ndarray,
471 mask: npt.NDArray[np.bool_] | None = None,
472 result_mask: npt.NDArray[np.bool_] | None = None,
473 **kwargs,
474 ) -> np.ndarray:
475 if values.ndim == 1:
476 # expand to 2d, dispatch, then squeeze if appropriate
477 values2d = values[None, :]
478 if mask is not None:
479 mask = mask[None, :]
480 if result_mask is not None:
481 result_mask = result_mask[None, :]
482 res = self._call_cython_op(
483 values2d,
484 min_count=min_count,
485 ngroups=ngroups,
486 comp_ids=comp_ids,
487 mask=mask,
488 result_mask=result_mask,
489 **kwargs,
490 )
491 if res.shape[0] == 1:
492 return res[0]
493
494 # otherwise we have OHLC
495 return res.T
496
497 return self._call_cython_op(
498 values,
499 min_count=min_count,
500 ngroups=ngroups,
501 comp_ids=comp_ids,
502 mask=mask,
503 result_mask=result_mask,
504 **kwargs,
505 )
506
507 @final
508 def _call_cython_op(
509 self,
510 values: np.ndarray, # np.ndarray[ndim=2]
511 *,
512 min_count: int,
513 ngroups: int,
514 comp_ids: np.ndarray,
515 mask: npt.NDArray[np.bool_] | None,
516 result_mask: npt.NDArray[np.bool_] | None,
517 **kwargs,
518 ) -> np.ndarray: # np.ndarray[ndim=2]
519 orig_values = values
520
521 dtype = values.dtype
522 is_numeric = is_numeric_dtype(dtype)
523
524 is_datetimelike = needs_i8_conversion(dtype)
525
526 if is_datetimelike:
527 values = values.view("int64")
528 is_numeric = True
529 elif is_bool_dtype(dtype):
530 values = values.view("uint8")
531 if values.dtype == "float16":
532 values = values.astype(np.float32)
533
534 values = values.T
535 if mask is not None:
536 mask = mask.T
537 if result_mask is not None:
538 result_mask = result_mask.T
539
540 out_shape = self._get_output_shape(ngroups, values)
541 func = self._get_cython_function(self.kind, self.how, values.dtype, is_numeric)
542 values = self._get_cython_vals(values)
543 out_dtype = self._get_out_dtype(values.dtype)
544
545 result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
546 if self.kind == "aggregate":
547 counts = np.zeros(ngroups, dtype=np.int64)
548 if self.how in ["min", "max", "mean", "last", "first", "sum"]:
549 func(
550 out=result,
551 counts=counts,
552 values=values,
553 labels=comp_ids,
554 min_count=min_count,
555 mask=mask,
556 result_mask=result_mask,
557 is_datetimelike=is_datetimelike,
558 )
559 elif self.how in ["var", "ohlc", "prod", "median"]:
560 func(
561 result,
562 counts,
563 values,
564 comp_ids,
565 min_count=min_count,
566 mask=mask,
567 result_mask=result_mask,
568 **kwargs,
569 )
570 else:
571 raise NotImplementedError(f"{self.how} is not implemented")
572 else:
573 # TODO: min_count
574 if self.how != "rank":
575 # TODO: should rank take result_mask?
576 kwargs["result_mask"] = result_mask
577 func(
578 out=result,
579 values=values,
580 labels=comp_ids,
581 ngroups=ngroups,
582 is_datetimelike=is_datetimelike,
583 mask=mask,
584 **kwargs,
585 )
586
587 if self.kind == "aggregate":
588 # i.e. counts is defined. Locations where count<min_count
589 # need to have the result set to np.nan, which may require casting,
590 # see GH#40767
591 if is_integer_dtype(result.dtype) and not is_datetimelike:
592 # if the op keeps the int dtypes, we have to use 0
593 cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
594 empty_groups = counts < cutoff
595 if empty_groups.any():
596 if result_mask is not None:
597 assert result_mask[empty_groups].all()
598 else:
599 # Note: this conversion could be lossy, see GH#40767
600 result = result.astype("float64")
601 result[empty_groups] = np.nan
602
603 result = result.T
604
605 if self.how not in self.cast_blocklist:
606 # e.g. if we are int64 and need to restore to datetime64/timedelta64
607 # "rank" is the only member of cast_blocklist we get here
608 # Casting only needed for float16, bool, datetimelike,
609 # and self.how in ["sum", "prod", "ohlc", "cumprod"]
610 res_dtype = self._get_result_dtype(orig_values.dtype)
611 op_result = maybe_downcast_to_dtype(result, res_dtype)
612 else:
613 op_result = result
614
615 return op_result
616
617 @final
618 def cython_operation(
619 self,
620 *,
621 values: ArrayLike,
622 axis: AxisInt,
623 min_count: int = -1,
624 comp_ids: np.ndarray,
625 ngroups: int,
626 **kwargs,
627 ) -> ArrayLike:
628 """
629 Call our cython function, with appropriate pre- and post- processing.
630 """
631 if values.ndim > 2:
632 raise NotImplementedError("number of dimensions is currently limited to 2")
633 if values.ndim == 2:
634 assert axis == 1, axis
635 elif not is_1d_only_ea_dtype(values.dtype):
636 # Note: it is *not* the case that axis is always 0 for 1-dim values,
637 # as we can have 1D ExtensionArrays that we need to treat as 2D
638 assert axis == 0
639
640 dtype = values.dtype
641 is_numeric = is_numeric_dtype(dtype)
642
643 # can we do this operation with our cython functions
644 # if not raise NotImplementedError
645 self._disallow_invalid_ops(dtype, is_numeric)
646
647 if not isinstance(values, np.ndarray):
648 # i.e. ExtensionArray
649 return self._ea_wrap_cython_operation(
650 values,
651 min_count=min_count,
652 ngroups=ngroups,
653 comp_ids=comp_ids,
654 **kwargs,
655 )
656
657 return self._cython_op_ndim_compat(
658 values,
659 min_count=min_count,
660 ngroups=ngroups,
661 comp_ids=comp_ids,
662 mask=None,
663 **kwargs,
664 )
665
666
667class BaseGrouper:
668 """
669 This is an internal Grouper class, which actually holds
670 the generated groups
671
672 Parameters
673 ----------
674 axis : Index
675 groupings : Sequence[Grouping]
676 all the grouping instances to handle in this grouper
677 for example for grouper list to groupby, need to pass the list
678 sort : bool, default True
679 whether this grouper will give sorted result or not
680
681 """
682
683 axis: Index
684
685 def __init__(
686 self,
687 axis: Index,
688 groupings: Sequence[grouper.Grouping],
689 sort: bool = True,
690 dropna: bool = True,
691 ) -> None:
692 assert isinstance(axis, Index), axis
693
694 self.axis = axis
695 self._groupings: list[grouper.Grouping] = list(groupings)
696 self._sort = sort
697 self.dropna = dropna
698
699 @property
700 def groupings(self) -> list[grouper.Grouping]:
701 return self._groupings
702
703 @property
704 def shape(self) -> Shape:
705 return tuple(ping.ngroups for ping in self.groupings)
706
707 def __iter__(self) -> Iterator[Hashable]:
708 return iter(self.indices)
709
710 @property
711 def nkeys(self) -> int:
712 return len(self.groupings)
713
714 def get_iterator(
715 self, data: NDFrameT, axis: AxisInt = 0
716 ) -> Iterator[tuple[Hashable, NDFrameT]]:
717 """
718 Groupby iterator
719
720 Returns
721 -------
722 Generator yielding sequence of (name, subsetted object)
723 for each group
724 """
725 splitter = self._get_splitter(data, axis=axis)
726 keys = self.group_keys_seq
727 yield from zip(keys, splitter)
728
729 @final
730 def _get_splitter(self, data: NDFrame, axis: AxisInt = 0) -> DataSplitter:
731 """
732 Returns
733 -------
734 Generator yielding subsetted objects
735 """
736 ids, _, ngroups = self.group_info
737 return _get_splitter(data, ids, ngroups, axis=axis)
738
739 @final
740 @cache_readonly
741 def group_keys_seq(self):
742 if len(self.groupings) == 1:
743 return self.levels[0]
744 else:
745 ids, _, ngroups = self.group_info
746
747 # provide "flattened" iterator for multi-group setting
748 return get_flattened_list(ids, ngroups, self.levels, self.codes)
749
750 @final
751 def apply(
752 self, f: Callable, data: DataFrame | Series, axis: AxisInt = 0
753 ) -> tuple[list, bool]:
754 mutated = False
755 splitter = self._get_splitter(data, axis=axis)
756 group_keys = self.group_keys_seq
757 result_values = []
758
759 # This calls DataSplitter.__iter__
760 zipped = zip(group_keys, splitter)
761
762 for key, group in zipped:
763 object.__setattr__(group, "name", key)
764
765 # group might be modified
766 group_axes = group.axes
767 res = f(group)
768 if not mutated and not _is_indexed_like(res, group_axes, axis):
769 mutated = True
770 result_values.append(res)
771 # getattr pattern for __name__ is needed for functools.partial objects
772 if len(group_keys) == 0 and getattr(f, "__name__", None) in [
773 "skew",
774 "sum",
775 "prod",
776 ]:
777 # If group_keys is empty, then no function calls have been made,
778 # so we will not have raised even if this is an invalid dtype.
779 # So do one dummy call here to raise appropriate TypeError.
780 f(data.iloc[:0])
781
782 return result_values, mutated
783
784 @cache_readonly
785 def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
786 """dict {group name -> group indices}"""
787 if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
788 # This shows unused categories in indices GH#38642
789 return self.groupings[0].indices
790 codes_list = [ping.codes for ping in self.groupings]
791 keys = [ping.group_index for ping in self.groupings]
792 return get_indexer_dict(codes_list, keys)
793
794 @final
795 def result_ilocs(self) -> npt.NDArray[np.intp]:
796 """
797 Get the original integer locations of result_index in the input.
798 """
799 # Original indices are where group_index would go via sorting.
800 # But when dropna is true, we need to remove null values while accounting for
801 # any gaps that then occur because of them.
802 group_index = get_group_index(
803 self.codes, self.shape, sort=self._sort, xnull=True
804 )
805 group_index, _ = compress_group_index(group_index, sort=self._sort)
806
807 if self.has_dropped_na:
808 mask = np.where(group_index >= 0)
809 # Count how many gaps are caused by previous null values for each position
810 null_gaps = np.cumsum(group_index == -1)[mask]
811 group_index = group_index[mask]
812
813 result = get_group_index_sorter(group_index, self.ngroups)
814
815 if self.has_dropped_na:
816 # Shift by the number of prior null gaps
817 result += np.take(null_gaps, result)
818
819 return result
820
821 @final
822 @property
823 def codes(self) -> list[npt.NDArray[np.signedinteger]]:
824 return [ping.codes for ping in self.groupings]
825
826 @property
827 def levels(self) -> list[Index]:
828 return [ping.group_index for ping in self.groupings]
829
830 @property
831 def names(self) -> list[Hashable]:
832 return [ping.name for ping in self.groupings]
833
834 @final
835 def size(self) -> Series:
836 """
837 Compute group sizes.
838 """
839 ids, _, ngroups = self.group_info
840 out: np.ndarray | list
841 if ngroups:
842 out = np.bincount(ids[ids != -1], minlength=ngroups)
843 else:
844 out = []
845 return Series(out, index=self.result_index, dtype="int64")
846
847 @cache_readonly
848 def groups(self) -> dict[Hashable, np.ndarray]:
849 """dict {group name -> group labels}"""
850 if len(self.groupings) == 1:
851 return self.groupings[0].groups
852 else:
853 to_groupby = zip(*(ping.grouping_vector for ping in self.groupings))
854 index = Index(to_groupby)
855 return self.axis.groupby(index)
856
857 @final
858 @cache_readonly
859 def is_monotonic(self) -> bool:
860 # return if my group orderings are monotonic
861 return Index(self.group_info[0]).is_monotonic_increasing
862
863 @final
864 @cache_readonly
865 def has_dropped_na(self) -> bool:
866 """
867 Whether grouper has null value(s) that are dropped.
868 """
869 return bool((self.group_info[0] < 0).any())
870
871 @cache_readonly
872 def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
873 comp_ids, obs_group_ids = self._get_compressed_codes()
874
875 ngroups = len(obs_group_ids)
876 comp_ids = ensure_platform_int(comp_ids)
877
878 return comp_ids, obs_group_ids, ngroups
879
880 @cache_readonly
881 def codes_info(self) -> npt.NDArray[np.intp]:
882 # return the codes of items in original grouped axis
883 ids, _, _ = self.group_info
884 return ids
885
886 @final
887 def _get_compressed_codes(
888 self,
889 ) -> tuple[npt.NDArray[np.signedinteger], npt.NDArray[np.intp]]:
890 # The first returned ndarray may have any signed integer dtype
891 if len(self.groupings) > 1:
892 group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
893 return compress_group_index(group_index, sort=self._sort)
894 # FIXME: compress_group_index's second return value is int64, not intp
895
896 ping = self.groupings[0]
897 return ping.codes, np.arange(len(ping.group_index), dtype=np.intp)
898
899 @final
900 @cache_readonly
901 def ngroups(self) -> int:
902 return len(self.result_index)
903
904 @property
905 def reconstructed_codes(self) -> list[npt.NDArray[np.intp]]:
906 codes = self.codes
907 ids, obs_ids, _ = self.group_info
908 return decons_obs_group_ids(ids, obs_ids, self.shape, codes, xnull=True)
909
910 @cache_readonly
911 def result_index(self) -> Index:
912 if len(self.groupings) == 1:
913 return self.groupings[0].result_index.rename(self.names[0])
914
915 codes = self.reconstructed_codes
916 levels = [ping.result_index for ping in self.groupings]
917 return MultiIndex(
918 levels=levels, codes=codes, verify_integrity=False, names=self.names
919 )
920
921 @final
922 def get_group_levels(self) -> list[ArrayLike]:
923 # Note: only called from _insert_inaxis_grouper, which
924 # is only called for BaseGrouper, never for BinGrouper
925 if len(self.groupings) == 1:
926 return [self.groupings[0].group_arraylike]
927
928 name_list = []
929 for ping, codes in zip(self.groupings, self.reconstructed_codes):
930 codes = ensure_platform_int(codes)
931 levels = ping.group_arraylike.take(codes)
932
933 name_list.append(levels)
934
935 return name_list
936
937 # ------------------------------------------------------------
938 # Aggregation functions
939
940 @final
941 def _cython_operation(
942 self,
943 kind: str,
944 values,
945 how: str,
946 axis: AxisInt,
947 min_count: int = -1,
948 **kwargs,
949 ) -> ArrayLike:
950 """
951 Returns the values of a cython operation.
952 """
953 assert kind in ["transform", "aggregate"]
954
955 cy_op = WrappedCythonOp(kind=kind, how=how, has_dropped_na=self.has_dropped_na)
956
957 ids, _, _ = self.group_info
958 ngroups = self.ngroups
959 return cy_op.cython_operation(
960 values=values,
961 axis=axis,
962 min_count=min_count,
963 comp_ids=ids,
964 ngroups=ngroups,
965 **kwargs,
966 )
967
968 @final
969 def agg_series(
970 self, obj: Series, func: Callable, preserve_dtype: bool = False
971 ) -> ArrayLike:
972 """
973 Parameters
974 ----------
975 obj : Series
976 func : function taking a Series and returning a scalar-like
977 preserve_dtype : bool
978 Whether the aggregation is known to be dtype-preserving.
979
980 Returns
981 -------
982 np.ndarray or ExtensionArray
983 """
984 # test_groupby_empty_with_category gets here with self.ngroups == 0
985 # and len(obj) > 0
986
987 if len(obj) > 0 and not isinstance(obj._values, np.ndarray):
988 # we can preserve a little bit more aggressively with EA dtype
989 # because maybe_cast_pointwise_result will do a try/except
990 # with _from_sequence. NB we are assuming here that _from_sequence
991 # is sufficiently strict that it casts appropriately.
992 preserve_dtype = True
993
994 result = self._aggregate_series_pure_python(obj, func)
995
996 npvalues = lib.maybe_convert_objects(result, try_float=False)
997 if preserve_dtype:
998 out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
999 else:
1000 out = npvalues
1001 return out
1002
1003 @final
1004 def _aggregate_series_pure_python(
1005 self, obj: Series, func: Callable
1006 ) -> npt.NDArray[np.object_]:
1007 _, _, ngroups = self.group_info
1008
1009 result = np.empty(ngroups, dtype="O")
1010 initialized = False
1011
1012 splitter = self._get_splitter(obj, axis=0)
1013
1014 for i, group in enumerate(splitter):
1015 res = func(group)
1016 res = libreduction.extract_result(res)
1017
1018 if not initialized:
1019 # We only do this validation on the first iteration
1020 libreduction.check_result_array(res, group.dtype)
1021 initialized = True
1022
1023 result[i] = res
1024
1025 return result
1026
1027
1028class BinGrouper(BaseGrouper):
1029 """
1030 This is an internal Grouper class
1031
1032 Parameters
1033 ----------
1034 bins : the split index of binlabels to group the item of axis
1035 binlabels : the label list
1036 indexer : np.ndarray[np.intp], optional
1037 the indexer created by Grouper
1038 some groupers (TimeGrouper) will sort its axis and its
1039 group_info is also sorted, so need the indexer to reorder
1040
1041 Examples
1042 --------
1043 bins: [2, 4, 6, 8, 10]
1044 binlabels: DatetimeIndex(['2005-01-01', '2005-01-03',
1045 '2005-01-05', '2005-01-07', '2005-01-09'],
1046 dtype='datetime64[ns]', freq='2D')
1047
1048 the group_info, which contains the label of each item in grouped
1049 axis, the index of label in label list, group number, is
1050
1051 (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]), array([0, 1, 2, 3, 4]), 5)
1052
1053 means that, the grouped axis has 10 items, can be grouped into 5
1054 labels, the first and second items belong to the first label, the
1055 third and forth items belong to the second label, and so on
1056
1057 """
1058
1059 bins: npt.NDArray[np.int64]
1060 binlabels: Index
1061
1062 def __init__(
1063 self,
1064 bins,
1065 binlabels,
1066 indexer=None,
1067 ) -> None:
1068 self.bins = ensure_int64(bins)
1069 self.binlabels = ensure_index(binlabels)
1070 self.indexer = indexer
1071
1072 # These lengths must match, otherwise we could call agg_series
1073 # with empty self.bins, which would raise in libreduction.
1074 assert len(self.binlabels) == len(self.bins)
1075
1076 @cache_readonly
1077 def groups(self):
1078 """dict {group name -> group labels}"""
1079 # this is mainly for compat
1080 # GH 3881
1081 result = {
1082 key: value
1083 for key, value in zip(self.binlabels, self.bins)
1084 if key is not NaT
1085 }
1086 return result
1087
1088 @property
1089 def nkeys(self) -> int:
1090 # still matches len(self.groupings), but we can hard-code
1091 return 1
1092
1093 @cache_readonly
1094 def codes_info(self) -> npt.NDArray[np.intp]:
1095 # return the codes of items in original grouped axis
1096 ids, _, _ = self.group_info
1097 if self.indexer is not None:
1098 sorter = np.lexsort((ids, self.indexer))
1099 ids = ids[sorter]
1100 return ids
1101
1102 def get_iterator(self, data: NDFrame, axis: AxisInt = 0):
1103 """
1104 Groupby iterator
1105
1106 Returns
1107 -------
1108 Generator yielding sequence of (name, subsetted object)
1109 for each group
1110 """
1111 if axis == 0:
1112 slicer = lambda start, edge: data.iloc[start:edge]
1113 else:
1114 slicer = lambda start, edge: data.iloc[:, start:edge]
1115
1116 length = len(data.axes[axis])
1117
1118 start = 0
1119 for edge, label in zip(self.bins, self.binlabels):
1120 if label is not NaT:
1121 yield label, slicer(start, edge)
1122 start = edge
1123
1124 if start < length:
1125 yield self.binlabels[-1], slicer(start, None)
1126
1127 @cache_readonly
1128 def indices(self):
1129 indices = collections.defaultdict(list)
1130
1131 i = 0
1132 for label, bin in zip(self.binlabels, self.bins):
1133 if i < bin:
1134 if label is not NaT:
1135 indices[label] = list(range(i, bin))
1136 i = bin
1137 return indices
1138
1139 @cache_readonly
1140 def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
1141 ngroups = self.ngroups
1142 obs_group_ids = np.arange(ngroups, dtype=np.intp)
1143 rep = np.diff(np.r_[0, self.bins])
1144
1145 rep = ensure_platform_int(rep)
1146 if ngroups == len(self.bins):
1147 comp_ids = np.repeat(np.arange(ngroups), rep)
1148 else:
1149 comp_ids = np.repeat(np.r_[-1, np.arange(ngroups)], rep)
1150
1151 return (
1152 ensure_platform_int(comp_ids),
1153 obs_group_ids,
1154 ngroups,
1155 )
1156
1157 @cache_readonly
1158 def reconstructed_codes(self) -> list[np.ndarray]:
1159 # get unique result indices, and prepend 0 as groupby starts from the first
1160 return [np.r_[0, np.flatnonzero(self.bins[1:] != self.bins[:-1]) + 1]]
1161
1162 @cache_readonly
1163 def result_index(self) -> Index:
1164 if len(self.binlabels) != 0 and isna(self.binlabels[0]):
1165 return self.binlabels[1:]
1166
1167 return self.binlabels
1168
1169 @property
1170 def levels(self) -> list[Index]:
1171 return [self.binlabels]
1172
1173 @property
1174 def names(self) -> list[Hashable]:
1175 return [self.binlabels.name]
1176
1177 @property
1178 def groupings(self) -> list[grouper.Grouping]:
1179 lev = self.binlabels
1180 codes = self.group_info[0]
1181 labels = lev.take(codes)
1182 ping = grouper.Grouping(
1183 labels, labels, in_axis=False, level=None, uniques=lev._values
1184 )
1185 return [ping]
1186
1187
1188def _is_indexed_like(obj, axes, axis: AxisInt) -> bool:
1189 if isinstance(obj, Series):
1190 if len(axes) > 1:
1191 return False
1192 return obj.axes[axis].equals(axes[axis])
1193 elif isinstance(obj, DataFrame):
1194 return obj.axes[axis].equals(axes[axis])
1195
1196 return False
1197
1198
1199# ----------------------------------------------------------------------
1200# Splitting / application
1201
1202
1203class DataSplitter(Generic[NDFrameT]):
1204 def __init__(
1205 self,
1206 data: NDFrameT,
1207 labels: npt.NDArray[np.intp],
1208 ngroups: int,
1209 axis: AxisInt = 0,
1210 ) -> None:
1211 self.data = data
1212 self.labels = ensure_platform_int(labels) # _should_ already be np.intp
1213 self.ngroups = ngroups
1214
1215 self.axis = axis
1216 assert isinstance(axis, int), axis
1217
1218 @cache_readonly
1219 def _slabels(self) -> npt.NDArray[np.intp]:
1220 # Sorted labels
1221 return self.labels.take(self._sort_idx)
1222
1223 @cache_readonly
1224 def _sort_idx(self) -> npt.NDArray[np.intp]:
1225 # Counting sort indexer
1226 return get_group_index_sorter(self.labels, self.ngroups)
1227
1228 def __iter__(self) -> Iterator:
1229 sdata = self._sorted_data
1230
1231 if self.ngroups == 0:
1232 # we are inside a generator, rather than raise StopIteration
1233 # we merely return signal the end
1234 return
1235
1236 starts, ends = lib.generate_slices(self._slabels, self.ngroups)
1237
1238 for start, end in zip(starts, ends):
1239 yield self._chop(sdata, slice(start, end))
1240
1241 @cache_readonly
1242 def _sorted_data(self) -> NDFrameT:
1243 return self.data.take(self._sort_idx, axis=self.axis)
1244
1245 def _chop(self, sdata, slice_obj: slice) -> NDFrame:
1246 raise AbstractMethodError(self)
1247
1248
1249class SeriesSplitter(DataSplitter):
1250 def _chop(self, sdata: Series, slice_obj: slice) -> Series:
1251 # fastpath equivalent to `sdata.iloc[slice_obj]`
1252 mgr = sdata._mgr.get_slice(slice_obj)
1253 ser = sdata._constructor(mgr, name=sdata.name, fastpath=True)
1254 return ser.__finalize__(sdata, method="groupby")
1255
1256
1257class FrameSplitter(DataSplitter):
1258 def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
1259 # Fastpath equivalent to:
1260 # if self.axis == 0:
1261 # return sdata.iloc[slice_obj]
1262 # else:
1263 # return sdata.iloc[:, slice_obj]
1264 mgr = sdata._mgr.get_slice(slice_obj, axis=1 - self.axis)
1265 df = sdata._constructor(mgr)
1266 return df.__finalize__(sdata, method="groupby")
1267
1268
1269def _get_splitter(
1270 data: NDFrame, labels: np.ndarray, ngroups: int, axis: AxisInt = 0
1271) -> DataSplitter:
1272 if isinstance(data, Series):
1273 klass: type[DataSplitter] = SeriesSplitter
1274 else:
1275 # i.e. DataFrame
1276 klass = FrameSplitter
1277
1278 return klass(data, labels, ngroups, axis)