Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/base_layer.py: 22%
1285 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Contains the base Layer class, from which all layers inherit."""
19import collections
20import contextlib
21import functools
22import itertools
23import textwrap
24import threading
25import warnings
26import weakref
28import numpy as np
29import tensorflow.compat.v2 as tf
31from keras.src import backend
32from keras.src import constraints
33from keras.src import initializers
34from keras.src import regularizers
35from keras.src.dtensor import lazy_variable
36from keras.src.engine import base_layer_utils
37from keras.src.engine import input_spec
38from keras.src.engine import keras_tensor
39from keras.src.engine import node as node_module
40from keras.src.mixed_precision import autocast_variable
41from keras.src.mixed_precision import policy
42from keras.src.saving import serialization_lib
43from keras.src.saving.legacy.saved_model import layer_serialization
44from keras.src.utils import generic_utils
45from keras.src.utils import layer_utils
46from keras.src.utils import object_identity
47from keras.src.utils import tf_inspect
48from keras.src.utils import tf_utils
49from keras.src.utils import traceback_utils
50from keras.src.utils import version_utils
52# A module that only depends on `keras.layers` import these from here.
53from keras.src.utils.generic_utils import to_snake_case # noqa: F401
54from keras.src.utils.tf_utils import is_tensor_or_tensor_list # noqa: F401
56# isort: off
57from google.protobuf import json_format
58from tensorflow.python.platform import tf_logging
59from tensorflow.python.util.tf_export import (
60 get_canonical_name_for_symbol,
61)
62from tensorflow.python.util.tf_export import keras_export
63from tensorflow.tools.docs import doc_controls
66metrics_mod = generic_utils.LazyLoader(
67 "metrics_mod", globals(), "keras.src.metrics"
68)
71# Prefix that is added to the TF op layer names.
72_TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"
74# TODO(mdan): Should we have a single generic type for types that can be passed
75# to tf.cast?
76_AUTOCAST_TYPES = (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)
78keras_layers_gauge = tf.__internal__.monitoring.BoolGauge(
79 "/tensorflow/api/keras/layers", "keras layers usage", "method"
80)
81keras_models_gauge = tf.__internal__.monitoring.BoolGauge(
82 "/tensorflow/api/keras/models", "keras model usage", "method"
83)
84keras_api_gauge = tf.__internal__.monitoring.BoolGauge(
85 "/tensorflow/api/keras", "keras api usage", "method"
86)
87keras_premade_model_gauge = tf.__internal__.monitoring.BoolGauge(
88 "/tensorflow/api/keras/premade_models", "premade keras model usage", "type"
89)
91_is_name_scope_on_model_declaration_enabled = False
93_name_scope_unnester_stack = threading.local()
96@contextlib.contextmanager
97def _name_scope_unnester(full_name_scope):
98 """Helper to get relative name scope from fully-speced nested name scopes.
100 Args:
101 full_name_scope: full(absolute) name scope path.
103 Yields:
104 Relative name scope path from the parent `_name_scope_unnester` context
105 manager.
107 Example:
108 ```
109 with _name_scope_unnester('a') as name1: # name1 == 'a'
110 with _name_scope_unnester('a/b') as name2: # name2 == 'b'
111 with _name_scope_unnester('a/b/c') as name3: # name3 == 'c'
112 pass
113 ```
114 """
115 if not getattr(_name_scope_unnester_stack, "value", None):
116 _name_scope_unnester_stack.value = [""]
118 _name_scope_unnester_stack.value.append(full_name_scope)
120 try:
121 full_name_scope = _name_scope_unnester_stack.value[-1]
122 outer_name_scope = _name_scope_unnester_stack.value[-2]
123 relative_name_scope = full_name_scope.lstrip(outer_name_scope)
124 relative_name_scope = relative_name_scope.lstrip("/")
125 yield relative_name_scope
126 finally:
127 _name_scope_unnester_stack.value.pop()
130@keras_export("keras.layers.Layer")
131class Layer(tf.Module, version_utils.LayerVersionSelector):
132 """This is the class from which all layers inherit.
134 A layer is a callable object that takes as input one or more tensors and
135 that outputs one or more tensors. It involves *computation*, defined
136 in the `call()` method, and a *state* (weight variables). State can be
137 created in various places, at the convenience of the subclass implementer:
139 * in `__init__()`;
140 * in the optional `build()` method, which is invoked by the first
141 `__call__()` to the layer, and supplies the shape(s) of the input(s),
142 which may not have been known at initialization time;
143 * in the first invocation of `call()`, with some caveats discussed
144 below.
146 Layers are recursively composable: If you assign a Layer instance as an
147 attribute of another Layer, the outer layer will start tracking the weights
148 created by the inner layer. Nested layers should be instantiated in the
149 `__init__()` method.
151 Users will just instantiate a layer and then treat it as a callable.
153 Args:
154 trainable: Boolean, whether the layer's variables should be trainable.
155 name: String name of the layer.
156 dtype: The dtype of the layer's computations and weights. Can also be a
157 `tf.keras.mixed_precision.Policy`, which allows the computation and
158 weight dtype to differ. Default of `None` means to use
159 `tf.keras.mixed_precision.global_policy()`, which is a float32 policy
160 unless set to different value.
161 dynamic: Set this to `True` if your layer should only be run eagerly, and
162 should not be used to generate a static computation graph.
163 This would be the case for a Tree-RNN or a recursive network,
164 for example, or generally for any layer that manipulates tensors
165 using Python control flow. If `False`, we assume that the layer can
166 safely be used to generate a static computation graph.
168 Attributes:
169 name: The name of the layer (string).
170 dtype: The dtype of the layer's weights.
171 variable_dtype: Alias of `dtype`.
172 compute_dtype: The dtype of the layer's computations. Layers automatically
173 cast inputs to this dtype which causes the computations and output to
174 also be in this dtype. When mixed precision is used with a
175 `tf.keras.mixed_precision.Policy`, this will be different than
176 `variable_dtype`.
177 dtype_policy: The layer's dtype policy. See the
178 `tf.keras.mixed_precision.Policy` documentation for details.
179 trainable_weights: List of variables to be included in backprop.
180 non_trainable_weights: List of variables that should not be
181 included in backprop.
182 weights: The concatenation of the lists trainable_weights and
183 non_trainable_weights (in this order).
184 trainable: Whether the layer should be trained (boolean), i.e. whether
185 its potentially-trainable weights should be returned as part of
186 `layer.trainable_weights`.
187 input_spec: Optional (list of) `InputSpec` object(s) specifying the
188 constraints on inputs that can be accepted by the layer.
190 We recommend that descendants of `Layer` implement the following methods:
192 * `__init__()`: Defines custom layer attributes, and creates layer weights
193 that do not depend on input shapes, using `add_weight()`, or other state.
194 * `build(self, input_shape)`: This method can be used to create weights that
195 depend on the shape(s) of the input(s), using `add_weight()`, or other
196 state. `__call__()` will automatically build the layer (if it has not been
197 built yet) by calling `build()`.
198 * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making
199 sure `build()` has been called. `call()` performs the logic of applying
200 the layer to the `inputs`. The first invocation may additionally create
201 state that could not be conveniently created in `build()`; see its
202 docstring for details.
203 Two reserved keyword arguments you can optionally use in `call()` are:
204 - `training` (boolean, whether the call is in inference mode or training
205 mode). See more details in [the layer/model subclassing guide](
206 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method)
207 - `mask` (boolean tensor encoding masked timesteps in the input, used
208 in RNN layers). See more details in
209 [the layer/model subclassing guide](
210 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method)
211 A typical signature for this method is `call(self, inputs)`, and user
212 could optionally add `training` and `mask` if the layer need them. `*args`
213 and `**kwargs` is only useful for future extension when more input
214 parameters are planned to be added.
215 * `get_config(self)`: Returns a dictionary containing the configuration used
216 to initialize this layer. If the keys differ from the arguments
217 in `__init__`, then override `from_config(self)` as well.
218 This method is used when saving
219 the layer or a model that contains this layer.
221 Examples:
223 Here's a basic example: a layer with two variables, `w` and `b`,
224 that returns `y = w . x + b`.
225 It shows how to implement `build()` and `call()`.
226 Variables set as attributes of a layer are tracked as weights
227 of the layers (in `layer.weights`).
229 ```python
230 class SimpleDense(Layer):
232 def __init__(self, units=32):
233 super(SimpleDense, self).__init__()
234 self.units = units
236 def build(self, input_shape): # Create the state of the layer (weights)
237 w_init = tf.random_normal_initializer()
238 self.w = tf.Variable(
239 initial_value=w_init(shape=(input_shape[-1], self.units),
240 dtype='float32'),
241 trainable=True)
242 b_init = tf.zeros_initializer()
243 self.b = tf.Variable(
244 initial_value=b_init(shape=(self.units,), dtype='float32'),
245 trainable=True)
247 def call(self, inputs): # Defines the computation from inputs to outputs
248 return tf.matmul(inputs, self.w) + self.b
250 # Instantiates the layer.
251 linear_layer = SimpleDense(4)
253 # This will also call `build(input_shape)` and create the weights.
254 y = linear_layer(tf.ones((2, 2)))
255 assert len(linear_layer.weights) == 2
257 # These weights are trainable, so they're listed in `trainable_weights`:
258 assert len(linear_layer.trainable_weights) == 2
259 ```
261 Note that the method `add_weight()` offers a shortcut to create weights:
263 ```python
264 class SimpleDense(Layer):
266 def __init__(self, units=32):
267 super(SimpleDense, self).__init__()
268 self.units = units
270 def build(self, input_shape):
271 self.w = self.add_weight(shape=(input_shape[-1], self.units),
272 initializer='random_normal',
273 trainable=True)
274 self.b = self.add_weight(shape=(self.units,),
275 initializer='random_normal',
276 trainable=True)
278 def call(self, inputs):
279 return tf.matmul(inputs, self.w) + self.b
280 ```
282 Besides trainable weights, updated via backpropagation during training,
283 layers can also have non-trainable weights. These weights are meant to
284 be updated manually during `call()`. Here's a example layer that computes
285 the running sum of its inputs:
287 ```python
288 class ComputeSum(Layer):
290 def __init__(self, input_dim):
291 super(ComputeSum, self).__init__()
292 # Create a non-trainable weight.
293 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)),
294 trainable=False)
296 def call(self, inputs):
297 self.total.assign_add(tf.reduce_sum(inputs, axis=0))
298 return self.total
300 my_sum = ComputeSum(2)
301 x = tf.ones((2, 2))
303 y = my_sum(x)
304 print(y.numpy()) # [2. 2.]
306 y = my_sum(x)
307 print(y.numpy()) # [4. 4.]
309 assert my_sum.weights == [my_sum.total]
310 assert my_sum.non_trainable_weights == [my_sum.total]
311 assert my_sum.trainable_weights == []
312 ```
314 For more information about creating layers, see the guide
315 [Making new Layers and Models via subclassing](
316 https://www.tensorflow.org/guide/keras/custom_layers_and_models)
317 """
319 @tf.__internal__.tracking.no_automatic_dependency_tracking
320 def __init__(
321 self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
322 ):
323 self._instrument_layer_creation()
325 # These properties should be set by the user via keyword arguments.
326 # note that 'dtype', 'input_shape' and 'batch_input_shape'
327 # are only applicable to input layers: do not pass these keywords
328 # to non-input layers.
329 allowed_kwargs = {
330 "input_dim",
331 "input_shape",
332 "batch_input_shape",
333 "batch_size",
334 "weights",
335 "activity_regularizer",
336 "autocast",
337 "implementation",
338 }
339 # Validate optional keyword arguments.
340 generic_utils.validate_kwargs(kwargs, allowed_kwargs)
342 # Mutable properties
343 # Indicates whether the layer's weights are updated during training
344 # and whether the layer's updates are run during training.
345 if not (
346 isinstance(trainable, bool)
347 or (
348 isinstance(trainable, (tf.Tensor, tf.Variable))
349 and trainable.dtype is tf.bool
350 )
351 ):
352 raise TypeError(
353 "Expected `trainable` argument to be a boolean, "
354 f"but got: {trainable}"
355 )
356 self._trainable = trainable
357 # A stateful layer is a layer whose updates are run during inference
358 # too, for instance stateful RNNs.
359 self._stateful = False
360 # Indicates whether `build` needs to be called upon layer call, to
361 # create the layer's weights. (Note that the first call() may also
362 # create weights, independent of build().)
363 self.built = False
364 # Provides information about which inputs are compatible with the layer.
365 self._input_spec = None
367 # SavedModel-related attributes.
368 # Record the build input shape for loading purposes.
369 # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
370 # submitted.
371 self._build_input_shape = None
372 self._saved_model_inputs_spec = None
373 self._saved_model_arg_spec = None
375 # `Layer.compute_mask` will be called at the end of `Layer.__call__` if
376 # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets
377 # `self.supports_masking=True`.
378 self._supports_masking = not generic_utils.is_default(self.compute_mask)
380 self._init_set_name(name)
381 self._activity_regularizer = regularizers.get(
382 kwargs.pop("activity_regularizer", None)
383 )
384 self._maybe_create_attribute("_trainable_weights", [])
385 self._maybe_create_attribute("_non_trainable_weights", [])
386 self._updates = []
387 # Object to store all thread local layer properties.
388 self._thread_local = threading.local()
389 # A list of zero-argument lambdas which return Tensors, used for
390 # variable regularizers.
391 self._callable_losses = []
392 # A list of symbolic Tensors containing activity regularizers and losses
393 # manually added through `add_loss` in graph-building mode.
394 self._losses = []
395 # A list of metric instances corresponding to the symbolic metric
396 # tensors added using the `add_metric` API.
397 self._metrics = []
398 # Ensures the same metric is not added multiple times in
399 # `MirroredStrategy`.
400 self._metrics_lock = threading.Lock()
402 # Note that models also have a dtype policy, as they are layers. For
403 # functional models, the policy is only used in Model.compile, which
404 # wraps the optimizer with a LossScaleOptimizer if the policy name is
405 # "mixed_float16". Subclassed models additionally use the policy's
406 # compute and variable dtypes, as like any ordinary layer.
407 self._set_dtype_policy(dtype)
408 # Boolean indicating whether the layer automatically casts its inputs to
409 # the layer's compute_dtype.
410 self._autocast = kwargs.get(
411 "autocast", base_layer_utils.v2_dtype_behavior_enabled()
412 )
414 # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s.
415 # Ordered by when the object was assigned as an attr.
416 # Entries are unique.
417 self._maybe_create_attribute("_self_tracked_trackables", [])
419 # These lists will be filled via successive calls
420 # to self._add_inbound_node().
421 # Used in symbolic mode only, only in conjunction with graph-networks
422 self._inbound_nodes_value = []
423 self._outbound_nodes_value = []
425 self._init_call_fn_args()
427 # Whether the `call` method can be used to build a TF graph without
428 # issues. This attribute has no effect if the model is created using
429 # the Functional API. Instead, `model.dynamic` is determined based on
430 # the internal layers.
431 if not isinstance(dynamic, bool):
432 raise TypeError(
433 "Expected `dynamic` argument to be a boolean, "
434 f"but got: {dynamic}"
435 )
436 self._dynamic = dynamic
438 # Manage input shape information if passed.
439 if "input_dim" in kwargs and "input_shape" not in kwargs:
440 # Backwards compatibility: alias 'input_dim' to 'input_shape'.
441 kwargs["input_shape"] = (kwargs["input_dim"],)
442 if "input_shape" in kwargs or "batch_input_shape" in kwargs:
443 # In this case we will later create an input layer
444 # to insert before the current layer
445 if "batch_input_shape" in kwargs:
446 batch_input_shape = tuple(kwargs["batch_input_shape"])
447 elif "input_shape" in kwargs:
448 if "batch_size" in kwargs:
449 batch_size = kwargs["batch_size"]
450 else:
451 batch_size = None
452 batch_input_shape = (batch_size,) + tuple(kwargs["input_shape"])
453 self._batch_input_shape = batch_input_shape
455 # Manage initial weight values if passed.
456 self._initial_weights = kwargs.get("weights", None)
458 # Whether the layer will track any layers that is set as attribute on
459 # itself as sub-layers, the weights from the sub-layers will be included
460 # in the parent layer's variables() as well. Defaults to `True`, which
461 # means auto tracking is turned on. Certain subclass might want to turn
462 # it off, like Sequential model.
463 self._auto_track_sub_layers = True
465 # For backwards compat reasons, most built-in layers do not guarantee
466 # That they will 100% preserve the structure of input args when saving
467 # / loading configs. E.g. they may un-nest an arg that is
468 # a list with one element.
469 self._preserve_input_structure_in_config = False
471 # Save outer name scope at layer declaration so that it is preserved at
472 # the actual layer construction.
473 self._name_scope_on_declaration = tf.get_current_name_scope()
475 # Save the temp regularization losses created in the DTensor use case.
476 # When DTensor is enable, we will first create LazyInitVariable and then
477 # DVariable with proper layout afterward. For the weights regularization
478 # loss, we have to create against the DVariable as well.
479 self._captured_weight_regularizer = []
481 @tf.__internal__.tracking.no_automatic_dependency_tracking
482 @generic_utils.default
483 def build(self, input_shape):
484 """Creates the variables of the layer (for subclass implementers).
486 This is a method that implementers of subclasses of `Layer` or `Model`
487 can override if they need a state-creation step in-between
488 layer instantiation and layer call. It is invoked automatically before
489 the first execution of `call()`.
491 This is typically used to create the weights of `Layer` subclasses
492 (at the discretion of the subclass implementer).
494 Args:
495 input_shape: Instance of `TensorShape`, or list of instances of
496 `TensorShape` if the layer expects a list of inputs
497 (one instance per input).
498 """
499 self._build_input_shape = input_shape
500 self.built = True
502 @doc_controls.for_subclass_implementers
503 def call(self, inputs, *args, **kwargs):
504 """This is where the layer's logic lives.
506 The `call()` method may not create state (except in its first
507 invocation, wrapping the creation of variables or other resources in
508 `tf.init_scope()`). It is recommended to create state, including
509 `tf.Variable` instances and nested `Layer` instances,
510 in `__init__()`, or in the `build()` method that is
511 called automatically before `call()` executes for the first time.
513 Args:
514 inputs: Input tensor, or dict/list/tuple of input tensors.
515 The first positional `inputs` argument is subject to special rules:
516 - `inputs` must be explicitly passed. A layer cannot have zero
517 arguments, and `inputs` cannot be provided via the default value
518 of a keyword argument.
519 - NumPy array or Python scalar values in `inputs` get cast as
520 tensors.
521 - Keras mask metadata is only collected from `inputs`.
522 - Layers are built (`build(input_shape)` method)
523 using shape info from `inputs` only.
524 - `input_spec` compatibility is only checked against `inputs`.
525 - Mixed precision input casting is only applied to `inputs`.
526 If a layer has tensor arguments in `*args` or `**kwargs`, their
527 casting behavior in mixed precision should be handled manually.
528 - The SavedModel input specification is generated using `inputs`
529 only.
530 - Integration with various ecosystem packages like TFMOT, TFLite,
531 TF.js, etc is only supported for `inputs` and not for tensors in
532 positional and keyword arguments.
533 *args: Additional positional arguments. May contain tensors, although
534 this is not recommended, for the reasons above.
535 **kwargs: Additional keyword arguments. May contain tensors, although
536 this is not recommended, for the reasons above.
537 The following optional keyword arguments are reserved:
538 - `training`: Boolean scalar tensor of Python boolean indicating
539 whether the `call` is meant for training or inference.
540 - `mask`: Boolean input mask. If the layer's `call()` method takes a
541 `mask` argument, its default value will be set to the mask
542 generated for `inputs` by the previous layer (if `input` did come
543 from a layer that generated a corresponding mask, i.e. if it came
544 from a Keras layer with masking support).
546 Returns:
547 A tensor or list/tuple of tensors.
548 """
549 return inputs
551 @doc_controls.for_subclass_implementers
552 def add_weight(
553 self,
554 name=None,
555 shape=None,
556 dtype=None,
557 initializer=None,
558 regularizer=None,
559 trainable=None,
560 constraint=None,
561 use_resource=None,
562 synchronization=tf.VariableSynchronization.AUTO,
563 aggregation=tf.VariableAggregation.NONE,
564 **kwargs,
565 ):
566 """Adds a new variable to the layer.
568 Args:
569 name: Variable name.
570 shape: Variable shape. Defaults to scalar if unspecified.
571 dtype: The type of the variable. Defaults to `self.dtype`.
572 initializer: Initializer instance (callable).
573 regularizer: Regularizer instance (callable).
574 trainable: Boolean, whether the variable should be part of the layer's
575 "trainable_variables" (e.g. variables, biases)
576 or "non_trainable_variables" (e.g. BatchNorm mean and variance).
577 Note that `trainable` cannot be `True` if `synchronization`
578 is set to `ON_READ`.
579 constraint: Constraint instance (callable).
580 use_resource: Whether to use a `ResourceVariable` or not.
581 See [this guide](
582 https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables)
583 for more information.
584 synchronization: Indicates when a distributed a variable will be
585 aggregated. Accepted values are constants defined in the class
586 `tf.VariableSynchronization`. By default the synchronization is set
587 to `AUTO` and the current `DistributionStrategy` chooses when to
588 synchronize. If `synchronization` is set to `ON_READ`, `trainable`
589 must not be set to `True`.
590 aggregation: Indicates how a distributed variable will be aggregated.
591 Accepted values are constants defined in the class
592 `tf.VariableAggregation`.
593 **kwargs: Additional keyword arguments. Accepted values are `getter`,
594 `collections`, `experimental_autocast` and `caching_device`.
596 Returns:
597 The variable created.
599 Raises:
600 ValueError: When giving unsupported dtype and no initializer or when
601 trainable has been set to True with synchronization set as
602 `ON_READ`.
603 """
604 if shape is None:
605 shape = ()
606 kwargs.pop("partitioner", None) # Ignored.
607 # Validate optional keyword arguments.
608 for kwarg in kwargs:
609 if kwarg not in [
610 "collections",
611 "experimental_autocast",
612 "caching_device",
613 "getter",
614 "layout",
615 "experimental_enable_variable_lifting",
616 ]:
617 raise TypeError("Unknown keyword argument:", kwarg)
618 collections_arg = kwargs.pop("collections", None)
619 # 'experimental_autocast' can be set to False by the caller to indicate
620 # an AutoCastVariable should never be created.
621 autocast = kwargs.pop("experimental_autocast", True)
622 # See the docstring for tf.Variable about the details for
623 # caching_device.
624 caching_device = kwargs.pop("caching_device", None)
626 layout = kwargs.pop("layout", None)
627 # Specially handling of auto layout fetch, based on the variable name
628 # and attribute name. For built-in keras layers, usually the variable
629 # name, eg 'kernel', will match with a 'kernel_layout' attribute name on
630 # the instance. We will try to do this auto fetch if layout is not
631 # explicitly specified. This is mainly a quick workaround for not
632 # applying too many interface change to built-in layers, until DTensor
633 # is a public API. Also see dtensor.utils.allow_initializer_layout for
634 # more details.
635 # TODO(scottzhu): Remove this once dtensor is public to end user.
636 if not layout and name:
637 layout = getattr(self, name + "_layout", None)
639 if dtype is None:
640 dtype = self.dtype or backend.floatx()
641 dtype = tf.as_dtype(dtype)
642 if self._dtype_policy.variable_dtype is None:
643 # The policy is "_infer", so we infer the policy from the variable
644 # dtype.
645 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
646 initializer = initializers.get(initializer)
647 regularizer = regularizers.get(regularizer)
648 constraint = constraints.get(constraint)
650 if synchronization == tf.VariableSynchronization.ON_READ:
651 if trainable:
652 raise ValueError(
653 "Synchronization value can be set to "
654 "VariableSynchronization.ON_READ only for non-trainable "
655 "variables. You have specified trainable=True and "
656 "synchronization=VariableSynchronization.ON_READ."
657 )
658 else:
659 # Set trainable to be false when variable is to be synced on
660 # read.
661 trainable = False
662 elif trainable is None:
663 trainable = True
665 # Initialize variable when no initializer provided
666 if initializer is None:
667 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
668 if dtype.is_floating:
669 initializer = initializers.get("glorot_uniform")
670 # If dtype is DT_INT/DT_UINT, provide a default value `zero`
671 # If dtype is DT_BOOL, provide a default value `FALSE`
672 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
673 initializer = initializers.get("zeros")
674 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX
675 # here?
676 elif "getter" not in kwargs:
677 # When `getter` is specified, it's possibly fine for
678 # `initializer` to be None since it's up to the custom `getter`
679 # to raise error in case it indeed needs `initializer`.
680 raise ValueError(
681 f"An initializer for variable {name} of type "
682 f"{dtype.base_dtype} is required for layer "
683 f"{self.name}. Received: {initializer}."
684 )
686 getter = kwargs.pop("getter", base_layer_utils.make_variable)
687 if (
688 autocast
689 and self._dtype_policy.compute_dtype
690 != self._dtype_policy.variable_dtype
691 and dtype.is_floating
692 ):
693 old_getter = getter
695 # Wrap variable constructor to return an AutoCastVariable.
696 def getter(*args, **kwargs):
697 variable = old_getter(*args, **kwargs)
698 return autocast_variable.create_autocast_variable(variable)
700 # Also the caching_device does not work with the mixed precision
701 # API, disable it if it is specified.
702 # TODO(b/142020079): Re-enable it once the bug is fixed.
703 if caching_device is not None:
704 tf_logging.warning(
705 "`caching_device` does not work with mixed precision API. "
706 "Ignoring user specified `caching_device`."
707 )
708 caching_device = None
709 if layout:
710 getter = functools.partial(getter, layout=layout)
712 variable = self._add_variable_with_custom_getter(
713 name=name,
714 shape=shape,
715 # TODO(allenl): a `make_variable` equivalent should be added as a
716 # `Trackable` method.
717 getter=getter,
718 # Manage errors in Layer rather than Trackable.
719 overwrite=True,
720 initializer=initializer,
721 dtype=dtype,
722 constraint=constraint,
723 trainable=trainable,
724 use_resource=use_resource,
725 collections=collections_arg,
726 synchronization=synchronization,
727 aggregation=aggregation,
728 caching_device=caching_device,
729 )
730 if regularizer is not None:
731 # TODO(fchollet): in the future, this should be handled at the
732 # level of variable creation, and weight regularization losses
733 # should be variable attributes.
734 name_in_scope = variable.name[: variable.name.find(":")]
735 self._handle_weight_regularization(
736 name_in_scope, variable, regularizer
737 )
738 if base_layer_utils.is_split_variable(variable):
739 for v in variable:
740 backend.track_variable(v)
741 if trainable:
742 self._trainable_weights.append(v)
743 else:
744 self._non_trainable_weights.append(v)
745 else:
746 backend.track_variable(variable)
747 if trainable:
748 self._trainable_weights.append(variable)
749 else:
750 self._non_trainable_weights.append(variable)
751 return variable
753 def __new__(cls, *args, **kwargs):
754 # Generate a config to be returned by default by `get_config()`.
755 arg_names = tf_inspect.getfullargspec(cls.__init__).args
756 kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
757 instance = super(Layer, cls).__new__(cls, *args, **kwargs)
758 # For safety, we only rely on auto-configs for a small set of
759 # serializable types.
760 supported_types = (str, int, float, bool, type(None))
761 try:
762 flat_arg_values = tf.nest.flatten(kwargs)
763 auto_get_config = True
764 for value in flat_arg_values:
765 if not isinstance(value, supported_types):
766 auto_get_config = False
767 break
768 except TypeError:
769 auto_get_config = False
770 try:
771 instance._auto_get_config = auto_get_config
772 if auto_get_config:
773 instance._auto_config = serialization_lib.Config(**kwargs)
774 except RecursionError:
775 # Setting an instance attribute in __new__ has the potential
776 # to trigger an infinite recursion if a subclass overrides
777 # setattr in an unsafe way.
778 pass
779 return instance
781 @generic_utils.default
782 def get_config(self):
783 """Returns the config of the layer.
785 A layer config is a Python dictionary (serializable)
786 containing the configuration of a layer.
787 The same layer can be reinstantiated later
788 (without its trained weights) from this configuration.
790 The config of a layer does not include connectivity
791 information, nor the layer class name. These are handled
792 by `Network` (one layer of abstraction above).
794 Note that `get_config()` does not guarantee to return a fresh copy of
795 dict every time it is called. The callers should make a copy of the
796 returned dict if they want to modify it.
798 Returns:
799 Python dictionary.
800 """
801 config = {
802 "name": self.name,
803 "trainable": self.trainable,
804 }
805 config["dtype"] = policy.serialize(self._dtype_policy)
806 if hasattr(self, "_batch_input_shape"):
807 config["batch_input_shape"] = self._batch_input_shape
809 if not generic_utils.is_default(self.get_config):
810 # In this case the subclass implements get_config()
811 return config
813 # In this case the subclass doesn't implement get_config():
814 # Let's see if we can autogenerate it.
815 if getattr(self, "_auto_get_config", False):
816 xtra_args = set(config.keys())
817 config.update(self._auto_config.config)
818 # Remove args non explicitly supported
819 argspec = tf_inspect.getfullargspec(self.__init__)
820 if argspec.varkw != "kwargs":
821 for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
822 config.pop(key, None)
823 return config
824 else:
825 raise NotImplementedError(
826 textwrap.dedent(
827 f"""
828 Layer {self.__class__.__name__} was created by passing
829 non-serializable argument values in `__init__()`,
830 and therefore the layer must override `get_config()` in
831 order to be serializable. Please implement `get_config()`.
833 Example:
835 class CustomLayer(keras.layers.Layer):
836 def __init__(self, arg1, arg2, **kwargs):
837 super().__init__(**kwargs)
838 self.arg1 = arg1
839 self.arg2 = arg2
841 def get_config(self):
842 config = super().get_config()
843 config.update({{
844 "arg1": self.arg1,
845 "arg2": self.arg2,
846 }})
847 return config"""
848 )
849 )
851 @classmethod
852 def from_config(cls, config):
853 """Creates a layer from its config.
855 This method is the reverse of `get_config`,
856 capable of instantiating the same layer from the config
857 dictionary. It does not handle layer connectivity
858 (handled by Network), nor weights (handled by `set_weights`).
860 Args:
861 config: A Python dictionary, typically the
862 output of get_config.
864 Returns:
865 A layer instance.
866 """
867 try:
868 return cls(**config)
869 except Exception as e:
870 raise TypeError(
871 f"Error when deserializing class '{cls.__name__}' using "
872 f"config={config}.\n\nException encountered: {e}"
873 )
875 def compute_output_shape(self, input_shape):
876 """Computes the output shape of the layer.
878 This method will cause the layer's state to be built, if that has not
879 happened before. This requires that the layer will later be used with
880 inputs that match the input shape provided here.
882 Args:
883 input_shape: Shape tuple (tuple of integers) or `tf.TensorShape`,
884 or structure of shape tuples / `tf.TensorShape` instances
885 (one per output tensor of the layer).
886 Shape tuples can include None for free dimensions,
887 instead of an integer.
889 Returns:
890 A `tf.TensorShape` instance
891 or structure of `tf.TensorShape` instances.
892 """
893 if tf.executing_eagerly():
894 # In this case we build the model first in order to do shape
895 # inference. This is acceptable because the framework only calls
896 # `compute_output_shape` on shape values that the layer would later
897 # be built for. It would however cause issues in case a user
898 # attempts to use `compute_output_shape` manually with shapes that
899 # are incompatible with the shape the Layer will be called on (these
900 # users will have to implement `compute_output_shape` themselves).
901 self._maybe_build(input_shape)
902 graph_name = str(self.name) + "_scratch_graph"
903 with tf.__internal__.FuncGraph(graph_name).as_default():
904 input_shape = tf_utils.convert_shapes(
905 input_shape, to_tuples=False
906 )
908 def _make_placeholder_like(shape):
909 ph = backend.placeholder(shape=shape, dtype=self.dtype)
910 ph._keras_mask = None
911 return ph
913 inputs = tf.nest.map_structure(
914 _make_placeholder_like, input_shape
915 )
916 try:
917 outputs = self(inputs, training=False)
918 except TypeError as e:
919 raise NotImplementedError(
920 "We could not automatically infer the static shape of "
921 "the layer's output. Please implement the "
922 "`compute_output_shape` method on your layer (%s)."
923 % self.__class__.__name__
924 ) from e
925 return tf.nest.map_structure(lambda t: t.shape, outputs)
926 raise NotImplementedError(
927 "Please run in eager mode or implement the `compute_output_shape` "
928 "method on your layer (%s)." % self.__class__.__name__
929 )
931 @doc_controls.for_subclass_implementers
932 def compute_output_signature(self, input_signature):
933 """Compute the output tensor signature of the layer based on the inputs.
935 Unlike a TensorShape object, a TensorSpec object contains both shape
936 and dtype information for a tensor. This method allows layers to provide
937 output dtype information if it is different from the input dtype.
938 For any layer that doesn't implement this function,
939 the framework will fall back to use `compute_output_shape`, and will
940 assume that the output dtype matches the input dtype.
942 Args:
943 input_signature: Single TensorSpec or nested structure of TensorSpec
944 objects, describing a candidate input for the layer.
946 Returns:
947 Single TensorSpec or nested structure of TensorSpec objects,
948 describing how the layer would transform the provided input.
950 Raises:
951 TypeError: If input_signature contains a non-TensorSpec object.
952 """
954 def check_type_return_shape(s):
955 if not isinstance(s, tf.TensorSpec):
956 raise TypeError(
957 "Only TensorSpec signature types are supported. "
958 f"Received: {s}."
959 )
960 return s.shape
962 input_shape = tf.nest.map_structure(
963 check_type_return_shape, input_signature
964 )
965 output_shape = self.compute_output_shape(input_shape)
967 try:
968 dtype = self.output.dtype
969 except AttributeError:
970 dtype = self._compute_dtype
972 if dtype is None:
973 input_dtypes = [s.dtype for s in tf.nest.flatten(input_signature)]
974 # Default behavior when self.dtype is None, is to use the first
975 # input's dtype.
976 dtype = input_dtypes[0]
977 return tf.nest.map_structure(
978 lambda s: tf.TensorSpec(dtype=dtype, shape=s), output_shape
979 )
981 @generic_utils.default
982 def compute_mask(self, inputs, mask=None):
983 """Computes an output mask tensor.
985 Args:
986 inputs: Tensor or list of tensors.
987 mask: Tensor or list of tensors.
989 Returns:
990 None or a tensor (or list of tensors,
991 one per output tensor of the layer).
992 """
993 if not self._supports_masking:
994 if any(m is not None for m in tf.nest.flatten(mask)):
995 raise TypeError(
996 "Layer " + self.name + " does not support masking, "
997 "but was passed an input_mask: " + str(mask)
998 )
999 # masking not explicitly supported: return None as mask.
1000 return None
1001 # if masking is explicitly supported, by default
1002 # carry over the input mask
1003 return mask
1005 @traceback_utils.filter_traceback
1006 def __call__(self, *args, **kwargs):
1007 """Wraps `call`, applying pre- and post-processing steps.
1009 Args:
1010 *args: Positional arguments to be passed to `self.call`.
1011 **kwargs: Keyword arguments to be passed to `self.call`.
1013 Returns:
1014 Output tensor(s).
1016 Note:
1017 - The following optional keyword arguments are reserved for specific
1018 uses:
1019 * `training`: Boolean scalar tensor of Python boolean indicating
1020 whether the `call` is meant for training or inference.
1021 * `mask`: Boolean input mask.
1022 - If the layer's `call` method takes a `mask` argument (as some Keras
1023 layers do), its default value will be set to the mask generated
1024 for `inputs` by the previous layer (if `input` did come from
1025 a layer that generated a corresponding mask, i.e. if it came from
1026 a Keras layer with masking support.
1027 - If the layer is not built, the method will call `build`.
1029 Raises:
1030 ValueError: if the layer's `call` method returns None (an invalid
1031 value).
1032 RuntimeError: if `super().__init__()` was not called in the
1033 constructor.
1034 """
1035 if not hasattr(self, "_thread_local"):
1036 raise RuntimeError(
1037 "You must call `super().__init__()` in the layer constructor."
1038 )
1040 # `inputs` (the first arg in the method spec) is special cased in
1041 # layer call due to historical reasons.
1042 # This special casing currently takes the form of:
1043 # - 'inputs' must be explicitly passed. A layer cannot have zero
1044 # arguments, and inputs cannot have been provided via the default
1045 # value of a kwarg.
1046 # - numpy/scalar values in `inputs` get converted to tensors
1047 # - implicit masks / mask metadata are only collected from 'inputs`
1048 # - Layers are built using shape info from 'inputs' only
1049 # - input_spec compatibility is only checked against `inputs`
1050 # - mixed precision casting (autocast) is only applied to `inputs`,
1051 # not to any other argument.
1052 inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs)
1053 input_list = tf.nest.flatten(inputs)
1055 # Functional Model construction mode is invoked when `Layer`s are called
1056 # on symbolic `KerasTensor`s, i.e.:
1057 # >> inputs = tf.keras.Input(10)
1058 # >> outputs = MyLayer()(inputs) # Functional construction mode.
1059 # >> model = tf.keras.Model(inputs, outputs)
1060 if _in_functional_construction_mode(
1061 self, inputs, args, kwargs, input_list
1062 ):
1063 return self._functional_construction_call(
1064 inputs, args, kwargs, input_list
1065 )
1067 # Maintains info about the `Layer.call` stack.
1068 call_context = base_layer_utils.call_context()
1070 # Accept NumPy and scalar inputs by converting to Tensors.
1071 if any(
1072 isinstance(x, (tf.Tensor, np.ndarray, float, int))
1073 for x in input_list
1074 ):
1075 inputs = tf.nest.map_structure(
1076 _convert_numpy_or_python_types, inputs
1077 )
1078 input_list = tf.nest.flatten(inputs)
1080 # Handle `mask` propagation from previous layer to current layer. Masks
1081 # can be propagated explicitly via the `mask` argument, or implicitly
1082 # via setting the `_keras_mask` attribute on the inputs to a Layer.
1083 # Masks passed explicitly take priority.
1084 input_masks, mask_is_implicit = self._get_input_masks(
1085 inputs, input_list, args, kwargs
1086 )
1087 if self._expects_mask_arg and mask_is_implicit:
1088 kwargs["mask"] = input_masks
1090 # Training mode for `Layer.call` is set via (in order of priority):
1091 # (1) The `training` argument passed to this `Layer.call`, if it is not
1092 # None
1093 # (2) The training mode of an outer `Layer.call`.
1094 # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if
1095 # set)
1096 # (4) Any non-None default value for `training` specified in the call
1097 # signature
1098 # (5) False (treating the layer as if it's in inference)
1099 args, kwargs, training_mode = self._set_training_mode(
1100 args, kwargs, call_context
1101 )
1103 # Losses are cleared for all sublayers on the outermost `Layer.call`.
1104 # Losses are not cleared on inner `Layer.call`s, because sublayers can
1105 # be called multiple times.
1106 if not call_context.in_call:
1107 self._clear_losses()
1109 eager = tf.executing_eagerly()
1110 with call_context.enter(
1111 layer=self,
1112 inputs=inputs,
1113 build_graph=not eager,
1114 training=training_mode,
1115 ):
1117 input_spec.assert_input_compatibility(
1118 self.input_spec, inputs, self.name
1119 )
1121 if eager:
1122 call_fn = self.call
1123 name_scope = self._name
1124 else:
1125 name_scope = self._get_unnested_name_scope()
1126 call_fn = self._autographed_call()
1128 call_fn = traceback_utils.inject_argument_info_in_traceback(
1129 call_fn,
1130 object_name=(
1131 f"layer '{self.name}' (type {self.__class__.__name__})"
1132 ),
1133 )
1134 with contextlib.ExitStack() as namescope_stack:
1135 if _is_name_scope_on_model_declaration_enabled:
1136 namescope_stack.enter_context(
1137 _name_scope_unnester(self._name_scope_on_declaration)
1138 )
1139 namescope_stack.enter_context(tf.name_scope(name_scope))
1141 if not self.built:
1142 self._maybe_build(inputs)
1144 if self._autocast:
1145 inputs = self._maybe_cast_inputs(inputs, input_list)
1147 with autocast_variable.enable_auto_cast_variables(
1148 self._compute_dtype_object
1149 ):
1150 outputs = call_fn(inputs, *args, **kwargs)
1152 if self._activity_regularizer:
1153 self._handle_activity_regularization(inputs, outputs)
1154 if self._supports_masking:
1155 self._set_mask_metadata(
1156 inputs, outputs, input_masks, not eager
1157 )
1158 if self._saved_model_inputs_spec is None:
1159 self._set_save_spec(inputs, args, kwargs)
1161 return outputs
1163 def _get_unnested_name_scope(self):
1164 if _is_name_scope_on_model_declaration_enabled:
1165 with _name_scope_unnester(
1166 self._name_scope_on_declaration
1167 ) as relative_name_scope_on_declaration:
1168 # To avoid `tf.name_scope` autoincrement, use absolute path.
1169 relative_name_scope = filter(
1170 None,
1171 [
1172 tf.get_current_name_scope(),
1173 relative_name_scope_on_declaration,
1174 ],
1175 )
1176 current_name_scope = "/".join(relative_name_scope) + "/"
1177 if current_name_scope == "/":
1178 current_name_scope = self._name_scope_on_declaration
1179 with tf.name_scope(current_name_scope):
1180 name_scope = self._name_scope() # Avoid autoincrementing.
1181 else:
1182 name_scope = self._name_scope()
1184 return name_scope
1186 @property
1187 def dtype(self):
1188 """The dtype of the layer weights.
1190 This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless
1191 mixed precision is used, this is the same as `Layer.compute_dtype`, the
1192 dtype of the layer's computations.
1193 """
1194 return self._dtype_policy.variable_dtype
1196 @property
1197 def name(self):
1198 """Name of the layer (string), set in the constructor."""
1199 return self._name
1201 @property
1202 def supports_masking(self):
1203 """Whether this layer supports computing a mask using `compute_mask`."""
1204 return self._supports_masking
1206 @supports_masking.setter
1207 def supports_masking(self, value):
1208 self._supports_masking = value
1210 @property
1211 def dynamic(self):
1212 """Whether the layer is dynamic (eager-only); set in the constructor."""
1213 return any(layer._dynamic for layer in self._flatten_layers())
1215 @property
1216 @doc_controls.do_not_doc_inheritable
1217 def stateful(self):
1218 return any(layer._stateful for layer in self._flatten_layers())
1220 @stateful.setter
1221 def stateful(self, value):
1222 self._stateful = value
1224 @property
1225 def trainable(self):
1226 return self._trainable
1228 @trainable.setter
1229 def trainable(self, value):
1230 """Sets trainable attribute for the layer and its sublayers.
1232 When this value is changed during training (e.g. with a
1233 `tf.keras.callbacks.Callback`) you need to call the parent
1234 `tf.keras.Model.make_train_function` with `force=True` in order to
1235 recompile the training graph.
1237 Args:
1238 value: Boolean with the desired state for the layer's trainable
1239 attribute.
1240 """
1241 for layer in self._flatten_layers():
1242 layer._trainable = value
1244 @property
1245 def activity_regularizer(self):
1246 """Optional regularizer function for the output of this layer."""
1247 return self._activity_regularizer
1249 @activity_regularizer.setter
1250 def activity_regularizer(self, regularizer):
1251 """Optional regularizer function for the output of this layer."""
1252 self._activity_regularizer = regularizer
1254 @property
1255 def input_spec(self):
1256 """`InputSpec` instance(s) describing the input format for this layer.
1258 When you create a layer subclass, you can set `self.input_spec` to
1259 enable the layer to run input compatibility checks when it is called.
1260 Consider a `Conv2D` layer: it can only be called on a single input
1261 tensor of rank 4. As such, you can set, in `__init__()`:
1263 ```python
1264 self.input_spec = tf.keras.layers.InputSpec(ndim=4)
1265 ```
1267 Now, if you try to call the layer on an input that isn't rank 4
1268 (for instance, an input of shape `(2,)`, it will raise a
1269 nicely-formatted error:
1271 ```
1272 ValueError: Input 0 of layer conv2d is incompatible with the layer:
1273 expected ndim=4, found ndim=1. Full shape received: [2]
1274 ```
1276 Input checks that can be specified via `input_spec` include:
1277 - Structure (e.g. a single input, a list of 2 inputs, etc)
1278 - Shape
1279 - Rank (ndim)
1280 - Dtype
1282 For more information, see `tf.keras.layers.InputSpec`.
1284 Returns:
1285 A `tf.keras.layers.InputSpec` instance, or nested structure thereof.
1286 """
1287 return self._input_spec
1289 @input_spec.setter
1290 # Must be decorated to prevent tracking, since the input_spec can be nested
1291 # InputSpec objects.
1292 @tf.__internal__.tracking.no_automatic_dependency_tracking
1293 def input_spec(self, value):
1294 for v in tf.nest.flatten(value):
1295 if v is not None and not isinstance(v, input_spec.InputSpec):
1296 raise TypeError(
1297 "Layer input_spec must be an instance of InputSpec. "
1298 "Got: {}".format(v)
1299 )
1300 self._input_spec = value
1302 @property
1303 def trainable_weights(self):
1304 """List of all trainable weights tracked by this layer.
1306 Trainable weights are updated via gradient descent during training.
1308 Returns:
1309 A list of trainable variables.
1310 """
1311 self._update_trackables()
1312 if self.trainable:
1313 children_weights = self._gather_children_attribute(
1314 "trainable_variables"
1315 )
1316 return self._dedup_weights(
1317 self._trainable_weights + children_weights
1318 )
1319 else:
1320 return []
1322 @property
1323 def non_trainable_weights(self):
1324 """List of all non-trainable weights tracked by this layer.
1326 Non-trainable weights are *not* updated during training. They are
1327 expected to be updated manually in `call()`.
1329 Returns:
1330 A list of non-trainable variables.
1331 """
1332 self._update_trackables()
1333 if self.trainable:
1334 children_weights = self._gather_children_attribute(
1335 "non_trainable_variables"
1336 )
1337 non_trainable_weights = (
1338 self._non_trainable_weights + children_weights
1339 )
1340 else:
1341 children_weights = self._gather_children_attribute("variables")
1342 non_trainable_weights = (
1343 self._trainable_weights
1344 + self._non_trainable_weights
1345 + children_weights
1346 )
1347 return self._dedup_weights(non_trainable_weights)
1349 @property
1350 def weights(self):
1351 """Returns the list of all layer variables/weights.
1353 Returns:
1354 A list of variables.
1355 """
1356 return self.trainable_weights + self.non_trainable_weights
1358 @property
1359 @doc_controls.do_not_generate_docs
1360 def updates(self):
1361 warnings.warn(
1362 "`layer.updates` will be removed in a future version. "
1363 "This property should not be used in TensorFlow 2.0, "
1364 "as `updates` are applied automatically.",
1365 stacklevel=2,
1366 )
1367 return []
1369 @property
1370 def losses(self):
1371 """List of losses added using the `add_loss()` API.
1373 Variable regularization tensors are created when this property is
1374 accessed, so it is eager safe: accessing `losses` under a
1375 `tf.GradientTape` will propagate gradients back to the corresponding
1376 variables.
1378 Examples:
1380 >>> class MyLayer(tf.keras.layers.Layer):
1381 ... def call(self, inputs):
1382 ... self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1383 ... return inputs
1384 >>> l = MyLayer()
1385 >>> l(np.ones((10, 1)))
1386 >>> l.losses
1387 [1.0]
1389 >>> inputs = tf.keras.Input(shape=(10,))
1390 >>> x = tf.keras.layers.Dense(10)(inputs)
1391 >>> outputs = tf.keras.layers.Dense(1)(x)
1392 >>> model = tf.keras.Model(inputs, outputs)
1393 >>> # Activity regularization.
1394 >>> len(model.losses)
1395 0
1396 >>> model.add_loss(tf.abs(tf.reduce_mean(x)))
1397 >>> len(model.losses)
1398 1
1400 >>> inputs = tf.keras.Input(shape=(10,))
1401 >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
1402 >>> x = d(inputs)
1403 >>> outputs = tf.keras.layers.Dense(1)(x)
1404 >>> model = tf.keras.Model(inputs, outputs)
1405 >>> # Weight regularization.
1406 >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
1407 >>> model.losses
1408 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
1410 Returns:
1411 A list of tensors.
1412 """
1413 collected_losses = []
1414 for layer in self._flatten_layers():
1415 # If any eager losses are present, we assume the model to be part of
1416 # an eager training loop (either a custom one or the one used when
1417 # `run_eagerly=True`) and so we always return just the eager losses.
1418 if layer._eager_losses:
1419 # Filter placeholder losses that may have been added by revived
1420 # layers. (see base_layer_utils for details).
1421 if (
1422 layer._eager_losses[0]
1423 is not base_layer_utils.REVIVED_LOSS_PLACEHOLDER
1424 ):
1425 collected_losses.extend(layer._eager_losses)
1426 else:
1427 collected_losses.extend(layer._losses)
1428 for regularizer in layer._callable_losses:
1429 loss_tensor = regularizer()
1430 if loss_tensor is not None:
1431 collected_losses.append(loss_tensor)
1432 return collected_losses
1434 def add_loss(self, losses, **kwargs):
1435 """Add loss tensor(s), potentially dependent on layer inputs.
1437 Some losses (for instance, activity regularization losses) may be
1438 dependent on the inputs passed when calling a layer. Hence, when reusing
1439 the same layer on different inputs `a` and `b`, some entries in
1440 `layer.losses` may be dependent on `a` and some on `b`. This method
1441 automatically keeps track of dependencies.
1443 This method can be used inside a subclassed layer or model's `call`
1444 function, in which case `losses` should be a Tensor or list of Tensors.
1446 Example:
1448 ```python
1449 class MyLayer(tf.keras.layers.Layer):
1450 def call(self, inputs):
1451 self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1452 return inputs
1453 ```
1455 The same code works in distributed training: the input to `add_loss()`
1456 is treated like a regularization loss and averaged across replicas
1457 by the training loop (both built-in `Model.fit()` and compliant custom
1458 training loops).
1460 The `add_loss` method can also be called directly on a Functional Model
1461 during construction. In this case, any loss Tensors passed to this Model
1462 must be symbolic and be able to be traced back to the model's `Input`s.
1463 These losses become part of the model's topology and are tracked in
1464 `get_config`.
1466 Example:
1468 ```python
1469 inputs = tf.keras.Input(shape=(10,))
1470 x = tf.keras.layers.Dense(10)(inputs)
1471 outputs = tf.keras.layers.Dense(1)(x)
1472 model = tf.keras.Model(inputs, outputs)
1473 # Activity regularization.
1474 model.add_loss(tf.abs(tf.reduce_mean(x)))
1475 ```
1477 If this is not the case for your loss (if, for example, your loss
1478 references a `Variable` of one of the model's layers), you can wrap your
1479 loss in a zero-argument lambda. These losses are not tracked as part of
1480 the model's topology since they can't be serialized.
1482 Example:
1484 ```python
1485 inputs = tf.keras.Input(shape=(10,))
1486 d = tf.keras.layers.Dense(10)
1487 x = d(inputs)
1488 outputs = tf.keras.layers.Dense(1)(x)
1489 model = tf.keras.Model(inputs, outputs)
1490 # Weight regularization.
1491 model.add_loss(lambda: tf.reduce_mean(d.kernel))
1492 ```
1494 Args:
1495 losses: Loss tensor, or list/tuple of tensors. Rather than tensors,
1496 losses may also be zero-argument callables which create a loss
1497 tensor.
1498 **kwargs: Used for backwards compatibility only.
1499 """
1500 kwargs.pop("inputs", None)
1501 if kwargs:
1502 raise TypeError(f"Unknown keyword arguments: {kwargs.keys()}")
1504 def _tag_callable(loss):
1505 """Tags callable loss tensor as `_unconditional_loss`."""
1506 if callable(loss):
1507 # We run the loss without autocasting, as regularizers are often
1508 # numerically unstable in float16.
1509 with autocast_variable.enable_auto_cast_variables(None):
1510 loss = loss()
1511 if loss is None:
1512 # Will be filtered out when computing the .losses property
1513 return None
1514 if not tf.is_tensor(loss):
1515 loss = tf.convert_to_tensor(loss, dtype=backend.floatx())
1516 loss._unconditional_loss = True
1517 return loss
1519 losses = tf.nest.flatten(losses)
1521 callable_losses = []
1522 eager_losses = []
1523 symbolic_losses = []
1524 for loss in losses:
1525 if callable(loss):
1526 callable_losses.append(functools.partial(_tag_callable, loss))
1527 continue
1528 if loss is None:
1529 continue
1530 if not tf.is_tensor(loss) and not isinstance(
1531 loss, keras_tensor.KerasTensor
1532 ):
1533 loss = tf.convert_to_tensor(loss, dtype=backend.floatx())
1534 # TF Functions should take the eager path.
1535 if (
1536 tf_utils.is_symbolic_tensor(loss)
1537 or isinstance(loss, keras_tensor.KerasTensor)
1538 ) and not base_layer_utils.is_in_tf_function():
1539 symbolic_losses.append(loss)
1540 elif tf.is_tensor(loss):
1541 eager_losses.append(loss)
1543 self._callable_losses.extend(callable_losses)
1545 in_call_context = base_layer_utils.call_context().in_call
1546 if eager_losses and not in_call_context:
1547 raise ValueError(
1548 "Expected a symbolic Tensors or a callable for the loss value. "
1549 "Please wrap your loss computation in a zero argument `lambda`."
1550 )
1552 self._eager_losses.extend(eager_losses)
1554 for symbolic_loss in symbolic_losses:
1555 if getattr(self, "_is_graph_network", False):
1556 self._graph_network_add_loss(symbolic_loss)
1557 else:
1558 # Possible a loss was added in a Layer's `build`.
1559 self._losses.append(symbolic_loss)
1561 @property
1562 def metrics(self):
1563 """List of metrics added using the `add_metric()` API.
1565 Example:
1567 >>> input = tf.keras.layers.Input(shape=(3,))
1568 >>> d = tf.keras.layers.Dense(2)
1569 >>> output = d(input)
1570 >>> d.add_metric(tf.reduce_max(output), name='max')
1571 >>> d.add_metric(tf.reduce_min(output), name='min')
1572 >>> [m.name for m in d.metrics]
1573 ['max', 'min']
1575 Returns:
1576 A list of `Metric` objects.
1577 """
1578 collected_metrics = []
1579 for layer in self._flatten_layers():
1580 if not hasattr(layer, "_metrics_lock"):
1581 continue
1582 with layer._metrics_lock:
1583 collected_metrics.extend(layer._metrics)
1584 return collected_metrics
1586 def add_metric(self, value, name=None, **kwargs):
1587 """Adds metric tensor to the layer.
1589 This method can be used inside the `call()` method of a subclassed layer
1590 or model.
1592 ```python
1593 class MyMetricLayer(tf.keras.layers.Layer):
1594 def __init__(self):
1595 super(MyMetricLayer, self).__init__(name='my_metric_layer')
1596 self.mean = tf.keras.metrics.Mean(name='metric_1')
1598 def call(self, inputs):
1599 self.add_metric(self.mean(inputs))
1600 self.add_metric(tf.reduce_sum(inputs), name='metric_2')
1601 return inputs
1602 ```
1604 This method can also be called directly on a Functional Model during
1605 construction. In this case, any tensor passed to this Model must
1606 be symbolic and be able to be traced back to the model's `Input`s. These
1607 metrics become part of the model's topology and are tracked when you
1608 save the model via `save()`.
1610 ```python
1611 inputs = tf.keras.Input(shape=(10,))
1612 x = tf.keras.layers.Dense(10)(inputs)
1613 outputs = tf.keras.layers.Dense(1)(x)
1614 model = tf.keras.Model(inputs, outputs)
1615 model.add_metric(math_ops.reduce_sum(x), name='metric_1')
1616 ```
1618 Note: Calling `add_metric()` with the result of a metric object on a
1619 Functional Model, as shown in the example below, is not supported. This
1620 is because we cannot trace the metric result tensor back to the model's
1621 inputs.
1623 ```python
1624 inputs = tf.keras.Input(shape=(10,))
1625 x = tf.keras.layers.Dense(10)(inputs)
1626 outputs = tf.keras.layers.Dense(1)(x)
1627 model = tf.keras.Model(inputs, outputs)
1628 model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
1629 ```
1631 Args:
1632 value: Metric tensor.
1633 name: String metric name.
1634 **kwargs: Additional keyword arguments for backward compatibility.
1635 Accepted values:
1636 `aggregation` - When the `value` tensor provided is not the result
1637 of calling a `keras.Metric` instance, it will be aggregated by
1638 default using a `keras.Metric.Mean`.
1639 """
1640 kwargs_keys = list(kwargs.keys())
1641 if len(kwargs_keys) > 1 or (
1642 len(kwargs_keys) == 1 and kwargs_keys[0] != "aggregation"
1643 ):
1644 raise TypeError(
1645 f"Unknown keyword arguments: {kwargs.keys()}. "
1646 "Expected `aggregation`."
1647 )
1649 from_metric_obj = hasattr(value, "_metric_obj")
1650 is_symbolic = isinstance(value, keras_tensor.KerasTensor)
1651 in_call_context = base_layer_utils.call_context().in_call
1653 if name is None and not from_metric_obj:
1654 # Eg. `self.add_metric(math_ops.reduce_sum(x))` In eager mode, we
1655 # use metric name to lookup a metric. Without a name, a new Mean
1656 # metric wrapper will be created on every model/layer call. So, we
1657 # raise an error when no name is provided. We will do the same for
1658 # symbolic mode for consistency although a name will be generated if
1659 # no name is provided.
1661 # We will not raise this error in the foll use case for the sake of
1662 # consistency as name in provided in the metric constructor.
1663 # mean = metrics.Mean(name='my_metric')
1664 # model.add_metric(mean(outputs))
1665 raise ValueError(
1666 "Please provide a name for your metric like "
1667 "`self.add_metric(tf.reduce_sum(inputs), "
1668 "name='mean_activation')`"
1669 )
1670 elif from_metric_obj:
1671 name = value._metric_obj.name
1673 if not in_call_context and not is_symbolic:
1674 raise ValueError(
1675 "Expected a symbolic Tensor for the metric value, received: "
1676 + str(value)
1677 )
1679 # If a metric was added in a Layer's `call` or `build`.
1680 if in_call_context or not getattr(self, "_is_graph_network", False):
1681 # TF Function path should take the eager path.
1683 # If the given metric is available in `metrics` list we just update
1684 # state on it, otherwise we create a new metric instance and
1685 # add it to the `metrics` list.
1686 metric_obj = getattr(value, "_metric_obj", None)
1687 # Tensors that come from a Metric object already updated the Metric
1688 # state.
1689 should_update_state = not metric_obj
1690 name = metric_obj.name if metric_obj else name
1692 with self._metrics_lock:
1693 match = self._get_existing_metric(name)
1694 if match:
1695 metric_obj = match
1696 elif metric_obj:
1697 self._metrics.append(metric_obj)
1698 else:
1699 # Build the metric object with the value's dtype if it
1700 # defines one
1701 metric_obj = metrics_mod.Mean(
1702 name=name, dtype=getattr(value, "dtype", None)
1703 )
1704 self._metrics.append(metric_obj)
1706 if should_update_state:
1707 metric_obj(value)
1708 else:
1709 if from_metric_obj:
1710 raise ValueError(
1711 "Using the result of calling a `Metric` object "
1712 "when calling `add_metric` on a Functional "
1713 "Model is not supported. Please pass the "
1714 "Tensor to monitor directly."
1715 )
1717 # Insert layers into the Keras Graph Network.
1718 aggregation = None if from_metric_obj else "mean"
1719 self._graph_network_add_metric(value, aggregation, name)
1721 @doc_controls.do_not_doc_inheritable
1722 def add_update(self, updates):
1723 """Add update op(s), potentially dependent on layer inputs.
1725 Weight updates (for instance, the updates of the moving mean and
1726 variance in a BatchNormalization layer) may be dependent on the inputs
1727 passed when calling a layer. Hence, when reusing the same layer on
1728 different inputs `a` and `b`, some entries in `layer.updates` may be
1729 dependent on `a` and some on `b`. This method automatically keeps track
1730 of dependencies.
1732 This call is ignored when eager execution is enabled (in that case,
1733 variable updates are run on the fly and thus do not need to be tracked
1734 for later execution).
1736 Args:
1737 updates: Update op, or list/tuple of update ops, or zero-arg callable
1738 that returns an update op. A zero-arg callable should be passed in
1739 order to disable running the updates by setting `trainable=False`
1740 on this Layer, when executing in Eager mode.
1741 """
1742 call_context = base_layer_utils.call_context()
1743 # No need to run updates during Functional API construction.
1744 if call_context.in_keras_graph:
1745 return
1747 # Callable updates are disabled by setting `trainable=False`.
1748 if not call_context.frozen:
1749 for update in tf.nest.flatten(updates):
1750 if callable(update):
1751 update()
1753 def set_weights(self, weights):
1754 """Sets the weights of the layer, from NumPy arrays.
1756 The weights of a layer represent the state of the layer. This function
1757 sets the weight values from numpy arrays. The weight values should be
1758 passed in the order they are created by the layer. Note that the layer's
1759 weights must be instantiated before calling this function, by calling
1760 the layer.
1762 For example, a `Dense` layer returns a list of two values: the kernel
1763 matrix and the bias vector. These can be used to set the weights of
1764 another `Dense` layer:
1766 >>> layer_a = tf.keras.layers.Dense(1,
1767 ... kernel_initializer=tf.constant_initializer(1.))
1768 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
1769 >>> layer_a.get_weights()
1770 [array([[1.],
1771 [1.],
1772 [1.]], dtype=float32), array([0.], dtype=float32)]
1773 >>> layer_b = tf.keras.layers.Dense(1,
1774 ... kernel_initializer=tf.constant_initializer(2.))
1775 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
1776 >>> layer_b.get_weights()
1777 [array([[2.],
1778 [2.],
1779 [2.]], dtype=float32), array([0.], dtype=float32)]
1780 >>> layer_b.set_weights(layer_a.get_weights())
1781 >>> layer_b.get_weights()
1782 [array([[1.],
1783 [1.],
1784 [1.]], dtype=float32), array([0.], dtype=float32)]
1786 Args:
1787 weights: a list of NumPy arrays. The number
1788 of arrays and their shape must match
1789 number of the dimensions of the weights
1790 of the layer (i.e. it should match the
1791 output of `get_weights`).
1793 Raises:
1794 ValueError: If the provided weights list does not match the
1795 layer's specifications.
1796 """
1797 params = self.weights
1799 expected_num_weights = 0
1800 for param in params:
1801 if isinstance(param, base_layer_utils.TrackableWeightHandler):
1802 expected_num_weights += param.num_tensors
1803 else:
1804 expected_num_weights += 1
1806 if expected_num_weights != len(weights):
1807 raise ValueError(
1808 'You called `set_weights(weights)` on layer "%s" '
1809 "with a weight list of length %s, but the layer was "
1810 "expecting %s weights. Provided weights: %s..."
1811 % (
1812 self.name,
1813 len(weights),
1814 expected_num_weights,
1815 str(weights)[:50],
1816 )
1817 )
1819 weight_index = 0
1820 weight_value_tuples = []
1821 for param in params:
1822 if isinstance(param, base_layer_utils.TrackableWeightHandler):
1823 num_tensors = param.num_tensors
1824 tensors = weights[weight_index : weight_index + num_tensors]
1825 param.set_weights(tensors)
1826 weight_index += num_tensors
1827 else:
1828 weight = weights[weight_index]
1829 weight_shape = weight.shape if hasattr(weight, "shape") else ()
1830 ref_shape = param.shape
1831 if not ref_shape.is_compatible_with(weight_shape):
1832 raise ValueError(
1833 f"Layer {self.name} weight shape {ref_shape} "
1834 "is not compatible with provided weight "
1835 f"shape {weight_shape}."
1836 )
1837 weight_value_tuples.append((param, weight))
1838 weight_index += 1
1840 backend.batch_set_value(weight_value_tuples)
1842 # Perform any layer defined finalization of the layer state.
1843 for layer in self._flatten_layers():
1844 layer.finalize_state()
1846 def get_weights(self):
1847 """Returns the current weights of the layer, as NumPy arrays.
1849 The weights of a layer represent the state of the layer. This function
1850 returns both trainable and non-trainable weight values associated with
1851 this layer as a list of NumPy arrays, which can in turn be used to load
1852 state into similarly parameterized layers.
1854 For example, a `Dense` layer returns a list of two values: the kernel
1855 matrix and the bias vector. These can be used to set the weights of
1856 another `Dense` layer:
1858 >>> layer_a = tf.keras.layers.Dense(1,
1859 ... kernel_initializer=tf.constant_initializer(1.))
1860 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
1861 >>> layer_a.get_weights()
1862 [array([[1.],
1863 [1.],
1864 [1.]], dtype=float32), array([0.], dtype=float32)]
1865 >>> layer_b = tf.keras.layers.Dense(1,
1866 ... kernel_initializer=tf.constant_initializer(2.))
1867 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
1868 >>> layer_b.get_weights()
1869 [array([[2.],
1870 [2.],
1871 [2.]], dtype=float32), array([0.], dtype=float32)]
1872 >>> layer_b.set_weights(layer_a.get_weights())
1873 >>> layer_b.get_weights()
1874 [array([[1.],
1875 [1.],
1876 [1.]], dtype=float32), array([0.], dtype=float32)]
1878 Returns:
1879 Weights values as a list of NumPy arrays.
1880 """
1881 weights = self.weights
1882 output_weights = []
1883 for weight in weights:
1884 if isinstance(weight, base_layer_utils.TrackableWeightHandler):
1885 output_weights.extend(weight.get_tensors())
1886 else:
1887 output_weights.append(weight)
1888 return backend.batch_get_value(output_weights)
1890 @doc_controls.do_not_generate_docs
1891 def finalize_state(self):
1892 """Finalizes the layers state after updating layer weights.
1894 This function can be subclassed in a layer and will be called after
1895 updating a layer weights. It can be overridden to finalize any
1896 additional layer state after a weight update.
1898 This function will be called after weights of a layer have been restored
1899 from a loaded model.
1900 """
1901 pass
1903 @doc_controls.do_not_doc_inheritable
1904 def get_input_mask_at(self, node_index):
1905 """Retrieves the input mask tensor(s) of a layer at a given node.
1907 Args:
1908 node_index: Integer, index of the node
1909 from which to retrieve the attribute.
1910 E.g. `node_index=0` will correspond to the
1911 first time the layer was called.
1913 Returns:
1914 A mask tensor
1915 (or list of tensors if the layer has multiple inputs).
1916 """
1917 inputs = self.get_input_at(node_index)
1918 if isinstance(inputs, list):
1919 return [getattr(x, "_keras_mask", None) for x in inputs]
1920 else:
1921 return getattr(inputs, "_keras_mask", None)
1923 @doc_controls.do_not_doc_inheritable
1924 def get_output_mask_at(self, node_index):
1925 """Retrieves the output mask tensor(s) of a layer at a given node.
1927 Args:
1928 node_index: Integer, index of the node
1929 from which to retrieve the attribute.
1930 E.g. `node_index=0` will correspond to the
1931 first time the layer was called.
1933 Returns:
1934 A mask tensor
1935 (or list of tensors if the layer has multiple outputs).
1936 """
1937 output = self.get_output_at(node_index)
1938 if isinstance(output, list):
1939 return [getattr(x, "_keras_mask", None) for x in output]
1940 else:
1941 return getattr(output, "_keras_mask", None)
1943 @property
1944 @doc_controls.do_not_doc_inheritable
1945 def input_mask(self):
1946 """Retrieves the input mask tensor(s) of a layer.
1948 Only applicable if the layer has exactly one inbound node,
1949 i.e. if it is connected to one incoming layer.
1951 Returns:
1952 Input mask tensor (potentially None) or list of input
1953 mask tensors.
1955 Raises:
1956 AttributeError: if the layer is connected to
1957 more than one incoming layers.
1958 """
1959 inputs = self.input
1960 if isinstance(inputs, list):
1961 return [getattr(x, "_keras_mask", None) for x in inputs]
1962 else:
1963 return getattr(inputs, "_keras_mask", None)
1965 @property
1966 @doc_controls.do_not_doc_inheritable
1967 def output_mask(self):
1968 """Retrieves the output mask tensor(s) of a layer.
1970 Only applicable if the layer has exactly one inbound node,
1971 i.e. if it is connected to one incoming layer.
1973 Returns:
1974 Output mask tensor (potentially None) or list of output
1975 mask tensors.
1977 Raises:
1978 AttributeError: if the layer is connected to
1979 more than one incoming layers.
1980 """
1981 output = self.output
1982 if isinstance(output, list):
1983 return [getattr(x, "_keras_mask", None) for x in output]
1984 else:
1985 return getattr(output, "_keras_mask", None)
1987 @doc_controls.do_not_doc_inheritable
1988 def get_input_shape_at(self, node_index):
1989 """Retrieves the input shape(s) of a layer at a given node.
1991 Args:
1992 node_index: Integer, index of the node
1993 from which to retrieve the attribute.
1994 E.g. `node_index=0` will correspond to the
1995 first time the layer was called.
1997 Returns:
1998 A shape tuple
1999 (or list of shape tuples if the layer has multiple inputs).
2001 Raises:
2002 RuntimeError: If called in Eager mode.
2003 """
2004 return self._get_node_attribute_at_index(
2005 node_index, "input_shapes", "input shape"
2006 )
2008 @doc_controls.do_not_doc_inheritable
2009 def get_output_shape_at(self, node_index):
2010 """Retrieves the output shape(s) of a layer at a given node.
2012 Args:
2013 node_index: Integer, index of the node
2014 from which to retrieve the attribute.
2015 E.g. `node_index=0` will correspond to the
2016 first time the layer was called.
2018 Returns:
2019 A shape tuple
2020 (or list of shape tuples if the layer has multiple outputs).
2022 Raises:
2023 RuntimeError: If called in Eager mode.
2024 """
2025 return self._get_node_attribute_at_index(
2026 node_index, "output_shapes", "output shape"
2027 )
2029 @doc_controls.do_not_doc_inheritable
2030 def get_input_at(self, node_index):
2031 """Retrieves the input tensor(s) of a layer at a given node.
2033 Args:
2034 node_index: Integer, index of the node
2035 from which to retrieve the attribute.
2036 E.g. `node_index=0` will correspond to the
2037 first input node of the layer.
2039 Returns:
2040 A tensor (or list of tensors if the layer has multiple inputs).
2042 Raises:
2043 RuntimeError: If called in Eager mode.
2044 """
2045 return self._get_node_attribute_at_index(
2046 node_index, "input_tensors", "input"
2047 )
2049 @doc_controls.do_not_doc_inheritable
2050 def get_output_at(self, node_index):
2051 """Retrieves the output tensor(s) of a layer at a given node.
2053 Args:
2054 node_index: Integer, index of the node
2055 from which to retrieve the attribute.
2056 E.g. `node_index=0` will correspond to the
2057 first output node of the layer.
2059 Returns:
2060 A tensor (or list of tensors if the layer has multiple outputs).
2062 Raises:
2063 RuntimeError: If called in Eager mode.
2064 """
2065 return self._get_node_attribute_at_index(
2066 node_index, "output_tensors", "output"
2067 )
2069 @property
2070 def input(self):
2071 """Retrieves the input tensor(s) of a layer.
2073 Only applicable if the layer has exactly one input,
2074 i.e. if it is connected to one incoming layer.
2076 Returns:
2077 Input tensor or list of input tensors.
2079 Raises:
2080 RuntimeError: If called in Eager mode.
2081 AttributeError: If no inbound nodes are found.
2082 """
2083 if not self._inbound_nodes:
2084 raise AttributeError(
2085 "Layer " + self.name + " is not connected, no input to return."
2086 )
2087 return self._get_node_attribute_at_index(0, "input_tensors", "input")
2089 @property
2090 def output(self):
2091 """Retrieves the output tensor(s) of a layer.
2093 Only applicable if the layer has exactly one output,
2094 i.e. if it is connected to one incoming layer.
2096 Returns:
2097 Output tensor or list of output tensors.
2099 Raises:
2100 AttributeError: if the layer is connected to more than one incoming
2101 layers.
2102 RuntimeError: if called in Eager mode.
2103 """
2104 if not self._inbound_nodes:
2105 raise AttributeError(
2106 "Layer " + self.name + " has no inbound nodes."
2107 )
2108 return self._get_node_attribute_at_index(0, "output_tensors", "output")
2110 @property
2111 @doc_controls.do_not_doc_inheritable
2112 def input_shape(self):
2113 """Retrieves the input shape(s) of a layer.
2115 Only applicable if the layer has exactly one input,
2116 i.e. if it is connected to one incoming layer, or if all inputs
2117 have the same shape.
2119 Returns:
2120 Input shape, as an integer shape tuple
2121 (or list of shape tuples, one tuple per input tensor).
2123 Raises:
2124 AttributeError: if the layer has no defined input_shape.
2125 RuntimeError: if called in Eager mode.
2126 """
2127 if not self._inbound_nodes:
2128 raise AttributeError(
2129 f'The layer "{self.name}" has never been called '
2130 "and thus has no defined input shape. Note that the "
2131 "`input_shape` property is only available for "
2132 "Functional and Sequential models."
2133 )
2134 all_input_shapes = set(
2135 [str(node.input_shapes) for node in self._inbound_nodes]
2136 )
2137 if len(all_input_shapes) == 1:
2138 return self._inbound_nodes[0].input_shapes
2139 else:
2140 raise AttributeError(
2141 'The layer "'
2142 + str(self.name)
2143 + '" has multiple inbound nodes, '
2144 "with different input shapes. Hence "
2145 'the notion of "input shape" is '
2146 "ill-defined for the layer. "
2147 "Use `get_input_shape_at(node_index)` "
2148 "instead."
2149 )
2151 def count_params(self):
2152 """Count the total number of scalars composing the weights.
2154 Returns:
2155 An integer count.
2157 Raises:
2158 ValueError: if the layer isn't yet built
2159 (in which case its weights aren't yet defined).
2160 """
2161 if not self.built:
2162 if getattr(self, "_is_graph_network", False):
2163 with tf_utils.maybe_init_scope(self):
2164 self._maybe_build(self.inputs)
2165 else:
2166 raise ValueError(
2167 "You tried to call `count_params` "
2168 f"on layer {self.name}"
2169 ", but the layer isn't built. "
2170 "You can build it manually via: "
2171 f"`{self.name}.build(batch_input_shape)`."
2172 )
2173 return layer_utils.count_params(self.weights)
2175 @property
2176 @doc_controls.do_not_doc_inheritable
2177 def output_shape(self):
2178 """Retrieves the output shape(s) of a layer.
2180 Only applicable if the layer has one output,
2181 or if all outputs have the same shape.
2183 Returns:
2184 Output shape, as an integer shape tuple
2185 (or list of shape tuples, one tuple per output tensor).
2187 Raises:
2188 AttributeError: if the layer has no defined output shape.
2189 RuntimeError: if called in Eager mode.
2190 """
2191 if not self._inbound_nodes:
2192 raise AttributeError(
2193 f'The layer "{self.name}" has never been called '
2194 "and thus has no defined output shape."
2195 )
2196 all_output_shapes = set(
2197 [str(node.output_shapes) for node in self._inbound_nodes]
2198 )
2199 if len(all_output_shapes) == 1:
2200 return self._inbound_nodes[0].output_shapes
2201 else:
2202 raise AttributeError(
2203 'The layer "%s"'
2204 " has multiple inbound nodes, "
2205 "with different output shapes. Hence "
2206 'the notion of "output shape" is '
2207 "ill-defined for the layer. "
2208 "Use `get_output_shape_at(node_index)` "
2209 "instead." % self.name
2210 )
2212 @property
2213 def dtype_policy(self):
2214 """The dtype policy associated with this layer.
2216 This is an instance of a `tf.keras.mixed_precision.Policy`.
2217 """
2218 return self._dtype_policy
2220 @property
2221 def compute_dtype(self):
2222 """The dtype of the layer's computations.
2224 This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless
2225 mixed precision is used, this is the same as `Layer.dtype`, the dtype of
2226 the weights.
2228 Layers automatically cast their inputs to the compute dtype, which
2229 causes computations and the output to be in the compute dtype as well.
2230 This is done by the base Layer class in `Layer.__call__`, so you do not
2231 have to insert these casts if implementing your own layer.
2233 Layers often perform certain internal computations in higher precision
2234 when `compute_dtype` is float16 or bfloat16 for numeric stability. The
2235 output will still typically be float16 or bfloat16 in such cases.
2237 Returns:
2238 The layer's compute dtype.
2239 """
2240 return self._dtype_policy.compute_dtype
2242 @property
2243 def variable_dtype(self):
2244 """Alias of `Layer.dtype`, the dtype of the weights."""
2245 return self.dtype
2247 @property
2248 @doc_controls.do_not_doc_inheritable
2249 def inbound_nodes(self):
2250 """Return Functional API nodes upstream of this layer."""
2251 return self._inbound_nodes
2253 @property
2254 @doc_controls.do_not_doc_inheritable
2255 def outbound_nodes(self):
2256 """Return Functional API nodes downstream of this layer."""
2257 return self._outbound_nodes
2259 ############################################################################
2260 # Methods & attributes below are public aliases of other methods. #
2261 ############################################################################
2263 @property
2264 @doc_controls.do_not_generate_docs
2265 def variables(self):
2266 """Returns the list of all layer variables/weights.
2268 Alias of `self.weights`.
2270 Note: This will not track the weights of nested `tf.Modules` that are
2271 not themselves Keras layers.
2273 Returns:
2274 A list of variables.
2275 """
2276 return self.weights
2278 @property
2279 @doc_controls.do_not_generate_docs
2280 def trainable_variables(self):
2281 return self.trainable_weights
2283 @property
2284 @doc_controls.do_not_generate_docs
2285 def non_trainable_variables(self):
2286 return self.non_trainable_weights
2288 @doc_controls.do_not_doc_inheritable
2289 def add_variable(self, *args, **kwargs):
2290 """Deprecated, do NOT use! Alias for `add_weight`."""
2291 warnings.warn(
2292 "`layer.add_variable` is deprecated and "
2293 "will be removed in a future version. "
2294 "Please use the `layer.add_weight()` method instead.",
2295 stacklevel=2,
2296 )
2297 return self.add_weight(*args, **kwargs)
2299 def get_build_config(self):
2300 """Returns a dictionary with the layer's input shape.
2302 This method returns a config dict that can be used by
2303 `build_from_config(config)` to create all states (e.g. Variables and
2304 Lookup tables) needed by the layer.
2306 By default, the config only contains the input shape that the layer
2307 was built with. If you're writing a custom layer that creates state in
2308 an unusual way, you should override this method to make sure this state
2309 is already created when Keras attempts to load its value upon model
2310 loading.
2312 Returns:
2313 A dict containing the input shape associated with the layer.
2314 """
2315 if self._build_input_shape is not None:
2317 def convert_tensorshapes(x):
2318 if isinstance(x, tf.TensorShape) and x._dims:
2319 return tuple(x.as_list())
2320 return x
2322 return {
2323 "input_shape": tf.nest.map_structure(
2324 convert_tensorshapes, self._build_input_shape
2325 )
2326 }
2328 def build_from_config(self, config):
2329 """Builds the layer's states with the supplied config dict.
2331 By default, this method calls the `build(config["input_shape"])` method,
2332 which creates weights based on the layer's input shape in the supplied
2333 config. If your config contains other information needed to load the
2334 layer's state, you should override this method.
2336 Args:
2337 config: Dict containing the input shape associated with this layer.
2338 """
2339 input_shape = config["input_shape"]
2340 if input_shape is not None:
2341 self.build(input_shape)
2343 ############################################################################
2344 # Methods & attributes below are all private and only used by the framework.
2345 ############################################################################
2347 # See tf.Module for the usage of this property.
2348 # The key for _obj_reference_counts_dict is a Trackable, which could be a
2349 # variable or layer etc. tf.Module._flatten will fail to flatten the key
2350 # since it is trying to convert Trackable to a string. This attribute can be
2351 # ignored even after the fix of nest lib, since the trackable object should
2352 # already been available as individual attributes.
2353 # _obj_reference_counts_dict just contains a copy of them.
2354 _TF_MODULE_IGNORED_PROPERTIES = frozenset(
2355 itertools.chain(
2356 ("_obj_reference_counts_dict",),
2357 tf.Module._TF_MODULE_IGNORED_PROPERTIES,
2358 )
2359 )
2361 # When loading from a SavedModel, Layers typically can be revived into a
2362 # generic Layer wrapper. Sometimes, however, layers may implement methods
2363 # that go beyond this wrapper, as in the case of PreprocessingLayers'
2364 # `adapt` method. When this is the case, layer implementers can override
2365 # must_restore_from_config to return True; layers with this property must
2366 # be restored into their actual objects (and will fail if the object is
2367 # not available to the restoration code).
2368 _must_restore_from_config = False
2370 def _get_cell_name(self):
2371 canonical_name = get_canonical_name_for_symbol(
2372 self.__class__, api_name="keras", add_prefix_to_v1_names=True
2373 )
2374 if canonical_name is not None:
2375 return f"tf.{canonical_name}"
2376 return self.__class__.__module__ + "." + self.__class__.__name__
2378 def _instrument_layer_creation(self):
2379 self._instrumented_keras_api = False
2380 self._instrumented_keras_layer_class = False
2381 self._instrumented_keras_model_class = False
2382 if not getattr(self, "_disable_keras_instrumentation", False):
2383 keras_api_gauge.get_cell("layer").set(True)
2384 self._instrumented_keras_api = True
2385 if getattr(self, "_is_model_for_instrumentation", False):
2386 keras_models_gauge.get_cell(self._get_cell_name()).set(True)
2387 self._instrumented_keras_model_class = True
2388 else:
2389 keras_layers_gauge.get_cell(self._get_cell_name()).set(True)
2390 self._instrumented_keras_layer_class = True
2391 else:
2392 # This is a legacy layer that has disabled instrumentation
2393 # as a native keras object. We still instrument this as
2394 # legacy usage.
2395 keras_api_gauge.get_cell("legacy_layer").set(True)
2397 @doc_controls.for_subclass_implementers
2398 def _add_trackable(self, trackable_object, trainable):
2399 """Adds a Trackable object to this layer's state.
2401 Args:
2402 trackable_object: The tf.tracking.Trackable object to add.
2403 trainable: Boolean, whether the variable should be part of the layer's
2404 "trainable_variables" (e.g. variables, biases) or
2405 "non_trainable_variables" (e.g. BatchNorm mean and variance).
2407 Returns:
2408 The TrackableWeightHandler used to track this object.
2409 """
2410 if isinstance(
2411 trackable_object, base_layer_utils.TrackableWeightHandler
2412 ):
2413 handler = trackable_object
2414 else:
2415 handler = base_layer_utils.TrackableWeightHandler(trackable_object)
2416 if trainable:
2417 self._trainable_weights.append(handler)
2418 else:
2419 self._non_trainable_weights.append(handler)
2420 return handler
2422 def _clear_losses(self):
2423 """Used every step in eager to reset losses."""
2424 # Set to thread local directly to avoid Layer.__setattr__ overhead.
2425 if not getattr(
2426 self, "_self_tracked_trackables", None
2427 ): # Fast path for single Layer.
2428 self._thread_local._eager_losses = []
2429 else:
2430 for layer in self._flatten_layers():
2431 layer._thread_local._eager_losses = []
2433 def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs):
2434 if self.dynamic:
2435 # We will use static shape inference to return symbolic tensors
2436 # matching the specifications of the layer outputs.
2437 # Since `self.dynamic` is True, we will never attempt to
2438 # run the underlying TF graph (which is disconnected).
2439 # TODO(fchollet): consider py_func as an alternative, which
2440 # would enable us to run the underlying graph if needed.
2441 input_signature = tf.nest.map_structure(
2442 lambda x: tf.TensorSpec(shape=x.shape, dtype=x.dtype), inputs
2443 )
2444 output_signature = self.compute_output_signature(input_signature)
2445 return tf.nest.map_structure(
2446 keras_tensor.KerasTensor, output_signature
2447 )
2448 else:
2449 return self._infer_output_signature(
2450 inputs, args, kwargs, input_masks
2451 )
2453 def _infer_output_signature(self, inputs, args, kwargs, input_masks):
2454 """Call the layer on input KerasTensors, returns output KerasTensors."""
2456 keras_tensor_inputs = inputs
2457 call_fn = self.call
2458 # Wrapping `call` function in autograph to allow for dynamic control
2459 # flow and control dependencies in call. We are limiting this to
2460 # subclassed layers as autograph is strictly needed only for
2461 # subclassed layers and models.
2462 # tf_convert will respect the value of autograph setting in the
2463 # enclosing tf.function, if any.
2464 if base_layer_utils.is_subclassed(
2465 self
2466 ) and not base_layer_utils.from_saved_model(self):
2467 call_fn = tf.__internal__.autograph.tf_convert(
2468 self.call, tf.__internal__.autograph.control_status_ctx()
2469 )
2471 call_fn = traceback_utils.inject_argument_info_in_traceback(
2472 call_fn,
2473 object_name=f'layer "{self.name}" (type {self.__class__.__name__})',
2474 )
2476 # We enter a scratch graph and build placeholder inputs inside of it
2477 # that match the input args.
2478 # We then call the layer inside of the scratch graph to identify the
2479 # output signatures, then we build KerasTensors corresponding to those
2480 # outputs.
2481 scratch_graph = tf.__internal__.FuncGraph(
2482 str(self.name) + "_scratch_graph"
2483 )
2484 with scratch_graph.as_default():
2485 inputs = tf.nest.map_structure(
2486 keras_tensor.keras_tensor_to_placeholder, inputs
2487 )
2488 args = tf.nest.map_structure(
2489 keras_tensor.keras_tensor_to_placeholder, args
2490 )
2491 kwargs = tf.nest.map_structure(
2492 keras_tensor.keras_tensor_to_placeholder, kwargs
2493 )
2494 input_masks = tf.nest.map_structure(
2495 keras_tensor.keras_tensor_to_placeholder, input_masks
2496 )
2498 with backend.name_scope(self._name_scope()):
2499 with autocast_variable.enable_auto_cast_variables(
2500 self._compute_dtype_object
2501 ):
2502 # Build layer if applicable (if the `build` method has been
2503 # overridden).
2504 # TODO(kaftan): do we maybe_build here, or have we already
2505 # done it?
2506 self._maybe_build(inputs)
2507 inputs = self._maybe_cast_inputs(inputs)
2508 outputs = call_fn(inputs, *args, **kwargs)
2510 self._handle_activity_regularization(inputs, outputs)
2511 self._set_mask_metadata(
2512 inputs, outputs, input_masks, build_graph=False
2513 )
2514 outputs = tf.nest.map_structure(
2515 keras_tensor.keras_tensor_from_tensor, outputs
2516 )
2518 self._set_save_spec(keras_tensor_inputs, args, kwargs)
2519 if hasattr(self, "_set_inputs") and not self.inputs:
2520 # TODO(kaftan): figure out if we need to do this at all
2521 # Subclassed network: explicitly set metadata normally set by
2522 # a call to self._set_inputs().
2523 self._set_inputs(inputs, outputs)
2524 del scratch_graph
2525 return outputs
2527 def _functional_construction_call(self, inputs, args, kwargs, input_list):
2528 call_context = base_layer_utils.call_context()
2530 # Accept NumPy and scalar inputs by converting to Tensors.
2531 if any(
2532 isinstance(x, (tf.Tensor, np.ndarray, float, int))
2533 for x in input_list
2534 ):
2536 def _convert_non_tensor(x):
2537 # Don't call `ops.convert_to_tensor` on all `inputs` because
2538 # `SparseTensors` can't be converted to `Tensor`.
2539 if isinstance(x, (tf.Tensor, np.ndarray, float, int)):
2540 return tf.convert_to_tensor(x)
2541 return x
2543 inputs = tf.nest.map_structure(_convert_non_tensor, inputs)
2544 input_list = tf.nest.flatten(inputs)
2546 # Handle `mask` propagation from previous layer to current layer. Masks
2547 # can be propagated explicitly via the `mask` argument, or implicitly
2548 # via setting the `_keras_mask` attribute on the inputs to a Layer.
2549 # Masks passed explicitly take priority.
2550 mask_arg_passed_by_framework = False
2551 input_masks, mask_is_implicit = self._get_input_masks(
2552 inputs, input_list, args, kwargs
2553 )
2554 if self._expects_mask_arg and mask_is_implicit:
2555 kwargs["mask"] = input_masks
2556 mask_arg_passed_by_framework = True
2558 # If `training` argument is None or not explicitly passed,
2559 # propagate `training` value from this layer's calling layer.
2560 training_value = None
2561 training_arg_passed_by_framework = False
2562 # Priority 1: `training` was explicitly passed a non-None value.
2563 if self._call_spec.arg_was_passed("training", args, kwargs):
2564 training_value = self._call_spec.get_arg_value(
2565 "training", args, kwargs
2566 )
2567 if not self._expects_training_arg:
2568 kwargs.pop("training")
2570 if training_value is None:
2571 # Priority 2: `training` was passed to a parent layer.
2572 if call_context.training is not None:
2573 training_value = call_context.training
2574 # Priority 3: `learning_phase()` has been set.
2575 elif backend.global_learning_phase_is_set():
2576 training_value = backend.learning_phase()
2577 # Force the training_value to be bool type which matches to the
2578 # contract for layer/model call args.
2579 if tf.is_tensor(training_value):
2580 training_value = tf.cast(training_value, tf.bool)
2581 else:
2582 training_value = bool(training_value)
2583 # Priority 4: trace layer with the default training argument
2584 # specified in the `call` signature (or in inference mode if the
2585 # `call` signature specifies no non-None default).
2586 else:
2587 training_value = self._call_spec.default_training_arg
2588 # In cases (2), (3), (4) the training argument is passed
2589 # automatically by the framework, and will not be hard-coded into
2590 # the model.
2591 if self._expects_training_arg:
2592 args, kwargs = self._call_spec.set_arg_value(
2593 "training", training_value, args, kwargs
2594 )
2595 training_arg_passed_by_framework = True
2597 with call_context.enter(
2598 layer=self, inputs=inputs, build_graph=True, training=training_value
2599 ):
2600 # Check input assumptions set after layer building, e.g. input
2601 # shape.
2602 try:
2603 outputs = self._keras_tensor_symbolic_call(
2604 inputs, input_masks, args, kwargs
2605 )
2606 except TypeError as e:
2607 if "DictWrapper" in str(e):
2608 raise TypeError(
2609 f"{self} could not be deserialized properly. Please"
2610 " ensure that components that are Python object"
2611 " instances (layers, models, etc.) returned by"
2612 " `get_config()` are explicitly deserialized in the"
2613 " model's `from_config()` method."
2614 ) from e
2615 else:
2616 raise e
2618 if outputs is None:
2619 raise ValueError(
2620 "A layer's `call` method should return a "
2621 "Tensor or a list of Tensors, not None "
2622 "(layer: " + self.name + ")."
2623 )
2624 if training_arg_passed_by_framework:
2625 args, kwargs = self._call_spec.set_arg_value(
2626 "training", None, args, kwargs, pop_kwarg_if_none=True
2627 )
2628 if mask_arg_passed_by_framework:
2629 kwargs.pop("mask")
2630 # Node connectivity does not special-case the first argument.
2631 outputs = self._set_connectivity_metadata(
2632 (inputs,) + args, kwargs, outputs
2633 )
2634 return outputs
2636 def _set_training_mode(self, args, kwargs, call_context):
2637 training_mode = None
2638 if self._expects_training_arg:
2639 # (1) `training` was passed to this `Layer.call`.
2640 if self._call_spec.arg_was_passed("training", args, kwargs):
2641 training_mode = self._call_spec.get_arg_value(
2642 "training", args, kwargs
2643 )
2644 # If no `training` arg was passed, or `None` was explicitly passed,
2645 # the framework will make a decision about the training mode is.
2646 if training_mode is None:
2647 call_ctx_training = call_context.training
2648 # (2) `training` mode is inferred from an outer `Layer.call`.
2649 if call_ctx_training is not None:
2650 training_mode = call_ctx_training
2651 # (3) User set `tf.keras.backend.set_learning_phase`.
2652 elif backend.global_learning_phase_is_set():
2653 training_mode = backend.learning_phase()
2654 # Ensure value is a `bool` or `tf.bool`.
2655 if isinstance(training_mode, bool):
2656 pass
2657 elif tf.is_tensor(training_mode):
2658 training_mode = tf.cast(training_mode, tf.bool)
2659 else:
2660 training_mode = bool(training_mode)
2661 # (4) We default to using `call`'s default value for `training`,
2662 # or treating the layer as if it is in inference if no non-None
2663 # default is specified in the `call` signature.
2664 else:
2665 training_mode = self._call_spec.default_training_arg
2667 # For case (2), (3), (4) `training` arg is passed by framework.
2668 args, kwargs = self._call_spec.set_arg_value(
2669 "training", training_mode, args, kwargs
2670 )
2671 else:
2672 if "training" in kwargs:
2673 # `training` was passed to this `Layer` but is not needed for
2674 # `Layer.call`. It will set the default mode for inner
2675 # `Layer.call`s.
2676 training_mode = kwargs.pop("training")
2677 else:
2678 # Grab the current `training` mode from any outer `Layer.call`.
2679 training_mode = call_context.training
2681 return args, kwargs, training_mode
2683 def _autographed_call(self):
2684 # Wrapping `call` function in autograph to allow for dynamic control
2685 # flow and control dependencies in call. We are limiting this to
2686 # subclassed layers as autograph is strictly needed only for
2687 # subclassed layers and models.
2688 # tf_convert will respect the value of autograph setting in the
2689 # enclosing tf.function, if any.
2690 if base_layer_utils.is_subclassed(
2691 self
2692 ) and not base_layer_utils.from_saved_model(self):
2693 return tf.__internal__.autograph.tf_convert(
2694 self.call, tf.__internal__.autograph.control_status_ctx()
2695 )
2696 else:
2697 return self.call
2699 @property
2700 def _inbound_nodes(self):
2701 return self._inbound_nodes_value
2703 @_inbound_nodes.setter
2704 @tf.__internal__.tracking.no_automatic_dependency_tracking
2705 def _inbound_nodes(self, value):
2706 self._inbound_nodes_value = value
2708 @property
2709 def _outbound_nodes(self):
2710 return self._outbound_nodes_value
2712 @_outbound_nodes.setter
2713 @tf.__internal__.tracking.no_automatic_dependency_tracking
2714 def _outbound_nodes(self, value):
2715 self._outbound_nodes_value = value
2717 def _set_dtype_policy(self, dtype):
2718 """Sets self._dtype_policy."""
2719 self._dtype_policy = policy.get_policy(dtype)
2721 # Performance optimization: cache the compute dtype as a Dtype object or
2722 # None, so that str to Dtype conversion doesn't happen in
2723 # Layer.__call__.
2724 # TODO(b/157486353): Investigate returning DTypes in Policy.
2725 if self._dtype_policy.compute_dtype:
2726 self._compute_dtype_object = tf.as_dtype(
2727 self._dtype_policy.compute_dtype
2728 )
2729 else:
2730 self._compute_dtype_object = None
2732 @property
2733 def _compute_dtype(self):
2734 """Deprecated alias of `compute_dtype`."""
2735 return self._dtype_policy.compute_dtype
2737 def _maybe_cast_inputs(self, inputs, input_list=None):
2738 """Maybe casts the inputs to the compute dtype.
2740 If self._compute_dtype is floating-point, and self_autocast is True,
2741 floating-point inputs are casted to self._compute_dtype.
2743 Args:
2744 inputs: Input tensor, or structure of input tensors.
2745 input_list: Flat list of input tensors.
2747 Returns:
2748 `inputs`, but tensors may have been casted to self._compute_dtype
2749 """
2750 if not input_list:
2751 input_list = tf.nest.flatten(inputs)
2753 compute_dtype_object = self._compute_dtype_object
2754 should_autocast = (
2755 self._autocast
2756 and compute_dtype_object
2757 and compute_dtype_object.is_floating
2758 )
2760 if should_autocast and any(
2761 map(self._should_cast_single_input, input_list)
2762 ):
2763 # Only perform expensive `nest` operation when needed.
2764 return tf.nest.map_structure(self._cast_single_input, inputs)
2765 else:
2766 return inputs
2768 def _should_cast_single_input(self, x):
2769 if isinstance(x, _AUTOCAST_TYPES):
2770 return (
2771 self._compute_dtype_object
2772 and x.dtype != self._compute_dtype_object
2773 and x.dtype.is_floating
2774 )
2775 return False
2777 def _cast_single_input(self, x):
2778 """Cast a single Tensor or TensorSpec to the compute dtype."""
2779 if self._should_cast_single_input(x):
2780 return tf.cast(x, self._compute_dtype_object)
2781 else:
2782 return x
2784 # _dtype used to be an attribute set in the constructor. We still expose it
2785 # because some clients still use it.
2786 # TODO(reedwm): Deprecate, then remove the _dtype property.
2787 @property
2788 def _dtype(self):
2789 # This is equivalent to returning self.dtype . We do not return
2790 # self.dtype as it would cause infinite recursion in a few subclasses,
2791 # which override "dtype" to return self._dtype.
2792 return self._dtype_policy.variable_dtype
2794 @_dtype.setter
2795 def _dtype(self, value):
2796 value = tf.as_dtype(value).name
2797 self._set_dtype_policy(policy.Policy(value))
2799 def _name_scope(self):
2800 if not tf.__internal__.tf2.enabled():
2801 return self.name
2802 name_scope = self.name
2803 current_name_scope = tf.__internal__.get_name_scope()
2804 if current_name_scope:
2805 name_scope = current_name_scope + "/" + name_scope
2806 if name_scope:
2807 # Note that the trailing `/` prevents autogenerated
2808 # numerical suffixes to get appended. It will also fully reset
2809 # nested name scope (i.e. the outer name scope has no effect).
2810 name_scope += "/"
2811 return name_scope
2813 def _init_set_name(self, name, zero_based=True):
2814 if name is None:
2815 self._name = backend.unique_object_name(
2816 generic_utils.to_snake_case(self.__class__.__name__),
2817 zero_based=zero_based,
2818 )
2819 elif isinstance(name, str):
2820 backend.observe_object_name(name)
2821 self._name = name
2822 else:
2823 raise TypeError(
2824 f"Expected `name` argument to be a string, but got: {name}"
2825 )
2827 def _get_existing_metric(self, name=None):
2828 match = [m for m in self._metrics if m.name == name]
2829 if not match:
2830 return
2831 if len(match) > 1:
2832 raise ValueError(
2833 "Please provide different names for the metrics you have "
2834 'added. We found {} metrics with the name: "{}"'.format(
2835 len(match), name
2836 )
2837 )
2838 return match[0]
2840 def _handle_weight_regularization(self, name, variable, regularizer):
2841 """Create lambdas which compute regularization losses."""
2843 def _loss_for_variable(v):
2844 """Creates a regularization loss `Tensor` for variable `v`."""
2845 with backend.name_scope(name + "/Regularizer"):
2846 regularization = regularizer(v)
2847 return regularization
2849 if base_layer_utils.is_split_variable(variable):
2850 for v in variable:
2851 self.add_loss(functools.partial(_loss_for_variable, v))
2852 elif isinstance(variable, lazy_variable.LazyInitVariable):
2853 self._captured_weight_regularizer.append(
2854 (name, variable, regularizer)
2855 )
2856 else:
2857 self.add_loss(functools.partial(_loss_for_variable, variable))
2859 def _handle_activity_regularization(self, inputs, outputs):
2860 # Apply activity regularization.
2861 # Note that it should be applied every time the layer creates a new
2862 # output, since it is output-specific.
2863 if self._activity_regularizer:
2864 output_list = tf.nest.flatten(outputs)
2865 with backend.name_scope("ActivityRegularizer"):
2866 for output in output_list:
2867 activity_loss = tf.convert_to_tensor(
2868 self._activity_regularizer(output)
2869 )
2870 batch_size = tf.cast(
2871 tf.shape(output)[0], activity_loss.dtype
2872 )
2873 # Make activity regularization strength batch-agnostic.
2874 mean_activity_loss = activity_loss / batch_size
2875 self.add_loss(mean_activity_loss)
2877 def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph):
2878 # Many `Layer`s don't need to call `compute_mask`.
2879 # This method is optimized to do as little work as needed for the common
2880 # case.
2881 if not self._supports_masking:
2882 return
2884 flat_outputs = tf.nest.flatten(outputs)
2886 mask_already_computed = getattr(
2887 self, "_compute_output_and_mask_jointly", False
2888 ) or all(
2889 getattr(x, "_keras_mask", None) is not None for x in flat_outputs
2890 )
2891 if mask_already_computed:
2892 if build_graph:
2893 self._set_mask_keras_history_checked(flat_outputs)
2894 return
2896 output_masks = self.compute_mask(inputs, previous_mask)
2897 if output_masks is None:
2898 return
2900 flat_masks = tf.nest.flatten(output_masks)
2901 for tensor, mask in zip(flat_outputs, flat_masks):
2902 try:
2903 tensor._keras_mask = mask
2904 except AttributeError:
2905 # C Type such as np.ndarray.
2906 pass
2908 if build_graph:
2909 self._set_mask_keras_history_checked(flat_outputs)
2911 def _set_mask_keras_history_checked(self, flat_outputs):
2912 for output in flat_outputs:
2913 if getattr(output, "_keras_mask", None) is not None:
2914 # Do not track masks for `TensorFlowOpLayer` construction.
2915 output._keras_mask._keras_history_checked = True
2917 def _get_input_masks(self, inputs, input_list, args, kwargs):
2918 if not self._supports_masking and not self._expects_mask_arg:
2919 # Input masks only need to be retrieved if they are needed for
2920 # `call` or `compute_mask`.
2921 input_masks = None
2922 implicit_mask = False
2923 elif self._call_spec.arg_was_passed("mask", args, kwargs):
2924 input_masks = self._call_spec.get_arg_value("mask", args, kwargs)
2925 implicit_mask = False
2926 else:
2927 input_masks = [getattr(t, "_keras_mask", None) for t in input_list]
2928 if all(mask is None for mask in input_masks):
2929 input_masks = None
2930 implicit_mask = False
2931 else:
2932 # Only do expensive `nest` op when masking is actually being
2933 # used.
2934 input_masks = tf.nest.pack_sequence_as(inputs, input_masks)
2935 implicit_mask = True
2936 return input_masks, implicit_mask
2938 def _set_connectivity_metadata(self, args, kwargs, outputs):
2939 # If the layer returns tensors from its inputs unmodified,
2940 # we copy them to avoid loss of KerasHistory metadata.
2941 flat_outputs = tf.nest.flatten(outputs)
2942 flat_inputs = tf.nest.flatten((args, kwargs))
2943 input_ids_set = {id(i) for i in flat_inputs}
2944 outputs_copy = []
2945 for x in flat_outputs:
2946 if id(x) in input_ids_set:
2947 with backend.name_scope(self.name):
2948 x = tf.identity(x)
2949 outputs_copy.append(x)
2950 outputs = tf.nest.pack_sequence_as(outputs, outputs_copy)
2952 # Create node, Node wires itself to inbound and outbound layers. The
2953 # Node constructor actually updates this layer's self._inbound_nodes,
2954 # sets _keras_history on the outputs, and adds itself to the
2955 # `_outbound_nodes` of the layers that produced the inputs to this layer
2956 # call.
2957 node_module.Node(
2958 self, call_args=args, call_kwargs=kwargs, outputs=outputs
2959 )
2960 return outputs
2962 def _get_node_attribute_at_index(self, node_index, attr, attr_name):
2963 """Private utility to retrieves an attribute (e.g. inputs) from a node.
2965 This is used to implement the methods:
2966 - get_input_shape_at
2967 - get_output_shape_at
2968 - get_input_at
2969 etc...
2971 Args:
2972 node_index: Integer index of the node from which
2973 to retrieve the attribute.
2974 attr: Exact node attribute name.
2975 attr_name: Human-readable attribute name, for error messages.
2977 Returns:
2978 The layer's attribute `attr` at the node of index `node_index`.
2980 Raises:
2981 RuntimeError: If the layer has no inbound nodes, or if called in
2982 Eager mode.
2983 ValueError: If the index provided does not match any node.
2984 """
2985 if not self._inbound_nodes:
2986 raise RuntimeError(
2987 f"The layer {self.name} has never been called "
2988 f"and thus has no defined {attr_name}."
2989 )
2990 if not len(self._inbound_nodes) > node_index:
2991 raise ValueError(
2992 f"Asked to get {attr_name} at node "
2993 f"{node_index}, but the layer has only "
2994 f"{len(self._inbound_nodes)} inbound nodes."
2995 )
2996 values = getattr(self._inbound_nodes[node_index], attr)
2997 if isinstance(values, list) and len(values) == 1:
2998 return values[0]
2999 else:
3000 return values
3002 def _maybe_build(self, inputs):
3003 # Check input assumptions set before layer building, e.g. input rank.
3004 if not self.built:
3005 input_spec.assert_input_compatibility(
3006 self.input_spec, inputs, self.name
3007 )
3008 input_list = tf.nest.flatten(inputs)
3009 if input_list and self._dtype_policy.compute_dtype is None:
3010 try:
3011 dtype = input_list[0].dtype.base_dtype.name
3012 except AttributeError:
3013 pass
3014 else:
3015 self._set_dtype_policy(policy.Policy(dtype))
3016 input_shapes = None
3017 # Converts Tensors / CompositeTensors to TensorShapes.
3018 if any(hasattr(x, "shape") for x in input_list):
3019 input_shapes = tf_utils.get_shapes(inputs)
3020 else:
3021 # Converts input shape to TensorShapes.
3022 try:
3023 input_shapes = tf_utils.convert_shapes(
3024 inputs, to_tuples=False
3025 )
3026 except ValueError:
3027 pass
3028 # Only call `build` if the user has manually overridden the build
3029 # method.
3030 if not hasattr(self.build, "_is_default"):
3031 # Any setup work performed only once should happen in an
3032 # `init_scope` to avoid creating symbolic Tensors that will
3033 # later pollute any eager operations.
3034 with tf_utils.maybe_init_scope(self):
3035 self.build(input_shapes)
3036 # We must set also ensure that the layer is marked as built, and the
3037 # build shape is stored since user defined build functions may not
3038 # be calling `super.build()`
3039 Layer.build(self, input_shapes)
3041 # Optionally load weight values specified at layer instantiation.
3042 if self._initial_weights is not None:
3043 with tf.init_scope():
3044 # Using `init_scope` since we want variable assignment in
3045 # `set_weights` to be treated like variable initialization.
3046 self.set_weights(self._initial_weights)
3047 self._initial_weights = None
3049 def _get_trainable_state(self):
3050 """Get the `trainable` state of each sublayer.
3052 Returns:
3053 A dict mapping all sublayers to their `trainable` value.
3054 """
3055 trainable_state = weakref.WeakKeyDictionary()
3056 for layer in self._flatten_layers():
3057 trainable_state[layer] = layer.trainable
3058 return trainable_state
3060 def _set_trainable_state(self, trainable_state):
3061 """Set `trainable` state for each sublayer."""
3062 for layer in self._flatten_layers():
3063 if layer in trainable_state:
3064 layer.trainable = trainable_state[layer]
3066 @property
3067 def _obj_reference_counts(self):
3068 """A dict counting the number of attributes referencing an object."""
3069 self._maybe_create_attribute(
3070 "_obj_reference_counts_dict",
3071 object_identity.ObjectIdentityDictionary(),
3072 )
3073 return self._obj_reference_counts_dict
3075 @tf.__internal__.tracking.no_automatic_dependency_tracking
3076 def _maybe_create_attribute(self, name, default_value):
3077 """Create attribute (with the default value) if it hasn't been created.
3079 This is useful for fields that is used for tracking purpose,
3080 _trainable_weights, or _layers. Note that user could create a layer
3081 subclass and assign an internal field before invoking the
3082 Layer.__init__(), the __setattr__() need to create the tracking fields
3083 and __init__() need to not override them.
3085 Args:
3086 name: String, the name of the attribute.
3087 default_value: Object, the default value of the attribute.
3088 """
3089 if not hasattr(self, name):
3090 self.__setattr__(name, default_value)
3092 def __delattr__(self, name):
3093 # For any super.__delattr__() call, we will directly use the
3094 # implementation in Trackable and skip the behavior in AutoTrackable.
3095 # The Layer was originally use Trackable as base class, the change of
3096 # using Module as base class forced us to have AutoTrackable in the
3097 # class hierarchy.
3098 #
3099 # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
3100 # __setattr__ in AutoTrackable may be unsustainable.
3101 existing_value = getattr(self, name, None)
3103 # If this value is replacing an existing object assigned to an
3104 # attribute, we should clean it out to avoid leaking memory. First we
3105 # check if there are other attributes referencing it.
3106 reference_counts = self._obj_reference_counts
3107 if existing_value not in reference_counts:
3108 super(tf.__internal__.tracking.AutoTrackable, self).__delattr__(
3109 name
3110 )
3111 return
3113 reference_count = reference_counts[existing_value]
3114 if reference_count > 1:
3115 # There are other remaining references. We can't remove this object
3116 # from _layers etc.
3117 reference_counts[existing_value] = reference_count - 1
3118 super(tf.__internal__.tracking.AutoTrackable, self).__delattr__(
3119 name
3120 )
3121 return
3122 else:
3123 # This is the last remaining reference.
3124 del reference_counts[existing_value]
3126 super(tf.__internal__.tracking.AutoTrackable, self).__delattr__(name)
3128 if isinstance(existing_value, Layer) or base_layer_utils.has_weights(
3129 existing_value
3130 ):
3131 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__(
3132 "_self_tracked_trackables",
3133 [
3134 l
3135 for l in self._self_tracked_trackables
3136 if l is not existing_value
3137 ],
3138 )
3139 if isinstance(existing_value, tf.Variable):
3140 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__(
3141 "_trainable_weights",
3142 [w for w in self._trainable_weights if w is not existing_value],
3143 )
3144 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__(
3145 "_non_trainable_weights",
3146 [
3147 w
3148 for w in self._non_trainable_weights
3149 if w is not existing_value
3150 ],
3151 )
3153 def __setattr__(self, name, value):
3154 if (
3155 name == "_self_setattr_tracking"
3156 or not getattr(self, "_self_setattr_tracking", True)
3157 # Exclude @property.setters from tracking
3158 or hasattr(self.__class__, name)
3159 ):
3160 try:
3161 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__(
3162 name, value
3163 )
3164 except AttributeError:
3165 raise AttributeError(
3166 (
3167 'Can\'t set the attribute "{}", likely because it '
3168 "conflicts with an existing read-only @property of the "
3169 "object. Please choose a different name."
3170 ).format(name)
3171 )
3172 return
3174 # Wraps data structures in `Trackable`, unwraps `NoDependency` objects.
3175 value = tf.__internal__.tracking.sticky_attribute_assignment(
3176 trackable=self, value=value, name=name
3177 )
3179 reference_counts = self._obj_reference_counts
3180 reference_counts[value] = reference_counts.get(value, 0) + 1
3182 # When replacing an existing tf.Variable with a new one, we want to
3183 # check its existing position in the
3184 # self._trainable/non_trainable_variable, so that we can put it back to
3185 # the original position.
3186 if isinstance(value, tf.Variable) and isinstance(
3187 getattr(self, name, None), tf.Variable
3188 ):
3189 existing_variable = getattr(self, name)
3191 def _get_variable_from_list(var_list, var):
3192 # helper function to get the tf.variable from the list
3193 # the default list.index() use == for comparison, which will
3194 # cause issue for eager tensor.
3195 for i in range(len(var_list)):
3196 if var_list[i] is var:
3197 return i
3198 return None
3200 if existing_variable.trainable:
3201 self._maybe_create_attribute("_trainable_weights", [])
3202 position = _get_variable_from_list(
3203 self._trainable_weights, existing_variable
3204 )
3205 else:
3206 self._maybe_create_attribute("_non_trainable_variable", [])
3207 position = _get_variable_from_list(
3208 self._non_trainable_variable, existing_variable
3209 )
3210 else:
3211 position = None
3213 # Clean out the old attribute, which clears _layers and
3214 # _trainable_weights if necessary.
3215 try:
3216 self.__delattr__(name)
3217 except AttributeError:
3218 pass
3220 # Keep track of metric instance created in subclassed layer.
3221 for val in tf.nest.flatten(value):
3222 if isinstance(val, metrics_mod.Metric) and hasattr(
3223 self, "_metrics"
3224 ):
3225 self._metrics.append(val)
3227 # Append value to self._self_tracked_trackables if relevant
3228 if getattr(self, "_auto_track_sub_layers", True) and (
3229 isinstance(value, tf.Module) or base_layer_utils.has_weights(value)
3230 ):
3231 self._maybe_create_attribute("_self_tracked_trackables", [])
3232 # We need to check object identity to avoid de-duplicating empty
3233 # container types which compare equal.
3234 if not any(
3235 (layer is value for layer in self._self_tracked_trackables)
3236 ):
3237 self._self_tracked_trackables.append(value)
3238 if hasattr(value, "_use_resource_variables"):
3239 # Legacy layers (V1 tf.layers) must always use
3240 # resource variables.
3241 value._use_resource_variables = True
3243 # Append value to list of trainable / non-trainable weights if relevant
3244 # TODO(b/125122625): This won't pick up on any variables added to a
3245 # list/dict after creation.
3246 self._track_variables(value, position=position)
3248 # TODO(b/180760306) Skip the auto trackable from tf.Module to keep
3249 # status quo. See the comment at __delattr__.
3250 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__(
3251 name, value
3252 )
3254 def _update_trackables(self):
3255 """Track variables added to lists/dicts after creation"""
3256 for trackable_obj in self._self_tracked_trackables:
3257 if isinstance(
3258 trackable_obj, tf.__internal__.tracking.TrackableDataStructure
3259 ):
3260 self._track_variables(trackable_obj)
3262 def _track_variables(self, value, position=None):
3263 """Tracks `Variable`s including `Variable`s in `CompositeTensor`s."""
3264 for val in tf.nest.flatten(value):
3265 if isinstance(val, tf.Variable):
3266 self._track_variable(val, position=position)
3267 elif tf_utils.is_extension_type(val):
3268 # Manually expand extension types to track resource variables.
3269 nested_vals = tf_utils.type_spec_from_value(val)._to_components(
3270 val
3271 )
3272 self._track_variables(nested_vals, position=position)
3274 def _track_variable(self, val, position=None):
3275 """Tracks the given `tf.Variable`."""
3276 # Users may add extra weights/variables simply by assigning them to
3277 # attributes (invalid for graph networks)
3278 self._maybe_create_attribute("_trainable_weights", [])
3279 self._maybe_create_attribute("_non_trainable_weights", [])
3280 if val.trainable:
3281 if any(val is w for w in self._trainable_weights):
3282 return
3283 if position is None:
3284 self._trainable_weights.append(val)
3285 else:
3286 self._trainable_weights.insert(position, val)
3287 else:
3288 if any(val is w for w in self._non_trainable_weights):
3289 return
3290 if position is None:
3291 self._non_trainable_weights.append(val)
3292 else:
3293 self._non_trainable_weights.insert(position, val)
3294 backend.track_variable(val)
3296 def _gather_children_attribute(self, attribute):
3297 assert attribute in {
3298 "variables",
3299 "trainable_variables",
3300 "non_trainable_variables",
3301 }
3302 if hasattr(self, "_self_tracked_trackables"):
3303 nested_layers = self._flatten_modules(
3304 include_self=False, recursive=False
3305 )
3306 return list(
3307 itertools.chain.from_iterable(
3308 getattr(layer, attribute) for layer in nested_layers
3309 )
3310 )
3311 return []
3313 def _flatten_layers(self, recursive=True, include_self=True):
3314 for m in self._flatten_modules(
3315 recursive=recursive, include_self=include_self
3316 ):
3317 if isinstance(m, Layer):
3318 yield m
3320 def _flatten_modules(self, recursive=True, include_self=True):
3321 """Flattens `tf.Module` instances (excluding `Metrics`).
3323 Args:
3324 recursive: Whether to recursively flatten through submodules.
3325 include_self: Whether to include this `Layer` instance.
3327 Yields:
3328 `tf.Module` instance tracked by this `Layer`.
3329 """
3330 if include_self:
3331 yield self
3333 # Only instantiate set and deque if needed.
3334 trackables = getattr(self, "_self_tracked_trackables", None)
3335 if trackables:
3336 seen_object_ids = set()
3337 deque = collections.deque(trackables)
3338 while deque:
3339 trackable_obj = deque.popleft()
3340 trackable_id = id(trackable_obj)
3341 if trackable_id in seen_object_ids:
3342 continue
3343 seen_object_ids.add(trackable_id)
3345 # Metrics are not considered part of the Layer's topology.
3346 if isinstance(trackable_obj, tf.Module) and not isinstance(
3347 trackable_obj, metrics_mod.Metric
3348 ):
3349 yield trackable_obj
3350 # Introspect recursively through sublayers.
3351 if recursive:
3352 subtrackables = getattr(
3353 trackable_obj, "_self_tracked_trackables", None
3354 )
3355 if subtrackables:
3356 deque.extendleft(reversed(subtrackables))
3357 elif isinstance(
3358 trackable_obj,
3359 tf.__internal__.tracking.TrackableDataStructure,
3360 ):
3361 # Data structures are introspected even with
3362 # `recursive=False`.
3363 tracked_values = trackable_obj._values
3364 if tracked_values:
3365 deque.extendleft(reversed(tracked_values))
3367 # This is a hack so that the is_layer (within
3368 # training/trackable/layer_utils.py) check doesn't get the weights attr.
3369 # TODO(b/110718070): Remove when fixed.
3370 def _is_layer(self):
3371 return True
3373 def _init_call_fn_args(self, expects_training_arg=None):
3374 self._call_spec = layer_utils.CallFunctionSpec(
3375 tf_inspect.getfullargspec(self.call)
3376 )
3377 if expects_training_arg is not None:
3378 self._call_spec.expects_training_arg = expects_training_arg
3380 @property
3381 def _expects_training_arg(self):
3382 """Whether the call function uses 'training' as a parameter."""
3383 return self._call_spec.expects_training_arg
3385 @property
3386 def _expects_mask_arg(self):
3387 return self._call_spec.expects_mask_arg
3389 @property
3390 def _eager_losses(self):
3391 # A list of loss values containing activity regularizers and losses
3392 # manually added through `add_loss` during eager execution. It is
3393 # cleared after every batch. Because we plan on eventually allowing a
3394 # same model instance to be trained in eager mode or graph mode
3395 # alternatively, we need to keep track of eager losses and symbolic
3396 # losses via separate attributes.
3397 if not hasattr(self._thread_local, "_eager_losses"):
3398 self._thread_local._eager_losses = []
3399 return self._thread_local._eager_losses
3401 @_eager_losses.setter
3402 def _eager_losses(self, losses):
3403 self._thread_local._eager_losses = losses
3405 def _dedup_weights(self, weights):
3406 """Dedupe weights while maintaining order as much as possible."""
3407 output, seen_ids = [], set()
3408 for w in weights:
3409 if id(w) not in seen_ids:
3410 output.append(w)
3411 # Track the Variable's identity to avoid __eq__ issues.
3412 seen_ids.add(id(w))
3413 return output
3415 # SavedModel properties. Please see keras/saving/saved_model for details.
3417 @tf.__internal__.tracking.no_automatic_dependency_tracking
3418 def _set_save_spec(self, inputs, args=None, kwargs=None):
3419 """Defines the save spec so that serialization can trace layer calls.
3421 The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are
3422 saved into a tuple of `([inputs] + args, kwargs)`.
3424 Args:
3425 inputs: possibly nested inputs passed into the call function.
3426 args: a list of positional arguments passed into call.
3427 kwargs: a dictionary of keyword arguments passed into call.
3428 """
3429 if self._saved_model_inputs_spec is not None:
3430 return # Already set.
3432 inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
3433 args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
3434 kwargs_spec = {}
3435 # Filter out non-tensor arguments from kwargs.
3436 for key, kwarg in kwargs.items():
3437 flat_kwarg = tf.nest.flatten(kwarg)
3438 flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
3439 if any(s is None for s in flat_specs):
3440 continue
3441 kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
3443 self._saved_model_inputs_spec = inputs_spec
3444 self._saved_model_arg_spec = (
3445 [inputs_spec] + list(args_spec),
3446 kwargs_spec,
3447 )
3449 def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
3450 if self._saved_model_inputs_spec is None:
3451 return None
3453 spec = tf.nest.map_structure(
3454 lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
3455 self._saved_model_arg_spec,
3456 )
3457 return spec[0][0] if inputs_only else spec
3459 @property
3460 def _trackable_saved_model_saver(self):
3461 return layer_serialization.LayerSavedModelSaver(self)
3463 @property
3464 def _object_identifier(self):
3465 return self._trackable_saved_model_saver.object_identifier
3467 @property
3468 def _tracking_metadata(self):
3469 """Info about this layer to be saved into the SavedModel."""
3470 return self._trackable_saved_model_saver.tracking_metadata
3472 def _trackable_children(self, save_type="checkpoint", **kwargs):
3473 if save_type == "savedmodel":
3474 cache = kwargs["cache"]
3475 # TODO(b/213628533): This must be called before super() to ensure
3476 # that any input shape changes are applied before getting the config
3477 # of the model.
3478 children = self._trackable_saved_model_saver.trackable_children(
3479 cache
3480 )
3481 else:
3482 children = {}
3483 children.update(super()._trackable_children(save_type, **kwargs))
3484 return children
3486 @property
3487 def _use_input_spec_as_call_signature(self):
3488 # Whether input spec can be used as the call signature when tracing the
3489 # Layer for SavedModel. By default, this is set to `True` for layers
3490 # exported from the Keras library, because the layers more rigidly
3491 # define the `input_specs` property (many custom layers only set the
3492 # `ndims`)
3493 return (
3494 get_canonical_name_for_symbol(type(self), api_name="keras")
3495 is not None
3496 )
3498 def __getstate__(self):
3499 # Override to support `copy.deepcopy` and pickling.
3500 # Thread-local objects cannot be copied in Python 3, so pop these.
3501 # Thread-local objects are used to cache losses in MirroredStrategy, and
3502 # so shouldn't be copied.
3503 state = self.__dict__.copy()
3504 state.pop("_thread_local", None)
3505 state.pop("_metrics_lock", None)
3506 return state
3508 def __setstate__(self, state):
3509 state["_thread_local"] = threading.local()
3510 state["_metrics_lock"] = threading.Lock()
3511 # Bypass Trackable logic as `__dict__` already contains this info.
3512 object.__setattr__(self, "__dict__", state)
3514 def save_own_variables(self, store):
3515 """Saves the state of the layer.
3517 You can override this method to take full control of how the state of
3518 the layer is saved upon calling `model.save()`.
3520 Args:
3521 store: Dict where the state of the model will be saved.
3522 """
3523 all_vars = self._trainable_weights + self._non_trainable_weights
3524 for i, v in enumerate(all_vars):
3525 store[f"{i}"] = v.numpy()
3527 def load_own_variables(self, store):
3528 """Loads the state of the layer.
3530 You can override this method to take full control of how the state of
3531 the layer is loaded upon calling `keras.models.load_model()`.
3533 Args:
3534 store: Dict from which the state of the model will be loaded.
3535 """
3536 self._update_trackables()
3537 all_vars = self._trainable_weights + self._non_trainable_weights
3538 if len(store.keys()) != len(all_vars):
3539 raise ValueError(
3540 f"Layer '{self.name}' expected {len(all_vars)} variables, "
3541 "but received "
3542 f"{len(store.keys())} variables during loading. "
3543 f"Expected: {[v.name for v in all_vars]}"
3544 )
3545 for i, v in enumerate(all_vars):
3546 # TODO(rchao): check shapes and raise errors.
3547 v.assign(store[f"{i}"])
3550class TensorFlowOpLayer(Layer):
3551 """Wraps a TensorFlow Operation in a Layer.
3553 This class is used internally by the Functional API. When a user
3554 uses a raw TensorFlow Operation on symbolic tensors originating
3555 from an `Input` Layer, the resultant operation will be wrapped
3556 with this Layer object in order to make the operation compatible
3557 with the Keras API.
3559 This Layer will create a new, identical operation (except for inputs
3560 and outputs) every time it is called. If `run_eagerly` is `True`,
3561 the op creation and calculation will happen inside an Eager function.
3563 Instances of this Layer are created when `autolambda` is called, which
3564 is whenever a Layer's `__call__` encounters symbolic inputs that do
3565 not have Keras metadata, or when a Network's `__init__` encounters
3566 outputs that do not have Keras metadata.
3568 Attributes:
3569 node_def: String, the serialized NodeDef of the Op this layer will wrap.
3570 name: String, the name of the Layer.
3571 constants: Dict of NumPy arrays, the values of any Tensors needed for this
3572 Operation that do not originate from a Keras `Input` Layer. Since all
3573 placeholders must come from Keras `Input` Layers, these Tensors must be
3574 treated as constant in the Functional API.
3575 trainable: Bool, whether this Layer is trainable. Currently Variables are
3576 not supported, and so this parameter has no effect.
3577 dtype: The default dtype of this Layer. Inherited from `Layer` and has no
3578 effect on this class, however is used in `get_config`.
3579 """
3581 @tf.__internal__.tracking.no_automatic_dependency_tracking
3582 def __init__(
3583 self, node_def, name, constants=None, trainable=True, dtype=None
3584 ):
3585 # Pass autocast=False, as if inputs are cast, input types might not
3586 # match Operation type.
3587 super(TensorFlowOpLayer, self).__init__(
3588 name=_TF_OP_LAYER_NAME_PREFIX + name,
3589 trainable=trainable,
3590 dtype=dtype,
3591 autocast=False,
3592 )
3593 if isinstance(node_def, dict):
3594 self.node_def = json_format.ParseDict(
3595 node_def, tf.compat.v1.NodeDef()
3596 )
3597 else:
3598 if not isinstance(node_def, bytes):
3599 node_def = node_def.encode("utf-8")
3600 self.node_def = tf.compat.v1.NodeDef.FromString(node_def)
3601 # JSON serialization stringifies keys which are integer input indices.
3602 self.constants = (
3603 {int(index): constant for index, constant in constants.items()}
3604 if constants is not None
3605 else {}
3606 )
3607 # Layer uses original op unless it is called on new inputs.
3608 # This means `built` is not set in `__call__`.
3609 self.built = True
3611 # Do not individually trace TensorflowOpLayers in the SavedModel.
3612 self._must_restore_from_config = True
3614 def call(self, inputs):
3615 if tf.executing_eagerly():
3616 return self._defun_call(inputs)
3617 return self._make_op(inputs)
3619 def _make_node_def(self, graph):
3620 node_def = tf.compat.v1.NodeDef()
3621 node_def.CopyFrom(self.node_def)
3622 # Used in TPUReplicateContext to indicate whether this node has been
3623 # cloned and to not add TPU attributes.
3624 node_def.attr["_cloned"].b = True
3625 node_def.name = graph.unique_name(node_def.name)
3626 return node_def
3628 def _make_op(self, inputs):
3629 inputs = tf.nest.flatten(inputs)
3630 graph = inputs[0].graph
3631 node_def = self._make_node_def(graph)
3632 with graph.as_default():
3633 for index, constant in self.constants.items():
3634 # Recreate constant in graph to add distribution context.
3635 value = tf.get_static_value(constant)
3636 if value is not None:
3637 if isinstance(value, dict):
3638 value = serialization_lib.deserialize_keras_object(
3639 value
3640 )
3641 constant = tf.constant(value, name=node_def.input[index])
3642 inputs.insert(index, constant)
3643 # TODO(b/183990973): We should drop or consolidate these private api
3644 # calls for adding an op to the graph and recording its gradient.
3645 c_op = tf.__internal__.create_c_op(
3646 graph, node_def, inputs, control_inputs=[]
3647 )
3648 op = graph._create_op_from_tf_operation(c_op)
3649 op._control_flow_post_processing()
3651 # Record the gradient because custom-made ops don't go through the
3652 # code-gen'd eager call path
3653 op_type = tf.compat.as_str(op.op_def.name)
3654 attr_names = [
3655 tf.compat.as_str(attr.name) for attr in op.op_def.attr
3656 ]
3657 attrs = []
3658 for attr_name in attr_names:
3659 attrs.append(attr_name)
3660 attrs.append(op.get_attr(attr_name))
3661 attrs = tuple(attrs)
3662 tf.__internal__.record_gradient(
3663 op_type, op.inputs, attrs, op.outputs
3664 )
3666 if len(op.outputs) == 1:
3667 return op.outputs[0]
3668 return op.outputs
3670 @tf.function
3671 def _defun_call(self, inputs):
3672 """Wraps op creation method in an Eager function for `run_eagerly`."""
3673 return self._make_op(inputs)
3675 def get_config(self):
3676 config = super(TensorFlowOpLayer, self).get_config()
3677 config.update(
3678 {
3679 # `__init__` prefixes the name. Revert to the constructor
3680 # argument.
3681 "name": config["name"][len(_TF_OP_LAYER_NAME_PREFIX) :],
3682 "node_def": json_format.MessageToDict(self.node_def),
3683 "constants": {
3684 i: backend.get_value(c) for i, c in self.constants.items()
3685 },
3686 }
3687 )
3688 return config
3691class AddLoss(Layer):
3692 """Adds its inputs as a loss.
3694 Attributes:
3695 unconditional: Whether or not the loss should be conditioned on the
3696 inputs.
3697 """
3699 def __init__(self, unconditional, **kwargs):
3700 # Pass autocast=False, as there is no reason to cast loss to a different
3701 # dtype.
3702 kwargs["autocast"] = False
3703 super(AddLoss, self).__init__(**kwargs)
3704 self.unconditional = unconditional
3706 def call(self, inputs):
3707 self.add_loss(inputs, inputs=(not self.unconditional))
3708 return inputs
3710 def get_config(self):
3711 config = super(AddLoss, self).get_config()
3712 config.update({"unconditional": self.unconditional})
3713 return config
3716class AddMetric(Layer):
3717 """Adds its inputs as a metric.
3719 Attributes:
3720 aggregation: 'mean' or None. How the inputs should be aggregated.
3721 metric_name: The name to use for this metric.
3722 """
3724 def __init__(self, aggregation=None, metric_name=None, **kwargs):
3725 super(AddMetric, self).__init__(**kwargs)
3726 self.aggregation = aggregation
3727 self.metric_name = metric_name
3729 def call(self, inputs):
3730 self.add_metric(
3731 inputs, aggregation=self.aggregation, name=self.metric_name
3732 )
3733 return inputs
3735 def get_config(self):
3736 config = super(AddMetric, self).get_config()
3737 config.update(
3738 {"aggregation": self.aggregation, "metric_name": self.metric_name}
3739 )
3740 return config
3743def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list):
3744 """Check the arguments to see if we are constructing a functional model."""
3745 # We are constructing a functional model if any of the inputs
3746 # are KerasTensors
3747 return any(
3748 isinstance(tensor, keras_tensor.KerasTensor)
3749 for tensor in tf.nest.flatten([inputs, args, kwargs])
3750 )
3753def _convert_numpy_or_python_types(x):
3754 if isinstance(x, (tf.Tensor, np.ndarray, float, int)):
3755 return tf.convert_to_tensor(x)
3756 return x
3759@keras_export("keras.__internal__.apply_name_scope_on_model_declaration", v1=[])
3760def _apply_name_scope_on_model_declaration(enable):
3761 """Apply `with tf.name_scope(...)` on model declaration.
3763 ```python
3764 tf.keras.__internal__.apply_name_scope_on_model_declaration(True)
3766 inputs = input_layer.Input((3,))
3767 with tf.name_scope('MyScope'):
3768 outputs = layers.Dense(10, name='MyDense')(inputs)
3769 model = tf.keras.Model(inputs, outputs)
3771 # with `tf.keras.__internal__.apply_name_scope_on_model_declaration(True)`,
3772 # The name of the dense layer is "model/MyScope/MyDense/*", and without,
3773 # "model/MyDense/*"
3774 ```
3776 Args:
3777 enable: Enables if `True`, disables if `False`.
3778 """
3779 if not isinstance(enable, bool):
3780 raise TypeError(
3781 f"`enable` argument must be `True` or `False`, got {enable}"
3782 )
3784 global _is_name_scope_on_model_declaration_enabled
3785 _is_name_scope_on_model_declaration_enabled = enable
3788@keras_export("keras.__internal__.layers.BaseRandomLayer")
3789class BaseRandomLayer(Layer):
3790 """A layer handle the random number creation and savemodel behavior."""
3792 @tf.__internal__.tracking.no_automatic_dependency_tracking
3793 def __init__(
3794 self, seed=None, force_generator=False, rng_type=None, **kwargs
3795 ):
3796 """Initialize the BaseRandomLayer.
3798 Note that the constructor is annotated with
3799 @no_automatic_dependency_tracking. This is to skip the auto
3800 tracking of self._random_generator instance, which is an AutoTrackable.
3801 The backend.RandomGenerator could contain a tf.random.Generator instance
3802 which will have tf.Variable as the internal state. We want to avoid
3803 saving that state into model.weights and checkpoints for backward
3804 compatibility reason. In the meantime, we still need to make them
3805 visible to SavedModel when it is tracing the tf.function for the
3806 `call()`.
3807 See _list_extra_dependencies_for_serialization below for more details.
3809 Args:
3810 seed: optional integer, used to create RandomGenerator.
3811 force_generator: boolean, default to False, whether to force the
3812 RandomGenerator to use the code branch of tf.random.Generator.
3813 rng_type: string, the rng type that will be passed to backend
3814 RandomGenerator. `None` will allow RandomGenerator to choose
3815 types by itself. Valid values are "stateful", "stateless",
3816 "legacy_stateful". Defaults to `None`.
3817 **kwargs: other keyword arguments that will be passed to the parent
3818 *class
3819 """
3820 super().__init__(**kwargs)
3821 self._random_generator = backend.RandomGenerator(
3822 seed, force_generator=force_generator, rng_type=rng_type
3823 )
3825 def build(self, input_shape):
3826 super().build(input_shape)
3827 self._random_generator._maybe_init()
3829 def _trackable_children(self, save_type="checkpoint", **kwargs):
3830 if save_type == "savedmodel":
3831 cache = kwargs["cache"]
3832 # TODO(b/213628533): This must be called before super() to ensure
3833 # that any input shape changes are applied before getting the config
3834 # of the model.
3835 children = self._trackable_saved_model_saver.trackable_children(
3836 cache
3837 )
3838 # This method exposes the self._random_generator to SavedModel only
3839 # (not layer.weights and checkpoint).
3840 children["_random_generator"] = self._random_generator
3841 else:
3842 children = {}
3843 children.update(super()._trackable_children(save_type, **kwargs))
3844 return children
3846 def _lookup_dependency(self, name):
3847 # When loading from a Keras SavedModel load, make sure that the loader
3848 # can find the random generator, otherwise the loader will assume that
3849 # it does not exist, and will try to create a new generator.
3850 if name == "_random_generator":
3851 return self._random_generator
3852 else:
3853 return super()._lookup_dependency(name)