1from __future__ import annotations
2
3import functools
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8)
9
10import numpy as np
11
12from pandas._typing import Scalar
13from pandas.compat._optional import import_optional_dependency
14
15from pandas.core.util.numba_ import jit_user_function
16
17
18@functools.lru_cache(maxsize=None)
19def generate_numba_apply_func(
20 func: Callable[..., Scalar],
21 nopython: bool,
22 nogil: bool,
23 parallel: bool,
24):
25 """
26 Generate a numba jitted apply function specified by values from engine_kwargs.
27
28 1. jit the user's function
29 2. Return a rolling apply function with the jitted function inline
30
31 Configurations specified in engine_kwargs apply to both the user's
32 function _AND_ the rolling apply function.
33
34 Parameters
35 ----------
36 func : function
37 function to be applied to each window and will be JITed
38 nopython : bool
39 nopython to be passed into numba.jit
40 nogil : bool
41 nogil to be passed into numba.jit
42 parallel : bool
43 parallel to be passed into numba.jit
44
45 Returns
46 -------
47 Numba function
48 """
49 numba_func = jit_user_function(func, nopython, nogil, parallel)
50 if TYPE_CHECKING:
51 import numba
52 else:
53 numba = import_optional_dependency("numba")
54
55 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
56 def roll_apply(
57 values: np.ndarray,
58 begin: np.ndarray,
59 end: np.ndarray,
60 minimum_periods: int,
61 *args: Any,
62 ) -> np.ndarray:
63 result = np.empty(len(begin))
64 for i in numba.prange(len(result)):
65 start = begin[i]
66 stop = end[i]
67 window = values[start:stop]
68 count_nan = np.sum(np.isnan(window))
69 if len(window) - count_nan >= minimum_periods:
70 result[i] = numba_func(window, *args)
71 else:
72 result[i] = np.nan
73 return result
74
75 return roll_apply
76
77
78@functools.lru_cache(maxsize=None)
79def generate_numba_ewm_func(
80 nopython: bool,
81 nogil: bool,
82 parallel: bool,
83 com: float,
84 adjust: bool,
85 ignore_na: bool,
86 deltas: tuple,
87 normalize: bool,
88):
89 """
90 Generate a numba jitted ewm mean or sum function specified by values
91 from engine_kwargs.
92
93 Parameters
94 ----------
95 nopython : bool
96 nopython to be passed into numba.jit
97 nogil : bool
98 nogil to be passed into numba.jit
99 parallel : bool
100 parallel to be passed into numba.jit
101 com : float
102 adjust : bool
103 ignore_na : bool
104 deltas : tuple
105 normalize : bool
106
107 Returns
108 -------
109 Numba function
110 """
111 if TYPE_CHECKING:
112 import numba
113 else:
114 numba = import_optional_dependency("numba")
115
116 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
117 def ewm(
118 values: np.ndarray,
119 begin: np.ndarray,
120 end: np.ndarray,
121 minimum_periods: int,
122 ) -> np.ndarray:
123 result = np.empty(len(values))
124 alpha = 1.0 / (1.0 + com)
125 old_wt_factor = 1.0 - alpha
126 new_wt = 1.0 if adjust else alpha
127
128 for i in numba.prange(len(begin)):
129 start = begin[i]
130 stop = end[i]
131 window = values[start:stop]
132 sub_result = np.empty(len(window))
133
134 weighted = window[0]
135 nobs = int(not np.isnan(weighted))
136 sub_result[0] = weighted if nobs >= minimum_periods else np.nan
137 old_wt = 1.0
138
139 for j in range(1, len(window)):
140 cur = window[j]
141 is_observation = not np.isnan(cur)
142 nobs += is_observation
143 if not np.isnan(weighted):
144 if is_observation or not ignore_na:
145 if normalize:
146 # note that len(deltas) = len(vals) - 1 and deltas[i]
147 # is to be used in conjunction with vals[i+1]
148 old_wt *= old_wt_factor ** deltas[start + j - 1]
149 else:
150 weighted = old_wt_factor * weighted
151 if is_observation:
152 if normalize:
153 # avoid numerical errors on constant series
154 if weighted != cur:
155 weighted = old_wt * weighted + new_wt * cur
156 if normalize:
157 weighted = weighted / (old_wt + new_wt)
158 if adjust:
159 old_wt += new_wt
160 else:
161 old_wt = 1.0
162 else:
163 weighted += cur
164 elif is_observation:
165 weighted = cur
166
167 sub_result[j] = weighted if nobs >= minimum_periods else np.nan
168
169 result[start:stop] = sub_result
170
171 return result
172
173 return ewm
174
175
176@functools.lru_cache(maxsize=None)
177def generate_numba_table_func(
178 func: Callable[..., np.ndarray],
179 nopython: bool,
180 nogil: bool,
181 parallel: bool,
182):
183 """
184 Generate a numba jitted function to apply window calculations table-wise.
185
186 Func will be passed a M window size x N number of columns array, and
187 must return a 1 x N number of columns array. Func is intended to operate
188 row-wise, but the result will be transposed for axis=1.
189
190 1. jit the user's function
191 2. Return a rolling apply function with the jitted function inline
192
193 Parameters
194 ----------
195 func : function
196 function to be applied to each window and will be JITed
197 nopython : bool
198 nopython to be passed into numba.jit
199 nogil : bool
200 nogil to be passed into numba.jit
201 parallel : bool
202 parallel to be passed into numba.jit
203
204 Returns
205 -------
206 Numba function
207 """
208 numba_func = jit_user_function(func, nopython, nogil, parallel)
209 if TYPE_CHECKING:
210 import numba
211 else:
212 numba = import_optional_dependency("numba")
213
214 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
215 def roll_table(
216 values: np.ndarray,
217 begin: np.ndarray,
218 end: np.ndarray,
219 minimum_periods: int,
220 *args: Any,
221 ):
222 result = np.empty((len(begin), values.shape[1]))
223 min_periods_mask = np.empty(result.shape)
224 for i in numba.prange(len(result)):
225 start = begin[i]
226 stop = end[i]
227 window = values[start:stop]
228 count_nan = np.sum(np.isnan(window), axis=0)
229 sub_result = numba_func(window, *args)
230 nan_mask = len(window) - count_nan >= minimum_periods
231 min_periods_mask[i, :] = nan_mask
232 result[i, :] = sub_result
233 result = np.where(min_periods_mask, result, np.nan)
234 return result
235
236 return roll_table
237
238
239# This function will no longer be needed once numba supports
240# axis for all np.nan* agg functions
241# https://github.com/numba/numba/issues/1269
242@functools.lru_cache(maxsize=None)
243def generate_manual_numpy_nan_agg_with_axis(nan_func):
244 if TYPE_CHECKING:
245 import numba
246 else:
247 numba = import_optional_dependency("numba")
248
249 @numba.jit(nopython=True, nogil=True, parallel=True)
250 def nan_agg_with_axis(table):
251 result = np.empty(table.shape[1])
252 for i in numba.prange(table.shape[1]):
253 partition = table[:, i]
254 result[i] = nan_func(partition)
255 return result
256
257 return nan_agg_with_axis
258
259
260@functools.lru_cache(maxsize=None)
261def generate_numba_ewm_table_func(
262 nopython: bool,
263 nogil: bool,
264 parallel: bool,
265 com: float,
266 adjust: bool,
267 ignore_na: bool,
268 deltas: tuple,
269 normalize: bool,
270):
271 """
272 Generate a numba jitted ewm mean or sum function applied table wise specified
273 by values from engine_kwargs.
274
275 Parameters
276 ----------
277 nopython : bool
278 nopython to be passed into numba.jit
279 nogil : bool
280 nogil to be passed into numba.jit
281 parallel : bool
282 parallel to be passed into numba.jit
283 com : float
284 adjust : bool
285 ignore_na : bool
286 deltas : tuple
287 normalize: bool
288
289 Returns
290 -------
291 Numba function
292 """
293 if TYPE_CHECKING:
294 import numba
295 else:
296 numba = import_optional_dependency("numba")
297
298 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
299 def ewm_table(
300 values: np.ndarray,
301 begin: np.ndarray,
302 end: np.ndarray,
303 minimum_periods: int,
304 ) -> np.ndarray:
305 alpha = 1.0 / (1.0 + com)
306 old_wt_factor = 1.0 - alpha
307 new_wt = 1.0 if adjust else alpha
308 old_wt = np.ones(values.shape[1])
309
310 result = np.empty(values.shape)
311 weighted = values[0].copy()
312 nobs = (~np.isnan(weighted)).astype(np.int64)
313 result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
314 for i in range(1, len(values)):
315 cur = values[i]
316 is_observations = ~np.isnan(cur)
317 nobs += is_observations.astype(np.int64)
318 for j in numba.prange(len(cur)):
319 if not np.isnan(weighted[j]):
320 if is_observations[j] or not ignore_na:
321 if normalize:
322 # note that len(deltas) = len(vals) - 1 and deltas[i]
323 # is to be used in conjunction with vals[i+1]
324 old_wt[j] *= old_wt_factor ** deltas[i - 1]
325 else:
326 weighted[j] = old_wt_factor * weighted[j]
327 if is_observations[j]:
328 if normalize:
329 # avoid numerical errors on constant series
330 if weighted[j] != cur[j]:
331 weighted[j] = (
332 old_wt[j] * weighted[j] + new_wt * cur[j]
333 )
334 if normalize:
335 weighted[j] = weighted[j] / (old_wt[j] + new_wt)
336 if adjust:
337 old_wt[j] += new_wt
338 else:
339 old_wt[j] = 1.0
340 else:
341 weighted[j] += cur[j]
342 elif is_observations[j]:
343 weighted[j] = cur[j]
344
345 result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
346
347 return result
348
349 return ewm_table