1from __future__ import annotations
2
3import functools
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8)
9
10if TYPE_CHECKING:
11 from pandas._typing import Scalar
12
13import numpy as np
14
15from pandas.compat._optional import import_optional_dependency
16
17
18@functools.cache
19def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
20 if TYPE_CHECKING:
21 import numba
22 else:
23 numba = import_optional_dependency("numba")
24 nb_compat_func = numba.extending.register_jitable(func)
25
26 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
27 def nb_looper(values, axis):
28 # Operate on the first row/col in order to get
29 # the output shape
30 if axis == 0:
31 first_elem = values[:, 0]
32 dim0 = values.shape[1]
33 else:
34 first_elem = values[0]
35 dim0 = values.shape[0]
36 res0 = nb_compat_func(first_elem)
37 # Use np.asarray to get shape for
38 # https://github.com/numba/numba/issues/4202#issuecomment-1185981507
39 buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
40 if axis == 0:
41 buf_shape = buf_shape[::-1]
42 buff = np.empty(buf_shape)
43
44 if axis == 1:
45 buff[0] = res0
46 for i in numba.prange(1, values.shape[0]):
47 buff[i] = nb_compat_func(values[i])
48 else:
49 buff[:, 0] = res0
50 for j in numba.prange(1, values.shape[1]):
51 buff[:, j] = nb_compat_func(values[:, j])
52 return buff
53
54 return nb_looper
55
56
57@functools.cache
58def make_looper(func, result_dtype, is_grouped_kernel, nopython, nogil, parallel):
59 if TYPE_CHECKING:
60 import numba
61 else:
62 numba = import_optional_dependency("numba")
63
64 if is_grouped_kernel:
65
66 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
67 def column_looper(
68 values: np.ndarray,
69 labels: np.ndarray,
70 ngroups: int,
71 min_periods: int,
72 *args,
73 ):
74 result = np.empty((values.shape[0], ngroups), dtype=result_dtype)
75 na_positions = {}
76 for i in numba.prange(values.shape[0]):
77 output, na_pos = func(
78 values[i], result_dtype, labels, ngroups, min_periods, *args
79 )
80 result[i] = output
81 if len(na_pos) > 0:
82 na_positions[i] = np.array(na_pos)
83 return result, na_positions
84
85 else:
86
87 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
88 def column_looper(
89 values: np.ndarray,
90 start: np.ndarray,
91 end: np.ndarray,
92 min_periods: int,
93 *args,
94 ):
95 result = np.empty((values.shape[0], len(start)), dtype=result_dtype)
96 na_positions = {}
97 for i in numba.prange(values.shape[0]):
98 output, na_pos = func(
99 values[i], result_dtype, start, end, min_periods, *args
100 )
101 result[i] = output
102 if len(na_pos) > 0:
103 na_positions[i] = np.array(na_pos)
104 return result, na_positions
105
106 return column_looper
107
108
109default_dtype_mapping: dict[np.dtype, Any] = {
110 np.dtype("int8"): np.int64,
111 np.dtype("int16"): np.int64,
112 np.dtype("int32"): np.int64,
113 np.dtype("int64"): np.int64,
114 np.dtype("uint8"): np.uint64,
115 np.dtype("uint16"): np.uint64,
116 np.dtype("uint32"): np.uint64,
117 np.dtype("uint64"): np.uint64,
118 np.dtype("float32"): np.float64,
119 np.dtype("float64"): np.float64,
120 np.dtype("complex64"): np.complex128,
121 np.dtype("complex128"): np.complex128,
122}
123
124
125# TODO: Preserve complex dtypes
126
127float_dtype_mapping: dict[np.dtype, Any] = {
128 np.dtype("int8"): np.float64,
129 np.dtype("int16"): np.float64,
130 np.dtype("int32"): np.float64,
131 np.dtype("int64"): np.float64,
132 np.dtype("uint8"): np.float64,
133 np.dtype("uint16"): np.float64,
134 np.dtype("uint32"): np.float64,
135 np.dtype("uint64"): np.float64,
136 np.dtype("float32"): np.float64,
137 np.dtype("float64"): np.float64,
138 np.dtype("complex64"): np.float64,
139 np.dtype("complex128"): np.float64,
140}
141
142identity_dtype_mapping: dict[np.dtype, Any] = {
143 np.dtype("int8"): np.int8,
144 np.dtype("int16"): np.int16,
145 np.dtype("int32"): np.int32,
146 np.dtype("int64"): np.int64,
147 np.dtype("uint8"): np.uint8,
148 np.dtype("uint16"): np.uint16,
149 np.dtype("uint32"): np.uint32,
150 np.dtype("uint64"): np.uint64,
151 np.dtype("float32"): np.float32,
152 np.dtype("float64"): np.float64,
153 np.dtype("complex64"): np.complex64,
154 np.dtype("complex128"): np.complex128,
155}
156
157
158def generate_shared_aggregator(
159 func: Callable[..., Scalar],
160 dtype_mapping: dict[np.dtype, np.dtype],
161 is_grouped_kernel: bool,
162 nopython: bool,
163 nogil: bool,
164 parallel: bool,
165):
166 """
167 Generate a Numba function that loops over the columns 2D object and applies
168 a 1D numba kernel over each column.
169
170 Parameters
171 ----------
172 func : function
173 aggregation function to be applied to each column
174 dtype_mapping: dict or None
175 If not None, maps a dtype to a result dtype.
176 Otherwise, will fall back to default mapping.
177 is_grouped_kernel: bool, default False
178 Whether func operates using the group labels (True)
179 or using starts/ends arrays
180
181 If true, you also need to pass the number of groups to this function
182 nopython : bool
183 nopython to be passed into numba.jit
184 nogil : bool
185 nogil to be passed into numba.jit
186 parallel : bool
187 parallel to be passed into numba.jit
188
189 Returns
190 -------
191 Numba function
192 """
193
194 # A wrapper around the looper function,
195 # to dispatch based on dtype since numba is unable to do that in nopython mode
196
197 # It also post-processes the values by inserting nans where number of observations
198 # is less than min_periods
199 # Cannot do this in numba nopython mode
200 # (you'll run into type-unification error when you cast int -> float)
201 def looper_wrapper(
202 values,
203 start=None,
204 end=None,
205 labels=None,
206 ngroups=None,
207 min_periods: int = 0,
208 **kwargs,
209 ):
210 result_dtype = dtype_mapping[values.dtype]
211 column_looper = make_looper(
212 func, result_dtype, is_grouped_kernel, nopython, nogil, parallel
213 )
214 # Need to unpack kwargs since numba only supports *args
215 if is_grouped_kernel:
216 result, na_positions = column_looper(
217 values, labels, ngroups, min_periods, *kwargs.values()
218 )
219 else:
220 result, na_positions = column_looper(
221 values, start, end, min_periods, *kwargs.values()
222 )
223 if result.dtype.kind == "i":
224 # Look if na_positions is not empty
225 # If so, convert the whole block
226 # This is OK since int dtype cannot hold nan,
227 # so if min_periods not satisfied for 1 col, it is not satisfied for
228 # all columns at that index
229 for na_pos in na_positions.values():
230 if len(na_pos) > 0:
231 result = result.astype("float64")
232 break
233 # TODO: Optimize this
234 for i, na_pos in na_positions.items():
235 if len(na_pos) > 0:
236 result[i, na_pos] = np.nan
237 return result
238
239 return looper_wrapper