1from __future__ import annotations
2
3import functools
4from typing import (
5 TYPE_CHECKING,
6 Callable,
7)
8
9import numpy as np
10
11from pandas._typing import Scalar
12from pandas.compat._optional import import_optional_dependency
13
14
15@functools.lru_cache(maxsize=None)
16def generate_shared_aggregator(
17 func: Callable[..., Scalar],
18 nopython: bool,
19 nogil: bool,
20 parallel: bool,
21):
22 """
23 Generate a Numba function that loops over the columns 2D object and applies
24 a 1D numba kernel over each column.
25
26 Parameters
27 ----------
28 func : function
29 aggregation function to be applied to each column
30 nopython : bool
31 nopython to be passed into numba.jit
32 nogil : bool
33 nogil to be passed into numba.jit
34 parallel : bool
35 parallel to be passed into numba.jit
36
37 Returns
38 -------
39 Numba function
40 """
41 if TYPE_CHECKING:
42 import numba
43 else:
44 numba = import_optional_dependency("numba")
45
46 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
47 def column_looper(
48 values: np.ndarray,
49 start: np.ndarray,
50 end: np.ndarray,
51 min_periods: int,
52 *args,
53 ):
54 result = np.empty((len(start), values.shape[1]), dtype=np.float64)
55 for i in numba.prange(values.shape[1]):
56 result[:, i] = func(values[:, i], start, end, min_periods, *args)
57 return result
58
59 return column_looper