1from __future__ import annotations
2
3from typing import TYPE_CHECKING
4
5import numpy as np
6
7from pandas.compat._optional import import_optional_dependency
8
9
10def generate_online_numba_ewma_func(
11 nopython: bool,
12 nogil: bool,
13 parallel: bool,
14):
15 """
16 Generate a numba jitted groupby ewma function specified by values
17 from engine_kwargs.
18
19 Parameters
20 ----------
21 nopython : bool
22 nopython to be passed into numba.jit
23 nogil : bool
24 nogil to be passed into numba.jit
25 parallel : bool
26 parallel to be passed into numba.jit
27
28 Returns
29 -------
30 Numba function
31 """
32 if TYPE_CHECKING:
33 import numba
34 else:
35 numba = import_optional_dependency("numba")
36
37 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
38 def online_ewma(
39 values: np.ndarray,
40 deltas: np.ndarray,
41 minimum_periods: int,
42 old_wt_factor: float,
43 new_wt: float,
44 old_wt: np.ndarray,
45 adjust: bool,
46 ignore_na: bool,
47 ):
48 """
49 Compute online exponentially weighted mean per column over 2D values.
50
51 Takes the first observation as is, then computes the subsequent
52 exponentially weighted mean accounting minimum periods.
53 """
54 result = np.empty(values.shape)
55 weighted_avg = values[0]
56 nobs = (~np.isnan(weighted_avg)).astype(np.int64)
57 result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
58
59 for i in range(1, len(values)):
60 cur = values[i]
61 is_observations = ~np.isnan(cur)
62 nobs += is_observations.astype(np.int64)
63 for j in numba.prange(len(cur)):
64 if not np.isnan(weighted_avg[j]):
65 if is_observations[j] or not ignore_na:
66 # note that len(deltas) = len(vals) - 1 and deltas[i] is to be
67 # used in conjunction with vals[i+1]
68 old_wt[j] *= old_wt_factor ** deltas[j - 1]
69 if is_observations[j]:
70 # avoid numerical errors on constant series
71 if weighted_avg[j] != cur[j]:
72 weighted_avg[j] = (
73 (old_wt[j] * weighted_avg[j]) + (new_wt * cur[j])
74 ) / (old_wt[j] + new_wt)
75 if adjust:
76 old_wt[j] += new_wt
77 else:
78 old_wt[j] = 1.0
79 elif is_observations[j]:
80 weighted_avg[j] = cur[j]
81
82 result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
83
84 return result, old_wt
85
86 return online_ewma
87
88
89class EWMMeanState:
90 def __init__(self, com, adjust, ignore_na, axis, shape) -> None:
91 alpha = 1.0 / (1.0 + com)
92 self.axis = axis
93 self.shape = shape
94 self.adjust = adjust
95 self.ignore_na = ignore_na
96 self.new_wt = 1.0 if adjust else alpha
97 self.old_wt_factor = 1.0 - alpha
98 self.old_wt = np.ones(self.shape[self.axis - 1])
99 self.last_ewm = None
100
101 def run_ewm(self, weighted_avg, deltas, min_periods, ewm_func):
102 result, old_wt = ewm_func(
103 weighted_avg,
104 deltas,
105 min_periods,
106 self.old_wt_factor,
107 self.new_wt,
108 self.old_wt,
109 self.adjust,
110 self.ignore_na,
111 )
112 self.old_wt = old_wt
113 self.last_ewm = result[-1]
114 return result
115
116 def reset(self) -> None:
117 self.old_wt = np.ones(self.shape[self.axis - 1])
118 self.last_ewm = None