Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/jax.py: 38%
13 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Required functions for optimized contractions of numpy arrays using jax.
3"""
5import numpy as np
7from ..sharing import to_backend_cache_wrap
9__all__ = ["build_expression", "evaluate_constants"]
12_JAX = None
15def _get_jax_and_to_jax():
16 global _JAX
17 if _JAX is None:
18 import jax
20 @to_backend_cache_wrap
21 @jax.jit
22 def to_jax(x):
23 return x
25 _JAX = jax, to_jax
27 return _JAX
30def build_expression(_, expr): # pragma: no cover
31 """Build a jax function based on ``arrays`` and ``expr``.
32 """
33 jax, _ = _get_jax_and_to_jax()
35 jax_expr = jax.jit(expr._contract)
37 def jax_contract(*arrays):
38 return np.asarray(jax_expr(arrays))
40 return jax_contract
43def evaluate_constants(const_arrays, expr): # pragma: no cover
44 """Convert constant arguments to jax arrays, and perform any possible
45 constant contractions.
46 """
47 jax, to_jax = _get_jax_and_to_jax()
49 return expr(*[to_jax(x) for x in const_arrays], backend='jax', evaluate_constants=True)