1"""
2masked_reductions.py is for reduction algorithms using a mask-based approach
3for missing values.
4"""
5from __future__ import annotations
6
7from typing import (
8 TYPE_CHECKING,
9 Callable,
10)
11import warnings
12
13import numpy as np
14
15from pandas._libs import missing as libmissing
16
17from pandas.core.nanops import check_below_min_count
18
19if TYPE_CHECKING:
20 from pandas._typing import (
21 AxisInt,
22 npt,
23 )
24
25
26def _reductions(
27 func: Callable,
28 values: np.ndarray,
29 mask: npt.NDArray[np.bool_],
30 *,
31 skipna: bool = True,
32 min_count: int = 0,
33 axis: AxisInt | None = None,
34 **kwargs,
35):
36 """
37 Sum, mean or product for 1D masked array.
38
39 Parameters
40 ----------
41 func : np.sum or np.prod
42 values : np.ndarray
43 Numpy array with the values (can be of any dtype that support the
44 operation).
45 mask : np.ndarray[bool]
46 Boolean numpy array (True values indicate missing values).
47 skipna : bool, default True
48 Whether to skip NA.
49 min_count : int, default 0
50 The required number of valid values to perform the operation. If fewer than
51 ``min_count`` non-NA values are present the result will be NA.
52 axis : int, optional, default None
53 """
54 if not skipna:
55 if mask.any() or check_below_min_count(values.shape, None, min_count):
56 return libmissing.NA
57 else:
58 return func(values, axis=axis, **kwargs)
59 else:
60 if check_below_min_count(values.shape, mask, min_count) and (
61 axis is None or values.ndim == 1
62 ):
63 return libmissing.NA
64
65 return func(values, where=~mask, axis=axis, **kwargs)
66
67
68def sum(
69 values: np.ndarray,
70 mask: npt.NDArray[np.bool_],
71 *,
72 skipna: bool = True,
73 min_count: int = 0,
74 axis: AxisInt | None = None,
75):
76 return _reductions(
77 np.sum, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
78 )
79
80
81def prod(
82 values: np.ndarray,
83 mask: npt.NDArray[np.bool_],
84 *,
85 skipna: bool = True,
86 min_count: int = 0,
87 axis: AxisInt | None = None,
88):
89 return _reductions(
90 np.prod, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
91 )
92
93
94def _minmax(
95 func: Callable,
96 values: np.ndarray,
97 mask: npt.NDArray[np.bool_],
98 *,
99 skipna: bool = True,
100 axis: AxisInt | None = None,
101):
102 """
103 Reduction for 1D masked array.
104
105 Parameters
106 ----------
107 func : np.min or np.max
108 values : np.ndarray
109 Numpy array with the values (can be of any dtype that support the
110 operation).
111 mask : np.ndarray[bool]
112 Boolean numpy array (True values indicate missing values).
113 skipna : bool, default True
114 Whether to skip NA.
115 axis : int, optional, default None
116 """
117 if not skipna:
118 if mask.any() or not values.size:
119 # min/max with empty array raise in numpy, pandas returns NA
120 return libmissing.NA
121 else:
122 return func(values, axis=axis)
123 else:
124 subset = values[~mask]
125 if subset.size:
126 return func(subset, axis=axis)
127 else:
128 # min/max with empty array raise in numpy, pandas returns NA
129 return libmissing.NA
130
131
132def min(
133 values: np.ndarray,
134 mask: npt.NDArray[np.bool_],
135 *,
136 skipna: bool = True,
137 axis: AxisInt | None = None,
138):
139 return _minmax(np.min, values=values, mask=mask, skipna=skipna, axis=axis)
140
141
142def max(
143 values: np.ndarray,
144 mask: npt.NDArray[np.bool_],
145 *,
146 skipna: bool = True,
147 axis: AxisInt | None = None,
148):
149 return _minmax(np.max, values=values, mask=mask, skipna=skipna, axis=axis)
150
151
152def mean(
153 values: np.ndarray,
154 mask: npt.NDArray[np.bool_],
155 *,
156 skipna: bool = True,
157 axis: AxisInt | None = None,
158):
159 if not values.size or mask.all():
160 return libmissing.NA
161 return _reductions(np.mean, values=values, mask=mask, skipna=skipna, axis=axis)
162
163
164def var(
165 values: np.ndarray,
166 mask: npt.NDArray[np.bool_],
167 *,
168 skipna: bool = True,
169 axis: AxisInt | None = None,
170 ddof: int = 1,
171):
172 if not values.size or mask.all():
173 return libmissing.NA
174
175 with warnings.catch_warnings():
176 warnings.simplefilter("ignore", RuntimeWarning)
177 return _reductions(
178 np.var, values=values, mask=mask, skipna=skipna, axis=axis, ddof=ddof
179 )
180
181
182def std(
183 values: np.ndarray,
184 mask: npt.NDArray[np.bool_],
185 *,
186 skipna: bool = True,
187 axis: AxisInt | None = None,
188 ddof: int = 1,
189):
190 if not values.size or mask.all():
191 return libmissing.NA
192
193 with warnings.catch_warnings():
194 warnings.simplefilter("ignore", RuntimeWarning)
195 return _reductions(
196 np.std, values=values, mask=mask, skipna=skipna, axis=axis, ddof=ddof
197 )