Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/tensorflow.py: 21%

63 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using tensorflow. 

3""" 

4 

5import numpy as np 

6 

7from ..sharing import to_backend_cache_wrap 

8 

9__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"] 

10 

11_CACHED_TF_DEVICE = None 

12 

13 

14def _get_tensorflow_and_device(): 

15 global _CACHED_TF_DEVICE 

16 

17 if _CACHED_TF_DEVICE is None: 

18 import tensorflow as tf 

19 

20 try: 

21 eager = tf.executing_eagerly() 

22 except AttributeError: 

23 try: 

24 eager = tf.contrib.eager.in_eager_mode() 

25 except AttributeError: 

26 eager = False 

27 

28 device = tf.test.gpu_device_name() 

29 if not device: 

30 device = 'cpu' 

31 

32 _CACHED_TF_DEVICE = tf, device, eager 

33 

34 return _CACHED_TF_DEVICE 

35 

36 

37@to_backend_cache_wrap(constants=True) 

38def to_tensorflow(array, constant=False): 

39 """Convert a numpy array to a ``tensorflow.placeholder`` instance. 

40 """ 

41 tf, device, eager = _get_tensorflow_and_device() 

42 

43 if eager: 

44 if isinstance(array, np.ndarray): 

45 with tf.device(device): 

46 return tf.convert_to_tensor(array) 

47 

48 return array 

49 

50 if isinstance(array, np.ndarray): 

51 if constant: 

52 return tf.convert_to_tensor(array) 

53 

54 return tf.placeholder(array.dtype, array.shape) 

55 

56 return array 

57 

58 

59# Standard graph mode 

60 

61 

62def build_expression_graph(arrays, expr): 

63 """Build a tensorflow function based on ``arrays`` and ``expr``. 

64 """ 

65 tf, _, _ = _get_tensorflow_and_device() 

66 

67 placeholders = [to_tensorflow(array) for array in arrays] 

68 graph = expr._contract(placeholders, backend='tensorflow') 

69 

70 def tensorflow_contract(*arrays): 

71 session = tf.get_default_session() 

72 # only want to feed placeholders - constant tensors already have values 

73 feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == 'Placeholder'} 

74 return session.run(graph, feed_dict=feed_dict) 

75 

76 return tensorflow_contract 

77 

78 

79def evaluate_constants_graph(const_arrays, expr): 

80 """Convert constant arguments to tensorflow constants, and perform any 

81 possible constant contractions. Requires evaluating a tensorflow graph. 

82 """ 

83 tf, _, _ = _get_tensorflow_and_device() 

84 

85 # compute the partial graph of new inputs 

86 const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays] 

87 new_ops, new_contraction_list = expr(*const_arrays, backend='tensorflow', evaluate_constants=True) 

88 

89 # evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts 

90 session = tf.get_default_session() 

91 new_consts = iter(session.run([x for x in new_ops if x is not None])) 

92 new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops] 

93 

94 return new_ops, new_contraction_list 

95 

96 

97# Eager execution mode 

98 

99 

100def build_expression_eager(_, expr): 

101 """Build a eager tensorflow function based on ``arrays`` and ``expr``. 

102 """ 

103 def tensorflow_eager_contract(*arrays): 

104 return expr._contract([to_tensorflow(x) for x in arrays], backend='tensorflow').numpy() 

105 

106 return tensorflow_eager_contract 

107 

108 

109def evaluate_constants_eager(const_arrays, expr): 

110 """Convert constant arguments to tensorflow_eager arrays, and perform any 

111 possible constant contractions. 

112 """ 

113 return expr(*[to_tensorflow(x) for x in const_arrays], backend='tensorflow', evaluate_constants=True) 

114 

115 

116# Dispatch to eager or graph mode 

117 

118 

119def build_expression(arrays, expr): 

120 _, _, eager = _get_tensorflow_and_device() 

121 fn = build_expression_eager if eager else build_expression_graph 

122 return fn(arrays, expr) 

123 

124 

125def evaluate_constants(const_arrays, expr): 

126 _, _, eager = _get_tensorflow_and_device() 

127 fn = evaluate_constants_eager if eager else evaluate_constants_graph 

128 return fn(const_arrays, expr)