Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/function_serialization.py: 15%
72 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for serializing `Function`s."""
17from tensorflow.core.protobuf import saved_object_graph_pb2
18from tensorflow.python.eager import function as defun
19from tensorflow.python.framework import func_graph as func_graph_module
20from tensorflow.python.saved_model import nested_structure_coder
21from tensorflow.python.util import nest
24def _serialize_function_spec(function_spec):
25 """Serialize a FunctionSpec object into its proto representation."""
26 if (
27 function_spec.fullargspec.args
28 and function_spec.fullargspec.args[0] == "self"
29 ):
30 raise TypeError(
31 "Can not serialize tf.function with unbound 'self' parameter."
32 )
34 proto = saved_object_graph_pb2.FunctionSpec()
36 # Intentionally skip encoding annotations of a function because function
37 # annotations are mainly for optional type checking during development
38 # and does not affect runtime behavior.
39 # https://www.python.org/dev/peps/pep-3107/
40 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
41 proto.fullargspec.CopyFrom(
42 nested_structure_coder.encode_structure(
43 function_spec.fullargspec._replace(annotations={})))
45 proto.is_method = False
46 proto.input_signature.CopyFrom(
47 nested_structure_coder.encode_structure(function_spec.input_signature))
49 # See `tf.function` and the JitCompile proto for details.
50 proto.jit_compile = {
51 None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT,
52 True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON,
53 False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF,
54 }.get(function_spec.jit_compile)
56 return proto
59def serialize_concrete_function(concrete_function, node_ids):
60 """Build a SavedConcreteFunction."""
61 bound_inputs = []
62 try:
63 for capture in concrete_function.captured_inputs:
64 bound_inputs.append(node_ids[capture])
65 except KeyError:
66 raise KeyError(
67 f"Failed to add concrete function '{concrete_function.name}' to object-"
68 f"based SavedModel as it captures tensor {capture!r} which is unsupported"
69 " or not reachable from root. "
70 "One reason could be that a stateful object or a variable that the "
71 "function depends on is not assigned to an attribute of the serialized "
72 "trackable object (see SaveTest.test_captures_unreachable_variable).")
73 concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
74 structured_outputs = func_graph_module.convert_structure_to_signature(
75 concrete_function.structured_outputs)
76 concrete_function_proto.canonicalized_input_signature.CopyFrom(
77 nested_structure_coder.encode_structure(
78 concrete_function.structured_input_signature))
79 concrete_function_proto.output_signature.CopyFrom(
80 nested_structure_coder.encode_structure(structured_outputs))
81 concrete_function_proto.bound_inputs.extend(bound_inputs)
82 return concrete_function_proto
85def serialize_bare_concrete_function(concrete_function):
86 """Build a SavedBareConcreteFunction."""
87 # pylint: disable=protected-access
88 proto = saved_object_graph_pb2.SavedBareConcreteFunction(
89 concrete_function_name=concrete_function.name,
90 allowed_positional_arguments=concrete_function._num_positional_args,
91 argument_keywords=concrete_function._arg_keywords)
92 if concrete_function._pre_initialized_function_spec is not None:
93 proto.function_spec.CopyFrom(
94 _serialize_function_spec(
95 concrete_function._pre_initialized_function_spec))
96 return proto
97 # pylint: enable=protected-access
100def serialize_function(function, concrete_functions):
101 """Build a SavedFunction proto."""
102 proto = saved_object_graph_pb2.SavedFunction()
104 function_spec_proto = _serialize_function_spec(function.function_spec)
105 proto.function_spec.CopyFrom(function_spec_proto)
106 for concrete_function in concrete_functions:
107 proto.concrete_functions.append(concrete_function.name)
108 return proto
111def wrap_cached_variables(concrete_function):
112 """Wraps the concrete function if it uses cached read tensors.
114 This function creates a new concrete function that captures variables
115 instead of the cached read tensors.
117 Args:
118 concrete_function: A Concrete function that maybe captures cached read
119 tensors.
121 Returns:
122 A concrete function that wraps the original concrete function, which
123 captures variables instead. If the original function did not capture any
124 cached values, then the function is not wrapped and the original object is
125 returned.
126 """
127 outer_graph = func_graph_module.FuncGraph(
128 "{}_no_cache".format(concrete_function.graph.name))
129 mapped_captures = None
130 remapped_captures = {}
132 # Update the external captures to use read tensors generated in the outer
133 # graph.
134 with outer_graph.as_default():
135 for capture, placeholder in concrete_function.graph.captures:
136 cached_variable = getattr(capture, "_cached_variable", None)
137 if cached_variable is None:
138 continue
139 cached_variable = cached_variable()
140 new_cached_value = cached_variable.read_value()
141 key = id(capture)
142 external = concrete_function.graph.function_captures.by_val_external[key]
143 internal = concrete_function.graph.function_captures.by_val_internal[key]
144 remapped_captures[key] = [external, internal]
145 concrete_function.graph.function_captures.add_or_replace(
146 key=key,
147 external=new_cached_value,
148 internal=placeholder,
149 is_by_ref=False)
150 mapped_captures = True
152 if not mapped_captures:
153 return concrete_function
155 inner_concrete = defun.ConcreteFunction(concrete_function.graph)
157 def wrap_function(*args):
158 return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access
160 args = nest.flatten(concrete_function.structured_input_signature,
161 expand_composites=True)
162 func_graph_module.func_graph_from_py_func(
163 None, wrap_function, args=tuple(args), kwargs={},
164 func_graph=outer_graph)
166 # Create concrete function, and copy the attributes necessary to serialize
167 # the function.
168 # pylint: disable=protected-access
169 fn = defun.ConcreteFunction(
170 outer_graph, spec=concrete_function._function_spec)
171 fn._arg_keywords = concrete_function._arg_keywords
172 fn._num_positional_args = concrete_function._num_positional_args
173 fn._pre_initialized_function_spec = (
174 concrete_function._pre_initialized_function_spec)
175 # pylint: enable=protected-access
177 # Return the captures to their original values
178 for key, capture in remapped_captures.items():
179 external, internal = capture
180 concrete_function.graph._function_captures.add_or_replace( # pylint: disable=protected-access
181 key=key,
182 external=external,
183 internal=internal,
184 is_by_ref=False)
185 return fn