Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/jax.py: 38%

13 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using jax. 

3""" 

4 

5import numpy as np 

6 

7from ..sharing import to_backend_cache_wrap 

8 

9__all__ = ["build_expression", "evaluate_constants"] 

10 

11 

12_JAX = None 

13 

14 

15def _get_jax_and_to_jax(): 

16 global _JAX 

17 if _JAX is None: 

18 import jax 

19 

20 @to_backend_cache_wrap 

21 @jax.jit 

22 def to_jax(x): 

23 return x 

24 

25 _JAX = jax, to_jax 

26 

27 return _JAX 

28 

29 

30def build_expression(_, expr): # pragma: no cover 

31 """Build a jax function based on ``arrays`` and ``expr``. 

32 """ 

33 jax, _ = _get_jax_and_to_jax() 

34 

35 jax_expr = jax.jit(expr._contract) 

36 

37 def jax_contract(*arrays): 

38 return np.asarray(jax_expr(arrays)) 

39 

40 return jax_contract 

41 

42 

43def evaluate_constants(const_arrays, expr): # pragma: no cover 

44 """Convert constant arguments to jax arrays, and perform any possible 

45 constant contractions. 

46 """ 

47 jax, to_jax = _get_jax_and_to_jax() 

48 

49 return expr(*[to_jax(x) for x in const_arrays], backend='jax', evaluate_constants=True)