Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/function_serialization.py: 15%

72 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tools for serializing `Function`s.""" 

16 

17from tensorflow.core.protobuf import saved_object_graph_pb2 

18from tensorflow.python.eager import function as defun 

19from tensorflow.python.framework import func_graph as func_graph_module 

20from tensorflow.python.saved_model import nested_structure_coder 

21from tensorflow.python.util import nest 

22 

23 

24def _serialize_function_spec(function_spec): 

25 """Serialize a FunctionSpec object into its proto representation.""" 

26 if ( 

27 function_spec.fullargspec.args 

28 and function_spec.fullargspec.args[0] == "self" 

29 ): 

30 raise TypeError( 

31 "Can not serialize tf.function with unbound 'self' parameter." 

32 ) 

33 

34 proto = saved_object_graph_pb2.FunctionSpec() 

35 

36 # Intentionally skip encoding annotations of a function because function 

37 # annotations are mainly for optional type checking during development 

38 # and does not affect runtime behavior. 

39 # https://www.python.org/dev/peps/pep-3107/ 

40 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec 

41 proto.fullargspec.CopyFrom( 

42 nested_structure_coder.encode_structure( 

43 function_spec.fullargspec._replace(annotations={}))) 

44 

45 proto.is_method = False 

46 proto.input_signature.CopyFrom( 

47 nested_structure_coder.encode_structure(function_spec.input_signature)) 

48 

49 # See `tf.function` and the JitCompile proto for details. 

50 proto.jit_compile = { 

51 None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT, 

52 True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON, 

53 False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF, 

54 }.get(function_spec.jit_compile) 

55 

56 return proto 

57 

58 

59def serialize_concrete_function(concrete_function, node_ids): 

60 """Build a SavedConcreteFunction.""" 

61 bound_inputs = [] 

62 try: 

63 for capture in concrete_function.captured_inputs: 

64 bound_inputs.append(node_ids[capture]) 

65 except KeyError: 

66 raise KeyError( 

67 f"Failed to add concrete function '{concrete_function.name}' to object-" 

68 f"based SavedModel as it captures tensor {capture!r} which is unsupported" 

69 " or not reachable from root. " 

70 "One reason could be that a stateful object or a variable that the " 

71 "function depends on is not assigned to an attribute of the serialized " 

72 "trackable object (see SaveTest.test_captures_unreachable_variable).") 

73 concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction() 

74 structured_outputs = func_graph_module.convert_structure_to_signature( 

75 concrete_function.structured_outputs) 

76 concrete_function_proto.canonicalized_input_signature.CopyFrom( 

77 nested_structure_coder.encode_structure( 

78 concrete_function.structured_input_signature)) 

79 concrete_function_proto.output_signature.CopyFrom( 

80 nested_structure_coder.encode_structure(structured_outputs)) 

81 concrete_function_proto.bound_inputs.extend(bound_inputs) 

82 return concrete_function_proto 

83 

84 

85def serialize_bare_concrete_function(concrete_function): 

86 """Build a SavedBareConcreteFunction.""" 

87 # pylint: disable=protected-access 

88 proto = saved_object_graph_pb2.SavedBareConcreteFunction( 

89 concrete_function_name=concrete_function.name, 

90 allowed_positional_arguments=concrete_function._num_positional_args, 

91 argument_keywords=concrete_function._arg_keywords) 

92 if concrete_function._pre_initialized_function_spec is not None: 

93 proto.function_spec.CopyFrom( 

94 _serialize_function_spec( 

95 concrete_function._pre_initialized_function_spec)) 

96 return proto 

97 # pylint: enable=protected-access 

98 

99 

100def serialize_function(function, concrete_functions): 

101 """Build a SavedFunction proto.""" 

102 proto = saved_object_graph_pb2.SavedFunction() 

103 

104 function_spec_proto = _serialize_function_spec(function.function_spec) 

105 proto.function_spec.CopyFrom(function_spec_proto) 

106 for concrete_function in concrete_functions: 

107 proto.concrete_functions.append(concrete_function.name) 

108 return proto 

109 

110 

111def wrap_cached_variables(concrete_function): 

112 """Wraps the concrete function if it uses cached read tensors. 

113 

114 This function creates a new concrete function that captures variables 

115 instead of the cached read tensors. 

116 

117 Args: 

118 concrete_function: A Concrete function that maybe captures cached read 

119 tensors. 

120 

121 Returns: 

122 A concrete function that wraps the original concrete function, which 

123 captures variables instead. If the original function did not capture any 

124 cached values, then the function is not wrapped and the original object is 

125 returned. 

126 """ 

127 outer_graph = func_graph_module.FuncGraph( 

128 "{}_no_cache".format(concrete_function.graph.name)) 

129 mapped_captures = None 

130 remapped_captures = {} 

131 

132 # Update the external captures to use read tensors generated in the outer 

133 # graph. 

134 with outer_graph.as_default(): 

135 for capture, placeholder in concrete_function.graph.captures: 

136 cached_variable = getattr(capture, "_cached_variable", None) 

137 if cached_variable is None: 

138 continue 

139 cached_variable = cached_variable() 

140 new_cached_value = cached_variable.read_value() 

141 key = id(capture) 

142 external = concrete_function.graph.function_captures.by_val_external[key] 

143 internal = concrete_function.graph.function_captures.by_val_internal[key] 

144 remapped_captures[key] = [external, internal] 

145 concrete_function.graph.function_captures.add_or_replace( 

146 key=key, 

147 external=new_cached_value, 

148 internal=placeholder, 

149 is_by_ref=False) 

150 mapped_captures = True 

151 

152 if not mapped_captures: 

153 return concrete_function 

154 

155 inner_concrete = defun.ConcreteFunction(concrete_function.graph) 

156 

157 def wrap_function(*args): 

158 return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access 

159 

160 args = nest.flatten(concrete_function.structured_input_signature, 

161 expand_composites=True) 

162 func_graph_module.func_graph_from_py_func( 

163 None, wrap_function, args=tuple(args), kwargs={}, 

164 func_graph=outer_graph) 

165 

166 # Create concrete function, and copy the attributes necessary to serialize 

167 # the function. 

168 # pylint: disable=protected-access 

169 fn = defun.ConcreteFunction( 

170 outer_graph, spec=concrete_function._function_spec) 

171 fn._arg_keywords = concrete_function._arg_keywords 

172 fn._num_positional_args = concrete_function._num_positional_args 

173 fn._pre_initialized_function_spec = ( 

174 concrete_function._pre_initialized_function_spec) 

175 # pylint: enable=protected-access 

176 

177 # Return the captures to their original values 

178 for key, capture in remapped_captures.items(): 

179 external, internal = capture 

180 concrete_function.graph._function_captures.add_or_replace( # pylint: disable=protected-access 

181 key=key, 

182 external=external, 

183 internal=internal, 

184 is_by_ref=False) 

185 return fn