Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py: 25%
1235 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the base Layer class, from which all layers inherit."""
18import collections
19import copy
20import functools
21import itertools
22import threading
23import warnings
24import weakref
26import numpy as np
28from google.protobuf import json_format
29from tensorflow.core.framework import node_def_pb2
30from tensorflow.python import tf2
31from tensorflow.python.autograph.core import ag_ctx
32from tensorflow.python.autograph.impl import api as autograph
33from tensorflow.python.distribute import distribute_lib
34from tensorflow.python.eager import backprop
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import func_graph
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import sparse_tensor
42from tensorflow.python.framework import tensor_conversion
43from tensorflow.python.framework import tensor_spec
44from tensorflow.python.framework import tensor_util
45from tensorflow.python.keras import backend
46from tensorflow.python.keras import constraints
47from tensorflow.python.keras import initializers
48from tensorflow.python.keras import regularizers
49from tensorflow.python.keras.engine import base_layer_utils
50from tensorflow.python.keras.engine import input_spec
51from tensorflow.python.keras.engine import keras_tensor
52from tensorflow.python.keras.engine import node as node_module
53from tensorflow.python.keras.mixed_precision import autocast_variable
54from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
55from tensorflow.python.keras.mixed_precision import policy
56from tensorflow.python.keras.saving.saved_model import layer_serialization
57from tensorflow.python.keras.utils import generic_utils
58from tensorflow.python.keras.utils import layer_utils
59from tensorflow.python.keras.utils import object_identity
60from tensorflow.python.keras.utils import tf_inspect
61from tensorflow.python.keras.utils import tf_utils
62from tensorflow.python.keras.utils import version_utils
63from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import
64from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
65from tensorflow.python.module import module
66from tensorflow.python.ops import array_ops
67from tensorflow.python.ops import math_ops
68from tensorflow.python.ops import variables as tf_variables
69from tensorflow.python.ops.numpy_ops import np_arrays
70from tensorflow.python.ops.ragged import ragged_tensor
71from tensorflow.python.platform import tf_logging
72from tensorflow.python.trackable import autotrackable
73from tensorflow.python.trackable import base as trackable
74from tensorflow.python.trackable import data_structures
75from tensorflow.python.util import compat
76from tensorflow.python.util import nest
77from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
78from tensorflow.python.util.tf_export import keras_export
79from tensorflow.tools.docs import doc_controls
81# A module that only depends on `keras.layers` import these from here.
83# pylint: disable=g-inconsistent-quotes
84metrics_mod = generic_utils.LazyLoader(
85 "metrics_mod", globals(),
86 "tensorflow.python.keras.metrics")
87# pylint: enable=g-inconsistent-quotes
89# Prefix that is added to the TF op layer names.
90_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_'
92# TODO(mdan): Should we have a single generic type for types that can be passed
93# to tf.cast?
94_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor,
95 ragged_tensor.RaggedTensor)
98@keras_export('keras.layers.Layer')
99class Layer(module.Module, version_utils.LayerVersionSelector):
100 """This is the class from which all layers inherit.
102 A layer is a callable object that takes as input one or more tensors and
103 that outputs one or more tensors. It involves *computation*, defined
104 in the `call()` method, and a *state* (weight variables), defined
105 either in the constructor `__init__()` or in the `build()` method.
107 Users will just instantiate a layer and then treat it as a callable.
109 Args:
110 trainable: Boolean, whether the layer's variables should be trainable.
111 name: String name of the layer.
112 dtype: The dtype of the layer's computations and weights. Can also be a
113 `tf.keras.mixed_precision.Policy`, which allows the computation and weight
114 dtype to differ. Default of `None` means to use
115 `tf.keras.mixed_precision.global_policy()`, which is a float32 policy
116 unless set to different value.
117 dynamic: Set this to `True` if your layer should only be run eagerly, and
118 should not be used to generate a static computation graph.
119 This would be the case for a Tree-RNN or a recursive network,
120 for example, or generally for any layer that manipulates tensors
121 using Python control flow. If `False`, we assume that the layer can
122 safely be used to generate a static computation graph.
124 Attributes:
125 name: The name of the layer (string).
126 dtype: The dtype of the layer's weights.
127 variable_dtype: Alias of `dtype`.
128 compute_dtype: The dtype of the layer's computations. Layers automatically
129 cast inputs to this dtype which causes the computations and output to also
130 be in this dtype. When mixed precision is used with a
131 `tf.keras.mixed_precision.Policy`, this will be different than
132 `variable_dtype`.
133 dtype_policy: The layer's dtype policy. See the
134 `tf.keras.mixed_precision.Policy` documentation for details.
135 trainable_weights: List of variables to be included in backprop.
136 non_trainable_weights: List of variables that should not be
137 included in backprop.
138 weights: The concatenation of the lists trainable_weights and
139 non_trainable_weights (in this order).
140 trainable: Whether the layer should be trained (boolean), i.e. whether
141 its potentially-trainable weights should be returned as part of
142 `layer.trainable_weights`.
143 input_spec: Optional (list of) `InputSpec` object(s) specifying the
144 constraints on inputs that can be accepted by the layer.
146 We recommend that descendants of `Layer` implement the following methods:
148 * `__init__()`: Defines custom layer attributes, and creates layer state
149 variables that do not depend on input shapes, using `add_weight()`.
150 * `build(self, input_shape)`: This method can be used to create weights that
151 depend on the shape(s) of the input(s), using `add_weight()`. `__call__()`
152 will automatically build the layer (if it has not been built yet) by
153 calling `build()`.
154 * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making
155 sure `build()` has been called. `call()` performs the logic of applying the
156 layer to the input tensors (which should be passed in as argument).
157 Two reserved keyword arguments you can optionally use in `call()` are:
158 - `training` (boolean, whether the call is in inference mode or training
159 mode). See more details in [the layer/model subclassing guide](
160 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method)
161 - `mask` (boolean tensor encoding masked timesteps in the input, used
162 in RNN layers). See more details in [the layer/model subclassing guide](
163 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method)
164 A typical signature for this method is `call(self, inputs)`, and user could
165 optionally add `training` and `mask` if the layer need them. `*args` and
166 `**kwargs` is only useful for future extension when more input parameters
167 are planned to be added.
168 * `get_config(self)`: Returns a dictionary containing the configuration used
169 to initialize this layer. If the keys differ from the arguments
170 in `__init__`, then override `from_config(self)` as well.
171 This method is used when saving
172 the layer or a model that contains this layer.
174 Examples:
176 Here's a basic example: a layer with two variables, `w` and `b`,
177 that returns `y = w . x + b`.
178 It shows how to implement `build()` and `call()`.
179 Variables set as attributes of a layer are tracked as weights
180 of the layers (in `layer.weights`).
182 ```python
183 class SimpleDense(Layer):
185 def __init__(self, units=32):
186 super(SimpleDense, self).__init__()
187 self.units = units
189 def build(self, input_shape): # Create the state of the layer (weights)
190 w_init = tf.random_normal_initializer()
191 self.w = tf.Variable(
192 initial_value=w_init(shape=(input_shape[-1], self.units),
193 dtype='float32'),
194 trainable=True)
195 b_init = tf.zeros_initializer()
196 self.b = tf.Variable(
197 initial_value=b_init(shape=(self.units,), dtype='float32'),
198 trainable=True)
200 def call(self, inputs): # Defines the computation from inputs to outputs
201 return tf.matmul(inputs, self.w) + self.b
203 # Instantiates the layer.
204 linear_layer = SimpleDense(4)
206 # This will also call `build(input_shape)` and create the weights.
207 y = linear_layer(tf.ones((2, 2)))
208 assert len(linear_layer.weights) == 2
210 # These weights are trainable, so they're listed in `trainable_weights`:
211 assert len(linear_layer.trainable_weights) == 2
212 ```
214 Note that the method `add_weight()` offers a shortcut to create weights:
216 ```python
217 class SimpleDense(Layer):
219 def __init__(self, units=32):
220 super(SimpleDense, self).__init__()
221 self.units = units
223 def build(self, input_shape):
224 self.w = self.add_weight(shape=(input_shape[-1], self.units),
225 initializer='random_normal',
226 trainable=True)
227 self.b = self.add_weight(shape=(self.units,),
228 initializer='random_normal',
229 trainable=True)
231 def call(self, inputs):
232 return tf.matmul(inputs, self.w) + self.b
233 ```
235 Besides trainable weights, updated via backpropagation during training,
236 layers can also have non-trainable weights. These weights are meant to
237 be updated manually during `call()`. Here's a example layer that computes
238 the running sum of its inputs:
240 ```python
241 class ComputeSum(Layer):
243 def __init__(self, input_dim):
244 super(ComputeSum, self).__init__()
245 # Create a non-trainable weight.
246 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)),
247 trainable=False)
249 def call(self, inputs):
250 self.total.assign_add(tf.reduce_sum(inputs, axis=0))
251 return self.total
253 my_sum = ComputeSum(2)
254 x = tf.ones((2, 2))
256 y = my_sum(x)
257 print(y.numpy()) # [2. 2.]
259 y = my_sum(x)
260 print(y.numpy()) # [4. 4.]
262 assert my_sum.weights == [my_sum.total]
263 assert my_sum.non_trainable_weights == [my_sum.total]
264 assert my_sum.trainable_weights == []
265 ```
267 For more information about creating layers, see the guide
268 [Making new Layers and Models via subclassing](
269 https://www.tensorflow.org/guide/keras/custom_layers_and_models)
270 """
272 # See tf.Module for the usage of this property.
273 # The key for _obj_reference_counts_dict is a Trackable, which could be a
274 # variable or layer etc. tf.Module._flatten will fail to flatten the key
275 # since it is trying to convert Trackable to a string. This attribute can be
276 # ignored even after the fix of nest lib, since the trackable object should
277 # already been available as individual attributes. _obj_reference_counts_dict
278 # just contains a copy of them.
279 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
280 ('_obj_reference_counts_dict',),
281 module.Module._TF_MODULE_IGNORED_PROPERTIES
282 ))
284 # When loading from a SavedModel, Layers typically can be revived into a
285 # generic Layer wrapper. Sometimes, however, layers may implement methods
286 # that go beyond this wrapper, as in the case of PreprocessingLayers'
287 # `adapt` method. When this is the case, layer implementers can override
288 # must_restore_from_config to return True; layers with this property must
289 # be restored into their actual objects (and will fail if the object is
290 # not available to the restoration code).
291 _must_restore_from_config = False
293 def _get_cell_name(self):
294 canonical_name = get_canonical_name_for_symbol(
295 self.__class__, api_name='keras', add_prefix_to_v1_names=True)
296 if canonical_name is not None:
297 return 'tf.{}'.format(canonical_name)
298 return self.__class__.__module__ + '.' + self.__class__.__name__
300 def _instrument_layer_creation(self):
301 self._instrumented_keras_api = False
302 self._instrumented_keras_layer_class = False
303 self._instrumented_keras_model_class = False
304 if not getattr(self, '_disable_keras_instrumentation', False):
305 self._instrumented_keras_api = True
306 if getattr(self, '_is_model_for_instrumentation', False):
307 self._instrumented_keras_model_class = True
308 else:
309 self._instrumented_keras_layer_class = True
311 @trackable.no_automatic_dependency_tracking
312 def __init__(self,
313 trainable=True,
314 name=None,
315 dtype=None,
316 dynamic=False,
317 **kwargs):
318 self._instrument_layer_creation()
320 # These properties should be set by the user via keyword arguments.
321 # note that 'dtype', 'input_shape' and 'batch_input_shape'
322 # are only applicable to input layers: do not pass these keywords
323 # to non-input layers.
324 allowed_kwargs = {
325 'input_dim',
326 'input_shape',
327 'batch_input_shape',
328 'batch_size',
329 'weights',
330 'activity_regularizer',
331 'autocast',
332 'implementation',
333 }
334 # Validate optional keyword arguments.
335 generic_utils.validate_kwargs(kwargs, allowed_kwargs)
337 # Mutable properties
338 # Indicates whether the layer's weights are updated during training
339 # and whether the layer's updates are run during training.
340 self._trainable = trainable
341 # A stateful layer is a layer whose updates are run during inference too,
342 # for instance stateful RNNs.
343 self._stateful = False
344 # Indicates whether `build` needs to be called upon layer call, to create
345 # the layer's weights.
346 self.built = False
347 # Provides information about which inputs are compatible with the layer.
348 self._input_spec = None
350 # SavedModel-related attributes.
351 # Record the build input shape for loading purposes.
352 # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
353 # submitted.
354 self._build_input_shape = None
355 self._saved_model_inputs_spec = None
357 # `Layer.compute_mask` will be called at the end of `Layer.__call__` if
358 # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets
359 # `self.supports_masking=True`.
360 self._supports_masking = not generic_utils.is_default(self.compute_mask)
362 self._init_set_name(name)
363 self._activity_regularizer = regularizers.get(
364 kwargs.pop('activity_regularizer', None))
365 self._maybe_create_attribute('_trainable_weights', [])
366 self._maybe_create_attribute('_non_trainable_weights', [])
367 self._updates = []
368 # Object to store all thread local layer properties.
369 self._thread_local = threading.local()
370 # A list of zero-argument lambdas which return Tensors, used for variable
371 # regularizers.
372 self._callable_losses = []
373 # A list of symbolic Tensors containing activity regularizers and losses
374 # manually added through `add_loss` in graph-building mode.
375 self._losses = []
376 # A list of metric instances corresponding to the symbolic metric tensors
377 # added using the `add_metric` API.
378 self._metrics = []
379 # Ensures the same metric is not added multiple times in `MirroredStrategy`.
380 self._metrics_lock = threading.Lock()
382 # Both graph and subclassed networks have a dtype policy. For graph
383 # networks, the policy's compute and variable dtypes are ignored. Such
384 # networks only use the policy if it is a PolicyV1, in which case it uses
385 # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
386 # subclassed networks, the compute and variable dtypes are used as like any
387 # ordinary layer.
388 self._set_dtype_policy(dtype)
389 # Boolean indicating whether the layer automatically casts its inputs to the
390 # layer's compute_dtype.
391 self._autocast = kwargs.get('autocast',
392 base_layer_utils.v2_dtype_behavior_enabled())
394 # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s.
395 # Ordered by when the object was assigned as an attr.
396 # Entries are unique.
397 self._maybe_create_attribute('_self_tracked_trackables', [])
399 # These lists will be filled via successive calls
400 # to self._add_inbound_node().
401 # Used in symbolic mode only, only in conjunction with graph-networks
402 self._inbound_nodes_value = []
403 self._outbound_nodes_value = []
405 self._init_call_fn_args()
407 # Whether the `call` method can be used to build a TF graph without issues.
408 # This attribute has no effect if the model is created using the Functional
409 # API. Instead, `model.dynamic` is determined based on the internal layers.
410 self._dynamic = dynamic
412 # Manage input shape information if passed.
413 if 'input_dim' in kwargs and 'input_shape' not in kwargs:
414 # Backwards compatibility: alias 'input_dim' to 'input_shape'.
415 kwargs['input_shape'] = (kwargs['input_dim'],)
416 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
417 # In this case we will later create an input layer
418 # to insert before the current layer
419 if 'batch_input_shape' in kwargs:
420 batch_input_shape = tuple(kwargs['batch_input_shape'])
421 elif 'input_shape' in kwargs:
422 if 'batch_size' in kwargs:
423 batch_size = kwargs['batch_size']
424 else:
425 batch_size = None
426 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
427 self._batch_input_shape = batch_input_shape
429 # Manage initial weight values if passed.
430 self._initial_weights = kwargs.get('weights', None)
432 # Whether the layer will track any layers that is set as attribute on itself
433 # as sub-layers, the weights from the sub-layers will be included in the
434 # parent layer's variables() as well.
435 # Default to True, which means auto tracking is turned on. Certain subclass
436 # might want to turn it off, like Sequential model.
437 self._auto_track_sub_layers = True
439 # For backwards compat reasons, most built-in layers do not guarantee
440 # That they will 100% preserve the structure of input args when saving
441 # / loading configs. E.g. they may un-nest an arg that is
442 # a list with one element.
443 self._preserve_input_structure_in_config = False
445 @trackable.no_automatic_dependency_tracking
446 @generic_utils.default
447 def build(self, input_shape):
448 """Creates the variables of the layer (optional, for subclass implementers).
450 This is a method that implementers of subclasses of `Layer` or `Model`
451 can override if they need a state-creation step in-between
452 layer instantiation and layer call.
454 This is typically used to create the weights of `Layer` subclasses.
456 Args:
457 input_shape: Instance of `TensorShape`, or list of instances of
458 `TensorShape` if the layer expects a list of inputs
459 (one instance per input).
460 """
461 # Only record the build input shapes of overridden build methods.
462 if not hasattr(self.build, '_is_default'):
463 self._build_input_shape = input_shape
464 self.built = True
466 @doc_controls.for_subclass_implementers
467 def call(self, inputs, *args, **kwargs): # pylint: disable=unused-argument
468 """This is where the layer's logic lives.
470 Note here that `call()` method in `tf.keras` is little bit different
471 from `keras` API. In `keras` API, you can pass support masking for
472 layers as additional arguments. Whereas `tf.keras` has `compute_mask()`
473 method to support masking.
475 Args:
476 inputs: Input tensor, or dict/list/tuple of input tensors.
477 The first positional `inputs` argument is subject to special rules:
478 - `inputs` must be explicitly passed. A layer cannot have zero
479 arguments, and `inputs` cannot be provided via the default value
480 of a keyword argument.
481 - NumPy array or Python scalar values in `inputs` get cast as tensors.
482 - Keras mask metadata is only collected from `inputs`.
483 - Layers are built (`build(input_shape)` method)
484 using shape info from `inputs` only.
485 - `input_spec` compatibility is only checked against `inputs`.
486 - Mixed precision input casting is only applied to `inputs`.
487 If a layer has tensor arguments in `*args` or `**kwargs`, their
488 casting behavior in mixed precision should be handled manually.
489 - The SavedModel input specification is generated using `inputs` only.
490 - Integration with various ecosystem packages like TFMOT, TFLite,
491 TF.js, etc is only supported for `inputs` and not for tensors in
492 positional and keyword arguments.
493 *args: Additional positional arguments. May contain tensors, although
494 this is not recommended, for the reasons above.
495 **kwargs: Additional keyword arguments. May contain tensors, although
496 this is not recommended, for the reasons above.
497 The following optional keyword arguments are reserved:
498 - `training`: Boolean scalar tensor of Python boolean indicating
499 whether the `call` is meant for training or inference.
500 - `mask`: Boolean input mask. If the layer's `call()` method takes a
501 `mask` argument, its default value will be set to the mask generated
502 for `inputs` by the previous layer (if `input` did come from a layer
503 that generated a corresponding mask, i.e. if it came from a Keras
504 layer with masking support).
506 Returns:
507 A tensor or list/tuple of tensors.
508 """
509 return inputs
511 @doc_controls.for_subclass_implementers
512 def _add_trackable(self, trackable_object, trainable):
513 """Adds a Trackable object to this layer's state.
515 Args:
516 trackable_object: The tf.tracking.Trackable object to add.
517 trainable: Boolean, whether the variable should be part of the layer's
518 "trainable_variables" (e.g. variables, biases) or
519 "non_trainable_variables" (e.g. BatchNorm mean and variance).
521 Returns:
522 The TrackableWeightHandler used to track this object.
523 """
524 if isinstance(trackable_object, base_layer_utils.TrackableWeightHandler):
525 handler = trackable_object
526 else:
527 handler = base_layer_utils.TrackableWeightHandler(trackable_object)
528 if trainable:
529 self._trainable_weights.append(handler)
530 else:
531 self._non_trainable_weights.append(handler)
532 return handler
534 @doc_controls.for_subclass_implementers
535 def add_weight(self,
536 name=None,
537 shape=None,
538 dtype=None,
539 initializer=None,
540 regularizer=None,
541 trainable=None,
542 constraint=None,
543 use_resource=None,
544 synchronization=tf_variables.VariableSynchronization.AUTO,
545 aggregation=tf_variables.VariableAggregation.NONE,
546 **kwargs):
547 """Adds a new variable to the layer.
549 Args:
550 name: Variable name.
551 shape: Variable shape. Defaults to scalar if unspecified.
552 dtype: The type of the variable. Defaults to `self.dtype`.
553 initializer: Initializer instance (callable).
554 regularizer: Regularizer instance (callable).
555 trainable: Boolean, whether the variable should be part of the layer's
556 "trainable_variables" (e.g. variables, biases)
557 or "non_trainable_variables" (e.g. BatchNorm mean and variance).
558 Note that `trainable` cannot be `True` if `synchronization`
559 is set to `ON_READ`.
560 constraint: Constraint instance (callable).
561 use_resource: Whether to use `ResourceVariable`.
562 synchronization: Indicates when a distributed a variable will be
563 aggregated. Accepted values are constants defined in the class
564 `tf.VariableSynchronization`. By default the synchronization is set to
565 `AUTO` and the current `DistributionStrategy` chooses
566 when to synchronize. If `synchronization` is set to `ON_READ`,
567 `trainable` must not be set to `True`.
568 aggregation: Indicates how a distributed variable will be aggregated.
569 Accepted values are constants defined in the class
570 `tf.VariableAggregation`.
571 **kwargs: Additional keyword arguments. Accepted values are `getter`,
572 `collections`, `experimental_autocast` and `caching_device`.
574 Returns:
575 The variable created.
577 Raises:
578 ValueError: When giving unsupported dtype and no initializer or when
579 trainable has been set to True with synchronization set as `ON_READ`.
580 """
581 if shape is None:
582 shape = ()
583 kwargs.pop('partitioner', None) # Ignored.
584 # Validate optional keyword arguments.
585 for kwarg in kwargs:
586 if kwarg not in ['collections', 'experimental_autocast',
587 'caching_device', 'getter']:
588 raise TypeError('Unknown keyword argument:', kwarg)
589 collections_arg = kwargs.pop('collections', None)
590 # 'experimental_autocast' can be set to False by the caller to indicate an
591 # AutoCastVariable should never be created.
592 autocast = kwargs.pop('experimental_autocast', True)
593 # See the docstring for tf.Variable about the details for caching_device.
594 caching_device = kwargs.pop('caching_device', None)
596 if dtype is None:
597 dtype = self.dtype or backend.floatx()
598 dtype = dtypes.as_dtype(dtype)
599 if self._dtype_policy.variable_dtype is None:
600 # The policy is "_infer", so we infer the policy from the variable dtype.
601 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
602 initializer = initializers.get(initializer)
603 regularizer = regularizers.get(regularizer)
604 constraint = constraints.get(constraint)
606 if synchronization == tf_variables.VariableSynchronization.ON_READ:
607 if trainable:
608 raise ValueError(
609 'Synchronization value can be set to '
610 'VariableSynchronization.ON_READ only for non-trainable variables. '
611 'You have specified trainable=True and '
612 'synchronization=VariableSynchronization.ON_READ.')
613 else:
614 # Set trainable to be false when variable is to be synced on read.
615 trainable = False
616 elif trainable is None:
617 trainable = True
619 # Initialize variable when no initializer provided
620 if initializer is None:
621 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
622 if dtype.is_floating:
623 initializer = initializers.get('glorot_uniform')
624 # If dtype is DT_INT/DT_UINT, provide a default value `zero`
625 # If dtype is DT_BOOL, provide a default value `FALSE`
626 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
627 initializer = initializers.get('zeros')
628 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
629 elif 'getter' not in kwargs:
630 # When `getter` is specified, it's possibly fine for `initializer` to be
631 # None since it's up to the custom `getter` to raise error in case it
632 # indeed needs `initializer`.
633 raise ValueError('An initializer for variable %s of type %s is required'
634 ' for layer %s' % (name, dtype.base_dtype, self.name))
636 getter = kwargs.pop('getter', base_layer_utils.make_variable)
637 if (autocast and
638 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
639 and dtype.is_floating):
640 old_getter = getter
641 # Wrap variable constructor to return an AutoCastVariable.
642 def getter(*args, **kwargs): # pylint: disable=function-redefined
643 variable = old_getter(*args, **kwargs)
644 return autocast_variable.create_autocast_variable(variable)
645 # Also the caching_device does not work with the mixed precision API,
646 # disable it if it is specified.
647 # TODO(b/142020079): Reenable it once the bug is fixed.
648 if caching_device is not None:
649 tf_logging.warning(
650 '`caching_device` does not work with mixed precision API. Ignoring '
651 'user specified `caching_device`.')
652 caching_device = None
654 variable = self._add_variable_with_custom_getter(
655 name=name,
656 shape=shape,
657 # TODO(allenl): a `make_variable` equivalent should be added as a
658 # `Trackable` method.
659 getter=getter,
660 # Manage errors in Layer rather than Trackable.
661 overwrite=True,
662 initializer=initializer,
663 dtype=dtype,
664 constraint=constraint,
665 trainable=trainable,
666 use_resource=use_resource,
667 collections=collections_arg,
668 synchronization=synchronization,
669 aggregation=aggregation,
670 caching_device=caching_device)
671 if regularizer is not None:
672 # TODO(fchollet): in the future, this should be handled at the
673 # level of variable creation, and weight regularization losses
674 # should be variable attributes.
675 name_in_scope = variable.name[:variable.name.find(':')]
676 self._handle_weight_regularization(name_in_scope,
677 variable,
678 regularizer)
679 if base_layer_utils.is_split_variable(variable):
680 for v in variable:
681 backend.track_variable(v)
682 if trainable:
683 self._trainable_weights.append(v)
684 else:
685 self._non_trainable_weights.append(v)
686 else:
687 backend.track_variable(variable)
688 if trainable:
689 self._trainable_weights.append(variable)
690 else:
691 self._non_trainable_weights.append(variable)
692 return variable
694 @generic_utils.default
695 def get_config(self):
696 """Returns the config of the layer.
698 A layer config is a Python dictionary (serializable)
699 containing the configuration of a layer.
700 The same layer can be reinstantiated later
701 (without its trained weights) from this configuration.
703 The config of a layer does not include connectivity
704 information, nor the layer class name. These are handled
705 by `Network` (one layer of abstraction above).
707 Note that `get_config()` does not guarantee to return a fresh copy of dict
708 every time it is called. The callers should make a copy of the returned dict
709 if they want to modify it.
711 Returns:
712 Python dictionary.
713 """
714 all_args = tf_inspect.getfullargspec(self.__init__).args
715 config = {
716 'name': self.name,
717 'trainable': self.trainable,
718 }
719 if hasattr(self, '_batch_input_shape'):
720 config['batch_input_shape'] = self._batch_input_shape
721 config['dtype'] = policy.serialize(self._dtype_policy)
722 if hasattr(self, 'dynamic'):
723 # Only include `dynamic` in the `config` if it is `True`
724 if self.dynamic:
725 config['dynamic'] = self.dynamic
726 elif 'dynamic' in all_args:
727 all_args.remove('dynamic')
728 expected_args = config.keys()
729 # Finds all arguments in the `__init__` that are not in the config:
730 extra_args = [arg for arg in all_args if arg not in expected_args]
731 # Check that either the only argument in the `__init__` is `self`,
732 # or that `get_config` has been overridden:
733 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
734 raise NotImplementedError('Layer %s has arguments in `__init__` and '
735 'therefore must override `get_config`.' %
736 self.__class__.__name__)
737 return config
739 @classmethod
740 def from_config(cls, config):
741 """Creates a layer from its config.
743 This method is the reverse of `get_config`,
744 capable of instantiating the same layer from the config
745 dictionary. It does not handle layer connectivity
746 (handled by Network), nor weights (handled by `set_weights`).
748 Args:
749 config: A Python dictionary, typically the
750 output of get_config.
752 Returns:
753 A layer instance.
754 """
755 return cls(**config)
757 def compute_output_shape(self, input_shape):
758 """Computes the output shape of the layer.
760 If the layer has not been built, this method will call `build` on the
761 layer. This assumes that the layer will later be used with inputs that
762 match the input shape provided here.
764 Args:
765 input_shape: Shape tuple (tuple of integers)
766 or list of shape tuples (one per output tensor of the layer).
767 Shape tuples can include None for free dimensions,
768 instead of an integer.
770 Returns:
771 An input shape tuple.
772 """
773 if context.executing_eagerly():
774 # In this case we build the model first in order to do shape inference.
775 # This is acceptable because the framework only calls
776 # `compute_output_shape` on shape values that the layer would later be
777 # built for. It would however cause issues in case a user attempts to
778 # use `compute_output_shape` manually with shapes that are incompatible
779 # with the shape the Layer will be called on (these users will have to
780 # implement `compute_output_shape` themselves).
781 self._maybe_build(input_shape)
782 with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default():
783 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
784 def _make_placeholder_like(shape):
785 ph = backend.placeholder(shape=shape, dtype=self.dtype)
786 ph._keras_mask = None
787 return ph
788 inputs = nest.map_structure(_make_placeholder_like, input_shape)
789 try:
790 outputs = self(inputs, training=False)
791 except TypeError as e:
792 raise NotImplementedError(
793 'We could not automatically infer the static shape of the '
794 'layer\'s output. Please implement the '
795 '`compute_output_shape` method on your layer (%s).' %
796 self.__class__.__name__) from e
797 return nest.map_structure(lambda t: t.shape, outputs)
798 raise NotImplementedError(
799 'Please run in eager mode or implement the `compute_output_shape` '
800 'method on your layer (%s).' % self.__class__.__name__)
802 @doc_controls.for_subclass_implementers
803 def compute_output_signature(self, input_signature):
804 """Compute the output tensor signature of the layer based on the inputs.
806 Unlike a TensorShape object, a TensorSpec object contains both shape
807 and dtype information for a tensor. This method allows layers to provide
808 output dtype information if it is different from the input dtype.
809 For any layer that doesn't implement this function,
810 the framework will fall back to use `compute_output_shape`, and will
811 assume that the output dtype matches the input dtype.
813 Args:
814 input_signature: Single TensorSpec or nested structure of TensorSpec
815 objects, describing a candidate input for the layer.
817 Returns:
818 Single TensorSpec or nested structure of TensorSpec objects, describing
819 how the layer would transform the provided input.
821 Raises:
822 TypeError: If input_signature contains a non-TensorSpec object.
823 """
824 def check_type_return_shape(s):
825 if not isinstance(s, tensor_spec.TensorSpec):
826 raise TypeError('Only TensorSpec signature types are supported, '
827 'but saw signature entry: {}.'.format(s))
828 return s.shape
829 input_shape = nest.map_structure(check_type_return_shape, input_signature)
830 output_shape = self.compute_output_shape(input_shape)
831 dtype = self._compute_dtype
832 if dtype is None:
833 input_dtypes = [s.dtype for s in nest.flatten(input_signature)]
834 # Default behavior when self.dtype is None, is to use the first input's
835 # dtype.
836 dtype = input_dtypes[0]
837 return nest.map_structure(
838 lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
839 output_shape)
841 def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs):
842 if self.dynamic:
843 # We will use static shape inference to return symbolic tensors
844 # matching the specifications of the layer outputs.
845 # Since `self.dynamic` is True, we will never attempt to
846 # run the underlying TF graph (which is disconnected).
847 # TODO(fchollet): consider py_func as an alternative, which
848 # would enable us to run the underlying graph if needed.
849 input_signature = nest.map_structure(
850 lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype),
851 inputs)
852 output_signature = self.compute_output_signature(input_signature)
853 return nest.map_structure(keras_tensor.KerasTensor, output_signature)
854 else:
855 return self._infer_output_signature(inputs, args, kwargs, input_masks)
857 def _infer_output_signature(self, inputs, args, kwargs, input_masks):
858 """TODO(kaftan): Docstring."""
860 call_fn = self.call
861 # Wrapping `call` function in autograph to allow for dynamic control
862 # flow and control dependencies in call. We are limiting this to
863 # subclassed layers as autograph is strictly needed only for
864 # subclassed layers and models.
865 # tf_convert will respect the value of autograph setting in the
866 # enclosing tf.function, if any.
867 if (base_layer_utils.is_subclassed(self) and
868 not base_layer_utils.from_saved_model(self)):
869 call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
871 # We enter a scratch graph and build placeholder inputs inside of it that
872 # match the input args.
873 # We then call the layer inside of the scratch graph to identify the
874 # output signatures, then we build KerasTensors corresponding to those
875 # outputs.
876 scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph')
877 with scratch_graph.as_default():
878 inputs = nest.map_structure(
879 keras_tensor.keras_tensor_to_placeholder, inputs)
880 args = nest.map_structure(
881 keras_tensor.keras_tensor_to_placeholder, args)
882 kwargs = nest.map_structure(
883 keras_tensor.keras_tensor_to_placeholder, kwargs)
884 input_masks = nest.map_structure(
885 keras_tensor.keras_tensor_to_placeholder, input_masks)
887 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable
888 with autocast_variable.enable_auto_cast_variables(
889 self._compute_dtype_object):
890 # Build layer if applicable (if the `build` method has been
891 # overridden).
892 # TODO(kaftan): do we maybe_build here, or have we already done it?
893 self._maybe_build(inputs)
894 inputs = self._maybe_cast_inputs(inputs)
895 outputs = call_fn(inputs, *args, **kwargs)
897 self._handle_activity_regularization(inputs, outputs)
898 self._set_mask_metadata(inputs, outputs, input_masks,
899 build_graph=False)
900 outputs = nest.map_structure(
901 keras_tensor.keras_tensor_from_tensor, outputs)
903 if hasattr(self, '_set_inputs') and not self.inputs:
904 # TODO(kaftan): figure out if we need to do this at all
905 # Subclassed network: explicitly set metadata normally set by
906 # a call to self._set_inputs().
907 self._set_inputs(inputs, outputs)
908 del scratch_graph
909 return outputs
911 @generic_utils.default
912 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
913 """Computes an output mask tensor.
915 Args:
916 inputs: Tensor or list of tensors.
917 mask: Tensor or list of tensors.
919 Returns:
920 None or a tensor (or list of tensors,
921 one per output tensor of the layer).
922 """
923 if not self._supports_masking:
924 if any(m is not None for m in nest.flatten(mask)):
925 raise TypeError('Layer ' + self.name + ' does not support masking, '
926 'but was passed an input_mask: ' + str(mask))
927 # masking not explicitly supported: return None as mask.
928 return None
929 # if masking is explicitly supported, by default
930 # carry over the input mask
931 return mask
933 def __call__(self, *args, **kwargs):
934 """Wraps `call`, applying pre- and post-processing steps.
936 Args:
937 *args: Positional arguments to be passed to `self.call`.
938 **kwargs: Keyword arguments to be passed to `self.call`.
940 Returns:
941 Output tensor(s).
943 Note:
944 - The following optional keyword arguments are reserved for specific uses:
945 * `training`: Boolean scalar tensor of Python boolean indicating
946 whether the `call` is meant for training or inference.
947 * `mask`: Boolean input mask.
948 - If the layer's `call` method takes a `mask` argument (as some Keras
949 layers do), its default value will be set to the mask generated
950 for `inputs` by the previous layer (if `input` did come from
951 a layer that generated a corresponding mask, i.e. if it came from
952 a Keras layer with masking support.
953 - If the layer is not built, the method will call `build`.
955 Raises:
956 ValueError: if the layer's `call` method returns None (an invalid value).
957 RuntimeError: if `super().__init__()` was not called in the constructor.
958 """
959 if not hasattr(self, '_thread_local'):
960 raise RuntimeError(
961 'You must call `super().__init__()` in the layer constructor.')
963 # `inputs` (the first arg in the method spec) is special cased in
964 # layer call due to historical reasons.
965 # This special casing currently takes the form of:
966 # - 'inputs' must be explicitly passed. A layer cannot have zero arguments,
967 # and inputs cannot have been provided via the default value of a kwarg.
968 # - numpy/scalar values in `inputs` get converted to tensors
969 # - implicit masks / mask metadata are only collected from 'inputs`
970 # - Layers are built using shape info from 'inputs' only
971 # - input_spec compatibility is only checked against `inputs`
972 # - mixed precision casting (autocast) is only applied to `inputs`,
973 # not to any other argument.
974 # - setting the SavedModel saving spec.
975 inputs, args, kwargs = self._split_out_first_arg(args, kwargs)
976 input_list = nest.flatten(inputs)
978 # Functional Model construction mode is invoked when `Layer`s are called on
979 # symbolic `KerasTensor`s, i.e.:
980 # >> inputs = tf.keras.Input(10)
981 # >> outputs = MyLayer()(inputs) # Functional construction mode.
982 # >> model = tf.keras.Model(inputs, outputs)
983 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
984 return self._functional_construction_call(inputs, args, kwargs,
985 input_list)
987 # Maintains info about the `Layer.call` stack.
988 call_context = base_layer_utils.call_context()
990 # Accept NumPy and scalar inputs by converting to Tensors.
991 if any(isinstance(x, (
992 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
993 inputs = nest.map_structure(_convert_numpy_or_python_types, inputs)
994 input_list = nest.flatten(inputs)
996 # Handle `mask` propagation from previous layer to current layer. Masks can
997 # be propagated explicitly via the `mask` argument, or implicitly via
998 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
999 # explicitly take priority.
1000 input_masks, mask_is_implicit = self._get_input_masks(
1001 inputs, input_list, args, kwargs)
1002 if self._expects_mask_arg and mask_is_implicit:
1003 kwargs['mask'] = input_masks
1005 # Training mode for `Layer.call` is set via (in order of priority):
1006 # (1) The `training` argument passed to this `Layer.call`, if it is not None
1007 # (2) The training mode of an outer `Layer.call`.
1008 # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set)
1009 # (4) Any non-None default value for `training` specified in the call
1010 # signature
1011 # (5) False (treating the layer as if it's in inference)
1012 args, kwargs, training_mode = self._set_training_mode(
1013 args, kwargs, call_context)
1015 # Losses are cleared for all sublayers on the outermost `Layer.call`.
1016 # Losses are not cleared on inner `Layer.call`s, because sublayers can be
1017 # called multiple times.
1018 if not call_context.in_call:
1019 self._clear_losses()
1021 eager = context.executing_eagerly()
1022 with call_context.enter(
1023 layer=self,
1024 inputs=inputs,
1025 build_graph=not eager,
1026 training=training_mode):
1028 input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
1029 if eager:
1030 call_fn = self.call
1031 name_scope = self._name
1032 else:
1033 name_scope = self._name_scope() # Avoid autoincrementing. # pylint: disable=not-callable
1034 call_fn = self._autographed_call()
1036 with ops.name_scope_v2(name_scope):
1037 if not self.built:
1038 self._maybe_build(inputs)
1040 if self._autocast:
1041 inputs = self._maybe_cast_inputs(inputs, input_list)
1043 with autocast_variable.enable_auto_cast_variables(
1044 self._compute_dtype_object):
1045 outputs = call_fn(inputs, *args, **kwargs)
1047 if self._activity_regularizer:
1048 self._handle_activity_regularization(inputs, outputs)
1049 if self._supports_masking:
1050 self._set_mask_metadata(inputs, outputs, input_masks, not eager)
1051 if self._saved_model_inputs_spec is None:
1052 self._set_save_spec(inputs)
1054 return outputs
1056 def _functional_construction_call(self, inputs, args, kwargs, input_list):
1057 call_context = base_layer_utils.call_context()
1059 # Accept NumPy and scalar inputs by converting to Tensors.
1060 if any(isinstance(x, (
1061 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
1063 def _convert_non_tensor(x):
1064 # Don't call `ops.convert_to_tensor` on all `inputs` because
1065 # `SparseTensors` can't be converted to `Tensor`.
1066 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
1067 return tensor_conversion.convert_to_tensor_v2_with_dispatch(x)
1068 return x
1070 inputs = nest.map_structure(_convert_non_tensor, inputs)
1071 input_list = nest.flatten(inputs)
1073 # Handle `mask` propagation from previous layer to current layer. Masks can
1074 # be propagated explicitly via the `mask` argument, or implicitly via
1075 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
1076 # explicitly take priority.
1077 mask_arg_passed_by_framework = False
1078 input_masks, mask_is_implicit = self._get_input_masks(
1079 inputs, input_list, args, kwargs)
1080 if self._expects_mask_arg and mask_is_implicit:
1081 kwargs['mask'] = input_masks
1082 mask_arg_passed_by_framework = True
1084 # If `training` argument is None or not explicitly passed,
1085 # propagate `training` value from this layer's calling layer.
1086 training_value = None
1087 training_arg_passed_by_framework = False
1088 # Priority 1: `training` was explicitly passed a non-None value.
1089 if self._call_arg_was_passed('training', args, kwargs):
1090 training_value = self._get_call_arg_value('training', args, kwargs)
1091 if not self._expects_training_arg:
1092 kwargs.pop('training')
1094 if training_value is None:
1095 # Priority 2: `training` was passed to a parent layer.
1096 if call_context.training is not None:
1097 training_value = call_context.training
1098 # Priority 3: `learning_phase()` has been set.
1099 elif backend.global_learning_phase_is_set():
1100 training_value = backend.learning_phase()
1101 # Force the training_value to be bool type which matches to the contract
1102 # for layer/model call args.
1103 if tensor_util.is_tf_type(training_value):
1104 training_value = math_ops.cast(training_value, dtypes.bool)
1105 else:
1106 training_value = bool(training_value)
1107 # Priority 4: trace layer with the default training argument specified
1108 # in the `call` signature (or in inference mode if the `call` signature
1109 # specifies no non-None default).
1110 else:
1111 training_value = self._default_training_arg
1112 # In cases (2), (3), (4) the training argument is passed automatically
1113 # by the framework, and will not be hard-coded into the model.
1114 if self._expects_training_arg:
1115 args, kwargs = self._set_call_arg_value('training', training_value,
1116 args, kwargs)
1117 training_arg_passed_by_framework = True
1119 with call_context.enter(
1120 layer=self, inputs=inputs, build_graph=True, training=training_value):
1121 # Check input assumptions set after layer building, e.g. input shape.
1122 outputs = self._keras_tensor_symbolic_call(
1123 inputs, input_masks, args, kwargs)
1125 if outputs is None:
1126 raise ValueError('A layer\'s `call` method should return a '
1127 'Tensor or a list of Tensors, not None '
1128 '(layer: ' + self.name + ').')
1129 if training_arg_passed_by_framework:
1130 args, kwargs = self._set_call_arg_value(
1131 'training', None, args, kwargs, pop_kwarg_if_none=True)
1132 if mask_arg_passed_by_framework:
1133 kwargs.pop('mask')
1134 # Node connectivity does not special-case the first argument.
1135 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
1136 outputs)
1137 return outputs
1139 def _set_training_mode(self, args, kwargs, call_context):
1140 training_mode = None
1141 if self._expects_training_arg:
1142 # (1) `training` was passed to this `Layer.call`.
1143 if self._call_arg_was_passed('training', args, kwargs):
1144 training_mode = self._get_call_arg_value('training', args, kwargs)
1145 # If no `training` arg was passed, or `None` was explicitly passed,
1146 # the framework will make a decision about the training mode is.
1147 if training_mode is None:
1148 call_ctx_training = call_context.training
1149 # (2) `training` mode is inferred from an outer `Layer.call`.
1150 if call_ctx_training is not None:
1151 training_mode = call_ctx_training
1152 # (3) User set `tf.keras.backend.set_learning_phase`.
1153 elif backend.global_learning_phase_is_set():
1154 training_mode = backend.learning_phase()
1155 # Ensure value is a `bool` or `tf.bool`.
1156 if isinstance(training_mode, bool):
1157 pass
1158 elif tensor_util.is_tf_type(training_mode):
1159 training_mode = math_ops.cast(training_mode, dtypes.bool)
1160 else:
1161 training_mode = bool(training_mode)
1162 # (4) We default to using `call`'s default value for `training`,
1163 # or treating the layer as if it is in inference if no non-None default
1164 # is specified in the `call` signature.
1165 else:
1166 training_mode = self._default_training_arg
1168 # For case (2), (3), (4) `training` arg is passed by framework.
1169 args, kwargs = self._set_call_arg_value('training', training_mode, args,
1170 kwargs)
1171 else:
1172 if 'training' in kwargs:
1173 # `training` was passed to this `Layer` but is not needed for
1174 # `Layer.call`. It will set the default mode for inner `Layer.call`s.
1175 training_mode = kwargs.pop('training')
1176 else:
1177 # Grab the current `training` mode from any outer `Layer.call`.
1178 training_mode = call_context.training
1180 return args, kwargs, training_mode
1182 def _autographed_call(self):
1183 # Wrapping `call` function in autograph to allow for dynamic control
1184 # flow and control dependencies in call. We are limiting this to
1185 # subclassed layers as autograph is strictly needed only for
1186 # subclassed layers and models.
1187 # tf_convert will respect the value of autograph setting in the
1188 # enclosing tf.function, if any.
1189 if (base_layer_utils.is_subclassed(self) and
1190 not base_layer_utils.from_saved_model(self)):
1191 return autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
1192 else:
1193 return self.call
1195 @property
1196 def dtype(self):
1197 """The dtype of the layer weights.
1199 This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless
1200 mixed precision is used, this is the same as `Layer.compute_dtype`, the
1201 dtype of the layer's computations.
1202 """
1203 return self._dtype_policy.variable_dtype
1205 @property
1206 def name(self):
1207 """Name of the layer (string), set in the constructor."""
1208 return self._name
1210 @property
1211 def supports_masking(self):
1212 """Whether this layer supports computing a mask using `compute_mask`."""
1213 return self._supports_masking
1215 @supports_masking.setter
1216 def supports_masking(self, value):
1217 self._supports_masking = value
1219 @property
1220 def dynamic(self):
1221 """Whether the layer is dynamic (eager-only); set in the constructor."""
1222 return any(layer._dynamic for layer in self._flatten_layers())
1224 @property
1225 @doc_controls.do_not_doc_inheritable
1226 def stateful(self):
1227 return any(layer._stateful for layer in self._flatten_layers())
1229 @stateful.setter
1230 def stateful(self, value):
1231 self._stateful = value
1233 @property
1234 def trainable(self):
1235 return self._trainable
1237 @trainable.setter
1238 def trainable(self, value):
1239 for layer in self._flatten_layers():
1240 layer._trainable = value
1242 @property
1243 def activity_regularizer(self):
1244 """Optional regularizer function for the output of this layer."""
1245 return self._activity_regularizer
1247 @activity_regularizer.setter
1248 def activity_regularizer(self, regularizer):
1249 """Optional regularizer function for the output of this layer."""
1250 self._activity_regularizer = regularizer
1252 @property
1253 def input_spec(self):
1254 """`InputSpec` instance(s) describing the input format for this layer.
1256 When you create a layer subclass, you can set `self.input_spec` to enable
1257 the layer to run input compatibility checks when it is called.
1258 Consider a `Conv2D` layer: it can only be called on a single input tensor
1259 of rank 4. As such, you can set, in `__init__()`:
1261 ```python
1262 self.input_spec = tf.keras.layers.InputSpec(ndim=4)
1263 ```
1265 Now, if you try to call the layer on an input that isn't rank 4
1266 (for instance, an input of shape `(2,)`, it will raise a nicely-formatted
1267 error:
1269 ```
1270 ValueError: Input 0 of layer conv2d is incompatible with the layer:
1271 expected ndim=4, found ndim=1. Full shape received: [2]
1272 ```
1274 Input checks that can be specified via `input_spec` include:
1275 - Structure (e.g. a single input, a list of 2 inputs, etc)
1276 - Shape
1277 - Rank (ndim)
1278 - Dtype
1280 For more information, see `tf.keras.layers.InputSpec`.
1282 Returns:
1283 A `tf.keras.layers.InputSpec` instance, or nested structure thereof.
1284 """
1285 return self._input_spec
1287 @input_spec.setter
1288 # Must be decorated to prevent tracking, since the input_spec can be nested
1289 # InputSpec objects.
1290 @trackable.no_automatic_dependency_tracking
1291 def input_spec(self, value):
1292 for v in nest.flatten(value):
1293 if v is not None and not isinstance(v, InputSpec):
1294 raise TypeError('Layer input_spec must be an instance of InputSpec. '
1295 'Got: {}'.format(v))
1296 self._input_spec = value
1298 @property
1299 def trainable_weights(self):
1300 """List of all trainable weights tracked by this layer.
1302 Trainable weights are updated via gradient descent during training.
1304 Returns:
1305 A list of trainable variables.
1306 """
1307 if self.trainable:
1308 children_weights = self._gather_children_attribute('trainable_variables')
1309 return self._dedup_weights(self._trainable_weights + children_weights)
1310 else:
1311 return []
1313 @property
1314 def non_trainable_weights(self):
1315 """List of all non-trainable weights tracked by this layer.
1317 Non-trainable weights are *not* updated during training. They are expected
1318 to be updated manually in `call()`.
1320 Returns:
1321 A list of non-trainable variables.
1322 """
1323 if self.trainable:
1324 children_weights = self._gather_children_attribute(
1325 'non_trainable_variables')
1326 non_trainable_weights = self._non_trainable_weights + children_weights
1327 else:
1328 children_weights = self._gather_children_attribute('variables')
1329 non_trainable_weights = (
1330 self._trainable_weights + self._non_trainable_weights +
1331 children_weights)
1332 return self._dedup_weights(non_trainable_weights)
1334 @property
1335 def weights(self):
1336 """Returns the list of all layer variables/weights.
1338 Returns:
1339 A list of variables.
1340 """
1341 return self.trainable_weights + self.non_trainable_weights
1343 @property
1344 @doc_controls.do_not_generate_docs
1345 def updates(self):
1346 warnings.warn('`layer.updates` will be removed in a future version. '
1347 'This property should not be used in TensorFlow 2.0, '
1348 'as `updates` are applied automatically.')
1349 return []
1351 @property
1352 def losses(self):
1353 """List of losses added using the `add_loss()` API.
1355 Variable regularization tensors are created when this property is accessed,
1356 so it is eager safe: accessing `losses` under a `tf.GradientTape` will
1357 propagate gradients back to the corresponding variables.
1359 Examples:
1361 >>> class MyLayer(tf.keras.layers.Layer):
1362 ... def call(self, inputs):
1363 ... self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1364 ... return inputs
1365 >>> l = MyLayer()
1366 >>> l(np.ones((10, 1)))
1367 >>> l.losses
1368 [1.0]
1370 >>> inputs = tf.keras.Input(shape=(10,))
1371 >>> x = tf.keras.layers.Dense(10)(inputs)
1372 >>> outputs = tf.keras.layers.Dense(1)(x)
1373 >>> model = tf.keras.Model(inputs, outputs)
1374 >>> # Activity regularization.
1375 >>> len(model.losses)
1376 0
1377 >>> model.add_loss(tf.abs(tf.reduce_mean(x)))
1378 >>> len(model.losses)
1379 1
1381 >>> inputs = tf.keras.Input(shape=(10,))
1382 >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
1383 >>> x = d(inputs)
1384 >>> outputs = tf.keras.layers.Dense(1)(x)
1385 >>> model = tf.keras.Model(inputs, outputs)
1386 >>> # Weight regularization.
1387 >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
1388 >>> model.losses
1389 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
1391 Returns:
1392 A list of tensors.
1393 """
1394 collected_losses = []
1395 for layer in self._flatten_layers():
1396 # If any eager losses are present, we assume the model to be part of an
1397 # eager training loop (either a custom one or the one used when
1398 # `run_eagerly=True`) and so we always return just the eager losses.
1399 if layer._eager_losses:
1400 # Filter placeholder losses that may have been added by revived layers.
1401 # (see base_layer_utils for details).
1402 if (layer._eager_losses[0] is
1403 not base_layer_utils.REVIVED_LOSS_PLACEHOLDER):
1404 collected_losses.extend(layer._eager_losses)
1405 else:
1406 collected_losses.extend(layer._losses)
1407 for regularizer in layer._callable_losses:
1408 loss_tensor = regularizer()
1409 if loss_tensor is not None:
1410 collected_losses.append(loss_tensor)
1411 return collected_losses
1413 def add_loss(self, losses, **kwargs):
1414 """Add loss tensor(s), potentially dependent on layer inputs.
1416 Some losses (for instance, activity regularization losses) may be dependent
1417 on the inputs passed when calling a layer. Hence, when reusing the same
1418 layer on different inputs `a` and `b`, some entries in `layer.losses` may
1419 be dependent on `a` and some on `b`. This method automatically keeps track
1420 of dependencies.
1422 This method can be used inside a subclassed layer or model's `call`
1423 function, in which case `losses` should be a Tensor or list of Tensors.
1425 Example:
1427 ```python
1428 class MyLayer(tf.keras.layers.Layer):
1429 def call(self, inputs):
1430 self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1431 return inputs
1432 ```
1434 This method can also be called directly on a Functional Model during
1435 construction. In this case, any loss Tensors passed to this Model must
1436 be symbolic and be able to be traced back to the model's `Input`s. These
1437 losses become part of the model's topology and are tracked in `get_config`.
1439 Example:
1441 ```python
1442 inputs = tf.keras.Input(shape=(10,))
1443 x = tf.keras.layers.Dense(10)(inputs)
1444 outputs = tf.keras.layers.Dense(1)(x)
1445 model = tf.keras.Model(inputs, outputs)
1446 # Activity regularization.
1447 model.add_loss(tf.abs(tf.reduce_mean(x)))
1448 ```
1450 If this is not the case for your loss (if, for example, your loss references
1451 a `Variable` of one of the model's layers), you can wrap your loss in a
1452 zero-argument lambda. These losses are not tracked as part of the model's
1453 topology since they can't be serialized.
1455 Example:
1457 ```python
1458 inputs = tf.keras.Input(shape=(10,))
1459 d = tf.keras.layers.Dense(10)
1460 x = d(inputs)
1461 outputs = tf.keras.layers.Dense(1)(x)
1462 model = tf.keras.Model(inputs, outputs)
1463 # Weight regularization.
1464 model.add_loss(lambda: tf.reduce_mean(d.kernel))
1465 ```
1467 Args:
1468 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
1469 may also be zero-argument callables which create a loss tensor.
1470 **kwargs: Additional keyword arguments for backward compatibility.
1471 Accepted values:
1472 inputs - Deprecated, will be automatically inferred.
1473 """
1474 kwargs.pop('inputs', None)
1475 if kwargs:
1476 raise TypeError('Unknown keyword arguments: %s' % (kwargs.keys(),))
1478 def _tag_callable(loss):
1479 """Tags callable loss tensor as `_unconditional_loss`."""
1480 if callable(loss):
1481 # We run the loss without autocasting, as regularizers are often
1482 # numerically unstable in float16.
1483 with autocast_variable.enable_auto_cast_variables(None):
1484 loss = loss()
1485 if loss is None:
1486 return None # Will be filtered out when computing the .losses property
1487 if not tensor_util.is_tf_type(loss):
1488 loss = tensor_conversion.convert_to_tensor_v2_with_dispatch(
1489 loss, dtype=backend.floatx()
1490 )
1491 loss._unconditional_loss = True # pylint: disable=protected-access
1492 return loss
1494 losses = nest.flatten(losses)
1496 callable_losses = []
1497 eager_losses = []
1498 symbolic_losses = []
1499 for loss in losses:
1500 if callable(loss):
1501 callable_losses.append(functools.partial(_tag_callable, loss))
1502 continue
1503 if loss is None:
1504 continue
1505 if not tensor_util.is_tf_type(loss) and not isinstance(
1506 loss, keras_tensor.KerasTensor):
1507 loss = tensor_conversion.convert_to_tensor_v2_with_dispatch(
1508 loss, dtype=backend.floatx()
1509 )
1510 # TF Functions should take the eager path.
1511 if ((tf_utils.is_symbolic_tensor(loss) or
1512 isinstance(loss, keras_tensor.KerasTensor)) and
1513 not base_layer_utils.is_in_tf_function()):
1514 symbolic_losses.append(loss)
1515 elif tensor_util.is_tf_type(loss):
1516 eager_losses.append(loss)
1518 self._callable_losses.extend(callable_losses)
1520 in_call_context = base_layer_utils.call_context().in_call
1521 if eager_losses and not in_call_context:
1522 raise ValueError(
1523 'Expected a symbolic Tensors or a callable for the loss value. '
1524 'Please wrap your loss computation in a zero argument `lambda`.')
1526 self._eager_losses.extend(eager_losses)
1528 for symbolic_loss in symbolic_losses:
1529 if getattr(self, '_is_graph_network', False):
1530 self._graph_network_add_loss(symbolic_loss)
1531 else:
1532 # Possible a loss was added in a Layer's `build`.
1533 self._losses.append(symbolic_loss)
1535 def _clear_losses(self):
1536 """Used every step in eager to reset losses."""
1537 # Set to thread local directly to avoid Layer.__setattr__ overhead.
1538 if not getattr(self, '_self_tracked_trackables',
1539 None): # Fast path for single Layer.
1540 self._thread_local._eager_losses = []
1541 else:
1542 for layer in self._flatten_layers():
1543 layer._thread_local._eager_losses = []
1545 @property
1546 def metrics(self):
1547 """List of metrics added using the `add_metric()` API.
1549 Example:
1551 >>> input = tf.keras.layers.Input(shape=(3,))
1552 >>> d = tf.keras.layers.Dense(2)
1553 >>> output = d(input)
1554 >>> d.add_metric(tf.reduce_max(output), name='max')
1555 >>> d.add_metric(tf.reduce_min(output), name='min')
1556 >>> [m.name for m in d.metrics]
1557 ['max', 'min']
1559 Returns:
1560 A list of `Metric` objects.
1561 """
1562 collected_metrics = []
1563 for layer in self._flatten_layers():
1564 with layer._metrics_lock:
1565 collected_metrics.extend(layer._metrics)
1566 return collected_metrics
1568 def add_metric(self, value, name=None, **kwargs):
1569 """Adds metric tensor to the layer.
1571 This method can be used inside the `call()` method of a subclassed layer
1572 or model.
1574 ```python
1575 class MyMetricLayer(tf.keras.layers.Layer):
1576 def __init__(self):
1577 super(MyMetricLayer, self).__init__(name='my_metric_layer')
1578 self.mean = tf.keras.metrics.Mean(name='metric_1')
1580 def call(self, inputs):
1581 self.add_metric(self.mean(inputs))
1582 self.add_metric(tf.reduce_sum(inputs), name='metric_2')
1583 return inputs
1584 ```
1586 This method can also be called directly on a Functional Model during
1587 construction. In this case, any tensor passed to this Model must
1588 be symbolic and be able to be traced back to the model's `Input`s. These
1589 metrics become part of the model's topology and are tracked when you
1590 save the model via `save()`.
1592 ```python
1593 inputs = tf.keras.Input(shape=(10,))
1594 x = tf.keras.layers.Dense(10)(inputs)
1595 outputs = tf.keras.layers.Dense(1)(x)
1596 model = tf.keras.Model(inputs, outputs)
1597 model.add_metric(math_ops.reduce_sum(x), name='metric_1')
1598 ```
1600 Note: Calling `add_metric()` with the result of a metric object on a
1601 Functional Model, as shown in the example below, is not supported. This is
1602 because we cannot trace the metric result tensor back to the model's inputs.
1604 ```python
1605 inputs = tf.keras.Input(shape=(10,))
1606 x = tf.keras.layers.Dense(10)(inputs)
1607 outputs = tf.keras.layers.Dense(1)(x)
1608 model = tf.keras.Model(inputs, outputs)
1609 model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
1610 ```
1612 Args:
1613 value: Metric tensor.
1614 name: String metric name.
1615 **kwargs: Additional keyword arguments for backward compatibility.
1616 Accepted values:
1617 `aggregation` - When the `value` tensor provided is not the result of
1618 calling a `keras.Metric` instance, it will be aggregated by default
1619 using a `keras.Metric.Mean`.
1620 """
1621 kwargs_keys = list(kwargs.keys())
1622 if (len(kwargs_keys) > 1 or
1623 (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')):
1624 raise TypeError('Unknown keyword arguments: ', str(kwargs.keys()))
1626 from_metric_obj = hasattr(value, '_metric_obj')
1627 is_symbolic = isinstance(value, keras_tensor.KerasTensor)
1628 in_call_context = base_layer_utils.call_context().in_call
1630 if name is None and not from_metric_obj:
1631 # Eg. `self.add_metric(math_ops.reduce_sum(x))`
1632 # In eager mode, we use metric name to lookup a metric. Without a name,
1633 # a new Mean metric wrapper will be created on every model/layer call.
1634 # So, we raise an error when no name is provided.
1635 # We will do the same for symbolic mode for consistency although a name
1636 # will be generated if no name is provided.
1638 # We will not raise this error in the foll use case for the sake of
1639 # consistency as name in provided in the metric constructor.
1640 # mean = metrics.Mean(name='my_metric')
1641 # model.add_metric(mean(outputs))
1642 raise ValueError('Please provide a name for your metric like '
1643 '`self.add_metric(tf.reduce_sum(inputs), '
1644 'name=\'mean_activation\')`')
1645 elif from_metric_obj:
1646 name = value._metric_obj.name
1648 if not in_call_context and not is_symbolic:
1649 raise ValueError('Expected a symbolic Tensor for the metric value, '
1650 'received: ' + str(value))
1652 # If a metric was added in a Layer's `call` or `build`.
1653 if in_call_context or not getattr(self, '_is_graph_network', False):
1654 # TF Function path should take the eager path.
1656 # If the given metric is available in `metrics` list we just update state
1657 # on it, otherwise we create a new metric instance and
1658 # add it to the `metrics` list.
1659 metric_obj = getattr(value, '_metric_obj', None)
1660 # Tensors that come from a Metric object already updated the Metric state.
1661 should_update_state = not metric_obj
1662 name = metric_obj.name if metric_obj else name
1664 with self._metrics_lock:
1665 match = self._get_existing_metric(name)
1666 if match:
1667 metric_obj = match
1668 elif metric_obj:
1669 self._metrics.append(metric_obj)
1670 else:
1671 # Build the metric object with the value's dtype if it defines one
1672 metric_obj = metrics_mod.Mean(
1673 name=name, dtype=getattr(value, 'dtype', None))
1674 self._metrics.append(metric_obj)
1676 if should_update_state:
1677 metric_obj(value)
1678 else:
1679 if from_metric_obj:
1680 raise ValueError('Using the result of calling a `Metric` object '
1681 'when calling `add_metric` on a Functional '
1682 'Model is not supported. Please pass the '
1683 'Tensor to monitor directly.')
1685 # Insert layers into the Keras Graph Network.
1686 aggregation = None if from_metric_obj else 'mean'
1687 self._graph_network_add_metric(value, aggregation, name)
1689 @doc_controls.do_not_doc_inheritable
1690 def add_update(self, updates, inputs=None):
1691 """Add update op(s), potentially dependent on layer inputs.
1693 Weight updates (for instance, the updates of the moving mean and variance
1694 in a BatchNormalization layer) may be dependent on the inputs passed
1695 when calling a layer. Hence, when reusing the same layer on
1696 different inputs `a` and `b`, some entries in `layer.updates` may be
1697 dependent on `a` and some on `b`. This method automatically keeps track
1698 of dependencies.
1700 This call is ignored when eager execution is enabled (in that case, variable
1701 updates are run on the fly and thus do not need to be tracked for later
1702 execution).
1704 Args:
1705 updates: Update op, or list/tuple of update ops, or zero-arg callable
1706 that returns an update op. A zero-arg callable should be passed in
1707 order to disable running the updates by setting `trainable=False`
1708 on this Layer, when executing in Eager mode.
1709 inputs: Deprecated, will be automatically inferred.
1710 """
1711 if inputs is not None:
1712 tf_logging.warning(
1713 '`add_update` `inputs` kwarg has been deprecated. You no longer need '
1714 'to pass a value to `inputs` as it is being automatically inferred.')
1715 call_context = base_layer_utils.call_context()
1716 # No need to run updates during Functional API construction.
1717 if call_context.in_keras_graph:
1718 return
1720 # Callable updates are disabled by setting `trainable=False`.
1721 if not call_context.frozen:
1722 for update in nest.flatten(updates):
1723 if callable(update):
1724 update() # pylint: disable=not-callable
1726 def set_weights(self, weights):
1727 """Sets the weights of the layer, from NumPy arrays.
1729 The weights of a layer represent the state of the layer. This function
1730 sets the weight values from numpy arrays. The weight values should be
1731 passed in the order they are created by the layer. Note that the layer's
1732 weights must be instantiated before calling this function, by calling
1733 the layer.
1735 For example, a `Dense` layer returns a list of two values: the kernel matrix
1736 and the bias vector. These can be used to set the weights of another
1737 `Dense` layer:
1739 >>> layer_a = tf.keras.layers.Dense(1,
1740 ... kernel_initializer=tf.constant_initializer(1.))
1741 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
1742 >>> layer_a.get_weights()
1743 [array([[1.],
1744 [1.],
1745 [1.]], dtype=float32), array([0.], dtype=float32)]
1746 >>> layer_b = tf.keras.layers.Dense(1,
1747 ... kernel_initializer=tf.constant_initializer(2.))
1748 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
1749 >>> layer_b.get_weights()
1750 [array([[2.],
1751 [2.],
1752 [2.]], dtype=float32), array([0.], dtype=float32)]
1753 >>> layer_b.set_weights(layer_a.get_weights())
1754 >>> layer_b.get_weights()
1755 [array([[1.],
1756 [1.],
1757 [1.]], dtype=float32), array([0.], dtype=float32)]
1759 Args:
1760 weights: a list of NumPy arrays. The number
1761 of arrays and their shape must match
1762 number of the dimensions of the weights
1763 of the layer (i.e. it should match the
1764 output of `get_weights`).
1766 Raises:
1767 ValueError: If the provided weights list does not match the
1768 layer's specifications.
1769 """
1770 params = self.weights
1772 expected_num_weights = 0
1773 for param in params:
1774 if isinstance(param, base_layer_utils.TrackableWeightHandler):
1775 expected_num_weights += param.num_tensors
1776 else:
1777 expected_num_weights += 1
1779 if expected_num_weights != len(weights):
1780 raise ValueError(
1781 'You called `set_weights(weights)` on layer "%s" '
1782 'with a weight list of length %s, but the layer was '
1783 'expecting %s weights. Provided weights: %s...' %
1784 (self.name, len(weights), expected_num_weights, str(weights)[:50]))
1786 weight_index = 0
1787 weight_value_tuples = []
1788 for param in params:
1789 if isinstance(param, base_layer_utils.TrackableWeightHandler):
1790 num_tensors = param.num_tensors
1791 tensors = weights[weight_index:weight_index + num_tensors]
1792 param.set_weights(tensors)
1793 weight_index += num_tensors
1794 else:
1795 weight = weights[weight_index]
1796 weight_shape = weight.shape if hasattr(weight, 'shape') else ()
1797 ref_shape = param.shape
1798 if not ref_shape.is_compatible_with(weight_shape):
1799 raise ValueError(
1800 'Layer weight shape %s not compatible with provided weight '
1801 'shape %s' % (ref_shape, weight_shape))
1802 weight_value_tuples.append((param, weight))
1803 weight_index += 1
1805 backend.batch_set_value(weight_value_tuples)
1807 # Perform any layer defined finalization of the layer state.
1808 for layer in self._flatten_layers():
1809 layer.finalize_state()
1811 def get_weights(self):
1812 """Returns the current weights of the layer, as NumPy arrays.
1814 The weights of a layer represent the state of the layer. This function
1815 returns both trainable and non-trainable weight values associated with this
1816 layer as a list of NumPy arrays, which can in turn be used to load state
1817 into similarly parameterized layers.
1819 For example, a `Dense` layer returns a list of two values: the kernel matrix
1820 and the bias vector. These can be used to set the weights of another
1821 `Dense` layer:
1823 >>> layer_a = tf.keras.layers.Dense(1,
1824 ... kernel_initializer=tf.constant_initializer(1.))
1825 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
1826 >>> layer_a.get_weights()
1827 [array([[1.],
1828 [1.],
1829 [1.]], dtype=float32), array([0.], dtype=float32)]
1830 >>> layer_b = tf.keras.layers.Dense(1,
1831 ... kernel_initializer=tf.constant_initializer(2.))
1832 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
1833 >>> layer_b.get_weights()
1834 [array([[2.],
1835 [2.],
1836 [2.]], dtype=float32), array([0.], dtype=float32)]
1837 >>> layer_b.set_weights(layer_a.get_weights())
1838 >>> layer_b.get_weights()
1839 [array([[1.],
1840 [1.],
1841 [1.]], dtype=float32), array([0.], dtype=float32)]
1843 Returns:
1844 Weights values as a list of NumPy arrays.
1845 """
1846 weights = self.weights
1847 output_weights = []
1848 for weight in weights:
1849 if isinstance(weight, base_layer_utils.TrackableWeightHandler):
1850 output_weights.extend(weight.get_tensors())
1851 else:
1852 output_weights.append(weight)
1853 return backend.batch_get_value(output_weights)
1855 @doc_controls.do_not_generate_docs
1856 def finalize_state(self):
1857 """Finalizes the layers state after updating layer weights.
1859 This function can be subclassed in a layer and will be called after updating
1860 a layer weights. It can be overridden to finalize any additional layer state
1861 after a weight update.
1862 """
1863 pass
1865 @doc_controls.do_not_generate_docs
1866 def get_updates_for(self, inputs):
1867 """Deprecated, do NOT use!
1869 Retrieves updates relevant to a specific set of inputs.
1871 Args:
1872 inputs: Input tensor or list/tuple of input tensors.
1874 Returns:
1875 List of update ops of the layer that depend on `inputs`.
1876 """
1877 warnings.warn('`layer.get_updates_for` is deprecated and '
1878 'will be removed in a future version. '
1879 'Please use `layer.updates` method instead.')
1880 return self.updates
1882 @doc_controls.do_not_generate_docs
1883 def get_losses_for(self, inputs):
1884 """Deprecated, do NOT use!
1886 Retrieves losses relevant to a specific set of inputs.
1888 Args:
1889 inputs: Input tensor or list/tuple of input tensors.
1891 Returns:
1892 List of loss tensors of the layer that depend on `inputs`.
1893 """
1894 warnings.warn('`layer.get_losses_for` is deprecated and '
1895 'will be removed in a future version. '
1896 'Please use `layer.losses` instead.')
1897 return self.losses
1899 @doc_controls.do_not_doc_inheritable
1900 def get_input_mask_at(self, node_index):
1901 """Retrieves the input mask tensor(s) of a layer at a given node.
1903 Args:
1904 node_index: Integer, index of the node
1905 from which to retrieve the attribute.
1906 E.g. `node_index=0` will correspond to the
1907 first time the layer was called.
1909 Returns:
1910 A mask tensor
1911 (or list of tensors if the layer has multiple inputs).
1912 """
1913 inputs = self.get_input_at(node_index)
1914 if isinstance(inputs, list):
1915 return [getattr(x, '_keras_mask', None) for x in inputs]
1916 else:
1917 return getattr(inputs, '_keras_mask', None)
1919 @doc_controls.do_not_doc_inheritable
1920 def get_output_mask_at(self, node_index):
1921 """Retrieves the output mask tensor(s) of a layer at a given node.
1923 Args:
1924 node_index: Integer, index of the node
1925 from which to retrieve the attribute.
1926 E.g. `node_index=0` will correspond to the
1927 first time the layer was called.
1929 Returns:
1930 A mask tensor
1931 (or list of tensors if the layer has multiple outputs).
1932 """
1933 output = self.get_output_at(node_index)
1934 if isinstance(output, list):
1935 return [getattr(x, '_keras_mask', None) for x in output]
1936 else:
1937 return getattr(output, '_keras_mask', None)
1939 @property
1940 @doc_controls.do_not_doc_inheritable
1941 def input_mask(self):
1942 """Retrieves the input mask tensor(s) of a layer.
1944 Only applicable if the layer has exactly one inbound node,
1945 i.e. if it is connected to one incoming layer.
1947 Returns:
1948 Input mask tensor (potentially None) or list of input
1949 mask tensors.
1951 Raises:
1952 AttributeError: if the layer is connected to
1953 more than one incoming layers.
1954 """
1955 inputs = self.input
1956 if isinstance(inputs, list):
1957 return [getattr(x, '_keras_mask', None) for x in inputs]
1958 else:
1959 return getattr(inputs, '_keras_mask', None)
1961 @property
1962 @doc_controls.do_not_doc_inheritable
1963 def output_mask(self):
1964 """Retrieves the output mask tensor(s) of a layer.
1966 Only applicable if the layer has exactly one inbound node,
1967 i.e. if it is connected to one incoming layer.
1969 Returns:
1970 Output mask tensor (potentially None) or list of output
1971 mask tensors.
1973 Raises:
1974 AttributeError: if the layer is connected to
1975 more than one incoming layers.
1976 """
1977 output = self.output
1978 if isinstance(output, list):
1979 return [getattr(x, '_keras_mask', None) for x in output]
1980 else:
1981 return getattr(output, '_keras_mask', None)
1983 @doc_controls.do_not_doc_inheritable
1984 def get_input_shape_at(self, node_index):
1985 """Retrieves the input shape(s) of a layer at a given node.
1987 Args:
1988 node_index: Integer, index of the node
1989 from which to retrieve the attribute.
1990 E.g. `node_index=0` will correspond to the
1991 first time the layer was called.
1993 Returns:
1994 A shape tuple
1995 (or list of shape tuples if the layer has multiple inputs).
1997 Raises:
1998 RuntimeError: If called in Eager mode.
1999 """
2000 return self._get_node_attribute_at_index(node_index, 'input_shapes',
2001 'input shape')
2003 @doc_controls.do_not_doc_inheritable
2004 def get_output_shape_at(self, node_index):
2005 """Retrieves the output shape(s) of a layer at a given node.
2007 Args:
2008 node_index: Integer, index of the node
2009 from which to retrieve the attribute.
2010 E.g. `node_index=0` will correspond to the
2011 first time the layer was called.
2013 Returns:
2014 A shape tuple
2015 (or list of shape tuples if the layer has multiple outputs).
2017 Raises:
2018 RuntimeError: If called in Eager mode.
2019 """
2020 return self._get_node_attribute_at_index(node_index, 'output_shapes',
2021 'output shape')
2023 @doc_controls.do_not_doc_inheritable
2024 def get_input_at(self, node_index):
2025 """Retrieves the input tensor(s) of a layer at a given node.
2027 Args:
2028 node_index: Integer, index of the node
2029 from which to retrieve the attribute.
2030 E.g. `node_index=0` will correspond to the
2031 first input node of the layer.
2033 Returns:
2034 A tensor (or list of tensors if the layer has multiple inputs).
2036 Raises:
2037 RuntimeError: If called in Eager mode.
2038 """
2039 return self._get_node_attribute_at_index(node_index, 'input_tensors',
2040 'input')
2042 @doc_controls.do_not_doc_inheritable
2043 def get_output_at(self, node_index):
2044 """Retrieves the output tensor(s) of a layer at a given node.
2046 Args:
2047 node_index: Integer, index of the node
2048 from which to retrieve the attribute.
2049 E.g. `node_index=0` will correspond to the
2050 first output node of the layer.
2052 Returns:
2053 A tensor (or list of tensors if the layer has multiple outputs).
2055 Raises:
2056 RuntimeError: If called in Eager mode.
2057 """
2058 return self._get_node_attribute_at_index(node_index, 'output_tensors',
2059 'output')
2061 @property
2062 def input(self):
2063 """Retrieves the input tensor(s) of a layer.
2065 Only applicable if the layer has exactly one input,
2066 i.e. if it is connected to one incoming layer.
2068 Returns:
2069 Input tensor or list of input tensors.
2071 Raises:
2072 RuntimeError: If called in Eager mode.
2073 AttributeError: If no inbound nodes are found.
2074 """
2075 if not self._inbound_nodes:
2076 raise AttributeError('Layer ' + self.name +
2077 ' is not connected, no input to return.')
2078 return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
2080 @property
2081 def output(self):
2082 """Retrieves the output tensor(s) of a layer.
2084 Only applicable if the layer has exactly one output,
2085 i.e. if it is connected to one incoming layer.
2087 Returns:
2088 Output tensor or list of output tensors.
2090 Raises:
2091 AttributeError: if the layer is connected to more than one incoming
2092 layers.
2093 RuntimeError: if called in Eager mode.
2094 """
2095 if not self._inbound_nodes:
2096 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
2097 return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
2099 @property
2100 @doc_controls.do_not_doc_inheritable
2101 def input_shape(self):
2102 """Retrieves the input shape(s) of a layer.
2104 Only applicable if the layer has exactly one input,
2105 i.e. if it is connected to one incoming layer, or if all inputs
2106 have the same shape.
2108 Returns:
2109 Input shape, as an integer shape tuple
2110 (or list of shape tuples, one tuple per input tensor).
2112 Raises:
2113 AttributeError: if the layer has no defined input_shape.
2114 RuntimeError: if called in Eager mode.
2115 """
2116 if not self._inbound_nodes:
2117 raise AttributeError('The layer has never been called '
2118 'and thus has no defined input shape.')
2119 all_input_shapes = set(
2120 [str(node.input_shapes) for node in self._inbound_nodes])
2121 if len(all_input_shapes) == 1:
2122 return self._inbound_nodes[0].input_shapes
2123 else:
2124 raise AttributeError('The layer "' + str(self.name) +
2125 ' has multiple inbound nodes, '
2126 'with different input shapes. Hence '
2127 'the notion of "input shape" is '
2128 'ill-defined for the layer. '
2129 'Use `get_input_shape_at(node_index)` '
2130 'instead.')
2132 def count_params(self):
2133 """Count the total number of scalars composing the weights.
2135 Returns:
2136 An integer count.
2138 Raises:
2139 ValueError: if the layer isn't yet built
2140 (in which case its weights aren't yet defined).
2141 """
2142 if not self.built:
2143 if getattr(self, '_is_graph_network', False):
2144 with tf_utils.maybe_init_scope(self):
2145 self._maybe_build(self.inputs)
2146 else:
2147 raise ValueError('You tried to call `count_params` on ' + self.name +
2148 ', but the layer isn\'t built. '
2149 'You can build it manually via: `' + self.name +
2150 '.build(batch_input_shape)`.')
2151 return layer_utils.count_params(self.weights)
2153 @property
2154 @doc_controls.do_not_doc_inheritable
2155 def output_shape(self):
2156 """Retrieves the output shape(s) of a layer.
2158 Only applicable if the layer has one output,
2159 or if all outputs have the same shape.
2161 Returns:
2162 Output shape, as an integer shape tuple
2163 (or list of shape tuples, one tuple per output tensor).
2165 Raises:
2166 AttributeError: if the layer has no defined output shape.
2167 RuntimeError: if called in Eager mode.
2168 """
2169 if not self._inbound_nodes:
2170 raise AttributeError('The layer has never been called '
2171 'and thus has no defined output shape.')
2172 all_output_shapes = set(
2173 [str(node.output_shapes) for node in self._inbound_nodes])
2174 if len(all_output_shapes) == 1:
2175 return self._inbound_nodes[0].output_shapes
2176 else:
2177 raise AttributeError('The layer "%s"'
2178 ' has multiple inbound nodes, '
2179 'with different output shapes. Hence '
2180 'the notion of "output shape" is '
2181 'ill-defined for the layer. '
2182 'Use `get_output_shape_at(node_index)` '
2183 'instead.' % self.name)
2185 @property
2186 @doc_controls.do_not_doc_inheritable
2187 def inbound_nodes(self):
2188 """Deprecated, do NOT use! Only for compatibility with external Keras."""
2189 return self._inbound_nodes
2191 @property
2192 @doc_controls.do_not_doc_inheritable
2193 def outbound_nodes(self):
2194 """Deprecated, do NOT use! Only for compatibility with external Keras."""
2195 return self._outbound_nodes
2197 ##############################################################################
2198 # Methods & attributes below are public aliases of other methods. #
2199 ##############################################################################
2201 @doc_controls.do_not_doc_inheritable
2202 def apply(self, inputs, *args, **kwargs):
2203 """Deprecated, do NOT use!
2205 This is an alias of `self.__call__`.
2207 Args:
2208 inputs: Input tensor(s).
2209 *args: additional positional arguments to be passed to `self.call`.
2210 **kwargs: additional keyword arguments to be passed to `self.call`.
2212 Returns:
2213 Output tensor(s).
2214 """
2215 warnings.warn('`layer.apply` is deprecated and '
2216 'will be removed in a future version. '
2217 'Please use `layer.__call__` method instead.')
2218 return self.__call__(inputs, *args, **kwargs)
2220 @doc_controls.do_not_doc_inheritable
2221 def add_variable(self, *args, **kwargs):
2222 """Deprecated, do NOT use! Alias for `add_weight`."""
2223 warnings.warn('`layer.add_variable` is deprecated and '
2224 'will be removed in a future version. '
2225 'Please use `layer.add_weight` method instead.')
2226 return self.add_weight(*args, **kwargs)
2228 @property
2229 @doc_controls.do_not_generate_docs
2230 def variables(self):
2231 """Returns the list of all layer variables/weights.
2233 Alias of `self.weights`.
2235 Note: This will not track the weights of nested `tf.Modules` that are not
2236 themselves Keras layers.
2238 Returns:
2239 A list of variables.
2240 """
2241 return self.weights
2243 @property
2244 @doc_controls.do_not_generate_docs
2245 def trainable_variables(self):
2246 return self.trainable_weights
2248 @property
2249 @doc_controls.do_not_generate_docs
2250 def non_trainable_variables(self):
2251 return self.non_trainable_weights
2253 ##############################################################################
2254 # Methods & attributes below are all private and only used by the framework. #
2255 ##############################################################################
2257 @property
2258 def _inbound_nodes(self):
2259 return self._inbound_nodes_value
2261 @_inbound_nodes.setter
2262 @trackable.no_automatic_dependency_tracking
2263 def _inbound_nodes(self, value):
2264 self._inbound_nodes_value = value
2266 @property
2267 def _outbound_nodes(self):
2268 return self._outbound_nodes_value
2270 @_outbound_nodes.setter
2271 @trackable.no_automatic_dependency_tracking
2272 def _outbound_nodes(self, value):
2273 self._outbound_nodes_value = value
2275 def _set_dtype_policy(self, dtype):
2276 """Sets self._dtype_policy."""
2277 if isinstance(dtype, policy.Policy):
2278 self._dtype_policy = dtype
2279 elif isinstance(dtype, dict):
2280 self._dtype_policy = policy.deserialize(dtype)
2281 elif isinstance(dtype, str) and dtype in ('mixed_float16',
2282 'mixed_bfloat16'):
2283 # The isinstance check is required since np.dtype raises an error if
2284 # compared to a non-dtype string.
2285 self._dtype_policy = policy.Policy(dtype)
2286 elif dtype:
2287 self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
2288 else:
2289 self._dtype_policy = policy.global_policy()
2290 if (self._dtype_policy.name == 'mixed_float16' and
2291 not loss_scale_optimizer.strategy_supports_loss_scaling()):
2292 # Although only loss scaling doesn't support certain strategies, to avoid
2293 # confusion, we disallow the 'mixed_float16' policy with unsupported
2294 # strategies. This is because 'mixed_float16' requires loss scaling for
2295 # numeric stability.
2296 strategy = distribute_lib.get_strategy()
2297 raise ValueError('Mixed precision is not supported with the '
2298 'tf.distribute.Strategy: %s. Either stop using mixed '
2299 'precision by removing the use of the "%s" policy or '
2300 'use a different Strategy, e.g. a MirroredStrategy.' %
2301 (strategy.__class__.__name__, self._dtype_policy.name))
2303 # Performance optimization: cache the compute dtype as a Dtype object or
2304 # None, so that str to Dtype conversion doesn't happen in Layer.__call__.
2305 # TODO(b/157486353): Investigate returning DTypes in Policy.
2306 if self._dtype_policy.compute_dtype:
2307 self._compute_dtype_object = dtypes.as_dtype(
2308 self._dtype_policy.compute_dtype)
2309 else:
2310 self._compute_dtype_object = None
2312 @property
2313 def dtype_policy(self):
2314 """The dtype policy associated with this layer.
2316 This is an instance of a `tf.keras.mixed_precision.Policy`.
2317 """
2318 return self._dtype_policy
2320 @property
2321 def compute_dtype(self):
2322 """The dtype of the layer's computations.
2324 This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless
2325 mixed precision is used, this is the same as `Layer.dtype`, the dtype of
2326 the weights.
2328 Layers automatically cast their inputs to the compute dtype, which causes
2329 computations and the output to be in the compute dtype as well. This is done
2330 by the base Layer class in `Layer.__call__`, so you do not have to insert
2331 these casts if implementing your own layer.
2333 Layers often perform certain internal computations in higher precision when
2334 `compute_dtype` is float16 or bfloat16 for numeric stability. The output
2335 will still typically be float16 or bfloat16 in such cases.
2337 Returns:
2338 The layer's compute dtype.
2339 """
2340 return self._dtype_policy.compute_dtype
2342 @property
2343 def _compute_dtype(self):
2344 """Deprecated alias of `compute_dtype`."""
2345 return self._dtype_policy.compute_dtype
2347 @property
2348 def variable_dtype(self):
2349 """Alias of `Layer.dtype`, the dtype of the weights."""
2350 return self.dtype
2352 def _maybe_cast_inputs(self, inputs, input_list=None):
2353 """Maybe casts the inputs to the compute dtype.
2355 If self._compute_dtype is floating-point, and self_autocast is True,
2356 floating-point inputs are casted to self._compute_dtype.
2358 Args:
2359 inputs: Input tensor, or structure of input tensors.
2360 input_list: Flat list of input tensors.
2362 Returns:
2363 `inputs`, but tensors may have been casted to self._compute_dtype
2364 """
2365 if not input_list:
2366 input_list = nest.flatten(inputs)
2368 compute_dtype_object = self._compute_dtype_object
2369 should_autocast = (
2370 self._autocast and compute_dtype_object and
2371 compute_dtype_object.is_floating)
2373 if (should_autocast and
2374 any(map(self._should_cast_single_input, input_list))):
2375 # Only perform expensive `nest` operation when needed.
2376 return nest.map_structure(self._cast_single_input, inputs)
2377 else:
2378 return inputs
2380 def _should_cast_single_input(self, x):
2381 if isinstance(x, _AUTOCAST_TYPES):
2382 return (self._compute_dtype_object and
2383 x.dtype != self._compute_dtype_object and x.dtype.is_floating)
2384 return False
2386 def _cast_single_input(self, x):
2387 """Cast a single Tensor or TensorSpec to the compute dtype."""
2388 if self._should_cast_single_input(x):
2389 return math_ops.cast(x, self._compute_dtype_object)
2390 else:
2391 return x
2393 # _dtype used to be an attribute set in the constructor. We still expose it
2394 # because some clients still use it.
2395 # TODO(reedwm): Deprecate, then remove the _dtype property.
2396 @property
2397 def _dtype(self):
2398 # This is equivalent to returning self.dtype . We do not return self.dtype
2399 # as it would cause infinite recursion in a few subclasses, which override
2400 # "dtype" to return self._dtype.
2401 return self._dtype_policy.variable_dtype
2403 @_dtype.setter
2404 def _dtype(self, value):
2405 value = dtypes.as_dtype(value).name
2406 self._set_dtype_policy(policy.Policy(value))
2408 def _name_scope(self): # pylint: disable=method-hidden
2409 if not tf2.enabled():
2410 return self.name
2411 name_scope = self.name
2412 current_name_scope = ops.get_name_scope()
2413 if current_name_scope:
2414 name_scope = current_name_scope + '/' + name_scope
2415 if name_scope:
2416 # Note that the trailing `/` prevents autogenerated
2417 # numerical suffixes to get appended. It will also fully reset
2418 # nested name scope (i.e. the outer name scope has no effect).
2419 name_scope += '/'
2420 return name_scope
2422 def _init_set_name(self, name, zero_based=True):
2423 if not name:
2424 self._name = backend.unique_object_name(
2425 generic_utils.to_snake_case(self.__class__.__name__),
2426 zero_based=zero_based)
2427 else:
2428 backend.observe_object_name(name)
2429 self._name = name
2431 def _get_existing_metric(self, name=None):
2432 match = [m for m in self._metrics if m.name == name]
2433 if not match:
2434 return
2435 if len(match) > 1:
2436 raise ValueError(
2437 'Please provide different names for the metrics you have added. '
2438 'We found {} metrics with the name: "{}"'.format(len(match), name))
2439 return match[0]
2441 def _handle_weight_regularization(self, name, variable, regularizer):
2442 """Create lambdas which compute regularization losses."""
2444 def _loss_for_variable(v):
2445 """Creates a regularization loss `Tensor` for variable `v`."""
2446 with backend.name_scope(name + '/Regularizer'):
2447 regularization = regularizer(v)
2448 return regularization
2450 if base_layer_utils.is_split_variable(variable):
2451 for v in variable:
2452 self.add_loss(functools.partial(_loss_for_variable, v))
2453 else:
2454 self.add_loss(functools.partial(_loss_for_variable, variable))
2456 def _handle_activity_regularization(self, inputs, outputs):
2457 # Apply activity regularization.
2458 # Note that it should be applied every time the layer creates a new
2459 # output, since it is output-specific.
2460 if self._activity_regularizer:
2461 output_list = nest.flatten(outputs)
2462 with backend.name_scope('ActivityRegularizer'):
2463 for output in output_list:
2464 activity_loss = self._activity_regularizer(output)
2465 batch_size = math_ops.cast(
2466 array_ops.shape(output)[0], activity_loss.dtype)
2467 # Make activity regularization strength batch-agnostic.
2468 mean_activity_loss = activity_loss / batch_size
2469 self.add_loss(mean_activity_loss)
2471 def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph):
2472 # Many `Layer`s don't need to call `compute_mask`.
2473 # This method is optimized to do as little work as needed for the common
2474 # case.
2475 if not self._supports_masking:
2476 return
2478 flat_outputs = nest.flatten(outputs)
2480 mask_already_computed = (
2481 getattr(self, '_compute_output_and_mask_jointly', False) or
2482 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
2483 if mask_already_computed:
2484 if build_graph:
2485 self._set_mask_keras_history_checked(flat_outputs)
2486 return
2488 output_masks = self.compute_mask(inputs, previous_mask)
2489 if output_masks is None:
2490 return
2492 flat_masks = nest.flatten(output_masks)
2493 for tensor, mask in zip(flat_outputs, flat_masks):
2494 try:
2495 tensor._keras_mask = mask
2496 except AttributeError:
2497 # C Type such as np.ndarray.
2498 pass
2500 if build_graph:
2501 self._set_mask_keras_history_checked(flat_outputs)
2503 def _set_mask_keras_history_checked(self, flat_outputs):
2504 for output in flat_outputs:
2505 if getattr(output, '_keras_mask', None) is not None:
2506 # Do not track masks for `TensorFlowOpLayer` construction.
2507 output._keras_mask._keras_history_checked = True
2509 def _get_input_masks(self, inputs, input_list, args, kwargs):
2510 if not self._supports_masking and not self._expects_mask_arg:
2511 # Input masks only need to be retrieved if they are needed for `call`
2512 # or `compute_mask`.
2513 input_masks = None
2514 implicit_mask = False
2515 elif self._call_arg_was_passed('mask', args, kwargs):
2516 input_masks = self._get_call_arg_value('mask', args, kwargs)
2517 implicit_mask = False
2518 else:
2519 input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
2520 if all(mask is None for mask in input_masks):
2521 input_masks = None
2522 implicit_mask = False
2523 else:
2524 # Only do expensive `nest` op when masking is actually being used.
2525 input_masks = nest.pack_sequence_as(inputs, input_masks)
2526 implicit_mask = True
2527 return input_masks, implicit_mask
2529 def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
2530 # Performance optimization: do no work in most common case.
2531 if not args and not kwargs:
2532 return False
2534 if arg_name in kwargs:
2535 return True
2536 call_fn_args = self._call_fn_args
2537 if not inputs_in_args:
2538 # Ignore `inputs` arg.
2539 call_fn_args = call_fn_args[1:]
2540 return arg_name in dict(zip(call_fn_args, args))
2542 def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
2543 if arg_name in kwargs:
2544 return kwargs[arg_name]
2545 call_fn_args = self._call_fn_args
2546 if not inputs_in_args:
2547 # Ignore `inputs` arg.
2548 call_fn_args = call_fn_args[1:]
2549 args_dict = dict(zip(call_fn_args, args))
2550 return args_dict[arg_name]
2552 def _set_call_arg_value(
2553 self, arg_name, new_value, args,
2554 kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
2555 arg_pos = self._call_fn_arg_positions.get(arg_name, None)
2556 if arg_pos is not None:
2557 if not inputs_in_args:
2558 # Ignore `inputs` arg.
2559 arg_pos = arg_pos - 1
2560 if len(args) > arg_pos:
2561 args = list(args)
2562 args[arg_pos] = new_value
2563 return tuple(args), kwargs
2564 if new_value is None and pop_kwarg_if_none:
2565 kwargs.pop(arg_name, None)
2566 else:
2567 kwargs[arg_name] = new_value
2568 return args, kwargs
2570 def _set_connectivity_metadata(self, args, kwargs, outputs):
2571 # If the layer returns tensors from its inputs unmodified,
2572 # we copy them to avoid loss of KerasHistory metadata.
2573 flat_outputs = nest.flatten(outputs)
2574 flat_inputs = nest.flatten((args, kwargs))
2575 input_ids_set = {id(i) for i in flat_inputs}
2576 outputs_copy = []
2577 for x in flat_outputs:
2578 if id(x) in input_ids_set:
2579 with backend.name_scope(self.name):
2580 x = array_ops.identity(x)
2581 outputs_copy.append(x)
2582 outputs = nest.pack_sequence_as(outputs, outputs_copy)
2584 # Create node, Node wires itself to inbound and outbound layers.
2585 # The Node constructor actually updates this layer's self._inbound_nodes,
2586 # sets _keras_history on the outputs, and adds itself to the
2587 # `_outbound_nodes` of the layers that produced the inputs to this
2588 # layer call.
2589 node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs)
2590 return outputs
2592 def _get_node_attribute_at_index(self, node_index, attr, attr_name):
2593 """Private utility to retrieves an attribute (e.g. inputs) from a node.
2595 This is used to implement the methods:
2596 - get_input_shape_at
2597 - get_output_shape_at
2598 - get_input_at
2599 etc...
2601 Args:
2602 node_index: Integer index of the node from which
2603 to retrieve the attribute.
2604 attr: Exact node attribute name.
2605 attr_name: Human-readable attribute name, for error messages.
2607 Returns:
2608 The layer's attribute `attr` at the node of index `node_index`.
2610 Raises:
2611 RuntimeError: If the layer has no inbound nodes, or if called in Eager
2612 mode.
2613 ValueError: If the index provided does not match any node.
2614 """
2615 if not self._inbound_nodes:
2616 raise RuntimeError('The layer has never been called '
2617 'and thus has no defined ' + attr_name + '.')
2618 if not len(self._inbound_nodes) > node_index:
2619 raise ValueError('Asked to get ' + attr_name + ' at node ' +
2620 str(node_index) + ', but the layer has only ' +
2621 str(len(self._inbound_nodes)) + ' inbound nodes.')
2622 values = getattr(self._inbound_nodes[node_index], attr)
2623 if isinstance(values, list) and len(values) == 1:
2624 return values[0]
2625 else:
2626 return values
2628 def _maybe_build(self, inputs):
2629 # Check input assumptions set before layer building, e.g. input rank.
2630 if not self.built:
2631 input_spec.assert_input_compatibility(
2632 self.input_spec, inputs, self.name)
2633 input_list = nest.flatten(inputs)
2634 if input_list and self._dtype_policy.compute_dtype is None:
2635 try:
2636 dtype = input_list[0].dtype.base_dtype.name
2637 except AttributeError:
2638 pass
2639 else:
2640 self._set_dtype_policy(policy.Policy(dtype))
2641 input_shapes = None
2642 # Converts Tensors / CompositeTensors to TensorShapes.
2643 if all(hasattr(x, 'shape') for x in input_list):
2644 input_shapes = tf_utils.get_shapes(inputs)
2645 else:
2646 # Converts input shape to TensorShapes.
2647 try:
2648 input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False)
2649 except ValueError:
2650 pass
2651 # Only call `build` if the user has manually overridden the build method.
2652 if not hasattr(self.build, '_is_default'):
2653 # Any setup work performed only once should happen in an `init_scope`
2654 # to avoid creating symbolic Tensors that will later pollute any eager
2655 # operations.
2656 with tf_utils.maybe_init_scope(self):
2657 self.build(input_shapes) # pylint:disable=not-callable
2658 # We must set also ensure that the layer is marked as built, and the build
2659 # shape is stored since user defined build functions may not be calling
2660 # `super.build()`
2661 Layer.build(self, input_shapes)
2663 # Optionally load weight values specified at layer instantiation.
2664 if self._initial_weights is not None:
2665 with ops.init_scope():
2666 # Using `init_scope` since we want variable assignment in
2667 # `set_weights` to be treated like variable initialization.
2668 self.set_weights(self._initial_weights)
2669 self._initial_weights = None
2671 def _symbolic_call(self, inputs):
2672 input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2673 output_shapes = self.compute_output_shape(input_shapes)
2674 # Convert to TensorShape so that nest.map_structure will not map into
2675 # individual dim of the shape.
2676 output_shapes = tf_utils.convert_shapes(output_shapes, to_tuples=False)
2678 def _make_placeholder_like(shape):
2679 ph = backend.placeholder(shape=shape, dtype=self.dtype)
2680 ph._keras_mask = None
2681 return ph
2682 return nest.map_structure(_make_placeholder_like, output_shapes)
2684 def _get_trainable_state(self):
2685 """Get the `trainable` state of each sublayer.
2687 Returns:
2688 A dict mapping all sublayers to their `trainable` value.
2689 """
2690 trainable_state = weakref.WeakKeyDictionary()
2691 for layer in self._flatten_layers():
2692 trainable_state[layer] = layer.trainable
2693 return trainable_state
2695 def _set_trainable_state(self, trainable_state):
2696 """Set `trainable` state for each sublayer."""
2697 for layer in self._flatten_layers():
2698 if layer in trainable_state:
2699 layer.trainable = trainable_state[layer]
2701 @property
2702 def _obj_reference_counts(self):
2703 """A dictionary counting the number of attributes referencing an object."""
2704 self._maybe_create_attribute('_obj_reference_counts_dict',
2705 object_identity.ObjectIdentityDictionary())
2706 return self._obj_reference_counts_dict
2708 @trackable.no_automatic_dependency_tracking
2709 def _maybe_create_attribute(self, name, default_value):
2710 """Create the attribute with the default value if it hasn't been created.
2712 This is useful for fields that is used for tracking purpose,
2713 _trainable_weights, or _layers. Note that user could create a layer subclass
2714 and assign an internal field before invoking the Layer.__init__(), the
2715 __setattr__() need to create the tracking fields and __init__() need to not
2716 override them.
2718 Args:
2719 name: String, the name of the attribute.
2720 default_value: Object, the default value of the attribute.
2721 """
2722 if not hasattr(self, name):
2723 self.__setattr__(name, default_value)
2725 def __delattr__(self, name):
2726 # For any super.__delattr__() call, we will directly use the implementation
2727 # in Trackable and skip the behavior in AutoTrackable. The Layer was
2728 # originally use Trackable as base class, the change of using Module as base
2729 # class forced us to have AutoTrackable in the class hierarchy.
2730 #
2731 # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
2732 # __setattr__ in AutoTrackable may be unsustainable.
2733 existing_value = getattr(self, name, None)
2735 # If this value is replacing an existing object assigned to an attribute, we
2736 # should clean it out to avoid leaking memory. First we check if there are
2737 # other attributes referencing it.
2738 reference_counts = self._obj_reference_counts
2739 if existing_value not in reference_counts:
2740 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call
2741 return
2743 reference_count = reference_counts[existing_value]
2744 if reference_count > 1:
2745 # There are other remaining references. We can't remove this object from
2746 # _layers etc.
2747 reference_counts[existing_value] = reference_count - 1
2748 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call
2749 return
2750 else:
2751 # This is the last remaining reference.
2752 del reference_counts[existing_value]
2754 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call
2756 if (isinstance(existing_value, Layer)
2757 or base_layer_utils.has_weights(existing_value)):
2758 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call
2759 '_self_tracked_trackables',
2760 [l for l in self._self_tracked_trackables if l is not existing_value])
2761 if isinstance(existing_value, tf_variables.Variable):
2762 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call
2763 '_trainable_weights',
2764 [w for w in self._trainable_weights if w is not existing_value])
2765 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call
2766 '_non_trainable_weights',
2767 [w for w in self._non_trainable_weights if w is not existing_value])
2769 def __setattr__(self, name, value):
2770 if (name == '_self_setattr_tracking' or
2771 not getattr(self, '_self_setattr_tracking', True) or
2772 # Exclude @property.setters from tracking
2773 hasattr(self.__class__, name)):
2774 try:
2775 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call
2776 except AttributeError:
2777 raise AttributeError(
2778 ('Can\'t set the attribute "{}", likely because it conflicts with '
2779 'an existing read-only @property of the object. Please choose a '
2780 'different name.').format(name))
2781 return
2783 # Wraps data structures in `Trackable`, unwraps `NoDependency` objects.
2784 value = data_structures.sticky_attribute_assignment(
2785 trackable=self, value=value, name=name)
2787 reference_counts = self._obj_reference_counts
2788 reference_counts[value] = reference_counts.get(value, 0) + 1
2790 # Clean out the old attribute, which clears _layers and _trainable_weights
2791 # if necessary.
2792 try:
2793 self.__delattr__(name)
2794 except AttributeError:
2795 pass
2797 # Keep track of metric instance created in subclassed layer.
2798 for val in nest.flatten(value):
2799 if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'):
2800 self._metrics.append(val)
2802 # Append value to self._self_tracked_trackables if relevant
2803 if (getattr(self, '_auto_track_sub_layers', True) and
2804 (isinstance(value, module.Module) or
2805 base_layer_utils.has_weights(value))):
2806 self._maybe_create_attribute('_self_tracked_trackables', [])
2807 # We need to check object identity to avoid de-duplicating empty
2808 # container types which compare equal.
2809 if not any((layer is value for layer in self._self_tracked_trackables)):
2810 self._self_tracked_trackables.append(value)
2811 if hasattr(value, '_use_resource_variables'):
2812 # Legacy layers (V1 tf.layers) must always use
2813 # resource variables.
2814 value._use_resource_variables = True
2816 # Append value to list of trainable / non-trainable weights if relevant
2817 # TODO(b/125122625): This won't pick up on any variables added to a
2818 # list/dict after creation.
2819 for val in nest.flatten(value, expand_composites=True):
2820 if not isinstance(val, tf_variables.Variable):
2821 continue
2823 # Users may add extra weights/variables
2824 # simply by assigning them to attributes (invalid for graph networks)
2825 self._maybe_create_attribute('_trainable_weights', [])
2826 self._maybe_create_attribute('_non_trainable_weights', [])
2827 if val.trainable:
2828 if any(val is w for w in self._trainable_weights):
2829 continue
2830 self._trainable_weights.append(val)
2831 else:
2832 if any(val is w for w in self._non_trainable_weights):
2833 continue
2834 self._non_trainable_weights.append(val)
2836 backend.track_variable(val)
2838 # TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
2839 # quo. See the comment at __delattr__.
2840 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call
2842 def _gather_children_attribute(self, attribute):
2843 assert attribute in {
2844 'variables', 'trainable_variables', 'non_trainable_variables'
2845 }
2846 if hasattr(self, '_self_tracked_trackables'):
2847 nested_layers = self._flatten_modules(include_self=False, recursive=False)
2848 return list(
2849 itertools.chain.from_iterable(
2850 getattr(layer, attribute) for layer in nested_layers))
2851 return []
2853 def _flatten_layers(self, recursive=True, include_self=True):
2854 for m in self._flatten_modules(
2855 recursive=recursive, include_self=include_self):
2856 if isinstance(m, Layer):
2857 yield m
2859 def _flatten_modules(self, recursive=True, include_self=True):
2860 """Flattens `tf.Module` instances (excluding `Metrics`).
2862 Args:
2863 recursive: Whether to recursively flatten through submodules.
2864 include_self: Whether to include this `Layer` instance.
2866 Yields:
2867 `tf.Module` instance tracked by this `Layer`.
2868 """
2869 if include_self:
2870 yield self
2872 # Only instantiate set and deque if needed.
2873 trackables = getattr(self, '_self_tracked_trackables', None)
2874 if trackables:
2875 seen_object_ids = set()
2876 deque = collections.deque(trackables)
2877 while deque:
2878 trackable_obj = deque.popleft()
2879 trackable_id = id(trackable_obj)
2880 if trackable_id in seen_object_ids:
2881 continue
2882 seen_object_ids.add(trackable_id)
2884 # Metrics are not considered part of the Layer's topology.
2885 if (isinstance(trackable_obj, module.Module) and
2886 not isinstance(trackable_obj, metrics_mod.Metric)):
2887 yield trackable_obj
2888 # Introspect recursively through sublayers.
2889 if recursive:
2890 subtrackables = getattr(trackable_obj, '_self_tracked_trackables',
2891 None)
2892 if subtrackables:
2893 deque.extendleft(reversed(subtrackables))
2894 elif isinstance(trackable_obj, data_structures.TrackableDataStructure):
2895 # Data structures are introspected even with `recursive=False`.
2896 tracked_values = trackable_obj._values
2897 if tracked_values:
2898 deque.extendleft(reversed(tracked_values))
2900 # This is a hack so that the is_layer (within
2901 # training/trackable/layer_utils.py) check doesn't get the weights attr.
2902 # TODO(b/110718070): Remove when fixed.
2903 def _is_layer(self):
2904 return True
2906 def _init_call_fn_args(self, expects_training_arg=None):
2907 # Clear cached call function arguments.
2908 self.__class__._call_full_argspec.fget.cache.pop(self, None)
2909 self.__class__._call_fn_args.fget.cache.pop(self, None)
2910 self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
2912 call_fn_args = self._call_fn_args
2913 call_fn_args += self._call_full_argspec.kwonlyargs or []
2914 if expects_training_arg is None:
2915 self._expects_training_arg = ('training' in call_fn_args or
2916 self._call_accepts_kwargs)
2917 else:
2918 # Use value encoded into the metadata when loading from the SavedModel.
2919 self._expects_training_arg = expects_training_arg
2920 # The default training arg will be any (non-None) default specified in the
2921 # method signature, or None if no value is specified.
2922 call_fn_arg_defaults = self._call_fn_arg_defaults.copy()
2923 call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {})
2924 self._default_training_arg = call_fn_arg_defaults.get('training')
2926 self._expects_mask_arg = ('mask' in call_fn_args or
2927 self._call_accepts_kwargs)
2929 @property
2930 @layer_utils.cached_per_instance
2931 def _call_full_argspec(self):
2932 # Argspec inspection is expensive and the call spec is used often, so it
2933 # makes sense to cache the result.
2934 return tf_inspect.getfullargspec(self.call)
2936 @property
2937 @layer_utils.cached_per_instance
2938 def _call_fn_args(self):
2939 all_args = self._call_full_argspec.args
2940 # Scrub `self` that appears if a decorator was applied.
2941 if all_args and all_args[0] == 'self':
2942 return all_args[1:]
2943 return all_args
2945 @property
2946 @layer_utils.cached_per_instance
2947 def _call_fn_arg_defaults(self):
2948 call_fn_args = self._call_fn_args
2949 call_fn_defaults = self._call_full_argspec.defaults or []
2950 defaults = dict()
2952 # The call arg defaults are an n-tuple of the last n elements of the args
2953 # list. (n = # of elements that have a default argument)
2954 for i in range(-1 * len(call_fn_defaults), 0):
2955 defaults[call_fn_args[i]] = call_fn_defaults[i]
2956 return defaults
2958 @property
2959 @layer_utils.cached_per_instance
2960 def _call_fn_arg_positions(self):
2961 call_fn_arg_positions = dict()
2962 for pos, arg in enumerate(self._call_fn_args):
2963 call_fn_arg_positions[arg] = pos
2964 return call_fn_arg_positions
2966 @property
2967 @layer_utils.cached_per_instance
2968 def _call_accepts_kwargs(self):
2969 return self._call_full_argspec.varkw is not None
2971 @property
2972 def _eager_losses(self):
2973 # A list of loss values containing activity regularizers and losses
2974 # manually added through `add_loss` during eager execution. It is cleared
2975 # after every batch.
2976 # Because we plan on eventually allowing a same model instance to be trained
2977 # in eager mode or graph mode alternatively, we need to keep track of
2978 # eager losses and symbolic losses via separate attributes.
2979 if not hasattr(self._thread_local, '_eager_losses'):
2980 self._thread_local._eager_losses = []
2981 return self._thread_local._eager_losses
2983 @_eager_losses.setter
2984 def _eager_losses(self, losses):
2985 self._thread_local._eager_losses = losses
2987 def _dedup_weights(self, weights):
2988 """Dedupe weights while maintaining order as much as possible."""
2989 output, seen_ids = [], set()
2990 for w in weights:
2991 if id(w) not in seen_ids:
2992 output.append(w)
2993 # Track the Variable's identity to avoid __eq__ issues.
2994 seen_ids.add(id(w))
2996 return output
2998 def _split_out_first_arg(self, args, kwargs):
2999 # Grab the argument corresponding to the first argument in the
3000 # layer's `call` method spec. This will either be the first positional
3001 # argument, or it will be provided as a keyword argument.
3002 if args:
3003 inputs = args[0]
3004 args = args[1:]
3005 elif self._call_fn_args[0] in kwargs:
3006 kwargs = copy.copy(kwargs)
3007 inputs = kwargs.pop(self._call_fn_args[0])
3008 else:
3009 raise ValueError(
3010 'The first argument to `Layer.call` must always be passed.')
3011 return inputs, args, kwargs
3013 # SavedModel properties. Please see keras/saving/saved_model for details.
3015 @trackable.no_automatic_dependency_tracking
3016 def _set_save_spec(self, inputs):
3017 if self._saved_model_inputs_spec is not None:
3018 return # Already set.
3020 self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec,
3021 inputs)
3023 def _get_save_spec(self, dynamic_batch=True):
3024 if self._saved_model_inputs_spec is None:
3025 return None
3027 return nest.map_structure(
3028 lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
3029 self._saved_model_inputs_spec)
3031 @property
3032 def _trackable_saved_model_saver(self):
3033 return layer_serialization.LayerSavedModelSaver(self)
3035 @property
3036 def _object_identifier(self):
3037 return self._trackable_saved_model_saver.object_identifier
3039 @property
3040 def _tracking_metadata(self):
3041 """Info about this layer to be saved into the SavedModel."""
3042 return self._trackable_saved_model_saver.tracking_metadata
3044 def _trackable_children(self, save_type='checkpoint', **kwargs):
3045 if save_type == 'savedmodel':
3046 cache = kwargs['cache']
3047 # TODO(b/213628533): This must be called before super() to ensure
3048 # that any input shape changes are applied before getting the config of
3049 # the model.
3050 children = self._trackable_saved_model_saver.trackable_children(cache)
3051 else:
3052 children = {}
3053 children.update(super()._trackable_children(save_type, **kwargs))
3054 return children
3056 @property
3057 def _use_input_spec_as_call_signature(self):
3058 # Whether input spec can be used as the call signature when tracing the
3059 # Layer for SavedModel. By default, this is set to `True` for layers
3060 # exported from the Keras library, because the layers more rigidly define
3061 # the `input_specs` property (many custom layers only set the `ndims`)
3062 return get_canonical_name_for_symbol(type(self),
3063 api_name='keras') is not None
3065 def __getstate__(self):
3066 # Override to support `copy.deepcopy` and pickling.
3067 # Thread-local objects cannot be copied in Python 3, so pop these.
3068 # Thread-local objects are used to cache losses in MirroredStrategy, and
3069 # so shouldn't be copied.
3070 state = self.__dict__.copy()
3071 state.pop('_thread_local', None)
3072 state.pop('_metrics_lock', None)
3073 return state
3075 def __setstate__(self, state):
3076 state['_thread_local'] = threading.local()
3077 state['_metrics_lock'] = threading.Lock()
3078 # Bypass Trackable logic as `__dict__` already contains this info.
3079 object.__setattr__(self, '__dict__', state)
3082class TensorFlowOpLayer(Layer):
3083 """Wraps a TensorFlow Operation in a Layer.
3085 This class is used internally by the Functional API. When a user
3086 uses a raw TensorFlow Operation on symbolic tensors originating
3087 from an `Input` Layer, the resultant operation will be wrapped
3088 with this Layer object in order to make the operation compatible
3089 with the Keras API.
3091 This Layer will create a new, identical operation (except for inputs
3092 and outputs) every time it is called. If `run_eagerly` is `True`,
3093 the op creation and calculation will happen inside an Eager function.
3095 Instances of this Layer are created when `autolambda` is called, which
3096 is whenever a Layer's `__call__` encounters symbolic inputs that do
3097 not have Keras metadata, or when a Network's `__init__` encounters
3098 outputs that do not have Keras metadata.
3100 Attributes:
3101 node_def: String, the serialized NodeDef of the Op this layer will wrap.
3102 name: String, the name of the Layer.
3103 constants: Dict of NumPy arrays, the values of any Tensors needed for this
3104 Operation that do not originate from a Keras `Input` Layer. Since all
3105 placeholders must come from Keras `Input` Layers, these Tensors must be
3106 treated as constant in the Functional API.
3107 trainable: Bool, whether this Layer is trainable. Currently Variables are
3108 not supported, and so this parameter has no effect.
3109 dtype: The default dtype of this Layer. Inherited from `Layer` and has no
3110 effect on this class, however is used in `get_config`.
3111 """
3113 @trackable.no_automatic_dependency_tracking
3114 def __init__(self,
3115 node_def,
3116 name,
3117 constants=None,
3118 trainable=True,
3119 dtype=None):
3120 # Pass autocast=False, as if inputs are cast, input types might not match
3121 # Operation type.
3122 super(TensorFlowOpLayer, self).__init__(
3123 name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype,
3124 autocast=False)
3125 if isinstance(node_def, dict):
3126 self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef())
3127 else:
3128 if not isinstance(node_def, bytes):
3129 node_def = node_def.encode('utf-8')
3130 self.node_def = node_def_pb2.NodeDef.FromString(node_def)
3131 # JSON serialization stringifies keys which are integer input indices.
3132 self.constants = ({
3133 int(index): constant for index, constant in constants.items()
3134 } if constants is not None else {})
3135 # Layer uses original op unless it is called on new inputs.
3136 # This means `built` is not set in `__call__`.
3137 self.built = True
3139 # Do not individually trace TensorflowOpLayers in the SavedModel.
3140 self._must_restore_from_config = True
3142 def call(self, inputs):
3143 if context.executing_eagerly():
3144 return self._defun_call(inputs)
3145 return self._make_op(inputs)
3147 def _make_node_def(self, graph):
3148 node_def = node_def_pb2.NodeDef()
3149 node_def.CopyFrom(self.node_def)
3150 # Used in TPUReplicateContext to indicate whether this node has been cloned
3151 # and to not add TPU attributes.
3152 node_def.attr['_cloned'].b = True
3153 node_def.name = graph.unique_name(node_def.name)
3154 return node_def
3156 def _make_op(self, inputs):
3157 inputs = nest.flatten(inputs)
3158 graph = inputs[0].graph
3159 node_def = self._make_node_def(graph)
3160 with graph.as_default():
3161 for index, constant in self.constants.items():
3162 # Recreate constant in graph to add distribution context.
3163 value = tensor_util.constant_value(constant)
3164 if value is not None:
3165 constant = constant_op.constant(value, name=node_def.input[index])
3166 inputs.insert(index, constant)
3167 # TODO(b/183990973): We should drop or consolidate these private api calls
3168 # for adding an op to the graph and recording its gradient.
3169 c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[])
3170 op = graph._create_op_from_tf_operation(c_op)
3171 op._control_flow_post_processing()
3173 # Record the gradient because custom-made ops don't go through the
3174 # code-gen'd eager call path
3175 op_type = compat.as_str(op.op_def.name)
3176 attr_names = [compat.as_str(attr.name) for attr in op.op_def.attr]
3177 attrs = []
3178 for attr_name in attr_names:
3179 attrs.append(attr_name)
3180 attrs.append(op.get_attr(attr_name))
3181 attrs = tuple(attrs)
3182 backprop.record_gradient(op_type, op.inputs, attrs, op.outputs)
3184 if len(op.outputs) == 1:
3185 return op.outputs[0]
3186 return op.outputs
3188 @def_function.function
3189 def _defun_call(self, inputs):
3190 """Wraps the op creation method in an Eager function for `run_eagerly`."""
3191 return self._make_op(inputs)
3193 def get_config(self):
3194 config = super(TensorFlowOpLayer, self).get_config()
3195 config.update({
3196 # `__init__` prefixes the name. Revert to the constructor argument.
3197 'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):],
3198 'node_def': json_format.MessageToDict(self.node_def),
3199 'constants': {
3200 i: backend.get_value(c) for i, c in self.constants.items()
3201 }
3202 })
3203 return config
3206class AddLoss(Layer):
3207 """Adds its inputs as a loss.
3209 Attributes:
3210 unconditional: Whether or not the loss should be conditioned on the inputs.
3211 """
3213 def __init__(self, unconditional, **kwargs):
3214 # Pass autocast=False, as there is no reason to cast loss to a different
3215 # dtype.
3216 kwargs['autocast'] = False
3217 super(AddLoss, self).__init__(**kwargs)
3218 self.unconditional = unconditional
3220 def call(self, inputs):
3221 self.add_loss(inputs, inputs=(not self.unconditional))
3222 return inputs
3224 def get_config(self):
3225 config = super(AddLoss, self).get_config()
3226 config.update({'unconditional': self.unconditional})
3227 return config
3230class AddMetric(Layer):
3231 """Adds its inputs as a metric.
3233 Attributes:
3234 aggregation: 'mean' or None. How the inputs should be aggregated.
3235 metric_name: The name to use for this metric.
3236 """
3238 def __init__(self, aggregation=None, metric_name=None, **kwargs):
3239 super(AddMetric, self).__init__(**kwargs)
3240 self.aggregation = aggregation
3241 self.metric_name = metric_name
3243 def call(self, inputs):
3244 self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name)
3245 return inputs
3247 def get_config(self):
3248 config = super(AddMetric, self).get_config()
3249 config.update({
3250 'aggregation': self.aggregation,
3251 'metric_name': self.metric_name
3252 })
3253 return config
3256def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
3257 """Check the arguments to see if we are constructing a functional model."""
3258 # We are constructing a functional model if any of the inputs
3259 # are KerasTensors
3260 return any(
3261 isinstance(tensor, keras_tensor.KerasTensor)
3262 for tensor in nest.flatten([inputs, args, kwargs]))
3265def _convert_numpy_or_python_types(x):
3266 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
3267 return tensor_conversion.convert_to_tensor_v2_with_dispatch(x)
3268 return x
3271# Avoid breaking users who directly import this symbol from this file.
3272# TODO(fchollet): remove this.
3273InputSpec = input_spec.InputSpec # pylint:disable=invalid-name