Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py: 24%
310 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel serialization.
17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should
18go to model_serialization.py.
19"""
21import functools
22import threading
23import weakref
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras.engine import base_layer_utils
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.mixed_precision import autocast_variable
32from tensorflow.python.keras.saving import saving_utils
33from tensorflow.python.keras.saving.saved_model import constants
34from tensorflow.python.keras.saving.saved_model import load as keras_load
35from tensorflow.python.keras.saving.saved_model import serialized_attributes
36from tensorflow.python.keras.saving.saved_model import utils
37from tensorflow.python.keras.utils import tf_contextlib
38from tensorflow.python.keras.utils import tf_inspect
39from tensorflow.python.keras.utils import tf_utils
40from tensorflow.python.keras.utils import version_utils
41from tensorflow.python.keras.utils.generic_utils import LazyLoader
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.trackable import data_structures
44from tensorflow.python.util import nest
45from tensorflow.python.util import tf_decorator
48# To avoid circular dependencies between keras/engine and keras/saving,
49# code in keras/saving must delay imports.
51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
52# once the issue with copybara is fixed.
53# pylint:disable=g-inconsistent-quotes
54base_layer = LazyLoader(
55 "base_layer", globals(),
56 "tensorflow.python.keras.engine.base_layer")
57metrics = LazyLoader("metrics", globals(),
58 "tensorflow.python.keras.metrics")
59input_layer = LazyLoader(
60 "input_layer", globals(),
61 "tensorflow.python.keras.engine.input_layer")
62training_lib = LazyLoader(
63 "training_lib", globals(),
64 "tensorflow.python.keras.engine.training")
65sequential_lib = LazyLoader(
66 "sequential_lib", globals(),
67 "tensorflow.python.keras.engine.sequential")
68# pylint:enable=g-inconsistent-quotes
71def should_skip_serialization(layer):
72 """Skip serializing extra objects and functions if layer inputs aren't set."""
73 saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and
74 layer._saved_model_inputs_spec is not None) # pylint: disable=protected-access
75 if not layer.built and not saved_model_input_spec_set:
76 logging.warning('Skipping full serialization of Keras layer {}, because '
77 'it is not built.'.format(layer))
78 return True
79 return False
82def wrap_layer_objects(layer, serialization_cache):
83 """Returns extra trackable objects to attach to the serialized layer.
85 Args:
86 layer: Keras Layer object.
87 serialization_cache: Dictionary shared between all objects during
88 serialization.
90 Returns:
91 A dictionary containing all checkpointable objects from a
92 SerializedAttributes object. See LayerAttributes and ModelAttributes for
93 entire list of objects
94 """
95 # Wrap all regularization losses as tf.functions.
96 # First, generate list of all regularization losses in this layer and
97 # sublayers.
98 all_losses = layer._callable_losses[:] # pylint: disable=protected-access
99 for child_layer in utils.list_all_layers(layer):
100 all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access
101 # Next, wrap all loss functions as tf.functions. Use the serialization cache
102 # to store already-wrapped functions.
103 keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
104 wrapped_loss_functions = []
105 for loss_fn in all_losses:
106 if loss_fn in keras_loss_cache:
107 wrapped_loss_functions.append(keras_loss_cache[loss_fn])
108 else:
109 wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache))
110 keras_loss_cache[loss_fn] = wrapped_loss
111 wrapped_loss_functions.append(wrapped_loss)
112 wrapped_layer_losses = [keras_loss_cache[fn]
113 for fn in layer._callable_losses[:]] # pylint: disable=protected-access
115 layer_metrics = data_structures.wrap_or_unwrap(
116 {m.name: m for m in layer._metrics}) # pylint: disable=protected-access
117 return dict(
118 variables=data_structures.wrap_or_unwrap(layer.variables),
119 trainable_variables=data_structures.wrap_or_unwrap(
120 layer.trainable_variables),
121 non_trainable_variables=data_structures.wrap_or_unwrap(
122 layer.non_trainable_variables),
123 layers=data_structures.wrap_or_unwrap(utils.list_all_layers(layer)),
124 metrics=data_structures.wrap_or_unwrap(layer.metrics),
125 regularization_losses=data_structures.wrap_or_unwrap(
126 wrapped_loss_functions),
127 layer_regularization_losses=data_structures.wrap_or_unwrap(
128 wrapped_layer_losses),
129 layer_metrics=layer_metrics)
130 # pylint: disable=protected-access
133def wrap_layer_functions(layer, serialization_cache):
134 """Returns dict of wrapped layer call function and losses in tf.functions.
136 Args:
137 layer: Keras Layer object.
138 serialization_cache: Dictionary shared between all objects during
139 serialization.
141 Returns:
142 A dictionary containing all keras tf.functions to serialize. See
143 LayerAttributes and ModelAttributes for the list of all attributes.
144 """
145 # Since Sequential models may be modified in place using model.add() or
146 # model.pop(), don't use saved functions.
147 if (isinstance(layer, keras_load.RevivedLayer) and
148 not isinstance(layer, sequential_lib.Sequential)):
149 return {fn_name: getattr(layer.keras_api, fn_name, None)
150 for fn_name in serialized_attributes.LayerAttributes.all_functions}
152 # Reset the losses of the layer and its children. The call function in each
153 # child layer is replaced with tf.functions.
154 original_fns = _replace_child_layer_functions(layer, serialization_cache)
155 original_losses = _reset_layer_losses(layer)
157 # Wrap all the layer call and activity regularizer functions.
159 # Use LayerCallCollection to ensure that all layer call functions (__call__,
160 # call with losses) are traced with the same inputs.
161 call_collection = LayerCallCollection(layer)
162 call_fn_with_losses = call_collection.add_function(
163 _wrap_call_and_conditional_losses(layer),
164 '{}_layer_call_and_return_conditional_losses'.format(layer.name),
165 # If any of this layer's child layers use the training arg, the traced
166 # call functions of this layer will have a training keyword argument. If
167 # the original layer does not expect the training arg, then it will have
168 # to be removed (by setting `match_layer_training_arg`).
169 match_layer_training_arg=True)
170 call_fn = call_collection.add_function(
171 _extract_outputs_from_fn(layer, call_fn_with_losses),
172 '{}_layer_call_fn'.format(layer.name),
173 # Since `call_fn` wraps call_fn_with_losses and not the original call
174 # function, `match_layer_training_arg` should be set to False.
175 match_layer_training_arg=False)
177 fns = {'call_and_return_conditional_losses': call_fn_with_losses,
178 '__call__': call_fn}
180 if layer._activity_regularizer is not None: # pylint: disable=protected-access
181 fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
182 fns['call_and_return_all_conditional_losses'] = (
183 call_collection.add_function(
184 _append_activity_regularizer_loss(
185 layer, call_fn_with_losses, fns['activity_regularizer_fn']),
186 '{}_layer_call_and_return_all_conditional_losses'.format(
187 layer.name),
188 match_layer_training_arg=False))
189 else:
190 fns['activity_regularizer_fn'] = None
191 fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
193 # Manually trigger traces before restoring the overwritten functions. The
194 # functions are traced within the layer call context to ensure that layer
195 # functions (e.g. add_loss) behave as though running in graph mode.
196 with tracing_scope():
197 call_collection.trace_with_input_signature()
198 with base_layer_utils.call_context().enter(
199 layer, inputs=None, build_graph=True, training=None, saving=True):
200 for fn in fns.values():
201 if fn is not None and fn.input_signature is not None:
202 if isinstance(fn, LayerCall):
203 fn = fn.wrapped_call
204 fn.get_concrete_function()
206 # Restore overwritten functions and losses
207 _restore_child_layer_functions(original_fns)
208 _restore_layer_losses(original_losses)
210 return fns
213def default_save_signature(layer):
214 original_losses = _reset_layer_losses(layer)
215 fn = saving_utils.trace_model_call(layer)
216 fn.get_concrete_function()
217 _restore_layer_losses(original_losses)
218 return fn
221def _replace_child_layer_functions(layer, serialization_cache):
222 """Replaces functions in the children layers with wrapped tf.functions.
224 This step allows functions from parent layers to reference the wrapped
225 functions from their children layers instead of retracing the ops.
227 This function also resets all losses stored in the layer. These are stored in
228 the returned dictionary. Use `_restore_child_layer_functions` to restore
229 the original attributes.
231 Args:
232 layer: Keras Layer object.
233 serialization_cache: Dictionary shared between all objects during
234 serialization.
236 Returns:
237 Dictionary mapping layer objects -> original functions and losses:
238 { Child layer 1: {
239 'losses': Original losses,
240 'call': Original call function
241 '_activity_regularizer': Original activity regularizer},
242 Child layer 2: ...
243 }
244 """
245 # pylint: disable=protected-access
246 original_fns = {}
248 def replace_layer_functions(child_layer, serialized_fns):
249 """Replaces layer call and activity regularizer with wrapped functions."""
250 original_fns[child_layer] = {
251 'call': child_layer.call,
252 '_activity_regularizer': child_layer._activity_regularizer
253 }
254 with utils.no_automatic_dependency_tracking_scope(child_layer):
255 try:
256 child_layer._activity_regularizer = serialized_fns.get(
257 'activity_regularizer_fn')
258 except AttributeError:
259 # Some layers have an unsettable activity regularizer.
260 pass
261 child_layer.call = utils.use_wrapped_call(
262 child_layer,
263 serialized_fns['call_and_return_conditional_losses'],
264 default_training_value=False)
266 def replace_metric_functions(child_layer, serialized_fns):
267 """Replaces metric functions with wrapped functions."""
268 original_fns[child_layer] = {
269 '__call__': child_layer.__call__,
270 'result': child_layer.result,
271 'update_state': child_layer.update_state
272 }
273 with utils.no_automatic_dependency_tracking_scope(child_layer):
274 child_layer.__call__ = serialized_fns['__call__']
275 child_layer.result = serialized_fns['result']
276 child_layer.update_state = serialized_fns['update_state']
278 for child_layer in utils.list_all_layers(layer):
279 if isinstance(child_layer, input_layer.InputLayer):
280 continue
282 if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
283 serialized_functions = (
284 child_layer._trackable_saved_model_saver._get_serialized_attributes(
285 serialization_cache).functions)
286 else:
287 serialized_functions = (
288 serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
289 if not serialized_functions:
290 # This indicates either:
291 # - circular dependency, which means the current layer's functions
292 # should be wrapped first.
293 # - Child layer's inputs are not defined, so its functions have not been
294 # wrapped. In this case, no replacement is necessary so move on to the
295 # next child.
296 continue
298 if isinstance(child_layer, metrics.Metric):
299 replace_metric_functions(child_layer, serialized_functions)
300 else:
301 replace_layer_functions(child_layer, serialized_functions)
303 return original_fns
304 # pylint: enable=protected-access
307def _restore_child_layer_functions(original_fns):
308 """Restores attributes replaced with `_replace_child_layer_functions`."""
309 for child_layer, fns in original_fns.items():
310 with utils.no_automatic_dependency_tracking_scope(child_layer):
311 for fn_name, fn in fns.items():
312 try:
313 setattr(child_layer, fn_name, fn) # pylint: disable=protected-access
314 except AttributeError:
315 pass # In the case of _activity_regularizer, setting the attribute
316 # may be disallowed.
319# pylint: disable=protected-access
320def _reset_layer_losses(parent_layer):
321 """Resets losses of layer and its sublayers, and returns original losses."""
322 losses_dict = {}
323 for layer in utils.list_all_layers_and_sublayers(parent_layer):
324 losses_dict[layer] = {'losses': layer._losses[:],
325 'eager_losses': layer._eager_losses[:]}
326 with utils.no_automatic_dependency_tracking_scope(layer):
327 layer._losses = []
328 layer._eager_losses = []
329 return losses_dict
332def _restore_layer_losses(losses_dict):
333 for layer in losses_dict:
334 with utils.no_automatic_dependency_tracking_scope(layer):
335 layer._losses = losses_dict[layer]['losses']
336 layer._eager_losses = losses_dict[layer]['eager_losses']
337# pylint: enable=protected-access
340class LayerTracingContext(threading.local):
342 def __init__(self):
343 super(LayerTracingContext, self).__init__()
344 self.enable_call_tracing = False
345 self.trace_queue = []
347_thread_local_data = LayerTracingContext()
350@tf_contextlib.contextmanager
351def tracing_scope():
352 """Enables tracing scope."""
353 # This enables the LayerCallCollection's tracing mechanism to trace all call
354 # functions in the collection.
355 previous_value = _thread_local_data.enable_call_tracing
356 previous_queue = _thread_local_data.trace_queue
357 try:
358 _thread_local_data.enable_call_tracing = True
359 _thread_local_data.trace_queue = []
360 yield
361 finally:
362 # Run traces from the queue.
363 while _thread_local_data.trace_queue:
364 fn, args, kwargs, training = _thread_local_data.trace_queue.pop()
365 if training is not None:
366 with K.deprecated_internal_learning_phase_scope(training):
367 fn.get_concrete_function(*args, **kwargs)
368 else:
369 fn.get_concrete_function(*args, **kwargs)
370 _thread_local_data.trace_queue = previous_queue
371 _thread_local_data.enable_call_tracing = previous_value
374def add_trace_to_queue(fn, args, kwargs, training=None):
375 if tracing_enabled():
376 _thread_local_data.trace_queue.append(
377 (fn, args[:], kwargs.copy(), training))
380def tracing_enabled():
381 """Whether to add extra traces to the queue."""
382 return _thread_local_data.enable_call_tracing
385class LayerCallCollection(object):
386 """Groups wrapped layer call functions.
388 This is used to ensure that all layer call functions are traced with the same
389 inputs-
390 - call
391 - call_and_return_conditional_losses
392 - call_and_return_all_conditional_losses
393 """
395 def __init__(self, layer):
396 self.layer = layer
398 self.layer_call_method = _get_layer_call_method(layer)
399 self._expects_training_arg = utils.layer_uses_training_bool(layer)
400 self._training_arg_index = utils.get_training_arg_index(
401 self.layer_call_method)
403 # If the layer call function has kwargs, then the traced function cannot
404 # have an input signature.
405 arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
406 self._has_kwargs = bool(self._expects_training_arg or
407 arg_spec.defaults or
408 arg_spec.kwonlyargs or
409 arg_spec.varkw)
411 self._input_signature = self._generate_input_signature(layer)
412 self._functions = weakref.WeakValueDictionary()
414 # Get the input argument name from the args.
415 args = arg_spec.args
416 if tf_inspect.ismethod(self.layer_call_method):
417 args = args[1:]
418 self._input_arg_name = args[0] if args else 'inputs'
420 def _generate_input_signature(self, layer):
421 """Inspects layer object and returns the inferred input signature.
423 Args:
424 layer: Layer object.
426 Returns:
427 List of possibly nested TensorSpecs of the layer call function inputs.
428 The list does not contain the `training` argument.
429 """
430 if (isinstance(layer.call, def_function.Function) and
431 layer.call.input_signature is not None):
432 return layer.call.input_signature
433 elif isinstance(layer, training_lib.Model):
434 return saving_utils.model_input_signature(layer)
435 elif (layer.input_spec is not None and
436 layer._use_input_spec_as_call_signature): # pylint: disable=protected-access
438 def to_tensor_spec_or_none(x):
439 spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access
440 # If the shape is too general (e.g. multiple dimensions are allowed),
441 # return None so that separate functions can be generated for each
442 # inferred input signature.
443 # TODO(b/134962016): currently partial signatures are not supported.
444 if spec.shape == tensor_shape.TensorShape(None):
445 return None
446 return spec
447 input_signature = [nest.map_structure(
448 to_tensor_spec_or_none, layer.input_spec)]
450 return input_signature
451 else:
452 return None
454 def add_trace(self, *args, **kwargs):
455 """Traces all functions with the same args and kwargs.
457 Args:
458 *args: Positional args passed to the original function.
459 **kwargs: Keyword args passed to the original function.
460 """
461 args = list(args)
462 kwargs = kwargs.copy()
464 for fn in self._functions.values():
465 # TODO(kathywu): Replace arguments with broader shapes defined in the
466 # input signature.
467 if self._expects_training_arg:
468 def trace_with_training(value, fn=fn):
469 utils.set_training_arg(value, self._training_arg_index, args, kwargs)
470 add_trace_to_queue(fn, args, kwargs, value)
472 trace_with_training(True)
473 trace_with_training(False)
474 else:
475 add_trace_to_queue(fn, args, kwargs)
477 @property
478 def fn_input_signature(self):
479 """Returns input signature for the wrapped layer call function."""
480 if self._has_kwargs:
481 # Input signatures may only describe tensor arguments and kwargs are not
482 # supported.
483 return None
484 if None in nest.flatten(self._input_signature):
485 # TODO(b/134962016): If input signature cannot be partially defined.
486 return None
487 return self._input_signature
489 def training_arg_was_passed(self, args, kwargs):
490 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access
491 return (utils.get_training_arg(self._training_arg_index, args, kwargs)
492 is not None)
493 else:
494 return self.layer._call_arg_was_passed( # pylint: disable=protected-access
495 'training', args, kwargs, inputs_in_args=True)
497 def get_training_arg_value(self, args, kwargs):
498 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access
499 return utils.get_training_arg(self._training_arg_index, args, kwargs)
500 else:
501 return self.layer._get_call_arg_value( # pylint: disable=protected-access
502 'training', args, kwargs, inputs_in_args=True)
504 def get_input_arg_value(self, args, kwargs):
505 return self.layer._get_call_arg_value( # pylint: disable=protected-access
506 self._input_arg_name, args, kwargs, inputs_in_args=True)
508 def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
509 """Wraps call function with added training argument if necessary."""
510 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access
511 # Add training arg to wrapper function.
512 arg_spec = tf_inspect.getfullargspec(call_fn)
513 args = arg_spec.args + ['training']
514 defaults = list(arg_spec.defaults or [])
515 defaults.append(False)
516 new_arg_spec = tf_inspect.FullArgSpec(
517 args=args,
518 varargs=arg_spec.varargs,
519 varkw=arg_spec.varkw,
520 defaults=defaults,
521 kwonlyargs=arg_spec.kwonlyargs,
522 kwonlydefaults=arg_spec.kwonlydefaults,
523 annotations=arg_spec.annotations)
525 # Set new training arg index
526 self._training_arg_index = len(args) - 1
527 if tf_inspect.ismethod(call_fn):
528 self._training_arg_index -= 1
530 def wrap_with_training_arg(*args, **kwargs):
531 if match_layer_training_arg:
532 # Remove the training value, since the original call_fn does not
533 # expect a training arg. Instead, the training value will be
534 # propagated using the call context created in LayerCall.
535 args = list(args)
536 kwargs = kwargs.copy()
537 utils.remove_training_arg(self._training_arg_index, args, kwargs)
538 return call_fn(*args, **kwargs)
540 return tf_decorator.make_decorator(
541 target=call_fn,
542 decorator_func=wrap_with_training_arg,
543 decorator_argspec=new_arg_spec)
545 return call_fn
547 def add_function(self, call_fn, name, match_layer_training_arg):
548 """Adds a layer call function to the collection.
550 Args:
551 call_fn: a python function
552 name: Name of call function
553 match_layer_training_arg: If True, removes the `training` from the
554 function arguments when calling `call_fn`.
556 Returns:
557 LayerCall (tf.function)
558 """
559 fn = LayerCall(
560 self,
561 self._maybe_wrap_with_training_arg(call_fn, match_layer_training_arg),
562 name,
563 input_signature=self.fn_input_signature)
564 self._functions[name] = fn.wrapped_call
565 return fn
567 def trace_with_input_signature(self):
568 """Trace with the layer/models inferred input signature if possible."""
569 if (None not in nest.flatten(self._input_signature) and self._has_kwargs):
570 # Manually add traces for layers that have keyword arguments and have
571 # a fully defined input signature.
572 self.add_trace(*self._input_signature)
575def _filtered_inputs(inputs):
576 return list(filter(tf_utils.is_tensor_or_variable, nest.flatten(inputs)))
579def layer_call_wrapper(call_collection, method, name):
580 """Ensures layer losses are kept the same, and runs method in call context."""
582 # Create wrapper that deals with losses and call context.
583 def wrapper(*args, **kwargs):
584 """Calls method within call context."""
585 layer = call_collection.layer
586 training = None
587 inputs = _filtered_inputs([args, kwargs])
588 # pylint: disable=protected-access
589 if (args or kwargs) and call_collection.training_arg_was_passed(
590 args, kwargs):
591 training = call_collection.get_training_arg_value(args, kwargs)
592 # pylint: enable=protected-access
593 original_losses = _reset_layer_losses(layer)
594 with base_layer_utils.call_context().enter(
595 layer, inputs=inputs, build_graph=False, training=training,
596 saving=True):
597 with autocast_variable.enable_auto_cast_variables(
598 layer._compute_dtype_object): # pylint: disable=protected-access
599 ret = method(*args, **kwargs)
600 _restore_layer_losses(original_losses)
601 return ret
603 # Rename to `name`, since tf.function doesn't have a name argument. Without
604 # this, all functions returned by this method will be named "call", which
605 # would be a nightmare to debug.
606 fn = tf_decorator.make_decorator(target=method, decorator_func=wrapper)
607 fn.__name__ = name
608 return fn
611class LayerCall(object):
612 """Function that triggers traces of other functions in the same collection."""
614 def __init__(self, call_collection, call_fn, name, input_signature):
615 """Initializes a LayerCall object.
617 Args:
618 call_collection: a LayerCallCollection, which contains the other layer
619 call functions (e.g. call_with_conditional_losses, call). These
620 functions should be traced with the same arguments.
621 call_fn: A call function.
622 name: Name of the call function.
623 input_signature: Input signature of call_fn (can be None).
624 """
625 self.call_collection = call_collection
626 self.input_signature = input_signature
627 self.wrapped_call = def_function.function(
628 layer_call_wrapper(call_collection, call_fn, name),
629 input_signature=input_signature)
630 self.original_layer_call = call_collection.layer_call_method
632 def _maybe_trace(self, args, kwargs):
633 # Trigger traces of other call functions + extra training-arg traces.
634 if tracing_enabled():
635 self.call_collection.add_trace(*args, **kwargs)
637 def __call__(self, *args, **kwargs):
638 self._maybe_trace(args, kwargs)
639 return self.wrapped_call(*args, **kwargs)
641 def get_concrete_function(self, *args, **kwargs):
642 self._maybe_trace(args, kwargs)
643 return self.wrapped_call.get_concrete_function(*args, **kwargs)
646def _wrap_call_and_conditional_losses(layer):
647 """Wraps call function that returns a tuple of (outputs, losses).
649 The losses returned are conditional on the inputs passed to the call function.
650 Unconditional losses (e.g. weight regularizeration) are wrapped separately.
652 Args:
653 layer: a Keras layer object
655 Returns:
656 python call function that returns outputs and conditional losses -- excludes
657 activity regularizer
658 """
659 # Create function that generates both outputs and losses
660 layer_call = _get_layer_call_method(layer)
661 def call_and_return_conditional_losses(*args, **kwargs):
662 """Returns layer (call_output, conditional losses) tuple."""
663 call_output = layer_call(*args, **kwargs)
664 if version_utils.is_v1_layer_or_model(layer):
665 conditional_losses = layer.get_losses_for(
666 _filtered_inputs([args, kwargs]))
667 else:
668 conditional_losses = [
669 l for l in layer.losses if not hasattr(l, '_unconditional_loss')
670 ]
671 return call_output, conditional_losses
673 return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
676def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
677 """Returns a function that returns only call function outputs."""
678 if isinstance(layer, keras_load.RevivedLayer):
679 return layer.keras_api.__call__ # pylint: disable=protected-access
680 def call(inputs, *args, **kwargs):
681 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
682 return _create_call_fn_decorator(layer, call)
685def _append_activity_regularizer_loss(
686 layer, call_fn_with_losses, activity_regularizer_fn):
687 """Appends activity regularizer loss to losses returned by the wrapped fn."""
688 def fn(inputs, *args, **kwargs):
689 outputs, losses = call_fn_with_losses(inputs, *args, **kwargs)
690 losses.append(activity_regularizer_fn(outputs))
691 return outputs, losses
692 return _create_call_fn_decorator(layer, fn)
695def _create_call_fn_decorator(layer, wrapped_call):
696 call_fn = _get_layer_call_method(layer)
697 fn, arg_spec = utils.maybe_add_training_arg(
698 call_fn, wrapped_call, layer._expects_training_arg, # pylint: disable=protected-access
699 default_training_value=False)
700 return tf_decorator.make_decorator(
701 target=call_fn,
702 decorator_func=fn,
703 decorator_argspec=arg_spec)
706def _wrap_unconditional_loss(loss_fn, index):
707 """Wraps callable/unconditional loss, returning a serializable function."""
708 # Extract original loss function from partial function
709 fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
710 if isinstance(fn, def_function.Function):
711 return fn
712 else:
713 return def_function.Function(
714 fn, 'loss_fn_{}'.format(index), input_signature=[])
717def _wrap_activity_regularizer(layer):
718 """Wraps the activity regularizer."""
719 # pylint: disable=protected-access
720 if isinstance(layer._activity_regularizer, def_function.Function):
721 return layer._activity_regularizer
722 return def_function.Function(
723 layer._activity_regularizer,
724 '{}_activity_regularizer'.format(layer.name),
725 input_signature=[
726 tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx())
727 ])
728 # pylint: enable=protected-access
731def _get_layer_call_method(layer):
732 if isinstance(layer.call, (def_function.Function)):
733 return layer.call.python_function
734 return layer.call