Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/tensorflow.py: 21%
63 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Required functions for optimized contractions of numpy arrays using tensorflow.
3"""
5import numpy as np
7from ..sharing import to_backend_cache_wrap
9__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]
11_CACHED_TF_DEVICE = None
14def _get_tensorflow_and_device():
15 global _CACHED_TF_DEVICE
17 if _CACHED_TF_DEVICE is None:
18 import tensorflow as tf
20 try:
21 eager = tf.executing_eagerly()
22 except AttributeError:
23 try:
24 eager = tf.contrib.eager.in_eager_mode()
25 except AttributeError:
26 eager = False
28 device = tf.test.gpu_device_name()
29 if not device:
30 device = "cpu"
32 _CACHED_TF_DEVICE = tf, device, eager
34 return _CACHED_TF_DEVICE
37@to_backend_cache_wrap(constants=True)
38def to_tensorflow(array, constant=False):
39 """Convert a numpy array to a ``tensorflow.placeholder`` instance."""
40 tf, device, eager = _get_tensorflow_and_device()
42 if eager:
43 if isinstance(array, np.ndarray):
44 with tf.device(device):
45 return tf.convert_to_tensor(array)
47 return array
49 if isinstance(array, np.ndarray):
50 if constant:
51 return tf.convert_to_tensor(array)
53 return tf.placeholder(array.dtype, array.shape)
55 return array
58# Standard graph mode
61def build_expression_graph(arrays, expr):
62 """Build a tensorflow function based on ``arrays`` and ``expr``."""
63 tf, _, _ = _get_tensorflow_and_device()
65 placeholders = [to_tensorflow(array) for array in arrays]
66 graph = expr._contract(placeholders, backend="tensorflow")
68 def tensorflow_contract(*arrays):
69 session = tf.get_default_session()
70 # only want to feed placeholders - constant tensors already have values
71 feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == "Placeholder"}
72 return session.run(graph, feed_dict=feed_dict)
74 return tensorflow_contract
77def evaluate_constants_graph(const_arrays, expr):
78 """Convert constant arguments to tensorflow constants, and perform any
79 possible constant contractions. Requires evaluating a tensorflow graph.
80 """
81 tf, _, _ = _get_tensorflow_and_device()
83 # compute the partial graph of new inputs
84 const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays]
85 new_ops, new_contraction_list = expr(*const_arrays, backend="tensorflow", evaluate_constants=True)
87 # evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts
88 session = tf.get_default_session()
89 new_consts = iter(session.run([x for x in new_ops if x is not None]))
90 new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops]
92 return new_ops, new_contraction_list
95# Eager execution mode
98def build_expression_eager(_, expr):
99 """Build a eager tensorflow function based on ``arrays`` and ``expr``."""
101 def tensorflow_eager_contract(*arrays):
102 return expr._contract([to_tensorflow(x) for x in arrays], backend="tensorflow").numpy()
104 return tensorflow_eager_contract
107def evaluate_constants_eager(const_arrays, expr):
108 """Convert constant arguments to tensorflow_eager arrays, and perform any
109 possible constant contractions.
110 """
111 return expr(*[to_tensorflow(x) for x in const_arrays], backend="tensorflow", evaluate_constants=True)
114# Dispatch to eager or graph mode
117def build_expression(arrays, expr):
118 _, _, eager = _get_tensorflow_and_device()
119 fn = build_expression_eager if eager else build_expression_graph
120 return fn(arrays, expr)
123def evaluate_constants(const_arrays, expr):
124 _, _, eager = _get_tensorflow_and_device()
125 fn = evaluate_constants_eager if eager else evaluate_constants_graph
126 return fn(const_arrays, expr)