Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/tensorflow.py: 21%

63 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using tensorflow. 

3""" 

4 

5import numpy as np 

6 

7from ..sharing import to_backend_cache_wrap 

8 

9__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"] 

10 

11_CACHED_TF_DEVICE = None 

12 

13 

14def _get_tensorflow_and_device(): 

15 global _CACHED_TF_DEVICE 

16 

17 if _CACHED_TF_DEVICE is None: 

18 import tensorflow as tf 

19 

20 try: 

21 eager = tf.executing_eagerly() 

22 except AttributeError: 

23 try: 

24 eager = tf.contrib.eager.in_eager_mode() 

25 except AttributeError: 

26 eager = False 

27 

28 device = tf.test.gpu_device_name() 

29 if not device: 

30 device = "cpu" 

31 

32 _CACHED_TF_DEVICE = tf, device, eager 

33 

34 return _CACHED_TF_DEVICE 

35 

36 

37@to_backend_cache_wrap(constants=True) 

38def to_tensorflow(array, constant=False): 

39 """Convert a numpy array to a ``tensorflow.placeholder`` instance.""" 

40 tf, device, eager = _get_tensorflow_and_device() 

41 

42 if eager: 

43 if isinstance(array, np.ndarray): 

44 with tf.device(device): 

45 return tf.convert_to_tensor(array) 

46 

47 return array 

48 

49 if isinstance(array, np.ndarray): 

50 if constant: 

51 return tf.convert_to_tensor(array) 

52 

53 return tf.placeholder(array.dtype, array.shape) 

54 

55 return array 

56 

57 

58# Standard graph mode 

59 

60 

61def build_expression_graph(arrays, expr): 

62 """Build a tensorflow function based on ``arrays`` and ``expr``.""" 

63 tf, _, _ = _get_tensorflow_and_device() 

64 

65 placeholders = [to_tensorflow(array) for array in arrays] 

66 graph = expr._contract(placeholders, backend="tensorflow") 

67 

68 def tensorflow_contract(*arrays): 

69 session = tf.get_default_session() 

70 # only want to feed placeholders - constant tensors already have values 

71 feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == "Placeholder"} 

72 return session.run(graph, feed_dict=feed_dict) 

73 

74 return tensorflow_contract 

75 

76 

77def evaluate_constants_graph(const_arrays, expr): 

78 """Convert constant arguments to tensorflow constants, and perform any 

79 possible constant contractions. Requires evaluating a tensorflow graph. 

80 """ 

81 tf, _, _ = _get_tensorflow_and_device() 

82 

83 # compute the partial graph of new inputs 

84 const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays] 

85 new_ops, new_contraction_list = expr(*const_arrays, backend="tensorflow", evaluate_constants=True) 

86 

87 # evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts 

88 session = tf.get_default_session() 

89 new_consts = iter(session.run([x for x in new_ops if x is not None])) 

90 new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops] 

91 

92 return new_ops, new_contraction_list 

93 

94 

95# Eager execution mode 

96 

97 

98def build_expression_eager(_, expr): 

99 """Build a eager tensorflow function based on ``arrays`` and ``expr``.""" 

100 

101 def tensorflow_eager_contract(*arrays): 

102 return expr._contract([to_tensorflow(x) for x in arrays], backend="tensorflow").numpy() 

103 

104 return tensorflow_eager_contract 

105 

106 

107def evaluate_constants_eager(const_arrays, expr): 

108 """Convert constant arguments to tensorflow_eager arrays, and perform any 

109 possible constant contractions. 

110 """ 

111 return expr(*[to_tensorflow(x) for x in const_arrays], backend="tensorflow", evaluate_constants=True) 

112 

113 

114# Dispatch to eager or graph mode 

115 

116 

117def build_expression(arrays, expr): 

118 _, _, eager = _get_tensorflow_and_device() 

119 fn = build_expression_eager if eager else build_expression_graph 

120 return fn(arrays, expr) 

121 

122 

123def evaluate_constants(const_arrays, expr): 

124 _, _, eager = _get_tensorflow_and_device() 

125 fn = evaluate_constants_eager if eager else evaluate_constants_graph 

126 return fn(const_arrays, expr)