1from __future__ import annotations
2
3from collections.abc import Iterable
4from typing import (
5 TYPE_CHECKING,
6 Literal,
7 cast,
8)
9
10import numpy as np
11
12from pandas.util._decorators import (
13 cache_readonly,
14 doc,
15)
16
17from pandas.core.dtypes.common import (
18 is_integer,
19 is_list_like,
20)
21
22if TYPE_CHECKING:
23 from pandas._typing import PositionalIndexer
24
25 from pandas import (
26 DataFrame,
27 Series,
28 )
29 from pandas.core.groupby import groupby
30
31
32class GroupByIndexingMixin:
33 """
34 Mixin for adding ._positional_selector to GroupBy.
35 """
36
37 @cache_readonly
38 def _positional_selector(self) -> GroupByPositionalSelector:
39 """
40 Return positional selection for each group.
41
42 ``groupby._positional_selector[i:j]`` is similar to
43 ``groupby.apply(lambda x: x.iloc[i:j])``
44 but much faster and preserves the original index and order.
45
46 ``_positional_selector[]`` is compatible with and extends :meth:`~GroupBy.head`
47 and :meth:`~GroupBy.tail`. For example:
48
49 - ``head(5)``
50 - ``_positional_selector[5:-5]``
51 - ``tail(5)``
52
53 together return all the rows.
54
55 Allowed inputs for the index are:
56
57 - An integer valued iterable, e.g. ``range(2, 4)``.
58 - A comma separated list of integers and slices, e.g. ``5``, ``2, 4``, ``2:4``.
59
60 The output format is the same as :meth:`~GroupBy.head` and
61 :meth:`~GroupBy.tail`, namely
62 a subset of the ``DataFrame`` or ``Series`` with the index and order preserved.
63
64 Returns
65 -------
66 Series
67 The filtered subset of the original Series.
68 DataFrame
69 The filtered subset of the original DataFrame.
70
71 See Also
72 --------
73 DataFrame.iloc : Purely integer-location based indexing for selection by
74 position.
75 GroupBy.head : Return first n rows of each group.
76 GroupBy.tail : Return last n rows of each group.
77 GroupBy.nth : Take the nth row from each group if n is an int, or a
78 subset of rows, if n is a list of ints.
79
80 Notes
81 -----
82 - The slice step cannot be negative.
83 - If the index specification results in overlaps, the item is not duplicated.
84 - If the index specification changes the order of items, then
85 they are returned in their original order.
86 By contrast, ``DataFrame.iloc`` can change the row order.
87 - ``groupby()`` parameters such as as_index and dropna are ignored.
88
89 The differences between ``_positional_selector[]`` and :meth:`~GroupBy.nth`
90 with ``as_index=False`` are:
91
92 - Input to ``_positional_selector`` can include
93 one or more slices whereas ``nth``
94 just handles an integer or a list of integers.
95 - ``_positional_selector`` can accept a slice relative to the
96 last row of each group.
97 - ``_positional_selector`` does not have an equivalent to the
98 ``nth()`` ``dropna`` parameter.
99
100 Examples
101 --------
102 >>> df = pd.DataFrame([["a", 1], ["a", 2], ["a", 3], ["b", 4], ["b", 5]],
103 ... columns=["A", "B"])
104 >>> df.groupby("A")._positional_selector[1:2]
105 A B
106 1 a 2
107 4 b 5
108
109 >>> df.groupby("A")._positional_selector[1, -1]
110 A B
111 1 a 2
112 2 a 3
113 4 b 5
114 """
115 if TYPE_CHECKING:
116 # pylint: disable-next=used-before-assignment
117 groupby_self = cast(groupby.GroupBy, self)
118 else:
119 groupby_self = self
120
121 return GroupByPositionalSelector(groupby_self)
122
123 def _make_mask_from_positional_indexer(
124 self,
125 arg: PositionalIndexer | tuple,
126 ) -> np.ndarray:
127 if is_list_like(arg):
128 if all(is_integer(i) for i in cast(Iterable, arg)):
129 mask = self._make_mask_from_list(cast(Iterable[int], arg))
130 else:
131 mask = self._make_mask_from_tuple(cast(tuple, arg))
132
133 elif isinstance(arg, slice):
134 mask = self._make_mask_from_slice(arg)
135 elif is_integer(arg):
136 mask = self._make_mask_from_int(cast(int, arg))
137 else:
138 raise TypeError(
139 f"Invalid index {type(arg)}. "
140 "Must be integer, list-like, slice or a tuple of "
141 "integers and slices"
142 )
143
144 if isinstance(mask, bool):
145 if mask:
146 mask = self._ascending_count >= 0
147 else:
148 mask = self._ascending_count < 0
149
150 return cast(np.ndarray, mask)
151
152 def _make_mask_from_int(self, arg: int) -> np.ndarray:
153 if arg >= 0:
154 return self._ascending_count == arg
155 else:
156 return self._descending_count == (-arg - 1)
157
158 def _make_mask_from_list(self, args: Iterable[int]) -> bool | np.ndarray:
159 positive = [arg for arg in args if arg >= 0]
160 negative = [-arg - 1 for arg in args if arg < 0]
161
162 mask: bool | np.ndarray = False
163
164 if positive:
165 mask |= np.isin(self._ascending_count, positive)
166
167 if negative:
168 mask |= np.isin(self._descending_count, negative)
169
170 return mask
171
172 def _make_mask_from_tuple(self, args: tuple) -> bool | np.ndarray:
173 mask: bool | np.ndarray = False
174
175 for arg in args:
176 if is_integer(arg):
177 mask |= self._make_mask_from_int(cast(int, arg))
178 elif isinstance(arg, slice):
179 mask |= self._make_mask_from_slice(arg)
180 else:
181 raise ValueError(
182 f"Invalid argument {type(arg)}. Should be int or slice."
183 )
184
185 return mask
186
187 def _make_mask_from_slice(self, arg: slice) -> bool | np.ndarray:
188 start = arg.start
189 stop = arg.stop
190 step = arg.step
191
192 if step is not None and step < 0:
193 raise ValueError(f"Invalid step {step}. Must be non-negative")
194
195 mask: bool | np.ndarray = True
196
197 if step is None:
198 step = 1
199
200 if start is None:
201 if step > 1:
202 mask &= self._ascending_count % step == 0
203
204 elif start >= 0:
205 mask &= self._ascending_count >= start
206
207 if step > 1:
208 mask &= (self._ascending_count - start) % step == 0
209
210 else:
211 mask &= self._descending_count < -start
212
213 offset_array = self._descending_count + start + 1
214 limit_array = (
215 self._ascending_count + self._descending_count + (start + 1)
216 ) < 0
217 offset_array = np.where(limit_array, self._ascending_count, offset_array)
218
219 mask &= offset_array % step == 0
220
221 if stop is not None:
222 if stop >= 0:
223 mask &= self._ascending_count < stop
224 else:
225 mask &= self._descending_count >= -stop
226
227 return mask
228
229 @cache_readonly
230 def _ascending_count(self) -> np.ndarray:
231 if TYPE_CHECKING:
232 groupby_self = cast(groupby.GroupBy, self)
233 else:
234 groupby_self = self
235
236 return groupby_self._cumcount_array()
237
238 @cache_readonly
239 def _descending_count(self) -> np.ndarray:
240 if TYPE_CHECKING:
241 groupby_self = cast(groupby.GroupBy, self)
242 else:
243 groupby_self = self
244
245 return groupby_self._cumcount_array(ascending=False)
246
247
248@doc(GroupByIndexingMixin._positional_selector)
249class GroupByPositionalSelector:
250 def __init__(self, groupby_object: groupby.GroupBy) -> None:
251 self.groupby_object = groupby_object
252
253 def __getitem__(self, arg: PositionalIndexer | tuple) -> DataFrame | Series:
254 """
255 Select by positional index per group.
256
257 Implements GroupBy._positional_selector
258
259 Parameters
260 ----------
261 arg : PositionalIndexer | tuple
262 Allowed values are:
263 - int
264 - int valued iterable such as list or range
265 - slice with step either None or positive
266 - tuple of integers and slices
267
268 Returns
269 -------
270 Series
271 The filtered subset of the original groupby Series.
272 DataFrame
273 The filtered subset of the original groupby DataFrame.
274
275 See Also
276 --------
277 DataFrame.iloc : Integer-location based indexing for selection by position.
278 GroupBy.head : Return first n rows of each group.
279 GroupBy.tail : Return last n rows of each group.
280 GroupBy._positional_selector : Return positional selection for each group.
281 GroupBy.nth : Take the nth row from each group if n is an int, or a
282 subset of rows, if n is a list of ints.
283 """
284 mask = self.groupby_object._make_mask_from_positional_indexer(arg)
285 return self.groupby_object._mask_selected_obj(mask)
286
287
288class GroupByNthSelector:
289 """
290 Dynamically substituted for GroupBy.nth to enable both call and index
291 """
292
293 def __init__(self, groupby_object: groupby.GroupBy) -> None:
294 self.groupby_object = groupby_object
295
296 def __call__(
297 self,
298 n: PositionalIndexer | tuple,
299 dropna: Literal["any", "all", None] = None,
300 ) -> DataFrame | Series:
301 return self.groupby_object._nth(n, dropna)
302
303 def __getitem__(self, n: PositionalIndexer | tuple) -> DataFrame | Series:
304 return self.groupby_object._nth(n)