Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/jax.py: 38%

13 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using jax. 

3""" 

4 

5import numpy as np 

6 

7from ..sharing import to_backend_cache_wrap 

8 

9__all__ = ["build_expression", "evaluate_constants"] 

10 

11_JAX = None 

12 

13 

14def _get_jax_and_to_jax(): 

15 global _JAX 

16 if _JAX is None: 

17 import jax 

18 

19 @to_backend_cache_wrap 

20 @jax.jit 

21 def to_jax(x): 

22 return x 

23 

24 _JAX = jax, to_jax 

25 

26 return _JAX 

27 

28 

29def build_expression(_, expr): # pragma: no cover 

30 """Build a jax function based on ``arrays`` and ``expr``.""" 

31 jax, _ = _get_jax_and_to_jax() 

32 

33 jax_expr = jax.jit(expr._contract) 

34 

35 def jax_contract(*arrays): 

36 return np.asarray(jax_expr(arrays)) 

37 

38 return jax_contract 

39 

40 

41def evaluate_constants(const_arrays, expr): # pragma: no cover 

42 """Convert constant arguments to jax arrays, and perform any possible 

43 constant contractions. 

44 """ 

45 jax, to_jax = _get_jax_and_to_jax() 

46 

47 return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True)