Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/theano.py: 28%
25 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Required functions for optimized contractions of numpy arrays using theano.
3"""
5import numpy as np
7from ..sharing import to_backend_cache_wrap
9__all__ = ["to_theano", "build_expression", "evaluate_constants"]
12@to_backend_cache_wrap(constants=True)
13def to_theano(array, constant=False):
14 """Convert a numpy array to ``theano.tensor.TensorType`` instance.
15 """
16 import theano
18 if isinstance(array, np.ndarray):
19 if constant:
20 return theano.tensor.constant(array)
22 return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
24 return array
27def build_expression(arrays, expr):
28 """Build a theano function based on ``arrays`` and ``expr``.
29 """
30 import theano
32 in_vars = [to_theano(array) for array in arrays]
33 out_var = expr._contract(in_vars, backend='theano')
35 # don't supply constants to graph
36 graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
37 graph = theano.function(graph_ins, out_var)
39 def theano_contract(*arrays):
40 return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
42 return theano_contract
45def evaluate_constants(const_arrays, expr):
46 # compute the partial graph of new inputs
47 const_arrays = [to_theano(x, constant=True) for x in const_arrays]
48 new_ops, new_contraction_list = expr(*const_arrays, backend='theano', evaluate_constants=True)
50 # evaluate the new inputs and convert to theano shared tensors
51 new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
53 return new_ops, new_contraction_list