1"""Common utilities for Numba operations with groupby ops"""
2from __future__ import annotations
3
4import functools
5import inspect
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 Callable,
10)
11
12import numpy as np
13
14from pandas.compat._optional import import_optional_dependency
15
16from pandas.core.util.numba_ import (
17 NumbaUtilError,
18 jit_user_function,
19)
20
21if TYPE_CHECKING:
22 from pandas._typing import Scalar
23
24
25def validate_udf(func: Callable) -> None:
26 """
27 Validate user defined function for ops when using Numba with groupby ops.
28
29 The first signature arguments should include:
30
31 def f(values, index, ...):
32 ...
33
34 Parameters
35 ----------
36 func : function, default False
37 user defined function
38
39 Returns
40 -------
41 None
42
43 Raises
44 ------
45 NumbaUtilError
46 """
47 if not callable(func):
48 raise NotImplementedError(
49 "Numba engine can only be used with a single function."
50 )
51 udf_signature = list(inspect.signature(func).parameters.keys())
52 expected_args = ["values", "index"]
53 min_number_args = len(expected_args)
54 if (
55 len(udf_signature) < min_number_args
56 or udf_signature[:min_number_args] != expected_args
57 ):
58 raise NumbaUtilError(
59 f"The first {min_number_args} arguments to {func.__name__} must be "
60 f"{expected_args}"
61 )
62
63
64@functools.cache
65def generate_numba_agg_func(
66 func: Callable[..., Scalar],
67 nopython: bool,
68 nogil: bool,
69 parallel: bool,
70) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
71 """
72 Generate a numba jitted agg function specified by values from engine_kwargs.
73
74 1. jit the user's function
75 2. Return a groupby agg function with the jitted function inline
76
77 Configurations specified in engine_kwargs apply to both the user's
78 function _AND_ the groupby evaluation loop.
79
80 Parameters
81 ----------
82 func : function
83 function to be applied to each group and will be JITed
84 nopython : bool
85 nopython to be passed into numba.jit
86 nogil : bool
87 nogil to be passed into numba.jit
88 parallel : bool
89 parallel to be passed into numba.jit
90
91 Returns
92 -------
93 Numba function
94 """
95 numba_func = jit_user_function(func)
96 if TYPE_CHECKING:
97 import numba
98 else:
99 numba = import_optional_dependency("numba")
100
101 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
102 def group_agg(
103 values: np.ndarray,
104 index: np.ndarray,
105 begin: np.ndarray,
106 end: np.ndarray,
107 num_columns: int,
108 *args: Any,
109 ) -> np.ndarray:
110 assert len(begin) == len(end)
111 num_groups = len(begin)
112
113 result = np.empty((num_groups, num_columns))
114 for i in numba.prange(num_groups):
115 group_index = index[begin[i] : end[i]]
116 for j in numba.prange(num_columns):
117 group = values[begin[i] : end[i], j]
118 result[i, j] = numba_func(group, group_index, *args)
119 return result
120
121 return group_agg
122
123
124@functools.cache
125def generate_numba_transform_func(
126 func: Callable[..., np.ndarray],
127 nopython: bool,
128 nogil: bool,
129 parallel: bool,
130) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
131 """
132 Generate a numba jitted transform function specified by values from engine_kwargs.
133
134 1. jit the user's function
135 2. Return a groupby transform function with the jitted function inline
136
137 Configurations specified in engine_kwargs apply to both the user's
138 function _AND_ the groupby evaluation loop.
139
140 Parameters
141 ----------
142 func : function
143 function to be applied to each window and will be JITed
144 nopython : bool
145 nopython to be passed into numba.jit
146 nogil : bool
147 nogil to be passed into numba.jit
148 parallel : bool
149 parallel to be passed into numba.jit
150
151 Returns
152 -------
153 Numba function
154 """
155 numba_func = jit_user_function(func)
156 if TYPE_CHECKING:
157 import numba
158 else:
159 numba = import_optional_dependency("numba")
160
161 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
162 def group_transform(
163 values: np.ndarray,
164 index: np.ndarray,
165 begin: np.ndarray,
166 end: np.ndarray,
167 num_columns: int,
168 *args: Any,
169 ) -> np.ndarray:
170 assert len(begin) == len(end)
171 num_groups = len(begin)
172
173 result = np.empty((len(values), num_columns))
174 for i in numba.prange(num_groups):
175 group_index = index[begin[i] : end[i]]
176 for j in numba.prange(num_columns):
177 group = values[begin[i] : end[i], j]
178 result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
179 return result
180
181 return group_transform