Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/pandas/core/window/numba_.py: 12%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

147 statements  

1from __future__ import annotations 

2 

3import functools 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Callable, 

8) 

9 

10import numpy as np 

11 

12from pandas.compat._optional import import_optional_dependency 

13 

14from pandas.core.util.numba_ import jit_user_function 

15 

16if TYPE_CHECKING: 

17 from pandas._typing import Scalar 

18 

19 

20@functools.cache 

21def generate_numba_apply_func( 

22 func: Callable[..., Scalar], 

23 nopython: bool, 

24 nogil: bool, 

25 parallel: bool, 

26): 

27 """ 

28 Generate a numba jitted apply function specified by values from engine_kwargs. 

29 

30 1. jit the user's function 

31 2. Return a rolling apply function with the jitted function inline 

32 

33 Configurations specified in engine_kwargs apply to both the user's 

34 function _AND_ the rolling apply function. 

35 

36 Parameters 

37 ---------- 

38 func : function 

39 function to be applied to each window and will be JITed 

40 nopython : bool 

41 nopython to be passed into numba.jit 

42 nogil : bool 

43 nogil to be passed into numba.jit 

44 parallel : bool 

45 parallel to be passed into numba.jit 

46 

47 Returns 

48 ------- 

49 Numba function 

50 """ 

51 numba_func = jit_user_function(func) 

52 if TYPE_CHECKING: 

53 import numba 

54 else: 

55 numba = import_optional_dependency("numba") 

56 

57 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

58 def roll_apply( 

59 values: np.ndarray, 

60 begin: np.ndarray, 

61 end: np.ndarray, 

62 minimum_periods: int, 

63 *args: Any, 

64 ) -> np.ndarray: 

65 result = np.empty(len(begin)) 

66 for i in numba.prange(len(result)): 

67 start = begin[i] 

68 stop = end[i] 

69 window = values[start:stop] 

70 count_nan = np.sum(np.isnan(window)) 

71 if len(window) - count_nan >= minimum_periods: 

72 result[i] = numba_func(window, *args) 

73 else: 

74 result[i] = np.nan 

75 return result 

76 

77 return roll_apply 

78 

79 

80@functools.cache 

81def generate_numba_ewm_func( 

82 nopython: bool, 

83 nogil: bool, 

84 parallel: bool, 

85 com: float, 

86 adjust: bool, 

87 ignore_na: bool, 

88 deltas: tuple, 

89 normalize: bool, 

90): 

91 """ 

92 Generate a numba jitted ewm mean or sum function specified by values 

93 from engine_kwargs. 

94 

95 Parameters 

96 ---------- 

97 nopython : bool 

98 nopython to be passed into numba.jit 

99 nogil : bool 

100 nogil to be passed into numba.jit 

101 parallel : bool 

102 parallel to be passed into numba.jit 

103 com : float 

104 adjust : bool 

105 ignore_na : bool 

106 deltas : tuple 

107 normalize : bool 

108 

109 Returns 

110 ------- 

111 Numba function 

112 """ 

113 if TYPE_CHECKING: 

114 import numba 

115 else: 

116 numba = import_optional_dependency("numba") 

117 

118 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

119 def ewm( 

120 values: np.ndarray, 

121 begin: np.ndarray, 

122 end: np.ndarray, 

123 minimum_periods: int, 

124 ) -> np.ndarray: 

125 result = np.empty(len(values)) 

126 alpha = 1.0 / (1.0 + com) 

127 old_wt_factor = 1.0 - alpha 

128 new_wt = 1.0 if adjust else alpha 

129 

130 for i in numba.prange(len(begin)): 

131 start = begin[i] 

132 stop = end[i] 

133 window = values[start:stop] 

134 sub_result = np.empty(len(window)) 

135 

136 weighted = window[0] 

137 nobs = int(not np.isnan(weighted)) 

138 sub_result[0] = weighted if nobs >= minimum_periods else np.nan 

139 old_wt = 1.0 

140 

141 for j in range(1, len(window)): 

142 cur = window[j] 

143 is_observation = not np.isnan(cur) 

144 nobs += is_observation 

145 if not np.isnan(weighted): 

146 if is_observation or not ignore_na: 

147 if normalize: 

148 # note that len(deltas) = len(vals) - 1 and deltas[i] 

149 # is to be used in conjunction with vals[i+1] 

150 old_wt *= old_wt_factor ** deltas[start + j - 1] 

151 else: 

152 weighted = old_wt_factor * weighted 

153 if is_observation: 

154 if normalize: 

155 # avoid numerical errors on constant series 

156 if weighted != cur: 

157 weighted = old_wt * weighted + new_wt * cur 

158 if normalize: 

159 weighted = weighted / (old_wt + new_wt) 

160 if adjust: 

161 old_wt += new_wt 

162 else: 

163 old_wt = 1.0 

164 else: 

165 weighted += cur 

166 elif is_observation: 

167 weighted = cur 

168 

169 sub_result[j] = weighted if nobs >= minimum_periods else np.nan 

170 

171 result[start:stop] = sub_result 

172 

173 return result 

174 

175 return ewm 

176 

177 

178@functools.cache 

179def generate_numba_table_func( 

180 func: Callable[..., np.ndarray], 

181 nopython: bool, 

182 nogil: bool, 

183 parallel: bool, 

184): 

