Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/save_impl.py: 23%
294 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel serialization.
17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should
18go to model_serialization.py.
19"""
21import functools
22import threading
23import weakref
25import tensorflow.compat.v1.logging as logging
26import tensorflow.compat.v2 as tf
28from keras.src import backend
29from keras.src.engine import base_layer_utils
30from keras.src.engine import input_spec
31from keras.src.mixed_precision import autocast_variable
32from keras.src.saving.legacy import saving_utils
33from keras.src.saving.legacy.saved_model import constants
34from keras.src.saving.legacy.saved_model import load as keras_load
35from keras.src.saving.legacy.saved_model import serialized_attributes
36from keras.src.saving.legacy.saved_model import utils
37from keras.src.utils import layer_utils
38from keras.src.utils import tf_contextlib
39from keras.src.utils import tf_utils
40from keras.src.utils import version_utils
41from keras.src.utils.generic_utils import LazyLoader
43# To avoid circular dependencies between keras/engine and keras/saving,
44# code in keras/saving must delay imports.
46# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
47# once the issue with copybara is fixed.
49base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer")
50metrics = LazyLoader("metrics", globals(), "keras.src.metrics")
51input_layer = LazyLoader("input_layer", globals(), "keras.src.engine.input_layer")
52training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training")
53sequential_lib = LazyLoader(
54 "sequential_lib", globals(), "keras.src.engine.sequential"
55)
58def should_skip_serialization(layer):
59 """Skip serializing extra objects and functions if layer inputs aren't
60 set."""
61 saved_model_input_spec_set = (
62 isinstance(layer, training_lib.Model)
63 and layer._saved_model_inputs_spec is not None
64 )
65 if not layer.built and not saved_model_input_spec_set:
66 logging.warning(
67 "Skipping full serialization of Keras layer {}, because "
68 "it is not built.".format(layer)
69 )
70 return True
71 return False
74def _filter_shards(variables):
75 return [var for var in variables if not hasattr(var, "_sharded_container")]
78def wrap_layer_objects(layer, serialization_cache):
79 """Returns extra trackable objects to attach to the serialized layer.
81 Args:
82 layer: Keras Layer object.
83 serialization_cache: Dictionary shared between all objects during
84 serialization.
86 Returns:
87 A dictionary containing all checkpointable objects from a
88 SerializedAttributes object. See LayerAttributes and ModelAttributes for
89 entire list of objects
90 """
91 # Wrap all regularization losses as tf.functions.
92 # First, generate list of all regularization losses in this layer and
93 # sublayers.
94 all_losses = layer._callable_losses[:]
95 for child_layer in utils.list_all_layers(layer):
96 all_losses.extend(child_layer._callable_losses)
97 # Next, wrap all loss functions as tf.functions. Use the serialization cache
98 # to store already-wrapped functions.
99 keras_loss_cache = serialization_cache.setdefault("keras_losses", {})
100 wrapped_loss_functions = []
101 for loss_fn in all_losses:
102 if loss_fn in keras_loss_cache:
103 wrapped_loss_functions.append(keras_loss_cache[loss_fn])
104 else:
105 wrapped_loss = _wrap_unconditional_loss(
106 loss_fn, len(keras_loss_cache)
107 )
108 keras_loss_cache[loss_fn] = wrapped_loss
109 wrapped_loss_functions.append(wrapped_loss)
110 wrapped_layer_losses = [
111 keras_loss_cache[fn] for fn in layer._callable_losses[:]
112 ]
114 layer_metrics = tf.__internal__.tracking.wrap(
115 {m.name: m for m in layer._metrics}
116 )
118 # Avoid duplicate creation of shard Variables on loading.
119 # `layer.variables` will return the shard Variables rather than the
120 # ShardedVariables (b/224541446), but Keras loading will create new
121 # ShardedVariables (and thus shard Variables) from Keras metadata if needed.
122 # There's no need to also save the shard Variables here, so filter them out.
123 variables = _filter_shards(layer.variables)
124 trainable_variables = _filter_shards(layer.trainable_variables)
125 non_trainable_variables = _filter_shards(layer.non_trainable_variables)
126 return dict(
127 variables=tf.__internal__.tracking.wrap(variables),
128 trainable_variables=tf.__internal__.tracking.wrap(trainable_variables),
129 non_trainable_variables=tf.__internal__.tracking.wrap(
130 non_trainable_variables
131 ),
132 layers=tf.__internal__.tracking.wrap(utils.list_all_layers(layer)),
133 metrics=tf.__internal__.tracking.wrap(layer.metrics),
134 regularization_losses=tf.__internal__.tracking.wrap(
135 wrapped_loss_functions
136 ),
137 layer_regularization_losses=tf.__internal__.tracking.wrap(
138 wrapped_layer_losses
139 ),
140 layer_metrics=layer_metrics,
141 )
144def wrap_layer_functions(layer, serialization_cache):
145 """Returns dict of wrapped layer call function and losses in tf.functions.
147 Args:
148 layer: Keras Layer object.
149 serialization_cache: Dictionary shared between all objects during
150 serialization.
152 Returns:
153 A dictionary containing all keras tf.functions to serialize. See
154 LayerAttributes and ModelAttributes for the list of all attributes.
155 """
156 # Since Sequential models may be modified in place using model.add() or
157 # model.pop(), don't use saved functions.
158 if isinstance(layer, keras_load.RevivedLayer) and not isinstance(
159 layer, sequential_lib.Sequential
160 ):
161 return {
162 fn_name: getattr(layer.keras_api, fn_name, None)
163 for fn_name in serialized_attributes.LayerAttributes.all_functions
164 }
166 # Reset the losses of the layer and its children. The call function in each
167 # child layer is replaced with tf.functions.
168 original_fns = _replace_child_layer_functions(layer, serialization_cache)
169 original_losses = _reset_layer_losses(layer)
171 # Wrap all the layer call and activity regularizer functions.
173 # Use LayerCallCollection to ensure that all layer call functions (__call__,
174 # call with losses) are traced with the same inputs.
175 call_collection = LayerCallCollection(layer)
176 call_fn_with_losses = call_collection.add_function(
177 _wrap_call_and_conditional_losses(layer),
178 f"{layer.name}_layer_call_and_return_conditional_losses",
179 # If any of this layer's child layers use the training arg, the traced
180 # call functions of this layer will have a training keyword argument. If
181 # the original layer does not expect the training arg, then it will have
182 # to be removed (by setting `match_layer_training_arg`).
183 match_layer_training_arg=True,
184 )
185 call_fn = call_collection.add_function(
186 _extract_outputs_from_fn(layer, call_fn_with_losses),
187 f"{layer.name}_layer_call_fn",
188 # Since `call_fn` wraps call_fn_with_losses and not the original call
189 # function, `match_layer_training_arg` should be set to False.
190 match_layer_training_arg=False,
191 )
193 fns = {
194 "call_and_return_conditional_losses": call_fn_with_losses,
195 "__call__": call_fn,
196 }
198 if layer._activity_regularizer is not None:
199 fns["activity_regularizer_fn"] = _wrap_activity_regularizer(layer)
200 fns[
201 "call_and_return_all_conditional_losses"
202 ] = call_collection.add_function(
203 _append_activity_regularizer_loss(
204 layer, call_fn_with_losses, fns["activity_regularizer_fn"]
205 ),
206 f"{layer.name}_layer_call_and_return_all_conditional_losses",
207 match_layer_training_arg=False,
208 )
209 else:
210 fns["activity_regularizer_fn"] = None
211 fns["call_and_return_all_conditional_losses"] = call_fn_with_losses
213 # Manually trigger traces before restoring the overwritten functions. The
214 # functions are traced within the layer call context to ensure that layer
215 # functions (e.g. add_loss) behave as though running in graph mode.
216 with tracing_scope():
217 call_collection.trace_with_input_signature()
218 with base_layer_utils.call_context().enter(
219 layer, inputs=None, build_graph=True, training=None, saving=True
220 ):
221 for fn in fns.values():
222 if fn is not None and not isinstance(fn, LayerCall):
223 fn.get_concrete_function()
225 # Restore overwritten functions and losses
226 _restore_child_layer_functions(original_fns)
227 _restore_layer_losses(original_losses)
229 return fns
232def default_save_signature(layer):
233 original_losses = _reset_layer_losses(layer)
234 fn = saving_utils.trace_model_call(layer)
235 _restore_layer_losses(original_losses)
236 return fn
239def _replace_child_layer_functions(layer, serialization_cache):
240 """Replaces functions in the children layers with wrapped tf.functions.
242 This step allows functions from parent layers to reference the wrapped
243 functions from their children layers instead of retracing the ops.
245 This function also resets all losses stored in the layer. These are stored
246 in the returned dictionary. Use `_restore_child_layer_functions` to restore
247 the original attributes.
249 Args:
250 layer: Keras Layer object.
251 serialization_cache: Dictionary shared between all objects during
252 serialization.
254 Returns:
255 Dictionary mapping layer objects -> original functions and losses:
256 { Child layer 1: {
257 'losses': Original losses,
258 'call': Original call function
259 '_activity_regularizer': Original activity regularizer},
260 Child layer 2: ...
261 }
262 """
264 original_fns = {}
266 def replace_layer_functions(child_layer, serialized_fns):
267 """Replaces layer call and activity regularizer with wrapped
268 functions."""
269 original_fns[child_layer] = {
270 "call": child_layer.call,
271 "_activity_regularizer": child_layer._activity_regularizer,
272 }
273 with utils.no_automatic_dependency_tracking_scope(child_layer):
274 try:
275 child_layer._activity_regularizer = serialized_fns.get(
276 "activity_regularizer_fn"
277 )
278 except AttributeError:
279 # Some layers have an unsettable activity regularizer.
280 pass
281 child_layer.call = utils.use_wrapped_call(
282 child_layer,
283 serialized_fns["call_and_return_conditional_losses"],
284 child_layer._call_spec,
285 default_training_value=False,
286 )
288 def replace_metric_functions(child_layer, serialized_fns):
289 """Replaces metric functions with wrapped functions."""
290 original_fns[child_layer] = {
291 "__call__": child_layer.__call__,
292 "result": child_layer.result,
293 "update_state": child_layer.update_state,
294 }
295 with utils.no_automatic_dependency_tracking_scope(child_layer):
296 child_layer.__call__ = serialized_fns["__call__"]
297 child_layer.result = serialized_fns["result"]
298 child_layer.update_state = serialized_fns["update_state"]
300 for child_layer in utils.list_all_layers(layer):
301 if isinstance(child_layer, input_layer.InputLayer):
302 continue
304 if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
305 serialized_functions = child_layer._trackable_saved_model_saver._get_serialized_attributes( # noqa: E501
306 serialization_cache
307 ).functions
308 else:
309 serialized_functions = serialization_cache[
310 constants.KERAS_CACHE_KEY
311 ][child_layer].functions
312 if not serialized_functions:
313 # This indicates either:
314 # - circular dependency, which means the current layer's functions
315 # should be wrapped first.
316 # - Child layer's inputs are not defined, so its functions have
317 # not been wrapped. In this case, no replacement is necessary so
318 # move on to the next child.
319 continue
321 if isinstance(child_layer, metrics.Metric):
322 replace_metric_functions(child_layer, serialized_functions)
323 else:
324 replace_layer_functions(child_layer, serialized_functions)
326 return original_fns
329def _restore_child_layer_functions(original_fns):
330 """Restores attributes replaced with `_replace_child_layer_functions`."""
331 for child_layer, fns in original_fns.items():
332 with utils.no_automatic_dependency_tracking_scope(child_layer):
333 for fn_name, fn in fns.items():
334 try:
335 setattr(child_layer, fn_name, fn)
336 except AttributeError:
337 # In the case of _activity_regularizer, setting the
338 # attribute may be disallowed.
339 pass
342def _reset_layer_losses(parent_layer):
343 """Resets losses of layer and its sublayers, and returns original losses."""
344 losses_dict = {}
345 for layer in utils.list_all_layers_and_sublayers(parent_layer):
346 losses_dict[layer] = {
347 "losses": layer._losses[:],
348 "eager_losses": layer._eager_losses[:],
349 }
350 with utils.no_automatic_dependency_tracking_scope(layer):
351 layer._losses = []
352 layer._eager_losses = []
353 return losses_dict
356def _restore_layer_losses(losses_dict):
357 for layer in losses_dict:
358 with utils.no_automatic_dependency_tracking_scope(layer):
359 layer._losses = losses_dict[layer]["losses"]
360 layer._eager_losses = losses_dict[layer]["eager_losses"]
363class LayerTracingContext(threading.local):
364 def __init__(self):
365 super().__init__()
366 self.enable_call_tracing = False
367 self.trace_queue = []
370_thread_local_data = LayerTracingContext()
373@tf_contextlib.contextmanager
374def tracing_scope():
375 """Enables tracing scope."""
376 # This enables the LayerCallCollection's tracing mechanism to trace all call
377 # functions in the collection.
378 previous_value = _thread_local_data.enable_call_tracing
379 previous_queue = _thread_local_data.trace_queue
380 try:
381 _thread_local_data.enable_call_tracing = True
382 _thread_local_data.trace_queue = []
383 yield
384 finally:
385 # Run traces from the queue.
386 while _thread_local_data.trace_queue:
387 fn, args, kwargs, training = _thread_local_data.trace_queue.pop(0)
388 if training is not None:
389 with backend.deprecated_internal_learning_phase_scope(training):
390 fn.get_concrete_function(*args, **kwargs)
391 else:
392 fn.get_concrete_function(*args, **kwargs)
393 _thread_local_data.trace_queue = previous_queue
394 _thread_local_data.enable_call_tracing = previous_value
397def add_trace_to_queue(fn, args, kwargs, training=None):
398 if tracing_enabled():
399 _thread_local_data.trace_queue.append(
400 (fn, args[:], kwargs.copy(), training)
401 )
404def tracing_enabled():
405 """Whether to add extra traces to the queue."""
406 return _thread_local_data.enable_call_tracing
409class LayerCallCollection:
410 """Groups wrapped layer call functions.
412 This is used to ensure that all layer call functions are traced with the
413 same inputs-
414 - call
415 - call_and_return_conditional_losses
416 - call_and_return_all_conditional_losses
417 """
419 def __init__(self, layer):
420 self.layer = layer
422 self.layer_call_method = _get_layer_call_method(layer)
423 self._expects_training_arg = utils.layer_uses_training_bool(layer)
424 self._call_spec = layer._call_spec
426 # Create new call spec if the layer itself does not accept a training
427 # arg, but one of its child layers does. When this layer's call
428 # functions are traced, they will be traced with an added `training`
429 # keyword argument.
430 if not self.layer._expects_training_arg and self._expects_training_arg:
431 arg_spec = utils.set_training_arg_spec(
432 self._call_spec.full_argspec, False
433 )
434 self._call_spec = layer_utils.CallFunctionSpec(arg_spec)
436 self._layer_inputs = self._get_layer_inputs(layer)
437 self._functions = weakref.WeakValueDictionary()
439 # Get the input argument name from the args.
440 if self._call_spec.arg_names:
441 self._input_arg_name = self._call_spec.arg_names[0]
442 else:
443 # Layer could be defined with only varargs, in which case use a
444 # default name.
445 self._input_arg_name = "inputs"
447 def _get_layer_inputs(self, layer):
448 """Inspects layer object and returns the inferred input signature.
450 Args:
451 layer: Layer object.
453 Returns:
454 List of possibly nested TensorSpecs of the layer call function inputs
455 in the form of `(args, kwargs)`
456 """
457 if (
458 isinstance(layer.call, tf.__internal__.function.Function)
459 and layer.call.input_signature is not None
460 ):
461 return layer.call.input_signature, {}
462 elif isinstance(layer, training_lib.Model):
463 return saving_utils.model_call_inputs(layer)
464 elif (
465 layer.input_spec is not None
466 and layer._use_input_spec_as_call_signature
467 ):
469 def to_tensor_spec_or_none(x):
470 spec = input_spec.to_tensor_spec(x, layer._compute_dtype)
471 # If the shape is too general (e.g. multiple dimensions are
472 # allowed), return None so that separate functions can be
473 # generated for each inferred input signature.
474 # TODO(b/134962016): currently partial signatures are not
475 # supported.
476 if spec.shape == tf.TensorShape(None):
477 return None, None
478 return spec
480 input_signature = [
481 tf.nest.map_structure(to_tensor_spec_or_none, layer.input_spec)
482 ]
484 return input_signature, {}
485 else:
486 return None, None
488 def add_trace(self, *args, **kwargs):
489 """Traces all functions with the same args and kwargs.
491 Args:
492 *args: Positional args passed to the original function.
493 **kwargs: Keyword args passed to the original function.
494 """
495 args = list(args)
496 kwargs = kwargs.copy()
498 for fn in self._functions.values():
499 # TODO(kathywu): Replace arguments with broader shapes defined in
500 # the input signature.
501 if self._expects_training_arg:
503 def trace_with_training(value, fn=fn):
504 nonlocal args, kwargs
505 (args, kwargs,) = self._call_spec.set_arg_value(
506 "training", value, args, kwargs, inputs_in_args=True
507 )
508 add_trace_to_queue(fn, args, kwargs, value)
510 trace_with_training(True)
511 trace_with_training(False)
512 else:
513 add_trace_to_queue(fn, args, kwargs)
515 def training_arg_was_passed(self, args, kwargs):
516 return self._call_spec.arg_was_passed(
517 "training", args, kwargs, inputs_in_args=True
518 )
520 def get_training_arg_value(self, args, kwargs):
521 try:
522 return self._call_spec.get_arg_value(
523 "training", args, kwargs, inputs_in_args=True
524 )
525 except KeyError: # Training is not in args or kwargs.
526 return None
528 def get_input_arg_value(self, args, kwargs):
529 return self._call_spec.get_arg_value(
530 self._input_arg_name, args, kwargs, inputs_in_args=True
531 )
533 def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
534 """Wraps call function with added training argument if necessary."""
535 if not self.layer._expects_training_arg and self._expects_training_arg:
536 # Add training arg to wrapper function.
537 def wrap_with_training_arg(*args, **kwargs):
538 if match_layer_training_arg:
539 # Remove the training value, since the original call_fn does
540 # not expect a training arg. Instead, the training value
541 # will be propagated using the call context created in
542 # LayerCall.
543 args = list(args)
544 kwargs = kwargs.copy()
545 (args, kwargs,) = self._call_spec.set_arg_value(
546 "training",
547 None,
548 args,
549 kwargs,
550 inputs_in_args=True,
551 pop_kwarg_if_none=True,
552 )
553 return call_fn(*args, **kwargs)
555 return tf.__internal__.decorator.make_decorator(
556 target=call_fn,
557 decorator_func=wrap_with_training_arg,
558 decorator_argspec=self._call_spec.full_argspec,
559 )
561 return call_fn
563 def add_function(self, call_fn, name, match_layer_training_arg):
564 """Adds a layer call function to the collection.
566 Args:
567 call_fn: a python function
568 name: Name of call function
569 match_layer_training_arg: If True, removes the `training` from the
570 function arguments when calling `call_fn`.
572 Returns:
573 LayerCall (tf.function)
574 """
575 fn = LayerCall(
576 self,
577 self._maybe_wrap_with_training_arg(
578 call_fn, match_layer_training_arg
579 ),
580 name,
581 )
582 self._functions[name] = fn.wrapped_call
583 return fn
585 def trace_with_input_signature(self):
586 """Trace with the layer/models inferred input signature if possible."""
587 if self._layer_inputs[0] is None:
588 return
590 args, kwargs = self._layer_inputs
591 if self._expects_training_arg:
592 args, kwargs = self._call_spec.set_arg_value(
593 "training", False, args, kwargs, inputs_in_args=True
594 )
595 if None not in tf.nest.flatten([args, kwargs]):
596 # Manually add traces for layers that have keyword arguments and
597 # have a fully defined input signature.
598 self.add_trace(*args, **kwargs)
601def _filtered_inputs(inputs):
602 return list(filter(tf_utils.is_tensor_or_variable, tf.nest.flatten(inputs)))
605def layer_call_wrapper(call_collection, method, name):
606 """Ensures layer losses are kept the same, and runs method in call
607 context."""
609 # Create wrapper that deals with losses and call context.
610 def wrapper(*args, **kwargs):
611 """Calls method within call context."""
612 layer = call_collection.layer
613 training = None
614 inputs = _filtered_inputs([args, kwargs])
616 if (args or kwargs) and call_collection.training_arg_was_passed(
617 args, kwargs
618 ):
619 training = call_collection.get_training_arg_value(args, kwargs)
621 original_losses = _reset_layer_losses(layer)
622 with base_layer_utils.call_context().enter(
623 layer,
624 inputs=inputs,
625 build_graph=False,
626 training=training,
627 saving=True,
628 ):
629 with autocast_variable.enable_auto_cast_variables(
630 layer._compute_dtype_object
631 ):
632 ret = method(*args, **kwargs)
633 _restore_layer_losses(original_losses)
634 return ret
636 # Rename to `name`, since tf.function doesn't have a name argument. Without
637 # this, all functions returned by this method will be named "call", which
638 # would be a nightmare to debug.
639 fn = tf.__internal__.decorator.make_decorator(
640 target=method, decorator_func=wrapper
641 )
642 fn.__name__ = name
643 return fn
646class LayerCall:
647 """Function that triggers traces of other functions in the same
648 collection."""
650 def __init__(self, call_collection, call_fn, name):
651 """Initializes a LayerCall object.
653 Args:
654 call_collection: a LayerCallCollection, which contains the other layer
655 call functions (e.g. call_with_conditional_losses, call). These
656 functions should be traced with the same arguments.
657 call_fn: A call function.
658 name: Name of the call function.
659 """
660 self.call_collection = call_collection
661 self.wrapped_call = tf.function(
662 layer_call_wrapper(call_collection, call_fn, name)
663 )
665 def _maybe_trace(self, args, kwargs):
666 # Trigger traces of other call functions + extra training-arg traces.
667 if tracing_enabled():
668 self.call_collection.add_trace(*args, **kwargs)
670 def __call__(self, *args, **kwargs):
671 self._maybe_trace(args, kwargs)
672 return self.wrapped_call(*args, **kwargs)
674 def get_concrete_function(self, *args, **kwargs):
675 self._maybe_trace(args, kwargs)
676 return self.wrapped_call.get_concrete_function(*args, **kwargs)
679def _wrap_call_and_conditional_losses(layer):
680 """Wraps call function that returns a tuple of (outputs, losses).
682 The losses returned are conditional on the inputs passed to the call
683 function. Unconditional losses (e.g. weight regularizeration) are wrapped
684 separately.
686 Args:
687 layer: a Keras layer object
689 Returns:
690 python call function that returns outputs and conditional losses --
691 excludes activity regularizer
692 """
693 # Create function that generates both outputs and losses
694 layer_call = _get_layer_call_method(layer)
696 def call_and_return_conditional_losses(*args, **kwargs):
697 """Returns layer (call_output, conditional losses) tuple."""
698 call_output = layer_call(*args, **kwargs)
699 if version_utils.is_v1_layer_or_model(layer):
700 conditional_losses = layer.get_losses_for(
701 _filtered_inputs([args, kwargs])
702 )
703 else:
704 conditional_losses = [
705 l for l in layer.losses if not hasattr(l, "_unconditional_loss")
706 ]
707 return call_output, conditional_losses
709 return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
712def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
713 """Returns a function that returns only call function outputs."""
714 if isinstance(layer, keras_load.RevivedLayer):
715 return layer.keras_api.__call__
717 def call(inputs, *args, **kwargs):
718 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
720 return _create_call_fn_decorator(layer, call)
723def _append_activity_regularizer_loss(
724 layer, call_fn_with_losses, activity_regularizer_fn
725):
726 """Appends activity regularizer loss to losses returned by the wrapped
727 fn."""
729 def fn(inputs, *args, **kwargs):
730 outputs, losses = call_fn_with_losses(inputs, *args, **kwargs)
731 losses.append(activity_regularizer_fn(outputs))
732 return outputs, losses
734 return _create_call_fn_decorator(layer, fn)
737def _create_call_fn_decorator(layer, wrapped_call):
738 call_fn = _get_layer_call_method(layer)
739 fn, arg_spec = utils.maybe_add_training_arg(
740 layer._call_spec,
741 wrapped_call,
742 layer._expects_training_arg,
743 default_training_value=False,
744 )
745 return tf.__internal__.decorator.make_decorator(
746 target=call_fn, decorator_func=fn, decorator_argspec=arg_spec
747 )
750def _wrap_unconditional_loss(loss_fn, index):
751 """Wraps callable/unconditional loss, returning a serializable function."""
752 # Extract original loss function from partial function
753 fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
754 if isinstance(fn, tf.__internal__.function.Function):
755 return fn
756 else:
757 return tf.__internal__.function.Function(
758 fn, f"loss_fn_{index}", input_signature=[]
759 )
762def _wrap_activity_regularizer(layer):
763 """Wraps the activity regularizer."""
765 if isinstance(
766 layer._activity_regularizer, tf.__internal__.function.Function
767 ):
768 return layer._activity_regularizer
769 return tf.__internal__.function.Function(
770 layer._activity_regularizer,
771 f"{layer.name}_activity_regularizer",
772 input_signature=[
773 tf.TensorSpec(None, layer._compute_dtype or backend.floatx())
774 ],
775 )
778def _get_layer_call_method(layer):
779 if isinstance(layer.call, (tf.__internal__.function.Function)):
780 return layer.call.python_function
781 return layer.call