1from __future__ import annotations
2
3from typing import (
4 TYPE_CHECKING,
5 Iterable,
6 Literal,
7 cast,
8)
9
10import numpy as np
11
12from pandas._typing import PositionalIndexer
13from pandas.util._decorators import (
14 cache_readonly,
15 doc,
16)
17
18from pandas.core.dtypes.common import (
19 is_integer,
20 is_list_like,
21)
22
23if TYPE_CHECKING:
24 from pandas import (
25 DataFrame,
26 Series,
27 )
28 from pandas.core.groupby import groupby
29
30
31class GroupByIndexingMixin:
32 """
33 Mixin for adding ._positional_selector to GroupBy.
34 """
35
36 @cache_readonly
37 def _positional_selector(self) -> GroupByPositionalSelector:
38 """
39 Return positional selection for each group.
40
41 ``groupby._positional_selector[i:j]`` is similar to
42 ``groupby.apply(lambda x: x.iloc[i:j])``
43 but much faster and preserves the original index and order.
44
45 ``_positional_selector[]`` is compatible with and extends :meth:`~GroupBy.head`
46 and :meth:`~GroupBy.tail`. For example:
47
48 - ``head(5)``
49 - ``_positional_selector[5:-5]``
50 - ``tail(5)``
51
52 together return all the rows.
53
54 Allowed inputs for the index are:
55
56 - An integer valued iterable, e.g. ``range(2, 4)``.
57 - A comma separated list of integers and slices, e.g. ``5``, ``2, 4``, ``2:4``.
58
59 The output format is the same as :meth:`~GroupBy.head` and
60 :meth:`~GroupBy.tail`, namely
61 a subset of the ``DataFrame`` or ``Series`` with the index and order preserved.
62
63 Returns
64 -------
65 Series
66 The filtered subset of the original Series.
67 DataFrame
68 The filtered subset of the original DataFrame.
69
70 See Also
71 --------
72 DataFrame.iloc : Purely integer-location based indexing for selection by
73 position.
74 GroupBy.head : Return first n rows of each group.
75 GroupBy.tail : Return last n rows of each group.
76 GroupBy.nth : Take the nth row from each group if n is an int, or a
77 subset of rows, if n is a list of ints.
78
79 Notes
80 -----
81 - The slice step cannot be negative.
82 - If the index specification results in overlaps, the item is not duplicated.
83 - If the index specification changes the order of items, then
84 they are returned in their original order.
85 By contrast, ``DataFrame.iloc`` can change the row order.
86 - ``groupby()`` parameters such as as_index and dropna are ignored.
87
88 The differences between ``_positional_selector[]`` and :meth:`~GroupBy.nth`
89 with ``as_index=False`` are:
90
91 - Input to ``_positional_selector`` can include
92 one or more slices whereas ``nth``
93 just handles an integer or a list of integers.
94 - ``_positional_selector`` can accept a slice relative to the
95 last row of each group.
96 - ``_positional_selector`` does not have an equivalent to the
97 ``nth()`` ``dropna`` parameter.
98
99 Examples
100 --------
101 >>> df = pd.DataFrame([["a", 1], ["a", 2], ["a", 3], ["b", 4], ["b", 5]],
102 ... columns=["A", "B"])
103 >>> df.groupby("A")._positional_selector[1:2]
104 A B
105 1 a 2
106 4 b 5
107
108 >>> df.groupby("A")._positional_selector[1, -1]
109 A B
110 1 a 2
111 2 a 3
112 4 b 5
113 """
114 if TYPE_CHECKING:
115 # pylint: disable-next=used-before-assignment
116 groupby_self = cast(groupby.GroupBy, self)
117 else:
118 groupby_self = self
119
120 return GroupByPositionalSelector(groupby_self)
121
122 def _make_mask_from_positional_indexer(
123 self,
124 arg: PositionalIndexer | tuple,
125 ) -> np.ndarray:
126 if is_list_like(arg):
127 if all(is_integer(i) for i in cast(Iterable, arg)):
128 mask = self._make_mask_from_list(cast(Iterable[int], arg))
129 else:
130 mask = self._make_mask_from_tuple(cast(tuple, arg))
131
132 elif isinstance(arg, slice):
133 mask = self._make_mask_from_slice(arg)
134 elif is_integer(arg):
135 mask = self._make_mask_from_int(cast(int, arg))
136 else:
137 raise TypeError(
138 f"Invalid index {type(arg)}. "
139 "Must be integer, list-like, slice or a tuple of "
140 "integers and slices"
141 )
142
143 if isinstance(mask, bool):
144 if mask:
145 mask = self._ascending_count >= 0
146 else:
147 mask = self._ascending_count < 0
148
149 return cast(np.ndarray, mask)
150
151 def _make_mask_from_int(self, arg: int) -> np.ndarray:
152 if arg >= 0:
153 return self._ascending_count == arg
154 else:
155 return self._descending_count == (-arg - 1)
156
157 def _make_mask_from_list(self, args: Iterable[int]) -> bool | np.ndarray:
158 positive = [arg for arg in args if arg >= 0]
159 negative = [-arg - 1 for arg in args if arg < 0]
160
161 mask: bool | np.ndarray = False
162
163 if positive:
164 mask |= np.isin(self._ascending_count, positive)
165
166 if negative:
167 mask |= np.isin(self._descending_count, negative)
168
169 return mask
170
171 def _make_mask_from_tuple(self, args: tuple) -> bool | np.ndarray:
172 mask: bool | np.ndarray = False
173
174 for arg in args:
175 if is_integer(arg):
176 mask |= self._make_mask_from_int(cast(int, arg))
177 elif isinstance(arg, slice):
178 mask |= self._make_mask_from_slice(arg)
179 else:
180 raise ValueError(
181 f"Invalid argument {type(arg)}. Should be int or slice."
182 )
183
184 return mask
185
186 def _make_mask_from_slice(self, arg: slice) -> bool | np.ndarray:
187 start = arg.start
188 stop = arg.stop
189 step = arg.step
190
191 if step is not None and step < 0:
192 raise ValueError(f"Invalid step {step}. Must be non-negative")
193
194 mask: bool | np.ndarray = True
195
196 if step is None:
197 step = 1
198
199 if start is None:
200 if step > 1:
201 mask &= self._ascending_count % step == 0
202
203 elif start >= 0:
204 mask &= self._ascending_count >= start
205
206 if step > 1:
207 mask &= (self._ascending_count - start) % step == 0
208
209 else:
210 mask &= self._descending_count < -start
211
212 offset_array = self._descending_count + start + 1
213 limit_array = (
214 self._ascending_count + self._descending_count + (start + 1)
215 ) < 0
216 offset_array = np.where(limit_array, self._ascending_count, offset_array)
217
218 mask &= offset_array % step == 0
219
220 if stop is not None:
221 if stop >= 0:
222 mask &= self._ascending_count < stop
223 else:
224 mask &= self._descending_count >= -stop
225
226 return mask
227
228 @cache_readonly
229 def _ascending_count(self) -> np.ndarray:
230 if TYPE_CHECKING:
231 groupby_self = cast(groupby.GroupBy, self)
232 else:
233 groupby_self = self
234
235 return groupby_self._cumcount_array()
236
237 @cache_readonly
238 def _descending_count(self) -> np.ndarray:
239 if TYPE_CHECKING:
240 groupby_self = cast(groupby.GroupBy, self)
241 else:
242 groupby_self = self
243
244 return groupby_self._cumcount_array(ascending=False)
245
246
247@doc(GroupByIndexingMixin._positional_selector)
248class GroupByPositionalSelector:
249 def __init__(self, groupby_object: groupby.GroupBy) -> None:
250 self.groupby_object = groupby_object
251
252 def __getitem__(self, arg: PositionalIndexer | tuple) -> DataFrame | Series:
253 """
254 Select by positional index per group.
255
256 Implements GroupBy._positional_selector
257
258 Parameters
259 ----------
260 arg : PositionalIndexer | tuple
261 Allowed values are:
262 - int
263 - int valued iterable such as list or range
264 - slice with step either None or positive
265 - tuple of integers and slices
266
267 Returns
268 -------
269 Series
270 The filtered subset of the original groupby Series.
271 DataFrame
272 The filtered subset of the original groupby DataFrame.
273
274 See Also
275 --------
276 DataFrame.iloc : Integer-location based indexing for selection by position.
277 GroupBy.head : Return first n rows of each group.
278 GroupBy.tail : Return last n rows of each group.
279 GroupBy._positional_selector : Return positional selection for each group.
280 GroupBy.nth : Take the nth row from each group if n is an int, or a
281 subset of rows, if n is a list of ints.
282 """
283 mask = self.groupby_object._make_mask_from_positional_indexer(arg)
284 return self.groupby_object._mask_selected_obj(mask)
285
286
287class GroupByNthSelector:
288 """
289 Dynamically substituted for GroupBy.nth to enable both call and index
290 """
291
292 def __init__(self, groupby_object: groupby.GroupBy) -> None:
293 self.groupby_object = groupby_object
294
295 def __call__(
296 self,
297 n: PositionalIndexer | tuple,
298 dropna: Literal["any", "all", None] = None,
299 ) -> DataFrame | Series:
300 return self.groupby_object._nth(n, dropna)
301
302 def __getitem__(self, n: PositionalIndexer | tuple) -> DataFrame | Series:
303 return self.groupby_object._nth(n)