1"""
2Implementation of nlargest and nsmallest.
3"""
4
5from __future__ import annotations
6
7from collections.abc import (
8 Hashable,
9 Sequence,
10)
11from typing import (
12 TYPE_CHECKING,
13 cast,
14 final,
15)
16
17import numpy as np
18
19from pandas._libs import algos as libalgos
20
21from pandas.core.dtypes.common import (
22 is_bool_dtype,
23 is_complex_dtype,
24 is_integer_dtype,
25 is_list_like,
26 is_numeric_dtype,
27 needs_i8_conversion,
28)
29from pandas.core.dtypes.dtypes import BaseMaskedDtype
30
31if TYPE_CHECKING:
32 from pandas._typing import (
33 DtypeObj,
34 IndexLabel,
35 )
36
37 from pandas import (
38 DataFrame,
39 Series,
40 )
41
42
43class SelectN:
44 def __init__(self, obj, n: int, keep: str) -> None:
45 self.obj = obj
46 self.n = n
47 self.keep = keep
48
49 if self.keep not in ("first", "last", "all"):
50 raise ValueError('keep must be either "first", "last" or "all"')
51
52 def compute(self, method: str) -> DataFrame | Series:
53 raise NotImplementedError
54
55 @final
56 def nlargest(self):
57 return self.compute("nlargest")
58
59 @final
60 def nsmallest(self):
61 return self.compute("nsmallest")
62
63 @final
64 @staticmethod
65 def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
66 """
67 Helper function to determine if dtype is valid for
68 nsmallest/nlargest methods
69 """
70 if is_numeric_dtype(dtype):
71 return not is_complex_dtype(dtype)
72 return needs_i8_conversion(dtype)
73
74
75class SelectNSeries(SelectN):
76 """
77 Implement n largest/smallest for Series
78
79 Parameters
80 ----------
81 obj : Series
82 n : int
83 keep : {'first', 'last'}, default 'first'
84
85 Returns
86 -------
87 nordered : Series
88 """
89
90 def compute(self, method: str) -> Series:
91 from pandas.core.reshape.concat import concat
92
93 n = self.n
94 dtype = self.obj.dtype
95 if not self.is_valid_dtype_n_method(dtype):
96 raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")
97
98 if n <= 0:
99 return self.obj[[]]
100
101 dropped = self.obj.dropna()
102 nan_index = self.obj.drop(dropped.index)
103
104 # slow method
105 if n >= len(self.obj):
106 ascending = method == "nsmallest"
107 return self.obj.sort_values(ascending=ascending).head(n)
108
109 # fast method
110 new_dtype = dropped.dtype
111
112 # Similar to algorithms._ensure_data
113 arr = dropped._values
114 if needs_i8_conversion(arr.dtype):
115 arr = arr.view("i8")
116 elif isinstance(arr.dtype, BaseMaskedDtype):
117 arr = arr._data
118 else:
119 arr = np.asarray(arr)
120 if arr.dtype.kind == "b":
121 arr = arr.view(np.uint8)
122
123 if method == "nlargest":
124 arr = -arr
125 if is_integer_dtype(new_dtype):
126 # GH 21426: ensure reverse ordering at boundaries
127 arr -= 1
128
129 elif is_bool_dtype(new_dtype):
130 # GH 26154: ensure False is smaller than True
131 arr = 1 - (-arr)
132
133 if self.keep == "last":
134 arr = arr[::-1]
135
136 nbase = n
137 narr = len(arr)
138 n = min(n, narr)
139
140 # arr passed into kth_smallest must be contiguous. We copy
141 # here because kth_smallest will modify its input
142 # avoid OOB access with kth_smallest_c when n <= 0
143 if len(arr) > 0:
144 kth_val = libalgos.kth_smallest(arr.copy(order="C"), n - 1)
145 else:
146 kth_val = np.nan
147 (ns,) = np.nonzero(arr <= kth_val)
148 inds = ns[arr[ns].argsort(kind="mergesort")]
149
150 if self.keep != "all":
151 inds = inds[:n]
152 findex = nbase
153 else:
154 if len(inds) < nbase <= len(nan_index) + len(inds):
155 findex = len(nan_index) + len(inds)
156 else:
157 findex = len(inds)
158
159 if self.keep == "last":
160 # reverse indices
161 inds = narr - 1 - inds
162
163 return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
164
165
166class SelectNFrame(SelectN):
167 """
168 Implement n largest/smallest for DataFrame
169
170 Parameters
171 ----------
172 obj : DataFrame
173 n : int
174 keep : {'first', 'last'}, default 'first'
175 columns : list or str
176
177 Returns
178 -------
179 nordered : DataFrame
180 """
181
182 def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
183 super().__init__(obj, n, keep)
184 if not is_list_like(columns) or isinstance(columns, tuple):
185 columns = [columns]
186
187 columns = cast(Sequence[Hashable], columns)
188 columns = list(columns)
189 self.columns = columns
190
191 def compute(self, method: str) -> DataFrame:
192 from pandas.core.api import Index
193
194 n = self.n
195 frame = self.obj
196 columns = self.columns
197
198 for column in columns:
199 dtype = frame[column].dtype
200 if not self.is_valid_dtype_n_method(dtype):
201 raise TypeError(
202 f"Column {repr(column)} has dtype {dtype}, "
203 f"cannot use method {repr(method)} with this dtype"
204 )
205
206 def get_indexer(current_indexer, other_indexer):
207 """
208 Helper function to concat `current_indexer` and `other_indexer`
209 depending on `method`
210 """
211 if method == "nsmallest":
212 return current_indexer.append(other_indexer)
213 else:
214 return other_indexer.append(current_indexer)
215
216 # Below we save and reset the index in case index contains duplicates
217 original_index = frame.index
218 cur_frame = frame = frame.reset_index(drop=True)
219 cur_n = n
220 indexer = Index([], dtype=np.int64)
221
222 for i, column in enumerate(columns):
223 # For each column we apply method to cur_frame[column].
224 # If it's the last column or if we have the number of
225 # results desired we are done.
226 # Otherwise there are duplicates of the largest/smallest
227 # value and we need to look at the rest of the columns
228 # to determine which of the rows with the largest/smallest
229 # value in the column to keep.
230 series = cur_frame[column]
231 is_last_column = len(columns) - 1 == i
232 values = getattr(series, method)(
233 cur_n, keep=self.keep if is_last_column else "all"
234 )
235
236 if is_last_column or len(values) <= cur_n:
237 indexer = get_indexer(indexer, values.index)
238 break
239
240 # Now find all values which are equal to
241 # the (nsmallest: largest)/(nlargest: smallest)
242 # from our series.
243 border_value = values == values[values.index[-1]]
244
245 # Some of these values are among the top-n
246 # some aren't.
247 unsafe_values = values[border_value]
248
249 # These values are definitely among the top-n
250 safe_values = values[~border_value]
251 indexer = get_indexer(indexer, safe_values.index)
252
253 # Go on and separate the unsafe_values on the remaining
254 # columns.
255 cur_frame = cur_frame.loc[unsafe_values.index]
256 cur_n = n - len(indexer)
257
258 frame = frame.take(indexer)
259
260 # Restore the index on frame
261 frame.index = original_index.take(indexer)
262
263 # If there is only one column, the frame is already sorted.
264 if len(columns) == 1:
265 return frame
266
267 ascending = method == "nsmallest"
268
269 return frame.sort_values(columns, ascending=ascending, kind="mergesort")