1from __future__ import annotations
2
3import functools
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8)
9
10import numpy as np
11
12from pandas.compat._optional import import_optional_dependency
13
14from pandas.core.util.numba_ import jit_user_function
15
16if TYPE_CHECKING:
17 from pandas._typing import Scalar
18
19
20@functools.cache
21def generate_numba_apply_func(
22 func: Callable[..., Scalar],
23 nopython: bool,
24 nogil: bool,
25 parallel: bool,
26):
27 """
28 Generate a numba jitted apply function specified by values from engine_kwargs.
29
30 1. jit the user's function
31 2. Return a rolling apply function with the jitted function inline
32
33 Configurations specified in engine_kwargs apply to both the user's
34 function _AND_ the rolling apply function.
35
36 Parameters
37 ----------
38 func : function
39 function to be applied to each window and will be JITed
40 nopython : bool
41 nopython to be passed into numba.jit
42 nogil : bool
43 nogil to be passed into numba.jit
44 parallel : bool
45 parallel to be passed into numba.jit
46
47 Returns
48 -------
49 Numba function
50 """
51 numba_func = jit_user_function(func)
52 if TYPE_CHECKING:
53 import numba
54 else:
55 numba = import_optional_dependency("numba")
56
57 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
58 def roll_apply(
59 values: np.ndarray,
60 begin: np.ndarray,
61 end: np.ndarray,
62 minimum_periods: int,
63 *args: Any,
64 ) -> np.ndarray:
65 result = np.empty(len(begin))
66 for i in numba.prange(len(result)):
67 start = begin[i]
68 stop = end[i]
69 window = values[start:stop]
70 count_nan = np.sum(np.isnan(window))
71 if len(window) - count_nan >= minimum_periods:
72 result[i] = numba_func(window, *args)
73 else:
74 result[i] = np.nan
75 return result
76
77 return roll_apply
78
79
80@functools.cache
81def generate_numba_ewm_func(
82 nopython: bool,
83 nogil: bool,
84 parallel: bool,
85 com: float,
86 adjust: bool,
87 ignore_na: bool,
88 deltas: tuple,
89 normalize: bool,
90):
91 """
92 Generate a numba jitted ewm mean or sum function specified by values
93 from engine_kwargs.
94
95 Parameters
96 ----------
97 nopython : bool
98 nopython to be passed into numba.jit
99 nogil : bool
100 nogil to be passed into numba.jit
101 parallel : bool
102 parallel to be passed into numba.jit
103 com : float
104 adjust : bool
105 ignore_na : bool
106 deltas : tuple
107 normalize : bool
108
109 Returns
110 -------
111 Numba function
112 """
113 if TYPE_CHECKING:
114 import numba
115 else:
116 numba = import_optional_dependency("numba")
117
118 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
119 def ewm(
120 values: np.ndarray,
121 begin: np.ndarray,
122 end: np.ndarray,
123 minimum_periods: int,
124 ) -> np.ndarray:
125 result = np.empty(len(values))
126 alpha = 1.0 / (1.0 + com)
127 old_wt_factor = 1.0 - alpha
128 new_wt = 1.0 if adjust else alpha
129
130 for i in numba.prange(len(begin)):
131 start = begin[i]
132 stop = end[i]
133 window = values[start:stop]
134 sub_result = np.empty(len(window))
135
136 weighted = window[0]
137 nobs = int(not np.isnan(weighted))
138 sub_result[0] = weighted if nobs >= minimum_periods else np.nan
139 old_wt = 1.0
140
141 for j in range(1, len(window)):
142 cur = window[j]
143 is_observation = not np.isnan(cur)
144 nobs += is_observation
145 if not np.isnan(weighted):
146 if is_observation or not ignore_na:
147 if normalize:
148 # note that len(deltas) = len(vals) - 1 and deltas[i]
149 # is to be used in conjunction with vals[i+1]
150 old_wt *= old_wt_factor ** deltas[start + j - 1]
151 else:
152 weighted = old_wt_factor * weighted
153 if is_observation:
154 if normalize:
155 # avoid numerical errors on constant series
156 if weighted != cur:
157 weighted = old_wt * weighted + new_wt * cur
158 if normalize:
159 weighted = weighted / (old_wt + new_wt)
160 if adjust:
161 old_wt += new_wt
162 else:
163 old_wt = 1.0
164 else:
165 weighted += cur
166 elif is_observation:
167 weighted = cur
168
169 sub_result[j] = weighted if nobs >= minimum_periods else np.nan
170
171 result[start:stop] = sub_result
172
173 return result
174
175 return ewm
176
177
178@functools.cache
179def generate_numba_table_func(
180 func: Callable[..., np.ndarray],
181 nopython: bool,
182 nogil: bool,
183 parallel: bool,
184):
185 """
186 Generate a numba jitted function to apply window calculations table-wise.
187
188 Func will be passed a M window size x N number of columns array, and
189 must return a 1 x N number of columns array. Func is intended to operate
190 row-wise, but the result will be transposed for axis=1.
191
192 1. jit the user's function
193 2. Return a rolling apply function with the jitted function inline
194
195 Parameters
196 ----------
197 func : function
198 function to be applied to each window and will be JITed
199 nopython : bool
200 nopython to be passed into numba.jit
201 nogil : bool
202 nogil to be passed into numba.jit
203 parallel : bool
204 parallel to be passed into numba.jit
205
206 Returns
207 -------
208 Numba function
209 """
210 numba_func = jit_user_function(func)
211 if TYPE_CHECKING:
212 import numba
213 else:
214 numba = import_optional_dependency("numba")
215
216 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
217 def roll_table(
218 values: np.ndarray,
219 begin: np.ndarray,
220 end: np.ndarray,
221 minimum_periods: int,
222 *args: Any,
223 ):
224 result = np.empty((len(begin), values.shape[1]))
225 min_periods_mask = np.empty(result.shape)
226 for i in numba.prange(len(result)):
227 start = begin[i]
228 stop = end[i]
229 window = values[start:stop]
230 count_nan = np.sum(np.isnan(window), axis=0)
231 sub_result = numba_func(window, *args)
232 nan_mask = len(window) - count_nan >= minimum_periods
233 min_periods_mask[i, :] = nan_mask
234 result[i, :] = sub_result
235 result = np.where(min_periods_mask, result, np.nan)
236 return result
237
238 return roll_table
239
240
241# This function will no longer be needed once numba supports
242# axis for all np.nan* agg functions
243# https://github.com/numba/numba/issues/1269
244@functools.cache
245def generate_manual_numpy_nan_agg_with_axis(nan_func):
246 if TYPE_CHECKING:
247 import numba
248 else:
249 numba = import_optional_dependency("numba")
250
251 @numba.jit(nopython=True, nogil=True, parallel=True)
252 def nan_agg_with_axis(table):
253 result = np.empty(table.shape[1])
254 for i in numba.prange(table.shape[1]):
255 partition = table[:, i]
256 result[i] = nan_func(partition)
257 return result
258
259 return nan_agg_with_axis
260
261
262@functools.cache
263def generate_numba_ewm_table_func(
264 nopython: bool,
265 nogil: bool,
266 parallel: bool,
267 com: float,
268 adjust: bool,
269 ignore_na: bool,
270 deltas: tuple,
271 normalize: bool,
272):
273 """
274 Generate a numba jitted ewm mean or sum function applied table wise specified
275 by values from engine_kwargs.
276
277 Parameters
278 ----------
279 nopython : bool
280 nopython to be passed into numba.jit
281 nogil : bool
282 nogil to be passed into numba.jit
283 parallel : bool
284 parallel to be passed into numba.jit
285 com : float
286 adjust : bool
287 ignore_na : bool
288 deltas : tuple
289 normalize: bool
290
291 Returns
292 -------
293 Numba function
294 """
295 if TYPE_CHECKING:
296 import numba
297 else:
298 numba = import_optional_dependency("numba")
299
300 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
301 def ewm_table(
302 values: np.ndarray,
303 begin: np.ndarray,
304 end: np.ndarray,
305 minimum_periods: int,
306 ) -> np.ndarray:
307 alpha = 1.0 / (1.0 + com)
308 old_wt_factor = 1.0 - alpha
309 new_wt = 1.0 if adjust else alpha
310 old_wt = np.ones(values.shape[1])
311
312 result = np.empty(values.shape)
313 weighted = values[0].copy()
314 nobs = (~np.isnan(weighted)).astype(np.int64)
315 result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
316 for i in range(1, len(values)):
317 cur = values[i]
318 is_observations = ~np.isnan(cur)
319 nobs += is_observations.astype(np.int64)
320 for j in numba.prange(len(cur)):
321 if not np.isnan(weighted[j]):
322 if is_observations[j] or not ignore_na:
323 if normalize:
324 # note that len(deltas) = len(vals) - 1 and deltas[i]
325 # is to be used in conjunction with vals[i+1]
326 old_wt[j] *= old_wt_factor ** deltas[i - 1]
327 else:
328 weighted[j] = old_wt_factor * weighted[j]
329 if is_observations[j]:
330 if normalize:
331 # avoid numerical errors on constant series
332 if weighted[j] != cur[j]:
333 weighted[j] = (
334 old_wt[j] * weighted[j] + new_wt * cur[j]
335 )
336 if normalize:
337 weighted[j] = weighted[j] / (old_wt[j] + new_wt)
338 if adjust:
339 old_wt[j] += new_wt
340 else:
341 old_wt[j] = 1.0
342 else:
343 weighted[j] += cur[j]
344 elif is_observations[j]:
345 weighted[j] = cur[j]
346
347 result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
348
349 return result
350
351 return ewm_table