1from collections import Counter
2from contextlib import suppress
3from typing import NamedTuple
4
5import numpy as np
6
7from . import is_scalar_nan
8
9
10def _unique(values, *, return_inverse=False, return_counts=False):
11 """Helper function to find unique values with support for python objects.
12
13 Uses pure python method for object dtype, and numpy method for
14 all other dtypes.
15
16 Parameters
17 ----------
18 values : ndarray
19 Values to check for unknowns.
20
21 return_inverse : bool, default=False
22 If True, also return the indices of the unique values.
23
24 return_counts : bool, default=False
25 If True, also return the number of times each unique item appears in
26 values.
27
28 Returns
29 -------
30 unique : ndarray
31 The sorted unique values.
32
33 unique_inverse : ndarray
34 The indices to reconstruct the original array from the unique array.
35 Only provided if `return_inverse` is True.
36
37 unique_counts : ndarray
38 The number of times each of the unique values comes up in the original
39 array. Only provided if `return_counts` is True.
40 """
41 if values.dtype == object:
42 return _unique_python(
43 values, return_inverse=return_inverse, return_counts=return_counts
44 )
45 # numerical
46 return _unique_np(
47 values, return_inverse=return_inverse, return_counts=return_counts
48 )
49
50
51def _unique_np(values, return_inverse=False, return_counts=False):
52 """Helper function to find unique values for numpy arrays that correctly
53 accounts for nans. See `_unique` documentation for details."""
54 uniques = np.unique(
55 values, return_inverse=return_inverse, return_counts=return_counts
56 )
57
58 inverse, counts = None, None
59
60 if return_counts:
61 *uniques, counts = uniques
62
63 if return_inverse:
64 *uniques, inverse = uniques
65
66 if return_counts or return_inverse:
67 uniques = uniques[0]
68
69 # np.unique will have duplicate missing values at the end of `uniques`
70 # here we clip the nans and remove it from uniques
71 if uniques.size and is_scalar_nan(uniques[-1]):
72 nan_idx = np.searchsorted(uniques, np.nan)
73 uniques = uniques[: nan_idx + 1]
74 if return_inverse:
75 inverse[inverse > nan_idx] = nan_idx
76
77 if return_counts:
78 counts[nan_idx] = np.sum(counts[nan_idx:])
79 counts = counts[: nan_idx + 1]
80
81 ret = (uniques,)
82
83 if return_inverse:
84 ret += (inverse,)
85
86 if return_counts:
87 ret += (counts,)
88
89 return ret[0] if len(ret) == 1 else ret
90
91
92class MissingValues(NamedTuple):
93 """Data class for missing data information"""
94
95 nan: bool
96 none: bool
97
98 def to_list(self):
99 """Convert tuple to a list where None is always first."""
100 output = []
101 if self.none:
102 output.append(None)
103 if self.nan:
104 output.append(np.nan)
105 return output
106
107
108def _extract_missing(values):
109 """Extract missing values from `values`.
110
111 Parameters
112 ----------
113 values: set
114 Set of values to extract missing from.
115
116 Returns
117 -------
118 output: set
119 Set with missing values extracted.
120
121 missing_values: MissingValues
122 Object with missing value information.
123 """
124 missing_values_set = {
125 value for value in values if value is None or is_scalar_nan(value)
126 }
127
128 if not missing_values_set:
129 return values, MissingValues(nan=False, none=False)
130
131 if None in missing_values_set:
132 if len(missing_values_set) == 1:
133 output_missing_values = MissingValues(nan=False, none=True)
134 else:
135 # If there is more than one missing value, then it has to be
136 # float('nan') or np.nan
137 output_missing_values = MissingValues(nan=True, none=True)
138 else:
139 output_missing_values = MissingValues(nan=True, none=False)
140
141 # create set without the missing values
142 output = values - missing_values_set
143 return output, output_missing_values
144
145
146class _nandict(dict):
147 """Dictionary with support for nans."""
148
149 def __init__(self, mapping):
150 super().__init__(mapping)
151 for key, value in mapping.items():
152 if is_scalar_nan(key):
153 self.nan_value = value
154 break
155
156 def __missing__(self, key):
157 if hasattr(self, "nan_value") and is_scalar_nan(key):
158 return self.nan_value
159 raise KeyError(key)
160
161
162def _map_to_integer(values, uniques):
163 """Map values based on its position in uniques."""
164 table = _nandict({val: i for i, val in enumerate(uniques)})
165 return np.array([table[v] for v in values])
166
167
168def _unique_python(values, *, return_inverse, return_counts):
169 # Only used in `_uniques`, see docstring there for details
170 try:
171 uniques_set = set(values)
172 uniques_set, missing_values = _extract_missing(uniques_set)
173
174 uniques = sorted(uniques_set)
175 uniques.extend(missing_values.to_list())
176 uniques = np.array(uniques, dtype=values.dtype)
177 except TypeError:
178 types = sorted(t.__qualname__ for t in set(type(v) for v in values))
179 raise TypeError(
180 "Encoders require their input argument must be uniformly "
181 f"strings or numbers. Got {types}"
182 )
183 ret = (uniques,)
184
185 if return_inverse:
186 ret += (_map_to_integer(values, uniques),)
187
188 if return_counts:
189 ret += (_get_counts(values, uniques),)
190
191 return ret[0] if len(ret) == 1 else ret
192
193
194def _encode(values, *, uniques, check_unknown=True):
195 """Helper function to encode values into [0, n_uniques - 1].
196
197 Uses pure python method for object dtype, and numpy method for
198 all other dtypes.
199 The numpy method has the limitation that the `uniques` need to
200 be sorted. Importantly, this is not checked but assumed to already be
201 the case. The calling method needs to ensure this for all non-object
202 values.
203
204 Parameters
205 ----------
206 values : ndarray
207 Values to encode.
208 uniques : ndarray
209 The unique values in `values`. If the dtype is not object, then
210 `uniques` needs to be sorted.
211 check_unknown : bool, default=True
212 If True, check for values in `values` that are not in `unique`
213 and raise an error. This is ignored for object dtype, and treated as
214 True in this case. This parameter is useful for
215 _BaseEncoder._transform() to avoid calling _check_unknown()
216 twice.
217
218 Returns
219 -------
220 encoded : ndarray
221 Encoded values
222 """
223 if values.dtype.kind in "OUS":
224 try:
225 return _map_to_integer(values, uniques)
226 except KeyError as e:
227 raise ValueError(f"y contains previously unseen labels: {str(e)}")
228 else:
229 if check_unknown:
230 diff = _check_unknown(values, uniques)
231 if diff:
232 raise ValueError(f"y contains previously unseen labels: {str(diff)}")
233 return np.searchsorted(uniques, values)
234
235
236def _check_unknown(values, known_values, return_mask=False):
237 """
238 Helper function to check for unknowns in values to be encoded.
239
240 Uses pure python method for object dtype, and numpy method for
241 all other dtypes.
242
243 Parameters
244 ----------
245 values : array
246 Values to check for unknowns.
247 known_values : array
248 Known values. Must be unique.
249 return_mask : bool, default=False
250 If True, return a mask of the same shape as `values` indicating
251 the valid values.
252
253 Returns
254 -------
255 diff : list
256 The unique values present in `values` and not in `know_values`.
257 valid_mask : boolean array
258 Additionally returned if ``return_mask=True``.
259
260 """
261 valid_mask = None
262
263 if values.dtype.kind in "OUS":
264 values_set = set(values)
265 values_set, missing_in_values = _extract_missing(values_set)
266
267 uniques_set = set(known_values)
268 uniques_set, missing_in_uniques = _extract_missing(uniques_set)
269 diff = values_set - uniques_set
270
271 nan_in_diff = missing_in_values.nan and not missing_in_uniques.nan
272 none_in_diff = missing_in_values.none and not missing_in_uniques.none
273
274 def is_valid(value):
275 return (
276 value in uniques_set
277 or missing_in_uniques.none
278 and value is None
279 or missing_in_uniques.nan
280 and is_scalar_nan(value)
281 )
282
283 if return_mask:
284 if diff or nan_in_diff or none_in_diff:
285 valid_mask = np.array([is_valid(value) for value in values])
286 else:
287 valid_mask = np.ones(len(values), dtype=bool)
288
289 diff = list(diff)
290 if none_in_diff:
291 diff.append(None)
292 if nan_in_diff:
293 diff.append(np.nan)
294 else:
295 unique_values = np.unique(values)
296 diff = np.setdiff1d(unique_values, known_values, assume_unique=True)
297 if return_mask:
298 if diff.size:
299 valid_mask = np.isin(values, known_values)
300 else:
301 valid_mask = np.ones(len(values), dtype=bool)
302
303 # check for nans in the known_values
304 if np.isnan(known_values).any():
305 diff_is_nan = np.isnan(diff)
306 if diff_is_nan.any():
307 # removes nan from valid_mask
308 if diff.size and return_mask:
309 is_nan = np.isnan(values)
310 valid_mask[is_nan] = 1
311
312 # remove nan from diff
313 diff = diff[~diff_is_nan]
314 diff = list(diff)
315
316 if return_mask:
317 return diff, valid_mask
318 return diff
319
320
321class _NaNCounter(Counter):
322 """Counter with support for nan values."""
323
324 def __init__(self, items):
325 super().__init__(self._generate_items(items))
326
327 def _generate_items(self, items):
328 """Generate items without nans. Stores the nan counts separately."""
329 for item in items:
330 if not is_scalar_nan(item):
331 yield item
332 continue
333 if not hasattr(self, "nan_count"):
334 self.nan_count = 0
335 self.nan_count += 1
336
337 def __missing__(self, key):
338 if hasattr(self, "nan_count") and is_scalar_nan(key):
339 return self.nan_count
340 raise KeyError(key)
341
342
343def _get_counts(values, uniques):
344 """Get the count of each of the `uniques` in `values`.
345
346 The counts will use the order passed in by `uniques`. For non-object dtypes,
347 `uniques` is assumed to be sorted and `np.nan` is at the end.
348 """
349 if values.dtype.kind in "OU":
350 counter = _NaNCounter(values)
351 output = np.zeros(len(uniques), dtype=np.int64)
352 for i, item in enumerate(uniques):
353 with suppress(KeyError):
354 output[i] = counter[item]
355 return output
356
357 unique_values, counts = _unique_np(values, return_counts=True)
358
359 # Recorder unique_values based on input: `uniques`
360 uniques_in_values = np.isin(uniques, unique_values, assume_unique=True)
361 if np.isnan(unique_values[-1]) and np.isnan(uniques[-1]):
362 uniques_in_values[-1] = True
363
364 unique_valid_indices = np.searchsorted(unique_values, uniques[uniques_in_values])
365 output = np.zeros_like(uniques, dtype=np.int64)
366 output[uniques_in_values] = counts[unique_valid_indices]
367 return output