Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/utils_impl.py: 42%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""SavedModel utility functions implementation.""" 

16 

17from tensorflow.core.framework import types_pb2 

18from tensorflow.core.protobuf import meta_graph_pb2 

19from tensorflow.core.protobuf import struct_pb2 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import byte_swap_tensor as bst 

22from tensorflow.python.framework import composite_tensor 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import sparse_tensor 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.ops import resource_variable_ops 

28from tensorflow.python.saved_model import nested_structure_coder 

29from tensorflow.python.util import deprecation 

30from tensorflow.python.util import nest 

31from tensorflow.python.util.tf_export import tf_export 

32 

33 

34# TensorInfo helpers. 

35_DEPRECATION_MSG = ( 

36 "This API was designed for TensorFlow v1. See " 

37 "https://www.tensorflow.org/guide/migrate for instructions on how to " 

38 "migrate your code to TensorFlow v2.") 

39 

40 

41@tf_export( 

42 v1=["saved_model.build_tensor_info", "saved_model.utils.build_tensor_info"]) 

43@deprecation.deprecated(None, _DEPRECATION_MSG) 

44def build_tensor_info(tensor): 

45 """Utility function to build TensorInfo proto from a Tensor. 

46 

47 Args: 

48 tensor: Tensor or SparseTensor whose name, dtype and shape are used to 

49 build the TensorInfo. For SparseTensors, the names of the three 

50 constituent Tensors are used. 

51 

52 Returns: 

53 A TensorInfo protocol buffer constructed based on the supplied argument. 

54 

55 Raises: 

56 RuntimeError: If eager execution is enabled. 

57 

58 @compatibility(TF2) 

59 This API is not compatible with eager execution as `tensor` needs to be a 

60 graph tensor, and there is no replacement for it in TensorFlow 2.x. To start 

