Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py: 24%
131 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions shared between SavedModel saving/loading implementations."""
17import itertools
18import threading
19import types
21from tensorflow.python.eager import context
22from tensorflow.python.keras import backend as K
23from tensorflow.python.keras.engine import base_layer_utils
24from tensorflow.python.keras.utils import control_flow_util
25from tensorflow.python.keras.utils import tf_contextlib
26from tensorflow.python.keras.utils import tf_inspect
27from tensorflow.python.keras.utils.generic_utils import LazyLoader
28from tensorflow.python.util import tf_decorator
31# pylint:disable=g-inconsistent-quotes
32training_lib = LazyLoader(
33 "training_lib", globals(),
34 "tensorflow.python.keras.engine.training")
35# pylint:enable=g-inconsistent-quotes
38def use_wrapped_call(layer, call_fn, default_training_value=None,
39 return_method=False):
40 """Creates fn that adds the losses returned by call_fn & returns the outputs.
42 Args:
43 layer: A Keras layer object
44 call_fn: tf.function that takes layer inputs (and possibly a training arg),
45 and returns a tuple of (outputs, list of losses).
46 default_training_value: Default value of the training kwarg. If `None`, the
47 default is `K.learning_phase()`.
48 return_method: Whether to return a method bound to the layer.
50 Returns:
51 function that calls call_fn and returns the outputs. Losses returned by
52 call_fn are added to the layer losses.
53 """
54 expects_training_arg = layer_uses_training_bool(layer)
55 if hasattr(call_fn, 'original_layer_call'): # call_fn is a LayerCall object
56 original_call = call_fn.original_layer_call
57 # In Python 3, callable objects are not compatible with inspect.getargspec
58 call_fn = call_fn.__call__
59 else:
60 original_call = call_fn
61 fn, arg_spec = maybe_add_training_arg(
62 original_call, call_fn, expects_training_arg, default_training_value)
64 def return_outputs_and_add_losses(*args, **kwargs):
65 """Returns the outputs from the layer call function, and adds the losses."""
66 if return_method:
67 args = args[1:]
69 outputs, losses = fn(*args, **kwargs)
70 layer.add_loss(losses, inputs=True)
72 # TODO(kathywu): This is a temporary hack. When a network of layers is
73 # revived from SavedModel, only the top-level layer will have losses. This
74 # causes issues in eager mode because the child layers may have graph losses
75 # (thus model.losses returns a mix of Eager and graph tensors). To fix this,
76 # whenever eager losses are added to one layer, add eager losses to all
77 # child layers. This causes `.losses` to only return eager losses.
78 # pylint: disable=protected-access
79 if context.executing_eagerly():
80 for i in layer._flatten_layers():
81 if i is not layer:
82 i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
83 # pylint: enable=protected-access
84 return outputs
86 decorated = tf_decorator.make_decorator(
87 target=call_fn,
88 decorator_func=return_outputs_and_add_losses,
89 decorator_argspec=arg_spec)
91 if return_method:
92 return types.MethodType(decorated, layer)
93 else:
94 return decorated
97def layer_uses_training_bool(layer):
98 """Returns whether this layer or any of its children uses the training arg."""
99 if layer._expects_training_arg: # pylint: disable=protected-access
100 return True
101 visited = {layer}
102 to_visit = list_all_layers(layer)
103 while to_visit:
104 layer = to_visit.pop()
105 if layer in visited:
106 continue
107 if getattr(layer, '_expects_training_arg', True):
108 return True
109 visited.add(layer)
110 to_visit.extend(list_all_layers(layer))
111 return False
114def list_all_layers(obj):
115 if isinstance(obj, training_lib.Model):
116 # Handle special case of Sequential, which doesn't return
117 # the `Input` layer.
118 return obj.layers
119 else:
120 return list(obj._flatten_layers(include_self=False, recursive=False)) # pylint: disable=protected-access
123def list_all_layers_and_sublayers(obj):
124 s = set([obj])
125 s.update(itertools.chain.from_iterable(
126 list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj)))
127 return s
130def maybe_add_training_arg(
131 original_call, wrapped_call, expects_training_arg, default_training_value):
132 """Decorate call and optionally adds training argument.
134 If a layer expects a training argument, this function ensures that 'training'
135 is present in the layer args or kwonly args, with the default training value.
137 Args:
138 original_call: Original call function.
139 wrapped_call: Wrapped call function.
140 expects_training_arg: Whether to include 'training' argument.
141 default_training_value: Default value of the training kwarg to include in
142 the arg spec. If `None`, the default is `K.learning_phase()`.
144 Returns:
145 Tuple of (
146 function that calls `wrapped_call` and sets the training arg,
147 Argspec of returned function or `None` if the argspec is unchanged)
148 """
149 if not expects_training_arg:
150 return wrapped_call, None
151 def wrap_with_training_arg(*args, **kwargs):
152 """Wrap the `wrapped_call` function, and set training argument."""
153 training_arg_index = get_training_arg_index(original_call)
154 training = get_training_arg(training_arg_index, args, kwargs)
155 if training is None:
156 training = default_training_value or K.learning_phase()
158 args = list(args)
159 kwargs = kwargs.copy()
161 def replace_training_and_call(training):
162 set_training_arg(training, training_arg_index, args, kwargs)
163 return wrapped_call(*args, **kwargs)
165 return control_flow_util.smart_cond(
166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
169 # Create arg spec for decorated function. If 'training' is not defined in the
170 # args of the original arg spec, then add it to kwonlyargs.
171 arg_spec = tf_inspect.getfullargspec(original_call)
172 defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
174 kwonlyargs = arg_spec.kwonlyargs
175 kwonlydefaults = arg_spec.kwonlydefaults or {}
176 # Add training arg if it does not exist, or set the default training value.
177 if 'training' not in arg_spec.args:
178 kwonlyargs.append('training')
179 kwonlydefaults['training'] = default_training_value
180 else:
181 index = arg_spec.args.index('training')
182 training_default_index = len(arg_spec.args) - index
183 if (arg_spec.defaults and
184 len(arg_spec.defaults) >= training_default_index and
185 defaults[-training_default_index] is None):
186 defaults[-training_default_index] = default_training_value
188 decorator_argspec = tf_inspect.FullArgSpec(
189 args=arg_spec.args,
190 varargs=arg_spec.varargs,
191 varkw=arg_spec.varkw,
192 defaults=defaults,
193 kwonlyargs=kwonlyargs,
194 kwonlydefaults=kwonlydefaults,
195 annotations=arg_spec.annotations)
196 return wrap_with_training_arg, decorator_argspec
199def get_training_arg_index(call_fn):
200 """Returns the index of 'training' in the layer call function arguments.
202 Args:
203 call_fn: Call function.
205 Returns:
206 - n: index of 'training' in the call function arguments.
207 - -1: if 'training' is not found in the arguments, but layer.call accepts
208 variable keyword arguments
209 - None: if layer doesn't expect a training argument.
210 """
211 argspec = tf_inspect.getfullargspec(call_fn)
212 if argspec.varargs:
213 # When there are variable args, training must be a keyword arg.
214 if 'training' in argspec.kwonlyargs or argspec.varkw:
215 return -1
216 return None
217 else:
218 # Try to find 'training' in the list of args or kwargs.
219 arg_list = argspec.args
220 if tf_inspect.ismethod(call_fn):
221 arg_list = arg_list[1:]
223 if 'training' in arg_list:
224 return arg_list.index('training')
225 elif 'training' in argspec.kwonlyargs or argspec.varkw:
226 return -1
227 return None
230def set_training_arg(training, index, args, kwargs):
231 if index is None or index < 0 or len(args) <= index: # index is invalid
232 kwargs['training'] = training
233 else:
234 args[index] = training
235 return args, kwargs
238def get_training_arg(index, args, kwargs):
239 if index is None or index < 0 or len(args) <= index: # index is invalid
240 return kwargs.get('training', None)
241 else:
242 return args[index]
245def remove_training_arg(index, args, kwargs):
246 if index is None or index < 0 or len(args) <= index: # index is invalid
247 kwargs.pop('training', None)
248 else:
249 args.pop(index)
252class SaveOptionsContext(threading.local):
254 def __init__(self):
255 super(SaveOptionsContext, self).__init__()
256 self.save_traces = True
259_save_options_context = SaveOptionsContext()
262@tf_contextlib.contextmanager
263def keras_option_scope(save_traces):
264 previous_value = _save_options_context.save_traces
265 try:
266 _save_options_context.save_traces = save_traces
267 yield
268 finally:
269 _save_options_context.save_traces = previous_value
272def should_save_traces():
273 """Whether to trace layer functions-can be disabled in the save_traces arg."""
274 return _save_options_context.save_traces
277@tf_contextlib.contextmanager
278def no_automatic_dependency_tracking_scope(obj):
279 """A context that disables automatic dependency tracking when assigning attrs.
281 Objects that inherit from Autotrackable automatically creates dependencies
282 to trackable objects through attribute assignments, and wraps data structures
283 (lists or dicts) with trackable classes. This scope may be used to temporarily
284 disable this behavior. This works similar to the decorator
285 `no_automatic_dependency_tracking`.
287 Example usage:
288 ```
289 model = tf.keras.Model()
290 model.arr1 = [] # Creates a ListWrapper object
291 with no_automatic_dependency_tracking_scope(model):
292 model.arr2 = [] # Creates a regular, untracked python list
293 ```
295 Args:
296 obj: A trackable object.
298 Yields:
299 a scope in which the object doesn't track dependencies.
300 """
301 previous_value = getattr(obj, '_setattr_tracking', True)
302 obj._setattr_tracking = False # pylint: disable=protected-access
303 try:
304 yield
305 finally:
306 obj._setattr_tracking = previous_value # pylint: disable=protected-access