Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/pandas/core/groupby/numba_.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

53 statements  

1"""Common utilities for Numba operations with groupby ops""" 

2from __future__ import annotations 

3 

4import functools 

5import inspect 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 Callable, 

10) 

11 

12import numpy as np 

13 

14from pandas.compat._optional import import_optional_dependency 

15 

16from pandas.core.util.numba_ import ( 

17 NumbaUtilError, 

18 jit_user_function, 

19) 

20 

21if TYPE_CHECKING: 

22 from pandas._typing import Scalar 

23 

24 

25def validate_udf(func: Callable) -> None: 

26 """ 

27 Validate user defined function for ops when using Numba with groupby ops. 

28 

29 The first signature arguments should include: 

30 

31 def f(values, index, ...): 

32 ... 

33 

34 Parameters 

35 ---------- 

36 func : function, default False 

37 user defined function 

38 

39 Returns 

40 ------- 

41 None 

42 

43 Raises 

44 ------ 

45 NumbaUtilError 

46 """ 

47 if not callable(func): 

48 raise NotImplementedError( 

49 "Numba engine can only be used with a single function." 

50 ) 

51 udf_signature = list(inspect.signature(func).parameters.keys()) 

52 expected_args = ["values", "index"] 

53 min_number_args = len(expected_args) 

54 if ( 

55 len(udf_signature) < min_number_args 

56 or udf_signature[:min_number_args] != expected_args 

57 ): 

58 raise NumbaUtilError( 

59 f"The first {min_number_args} arguments to {func.__name__} must be " 

60 f"{expected_args}" 

61 ) 

62 

63 

64@functools.cache 

65def generate_numba_agg_func( 

66 func: Callable[..., Scalar], 

67 nopython: bool, 

68 nogil: bool, 

69 parallel: bool, 

70) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: 

71 """ 

72 Generate a numba jitted agg function specified by values from engine_kwargs. 

73 

74 1. jit the user's function 

75 2. Return a groupby agg function with the jitted function inline 

76 

77 Configurations specified in engine_kwargs apply to both the user's 

78 function _AND_ the groupby evaluation loop. 

79 

80 Parameters 

81 ---------- 

82 func : function 

83 function to be applied to each group and will be JITed 

84 nopython : bool 

85 nopython to be passed into numba.jit 

86 nogil : bool 

87 nogil to be passed into numba.jit 

88 parallel : bool 

89 parallel to be passed into numba.jit 

90 

91 Returns 

92 ------- 

93 Numba function 

94 """ 

95 numba_func = jit_user_function(func) 

96 if TYPE_CHECKING: 

97 import numba 

98 else: 

99 numba = import_optional_dependency("numba") 

100 

101 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

102 def group_agg( 

103 values: np.ndarray, 

104 index: np.ndarray, 

105 begin: np.ndarray, 

106 end: np.ndarray, 

107 num_columns: int, 

108 *args: Any, 

109 ) -> np.ndarray: 

110 assert len(begin) == len(end) 

111 num_groups = len(begin) 

112 

113 result = np.empty((num_groups, num_columns)) 

114 for i in numba.prange(num_groups): 

115 group_index = index[begin[i] : end[i]] 

116 for j in numba.prange(num_columns): 

117 group = values[begin[i] : end[i], j] 

118 result[i, j] = numba_func(group, group_index, *args) 

119 return result 

120 

121 return group_agg 

122 

123 

124@functools.cache 

125def generate_numba_transform_func( 

126 func: Callable[..., np.ndarray], 

127 nopython: bool, 

128 nogil: bool, 

129 parallel: bool, 

130) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: 

131 """ 

132 Generate a numba jitted transform function specified by values from engine_kwargs. 

133 

134 1. jit the user's function 

135 2. Return a groupby transform function with the jitted function inline 

136 

137 Configurations specified in engine_kwargs apply to both the user's 

138 function _AND_ the groupby evaluation loop. 

139 

140 Parameters 

141 ---------- 

142 func : function 

143 function to be applied to each window and will be JITed 

144 nopython : bool 

145 nopython to be passed into numba.jit 

146 nogil : bool 

147 nogil to be passed into numba.jit 

148 parallel : bool 

149 parallel to be passed into numba.jit 

150 

151 Returns 

152 ------- 

153 Numba function 

154 """ 

155 numba_func = jit_user_function(func) 

156 if TYPE_CHECKING: 

157 import numba 

158 else: 

159 numba = import_optional_dependency("numba") 

160 

161 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

162 def group_transform( 

163 values: np.ndarray, 

164 index: np.ndarray, 

165 begin: np.ndarray, 

166 end: np.ndarray, 

167 num_columns: int, 

168 *args: Any, 

169 ) -> np.ndarray: 

170 assert len(begin) == len(end) 

171 num_groups = len(begin) 

172 

173 result = np.empty((len(values), num_columns)) 

174 for i in numba.prange(num_groups): 

175 group_index = index[begin[i] : end[i]] 

176 for j in numba.prange(num_columns): 

177 group = values[begin[i] : end[i], j] 

178 result[begin[i] : end[i], j] = numba_func(group, group_index, *args) 

179 return result 

180 

181 return group_transform