1"""
2For compatibility with numpy libraries, pandas functions or methods have to
3accept '*args' and '**kwargs' parameters to accommodate numpy arguments that
4are not actually used or respected in the pandas implementation.
5
6To ensure that users do not abuse these parameters, validation is performed in
7'validators.py' to make sure that any extra parameters passed correspond ONLY
8to those in the numpy signature. Part of that validation includes whether or
9not the user attempted to pass in non-default values for these extraneous
10parameters. As we want to discourage users from relying on these parameters
11when calling the pandas implementation, we want them only to pass in the
12default values for these parameters.
13
14This module provides a set of commonly used default arguments for functions and
15methods that are spread throughout the codebase. This module will make it
16easier to adjust to future upstream changes in the analogous numpy signatures.
17"""
18from __future__ import annotations
19
20from typing import (
21 TYPE_CHECKING,
22 Any,
23 TypeVar,
24 cast,
25 overload,
26)
27
28import numpy as np
29from numpy import ndarray
30
31from pandas._libs.lib import (
32 is_bool,
33 is_integer,
34)
35from pandas.errors import UnsupportedFunctionCall
36from pandas.util._validators import (
37 validate_args,
38 validate_args_and_kwargs,
39 validate_kwargs,
40)
41
42if TYPE_CHECKING:
43 from pandas._typing import (
44 Axis,
45 AxisInt,
46 )
47
48 AxisNoneT = TypeVar("AxisNoneT", Axis, None)
49
50
51class CompatValidator:
52 def __init__(
53 self,
54 defaults,
55 fname=None,
56 method: str | None = None,
57 max_fname_arg_count=None,
58 ) -> None:
59 self.fname = fname
60 self.method = method
61 self.defaults = defaults
62 self.max_fname_arg_count = max_fname_arg_count
63
64 def __call__(
65 self,
66 args,
67 kwargs,
68 fname=None,
69 max_fname_arg_count=None,
70 method: str | None = None,
71 ) -> None:
72 if not args and not kwargs:
73 return None
74
75 fname = self.fname if fname is None else fname
76 max_fname_arg_count = (
77 self.max_fname_arg_count
78 if max_fname_arg_count is None
79 else max_fname_arg_count
80 )
81 method = self.method if method is None else method
82
83 if method == "args":
84 validate_args(fname, args, max_fname_arg_count, self.defaults)
85 elif method == "kwargs":
86 validate_kwargs(fname, kwargs, self.defaults)
87 elif method == "both":
88 validate_args_and_kwargs(
89 fname, args, kwargs, max_fname_arg_count, self.defaults
90 )
91 else:
92 raise ValueError(f"invalid validation method '{method}'")
93
94
95ARGMINMAX_DEFAULTS = {"out": None}
96validate_argmin = CompatValidator(
97 ARGMINMAX_DEFAULTS, fname="argmin", method="both", max_fname_arg_count=1
98)
99validate_argmax = CompatValidator(
100 ARGMINMAX_DEFAULTS, fname="argmax", method="both", max_fname_arg_count=1
101)
102
103
104def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]:
105 if isinstance(skipna, ndarray) or skipna is None:
106 args = (skipna,) + args
107 skipna = True
108
109 return skipna, args
110
111
112def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
113 """
114 If 'Series.argmin' is called via the 'numpy' library, the third parameter
115 in its signature is 'out', which takes either an ndarray or 'None', so
116 check if the 'skipna' parameter is either an instance of ndarray or is
117 None, since 'skipna' itself should be a boolean
118 """
119 skipna, args = process_skipna(skipna, args)
120 validate_argmin(args, kwargs)
121 return skipna
122
123
124def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
125 """
126 If 'Series.argmax' is called via the 'numpy' library, the third parameter
127 in its signature is 'out', which takes either an ndarray or 'None', so
128 check if the 'skipna' parameter is either an instance of ndarray or is
129 None, since 'skipna' itself should be a boolean
130 """
131 skipna, args = process_skipna(skipna, args)
132 validate_argmax(args, kwargs)
133 return skipna
134
135
136ARGSORT_DEFAULTS: dict[str, int | str | None] = {}
137ARGSORT_DEFAULTS["axis"] = -1
138ARGSORT_DEFAULTS["kind"] = "quicksort"
139ARGSORT_DEFAULTS["order"] = None
140ARGSORT_DEFAULTS["kind"] = None
141ARGSORT_DEFAULTS["stable"] = None
142
143
144validate_argsort = CompatValidator(
145 ARGSORT_DEFAULTS, fname="argsort", max_fname_arg_count=0, method="both"
146)
147
148# two different signatures of argsort, this second validation for when the
149# `kind` param is supported
150ARGSORT_DEFAULTS_KIND: dict[str, int | None] = {}
151ARGSORT_DEFAULTS_KIND["axis"] = -1
152ARGSORT_DEFAULTS_KIND["order"] = None
153ARGSORT_DEFAULTS_KIND["stable"] = None
154validate_argsort_kind = CompatValidator(
155 ARGSORT_DEFAULTS_KIND, fname="argsort", max_fname_arg_count=0, method="both"
156)
157
158
159def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool:
160 """
161 If 'Categorical.argsort' is called via the 'numpy' library, the first
162 parameter in its signature is 'axis', which takes either an integer or
163 'None', so check if the 'ascending' parameter has either integer type or is
164 None, since 'ascending' itself should be a boolean
165 """
166 if is_integer(ascending) or ascending is None:
167 args = (ascending,) + args
168 ascending = True
169
170 validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
171 ascending = cast(bool, ascending)
172 return ascending
173
174
175CLIP_DEFAULTS: dict[str, Any] = {"out": None}
176validate_clip = CompatValidator(
177 CLIP_DEFAULTS, fname="clip", method="both", max_fname_arg_count=3
178)
179
180
181@overload
182def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None:
183 ...
184
185
186@overload
187def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT:
188 ...
189
190
191def validate_clip_with_axis(
192 axis: ndarray | AxisNoneT, args, kwargs
193) -> AxisNoneT | None:
194 """
195 If 'NDFrame.clip' is called via the numpy library, the third parameter in
196 its signature is 'out', which can takes an ndarray, so check if the 'axis'
197 parameter is an instance of ndarray, since 'axis' itself should either be
198 an integer or None
199 """
200 if isinstance(axis, ndarray):
201 args = (axis,) + args
202 # error: Incompatible types in assignment (expression has type "None",
203 # variable has type "Union[ndarray[Any, Any], str, int]")
204 axis = None # type: ignore[assignment]
205
206 validate_clip(args, kwargs)
207 # error: Incompatible return value type (got "Union[ndarray[Any, Any],
208 # str, int]", expected "Union[str, int, None]")
209 return axis # type: ignore[return-value]
210
211
212CUM_FUNC_DEFAULTS: dict[str, Any] = {}
213CUM_FUNC_DEFAULTS["dtype"] = None
214CUM_FUNC_DEFAULTS["out"] = None
215validate_cum_func = CompatValidator(
216 CUM_FUNC_DEFAULTS, method="both", max_fname_arg_count=1
217)
218validate_cumsum = CompatValidator(
219 CUM_FUNC_DEFAULTS, fname="cumsum", method="both", max_fname_arg_count=1
220)
221
222
223def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
224 """
225 If this function is called via the 'numpy' library, the third parameter in
226 its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
227 check if the 'skipna' parameter is a boolean or not
228 """
229 if not is_bool(skipna):
230 args = (skipna,) + args
231 skipna = True
232 elif isinstance(skipna, np.bool_):
233 skipna = bool(skipna)
234
235 validate_cum_func(args, kwargs, fname=name)
236 return skipna
237
238
239ALLANY_DEFAULTS: dict[str, bool | None] = {}
240ALLANY_DEFAULTS["dtype"] = None
241ALLANY_DEFAULTS["out"] = None
242ALLANY_DEFAULTS["keepdims"] = False
243ALLANY_DEFAULTS["axis"] = None
244validate_all = CompatValidator(
245 ALLANY_DEFAULTS, fname="all", method="both", max_fname_arg_count=1
246)
247validate_any = CompatValidator(
248 ALLANY_DEFAULTS, fname="any", method="both", max_fname_arg_count=1
249)
250
251LOGICAL_FUNC_DEFAULTS = {"out": None, "keepdims": False}
252validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs")
253
254MINMAX_DEFAULTS = {"axis": None, "dtype": None, "out": None, "keepdims": False}
255validate_min = CompatValidator(
256 MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1
257)
258validate_max = CompatValidator(
259 MINMAX_DEFAULTS, fname="max", method="both", max_fname_arg_count=1
260)
261
262RESHAPE_DEFAULTS: dict[str, str] = {"order": "C"}
263validate_reshape = CompatValidator(
264 RESHAPE_DEFAULTS, fname="reshape", method="both", max_fname_arg_count=1
265)
266
267REPEAT_DEFAULTS: dict[str, Any] = {"axis": None}
268validate_repeat = CompatValidator(
269 REPEAT_DEFAULTS, fname="repeat", method="both", max_fname_arg_count=1
270)
271
272ROUND_DEFAULTS: dict[str, Any] = {"out": None}
273validate_round = CompatValidator(
274 ROUND_DEFAULTS, fname="round", method="both", max_fname_arg_count=1
275)
276
277SORT_DEFAULTS: dict[str, int | str | None] = {}
278SORT_DEFAULTS["axis"] = -1
279SORT_DEFAULTS["kind"] = "quicksort"
280SORT_DEFAULTS["order"] = None
281validate_sort = CompatValidator(SORT_DEFAULTS, fname="sort", method="kwargs")
282
283STAT_FUNC_DEFAULTS: dict[str, Any | None] = {}
284STAT_FUNC_DEFAULTS["dtype"] = None
285STAT_FUNC_DEFAULTS["out"] = None
286
287SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
288SUM_DEFAULTS["axis"] = None
289SUM_DEFAULTS["keepdims"] = False
290SUM_DEFAULTS["initial"] = None
291
292PROD_DEFAULTS = SUM_DEFAULTS.copy()
293
294MEAN_DEFAULTS = SUM_DEFAULTS.copy()
295
296MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
297MEDIAN_DEFAULTS["overwrite_input"] = False
298MEDIAN_DEFAULTS["keepdims"] = False
299
300STAT_FUNC_DEFAULTS["keepdims"] = False
301
302validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS, method="kwargs")
303validate_sum = CompatValidator(
304 SUM_DEFAULTS, fname="sum", method="both", max_fname_arg_count=1
305)
306validate_prod = CompatValidator(
307 PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1
308)
309validate_mean = CompatValidator(
310 MEAN_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1
311)
312validate_median = CompatValidator(
313 MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1
314)
315
316STAT_DDOF_FUNC_DEFAULTS: dict[str, bool | None] = {}
317STAT_DDOF_FUNC_DEFAULTS["dtype"] = None
318STAT_DDOF_FUNC_DEFAULTS["out"] = None
319STAT_DDOF_FUNC_DEFAULTS["keepdims"] = False
320validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS, method="kwargs")
321
322TAKE_DEFAULTS: dict[str, str | None] = {}
323TAKE_DEFAULTS["out"] = None
324TAKE_DEFAULTS["mode"] = "raise"
325validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
326
327
328def validate_take_with_convert(convert: ndarray | bool | None, args, kwargs) -> bool:
329 """
330 If this function is called via the 'numpy' library, the third parameter in
331 its signature is 'axis', which takes either an ndarray or 'None', so check
332 if the 'convert' parameter is either an instance of ndarray or is None
333 """
334 if isinstance(convert, ndarray) or convert is None:
335 args = (convert,) + args
336 convert = True
337
338 validate_take(args, kwargs, max_fname_arg_count=3, method="both")
339 return convert
340
341
342TRANSPOSE_DEFAULTS = {"axes": None}
343validate_transpose = CompatValidator(
344 TRANSPOSE_DEFAULTS, fname="transpose", method="both", max_fname_arg_count=0
345)
346
347
348def validate_groupby_func(name: str, args, kwargs, allowed=None) -> None:
349 """
350 'args' and 'kwargs' should be empty, except for allowed kwargs because all
351 of their necessary parameters are explicitly listed in the function
352 signature
353 """
354 if allowed is None:
355 allowed = []
356
357 kwargs = set(kwargs) - set(allowed)
358
359 if len(args) + len(kwargs) > 0:
360 raise UnsupportedFunctionCall(
361 "numpy operations are not valid with groupby. "
362 f"Use .groupby(...).{name}() instead"
363 )
364
365
366RESAMPLER_NUMPY_OPS = ("min", "max", "sum", "prod", "mean", "std", "var")
367
368
369def validate_resampler_func(method: str, args, kwargs) -> None:
370 """
371 'args' and 'kwargs' should be empty because all of their necessary
372 parameters are explicitly listed in the function signature
373 """
374 if len(args) + len(kwargs) > 0:
375 if method in RESAMPLER_NUMPY_OPS:
376 raise UnsupportedFunctionCall(
377 "numpy operations are not valid with resample. "
378 f"Use .resample(...).{method}() instead"
379 )
380 raise TypeError("too many arguments passed in")
381
382
383def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
384 """
385 Ensure that the axis argument passed to min, max, argmin, or argmax is zero
386 or None, as otherwise it will be incorrectly ignored.
387
388 Parameters
389 ----------
390 axis : int or None
391 ndim : int, default 1
392
393 Raises
394 ------
395 ValueError
396 """
397 if axis is None:
398 return
399 if axis >= ndim or (axis < 0 and ndim + axis < 0):
400 raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})")
401
402
403_validation_funcs = {
404 "median": validate_median,
405 "mean": validate_mean,
406 "min": validate_min,
407 "max": validate_max,
408 "sum": validate_sum,
409 "prod": validate_prod,
410}
411
412
413def validate_func(fname, args, kwargs) -> None:
414 if fname not in _validation_funcs:
415 return validate_stat_func(args, kwargs, fname=fname)
416
417 validation_func = _validation_funcs[fname]
418 return validation_func(args, kwargs)