Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/theano.py: 28%
25 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Required functions for optimized contractions of numpy arrays using theano.
3"""
5import numpy as np
7from ..sharing import to_backend_cache_wrap
9__all__ = ["to_theano", "build_expression", "evaluate_constants"]
12@to_backend_cache_wrap(constants=True)
13def to_theano(array, constant=False):
14 """Convert a numpy array to ``theano.tensor.TensorType`` instance."""
15 import theano
17 if isinstance(array, np.ndarray):
18 if constant:
19 return theano.tensor.constant(array)
21 return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
23 return array
26def build_expression(arrays, expr):
27 """Build a theano function based on ``arrays`` and ``expr``."""
28 import theano
30 in_vars = [to_theano(array) for array in arrays]
31 out_var = expr._contract(in_vars, backend="theano")
33 # don't supply constants to graph
34 graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
35 graph = theano.function(graph_ins, out_var)
37 def theano_contract(*arrays):
38 return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
40 return theano_contract
43def evaluate_constants(const_arrays, expr):
44 # compute the partial graph of new inputs
45 const_arrays = [to_theano(x, constant=True) for x in const_arrays]
46 new_ops, new_contraction_list = expr(*const_arrays, backend="theano", evaluate_constants=True)
48 # evaluate the new inputs and convert to theano shared tensors
49 new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
51 return new_ops, new_contraction_list