Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/utils.py: 28%
112 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions shared between SavedModel saving/loading
16implementations."""
18import copy
19import itertools
20import threading
21import types
23import tensorflow.compat.v2 as tf
25from keras.src import backend
26from keras.src.engine import base_layer_utils
27from keras.src.utils import control_flow_util
28from keras.src.utils import tf_contextlib
29from keras.src.utils.generic_utils import LazyLoader
30from keras.src.utils.layer_utils import CallFunctionSpec
32training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training")
35def use_wrapped_call(
36 layer, call_fn, call_spec, default_training_value=None, return_method=False
37):
38 """Creates fn that adds losses returned by call_fn & returns the outputs.
40 Args:
41 layer: A Keras layer object
42 call_fn: tf.function that takes layer inputs (and possibly a training
43 arg), and returns a tuple of (outputs, list of losses).
44 call_spec: The `CallFunctionSpec` for the layer's call function.
45 default_training_value: Default value of the training kwarg. If `None`,
46 the default is `tf.keras.backend.learning_phase()`.
47 return_method: Whether to return a method bound to the layer.
49 Returns:
50 function that calls call_fn and returns the outputs. Losses returned by
51 call_fn are added to the layer losses.
52 """
53 expects_training_arg = layer_uses_training_bool(layer)
55 fn, arg_spec = maybe_add_training_arg(
56 call_spec, call_fn, expects_training_arg, default_training_value
57 )
59 def return_outputs_and_add_losses(*args, **kwargs):
60 """Returns the outputs from the layer call function, and adds the
61 losses."""
62 if return_method:
63 args = args[1:]
65 outputs, losses = fn(*args, **kwargs)
66 layer.add_loss(losses)
68 # TODO(kathywu): This is a temporary hack. When a network of layers is
69 # revived from SavedModel, only the top-level layer will have losses.
70 # This causes issues in eager mode because the child layers may have
71 # graph losses (thus model.losses returns a mix of Eager and graph
72 # tensors). To fix this, whenever eager losses are added to one layer,
73 # add eager losses to all child layers. This causes `.losses` to only
74 # return eager losses.
76 if tf.executing_eagerly():
77 for i in layer._flatten_layers():
78 if i is not layer:
79 i._eager_losses = [
80 base_layer_utils.REVIVED_LOSS_PLACEHOLDER
81 ]
83 return outputs
85 decorated = tf.__internal__.decorator.make_decorator(
86 target=call_fn,
87 decorator_func=return_outputs_and_add_losses,
88 decorator_argspec=arg_spec,
89 )
91 if return_method:
92 return types.MethodType(decorated, layer)
93 else:
94 return decorated
97def layer_uses_training_bool(layer):
98 """Returns whether this layer or any of its children uses the training
99 arg."""
100 if layer._expects_training_arg:
101 return True
102 visited = {layer}
103 to_visit = list_all_layers(layer)
104 while to_visit:
105 layer = to_visit.pop()
106 if layer in visited:
107 continue
108 if getattr(layer, "_expects_training_arg", True):
109 return True
110 visited.add(layer)
111 to_visit.extend(list_all_layers(layer))
112 return False
115def list_all_layers(obj):
116 if isinstance(obj, training_lib.Model):
117 # Handle special case of Sequential, which doesn't return
118 # the `Input` layer.
119 return obj.layers
120 else:
121 return list(obj._flatten_layers(include_self=False, recursive=False))
124def list_all_layers_and_sublayers(obj):
125 s = set([obj])
126 s.update(
127 itertools.chain.from_iterable(
128 list_all_layers_and_sublayers(layer)
129 for layer in list_all_layers(obj)
130 )
131 )
132 return s
135def maybe_add_training_arg(
136 call_spec, wrapped_call, expects_training_arg, default_training_value
137):
138 """Decorate call and optionally adds training argument.
140 If a layer expects a training argument, this function ensures that
141 'training' is present in the layer args or kwonly args, with the default
142 training value.
144 Args:
145 call_spec: CallFunctionSpec of the layer.
146 wrapped_call: Wrapped call function.
147 expects_training_arg: Whether to include 'training' argument.
148 default_training_value: Default value of the training kwarg to include in
149 the arg spec. If `None`, the default is
150 `tf.keras.backend.learning_phase()`.
152 Returns:
153 Tuple of (
154 function that calls `wrapped_call` and sets the training arg,
155 Argspec of returned function or `None` if the argspec is unchanged)
156 """
157 if not expects_training_arg:
158 return wrapped_call, None
160 arg_spec = set_training_arg_spec(
161 call_spec.full_argspec, default_training_value
162 )
163 call_spec = CallFunctionSpec(arg_spec)
165 def wrap_with_training_arg(*args, **kwargs):
166 """Wrap the `wrapped_call` function, and set training argument."""
167 try:
168 training = call_spec.get_arg_value(
169 "training", args, kwargs, inputs_in_args=True
170 )
171 except KeyError:
172 training = None
174 if training is None:
175 training = (
176 default_training_value
177 or base_layer_utils.call_context().training
178 or backend.learning_phase()
179 )
181 args = list(args)
182 kwargs = kwargs.copy()
184 def replace_training_and_call(training):
185 new_args, new_kwargs = call_spec.set_arg_value(
186 "training", training, args, kwargs, inputs_in_args=True
187 )
188 return wrapped_call(*new_args, **new_kwargs)
190 return control_flow_util.smart_cond(
191 training,
192 lambda: replace_training_and_call(True),
193 lambda: replace_training_and_call(False),
194 )
196 return wrap_with_training_arg, arg_spec
199def set_training_arg_spec(arg_spec, default_training_value):
200 """Set `training=DEFAULT` argument in an ArgSpec."""
201 if "training" in arg_spec.args:
202 # If `training` is already in the args list, try to set the default
203 # value.
204 index = arg_spec.args.index("training")
205 training_default_index = len(arg_spec.args) - index
206 defaults = (
207 list(arg_spec.defaults) if arg_spec.defaults is not None else []
208 )
209 if (
210 arg_spec.defaults
211 and len(arg_spec.defaults) >= training_default_index
212 and defaults[-training_default_index] is None
213 ):
214 defaults[-training_default_index] = default_training_value
215 return arg_spec._replace(defaults=defaults)
216 elif "training" not in arg_spec.kwonlyargs:
217 kwonlyargs = arg_spec.kwonlyargs + ["training"]
218 kwonlydefaults = copy.copy(arg_spec.kwonlydefaults) or {}
219 kwonlydefaults["training"] = default_training_value
220 return arg_spec._replace(
221 kwonlyargs=kwonlyargs, kwonlydefaults=kwonlydefaults
222 )
224 return arg_spec
227class SaveOptionsContext(threading.local):
228 def __init__(self):
229 super().__init__()
230 self.save_traces = True
231 self.in_tf_saved_model_scope = False
234_save_options_context = SaveOptionsContext()
237@tf_contextlib.contextmanager
238def keras_option_scope(save_traces, in_tf_saved_model_scope=True):
239 save_traces_previous_value = _save_options_context.save_traces
240 in_scope_previous_value = _save_options_context.in_tf_saved_model_scope
241 try:
242 _save_options_context.save_traces = save_traces
243 _save_options_context.in_tf_saved_model_scope = in_tf_saved_model_scope
244 yield
245 finally:
246 _save_options_context.save_traces = save_traces_previous_value
247 _save_options_context.in_tf_saved_model_scope = in_scope_previous_value
250def should_save_traces():
251 """Whether to trace layer functions-can be disabled in the save_traces
252 arg."""
253 return _save_options_context.save_traces
256def in_tf_saved_model_scope():
257 return _save_options_context.in_tf_saved_model_scope
260@tf_contextlib.contextmanager
261def no_automatic_dependency_tracking_scope(obj):
262 """Context that disables automatic dependency tracking when assigning attrs.
264 Objects that inherit from Autotrackable automatically creates dependencies
265 to trackable objects through attribute assignments, and wraps data
266 structures (lists or dicts) with trackable classes. This scope may be used
267 to temporarily disable this behavior. This works similar to the decorator
268 `no_automatic_dependency_tracking`.
270 Example usage:
271 ```
272 model = tf.keras.Model()
273 model.arr1 = [] # Creates a ListWrapper object
274 with no_automatic_dependency_tracking_scope(model):
275 model.arr2 = [] # Creates a regular, untracked python list
276 ```
278 Args:
279 obj: A trackable object.
281 Yields:
282 a scope in which the object doesn't track dependencies.
283 """
284 previous_value = getattr(obj, "_setattr_tracking", True)
285 obj._setattr_tracking = False
286 try:
287 yield
288 finally:
289 obj._setattr_tracking = previous_value