Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/tensorflow.py: 21%
63 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Required functions for optimized contractions of numpy arrays using tensorflow.
3"""
5import numpy as np
7from ..sharing import to_backend_cache_wrap
9__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]
11_CACHED_TF_DEVICE = None
14def _get_tensorflow_and_device():
15 global _CACHED_TF_DEVICE
17 if _CACHED_TF_DEVICE is None:
18 import tensorflow as tf
20 try:
21 eager = tf.executing_eagerly()
22 except AttributeError:
23 try:
24 eager = tf.contrib.eager.in_eager_mode()
25 except AttributeError:
26 eager = False
28 device = tf.test.gpu_device_name()
29 if not device:
30 device = 'cpu'
32 _CACHED_TF_DEVICE = tf, device, eager
34 return _CACHED_TF_DEVICE
37@to_backend_cache_wrap(constants=True)
38def to_tensorflow(array, constant=False):
39 """Convert a numpy array to a ``tensorflow.placeholder`` instance.
40 """
41 tf, device, eager = _get_tensorflow_and_device()
43 if eager:
44 if isinstance(array, np.ndarray):
45 with tf.device(device):
46 return tf.convert_to_tensor(array)
48 return array
50 if isinstance(array, np.ndarray):
51 if constant:
52 return tf.convert_to_tensor(array)
54 return tf.placeholder(array.dtype, array.shape)
56 return array
59# Standard graph mode
62def build_expression_graph(arrays, expr):
63 """Build a tensorflow function based on ``arrays`` and ``expr``.
64 """
65 tf, _, _ = _get_tensorflow_and_device()
67 placeholders = [to_tensorflow(array) for array in arrays]
68 graph = expr._contract(placeholders, backend='tensorflow')
70 def tensorflow_contract(*arrays):
71 session = tf.get_default_session()
72 # only want to feed placeholders - constant tensors already have values
73 feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == 'Placeholder'}
74 return session.run(graph, feed_dict=feed_dict)
76 return tensorflow_contract
79def evaluate_constants_graph(const_arrays, expr):
80 """Convert constant arguments to tensorflow constants, and perform any
81 possible constant contractions. Requires evaluating a tensorflow graph.
82 """
83 tf, _, _ = _get_tensorflow_and_device()
85 # compute the partial graph of new inputs
86 const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays]
87 new_ops, new_contraction_list = expr(*const_arrays, backend='tensorflow', evaluate_constants=True)
89 # evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts
90 session = tf.get_default_session()
91 new_consts = iter(session.run([x for x in new_ops if x is not None]))
92 new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops]
94 return new_ops, new_contraction_list
97# Eager execution mode
100def build_expression_eager(_, expr):
101 """Build a eager tensorflow function based on ``arrays`` and ``expr``.
102 """
103 def tensorflow_eager_contract(*arrays):
104 return expr._contract([to_tensorflow(x) for x in arrays], backend='tensorflow').numpy()
106 return tensorflow_eager_contract
109def evaluate_constants_eager(const_arrays, expr):
110 """Convert constant arguments to tensorflow_eager arrays, and perform any
111 possible constant contractions.
112 """
113 return expr(*[to_tensorflow(x) for x in const_arrays], backend='tensorflow', evaluate_constants=True)
116# Dispatch to eager or graph mode
119def build_expression(arrays, expr):
120 _, _, eager = _get_tensorflow_and_device()
121 fn = build_expression_eager if eager else build_expression_graph
122 return fn(arrays, expr)
125def evaluate_constants(const_arrays, expr):
126 _, _, eager = _get_tensorflow_and_device()
127 fn = evaluate_constants_eager if eager else evaluate_constants_graph
128 return fn(const_arrays, expr)