Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/theano.py: 28%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using theano. 

3""" 

4 

5import numpy as np 

6 

7from ..sharing import to_backend_cache_wrap 

8 

9__all__ = ["to_theano", "build_expression", "evaluate_constants"] 

10 

11 

12@to_backend_cache_wrap(constants=True) 

13def to_theano(array, constant=False): 

14 """Convert a numpy array to ``theano.tensor.TensorType`` instance. 

15 """ 

16 import theano 

17 

18 if isinstance(array, np.ndarray): 

19 if constant: 

20 return theano.tensor.constant(array) 

21 

22 return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))() 

23 

24 return array 

25 

26 

27def build_expression(arrays, expr): 

28 """Build a theano function based on ``arrays`` and ``expr``. 

29 """ 

30 import theano 

31 

32 in_vars = [to_theano(array) for array in arrays] 

33 out_var = expr._contract(in_vars, backend='theano') 

34 

35 # don't supply constants to graph 

36 graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)] 

37 graph = theano.function(graph_ins, out_var) 

38 

39 def theano_contract(*arrays): 

40 return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)]) 

41 

42 return theano_contract 

43 

44 

45def evaluate_constants(const_arrays, expr): 

46 # compute the partial graph of new inputs 

47 const_arrays = [to_theano(x, constant=True) for x in const_arrays] 

48 new_ops, new_contraction_list = expr(*const_arrays, backend='theano', evaluate_constants=True) 

49 

50 # evaluate the new inputs and convert to theano shared tensors 

51 new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops] 

52 

53 return new_ops, new_contraction_list