Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/save_options.py: 40%
45 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Options for saving SavedModels."""
17import enum
19from tensorflow.python.util import compat
20from tensorflow.python.util.tf_export import tf_export
23@tf_export("saved_model.experimental.VariablePolicy")
24class VariablePolicy(enum.Enum):
25 """Enum defining options for variable handling when saving.
27 NONE
28 No policy applied: Distributed variables are saved as one variable, with no
29 device attached.
31 SAVE_VARIABLE_DEVICES
32 When saving variables, also save their device assignment.
33 This is useful if one wants to hardcode devices in saved models, but it also
34 makes them non-portable if soft device placement is disabled (more details
35 in `tf.config.set_soft_device_placement`). This is currently not
36 fully supported by `saved_model.load`, and is mainly intended to be used
37 when one will be reading the saved model at a lower API level. In the
38 example below, the graph saved by the call to `saved_model.save` will have
39 the variable devices correctly specified:
40 ```python
41 exported = tf.train.Checkpoint()
42 with tf.device('/GPU:0'):
43 exported.x_gpu = tf.Variable(1.0)
44 with tf.device('/CPU:0'):
45 exported.x_cpu = tf.Variable(1.0)
46 tf.saved_model.save(exported, export_dir,
47 options = tf.saved_model.SaveOptions(
48 experimental_variable_policy=
49 tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
50 ```
51 Distributed variables are still saved as one variable under this policy.
53 EXPAND_DISTRIBUTED_VARIABLES
54 Distributed variables will be saved with information about their components,
55 allowing for their restoration on load. Also, the saved graph will contain
56 references to those variables. This is useful when one wants to use the
57 model for training in environments where the original distribution strategy
58 is not available.
59 """
61 NONE = None
63 SAVE_VARIABLE_DEVICES = "save_variable_devices"
65 EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
67 def _save_variable_devices(self):
68 """Checks whether variable devices should be saved."""
69 return self != VariablePolicy.NONE
71 def _expand_distributed_variables(self):
72 """Checks whether distributed variables should be expanded."""
73 return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
75 @staticmethod
76 def from_obj(obj):
77 """Tries to convert `obj` to a VariablePolicy instance."""
78 if obj is None:
79 return VariablePolicy.NONE
80 if isinstance(obj, VariablePolicy):
81 return obj
82 key = str(obj).lower()
83 for policy in VariablePolicy:
84 if key == policy.value:
85 return policy
86 raise ValueError(f"Received invalid VariablePolicy value: {obj}.")
89@tf_export("saved_model.SaveOptions")
90class SaveOptions:
91 """Options for saving to SavedModel.
93 This function may be used in the `options` argument in functions that
94 save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
95 """
97 # Define object attributes in __slots__ for improved memory and performance.
98 __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
99 "experimental_io_device", "experimental_variable_policy",
100 "experimental_custom_gradients")
102 def __init__(self,
103 namespace_whitelist=None,
104 save_debug_info=False,
105 function_aliases=None,
106 experimental_io_device=None,
107 experimental_variable_policy=None,
108 experimental_custom_gradients=True):
109 """Creates an object that stores options for SavedModel saving.
111 Args:
112 namespace_whitelist: List of strings containing op namespaces to whitelist
113 when saving a model. Saving an object that uses namespaced ops must
114 explicitly add all namespaces to the whitelist. The namespaced ops must
115 be registered into the framework when loading the SavedModel. If no
116 whitelist is provided, all namespaced ops will be allowed.
117 save_debug_info: Boolean indicating whether debug information is saved. If
118 True, then a debug/saved_model_debug_info.pb file will be written with
119 the contents of a GraphDebugInfo binary protocol buffer containing stack
120 trace information for all ops and functions that are saved.
121 function_aliases: Python dict. Mapping from string to object returned by
122 @tf.function. A single tf.function can generate many ConcreteFunctions.
123 If a downstream tool wants to refer to all concrete functions generated
124 by a single tf.function you can use the `function_aliases` argument to
125 store a map from the alias name to all concrete function names.
126 E.g.
128 >>> class Adder(tf.Module):
129 ... @tf.function
130 ... def double(self, x):
131 ... return x + x
133 >>> model = Adder()
134 >>> model.double.get_concrete_function(
135 ... tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input"))
136 >>> model.double.get_concrete_function(
137 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))
139 >>> options = tf.saved_model.SaveOptions(
140 ... function_aliases={'double': model.double})
141 >>> tf.saved_model.save(model, '/tmp/adder', options=options)
143 experimental_io_device: string. Applies in a distributed setting.
144 Tensorflow device to use to access the filesystem. If `None` (default)
145 then for each variable the filesystem is accessed from the CPU:0 device
146 of the host where that variable is assigned. If specified, the
147 filesystem is instead accessed from that device for all variables.
149 This is for example useful if you want to save to a local directory,
150 such as "/tmp" when running in a distributed setting. In that case pass
151 a device for the host where the "/tmp" directory is accessible.
152 experimental_variable_policy: The policy to apply to variables when
153 saving. This is either a `saved_model.experimental.VariablePolicy` enum
154 instance or one of its value strings (case is not important). See that
155 enum documentation for details. A value of `None` corresponds to the
156 default policy.
157 experimental_custom_gradients: Boolean. When True, will save traced
158 gradient functions for the functions decorated by `tf.custom_gradient`.
159 Defaults to `True`.
160 """
161 self.namespace_whitelist = _validate_namespace_whitelist(
162 namespace_whitelist)
163 self.save_debug_info = save_debug_info
164 self.function_aliases = function_aliases if function_aliases else dict()
165 self.experimental_custom_gradients = experimental_custom_gradients
166 self.experimental_io_device = experimental_io_device
167 self.experimental_variable_policy = (
168 VariablePolicy.from_obj(experimental_variable_policy))
171def _validate_namespace_whitelist(namespace_whitelist):
172 """Validates namespace whitelist argument."""
173 if namespace_whitelist is None:
174 return None
175 if not isinstance(namespace_whitelist, list):
176 raise TypeError("`namespace_whitelist` must be a list of strings. Got: "
177 f"{namespace_whitelist} with type "
178 f"{type(namespace_whitelist)}.")
180 processed = []
181 for namespace in namespace_whitelist:
182 if not isinstance(namespace, str):
183 raise ValueError("Whitelisted namespace must be a string. Got: "
184 f"{namespace} of type {type(namespace)}.")
185 processed.append(compat.as_str(namespace))
186 return processed