Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pandas/core/window/numba_.py: 12%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

146 statements  

1from __future__ import annotations 

2 

3import functools 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Callable, 

8) 

9 

10import numpy as np 

11 

12from pandas._typing import Scalar 

13from pandas.compat._optional import import_optional_dependency 

14 

15from pandas.core.util.numba_ import jit_user_function 

16 

17 

18@functools.lru_cache(maxsize=None) 

19def generate_numba_apply_func( 

20 func: Callable[..., Scalar], 

21 nopython: bool, 

22 nogil: bool, 

23 parallel: bool, 

24): 

25 """ 

26 Generate a numba jitted apply function specified by values from engine_kwargs. 

27 

28 1. jit the user's function 

29 2. Return a rolling apply function with the jitted function inline 

30 

31 Configurations specified in engine_kwargs apply to both the user's 

32 function _AND_ the rolling apply function. 

33 

34 Parameters 

35 ---------- 

36 func : function 

37 function to be applied to each window and will be JITed 

38 nopython : bool 

39 nopython to be passed into numba.jit 

40 nogil : bool 

41 nogil to be passed into numba.jit 

42 parallel : bool 

43 parallel to be passed into numba.jit 

44 

45 Returns 

46 ------- 

47 Numba function 

48 """ 

49 numba_func = jit_user_function(func, nopython, nogil, parallel) 

50 if TYPE_CHECKING: 

51 import numba 

52 else: 

53 numba = import_optional_dependency("numba") 

54 

55 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

56 def roll_apply( 

57 values: np.ndarray, 

58 begin: np.ndarray, 

59 end: np.ndarray, 

60 minimum_periods: int, 

61 *args: Any, 

62 ) -> np.ndarray: 

63 result = np.empty(len(begin)) 

64 for i in numba.prange(len(result)): 

65 start = begin[i] 

66 stop = end[i] 

67 window = values[start:stop] 

68 count_nan = np.sum(np.isnan(window)) 

69 if len(window) - count_nan >= minimum_periods: 

70 result[i] = numba_func(window, *args) 

71 else: 

72 result[i] = np.nan 

73 return result 

74 

75 return roll_apply 

76 

77 

78@functools.lru_cache(maxsize=None) 

79def generate_numba_ewm_func( 

80 nopython: bool, 

81 nogil: bool, 

82 parallel: bool, 

83 com: float, 

84 adjust: bool, 

85 ignore_na: bool, 

86 deltas: tuple, 

87 normalize: bool, 

88): 

89 """ 

90 Generate a numba jitted ewm mean or sum function specified by values 

91 from engine_kwargs. 

92 

93 Parameters 

94 ---------- 

95 nopython : bool 

96 nopython to be passed into numba.jit 

97 nogil : bool 

98 nogil to be passed into numba.jit 

99 parallel : bool 

100 parallel to be passed into numba.jit 

101 com : float 

102 adjust : bool 

103 ignore_na : bool 

104 deltas : tuple 

105 normalize : bool 

106 

107 Returns 

108 ------- 

109 Numba function 

110 """ 

111 if TYPE_CHECKING: 

112 import numba 

113 else: 

114 numba = import_optional_dependency("numba") 

115 

116 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

117 def ewm( 

118 values: np.ndarray, 

119 begin: np.ndarray, 

120 end: np.ndarray, 

121 minimum_periods: int, 

122 ) -> np.ndarray: 

123 result = np.empty(len(values)) 

124 alpha = 1.0 / (1.0 + com) 

125 old_wt_factor = 1.0 - alpha 

126 new_wt = 1.0 if adjust else alpha 

127 

128 for i in numba.prange(len(begin)): 

129 start = begin[i] 

130 stop = end[i] 

131 window = values[start:stop] 

132 sub_result = np.empty(len(window)) 

133 

134 weighted = window[0] 

135 nobs = int(not np.isnan(weighted)) 

136 sub_result[0] = weighted if nobs >= minimum_periods else np.nan 

137 old_wt = 1.0 

138 

139 for j in range(1, len(window)): 

140 cur = window[j] 

141 is_observation = not np.isnan(cur) 

142 nobs += is_observation 

143 if not np.isnan(weighted): 

144 if is_observation or not ignore_na: 

145 if normalize: 

146 # note that len(deltas) = len(vals) - 1 and deltas[i] 

147 # is to be used in conjunction with vals[i+1] 

148 old_wt *= old_wt_factor ** deltas[start + j - 1] 

149 else: 

150 weighted = old_wt_factor * weighted 

151 if is_observation: 

152 if normalize: 

153 # avoid numerical errors on constant series 

154 if weighted != cur: 

155 weighted = old_wt * weighted + new_wt * cur 

156 if normalize: 

157 weighted = weighted / (old_wt + new_wt) 

158 if adjust: 

159 old_wt += new_wt 

160 else: 

161 old_wt = 1.0 

162 else: 

163 weighted += cur 

164 elif is_observation: 

165 weighted = cur 

166 

167 sub_result[j] = weighted if nobs >= minimum_periods else np.nan 

168 

169 result[start:stop] = sub_result 

170 

171 return result 

172 

173 return ewm 

174 

175 

176@functools.lru_cache(maxsize=None) 

177def generate_numba_table_func( 

178 func: Callable[..., np.ndarray], 

179 nopython: bool, 

180 nogil: bool, 

181 parallel: bool, 

182): 

