Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pandas/core/groupby/numba_.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

52 statements  

1"""Common utilities for Numba operations with groupby ops""" 

2from __future__ import annotations 

3 

4import functools 

5import inspect 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 Callable, 

10) 

11 

12import numpy as np 

13 

14from pandas._typing import Scalar 

15from pandas.compat._optional import import_optional_dependency 

16 

17from pandas.core.util.numba_ import ( 

18 NumbaUtilError, 

19 jit_user_function, 

20) 

21 

22 

23def validate_udf(func: Callable) -> None: 

24 """ 

25 Validate user defined function for ops when using Numba with groupby ops. 

26 

27 The first signature arguments should include: 

28 

29 def f(values, index, ...): 

30 ... 

31 

32 Parameters 

33 ---------- 

34 func : function, default False 

35 user defined function 

36 

37 Returns 

38 ------- 

39 None 

40 

41 Raises 

42 ------ 

43 NumbaUtilError 

44 """ 

45 if not callable(func): 

46 raise NotImplementedError( 

47 "Numba engine can only be used with a single function." 

48 ) 

49 udf_signature = list(inspect.signature(func).parameters.keys()) 

50 expected_args = ["values", "index"] 

51 min_number_args = len(expected_args) 

52 if ( 

53 len(udf_signature) < min_number_args 

54 or udf_signature[:min_number_args] != expected_args 

55 ): 

56 raise NumbaUtilError( 

57 f"The first {min_number_args} arguments to {func.__name__} must be " 

58 f"{expected_args}" 

59 ) 

60 

61 

62@functools.lru_cache(maxsize=None) 

63def generate_numba_agg_func( 

64 func: Callable[..., Scalar], 

65 nopython: bool, 

66 nogil: bool, 

67 parallel: bool, 

68) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: 

69 """ 

70 Generate a numba jitted agg function specified by values from engine_kwargs. 

71 

72 1. jit the user's function 

73 2. Return a groupby agg function with the jitted function inline 

74 

75 Configurations specified in engine_kwargs apply to both the user's 

76 function _AND_ the groupby evaluation loop. 

77 

78 Parameters 

79 ---------- 

80 func : function 

81 function to be applied to each group and will be JITed 

82 nopython : bool 

83 nopython to be passed into numba.jit 

84 nogil : bool 

85 nogil to be passed into numba.jit 

86 parallel : bool 

87 parallel to be passed into numba.jit 

88 

89 Returns 

90 ------- 

91 Numba function 

92 """ 

93 numba_func = jit_user_function(func, nopython, nogil, parallel) 

94 if TYPE_CHECKING: 

95 import numba 

96 else: 

97 numba = import_optional_dependency("numba") 

98 

99 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

100 def group_agg( 

101 values: np.ndarray, 

102 index: np.ndarray, 

103 begin: np.ndarray, 

104 end: np.ndarray, 

105 num_columns: int, 

106 *args: Any, 

107 ) -> np.ndarray: 

108 assert len(begin) == len(end) 

109 num_groups = len(begin) 

110 

111 result = np.empty((num_groups, num_columns)) 

112 for i in numba.prange(num_groups): 

113 group_index = index[begin[i] : end[i]] 

114 for j in numba.prange(num_columns): 

115 group = values[begin[i] : end[i], j] 

116 result[i, j] = numba_func(group, group_index, *args) 

117 return result 

118 

119 return group_agg 

120 

121 

122@functools.lru_cache(maxsize=None) 

123def generate_numba_transform_func( 

124 func: Callable[..., np.ndarray], 

125 nopython: bool, 

126 nogil: bool, 

127 parallel: bool, 

128) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: 

129 """ 

130 Generate a numba jitted transform function specified by values from engine_kwargs. 

131 

132 1. jit the user's function 

133 2. Return a groupby transform function with the jitted function inline 

134 

135 Configurations specified in engine_kwargs apply to both the user's 

136 function _AND_ the groupby evaluation loop. 

137 

138 Parameters 

139 ---------- 

140 func : function 

141 function to be applied to each window and will be JITed 

142 nopython : bool 

143 nopython to be passed into numba.jit 

144 nogil : bool 

145 nogil to be passed into numba.jit 

146 parallel : bool 

147 parallel to be passed into numba.jit 

148 

149 Returns 

150 ------- 

151 Numba function 

152 """ 

153 numba_func = jit_user_function(func, nopython, nogil, parallel) 

154 if TYPE_CHECKING: 

155 import numba 

156 else: 

157 numba = import_optional_dependency("numba") 

158 

159 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

160 def group_transform( 

161 values: np.ndarray, 

162 index: np.ndarray, 

163 begin: np.ndarray, 

164 end: np.ndarray, 

165 num_columns: int, 

166 *args: Any, 

167 ) -> np.ndarray: 

168 assert len(begin) == len(end) 

169 num_groups = len(begin) 

170 

171 result = np.empty((len(values), num_columns)) 

172 for i in numba.prange(num_groups): 

173 group_index = index[begin[i] : end[i]] 

174 for j in numba.prange(num_columns): 

175 group = values[begin[i] : end[i], j] 

176 result[begin[i] : end[i], j] = numba_func(group, group_index, *args) 

177 return result 

178 

179 return group_transform