Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/capture/restore_captures.py: 24%
50 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""A shim layer for working with functions exported/restored from saved models.
18This functionality should ultimately be moved into a first-class core API.
19"""
21import warnings
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import handle_data_util
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.ops import variables as variables_lib
29from tensorflow.python.trackable import asset
30from tensorflow.python.trackable import resource
33def get_tensor_from_node(node):
34 """Resolves a saved model graph node into a tensor to be captured.
36 Args:
37 node: a tensor, variable, or resource to be resolved into a capturable
38 tensor
40 Returns:
41 A list of tensors.
42 Raises:
43 ValueError: if the node cannot be converted into a tensor.
44 """
45 with ops.init_scope():
46 # TODO(b/210144904): Use __tf_tensor__ instead of `is_[...]` checks
47 if getattr(node, "is_distributed_variable", False):
48 return node
49 elif getattr(node, "is_distributed_table", False):
50 return node
51 elif getattr(node, "is_sharded_variable", False):
52 return node
53 elif resource_variable_ops.is_resource_variable(node):
54 return node.handle
55 elif isinstance(node, asset.Asset):
56 return node.asset_path
57 elif tensor_util.is_tf_type(node):
58 return node
59 elif isinstance(node, resource.CapturableResource):
60 # Note: this executes restored functions in the CapturableResource.
61 return node.resource_handle
62 raise ValueError(f"Cannot convert node {node} to tensor.")
65def restore_captures(concrete_function, inputs):
66 """Restore captures for the concrete function.
68 Used at deserialization time. For functions that are being deserialized,
69 saved model restores objects that tensors were captured from, but functions
70 only know about their tensors -- object information is destroyed by tracing.
71 This additional logic extracts the tensors which the function originally
72 captured.
74 Args:
75 concrete_function: the concrete function for which to restore captures
76 inputs: a list tensors or other Python objects (such as variables) which
77 contain tensors that were originally captured by the function
78 """
79 bound_inputs = [get_tensor_from_node(obj) for obj in inputs]
80 # pylint: disable=g-complex-comprehension
81 bound_variables = [
82 obj
83 for obj in inputs
84 if isinstance(
85 obj,
86 (variables_lib.Variable, resource_variable_ops.BaseResourceVariable),
87 )
88 ]
89 # TODO(b/205010575): This is only injecting the captured inputs into the
90 # concrete function, note that we did not modify the FuncGraph
91 # itself.
92 captured_inputs_list = []
93 concrete_function.set_variables(bound_variables)
94 if bound_inputs:
95 for bound_input, internal_capture in zip(
96 bound_inputs, concrete_function.inputs[-len(bound_inputs) :]
97 ):
98 # Distributed inputs have special logic for capturing, so we call their
99 # custom restoration methods
100 if hasattr(bound_input, "__tf_experimental_restore_capture__"):
101 captured_inputs_list.append(
102 bound_input.__tf_experimental_restore_capture__(
103 concrete_function, internal_capture
104 )
105 )
106 else:
107 captured_inputs_list.append(bound_input)
108 concrete_function.graph.replace_capture(bound_input, internal_capture)
109 if internal_capture.dtype == dtypes.resource:
110 if resource_variable_ops.is_resource_variable(bound_input):
111 try:
112 handle = bound_input.handle
113 except ValueError:
114 # For mirrored variables we'll copy handle data for components
115 # as they get captured.
116 pass
117 else:
118 handle_data_util.copy_handle_data(handle, internal_capture)
119 else:
120 # TODO(b/213451747): Remove need to call copy_handle_data
121 handle_data_util.copy_handle_data(bound_input, internal_capture)
122 # Setting "captures" first means "capture" won't create a new
123 # placeholder for this input.
124 concrete_function.graph.capture(bound_input)
126 if any([inp is None for inp in captured_inputs_list]):
127 warnings.warn(
128 "Trying to load ShardedVariables using tf.saved_model.load. "
129 "This won't work if using a tf.distribute.Strategy, and may "
130 "use excess memory if not using a Strategy. Ignore this "
131 "warning if using tf.keras.models.load_model."
132 )
133 concrete_function.set_external_captures(captured_inputs_list)