61 writing programs using TensorFlow 2.x, please refer to the [Effective 

62 TensorFlow 2](https://www.tensorflow.org/guide/effective_tf2) guide. 

63 @end_compatibility 

64 """ 

65 if context.executing_eagerly(): 

66 raise RuntimeError("`build_tensor_info` is not supported in eager " 

67 "execution.") 

68 return build_tensor_info_internal(tensor) 

69 

70 

71def build_tensor_info_internal(tensor): 

72 """Utility function to build TensorInfo proto from a Tensor.""" 

73 if (isinstance(tensor, composite_tensor.CompositeTensor) and 

74 not isinstance(tensor, sparse_tensor.SparseTensor) and 

75 not isinstance(tensor, resource_variable_ops.ResourceVariable)): 

76 return _build_composite_tensor_info_internal(tensor) 

77 

78 tensor_info = meta_graph_pb2.TensorInfo( 

79 dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum, 

80 tensor_shape=tensor.get_shape().as_proto()) 

81 if isinstance(tensor, sparse_tensor.SparseTensor): 

82 tensor_info.coo_sparse.values_tensor_name = tensor.values.name 

83 tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name 

84 tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name 

85 else: 

86 tensor_info.name = tensor.name 

87 return tensor_info 

88 

89 

90def _build_composite_tensor_info_internal(tensor): 

91 """Utility function to build TensorInfo proto from a CompositeTensor.""" 

92 spec = tensor._type_spec # pylint: disable=protected-access 

93 tensor_info = meta_graph_pb2.TensorInfo() 

94 spec_proto = nested_structure_coder.encode_structure(spec) 

95 tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value) 

96 for component in nest.flatten(tensor, expand_composites=True): 

97 tensor_info.composite_tensor.components.add().CopyFrom( 

98 build_tensor_info_internal(component)) 

99 return tensor_info 

100 

101 

102def build_tensor_info_from_op(op): 

103 """Utility function to build TensorInfo proto from an Op. 

104 

105 Note that this function should be used with caution. It is strictly restricted 

106 to TensorFlow internal use-cases only. Please make sure you do need it before 

107 using it. 

108 

109 This utility function overloads the TensorInfo proto by setting the name to 

110 the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage 

111 is for the Op of the call site for the defunned function: 

112 ```python 

113 @function.defun 

114 def some_variable_initialization_fn(value_a, value_b): 

115 a = value_a 

116 b = value_b 

117 

118 value_a = constant_op.constant(1, name="a") 

119 value_b = constant_op.constant(2, name="b") 

120 op_info = utils.build_op_info( 

121 some_variable_initialization_fn(value_a, value_b)) 

122 ``` 

123 

124 Args: 

125 op: An Op whose name is used to build the TensorInfo. The name that points 

126 to the Op could be fetched at run time in the Loader session. 

127 

128 Returns: 

129 A TensorInfo protocol buffer constructed based on the supplied argument. 

130 

131 Raises: 

132 RuntimeError: If eager execution is enabled. 

133 """ 

134 if context.executing_eagerly(): 

135 raise RuntimeError( 

136 "`build_tensor_info_from_op` is not supported in eager execution.") 

137 return meta_graph_pb2.TensorInfo( 

138 dtype=types_pb2.DT_INVALID, 

139 tensor_shape=tensor_shape.unknown_shape().as_proto(), 

140 name=op.name) 

141 

142 

143@tf_export(v1=["saved_model.get_tensor_from_tensor_info", 

144 "saved_model.utils.get_tensor_from_tensor_info"]) 

145@deprecation.deprecated(None, _DEPRECATION_MSG) 

146def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): 

147 """Returns the Tensor or CompositeTensor described by a TensorInfo proto. 

148 

149 Args: 

150 tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or 

151 CompositeTensor. 

152 graph: The tf.Graph in which tensors are looked up. If None, the 

153 current default graph is used. 

154 import_scope: If not None, names in `tensor_info` are prefixed with this 

155 string before lookup. 

156 

157 Returns: 

158 The Tensor or SparseTensor or CompositeTensor in `graph` described by 

159 `tensor_info`. 

160 

161 Raises: 

162 KeyError: If `tensor_info` does not correspond to a tensor in `graph`. 

163 ValueError: If `tensor_info` is malformed. 

164 """ 

165 graph = graph or ops.get_default_graph() 

166 def _get_tensor(name): 

167 return graph.get_tensor_by_name( 

168 ops.prepend_name_scope(name, import_scope=import_scope)) 

169 encoding = tensor_info.WhichOneof("encoding") 

170 if encoding == "name": 

171 return _get_tensor(tensor_info.name) 

172 elif encoding == "coo_sparse": 

173 return sparse_tensor.SparseTensor( 

174 _get_tensor(tensor_info.coo_sparse.indices_tensor_name), 

175 _get_tensor(tensor_info.coo_sparse.values_tensor_name), 

176 _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) 

177 elif encoding == "composite_tensor": 

178 spec_proto = struct_pb2.StructuredValue( 

179 type_spec_value=tensor_info.composite_tensor.type_spec) 

180 spec = nested_structure_coder.decode_proto(spec_proto) 

181 components = [_get_tensor(component.name) for component in 

182 tensor_info.composite_tensor.components] 

183 return nest.pack_sequence_as(spec, components, expand_composites=True) 

184 else: 

185 raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Expected `" 

186 "coo_sparse`, `composite_tensor`, or `name` for a dense " 

187 "tensor.") 

188 

189 

190def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None): 

191 """Returns the element in the graph described by a TensorInfo proto. 

192 

193 Args: 

194 tensor_info: A TensorInfo proto describing an Op or Tensor by name. 

195 graph: The tf.Graph in which tensors are looked up. If None, the current 

196 default graph is used. 

197 import_scope: If not None, names in `tensor_info` are prefixed with this 

198 string before lookup. 

199 

200 Returns: 

201 Op or tensor in `graph` described by `tensor_info`. 

202 

203 Raises: 

204 KeyError: If `tensor_info` does not correspond to an op or tensor in `graph` 

205 """ 

206 graph = graph or ops.get_default_graph() 

207 return graph.as_graph_element( 

208 ops.prepend_name_scope(tensor_info.name, import_scope=import_scope)) 

209 

210 

211def swap_function_tensor_content(meta_graph_def, from_endiness, to_endiness): 

212 bst.swap_tensor_content_in_graph_function( 

213 meta_graph_def, from_endiness, to_endiness 

214 )