183 """ 

184 Generate a numba jitted function to apply window calculations table-wise. 

185 

186 Func will be passed a M window size x N number of columns array, and 

187 must return a 1 x N number of columns array. Func is intended to operate 

188 row-wise, but the result will be transposed for axis=1. 

189 

190 1. jit the user's function 

191 2. Return a rolling apply function with the jitted function inline 

192 

193 Parameters 

194 ---------- 

195 func : function 

196 function to be applied to each window and will be JITed 

197 nopython : bool 

198 nopython to be passed into numba.jit 

199 nogil : bool 

200 nogil to be passed into numba.jit 

201 parallel : bool 

202 parallel to be passed into numba.jit 

203 

204 Returns 

205 ------- 

206 Numba function 

207 """ 

208 numba_func = jit_user_function(func, nopython, nogil, parallel) 

209 if TYPE_CHECKING: 

210 import numba 

211 else: 

212 numba = import_optional_dependency("numba") 

213 

214 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

215 def roll_table( 

216 values: np.ndarray, 

217 begin: np.ndarray, 

218 end: np.ndarray, 

219 minimum_periods: int, 

220 *args: Any, 

221 ): 

222 result = np.empty((len(begin), values.shape[1])) 

223 min_periods_mask = np.empty(result.shape) 

224 for i in numba.prange(len(result)): 

225 start = begin[i] 

226 stop = end[i] 

227 window = values[start:stop] 

228 count_nan = np.sum(np.isnan(window), axis=0) 

229 sub_result = numba_func(window, *args) 

230 nan_mask = len(window) - count_nan >= minimum_periods 

231 min_periods_mask[i, :] = nan_mask 

232 result[i, :] = sub_result 

233 result = np.where(min_periods_mask, result, np.nan) 

234 return result 

235 

236 return roll_table 

237 

238 

239# This function will no longer be needed once numba supports 

240# axis for all np.nan* agg functions 

241# https://github.com/numba/numba/issues/1269 

242@functools.lru_cache(maxsize=None) 

243def generate_manual_numpy_nan_agg_with_axis(nan_func): 

244 if TYPE_CHECKING: 

245 import numba 

246 else: 

247 numba = import_optional_dependency("numba") 

248 

249 @numba.jit(nopython=True, nogil=True, parallel=True) 

250 def nan_agg_with_axis(table): 

251 result = np.empty(table.shape[1]) 

252 for i in numba.prange(table.shape[1]): 

253 partition = table[:, i] 

254 result[i] = nan_func(partition) 

255 return result 

256 

257 return nan_agg_with_axis 

258 

259 

260@functools.lru_cache(maxsize=None) 

261def generate_numba_ewm_table_func( 

262 nopython: bool, 

263 nogil: bool, 

264 parallel: bool, 

265 com: float, 

266 adjust: bool, 

267 ignore_na: bool, 

268 deltas: tuple, 

269 normalize: bool, 

270): 

271 """ 

272 Generate a numba jitted ewm mean or sum function applied table wise specified 

273 by values from engine_kwargs. 

274 

275 Parameters 

276 ---------- 

277 nopython : bool 

278 nopython to be passed into numba.jit 

279 nogil : bool 

280 nogil to be passed into numba.jit 

281 parallel : bool 

282 parallel to be passed into numba.jit 

283 com : float 

284 adjust : bool 

285 ignore_na : bool 

286 deltas : tuple 

287 normalize: bool 

288 

289 Returns 

290 ------- 

291 Numba function 

292 """ 

293 if TYPE_CHECKING: 

294 import numba 

295 else: 

296 numba = import_optional_dependency("numba") 

297 

298 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

299 def ewm_table( 

300 values: np.ndarray, 

301 begin: np.ndarray, 

302 end: np.ndarray, 

303 minimum_periods: int, 

304 ) -> np.ndarray: 

305 alpha = 1.0 / (1.0 + com) 

306 old_wt_factor = 1.0 - alpha 

307 new_wt = 1.0 if adjust else alpha 

308 old_wt = np.ones(values.shape[1]) 

309 

310 result = np.empty(values.shape) 

311 weighted = values[0].copy() 

312 nobs = (~np.isnan(weighted)).astype(np.int64) 

313 result[0] = np.where(nobs >= minimum_periods, weighted, np.nan) 

314 for i in range(1, len(values)): 

315 cur = values[i] 

316 is_observations = ~np.isnan(cur) 

317 nobs += is_observations.astype(np.int64) 

318 for j in numba.prange(len(cur)): 

319 if not np.isnan(weighted[j]): 

320 if is_observations[j] or not ignore_na: 

321 if normalize: 

322 # note that len(deltas) = len(vals) - 1 and deltas[i] 

323 # is to be used in conjunction with vals[i+1] 

324 old_wt[j] *= old_wt_factor ** deltas[i - 1] 

325 else: 

326 weighted[j] = old_wt_factor * weighted[j] 

327 if is_observations[j]: 

328 if normalize: 

329 # avoid numerical errors on constant series 

330 if weighted[j] != cur[j]: 

331 weighted[j] = ( 

332 old_wt[j] * weighted[j] + new_wt * cur[j] 

333 ) 

334 if normalize: 

335 weighted[j] = weighted[j] / (old_wt[j] + new_wt) 

336 if adjust: 

337 old_wt[j] += new_wt 

338 else: 

339 old_wt[j] = 1.0 

340 else: 

341 weighted[j] += cur[j] 

342 elif is_observations[j]: 

343 weighted[j] = cur[j] 

344 

345 result[i] = np.where(nobs >= minimum_periods, weighted, np.nan) 

346 

347 return result 

348 

349 return ewm_table