Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/theano.py: 28%

25 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using theano. 

3""" 

4 

5import numpy as np 

6 

7from ..sharing import to_backend_cache_wrap 

8 

9__all__ = ["to_theano", "build_expression", "evaluate_constants"] 

10 

11 

12@to_backend_cache_wrap(constants=True) 

13def to_theano(array, constant=False): 

14 """Convert a numpy array to ``theano.tensor.TensorType`` instance.""" 

15 import theano 

16 

17 if isinstance(array, np.ndarray): 

18 if constant: 

19 return theano.tensor.constant(array) 

20 

21 return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))() 

22 

23 return array 

24 

25 

26def build_expression(arrays, expr): 

27 """Build a theano function based on ``arrays`` and ``expr``.""" 

28 import theano 

29 

30 in_vars = [to_theano(array) for array in arrays] 

31 out_var = expr._contract(in_vars, backend="theano") 

32 

33 # don't supply constants to graph 

34 graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)] 

35 graph = theano.function(graph_ins, out_var) 

36 

37 def theano_contract(*arrays): 

38 return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)]) 

39 

40 return theano_contract 

41 

42 

43def evaluate_constants(const_arrays, expr): 

44 # compute the partial graph of new inputs 

45 const_arrays = [to_theano(x, constant=True) for x in const_arrays] 

46 new_ops, new_contraction_list = expr(*const_arrays, backend="theano", evaluate_constants=True) 

47 

48 # evaluate the new inputs and convert to theano shared tensors 

49 new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops] 

50 

51 return new_ops, new_contraction_list