1"""
2Implementation of nlargest and nsmallest.
3"""
4
5from __future__ import annotations
6
7from typing import (
8 TYPE_CHECKING,
9 Hashable,
10 Sequence,
11 cast,
12 final,
13)
14
15import numpy as np
16
17from pandas._libs import algos as libalgos
18from pandas._typing import (
19 DtypeObj,
20 IndexLabel,
21)
22
23from pandas.core.dtypes.common import (
24 is_bool_dtype,
25 is_complex_dtype,
26 is_integer_dtype,
27 is_list_like,
28 is_numeric_dtype,
29 needs_i8_conversion,
30)
31from pandas.core.dtypes.dtypes import BaseMaskedDtype
32
33if TYPE_CHECKING:
34 from pandas import (
35 DataFrame,
36 Series,
37 )
38
39
40class SelectN:
41 def __init__(self, obj, n: int, keep: str) -> None:
42 self.obj = obj
43 self.n = n
44 self.keep = keep
45
46 if self.keep not in ("first", "last", "all"):
47 raise ValueError('keep must be either "first", "last" or "all"')
48
49 def compute(self, method: str) -> DataFrame | Series:
50 raise NotImplementedError
51
52 @final
53 def nlargest(self):
54 return self.compute("nlargest")
55
56 @final
57 def nsmallest(self):
58 return self.compute("nsmallest")
59
60 @final
61 @staticmethod
62 def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
63 """
64 Helper function to determine if dtype is valid for
65 nsmallest/nlargest methods
66 """
67 if is_numeric_dtype(dtype):
68 return not is_complex_dtype(dtype)
69 return needs_i8_conversion(dtype)
70
71
72class SelectNSeries(SelectN):
73 """
74 Implement n largest/smallest for Series
75
76 Parameters
77 ----------
78 obj : Series
79 n : int
80 keep : {'first', 'last'}, default 'first'
81
82 Returns
83 -------
84 nordered : Series
85 """
86
87 def compute(self, method: str) -> Series:
88 from pandas.core.reshape.concat import concat
89
90 n = self.n
91 dtype = self.obj.dtype
92 if not self.is_valid_dtype_n_method(dtype):
93 raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")
94
95 if n <= 0:
96 return self.obj[[]]
97
98 dropped = self.obj.dropna()
99 nan_index = self.obj.drop(dropped.index)
100
101 # slow method
102 if n >= len(self.obj):
103 ascending = method == "nsmallest"
104 return self.obj.sort_values(ascending=ascending).head(n)
105
106 # fast method
107 new_dtype = dropped.dtype
108
109 # Similar to algorithms._ensure_data
110 arr = dropped._values
111 if needs_i8_conversion(arr.dtype):
112 arr = arr.view("i8")
113 elif isinstance(arr.dtype, BaseMaskedDtype):
114 arr = arr._data
115 else:
116 arr = np.asarray(arr)
117 if arr.dtype.kind == "b":
118 arr = arr.view(np.uint8)
119
120 if method == "nlargest":
121 arr = -arr
122 if is_integer_dtype(new_dtype):
123 # GH 21426: ensure reverse ordering at boundaries
124 arr -= 1
125
126 elif is_bool_dtype(new_dtype):
127 # GH 26154: ensure False is smaller than True
128 arr = 1 - (-arr)
129
130 if self.keep == "last":
131 arr = arr[::-1]
132
133 nbase = n
134 narr = len(arr)
135 n = min(n, narr)
136
137 # arr passed into kth_smallest must be contiguous. We copy
138 # here because kth_smallest will modify its input
139 kth_val = libalgos.kth_smallest(arr.copy(order="C"), n - 1)
140 (ns,) = np.nonzero(arr <= kth_val)
141 inds = ns[arr[ns].argsort(kind="mergesort")]
142
143 if self.keep != "all":
144 inds = inds[:n]
145 findex = nbase
146 else:
147 if len(inds) < nbase <= len(nan_index) + len(inds):
148 findex = len(nan_index) + len(inds)
149 else:
150 findex = len(inds)
151
152 if self.keep == "last":
153 # reverse indices
154 inds = narr - 1 - inds
155
156 return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
157
158
159class SelectNFrame(SelectN):
160 """
161 Implement n largest/smallest for DataFrame
162
163 Parameters
164 ----------
165 obj : DataFrame
166 n : int
167 keep : {'first', 'last'}, default 'first'
168 columns : list or str
169
170 Returns
171 -------
172 nordered : DataFrame
173 """
174
175 def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
176 super().__init__(obj, n, keep)
177 if not is_list_like(columns) or isinstance(columns, tuple):
178 columns = [columns]
179
180 columns = cast(Sequence[Hashable], columns)
181 columns = list(columns)
182 self.columns = columns
183
184 def compute(self, method: str) -> DataFrame:
185 from pandas.core.api import Index
186
187 n = self.n
188 frame = self.obj
189 columns = self.columns
190
191 for column in columns:
192 dtype = frame[column].dtype
193 if not self.is_valid_dtype_n_method(dtype):
194 raise TypeError(
195 f"Column {repr(column)} has dtype {dtype}, "
196 f"cannot use method {repr(method)} with this dtype"
197 )
198
199 def get_indexer(current_indexer, other_indexer):
200 """
201 Helper function to concat `current_indexer` and `other_indexer`
202 depending on `method`
203 """
204 if method == "nsmallest":
205 return current_indexer.append(other_indexer)
206 else:
207 return other_indexer.append(current_indexer)
208
209 # Below we save and reset the index in case index contains duplicates
210 original_index = frame.index
211 cur_frame = frame = frame.reset_index(drop=True)
212 cur_n = n
213 indexer = Index([], dtype=np.int64)
214
215 for i, column in enumerate(columns):
216 # For each column we apply method to cur_frame[column].
217 # If it's the last column or if we have the number of
218 # results desired we are done.
219 # Otherwise there are duplicates of the largest/smallest
220 # value and we need to look at the rest of the columns
221 # to determine which of the rows with the largest/smallest
222 # value in the column to keep.
223 series = cur_frame[column]
224 is_last_column = len(columns) - 1 == i
225 values = getattr(series, method)(
226 cur_n, keep=self.keep if is_last_column else "all"
227 )
228
229 if is_last_column or len(values) <= cur_n:
230 indexer = get_indexer(indexer, values.index)
231 break
232
233 # Now find all values which are equal to
234 # the (nsmallest: largest)/(nlargest: smallest)
235 # from our series.
236 border_value = values == values[values.index[-1]]
237
238 # Some of these values are among the top-n
239 # some aren't.
240 unsafe_values = values[border_value]
241
242 # These values are definitely among the top-n
243 safe_values = values[~border_value]
244 indexer = get_indexer(indexer, safe_values.index)
245
246 # Go on and separate the unsafe_values on the remaining
247 # columns.
248 cur_frame = cur_frame.loc[unsafe_values.index]
249 cur_n = n - len(indexer)
250
251 frame = frame.take(indexer)
252
253 # Restore the index on frame
254 frame.index = original_index.take(indexer)
255
256 # If there is only one column, the frame is already sorted.
257 if len(columns) == 1:
258 return frame
259
260 ascending = method == "nsmallest"
261
262 return frame.sort_values(columns, ascending=ascending, kind="mergesort")