1"""Common utilities for Numba operations with groupby ops"""
2from __future__ import annotations
3
4import functools
5import inspect
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 Callable,
10)
11
12import numpy as np
13
14from pandas._typing import Scalar
15from pandas.compat._optional import import_optional_dependency
16
17from pandas.core.util.numba_ import (
18 NumbaUtilError,
19 jit_user_function,
20)
21
22
23def validate_udf(func: Callable) -> None:
24 """
25 Validate user defined function for ops when using Numba with groupby ops.
26
27 The first signature arguments should include:
28
29 def f(values, index, ...):
30 ...
31
32 Parameters
33 ----------
34 func : function, default False
35 user defined function
36
37 Returns
38 -------
39 None
40
41 Raises
42 ------
43 NumbaUtilError
44 """
45 if not callable(func):
46 raise NotImplementedError(
47 "Numba engine can only be used with a single function."
48 )
49 udf_signature = list(inspect.signature(func).parameters.keys())
50 expected_args = ["values", "index"]
51 min_number_args = len(expected_args)
52 if (
53 len(udf_signature) < min_number_args
54 or udf_signature[:min_number_args] != expected_args
55 ):
56 raise NumbaUtilError(
57 f"The first {min_number_args} arguments to {func.__name__} must be "
58 f"{expected_args}"
59 )
60
61
62@functools.lru_cache(maxsize=None)
63def generate_numba_agg_func(
64 func: Callable[..., Scalar],
65 nopython: bool,
66 nogil: bool,
67 parallel: bool,
68) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
69 """
70 Generate a numba jitted agg function specified by values from engine_kwargs.
71
72 1. jit the user's function
73 2. Return a groupby agg function with the jitted function inline
74
75 Configurations specified in engine_kwargs apply to both the user's
76 function _AND_ the groupby evaluation loop.
77
78 Parameters
79 ----------
80 func : function
81 function to be applied to each group and will be JITed
82 nopython : bool
83 nopython to be passed into numba.jit
84 nogil : bool
85 nogil to be passed into numba.jit
86 parallel : bool
87 parallel to be passed into numba.jit
88
89 Returns
90 -------
91 Numba function
92 """
93 numba_func = jit_user_function(func, nopython, nogil, parallel)
94 if TYPE_CHECKING:
95 import numba
96 else:
97 numba = import_optional_dependency("numba")
98
99 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
100 def group_agg(
101 values: np.ndarray,
102 index: np.ndarray,
103 begin: np.ndarray,
104 end: np.ndarray,
105 num_columns: int,
106 *args: Any,
107 ) -> np.ndarray:
108 assert len(begin) == len(end)
109 num_groups = len(begin)
110
111 result = np.empty((num_groups, num_columns))
112 for i in numba.prange(num_groups):
113 group_index = index[begin[i] : end[i]]
114 for j in numba.prange(num_columns):
115 group = values[begin[i] : end[i], j]
116 result[i, j] = numba_func(group, group_index, *args)
117 return result
118
119 return group_agg
120
121
122@functools.lru_cache(maxsize=None)
123def generate_numba_transform_func(
124 func: Callable[..., np.ndarray],
125 nopython: bool,
126 nogil: bool,
127 parallel: bool,
128) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
129 """
130 Generate a numba jitted transform function specified by values from engine_kwargs.
131
132 1. jit the user's function
133 2. Return a groupby transform function with the jitted function inline
134
135 Configurations specified in engine_kwargs apply to both the user's
136 function _AND_ the groupby evaluation loop.
137
138 Parameters
139 ----------
140 func : function
141 function to be applied to each window and will be JITed
142 nopython : bool
143 nopython to be passed into numba.jit
144 nogil : bool
145 nogil to be passed into numba.jit
146 parallel : bool
147 parallel to be passed into numba.jit
148
149 Returns
150 -------
151 Numba function
152 """
153 numba_func = jit_user_function(func, nopython, nogil, parallel)
154 if TYPE_CHECKING:
155 import numba
156 else:
157 numba = import_optional_dependency("numba")
158
159 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
160 def group_transform(
161 values: np.ndarray,
162 index: np.ndarray,
163 begin: np.ndarray,
164 end: np.ndarray,
165 num_columns: int,
166 *args: Any,
167 ) -> np.ndarray:
168 assert len(begin) == len(end)
169 num_groups = len(begin)
170
171 result = np.empty((len(values), num_columns))
172 for i in numba.prange(num_groups):
173 group_index = index[begin[i] : end[i]]
174 for j in numba.prange(num_columns):
175 group = values[begin[i] : end[i], j]
176 result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
177 return result
178
179 return group_transform