185 """ 

186 Generate a numba jitted function to apply window calculations table-wise. 

187 

188 Func will be passed a M window size x N number of columns array, and 

189 must return a 1 x N number of columns array. Func is intended to operate 

190 row-wise, but the result will be transposed for axis=1. 

191 

192 1. jit the user's function 

193 2. Return a rolling apply function with the jitted function inline 

194 

195 Parameters 

196 ---------- 

197 func : function 

198 function to be applied to each window and will be JITed 

199 nopython : bool 

200 nopython to be passed into numba.jit 

201 nogil : bool 

202 nogil to be passed into numba.jit 

203 parallel : bool 

204 parallel to be passed into numba.jit 

205 

206 Returns 

207 ------- 

208 Numba function 

209 """ 

210 numba_func = jit_user_function(func) 

211 if TYPE_CHECKING: 

212 import numba 

213 else: 

214 numba = import_optional_dependency("numba") 

215 

216 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

217 def roll_table( 

218 values: np.ndarray, 

219 begin: np.ndarray, 

220 end: np.ndarray, 

221 minimum_periods: int, 

222 *args: Any, 

223 ): 

224 result = np.empty((len(begin), values.shape[1])) 

225 min_periods_mask = np.empty(result.shape) 

226 for i in numba.prange(len(result)): 

227 start = begin[i] 

228 stop = end[i] 

229 window = values[start:stop] 

230 count_nan = np.sum(np.isnan(window), axis=0) 

231 sub_result = numba_func(window, *args) 

232 nan_mask = len(window) - count_nan >= minimum_periods 

233 min_periods_mask[i, :] = nan_mask 

234 result[i, :] = sub_result 

235 result = np.where(min_periods_mask, result, np.nan) 

236 return result 

237 

238 return roll_table 

239 

240 

241# This function will no longer be needed once numba supports 

242# axis for all np.nan* agg functions 

243# https://github.com/numba/numba/issues/1269 

244@functools.cache 

245def generate_manual_numpy_nan_agg_with_axis(nan_func): 

246 if TYPE_CHECKING: 

247 import numba 

248 else: 

249 numba = import_optional_dependency("numba") 

250 

251 @numba.jit(nopython=True, nogil=True, parallel=True) 

252 def nan_agg_with_axis(table): 

253 result = np.empty(table.shape[1]) 

254 for i in numba.prange(table.shape[1]): 

255 partition = table[:, i] 

256 result[i] = nan_func(partition) 

257 return result 

258 

259 return nan_agg_with_axis 

260 

261 

262@functools.cache 

263def generate_numba_ewm_table_func( 

264 nopython: bool, 

265 nogil: bool, 

266 parallel: bool, 

267 com: float, 

268 adjust: bool, 

269 ignore_na: bool, 

270 deltas: tuple, 

271 normalize: bool, 

272): 

273 """ 

274 Generate a numba jitted ewm mean or sum function applied table wise specified 

275 by values from engine_kwargs. 

276 

277 Parameters 

278 ---------- 

279 nopython : bool 

280 nopython to be passed into numba.jit 

281 nogil : bool 

282 nogil to be passed into numba.jit 

283 parallel : bool 

284 parallel to be passed into numba.jit 

285 com : float 

286 adjust : bool 

287 ignore_na : bool 

288 deltas : tuple 

289 normalize: bool 

290 

291 Returns 

292 ------- 

293 Numba function 

294 """ 

295 if TYPE_CHECKING: 

296 import numba 

297 else: 

298 numba = import_optional_dependency("numba") 

299 

300 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

301 def ewm_table( 

302 values: np.ndarray, 

303 begin: np.ndarray, 

304 end: np.ndarray, 

305 minimum_periods: int, 

306 ) -> np.ndarray: 

307 alpha = 1.0 / (1.0 + com) 

308 old_wt_factor = 1.0 - alpha 

309 new_wt = 1.0 if adjust else alpha 

310 old_wt = np.ones(values.shape[1]) 

311 

312 result = np.empty(values.shape) 

313 weighted = values[0].copy() 

314 nobs = (~np.isnan(weighted)).astype(np.int64) 

315 result[0] = np.where(nobs >= minimum_periods, weighted, np.nan) 

316 for i in range(1, len(values)): 

317 cur = values[i] 

318 is_observations = ~np.isnan(cur) 

319 nobs += is_observations.astype(np.int64) 

320 for j in numba.prange(len(cur)): 

321 if not np.isnan(weighted[j]): 

322 if is_observations[j] or not ignore_na: 

323 if normalize: 

324 # note that len(deltas) = len(vals) - 1 and deltas[i] 

325 # is to be used in conjunction with vals[i+1] 

326 old_wt[j] *= old_wt_factor ** deltas[i - 1] 

327 else: 

328 weighted[j] = old_wt_factor * weighted[j] 

329 if is_observations[j]: 

330 if normalize: 

331 # avoid numerical errors on constant series 

332 if weighted[j] != cur[j]: 

333 weighted[j] = ( 

334 old_wt[j] * weighted[j] + new_wt * cur[j] 

335 ) 

336 if normalize: 

337 weighted[j] = weighted[j] / (old_wt[j] + new_wt) 

338 if adjust: 

339 old_wt[j] += new_wt 

340 else: 

341 old_wt[j] = 1.0 

342 else: 

343 weighted[j] += cur[j] 

344 elif is_observations[j]: 

345 weighted[j] = cur[j] 

346 

347 result[i] = np.where(nobs >= minimum_periods, weighted, np.nan) 

348 

349 return result 

350 

351 return ewm_table