Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/jax.py: 38%
13 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Required functions for optimized contractions of numpy arrays using jax.
3"""
5import numpy as np
7from ..sharing import to_backend_cache_wrap
9__all__ = ["build_expression", "evaluate_constants"]
11_JAX = None
14def _get_jax_and_to_jax():
15 global _JAX
16 if _JAX is None:
17 import jax
19 @to_backend_cache_wrap
20 @jax.jit
21 def to_jax(x):
22 return x
24 _JAX = jax, to_jax
26 return _JAX
29def build_expression(_, expr): # pragma: no cover
30 """Build a jax function based on ``arrays`` and ``expr``."""
31 jax, _ = _get_jax_and_to_jax()
33 jax_expr = jax.jit(expr._contract)
35 def jax_contract(*arrays):
36 return np.asarray(jax_expr(arrays))
38 return jax_contract
41def evaluate_constants(const_arrays, expr): # pragma: no cover
42 """Convert constant arguments to jax arrays, and perform any possible
43 constant contractions.
44 """
45 jax, to_jax = _get_jax_and_to_jax()
47 return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True)