1"""Common utilities for Numba operations"""
2from __future__ import annotations
3
4import types
5from typing import (
6 TYPE_CHECKING,
7 Callable,
8)
9
10import numpy as np
11
12from pandas.compat._optional import import_optional_dependency
13from pandas.errors import NumbaUtilError
14
15GLOBAL_USE_NUMBA: bool = False
16
17
18def maybe_use_numba(engine: str | None) -> bool:
19 """Signal whether to use numba routines."""
20 return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
21
22
23def set_use_numba(enable: bool = False) -> None:
24 global GLOBAL_USE_NUMBA
25 if enable:
26 import_optional_dependency("numba")
27 GLOBAL_USE_NUMBA = enable
28
29
30def get_jit_arguments(
31 engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
32) -> dict[str, bool]:
33 """
34 Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
35
36 Parameters
37 ----------
38 engine_kwargs : dict, default None
39 user passed keyword arguments for numba.JIT
40 kwargs : dict, default None
41 user passed keyword arguments to pass into the JITed function
42
43 Returns
44 -------
45 dict[str, bool]
46 nopython, nogil, parallel
47
48 Raises
49 ------
50 NumbaUtilError
51 """
52 if engine_kwargs is None:
53 engine_kwargs = {}
54
55 nopython = engine_kwargs.get("nopython", True)
56 if kwargs and nopython:
57 raise NumbaUtilError(
58 "numba does not support kwargs with nopython=True: "
59 "https://github.com/numba/numba/issues/2916"
60 )
61 nogil = engine_kwargs.get("nogil", False)
62 parallel = engine_kwargs.get("parallel", False)
63 return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
64
65
66def jit_user_function(func: Callable) -> Callable:
67 """
68 If user function is not jitted already, mark the user's function
69 as jitable.
70
71 Parameters
72 ----------
73 func : function
74 user defined function
75
76 Returns
77 -------
78 function
79 Numba JITed function, or function marked as JITable by numba
80 """
81 if TYPE_CHECKING:
82 import numba
83 else:
84 numba = import_optional_dependency("numba")
85
86 if numba.extending.is_jitted(func):
87 # Don't jit a user passed jitted function
88 numba_func = func
89 elif getattr(np, func.__name__, False) is func or isinstance(
90 func, types.BuiltinFunctionType
91 ):
92 # Not necessary to jit builtins or np functions
93 # This will mess up register_jitable
94 numba_func = func
95 else:
96 numba_func = numba.extending.register_jitable(func)
97
98 return numba_func