Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py: 25%

1235 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""Contains the base Layer class, from which all layers inherit.""" 

17 

18import collections 

19import copy 

20import functools 

21import itertools 

22import threading 

23import warnings 

24import weakref 

25 

26import numpy as np 

27 

28from google.protobuf import json_format 

29from tensorflow.core.framework import node_def_pb2 

30from tensorflow.python import tf2 

31from tensorflow.python.autograph.core import ag_ctx 

32from tensorflow.python.autograph.impl import api as autograph 

33from tensorflow.python.distribute import distribute_lib 

34from tensorflow.python.eager import backprop 

35from tensorflow.python.eager import context 

36from tensorflow.python.eager import def_function 

37from tensorflow.python.framework import constant_op 

38from tensorflow.python.framework import dtypes 

39from tensorflow.python.framework import func_graph 

40from tensorflow.python.framework import ops 

41from tensorflow.python.framework import sparse_tensor 

42from tensorflow.python.framework import tensor_conversion 

43from tensorflow.python.framework import tensor_spec 

44from tensorflow.python.framework import tensor_util 

45from tensorflow.python.keras import backend 

46from tensorflow.python.keras import constraints 

47from tensorflow.python.keras import initializers 

48from tensorflow.python.keras import regularizers 

49from tensorflow.python.keras.engine import base_layer_utils 

50from tensorflow.python.keras.engine import input_spec 

51from tensorflow.python.keras.engine import keras_tensor 

52from tensorflow.python.keras.engine import node as node_module 

53from tensorflow.python.keras.mixed_precision import autocast_variable 

54from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 

55from tensorflow.python.keras.mixed_precision import policy 

56from tensorflow.python.keras.saving.saved_model import layer_serialization 

57from tensorflow.python.keras.utils import generic_utils 

58from tensorflow.python.keras.utils import layer_utils 

59from tensorflow.python.keras.utils import object_identity 

60from tensorflow.python.keras.utils import tf_inspect 

61from tensorflow.python.keras.utils import tf_utils 

62from tensorflow.python.keras.utils import version_utils 

63from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import 

64from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import 

65from tensorflow.python.module import module 

66from tensorflow.python.ops import array_ops 

67from tensorflow.python.ops import math_ops 

68from tensorflow.python.ops import variables as tf_variables 

69from tensorflow.python.ops.numpy_ops import np_arrays 

70from tensorflow.python.ops.ragged import ragged_tensor 

71from tensorflow.python.platform import tf_logging 

72from tensorflow.python.trackable import autotrackable 

73from tensorflow.python.trackable import base as trackable 

74from tensorflow.python.trackable import data_structures 

75from tensorflow.python.util import compat 

76from tensorflow.python.util import nest 

77from tensorflow.python.util.tf_export import get_canonical_name_for_symbol 

78from tensorflow.python.util.tf_export import keras_export 

79from tensorflow.tools.docs import doc_controls 

80 

81# A module that only depends on `keras.layers` import these from here. 

82 

83# pylint: disable=g-inconsistent-quotes 

84metrics_mod = generic_utils.LazyLoader( 

85 "metrics_mod", globals(), 

86 "tensorflow.python.keras.metrics") 

87# pylint: enable=g-inconsistent-quotes 

88 

89# Prefix that is added to the TF op layer names. 

90_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_' 

91 

92# TODO(mdan): Should we have a single generic type for types that can be passed 

93# to tf.cast? 

94_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor, 

95 ragged_tensor.RaggedTensor) 

96 

97 

98@keras_export('keras.layers.Layer') 

99class Layer(module.Module, version_utils.LayerVersionSelector): 

100 """This is the class from which all layers inherit. 

101 

102 A layer is a callable object that takes as input one or more tensors and 

103 that outputs one or more tensors. It involves *computation*, defined 

104 in the `call()` method, and a *state* (weight variables), defined 

105 either in the constructor `__init__()` or in the `build()` method. 

106 

107 Users will just instantiate a layer and then treat it as a callable. 

108 

109 Args: 

110 trainable: Boolean, whether the layer's variables should be trainable. 

111 name: String name of the layer. 

112 dtype: The dtype of the layer's computations and weights. Can also be a 

113 `tf.keras.mixed_precision.Policy`, which allows the computation and weight 

114 dtype to differ. Default of `None` means to use 

115 `tf.keras.mixed_precision.global_policy()`, which is a float32 policy 

116 unless set to different value. 

117 dynamic: Set this to `True` if your layer should only be run eagerly, and 

118 should not be used to generate a static computation graph. 

119 This would be the case for a Tree-RNN or a recursive network, 

120 for example, or generally for any layer that manipulates tensors 

121 using Python control flow. If `False`, we assume that the layer can 

122 safely be used to generate a static computation graph. 

123 

124 Attributes: 

125 name: The name of the layer (string). 

126 dtype: The dtype of the layer's weights. 

127 variable_dtype: Alias of `dtype`. 

128 compute_dtype: The dtype of the layer's computations. Layers automatically 

129 cast inputs to this dtype which causes the computations and output to also 

130 be in this dtype. When mixed precision is used with a 

131 `tf.keras.mixed_precision.Policy`, this will be different than 

132 `variable_dtype`. 

133 dtype_policy: The layer's dtype policy. See the 

134 `tf.keras.mixed_precision.Policy` documentation for details. 

135 trainable_weights: List of variables to be included in backprop. 

136 non_trainable_weights: List of variables that should not be 

137 included in backprop. 

138 weights: The concatenation of the lists trainable_weights and 

139 non_trainable_weights (in this order). 

140 trainable: Whether the layer should be trained (boolean), i.e. whether 

141 its potentially-trainable weights should be returned as part of 

142 `layer.trainable_weights`. 

143 input_spec: Optional (list of) `InputSpec` object(s) specifying the 

144 constraints on inputs that can be accepted by the layer. 

145 

146 We recommend that descendants of `Layer` implement the following methods: 

147 

148 * `__init__()`: Defines custom layer attributes, and creates layer state 

149 variables that do not depend on input shapes, using `add_weight()`. 

150 * `build(self, input_shape)`: This method can be used to create weights that 

151 depend on the shape(s) of the input(s), using `add_weight()`. `__call__()` 

152 will automatically build the layer (if it has not been built yet) by 

153 calling `build()`. 

154 * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making 

155 sure `build()` has been called. `call()` performs the logic of applying the 

156 layer to the input tensors (which should be passed in as argument). 

157 Two reserved keyword arguments you can optionally use in `call()` are: 

158 - `training` (boolean, whether the call is in inference mode or training 

159 mode). See more details in [the layer/model subclassing guide]( 

160 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method) 

161 - `mask` (boolean tensor encoding masked timesteps in the input, used 

162 in RNN layers). See more details in [the layer/model subclassing guide]( 

163 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method) 

164 A typical signature for this method is `call(self, inputs)`, and user could 

165 optionally add `training` and `mask` if the layer need them. `*args` and 

166 `**kwargs` is only useful for future extension when more input parameters 

167 are planned to be added. 

168 * `get_config(self)`: Returns a dictionary containing the configuration used 

169 to initialize this layer. If the keys differ from the arguments 

170 in `__init__`, then override `from_config(self)` as well. 

171 This method is used when saving 

172 the layer or a model that contains this layer. 

173 

174 Examples: 

175 

176 Here's a basic example: a layer with two variables, `w` and `b`, 

177 that returns `y = w . x + b`. 

178 It shows how to implement `build()` and `call()`. 

179 Variables set as attributes of a layer are tracked as weights 

180 of the layers (in `layer.weights`). 

181 

182 ```python 

183 class SimpleDense(Layer): 

184 

185 def __init__(self, units=32): 

186 super(SimpleDense, self).__init__() 

187 self.units = units 

188 

189 def build(self, input_shape): # Create the state of the layer (weights) 

190 w_init = tf.random_normal_initializer() 

191 self.w = tf.Variable( 

192 initial_value=w_init(shape=(input_shape[-1], self.units), 

193 dtype='float32'), 

194 trainable=True) 

195 b_init = tf.zeros_initializer() 

196 self.b = tf.Variable( 

197 initial_value=b_init(shape=(self.units,), dtype='float32'), 

198 trainable=True) 

199 

200 def call(self, inputs): # Defines the computation from inputs to outputs 

201 return tf.matmul(inputs, self.w) + self.b 

202 

203 # Instantiates the layer. 

204 linear_layer = SimpleDense(4) 

205 

206 # This will also call `build(input_shape)` and create the weights. 

207 y = linear_layer(tf.ones((2, 2))) 

208 assert len(linear_layer.weights) == 2 

209 

210 # These weights are trainable, so they're listed in `trainable_weights`: 

211 assert len(linear_layer.trainable_weights) == 2 

212 ``` 

213 

214 Note that the method `add_weight()` offers a shortcut to create weights: 

215 

216 ```python 

217 class SimpleDense(Layer): 

218 

219 def __init__(self, units=32): 

220 super(SimpleDense, self).__init__() 

221 self.units = units 

222 

223 def build(self, input_shape): 

224 self.w = self.add_weight(shape=(input_shape[-1], self.units), 

225 initializer='random_normal', 

226 trainable=True) 

227 self.b = self.add_weight(shape=(self.units,), 

228 initializer='random_normal', 

229 trainable=True) 

230 

231 def call(self, inputs): 

232 return tf.matmul(inputs, self.w) + self.b 

233 ``` 

234 

235 Besides trainable weights, updated via backpropagation during training, 

236 layers can also have non-trainable weights. These weights are meant to 

237 be updated manually during `call()`. Here's a example layer that computes 

238 the running sum of its inputs: 

239 

240 ```python 

241 class ComputeSum(Layer): 

242 

243 def __init__(self, input_dim): 

244 super(ComputeSum, self).__init__() 

245 # Create a non-trainable weight. 

246 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), 

247 trainable=False) 

248 

249 def call(self, inputs): 

250 self.total.assign_add(tf.reduce_sum(inputs, axis=0)) 

251 return self.total 

252 

253 my_sum = ComputeSum(2) 

254 x = tf.ones((2, 2)) 

255 

256 y = my_sum(x) 

257 print(y.numpy()) # [2. 2.] 

258 

259 y = my_sum(x) 

260 print(y.numpy()) # [4. 4.] 

261 

262 assert my_sum.weights == [my_sum.total] 

263 assert my_sum.non_trainable_weights == [my_sum.total] 

264 assert my_sum.trainable_weights == [] 

265 ``` 

266 

267 For more information about creating layers, see the guide 

268 [Making new Layers and Models via subclassing]( 

269 https://www.tensorflow.org/guide/keras/custom_layers_and_models) 

270 """ 

271 

272 # See tf.Module for the usage of this property. 

273 # The key for _obj_reference_counts_dict is a Trackable, which could be a 

274 # variable or layer etc. tf.Module._flatten will fail to flatten the key 

275 # since it is trying to convert Trackable to a string. This attribute can be 

276 # ignored even after the fix of nest lib, since the trackable object should 

277 # already been available as individual attributes. _obj_reference_counts_dict 

278 # just contains a copy of them. 

279 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 

280 ('_obj_reference_counts_dict',), 

281 module.Module._TF_MODULE_IGNORED_PROPERTIES 

282 )) 

283 

284 # When loading from a SavedModel, Layers typically can be revived into a 

285 # generic Layer wrapper. Sometimes, however, layers may implement methods 

286 # that go beyond this wrapper, as in the case of PreprocessingLayers' 

287 # `adapt` method. When this is the case, layer implementers can override 

288 # must_restore_from_config to return True; layers with this property must 

289 # be restored into their actual objects (and will fail if the object is 

290 # not available to the restoration code). 

291 _must_restore_from_config = False 

292 

293 def _get_cell_name(self): 

294 canonical_name = get_canonical_name_for_symbol( 

295 self.__class__, api_name='keras', add_prefix_to_v1_names=True) 

296 if canonical_name is not None: 

297 return 'tf.{}'.format(canonical_name) 

298 return self.__class__.__module__ + '.' + self.__class__.__name__ 

299 

300 def _instrument_layer_creation(self): 

301 self._instrumented_keras_api = False 

302 self._instrumented_keras_layer_class = False 

303 self._instrumented_keras_model_class = False 

304 if not getattr(self, '_disable_keras_instrumentation', False): 

305 self._instrumented_keras_api = True 

306 if getattr(self, '_is_model_for_instrumentation', False): 

307 self._instrumented_keras_model_class = True 

308 else: 

309 self._instrumented_keras_layer_class = True 

310 

311 @trackable.no_automatic_dependency_tracking 

312 def __init__(self, 

313 trainable=True, 

314 name=None, 

315 dtype=None, 

316 dynamic=False, 

317 **kwargs): 

318 self._instrument_layer_creation() 

319 

320 # These properties should be set by the user via keyword arguments. 

321 # note that 'dtype', 'input_shape' and 'batch_input_shape' 

322 # are only applicable to input layers: do not pass these keywords 

323 # to non-input layers. 

324 allowed_kwargs = { 

325 'input_dim', 

326 'input_shape', 

327 'batch_input_shape', 

328 'batch_size', 

329 'weights', 

330 'activity_regularizer', 

331 'autocast', 

332 'implementation', 

333 } 

334 # Validate optional keyword arguments. 

335 generic_utils.validate_kwargs(kwargs, allowed_kwargs) 

336 

337 # Mutable properties 

338 # Indicates whether the layer's weights are updated during training 

339 # and whether the layer's updates are run during training. 

340 self._trainable = trainable 

341 # A stateful layer is a layer whose updates are run during inference too, 

342 # for instance stateful RNNs. 

343 self._stateful = False 

344 # Indicates whether `build` needs to be called upon layer call, to create 

345 # the layer's weights. 

346 self.built = False 

347 # Provides information about which inputs are compatible with the layer. 

348 self._input_spec = None 

349 

350 # SavedModel-related attributes. 

351 # Record the build input shape for loading purposes. 

352 # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is 

353 # submitted. 

354 self._build_input_shape = None 

355 self._saved_model_inputs_spec = None 

356 

357 # `Layer.compute_mask` will be called at the end of `Layer.__call__` if 

358 # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets 

359 # `self.supports_masking=True`. 

360 self._supports_masking = not generic_utils.is_default(self.compute_mask) 

361 

362 self._init_set_name(name) 

363 self._activity_regularizer = regularizers.get( 

364 kwargs.pop('activity_regularizer', None)) 

365 self._maybe_create_attribute('_trainable_weights', []) 

366 self._maybe_create_attribute('_non_trainable_weights', []) 

367 self._updates = [] 

368 # Object to store all thread local layer properties. 

369 self._thread_local = threading.local() 

370 # A list of zero-argument lambdas which return Tensors, used for variable 

371 # regularizers. 

372 self._callable_losses = [] 

373 # A list of symbolic Tensors containing activity regularizers and losses 

374 # manually added through `add_loss` in graph-building mode. 

375 self._losses = [] 

376 # A list of metric instances corresponding to the symbolic metric tensors 

377 # added using the `add_metric` API. 

378 self._metrics = [] 

379 # Ensures the same metric is not added multiple times in `MirroredStrategy`. 

380 self._metrics_lock = threading.Lock() 

381 

382 # Both graph and subclassed networks have a dtype policy. For graph 

383 # networks, the policy's compute and variable dtypes are ignored. Such 

384 # networks only use the policy if it is a PolicyV1, in which case it uses 

385 # the PolicyV1's loss_scale (Policy does not have a loss_scale). For 

386 # subclassed networks, the compute and variable dtypes are used as like any 

387 # ordinary layer. 

388 self._set_dtype_policy(dtype) 

389 # Boolean indicating whether the layer automatically casts its inputs to the 

390 # layer's compute_dtype. 

391 self._autocast = kwargs.get('autocast', 

392 base_layer_utils.v2_dtype_behavior_enabled()) 

393 

394 # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s. 

395 # Ordered by when the object was assigned as an attr. 

396 # Entries are unique. 

397 self._maybe_create_attribute('_self_tracked_trackables', []) 

398 

399 # These lists will be filled via successive calls 

400 # to self._add_inbound_node(). 

401 # Used in symbolic mode only, only in conjunction with graph-networks 

402 self._inbound_nodes_value = [] 

403 self._outbound_nodes_value = [] 

404 

405 self._init_call_fn_args() 

406 

407 # Whether the `call` method can be used to build a TF graph without issues. 

408 # This attribute has no effect if the model is created using the Functional 

409 # API. Instead, `model.dynamic` is determined based on the internal layers. 

410 self._dynamic = dynamic 

411 

412 # Manage input shape information if passed. 

413 if 'input_dim' in kwargs and 'input_shape' not in kwargs: 

414 # Backwards compatibility: alias 'input_dim' to 'input_shape'. 

415 kwargs['input_shape'] = (kwargs['input_dim'],) 

416 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 

417 # In this case we will later create an input layer 

418 # to insert before the current layer 

419 if 'batch_input_shape' in kwargs: 

420 batch_input_shape = tuple(kwargs['batch_input_shape']) 

421 elif 'input_shape' in kwargs: 

422 if 'batch_size' in kwargs: 

423 batch_size = kwargs['batch_size'] 

424 else: 

425 batch_size = None 

426 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 

427 self._batch_input_shape = batch_input_shape 

428 

429 # Manage initial weight values if passed. 

430 self._initial_weights = kwargs.get('weights', None) 

431 

432 # Whether the layer will track any layers that is set as attribute on itself 

433 # as sub-layers, the weights from the sub-layers will be included in the 

434 # parent layer's variables() as well. 

435 # Default to True, which means auto tracking is turned on. Certain subclass 

436 # might want to turn it off, like Sequential model. 

437 self._auto_track_sub_layers = True 

438 

439 # For backwards compat reasons, most built-in layers do not guarantee 

440 # That they will 100% preserve the structure of input args when saving 

441 # / loading configs. E.g. they may un-nest an arg that is 

442 # a list with one element. 

443 self._preserve_input_structure_in_config = False 

444 

445 @trackable.no_automatic_dependency_tracking 

446 @generic_utils.default 

447 def build(self, input_shape): 

448 """Creates the variables of the layer (optional, for subclass implementers). 

449 

450 This is a method that implementers of subclasses of `Layer` or `Model` 

451 can override if they need a state-creation step in-between 

452 layer instantiation and layer call. 

453 

454 This is typically used to create the weights of `Layer` subclasses. 

455 

456 Args: 

457 input_shape: Instance of `TensorShape`, or list of instances of 

458 `TensorShape` if the layer expects a list of inputs 

459 (one instance per input). 

460 """ 

461 # Only record the build input shapes of overridden build methods. 

462 if not hasattr(self.build, '_is_default'): 

463 self._build_input_shape = input_shape 

464 self.built = True 

465 

466 @doc_controls.for_subclass_implementers 

467 def call(self, inputs, *args, **kwargs): # pylint: disable=unused-argument 

468 """This is where the layer's logic lives. 

469 

470 Note here that `call()` method in `tf.keras` is little bit different 

471 from `keras` API. In `keras` API, you can pass support masking for 

472 layers as additional arguments. Whereas `tf.keras` has `compute_mask()` 

473 method to support masking. 

474 

475 Args: 

476 inputs: Input tensor, or dict/list/tuple of input tensors. 

477 The first positional `inputs` argument is subject to special rules: 

478 - `inputs` must be explicitly passed. A layer cannot have zero 

479 arguments, and `inputs` cannot be provided via the default value 

480 of a keyword argument. 

481 - NumPy array or Python scalar values in `inputs` get cast as tensors. 

482 - Keras mask metadata is only collected from `inputs`. 

483 - Layers are built (`build(input_shape)` method) 

484 using shape info from `inputs` only. 

485 - `input_spec` compatibility is only checked against `inputs`. 

486 - Mixed precision input casting is only applied to `inputs`. 

487 If a layer has tensor arguments in `*args` or `**kwargs`, their 

488 casting behavior in mixed precision should be handled manually. 

489 - The SavedModel input specification is generated using `inputs` only. 

490 - Integration with various ecosystem packages like TFMOT, TFLite, 

491 TF.js, etc is only supported for `inputs` and not for tensors in 

492 positional and keyword arguments. 

493 *args: Additional positional arguments. May contain tensors, although 

494 this is not recommended, for the reasons above. 

495 **kwargs: Additional keyword arguments. May contain tensors, although 

496 this is not recommended, for the reasons above. 

497 The following optional keyword arguments are reserved: 

498 - `training`: Boolean scalar tensor of Python boolean indicating 

499 whether the `call` is meant for training or inference. 

500 - `mask`: Boolean input mask. If the layer's `call()` method takes a 

501 `mask` argument, its default value will be set to the mask generated 

502 for `inputs` by the previous layer (if `input` did come from a layer 

503 that generated a corresponding mask, i.e. if it came from a Keras 

504 layer with masking support). 

505 

506 Returns: 

507 A tensor or list/tuple of tensors. 

508 """ 

509 return inputs 

510 

511 @doc_controls.for_subclass_implementers 

512 def _add_trackable(self, trackable_object, trainable): 

513 """Adds a Trackable object to this layer's state. 

514 

515 Args: 

516 trackable_object: The tf.tracking.Trackable object to add. 

517 trainable: Boolean, whether the variable should be part of the layer's 

518 "trainable_variables" (e.g. variables, biases) or 

519 "non_trainable_variables" (e.g. BatchNorm mean and variance). 

520 

521 Returns: 

522 The TrackableWeightHandler used to track this object. 

523 """ 

524 if isinstance(trackable_object, base_layer_utils.TrackableWeightHandler): 

525 handler = trackable_object 

526 else: 

527 handler = base_layer_utils.TrackableWeightHandler(trackable_object) 

528 if trainable: 

529 self._trainable_weights.append(handler) 

530 else: 

531 self._non_trainable_weights.append(handler) 

532 return handler 

533 

534 @doc_controls.for_subclass_implementers 

535 def add_weight(self, 

536 name=None, 

537 shape=None, 

538 dtype=None, 

539 initializer=None, 

540 regularizer=None, 

541 trainable=None, 

542 constraint=None, 

543 use_resource=None, 

544 synchronization=tf_variables.VariableSynchronization.AUTO, 

545 aggregation=tf_variables.VariableAggregation.NONE, 

546 **kwargs): 

547 """Adds a new variable to the layer. 

548 

549 Args: 

550 name: Variable name. 

551 shape: Variable shape. Defaults to scalar if unspecified. 

552 dtype: The type of the variable. Defaults to `self.dtype`. 

553 initializer: Initializer instance (callable). 

554 regularizer: Regularizer instance (callable). 

555 trainable: Boolean, whether the variable should be part of the layer's 

556 "trainable_variables" (e.g. variables, biases) 

557 or "non_trainable_variables" (e.g. BatchNorm mean and variance). 

558 Note that `trainable` cannot be `True` if `synchronization` 

559 is set to `ON_READ`. 

560 constraint: Constraint instance (callable). 

561 use_resource: Whether to use `ResourceVariable`. 

562 synchronization: Indicates when a distributed a variable will be 

563 aggregated. Accepted values are constants defined in the class 

564 `tf.VariableSynchronization`. By default the synchronization is set to 

565 `AUTO` and the current `DistributionStrategy` chooses 

566 when to synchronize. If `synchronization` is set to `ON_READ`, 

567 `trainable` must not be set to `True`. 

568 aggregation: Indicates how a distributed variable will be aggregated. 

569 Accepted values are constants defined in the class 

570 `tf.VariableAggregation`. 

571 **kwargs: Additional keyword arguments. Accepted values are `getter`, 

572 `collections`, `experimental_autocast` and `caching_device`. 

573 

574 Returns: 

575 The variable created. 

576 

577 Raises: 

578 ValueError: When giving unsupported dtype and no initializer or when 

579 trainable has been set to True with synchronization set as `ON_READ`. 

580 """ 

581 if shape is None: 

582 shape = () 

583 kwargs.pop('partitioner', None) # Ignored. 

584 # Validate optional keyword arguments. 

585 for kwarg in kwargs: 

586 if kwarg not in ['collections', 'experimental_autocast', 

587 'caching_device', 'getter']: 

588 raise TypeError('Unknown keyword argument:', kwarg) 

589 collections_arg = kwargs.pop('collections', None) 

590 # 'experimental_autocast' can be set to False by the caller to indicate an 

591 # AutoCastVariable should never be created. 

592 autocast = kwargs.pop('experimental_autocast', True) 

593 # See the docstring for tf.Variable about the details for caching_device. 

594 caching_device = kwargs.pop('caching_device', None) 

595 

596 if dtype is None: 

597 dtype = self.dtype or backend.floatx() 

598 dtype = dtypes.as_dtype(dtype) 

599 if self._dtype_policy.variable_dtype is None: 

600 # The policy is "_infer", so we infer the policy from the variable dtype. 

601 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) 

602 initializer = initializers.get(initializer) 

603 regularizer = regularizers.get(regularizer) 

604 constraint = constraints.get(constraint) 

605 

606 if synchronization == tf_variables.VariableSynchronization.ON_READ: 

607 if trainable: 

608 raise ValueError( 

609 'Synchronization value can be set to ' 

610 'VariableSynchronization.ON_READ only for non-trainable variables. ' 

611 'You have specified trainable=True and ' 

612 'synchronization=VariableSynchronization.ON_READ.') 

613 else: 

614 # Set trainable to be false when variable is to be synced on read. 

615 trainable = False 

616 elif trainable is None: 

617 trainable = True 

618 

619 # Initialize variable when no initializer provided 

620 if initializer is None: 

621 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 

622 if dtype.is_floating: 

623 initializer = initializers.get('glorot_uniform') 

624 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 

625 # If dtype is DT_BOOL, provide a default value `FALSE` 

626 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 

627 initializer = initializers.get('zeros') 

628 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 

629 elif 'getter' not in kwargs: 

630 # When `getter` is specified, it's possibly fine for `initializer` to be 

631 # None since it's up to the custom `getter` to raise error in case it 

632 # indeed needs `initializer`. 

633 raise ValueError('An initializer for variable %s of type %s is required' 

634 ' for layer %s' % (name, dtype.base_dtype, self.name)) 

635 

636 getter = kwargs.pop('getter', base_layer_utils.make_variable) 

637 if (autocast and 

638 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype 

639 and dtype.is_floating): 

640 old_getter = getter 

641 # Wrap variable constructor to return an AutoCastVariable. 

642 def getter(*args, **kwargs): # pylint: disable=function-redefined 

643 variable = old_getter(*args, **kwargs) 

644 return autocast_variable.create_autocast_variable(variable) 

645 # Also the caching_device does not work with the mixed precision API, 

646 # disable it if it is specified. 

647 # TODO(b/142020079): Reenable it once the bug is fixed. 

648 if caching_device is not None: 

649 tf_logging.warning( 

650 '`caching_device` does not work with mixed precision API. Ignoring ' 

651 'user specified `caching_device`.') 

652 caching_device = None 

653 

654 variable = self._add_variable_with_custom_getter( 

655 name=name, 

656 shape=shape, 

657 # TODO(allenl): a `make_variable` equivalent should be added as a 

658 # `Trackable` method. 

659 getter=getter, 

660 # Manage errors in Layer rather than Trackable. 

661 overwrite=True, 

662 initializer=initializer, 

663 dtype=dtype, 

664 constraint=constraint, 

665 trainable=trainable, 

666 use_resource=use_resource, 

667 collections=collections_arg, 

668 synchronization=synchronization, 

669 aggregation=aggregation, 

670 caching_device=caching_device) 

671 if regularizer is not None: 

672 # TODO(fchollet): in the future, this should be handled at the 

673 # level of variable creation, and weight regularization losses 

674 # should be variable attributes. 

675 name_in_scope = variable.name[:variable.name.find(':')] 

676 self._handle_weight_regularization(name_in_scope, 

677 variable, 

678 regularizer) 

679 if base_layer_utils.is_split_variable(variable): 

680 for v in variable: 

681 backend.track_variable(v) 

682 if trainable: 

683 self._trainable_weights.append(v) 

684 else: 

685 self._non_trainable_weights.append(v) 

686 else: 

687 backend.track_variable(variable) 

688 if trainable: 

689 self._trainable_weights.append(variable) 

690 else: 

691 self._non_trainable_weights.append(variable) 

692 return variable 

693 

694 @generic_utils.default 

695 def get_config(self): 

696 """Returns the config of the layer. 

697 

698 A layer config is a Python dictionary (serializable) 

699 containing the configuration of a layer. 

700 The same layer can be reinstantiated later 

701 (without its trained weights) from this configuration. 

702 

703 The config of a layer does not include connectivity 

704 information, nor the layer class name. These are handled 

705 by `Network` (one layer of abstraction above). 

706 

707 Note that `get_config()` does not guarantee to return a fresh copy of dict 

708 every time it is called. The callers should make a copy of the returned dict 

709 if they want to modify it. 

710 

711 Returns: 

712 Python dictionary. 

713 """ 

714 all_args = tf_inspect.getfullargspec(self.__init__).args 

715 config = { 

716 'name': self.name, 

717 'trainable': self.trainable, 

718 } 

719 if hasattr(self, '_batch_input_shape'): 

720 config['batch_input_shape'] = self._batch_input_shape 

721 config['dtype'] = policy.serialize(self._dtype_policy) 

722 if hasattr(self, 'dynamic'): 

723 # Only include `dynamic` in the `config` if it is `True` 

724 if self.dynamic: 

725 config['dynamic'] = self.dynamic 

726 elif 'dynamic' in all_args: 

727 all_args.remove('dynamic') 

728 expected_args = config.keys() 

729 # Finds all arguments in the `__init__` that are not in the config: 

730 extra_args = [arg for arg in all_args if arg not in expected_args] 

731 # Check that either the only argument in the `__init__` is `self`, 

732 # or that `get_config` has been overridden: 

733 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'): 

734 raise NotImplementedError('Layer %s has arguments in `__init__` and ' 

735 'therefore must override `get_config`.' % 

736 self.__class__.__name__) 

737 return config 

738 

739 @classmethod 

740 def from_config(cls, config): 

741 """Creates a layer from its config. 

742 

743 This method is the reverse of `get_config`, 

744 capable of instantiating the same layer from the config 

745 dictionary. It does not handle layer connectivity 

746 (handled by Network), nor weights (handled by `set_weights`). 

747 

748 Args: 

749 config: A Python dictionary, typically the 

750 output of get_config. 

751 

752 Returns: 

753 A layer instance. 

754 """ 

755 return cls(**config) 

756 

757 def compute_output_shape(self, input_shape): 

758 """Computes the output shape of the layer. 

759 

760 If the layer has not been built, this method will call `build` on the 

761 layer. This assumes that the layer will later be used with inputs that 

762 match the input shape provided here. 

763 

764 Args: 

765 input_shape: Shape tuple (tuple of integers) 

766 or list of shape tuples (one per output tensor of the layer). 

767 Shape tuples can include None for free dimensions, 

768 instead of an integer. 

769 

770 Returns: 

771 An input shape tuple. 

772 """ 

773 if context.executing_eagerly(): 

774 # In this case we build the model first in order to do shape inference. 

775 # This is acceptable because the framework only calls 

776 # `compute_output_shape` on shape values that the layer would later be 

777 # built for. It would however cause issues in case a user attempts to 

778 # use `compute_output_shape` manually with shapes that are incompatible 

779 # with the shape the Layer will be called on (these users will have to 

780 # implement `compute_output_shape` themselves). 

781 self._maybe_build(input_shape) 

782 with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default(): 

783 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 

784 def _make_placeholder_like(shape): 

785 ph = backend.placeholder(shape=shape, dtype=self.dtype) 

786 ph._keras_mask = None 

787 return ph 

788 inputs = nest.map_structure(_make_placeholder_like, input_shape) 

789 try: 

790 outputs = self(inputs, training=False) 

791 except TypeError as e: 

792 raise NotImplementedError( 

793 'We could not automatically infer the static shape of the ' 

794 'layer\'s output. Please implement the ' 

795 '`compute_output_shape` method on your layer (%s).' % 

796 self.__class__.__name__) from e 

797 return nest.map_structure(lambda t: t.shape, outputs) 

798 raise NotImplementedError( 

799 'Please run in eager mode or implement the `compute_output_shape` ' 

800 'method on your layer (%s).' % self.__class__.__name__) 

801 

802 @doc_controls.for_subclass_implementers 

803 def compute_output_signature(self, input_signature): 

804 """Compute the output tensor signature of the layer based on the inputs. 

805 

806 Unlike a TensorShape object, a TensorSpec object contains both shape 

807 and dtype information for a tensor. This method allows layers to provide 

808 output dtype information if it is different from the input dtype. 

809 For any layer that doesn't implement this function, 

810 the framework will fall back to use `compute_output_shape`, and will 

811 assume that the output dtype matches the input dtype. 

812 

813 Args: 

814 input_signature: Single TensorSpec or nested structure of TensorSpec 

815 objects, describing a candidate input for the layer. 

816 

817 Returns: 

818 Single TensorSpec or nested structure of TensorSpec objects, describing 

819 how the layer would transform the provided input. 

820 

821 Raises: 

822 TypeError: If input_signature contains a non-TensorSpec object. 

823 """ 

824 def check_type_return_shape(s): 

825 if not isinstance(s, tensor_spec.TensorSpec): 

826 raise TypeError('Only TensorSpec signature types are supported, ' 

827 'but saw signature entry: {}.'.format(s)) 

828 return s.shape 

829 input_shape = nest.map_structure(check_type_return_shape, input_signature) 

830 output_shape = self.compute_output_shape(input_shape) 

831 dtype = self._compute_dtype 

832 if dtype is None: 

833 input_dtypes = [s.dtype for s in nest.flatten(input_signature)] 

834 # Default behavior when self.dtype is None, is to use the first input's 

835 # dtype. 

836 dtype = input_dtypes[0] 

837 return nest.map_structure( 

838 lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), 

839 output_shape) 

840 

841 def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): 

842 if self.dynamic: 

843 # We will use static shape inference to return symbolic tensors 

844 # matching the specifications of the layer outputs. 

845 # Since `self.dynamic` is True, we will never attempt to 

846 # run the underlying TF graph (which is disconnected). 

847 # TODO(fchollet): consider py_func as an alternative, which 

848 # would enable us to run the underlying graph if needed. 

849 input_signature = nest.map_structure( 

850 lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype), 

851 inputs) 

852 output_signature = self.compute_output_signature(input_signature) 

853 return nest.map_structure(keras_tensor.KerasTensor, output_signature) 

854 else: 

855 return self._infer_output_signature(inputs, args, kwargs, input_masks) 

856 

857 def _infer_output_signature(self, inputs, args, kwargs, input_masks): 

858 """TODO(kaftan): Docstring.""" 

859 

860 call_fn = self.call 

861 # Wrapping `call` function in autograph to allow for dynamic control 

862 # flow and control dependencies in call. We are limiting this to 

863 # subclassed layers as autograph is strictly needed only for 

864 # subclassed layers and models. 

865 # tf_convert will respect the value of autograph setting in the 

866 # enclosing tf.function, if any. 

867 if (base_layer_utils.is_subclassed(self) and 

868 not base_layer_utils.from_saved_model(self)): 

869 call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 

870 

871 # We enter a scratch graph and build placeholder inputs inside of it that 

872 # match the input args. 

873 # We then call the layer inside of the scratch graph to identify the 

874 # output signatures, then we build KerasTensors corresponding to those 

875 # outputs. 

876 scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph') 

877 with scratch_graph.as_default(): 

878 inputs = nest.map_structure( 

879 keras_tensor.keras_tensor_to_placeholder, inputs) 

880 args = nest.map_structure( 

881 keras_tensor.keras_tensor_to_placeholder, args) 

882 kwargs = nest.map_structure( 

883 keras_tensor.keras_tensor_to_placeholder, kwargs) 

884 input_masks = nest.map_structure( 

885 keras_tensor.keras_tensor_to_placeholder, input_masks) 

886 

887 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable 

888 with autocast_variable.enable_auto_cast_variables( 

889 self._compute_dtype_object): 

890 # Build layer if applicable (if the `build` method has been 

891 # overridden). 

892 # TODO(kaftan): do we maybe_build here, or have we already done it? 

893 self._maybe_build(inputs) 

894 inputs = self._maybe_cast_inputs(inputs) 

895 outputs = call_fn(inputs, *args, **kwargs) 

896 

897 self._handle_activity_regularization(inputs, outputs) 

898 self._set_mask_metadata(inputs, outputs, input_masks, 

899 build_graph=False) 

900 outputs = nest.map_structure( 

901 keras_tensor.keras_tensor_from_tensor, outputs) 

902 

903 if hasattr(self, '_set_inputs') and not self.inputs: 

904 # TODO(kaftan): figure out if we need to do this at all 

905 # Subclassed network: explicitly set metadata normally set by 

906 # a call to self._set_inputs(). 

907 self._set_inputs(inputs, outputs) 

908 del scratch_graph 

909 return outputs 

910 

911 @generic_utils.default 

912 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 

913 """Computes an output mask tensor. 

914 

915 Args: 

916 inputs: Tensor or list of tensors. 

917 mask: Tensor or list of tensors. 

918 

919 Returns: 

920 None or a tensor (or list of tensors, 

921 one per output tensor of the layer). 

922 """ 

923 if not self._supports_masking: 

924 if any(m is not None for m in nest.flatten(mask)): 

925 raise TypeError('Layer ' + self.name + ' does not support masking, ' 

926 'but was passed an input_mask: ' + str(mask)) 

927 # masking not explicitly supported: return None as mask. 

928 return None 

929 # if masking is explicitly supported, by default 

930 # carry over the input mask 

931 return mask 

932 

933 def __call__(self, *args, **kwargs): 

934 """Wraps `call`, applying pre- and post-processing steps. 

935 

936 Args: 

937 *args: Positional arguments to be passed to `self.call`. 

938 **kwargs: Keyword arguments to be passed to `self.call`. 

939 

940 Returns: 

941 Output tensor(s). 

942 

943 Note: 

944 - The following optional keyword arguments are reserved for specific uses: 

945 * `training`: Boolean scalar tensor of Python boolean indicating 

946 whether the `call` is meant for training or inference. 

947 * `mask`: Boolean input mask. 

948 - If the layer's `call` method takes a `mask` argument (as some Keras 

949 layers do), its default value will be set to the mask generated 

950 for `inputs` by the previous layer (if `input` did come from 

951 a layer that generated a corresponding mask, i.e. if it came from 

952 a Keras layer with masking support. 

953 - If the layer is not built, the method will call `build`. 

954 

955 Raises: 

956 ValueError: if the layer's `call` method returns None (an invalid value). 

957 RuntimeError: if `super().__init__()` was not called in the constructor. 

958 """ 

959 if not hasattr(self, '_thread_local'): 

960 raise RuntimeError( 

961 'You must call `super().__init__()` in the layer constructor.') 

962 

963 # `inputs` (the first arg in the method spec) is special cased in 

964 # layer call due to historical reasons. 

965 # This special casing currently takes the form of: 

966 # - 'inputs' must be explicitly passed. A layer cannot have zero arguments, 

967 # and inputs cannot have been provided via the default value of a kwarg. 

968 # - numpy/scalar values in `inputs` get converted to tensors 

969 # - implicit masks / mask metadata are only collected from 'inputs` 

970 # - Layers are built using shape info from 'inputs' only 

971 # - input_spec compatibility is only checked against `inputs` 

972 # - mixed precision casting (autocast) is only applied to `inputs`, 

973 # not to any other argument. 

974 # - setting the SavedModel saving spec. 

975 inputs, args, kwargs = self._split_out_first_arg(args, kwargs) 

976 input_list = nest.flatten(inputs) 

977 

978 # Functional Model construction mode is invoked when `Layer`s are called on 

979 # symbolic `KerasTensor`s, i.e.: 

980 # >> inputs = tf.keras.Input(10) 

981 # >> outputs = MyLayer()(inputs) # Functional construction mode. 

982 # >> model = tf.keras.Model(inputs, outputs) 

983 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list): 

984 return self._functional_construction_call(inputs, args, kwargs, 

985 input_list) 

986 

987 # Maintains info about the `Layer.call` stack. 

988 call_context = base_layer_utils.call_context() 

989 

990 # Accept NumPy and scalar inputs by converting to Tensors. 

991 if any(isinstance(x, ( 

992 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): 

993 inputs = nest.map_structure(_convert_numpy_or_python_types, inputs) 

994 input_list = nest.flatten(inputs) 

995 

996 # Handle `mask` propagation from previous layer to current layer. Masks can 

997 # be propagated explicitly via the `mask` argument, or implicitly via 

998 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 

999 # explicitly take priority. 

1000 input_masks, mask_is_implicit = self._get_input_masks( 

1001 inputs, input_list, args, kwargs) 

1002 if self._expects_mask_arg and mask_is_implicit: 

1003 kwargs['mask'] = input_masks 

1004 

1005 # Training mode for `Layer.call` is set via (in order of priority): 

1006 # (1) The `training` argument passed to this `Layer.call`, if it is not None 

1007 # (2) The training mode of an outer `Layer.call`. 

1008 # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set) 

1009 # (4) Any non-None default value for `training` specified in the call 

1010 # signature 

1011 # (5) False (treating the layer as if it's in inference) 

1012 args, kwargs, training_mode = self._set_training_mode( 

1013 args, kwargs, call_context) 

1014 

1015 # Losses are cleared for all sublayers on the outermost `Layer.call`. 

1016 # Losses are not cleared on inner `Layer.call`s, because sublayers can be 

1017 # called multiple times. 

1018 if not call_context.in_call: 

1019 self._clear_losses() 

1020 

1021 eager = context.executing_eagerly() 

1022 with call_context.enter( 

1023 layer=self, 

1024 inputs=inputs, 

1025 build_graph=not eager, 

1026 training=training_mode): 

1027 

1028 input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) 

1029 if eager: 

1030 call_fn = self.call 

1031 name_scope = self._name 

1032 else: 

1033 name_scope = self._name_scope() # Avoid autoincrementing. # pylint: disable=not-callable 

1034 call_fn = self._autographed_call() 

1035 

1036 with ops.name_scope_v2(name_scope): 

1037 if not self.built: 

1038 self._maybe_build(inputs) 

1039 

1040 if self._autocast: 

1041 inputs = self._maybe_cast_inputs(inputs, input_list) 

1042 

1043 with autocast_variable.enable_auto_cast_variables( 

1044 self._compute_dtype_object): 

1045 outputs = call_fn(inputs, *args, **kwargs) 

1046 

1047 if self._activity_regularizer: 

1048 self._handle_activity_regularization(inputs, outputs) 

1049 if self._supports_masking: 

1050 self._set_mask_metadata(inputs, outputs, input_masks, not eager) 

1051 if self._saved_model_inputs_spec is None: 

1052 self._set_save_spec(inputs) 

1053 

1054 return outputs 

1055 

1056 def _functional_construction_call(self, inputs, args, kwargs, input_list): 

1057 call_context = base_layer_utils.call_context() 

1058 

1059 # Accept NumPy and scalar inputs by converting to Tensors. 

1060 if any(isinstance(x, ( 

1061 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): 

1062 

1063 def _convert_non_tensor(x): 

1064 # Don't call `ops.convert_to_tensor` on all `inputs` because 

1065 # `SparseTensors` can't be converted to `Tensor`. 

1066 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): 

1067 return tensor_conversion.convert_to_tensor_v2_with_dispatch(x) 

1068 return x 

1069 

1070 inputs = nest.map_structure(_convert_non_tensor, inputs) 

1071 input_list = nest.flatten(inputs) 

1072 

1073 # Handle `mask` propagation from previous layer to current layer. Masks can 

1074 # be propagated explicitly via the `mask` argument, or implicitly via 

1075 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 

1076 # explicitly take priority. 

1077 mask_arg_passed_by_framework = False 

1078 input_masks, mask_is_implicit = self._get_input_masks( 

1079 inputs, input_list, args, kwargs) 

1080 if self._expects_mask_arg and mask_is_implicit: 

1081 kwargs['mask'] = input_masks 

1082 mask_arg_passed_by_framework = True 

1083 

1084 # If `training` argument is None or not explicitly passed, 

1085 # propagate `training` value from this layer's calling layer. 

1086 training_value = None 

1087 training_arg_passed_by_framework = False 

1088 # Priority 1: `training` was explicitly passed a non-None value. 

1089 if self._call_arg_was_passed('training', args, kwargs): 

1090 training_value = self._get_call_arg_value('training', args, kwargs) 

1091 if not self._expects_training_arg: 

1092 kwargs.pop('training') 

1093 

1094 if training_value is None: 

1095 # Priority 2: `training` was passed to a parent layer. 

1096 if call_context.training is not None: 

1097 training_value = call_context.training 

1098 # Priority 3: `learning_phase()` has been set. 

1099 elif backend.global_learning_phase_is_set(): 

1100 training_value = backend.learning_phase() 

1101 # Force the training_value to be bool type which matches to the contract 

1102 # for layer/model call args. 

1103 if tensor_util.is_tf_type(training_value): 

1104 training_value = math_ops.cast(training_value, dtypes.bool) 

1105 else: 

1106 training_value = bool(training_value) 

1107 # Priority 4: trace layer with the default training argument specified 

1108 # in the `call` signature (or in inference mode if the `call` signature 

1109 # specifies no non-None default). 

1110 else: 

1111 training_value = self._default_training_arg 

1112 # In cases (2), (3), (4) the training argument is passed automatically 

1113 # by the framework, and will not be hard-coded into the model. 

1114 if self._expects_training_arg: 

1115 args, kwargs = self._set_call_arg_value('training', training_value, 

1116 args, kwargs) 

1117 training_arg_passed_by_framework = True 

1118 

1119 with call_context.enter( 

1120 layer=self, inputs=inputs, build_graph=True, training=training_value): 

1121 # Check input assumptions set after layer building, e.g. input shape. 

1122 outputs = self._keras_tensor_symbolic_call( 

1123 inputs, input_masks, args, kwargs) 

1124 

1125 if outputs is None: 

1126 raise ValueError('A layer\'s `call` method should return a ' 

1127 'Tensor or a list of Tensors, not None ' 

1128 '(layer: ' + self.name + ').') 

1129 if training_arg_passed_by_framework: 

1130 args, kwargs = self._set_call_arg_value( 

1131 'training', None, args, kwargs, pop_kwarg_if_none=True) 

1132 if mask_arg_passed_by_framework: 

1133 kwargs.pop('mask') 

1134 # Node connectivity does not special-case the first argument. 

1135 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, 

1136 outputs) 

1137 return outputs 

1138 

1139 def _set_training_mode(self, args, kwargs, call_context): 

1140 training_mode = None 

1141 if self._expects_training_arg: 

1142 # (1) `training` was passed to this `Layer.call`. 

1143 if self._call_arg_was_passed('training', args, kwargs): 

1144 training_mode = self._get_call_arg_value('training', args, kwargs) 

1145 # If no `training` arg was passed, or `None` was explicitly passed, 

1146 # the framework will make a decision about the training mode is. 

1147 if training_mode is None: 

1148 call_ctx_training = call_context.training 

1149 # (2) `training` mode is inferred from an outer `Layer.call`. 

1150 if call_ctx_training is not None: 

1151 training_mode = call_ctx_training 

1152 # (3) User set `tf.keras.backend.set_learning_phase`. 

1153 elif backend.global_learning_phase_is_set(): 

1154 training_mode = backend.learning_phase() 

1155 # Ensure value is a `bool` or `tf.bool`. 

1156 if isinstance(training_mode, bool): 

1157 pass 

1158 elif tensor_util.is_tf_type(training_mode): 

1159 training_mode = math_ops.cast(training_mode, dtypes.bool) 

1160 else: 

1161 training_mode = bool(training_mode) 

1162 # (4) We default to using `call`'s default value for `training`, 

1163 # or treating the layer as if it is in inference if no non-None default 

1164 # is specified in the `call` signature. 

1165 else: 

1166 training_mode = self._default_training_arg 

1167 

1168 # For case (2), (3), (4) `training` arg is passed by framework. 

1169 args, kwargs = self._set_call_arg_value('training', training_mode, args, 

1170 kwargs) 

1171 else: 

1172 if 'training' in kwargs: 

1173 # `training` was passed to this `Layer` but is not needed for 

1174 # `Layer.call`. It will set the default mode for inner `Layer.call`s. 

1175 training_mode = kwargs.pop('training') 

1176 else: 

1177 # Grab the current `training` mode from any outer `Layer.call`. 

1178 training_mode = call_context.training 

1179 

1180 return args, kwargs, training_mode 

1181 

1182 def _autographed_call(self): 

1183 # Wrapping `call` function in autograph to allow for dynamic control 

1184 # flow and control dependencies in call. We are limiting this to 

1185 # subclassed layers as autograph is strictly needed only for 

1186 # subclassed layers and models. 

1187 # tf_convert will respect the value of autograph setting in the 

1188 # enclosing tf.function, if any. 

1189 if (base_layer_utils.is_subclassed(self) and 

1190 not base_layer_utils.from_saved_model(self)): 

1191 return autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 

1192 else: 

1193 return self.call 

1194 

1195 @property 

1196 def dtype(self): 

1197 """The dtype of the layer weights. 

1198 

1199 This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless 

1200 mixed precision is used, this is the same as `Layer.compute_dtype`, the 

1201 dtype of the layer's computations. 

1202 """ 

1203 return self._dtype_policy.variable_dtype 

1204 

1205 @property 

1206 def name(self): 

1207 """Name of the layer (string), set in the constructor.""" 

1208 return self._name 

1209 

1210 @property 

1211 def supports_masking(self): 

1212 """Whether this layer supports computing a mask using `compute_mask`.""" 

1213 return self._supports_masking 

1214 

1215 @supports_masking.setter 

1216 def supports_masking(self, value): 

1217 self._supports_masking = value 

1218 

1219 @property 

1220 def dynamic(self): 

1221 """Whether the layer is dynamic (eager-only); set in the constructor.""" 

1222 return any(layer._dynamic for layer in self._flatten_layers()) 

1223 

1224 @property 

1225 @doc_controls.do_not_doc_inheritable 

1226 def stateful(self): 

1227 return any(layer._stateful for layer in self._flatten_layers()) 

1228 

1229 @stateful.setter 

1230 def stateful(self, value): 

1231 self._stateful = value 

1232 

1233 @property 

1234 def trainable(self): 

1235 return self._trainable 

1236 

1237 @trainable.setter 

1238 def trainable(self, value): 

1239 for layer in self._flatten_layers(): 

1240 layer._trainable = value 

1241 

1242 @property 

1243 def activity_regularizer(self): 

1244 """Optional regularizer function for the output of this layer.""" 

1245 return self._activity_regularizer 

1246 

1247 @activity_regularizer.setter 

1248 def activity_regularizer(self, regularizer): 

1249 """Optional regularizer function for the output of this layer.""" 

1250 self._activity_regularizer = regularizer 

1251 

1252 @property 

1253 def input_spec(self): 

1254 """`InputSpec` instance(s) describing the input format for this layer. 

1255 

1256 When you create a layer subclass, you can set `self.input_spec` to enable 

1257 the layer to run input compatibility checks when it is called. 

1258 Consider a `Conv2D` layer: it can only be called on a single input tensor 

1259 of rank 4. As such, you can set, in `__init__()`: 

1260 

1261 ```python 

1262 self.input_spec = tf.keras.layers.InputSpec(ndim=4) 

1263 ``` 

1264 

1265 Now, if you try to call the layer on an input that isn't rank 4 

1266 (for instance, an input of shape `(2,)`, it will raise a nicely-formatted 

1267 error: 

1268 

1269 ``` 

1270 ValueError: Input 0 of layer conv2d is incompatible with the layer: 

1271 expected ndim=4, found ndim=1. Full shape received: [2] 

1272 ``` 

1273 

1274 Input checks that can be specified via `input_spec` include: 

1275 - Structure (e.g. a single input, a list of 2 inputs, etc) 

1276 - Shape 

1277 - Rank (ndim) 

1278 - Dtype 

1279 

1280 For more information, see `tf.keras.layers.InputSpec`. 

1281 

1282 Returns: 

1283 A `tf.keras.layers.InputSpec` instance, or nested structure thereof. 

1284 """ 

1285 return self._input_spec 

1286 

1287 @input_spec.setter 

1288 # Must be decorated to prevent tracking, since the input_spec can be nested 

1289 # InputSpec objects. 

1290 @trackable.no_automatic_dependency_tracking 

1291 def input_spec(self, value): 

1292 for v in nest.flatten(value): 

1293 if v is not None and not isinstance(v, InputSpec): 

1294 raise TypeError('Layer input_spec must be an instance of InputSpec. ' 

1295 'Got: {}'.format(v)) 

1296 self._input_spec = value 

1297 

1298 @property 

1299 def trainable_weights(self): 

1300 """List of all trainable weights tracked by this layer. 

1301 

1302 Trainable weights are updated via gradient descent during training. 

1303 

1304 Returns: 

1305 A list of trainable variables. 

1306 """ 

1307 if self.trainable: 

1308 children_weights = self._gather_children_attribute('trainable_variables') 

1309 return self._dedup_weights(self._trainable_weights + children_weights) 

1310 else: 

1311 return [] 

1312 

1313 @property 

1314 def non_trainable_weights(self): 

1315 """List of all non-trainable weights tracked by this layer. 

1316 

1317 Non-trainable weights are *not* updated during training. They are expected 

1318 to be updated manually in `call()`. 

1319 

1320 Returns: 

1321 A list of non-trainable variables. 

1322 """ 

1323 if self.trainable: 

1324 children_weights = self._gather_children_attribute( 

1325 'non_trainable_variables') 

1326 non_trainable_weights = self._non_trainable_weights + children_weights 

1327 else: 

1328 children_weights = self._gather_children_attribute('variables') 

1329 non_trainable_weights = ( 

1330 self._trainable_weights + self._non_trainable_weights + 

1331 children_weights) 

1332 return self._dedup_weights(non_trainable_weights) 

1333 

1334 @property 

1335 def weights(self): 

1336 """Returns the list of all layer variables/weights. 

1337 

1338 Returns: 

1339 A list of variables. 

1340 """ 

1341 return self.trainable_weights + self.non_trainable_weights 

1342 

1343 @property 

1344 @doc_controls.do_not_generate_docs 

1345 def updates(self): 

1346 warnings.warn('`layer.updates` will be removed in a future version. ' 

1347 'This property should not be used in TensorFlow 2.0, ' 

1348 'as `updates` are applied automatically.') 

1349 return [] 

1350 

1351 @property 

1352 def losses(self): 

1353 """List of losses added using the `add_loss()` API. 

1354 

1355 Variable regularization tensors are created when this property is accessed, 

1356 so it is eager safe: accessing `losses` under a `tf.GradientTape` will 

1357 propagate gradients back to the corresponding variables. 

1358 

1359 Examples: 

1360 

1361 >>> class MyLayer(tf.keras.layers.Layer): 

1362 ... def call(self, inputs): 

1363 ... self.add_loss(tf.abs(tf.reduce_mean(inputs))) 

1364 ... return inputs 

1365 >>> l = MyLayer() 

1366 >>> l(np.ones((10, 1))) 

1367 >>> l.losses 

1368 [1.0] 

1369 

1370 >>> inputs = tf.keras.Input(shape=(10,)) 

1371 >>> x = tf.keras.layers.Dense(10)(inputs) 

1372 >>> outputs = tf.keras.layers.Dense(1)(x) 

1373 >>> model = tf.keras.Model(inputs, outputs) 

1374 >>> # Activity regularization. 

1375 >>> len(model.losses) 

1376 0 

1377 >>> model.add_loss(tf.abs(tf.reduce_mean(x))) 

1378 >>> len(model.losses) 

1379 1 

1380 

1381 >>> inputs = tf.keras.Input(shape=(10,)) 

1382 >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones') 

1383 >>> x = d(inputs) 

1384 >>> outputs = tf.keras.layers.Dense(1)(x) 

1385 >>> model = tf.keras.Model(inputs, outputs) 

1386 >>> # Weight regularization. 

1387 >>> model.add_loss(lambda: tf.reduce_mean(d.kernel)) 

1388 >>> model.losses 

1389 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>] 

1390 

1391 Returns: 

1392 A list of tensors. 

1393 """ 

1394 collected_losses = [] 

1395 for layer in self._flatten_layers(): 

1396 # If any eager losses are present, we assume the model to be part of an 

1397 # eager training loop (either a custom one or the one used when 

1398 # `run_eagerly=True`) and so we always return just the eager losses. 

1399 if layer._eager_losses: 

1400 # Filter placeholder losses that may have been added by revived layers. 

1401 # (see base_layer_utils for details). 

1402 if (layer._eager_losses[0] is 

1403 not base_layer_utils.REVIVED_LOSS_PLACEHOLDER): 

1404 collected_losses.extend(layer._eager_losses) 

1405 else: 

1406 collected_losses.extend(layer._losses) 

1407 for regularizer in layer._callable_losses: 

1408 loss_tensor = regularizer() 

1409 if loss_tensor is not None: 

1410 collected_losses.append(loss_tensor) 

1411 return collected_losses 

1412 

1413 def add_loss(self, losses, **kwargs): 

1414 """Add loss tensor(s), potentially dependent on layer inputs. 

1415 

1416 Some losses (for instance, activity regularization losses) may be dependent 

1417 on the inputs passed when calling a layer. Hence, when reusing the same 

1418 layer on different inputs `a` and `b`, some entries in `layer.losses` may 

1419 be dependent on `a` and some on `b`. This method automatically keeps track 

1420 of dependencies. 

1421 

1422 This method can be used inside a subclassed layer or model's `call` 

1423 function, in which case `losses` should be a Tensor or list of Tensors. 

1424 

1425 Example: 

1426 

1427 ```python 

1428 class MyLayer(tf.keras.layers.Layer): 

1429 def call(self, inputs): 

1430 self.add_loss(tf.abs(tf.reduce_mean(inputs))) 

1431 return inputs 

1432 ``` 

1433 

1434 This method can also be called directly on a Functional Model during 

1435 construction. In this case, any loss Tensors passed to this Model must 

1436 be symbolic and be able to be traced back to the model's `Input`s. These 

1437 losses become part of the model's topology and are tracked in `get_config`. 

1438 

1439 Example: 

1440 

1441 ```python 

1442 inputs = tf.keras.Input(shape=(10,)) 

1443 x = tf.keras.layers.Dense(10)(inputs) 

1444 outputs = tf.keras.layers.Dense(1)(x) 

1445 model = tf.keras.Model(inputs, outputs) 

1446 # Activity regularization. 

1447 model.add_loss(tf.abs(tf.reduce_mean(x))) 

1448 ``` 

1449 

1450 If this is not the case for your loss (if, for example, your loss references 

1451 a `Variable` of one of the model's layers), you can wrap your loss in a 

1452 zero-argument lambda. These losses are not tracked as part of the model's 

1453 topology since they can't be serialized. 

1454 

1455 Example: 

1456 

1457 ```python 

1458 inputs = tf.keras.Input(shape=(10,)) 

1459 d = tf.keras.layers.Dense(10) 

1460 x = d(inputs) 

1461 outputs = tf.keras.layers.Dense(1)(x) 

1462 model = tf.keras.Model(inputs, outputs) 

1463 # Weight regularization. 

1464 model.add_loss(lambda: tf.reduce_mean(d.kernel)) 

1465 ``` 

1466 

1467 Args: 

1468 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses 

1469 may also be zero-argument callables which create a loss tensor. 

1470 **kwargs: Additional keyword arguments for backward compatibility. 

1471 Accepted values: 

1472 inputs - Deprecated, will be automatically inferred. 

1473 """ 

1474 kwargs.pop('inputs', None) 

1475 if kwargs: 

1476 raise TypeError('Unknown keyword arguments: %s' % (kwargs.keys(),)) 

1477 

1478 def _tag_callable(loss): 

1479 """Tags callable loss tensor as `_unconditional_loss`.""" 

1480 if callable(loss): 

1481 # We run the loss without autocasting, as regularizers are often 

1482 # numerically unstable in float16. 

1483 with autocast_variable.enable_auto_cast_variables(None): 

1484 loss = loss() 

1485 if loss is None: 

1486 return None # Will be filtered out when computing the .losses property 

1487 if not tensor_util.is_tf_type(loss): 

1488 loss = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

1489 loss, dtype=backend.floatx() 

1490 ) 

1491 loss._unconditional_loss = True # pylint: disable=protected-access 

1492 return loss 

1493 

1494 losses = nest.flatten(losses) 

1495 

1496 callable_losses = [] 

1497 eager_losses = [] 

1498 symbolic_losses = [] 

1499 for loss in losses: 

1500 if callable(loss): 

1501 callable_losses.append(functools.partial(_tag_callable, loss)) 

1502 continue 

1503 if loss is None: 

1504 continue 

1505 if not tensor_util.is_tf_type(loss) and not isinstance( 

1506 loss, keras_tensor.KerasTensor): 

1507 loss = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

1508 loss, dtype=backend.floatx() 

1509 ) 

1510 # TF Functions should take the eager path. 

1511 if ((tf_utils.is_symbolic_tensor(loss) or 

1512 isinstance(loss, keras_tensor.KerasTensor)) and 

1513 not base_layer_utils.is_in_tf_function()): 

1514 symbolic_losses.append(loss) 

1515 elif tensor_util.is_tf_type(loss): 

1516 eager_losses.append(loss) 

1517 

1518 self._callable_losses.extend(callable_losses) 

1519 

1520 in_call_context = base_layer_utils.call_context().in_call 

1521 if eager_losses and not in_call_context: 

1522 raise ValueError( 

1523 'Expected a symbolic Tensors or a callable for the loss value. ' 

1524 'Please wrap your loss computation in a zero argument `lambda`.') 

1525 

1526 self._eager_losses.extend(eager_losses) 

1527 

1528 for symbolic_loss in symbolic_losses: 

1529 if getattr(self, '_is_graph_network', False): 

1530 self._graph_network_add_loss(symbolic_loss) 

1531 else: 

1532 # Possible a loss was added in a Layer's `build`. 

1533 self._losses.append(symbolic_loss) 

1534 

1535 def _clear_losses(self): 

1536 """Used every step in eager to reset losses.""" 

1537 # Set to thread local directly to avoid Layer.__setattr__ overhead. 

1538 if not getattr(self, '_self_tracked_trackables', 

1539 None): # Fast path for single Layer. 

1540 self._thread_local._eager_losses = [] 

1541 else: 

1542 for layer in self._flatten_layers(): 

1543 layer._thread_local._eager_losses = [] 

1544 

1545 @property 

1546 def metrics(self): 

1547 """List of metrics added using the `add_metric()` API. 

1548 

1549 Example: 

1550 

1551 >>> input = tf.keras.layers.Input(shape=(3,)) 

1552 >>> d = tf.keras.layers.Dense(2) 

1553 >>> output = d(input) 

1554 >>> d.add_metric(tf.reduce_max(output), name='max') 

1555 >>> d.add_metric(tf.reduce_min(output), name='min') 

1556 >>> [m.name for m in d.metrics] 

1557 ['max', 'min'] 

1558 

1559 Returns: 

1560 A list of `Metric` objects. 

1561 """ 

1562 collected_metrics = [] 

1563 for layer in self._flatten_layers(): 

1564 with layer._metrics_lock: 

1565 collected_metrics.extend(layer._metrics) 

1566 return collected_metrics 

1567 

1568 def add_metric(self, value, name=None, **kwargs): 

1569 """Adds metric tensor to the layer. 

1570 

1571 This method can be used inside the `call()` method of a subclassed layer 

1572 or model. 

1573 

1574 ```python 

1575 class MyMetricLayer(tf.keras.layers.Layer): 

1576 def __init__(self): 

1577 super(MyMetricLayer, self).__init__(name='my_metric_layer') 

1578 self.mean = tf.keras.metrics.Mean(name='metric_1') 

1579 

1580 def call(self, inputs): 

1581 self.add_metric(self.mean(inputs)) 

1582 self.add_metric(tf.reduce_sum(inputs), name='metric_2') 

1583 return inputs 

1584 ``` 

1585 

1586 This method can also be called directly on a Functional Model during 

1587 construction. In this case, any tensor passed to this Model must 

1588 be symbolic and be able to be traced back to the model's `Input`s. These 

1589 metrics become part of the model's topology and are tracked when you 

1590 save the model via `save()`. 

1591 

1592 ```python 

1593 inputs = tf.keras.Input(shape=(10,)) 

1594 x = tf.keras.layers.Dense(10)(inputs) 

1595 outputs = tf.keras.layers.Dense(1)(x) 

1596 model = tf.keras.Model(inputs, outputs) 

1597 model.add_metric(math_ops.reduce_sum(x), name='metric_1') 

1598 ``` 

1599 

1600 Note: Calling `add_metric()` with the result of a metric object on a 

1601 Functional Model, as shown in the example below, is not supported. This is 

1602 because we cannot trace the metric result tensor back to the model's inputs. 

1603 

1604 ```python 

1605 inputs = tf.keras.Input(shape=(10,)) 

1606 x = tf.keras.layers.Dense(10)(inputs) 

1607 outputs = tf.keras.layers.Dense(1)(x) 

1608 model = tf.keras.Model(inputs, outputs) 

1609 model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1') 

1610 ``` 

1611 

1612 Args: 

1613 value: Metric tensor. 

1614 name: String metric name. 

1615 **kwargs: Additional keyword arguments for backward compatibility. 

1616 Accepted values: 

1617 `aggregation` - When the `value` tensor provided is not the result of 

1618 calling a `keras.Metric` instance, it will be aggregated by default 

1619 using a `keras.Metric.Mean`. 

1620 """ 

1621 kwargs_keys = list(kwargs.keys()) 

1622 if (len(kwargs_keys) > 1 or 

1623 (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')): 

1624 raise TypeError('Unknown keyword arguments: ', str(kwargs.keys())) 

1625 

1626 from_metric_obj = hasattr(value, '_metric_obj') 

1627 is_symbolic = isinstance(value, keras_tensor.KerasTensor) 

1628 in_call_context = base_layer_utils.call_context().in_call 

1629 

1630 if name is None and not from_metric_obj: 

1631 # Eg. `self.add_metric(math_ops.reduce_sum(x))` 

1632 # In eager mode, we use metric name to lookup a metric. Without a name, 

1633 # a new Mean metric wrapper will be created on every model/layer call. 

1634 # So, we raise an error when no name is provided. 

1635 # We will do the same for symbolic mode for consistency although a name 

1636 # will be generated if no name is provided. 

1637 

1638 # We will not raise this error in the foll use case for the sake of 

1639 # consistency as name in provided in the metric constructor. 

1640 # mean = metrics.Mean(name='my_metric') 

1641 # model.add_metric(mean(outputs)) 

1642 raise ValueError('Please provide a name for your metric like ' 

1643 '`self.add_metric(tf.reduce_sum(inputs), ' 

1644 'name=\'mean_activation\')`') 

1645 elif from_metric_obj: 

1646 name = value._metric_obj.name 

1647 

1648 if not in_call_context and not is_symbolic: 

1649 raise ValueError('Expected a symbolic Tensor for the metric value, ' 

1650 'received: ' + str(value)) 

1651 

1652 # If a metric was added in a Layer's `call` or `build`. 

1653 if in_call_context or not getattr(self, '_is_graph_network', False): 

1654 # TF Function path should take the eager path. 

1655 

1656 # If the given metric is available in `metrics` list we just update state 

1657 # on it, otherwise we create a new metric instance and 

1658 # add it to the `metrics` list. 

1659 metric_obj = getattr(value, '_metric_obj', None) 

1660 # Tensors that come from a Metric object already updated the Metric state. 

1661 should_update_state = not metric_obj 

1662 name = metric_obj.name if metric_obj else name 

1663 

1664 with self._metrics_lock: 

1665 match = self._get_existing_metric(name) 

1666 if match: 

1667 metric_obj = match 

1668 elif metric_obj: 

1669 self._metrics.append(metric_obj) 

1670 else: 

1671 # Build the metric object with the value's dtype if it defines one 

1672 metric_obj = metrics_mod.Mean( 

1673 name=name, dtype=getattr(value, 'dtype', None)) 

1674 self._metrics.append(metric_obj) 

1675 

1676 if should_update_state: 

1677 metric_obj(value) 

1678 else: 

1679 if from_metric_obj: 

1680 raise ValueError('Using the result of calling a `Metric` object ' 

1681 'when calling `add_metric` on a Functional ' 

1682 'Model is not supported. Please pass the ' 

1683 'Tensor to monitor directly.') 

1684 

1685 # Insert layers into the Keras Graph Network. 

1686 aggregation = None if from_metric_obj else 'mean' 

1687 self._graph_network_add_metric(value, aggregation, name) 

1688 

1689 @doc_controls.do_not_doc_inheritable 

1690 def add_update(self, updates, inputs=None): 

1691 """Add update op(s), potentially dependent on layer inputs. 

1692 

1693 Weight updates (for instance, the updates of the moving mean and variance 

1694 in a BatchNormalization layer) may be dependent on the inputs passed 

1695 when calling a layer. Hence, when reusing the same layer on 

1696 different inputs `a` and `b`, some entries in `layer.updates` may be 

1697 dependent on `a` and some on `b`. This method automatically keeps track 

1698 of dependencies. 

1699 

1700 This call is ignored when eager execution is enabled (in that case, variable 

1701 updates are run on the fly and thus do not need to be tracked for later 

1702 execution). 

1703 

1704 Args: 

1705 updates: Update op, or list/tuple of update ops, or zero-arg callable 

1706 that returns an update op. A zero-arg callable should be passed in 

1707 order to disable running the updates by setting `trainable=False` 

1708 on this Layer, when executing in Eager mode. 

1709 inputs: Deprecated, will be automatically inferred. 

1710 """ 

1711 if inputs is not None: 

1712 tf_logging.warning( 

1713 '`add_update` `inputs` kwarg has been deprecated. You no longer need ' 

1714 'to pass a value to `inputs` as it is being automatically inferred.') 

1715 call_context = base_layer_utils.call_context() 

1716 # No need to run updates during Functional API construction. 

1717 if call_context.in_keras_graph: 

1718 return 

1719 

1720 # Callable updates are disabled by setting `trainable=False`. 

1721 if not call_context.frozen: 

1722 for update in nest.flatten(updates): 

1723 if callable(update): 

1724 update() # pylint: disable=not-callable 

1725 

1726 def set_weights(self, weights): 

1727 """Sets the weights of the layer, from NumPy arrays. 

1728 

1729 The weights of a layer represent the state of the layer. This function 

1730 sets the weight values from numpy arrays. The weight values should be 

1731 passed in the order they are created by the layer. Note that the layer's 

1732 weights must be instantiated before calling this function, by calling 

1733 the layer. 

1734 

1735 For example, a `Dense` layer returns a list of two values: the kernel matrix 

1736 and the bias vector. These can be used to set the weights of another 

1737 `Dense` layer: 

1738 

1739 >>> layer_a = tf.keras.layers.Dense(1, 

1740 ... kernel_initializer=tf.constant_initializer(1.)) 

1741 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]])) 

1742 >>> layer_a.get_weights() 

1743 [array([[1.], 

1744 [1.], 

1745 [1.]], dtype=float32), array([0.], dtype=float32)] 

1746 >>> layer_b = tf.keras.layers.Dense(1, 

1747 ... kernel_initializer=tf.constant_initializer(2.)) 

1748 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]])) 

1749 >>> layer_b.get_weights() 

1750 [array([[2.], 

1751 [2.], 

1752 [2.]], dtype=float32), array([0.], dtype=float32)] 

1753 >>> layer_b.set_weights(layer_a.get_weights()) 

1754 >>> layer_b.get_weights() 

1755 [array([[1.], 

1756 [1.], 

1757 [1.]], dtype=float32), array([0.], dtype=float32)] 

1758 

1759 Args: 

1760 weights: a list of NumPy arrays. The number 

1761 of arrays and their shape must match 

1762 number of the dimensions of the weights 

1763 of the layer (i.e. it should match the 

1764 output of `get_weights`). 

1765 

1766 Raises: 

1767 ValueError: If the provided weights list does not match the 

1768 layer's specifications. 

1769 """ 

1770 params = self.weights 

1771 

1772 expected_num_weights = 0 

1773 for param in params: 

1774 if isinstance(param, base_layer_utils.TrackableWeightHandler): 

1775 expected_num_weights += param.num_tensors 

1776 else: 

1777 expected_num_weights += 1 

1778 

1779 if expected_num_weights != len(weights): 

1780 raise ValueError( 

1781 'You called `set_weights(weights)` on layer "%s" ' 

1782 'with a weight list of length %s, but the layer was ' 

1783 'expecting %s weights. Provided weights: %s...' % 

1784 (self.name, len(weights), expected_num_weights, str(weights)[:50])) 

1785 

1786 weight_index = 0 

1787 weight_value_tuples = [] 

1788 for param in params: 

1789 if isinstance(param, base_layer_utils.TrackableWeightHandler): 

1790 num_tensors = param.num_tensors 

1791 tensors = weights[weight_index:weight_index + num_tensors] 

1792 param.set_weights(tensors) 

1793 weight_index += num_tensors 

1794 else: 

1795 weight = weights[weight_index] 

1796 weight_shape = weight.shape if hasattr(weight, 'shape') else () 

1797 ref_shape = param.shape 

1798 if not ref_shape.is_compatible_with(weight_shape): 

1799 raise ValueError( 

1800 'Layer weight shape %s not compatible with provided weight ' 

1801 'shape %s' % (ref_shape, weight_shape)) 

1802 weight_value_tuples.append((param, weight)) 

1803 weight_index += 1 

1804 

1805 backend.batch_set_value(weight_value_tuples) 

1806 

1807 # Perform any layer defined finalization of the layer state. 

1808 for layer in self._flatten_layers(): 

1809 layer.finalize_state() 

1810 

1811 def get_weights(self): 

1812 """Returns the current weights of the layer, as NumPy arrays. 

1813 

1814 The weights of a layer represent the state of the layer. This function 

1815 returns both trainable and non-trainable weight values associated with this 

1816 layer as a list of NumPy arrays, which can in turn be used to load state 

1817 into similarly parameterized layers. 

1818 

1819 For example, a `Dense` layer returns a list of two values: the kernel matrix 

1820 and the bias vector. These can be used to set the weights of another 

1821 `Dense` layer: 

1822 

1823 >>> layer_a = tf.keras.layers.Dense(1, 

1824 ... kernel_initializer=tf.constant_initializer(1.)) 

1825 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]])) 

1826 >>> layer_a.get_weights() 

1827 [array([[1.], 

1828 [1.], 

1829 [1.]], dtype=float32), array([0.], dtype=float32)] 

1830 >>> layer_b = tf.keras.layers.Dense(1, 

1831 ... kernel_initializer=tf.constant_initializer(2.)) 

1832 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]])) 

1833 >>> layer_b.get_weights() 

1834 [array([[2.], 

1835 [2.], 

1836 [2.]], dtype=float32), array([0.], dtype=float32)] 

1837 >>> layer_b.set_weights(layer_a.get_weights()) 

1838 >>> layer_b.get_weights() 

1839 [array([[1.], 

1840 [1.], 

1841 [1.]], dtype=float32), array([0.], dtype=float32)] 

1842 

1843 Returns: 

1844 Weights values as a list of NumPy arrays. 

1845 """ 

1846 weights = self.weights 

1847 output_weights = [] 

1848 for weight in weights: 

1849 if isinstance(weight, base_layer_utils.TrackableWeightHandler): 

1850 output_weights.extend(weight.get_tensors()) 

1851 else: 

1852 output_weights.append(weight) 

1853 return backend.batch_get_value(output_weights) 

1854 

1855 @doc_controls.do_not_generate_docs 

1856 def finalize_state(self): 

1857 """Finalizes the layers state after updating layer weights. 

1858 

1859 This function can be subclassed in a layer and will be called after updating 

1860 a layer weights. It can be overridden to finalize any additional layer state 

1861 after a weight update. 

1862 """ 

1863 pass 

1864 

1865 @doc_controls.do_not_generate_docs 

1866 def get_updates_for(self, inputs): 

1867 """Deprecated, do NOT use! 

1868 

1869 Retrieves updates relevant to a specific set of inputs. 

1870 

1871 Args: 

1872 inputs: Input tensor or list/tuple of input tensors. 

1873 

1874 Returns: 

1875 List of update ops of the layer that depend on `inputs`. 

1876 """ 

1877 warnings.warn('`layer.get_updates_for` is deprecated and ' 

1878 'will be removed in a future version. ' 

1879 'Please use `layer.updates` method instead.') 

1880 return self.updates 

1881 

1882 @doc_controls.do_not_generate_docs 

1883 def get_losses_for(self, inputs): 

1884 """Deprecated, do NOT use! 

1885 

1886 Retrieves losses relevant to a specific set of inputs. 

1887 

1888 Args: 

1889 inputs: Input tensor or list/tuple of input tensors. 

1890 

1891 Returns: 

1892 List of loss tensors of the layer that depend on `inputs`. 

1893 """ 

1894 warnings.warn('`layer.get_losses_for` is deprecated and ' 

1895 'will be removed in a future version. ' 

1896 'Please use `layer.losses` instead.') 

1897 return self.losses 

1898 

1899 @doc_controls.do_not_doc_inheritable 

1900 def get_input_mask_at(self, node_index): 

1901 """Retrieves the input mask tensor(s) of a layer at a given node. 

1902 

1903 Args: 

1904 node_index: Integer, index of the node 

1905 from which to retrieve the attribute. 

1906 E.g. `node_index=0` will correspond to the 

1907 first time the layer was called. 

1908 

1909 Returns: 

1910 A mask tensor 

1911 (or list of tensors if the layer has multiple inputs). 

1912 """ 

1913 inputs = self.get_input_at(node_index) 

1914 if isinstance(inputs, list): 

1915 return [getattr(x, '_keras_mask', None) for x in inputs] 

1916 else: 

1917 return getattr(inputs, '_keras_mask', None) 

1918 

1919 @doc_controls.do_not_doc_inheritable 

1920 def get_output_mask_at(self, node_index): 

1921 """Retrieves the output mask tensor(s) of a layer at a given node. 

1922 

1923 Args: 

1924 node_index: Integer, index of the node 

1925 from which to retrieve the attribute. 

1926 E.g. `node_index=0` will correspond to the 

1927 first time the layer was called. 

1928 

1929 Returns: 

1930 A mask tensor 

1931 (or list of tensors if the layer has multiple outputs). 

1932 """ 

1933 output = self.get_output_at(node_index) 

1934 if isinstance(output, list): 

1935 return [getattr(x, '_keras_mask', None) for x in output] 

1936 else: 

1937 return getattr(output, '_keras_mask', None) 

1938 

1939 @property 

1940 @doc_controls.do_not_doc_inheritable 

1941 def input_mask(self): 

1942 """Retrieves the input mask tensor(s) of a layer. 

1943 

1944 Only applicable if the layer has exactly one inbound node, 

1945 i.e. if it is connected to one incoming layer. 

1946 

1947 Returns: 

1948 Input mask tensor (potentially None) or list of input 

1949 mask tensors. 

1950 

1951 Raises: 

1952 AttributeError: if the layer is connected to 

1953 more than one incoming layers. 

1954 """ 

1955 inputs = self.input 

1956 if isinstance(inputs, list): 

1957 return [getattr(x, '_keras_mask', None) for x in inputs] 

1958 else: 

1959 return getattr(inputs, '_keras_mask', None) 

1960 

1961 @property 

1962 @doc_controls.do_not_doc_inheritable 

1963 def output_mask(self): 

1964 """Retrieves the output mask tensor(s) of a layer. 

1965 

1966 Only applicable if the layer has exactly one inbound node, 

1967 i.e. if it is connected to one incoming layer. 

1968 

1969 Returns: 

1970 Output mask tensor (potentially None) or list of output 

1971 mask tensors. 

1972 

1973 Raises: 

1974 AttributeError: if the layer is connected to 

1975 more than one incoming layers. 

1976 """ 

1977 output = self.output 

1978 if isinstance(output, list): 

1979 return [getattr(x, '_keras_mask', None) for x in output] 

1980 else: 

1981 return getattr(output, '_keras_mask', None) 

1982 

1983 @doc_controls.do_not_doc_inheritable 

1984 def get_input_shape_at(self, node_index): 

1985 """Retrieves the input shape(s) of a layer at a given node. 

1986 

1987 Args: 

1988 node_index: Integer, index of the node 

1989 from which to retrieve the attribute. 

1990 E.g. `node_index=0` will correspond to the 

1991 first time the layer was called. 

1992 

1993 Returns: 

1994 A shape tuple 

1995 (or list of shape tuples if the layer has multiple inputs). 

1996 

1997 Raises: 

1998 RuntimeError: If called in Eager mode. 

1999 """ 

2000 return self._get_node_attribute_at_index(node_index, 'input_shapes', 

2001 'input shape') 

2002 

2003 @doc_controls.do_not_doc_inheritable 

2004 def get_output_shape_at(self, node_index): 

2005 """Retrieves the output shape(s) of a layer at a given node. 

2006 

2007 Args: 

2008 node_index: Integer, index of the node 

2009 from which to retrieve the attribute. 

2010 E.g. `node_index=0` will correspond to the 

2011 first time the layer was called. 

2012 

2013 Returns: 

2014 A shape tuple 

2015 (or list of shape tuples if the layer has multiple outputs). 

2016 

2017 Raises: 

2018 RuntimeError: If called in Eager mode. 

2019 """ 

2020 return self._get_node_attribute_at_index(node_index, 'output_shapes', 

2021 'output shape') 

2022 

2023 @doc_controls.do_not_doc_inheritable 

2024 def get_input_at(self, node_index): 

2025 """Retrieves the input tensor(s) of a layer at a given node. 

2026 

2027 Args: 

2028 node_index: Integer, index of the node 

2029 from which to retrieve the attribute. 

2030 E.g. `node_index=0` will correspond to the 

2031 first input node of the layer. 

2032 

2033 Returns: 

2034 A tensor (or list of tensors if the layer has multiple inputs). 

2035 

2036 Raises: 

2037 RuntimeError: If called in Eager mode. 

2038 """ 

2039 return self._get_node_attribute_at_index(node_index, 'input_tensors', 

2040 'input') 

2041 

2042 @doc_controls.do_not_doc_inheritable 

2043 def get_output_at(self, node_index): 

2044 """Retrieves the output tensor(s) of a layer at a given node. 

2045 

2046 Args: 

2047 node_index: Integer, index of the node 

2048 from which to retrieve the attribute. 

2049 E.g. `node_index=0` will correspond to the 

2050 first output node of the layer. 

2051 

2052 Returns: 

2053 A tensor (or list of tensors if the layer has multiple outputs). 

2054 

2055 Raises: 

2056 RuntimeError: If called in Eager mode. 

2057 """ 

2058 return self._get_node_attribute_at_index(node_index, 'output_tensors', 

2059 'output') 

2060 

2061 @property 

2062 def input(self): 

2063 """Retrieves the input tensor(s) of a layer. 

2064 

2065 Only applicable if the layer has exactly one input, 

2066 i.e. if it is connected to one incoming layer. 

2067 

2068 Returns: 

2069 Input tensor or list of input tensors. 

2070 

2071 Raises: 

2072 RuntimeError: If called in Eager mode. 

2073 AttributeError: If no inbound nodes are found. 

2074 """ 

2075 if not self._inbound_nodes: 

2076 raise AttributeError('Layer ' + self.name + 

2077 ' is not connected, no input to return.') 

2078 return self._get_node_attribute_at_index(0, 'input_tensors', 'input') 

2079 

2080 @property 

2081 def output(self): 

2082 """Retrieves the output tensor(s) of a layer. 

2083 

2084 Only applicable if the layer has exactly one output, 

2085 i.e. if it is connected to one incoming layer. 

2086 

2087 Returns: 

2088 Output tensor or list of output tensors. 

2089 

2090 Raises: 

2091 AttributeError: if the layer is connected to more than one incoming 

2092 layers. 

2093 RuntimeError: if called in Eager mode. 

2094 """ 

2095 if not self._inbound_nodes: 

2096 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') 

2097 return self._get_node_attribute_at_index(0, 'output_tensors', 'output') 

2098 

2099 @property 

2100 @doc_controls.do_not_doc_inheritable 

2101 def input_shape(self): 

2102 """Retrieves the input shape(s) of a layer. 

2103 

2104 Only applicable if the layer has exactly one input, 

2105 i.e. if it is connected to one incoming layer, or if all inputs 

2106 have the same shape. 

2107 

2108 Returns: 

2109 Input shape, as an integer shape tuple 

2110 (or list of shape tuples, one tuple per input tensor). 

2111 

2112 Raises: 

2113 AttributeError: if the layer has no defined input_shape. 

2114 RuntimeError: if called in Eager mode. 

2115 """ 

2116 if not self._inbound_nodes: 

2117 raise AttributeError('The layer has never been called ' 

2118 'and thus has no defined input shape.') 

2119 all_input_shapes = set( 

2120 [str(node.input_shapes) for node in self._inbound_nodes]) 

2121 if len(all_input_shapes) == 1: 

2122 return self._inbound_nodes[0].input_shapes 

2123 else: 

2124 raise AttributeError('The layer "' + str(self.name) + 

2125 ' has multiple inbound nodes, ' 

2126 'with different input shapes. Hence ' 

2127 'the notion of "input shape" is ' 

2128 'ill-defined for the layer. ' 

2129 'Use `get_input_shape_at(node_index)` ' 

2130 'instead.') 

2131 

2132 def count_params(self): 

2133 """Count the total number of scalars composing the weights. 

2134 

2135 Returns: 

2136 An integer count. 

2137 

2138 Raises: 

2139 ValueError: if the layer isn't yet built 

2140 (in which case its weights aren't yet defined). 

2141 """ 

2142 if not self.built: 

2143 if getattr(self, '_is_graph_network', False): 

2144 with tf_utils.maybe_init_scope(self): 

2145 self._maybe_build(self.inputs) 

2146 else: 

2147 raise ValueError('You tried to call `count_params` on ' + self.name + 

2148 ', but the layer isn\'t built. ' 

2149 'You can build it manually via: `' + self.name + 

2150 '.build(batch_input_shape)`.') 

2151 return layer_utils.count_params(self.weights) 

2152 

2153 @property 

2154 @doc_controls.do_not_doc_inheritable 

2155 def output_shape(self): 

2156 """Retrieves the output shape(s) of a layer. 

2157 

2158 Only applicable if the layer has one output, 

2159 or if all outputs have the same shape. 

2160 

2161 Returns: 

2162 Output shape, as an integer shape tuple 

2163 (or list of shape tuples, one tuple per output tensor). 

2164 

2165 Raises: 

2166 AttributeError: if the layer has no defined output shape. 

2167 RuntimeError: if called in Eager mode. 

2168 """ 

2169 if not self._inbound_nodes: 

2170 raise AttributeError('The layer has never been called ' 

2171 'and thus has no defined output shape.') 

2172 all_output_shapes = set( 

2173 [str(node.output_shapes) for node in self._inbound_nodes]) 

2174 if len(all_output_shapes) == 1: 

2175 return self._inbound_nodes[0].output_shapes 

2176 else: 

2177 raise AttributeError('The layer "%s"' 

2178 ' has multiple inbound nodes, ' 

2179 'with different output shapes. Hence ' 

2180 'the notion of "output shape" is ' 

2181 'ill-defined for the layer. ' 

2182 'Use `get_output_shape_at(node_index)` ' 

2183 'instead.' % self.name) 

2184 

2185 @property 

2186 @doc_controls.do_not_doc_inheritable 

2187 def inbound_nodes(self): 

2188 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 

2189 return self._inbound_nodes 

2190 

2191 @property 

2192 @doc_controls.do_not_doc_inheritable 

2193 def outbound_nodes(self): 

2194 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 

2195 return self._outbound_nodes 

2196 

2197 ############################################################################## 

2198 # Methods & attributes below are public aliases of other methods. # 

2199 ############################################################################## 

2200 

2201 @doc_controls.do_not_doc_inheritable 

2202 def apply(self, inputs, *args, **kwargs): 

2203 """Deprecated, do NOT use! 

2204 

2205 This is an alias of `self.__call__`. 

2206 

2207 Args: 

2208 inputs: Input tensor(s). 

2209 *args: additional positional arguments to be passed to `self.call`. 

2210 **kwargs: additional keyword arguments to be passed to `self.call`. 

2211 

2212 Returns: 

2213 Output tensor(s). 

2214 """ 

2215 warnings.warn('`layer.apply` is deprecated and ' 

2216 'will be removed in a future version. ' 

2217 'Please use `layer.__call__` method instead.') 

2218 return self.__call__(inputs, *args, **kwargs) 

2219 

2220 @doc_controls.do_not_doc_inheritable 

2221 def add_variable(self, *args, **kwargs): 

2222 """Deprecated, do NOT use! Alias for `add_weight`.""" 

2223 warnings.warn('`layer.add_variable` is deprecated and ' 

2224 'will be removed in a future version. ' 

2225 'Please use `layer.add_weight` method instead.') 

2226 return self.add_weight(*args, **kwargs) 

2227 

2228 @property 

2229 @doc_controls.do_not_generate_docs 

2230 def variables(self): 

2231 """Returns the list of all layer variables/weights. 

2232 

2233 Alias of `self.weights`. 

2234 

2235 Note: This will not track the weights of nested `tf.Modules` that are not 

2236 themselves Keras layers. 

2237 

2238 Returns: 

2239 A list of variables. 

2240 """ 

2241 return self.weights 

2242 

2243 @property 

2244 @doc_controls.do_not_generate_docs 

2245 def trainable_variables(self): 

2246 return self.trainable_weights 

2247 

2248 @property 

2249 @doc_controls.do_not_generate_docs 

2250 def non_trainable_variables(self): 

2251 return self.non_trainable_weights 

2252 

2253 ############################################################################## 

2254 # Methods & attributes below are all private and only used by the framework. # 

2255 ############################################################################## 

2256 

2257 @property 

2258 def _inbound_nodes(self): 

2259 return self._inbound_nodes_value 

2260 

2261 @_inbound_nodes.setter 

2262 @trackable.no_automatic_dependency_tracking 

2263 def _inbound_nodes(self, value): 

2264 self._inbound_nodes_value = value 

2265 

2266 @property 

2267 def _outbound_nodes(self): 

2268 return self._outbound_nodes_value 

2269 

2270 @_outbound_nodes.setter 

2271 @trackable.no_automatic_dependency_tracking 

2272 def _outbound_nodes(self, value): 

2273 self._outbound_nodes_value = value 

2274 

2275 def _set_dtype_policy(self, dtype): 

2276 """Sets self._dtype_policy.""" 

2277 if isinstance(dtype, policy.Policy): 

2278 self._dtype_policy = dtype 

2279 elif isinstance(dtype, dict): 

2280 self._dtype_policy = policy.deserialize(dtype) 

2281 elif isinstance(dtype, str) and dtype in ('mixed_float16', 

2282 'mixed_bfloat16'): 

2283 # The isinstance check is required since np.dtype raises an error if 

2284 # compared to a non-dtype string. 

2285 self._dtype_policy = policy.Policy(dtype) 

2286 elif dtype: 

2287 self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name) 

2288 else: 

2289 self._dtype_policy = policy.global_policy() 

2290 if (self._dtype_policy.name == 'mixed_float16' and 

2291 not loss_scale_optimizer.strategy_supports_loss_scaling()): 

2292 # Although only loss scaling doesn't support certain strategies, to avoid 

2293 # confusion, we disallow the 'mixed_float16' policy with unsupported 

2294 # strategies. This is because 'mixed_float16' requires loss scaling for 

2295 # numeric stability. 

2296 strategy = distribute_lib.get_strategy() 

2297 raise ValueError('Mixed precision is not supported with the ' 

2298 'tf.distribute.Strategy: %s. Either stop using mixed ' 

2299 'precision by removing the use of the "%s" policy or ' 

2300 'use a different Strategy, e.g. a MirroredStrategy.' % 

2301 (strategy.__class__.__name__, self._dtype_policy.name)) 

2302 

2303 # Performance optimization: cache the compute dtype as a Dtype object or 

2304 # None, so that str to Dtype conversion doesn't happen in Layer.__call__. 

2305 # TODO(b/157486353): Investigate returning DTypes in Policy. 

2306 if self._dtype_policy.compute_dtype: 

2307 self._compute_dtype_object = dtypes.as_dtype( 

2308 self._dtype_policy.compute_dtype) 

2309 else: 

2310 self._compute_dtype_object = None 

2311 

2312 @property 

2313 def dtype_policy(self): 

2314 """The dtype policy associated with this layer. 

2315 

2316 This is an instance of a `tf.keras.mixed_precision.Policy`. 

2317 """ 

2318 return self._dtype_policy 

2319 

2320 @property 

2321 def compute_dtype(self): 

2322 """The dtype of the layer's computations. 

2323 

2324 This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless 

2325 mixed precision is used, this is the same as `Layer.dtype`, the dtype of 

2326 the weights. 

2327 

2328 Layers automatically cast their inputs to the compute dtype, which causes 

2329 computations and the output to be in the compute dtype as well. This is done 

2330 by the base Layer class in `Layer.__call__`, so you do not have to insert 

2331 these casts if implementing your own layer. 

2332 

2333 Layers often perform certain internal computations in higher precision when 

2334 `compute_dtype` is float16 or bfloat16 for numeric stability. The output 

2335 will still typically be float16 or bfloat16 in such cases. 

2336 

2337 Returns: 

2338 The layer's compute dtype. 

2339 """ 

2340 return self._dtype_policy.compute_dtype 

2341 

2342 @property 

2343 def _compute_dtype(self): 

2344 """Deprecated alias of `compute_dtype`.""" 

2345 return self._dtype_policy.compute_dtype 

2346 

2347 @property 

2348 def variable_dtype(self): 

2349 """Alias of `Layer.dtype`, the dtype of the weights.""" 

2350 return self.dtype 

2351 

2352 def _maybe_cast_inputs(self, inputs, input_list=None): 

2353 """Maybe casts the inputs to the compute dtype. 

2354 

2355 If self._compute_dtype is floating-point, and self_autocast is True, 

2356 floating-point inputs are casted to self._compute_dtype. 

2357 

2358 Args: 

2359 inputs: Input tensor, or structure of input tensors. 

2360 input_list: Flat list of input tensors. 

2361 

2362 Returns: 

2363 `inputs`, but tensors may have been casted to self._compute_dtype 

2364 """ 

2365 if not input_list: 

2366 input_list = nest.flatten(inputs) 

2367 

2368 compute_dtype_object = self._compute_dtype_object 

2369 should_autocast = ( 

2370 self._autocast and compute_dtype_object and 

2371 compute_dtype_object.is_floating) 

2372 

2373 if (should_autocast and 

2374 any(map(self._should_cast_single_input, input_list))): 

2375 # Only perform expensive `nest` operation when needed. 

2376 return nest.map_structure(self._cast_single_input, inputs) 

2377 else: 

2378 return inputs 

2379 

2380 def _should_cast_single_input(self, x): 

2381 if isinstance(x, _AUTOCAST_TYPES): 

2382 return (self._compute_dtype_object and 

2383 x.dtype != self._compute_dtype_object and x.dtype.is_floating) 

2384 return False 

2385 

2386 def _cast_single_input(self, x): 

2387 """Cast a single Tensor or TensorSpec to the compute dtype.""" 

2388 if self._should_cast_single_input(x): 

2389 return math_ops.cast(x, self._compute_dtype_object) 

2390 else: 

2391 return x 

2392 

2393 # _dtype used to be an attribute set in the constructor. We still expose it 

2394 # because some clients still use it. 

2395 # TODO(reedwm): Deprecate, then remove the _dtype property. 

2396 @property 

2397 def _dtype(self): 

2398 # This is equivalent to returning self.dtype . We do not return self.dtype 

2399 # as it would cause infinite recursion in a few subclasses, which override 

2400 # "dtype" to return self._dtype. 

2401 return self._dtype_policy.variable_dtype 

2402 

2403 @_dtype.setter 

2404 def _dtype(self, value): 

2405 value = dtypes.as_dtype(value).name 

2406 self._set_dtype_policy(policy.Policy(value)) 

2407 

2408 def _name_scope(self): # pylint: disable=method-hidden 

2409 if not tf2.enabled(): 

2410 return self.name 

2411 name_scope = self.name 

2412 current_name_scope = ops.get_name_scope() 

2413 if current_name_scope: 

2414 name_scope = current_name_scope + '/' + name_scope 

2415 if name_scope: 

2416 # Note that the trailing `/` prevents autogenerated 

2417 # numerical suffixes to get appended. It will also fully reset 

2418 # nested name scope (i.e. the outer name scope has no effect). 

2419 name_scope += '/' 

2420 return name_scope 

2421 

2422 def _init_set_name(self, name, zero_based=True): 

2423 if not name: 

2424 self._name = backend.unique_object_name( 

2425 generic_utils.to_snake_case(self.__class__.__name__), 

2426 zero_based=zero_based) 

2427 else: 

2428 backend.observe_object_name(name) 

2429 self._name = name 

2430 

2431 def _get_existing_metric(self, name=None): 

2432 match = [m for m in self._metrics if m.name == name] 

2433 if not match: 

2434 return 

2435 if len(match) > 1: 

2436 raise ValueError( 

2437 'Please provide different names for the metrics you have added. ' 

2438 'We found {} metrics with the name: "{}"'.format(len(match), name)) 

2439 return match[0] 

2440 

2441 def _handle_weight_regularization(self, name, variable, regularizer): 

2442 """Create lambdas which compute regularization losses.""" 

2443 

2444 def _loss_for_variable(v): 

2445 """Creates a regularization loss `Tensor` for variable `v`.""" 

2446 with backend.name_scope(name + '/Regularizer'): 

2447 regularization = regularizer(v) 

2448 return regularization 

2449 

2450 if base_layer_utils.is_split_variable(variable): 

2451 for v in variable: 

2452 self.add_loss(functools.partial(_loss_for_variable, v)) 

2453 else: 

2454 self.add_loss(functools.partial(_loss_for_variable, variable)) 

2455 

2456 def _handle_activity_regularization(self, inputs, outputs): 

2457 # Apply activity regularization. 

2458 # Note that it should be applied every time the layer creates a new 

2459 # output, since it is output-specific. 

2460 if self._activity_regularizer: 

2461 output_list = nest.flatten(outputs) 

2462 with backend.name_scope('ActivityRegularizer'): 

2463 for output in output_list: 

2464 activity_loss = self._activity_regularizer(output) 

2465 batch_size = math_ops.cast( 

2466 array_ops.shape(output)[0], activity_loss.dtype) 

2467 # Make activity regularization strength batch-agnostic. 

2468 mean_activity_loss = activity_loss / batch_size 

2469 self.add_loss(mean_activity_loss) 

2470 

2471 def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph): 

2472 # Many `Layer`s don't need to call `compute_mask`. 

2473 # This method is optimized to do as little work as needed for the common 

2474 # case. 

2475 if not self._supports_masking: 

2476 return 

2477 

2478 flat_outputs = nest.flatten(outputs) 

2479 

2480 mask_already_computed = ( 

2481 getattr(self, '_compute_output_and_mask_jointly', False) or 

2482 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) 

2483 if mask_already_computed: 

2484 if build_graph: 

2485 self._set_mask_keras_history_checked(flat_outputs) 

2486 return 

2487 

2488 output_masks = self.compute_mask(inputs, previous_mask) 

2489 if output_masks is None: 

2490 return 

2491 

2492 flat_masks = nest.flatten(output_masks) 

2493 for tensor, mask in zip(flat_outputs, flat_masks): 

2494 try: 

2495 tensor._keras_mask = mask 

2496 except AttributeError: 

2497 # C Type such as np.ndarray. 

2498 pass 

2499 

2500 if build_graph: 

2501 self._set_mask_keras_history_checked(flat_outputs) 

2502 

2503 def _set_mask_keras_history_checked(self, flat_outputs): 

2504 for output in flat_outputs: 

2505 if getattr(output, '_keras_mask', None) is not None: 

2506 # Do not track masks for `TensorFlowOpLayer` construction. 

2507 output._keras_mask._keras_history_checked = True 

2508 

2509 def _get_input_masks(self, inputs, input_list, args, kwargs): 

2510 if not self._supports_masking and not self._expects_mask_arg: 

2511 # Input masks only need to be retrieved if they are needed for `call` 

2512 # or `compute_mask`. 

2513 input_masks = None 

2514 implicit_mask = False 

2515 elif self._call_arg_was_passed('mask', args, kwargs): 

2516 input_masks = self._get_call_arg_value('mask', args, kwargs) 

2517 implicit_mask = False 

2518 else: 

2519 input_masks = [getattr(t, '_keras_mask', None) for t in input_list] 

2520 if all(mask is None for mask in input_masks): 

2521 input_masks = None 

2522 implicit_mask = False 

2523 else: 

2524 # Only do expensive `nest` op when masking is actually being used. 

2525 input_masks = nest.pack_sequence_as(inputs, input_masks) 

2526 implicit_mask = True 

2527 return input_masks, implicit_mask 

2528 

2529 def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): 

2530 # Performance optimization: do no work in most common case. 

2531 if not args and not kwargs: 

2532 return False 

2533 

2534 if arg_name in kwargs: 

2535 return True 

2536 call_fn_args = self._call_fn_args 

2537 if not inputs_in_args: 

2538 # Ignore `inputs` arg. 

2539 call_fn_args = call_fn_args[1:] 

2540 return arg_name in dict(zip(call_fn_args, args)) 

2541 

2542 def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): 

2543 if arg_name in kwargs: 

2544 return kwargs[arg_name] 

2545 call_fn_args = self._call_fn_args 

2546 if not inputs_in_args: 

2547 # Ignore `inputs` arg. 

2548 call_fn_args = call_fn_args[1:] 

2549 args_dict = dict(zip(call_fn_args, args)) 

2550 return args_dict[arg_name] 

2551 

2552 def _set_call_arg_value( 

2553 self, arg_name, new_value, args, 

2554 kwargs, inputs_in_args=False, pop_kwarg_if_none=False): 

2555 arg_pos = self._call_fn_arg_positions.get(arg_name, None) 

2556 if arg_pos is not None: 

2557 if not inputs_in_args: 

2558 # Ignore `inputs` arg. 

2559 arg_pos = arg_pos - 1 

2560 if len(args) > arg_pos: 

2561 args = list(args) 

2562 args[arg_pos] = new_value 

2563 return tuple(args), kwargs 

2564 if new_value is None and pop_kwarg_if_none: 

2565 kwargs.pop(arg_name, None) 

2566 else: 

2567 kwargs[arg_name] = new_value 

2568 return args, kwargs 

2569 

2570 def _set_connectivity_metadata(self, args, kwargs, outputs): 

2571 # If the layer returns tensors from its inputs unmodified, 

2572 # we copy them to avoid loss of KerasHistory metadata. 

2573 flat_outputs = nest.flatten(outputs) 

2574 flat_inputs = nest.flatten((args, kwargs)) 

2575 input_ids_set = {id(i) for i in flat_inputs} 

2576 outputs_copy = [] 

2577 for x in flat_outputs: 

2578 if id(x) in input_ids_set: 

2579 with backend.name_scope(self.name): 

2580 x = array_ops.identity(x) 

2581 outputs_copy.append(x) 

2582 outputs = nest.pack_sequence_as(outputs, outputs_copy) 

2583 

2584 # Create node, Node wires itself to inbound and outbound layers. 

2585 # The Node constructor actually updates this layer's self._inbound_nodes, 

2586 # sets _keras_history on the outputs, and adds itself to the 

2587 # `_outbound_nodes` of the layers that produced the inputs to this 

2588 # layer call. 

2589 node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs) 

2590 return outputs 

2591 

2592 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 

2593 """Private utility to retrieves an attribute (e.g. inputs) from a node. 

2594 

2595 This is used to implement the methods: 

2596 - get_input_shape_at 

2597 - get_output_shape_at 

2598 - get_input_at 

2599 etc... 

2600 

2601 Args: 

2602 node_index: Integer index of the node from which 

2603 to retrieve the attribute. 

2604 attr: Exact node attribute name. 

2605 attr_name: Human-readable attribute name, for error messages. 

2606 

2607 Returns: 

2608 The layer's attribute `attr` at the node of index `node_index`. 

2609 

2610 Raises: 

2611 RuntimeError: If the layer has no inbound nodes, or if called in Eager 

2612 mode. 

2613 ValueError: If the index provided does not match any node. 

2614 """ 

2615 if not self._inbound_nodes: 

2616 raise RuntimeError('The layer has never been called ' 

2617 'and thus has no defined ' + attr_name + '.') 

2618 if not len(self._inbound_nodes) > node_index: 

2619 raise ValueError('Asked to get ' + attr_name + ' at node ' + 

2620 str(node_index) + ', but the layer has only ' + 

2621 str(len(self._inbound_nodes)) + ' inbound nodes.') 

2622 values = getattr(self._inbound_nodes[node_index], attr) 

2623 if isinstance(values, list) and len(values) == 1: 

2624 return values[0] 

2625 else: 

2626 return values 

2627 

2628 def _maybe_build(self, inputs): 

2629 # Check input assumptions set before layer building, e.g. input rank. 

2630 if not self.built: 

2631 input_spec.assert_input_compatibility( 

2632 self.input_spec, inputs, self.name) 

2633 input_list = nest.flatten(inputs) 

2634 if input_list and self._dtype_policy.compute_dtype is None: 

2635 try: 

2636 dtype = input_list[0].dtype.base_dtype.name 

2637 except AttributeError: 

2638 pass 

2639 else: 

2640 self._set_dtype_policy(policy.Policy(dtype)) 

2641 input_shapes = None 

2642 # Converts Tensors / CompositeTensors to TensorShapes. 

2643 if all(hasattr(x, 'shape') for x in input_list): 

2644 input_shapes = tf_utils.get_shapes(inputs) 

2645 else: 

2646 # Converts input shape to TensorShapes. 

2647 try: 

2648 input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False) 

2649 except ValueError: 

2650 pass 

2651 # Only call `build` if the user has manually overridden the build method. 

2652 if not hasattr(self.build, '_is_default'): 

2653 # Any setup work performed only once should happen in an `init_scope` 

2654 # to avoid creating symbolic Tensors that will later pollute any eager 

2655 # operations. 

2656 with tf_utils.maybe_init_scope(self): 

2657 self.build(input_shapes) # pylint:disable=not-callable 

2658 # We must set also ensure that the layer is marked as built, and the build 

2659 # shape is stored since user defined build functions may not be calling 

2660 # `super.build()` 

2661 Layer.build(self, input_shapes) 

2662 

2663 # Optionally load weight values specified at layer instantiation. 

2664 if self._initial_weights is not None: 

2665 with ops.init_scope(): 

2666 # Using `init_scope` since we want variable assignment in 

2667 # `set_weights` to be treated like variable initialization. 

2668 self.set_weights(self._initial_weights) 

2669 self._initial_weights = None 

2670 

2671 def _symbolic_call(self, inputs): 

2672 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 

2673 output_shapes = self.compute_output_shape(input_shapes) 

2674 # Convert to TensorShape so that nest.map_structure will not map into 

2675 # individual dim of the shape. 

2676 output_shapes = tf_utils.convert_shapes(output_shapes, to_tuples=False) 

2677 

2678 def _make_placeholder_like(shape): 

2679 ph = backend.placeholder(shape=shape, dtype=self.dtype) 

2680 ph._keras_mask = None 

2681 return ph 

2682 return nest.map_structure(_make_placeholder_like, output_shapes) 

2683 

2684 def _get_trainable_state(self): 

2685 """Get the `trainable` state of each sublayer. 

2686 

2687 Returns: 

2688 A dict mapping all sublayers to their `trainable` value. 

2689 """ 

2690 trainable_state = weakref.WeakKeyDictionary() 

2691 for layer in self._flatten_layers(): 

2692 trainable_state[layer] = layer.trainable 

2693 return trainable_state 

2694 

2695 def _set_trainable_state(self, trainable_state): 

2696 """Set `trainable` state for each sublayer.""" 

2697 for layer in self._flatten_layers(): 

2698 if layer in trainable_state: 

2699 layer.trainable = trainable_state[layer] 

2700 

2701 @property 

2702 def _obj_reference_counts(self): 

2703 """A dictionary counting the number of attributes referencing an object.""" 

2704 self._maybe_create_attribute('_obj_reference_counts_dict', 

2705 object_identity.ObjectIdentityDictionary()) 

2706 return self._obj_reference_counts_dict 

2707 

2708 @trackable.no_automatic_dependency_tracking 

2709 def _maybe_create_attribute(self, name, default_value): 

2710 """Create the attribute with the default value if it hasn't been created. 

2711 

2712 This is useful for fields that is used for tracking purpose, 

2713 _trainable_weights, or _layers. Note that user could create a layer subclass 

2714 and assign an internal field before invoking the Layer.__init__(), the 

2715 __setattr__() need to create the tracking fields and __init__() need to not 

2716 override them. 

2717 

2718 Args: 

2719 name: String, the name of the attribute. 

2720 default_value: Object, the default value of the attribute. 

2721 """ 

2722 if not hasattr(self, name): 

2723 self.__setattr__(name, default_value) 

2724 

2725 def __delattr__(self, name): 

2726 # For any super.__delattr__() call, we will directly use the implementation 

2727 # in Trackable and skip the behavior in AutoTrackable. The Layer was 

2728 # originally use Trackable as base class, the change of using Module as base 

2729 # class forced us to have AutoTrackable in the class hierarchy. 

2730 # 

2731 # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and 

2732 # __setattr__ in AutoTrackable may be unsustainable. 

2733 existing_value = getattr(self, name, None) 

2734 

2735 # If this value is replacing an existing object assigned to an attribute, we 

2736 # should clean it out to avoid leaking memory. First we check if there are 

2737 # other attributes referencing it. 

2738 reference_counts = self._obj_reference_counts 

2739 if existing_value not in reference_counts: 

2740 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 

2741 return 

2742 

2743 reference_count = reference_counts[existing_value] 

2744 if reference_count > 1: 

2745 # There are other remaining references. We can't remove this object from 

2746 # _layers etc. 

2747 reference_counts[existing_value] = reference_count - 1 

2748 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 

2749 return 

2750 else: 

2751 # This is the last remaining reference. 

2752 del reference_counts[existing_value] 

2753 

2754 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 

2755 

2756 if (isinstance(existing_value, Layer) 

2757 or base_layer_utils.has_weights(existing_value)): 

2758 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 

2759 '_self_tracked_trackables', 

2760 [l for l in self._self_tracked_trackables if l is not existing_value]) 

2761 if isinstance(existing_value, tf_variables.Variable): 

2762 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 

2763 '_trainable_weights', 

2764 [w for w in self._trainable_weights if w is not existing_value]) 

2765 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 

2766 '_non_trainable_weights', 

2767 [w for w in self._non_trainable_weights if w is not existing_value]) 

2768 

2769 def __setattr__(self, name, value): 

2770 if (name == '_self_setattr_tracking' or 

2771 not getattr(self, '_self_setattr_tracking', True) or 

2772 # Exclude @property.setters from tracking 

2773 hasattr(self.__class__, name)): 

2774 try: 

2775 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call 

2776 except AttributeError: 

2777 raise AttributeError( 

2778 ('Can\'t set the attribute "{}", likely because it conflicts with ' 

2779 'an existing read-only @property of the object. Please choose a ' 

2780 'different name.').format(name)) 

2781 return 

2782 

2783 # Wraps data structures in `Trackable`, unwraps `NoDependency` objects. 

2784 value = data_structures.sticky_attribute_assignment( 

2785 trackable=self, value=value, name=name) 

2786 

2787 reference_counts = self._obj_reference_counts 

2788 reference_counts[value] = reference_counts.get(value, 0) + 1 

2789 

2790 # Clean out the old attribute, which clears _layers and _trainable_weights 

2791 # if necessary. 

2792 try: 

2793 self.__delattr__(name) 

2794 except AttributeError: 

2795 pass 

2796 

2797 # Keep track of metric instance created in subclassed layer. 

2798 for val in nest.flatten(value): 

2799 if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'): 

2800 self._metrics.append(val) 

2801 

2802 # Append value to self._self_tracked_trackables if relevant 

2803 if (getattr(self, '_auto_track_sub_layers', True) and 

2804 (isinstance(value, module.Module) or 

2805 base_layer_utils.has_weights(value))): 

2806 self._maybe_create_attribute('_self_tracked_trackables', []) 

2807 # We need to check object identity to avoid de-duplicating empty 

2808 # container types which compare equal. 

2809 if not any((layer is value for layer in self._self_tracked_trackables)): 

2810 self._self_tracked_trackables.append(value) 

2811 if hasattr(value, '_use_resource_variables'): 

2812 # Legacy layers (V1 tf.layers) must always use 

2813 # resource variables. 

2814 value._use_resource_variables = True 

2815 

2816 # Append value to list of trainable / non-trainable weights if relevant 

2817 # TODO(b/125122625): This won't pick up on any variables added to a 

2818 # list/dict after creation. 

2819 for val in nest.flatten(value, expand_composites=True): 

2820 if not isinstance(val, tf_variables.Variable): 

2821 continue 

2822 

2823 # Users may add extra weights/variables 

2824 # simply by assigning them to attributes (invalid for graph networks) 

2825 self._maybe_create_attribute('_trainable_weights', []) 

2826 self._maybe_create_attribute('_non_trainable_weights', []) 

2827 if val.trainable: 

2828 if any(val is w for w in self._trainable_weights): 

2829 continue 

2830 self._trainable_weights.append(val) 

2831 else: 

2832 if any(val is w for w in self._non_trainable_weights): 

2833 continue 

2834 self._non_trainable_weights.append(val) 

2835 

2836 backend.track_variable(val) 

2837 

2838 # TODO(b/180760306) Skip the auto trackable from tf.Module to keep status 

2839 # quo. See the comment at __delattr__. 

2840 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call 

2841 

2842 def _gather_children_attribute(self, attribute): 

2843 assert attribute in { 

2844 'variables', 'trainable_variables', 'non_trainable_variables' 

2845 } 

2846 if hasattr(self, '_self_tracked_trackables'): 

2847 nested_layers = self._flatten_modules(include_self=False, recursive=False) 

2848 return list( 

2849 itertools.chain.from_iterable( 

2850 getattr(layer, attribute) for layer in nested_layers)) 

2851 return [] 

2852 

2853 def _flatten_layers(self, recursive=True, include_self=True): 

2854 for m in self._flatten_modules( 

2855 recursive=recursive, include_self=include_self): 

2856 if isinstance(m, Layer): 

2857 yield m 

2858 

2859 def _flatten_modules(self, recursive=True, include_self=True): 

2860 """Flattens `tf.Module` instances (excluding `Metrics`). 

2861 

2862 Args: 

2863 recursive: Whether to recursively flatten through submodules. 

2864 include_self: Whether to include this `Layer` instance. 

2865 

2866 Yields: 

2867 `tf.Module` instance tracked by this `Layer`. 

2868 """ 

2869 if include_self: 

2870 yield self 

2871 

2872 # Only instantiate set and deque if needed. 

2873 trackables = getattr(self, '_self_tracked_trackables', None) 

2874 if trackables: 

2875 seen_object_ids = set() 

2876 deque = collections.deque(trackables) 

2877 while deque: 

2878 trackable_obj = deque.popleft() 

2879 trackable_id = id(trackable_obj) 

2880 if trackable_id in seen_object_ids: 

2881 continue 

2882 seen_object_ids.add(trackable_id) 

2883 

2884 # Metrics are not considered part of the Layer's topology. 

2885 if (isinstance(trackable_obj, module.Module) and 

2886 not isinstance(trackable_obj, metrics_mod.Metric)): 

2887 yield trackable_obj 

2888 # Introspect recursively through sublayers. 

2889 if recursive: 

2890 subtrackables = getattr(trackable_obj, '_self_tracked_trackables', 

2891 None) 

2892 if subtrackables: 

2893 deque.extendleft(reversed(subtrackables)) 

2894 elif isinstance(trackable_obj, data_structures.TrackableDataStructure): 

2895 # Data structures are introspected even with `recursive=False`. 

2896 tracked_values = trackable_obj._values 

2897 if tracked_values: 

2898 deque.extendleft(reversed(tracked_values)) 

2899 

2900 # This is a hack so that the is_layer (within 

2901 # training/trackable/layer_utils.py) check doesn't get the weights attr. 

2902 # TODO(b/110718070): Remove when fixed. 

2903 def _is_layer(self): 

2904 return True 

2905 

2906 def _init_call_fn_args(self, expects_training_arg=None): 

2907 # Clear cached call function arguments. 

2908 self.__class__._call_full_argspec.fget.cache.pop(self, None) 

2909 self.__class__._call_fn_args.fget.cache.pop(self, None) 

2910 self.__class__._call_accepts_kwargs.fget.cache.pop(self, None) 

2911 

2912 call_fn_args = self._call_fn_args 

2913 call_fn_args += self._call_full_argspec.kwonlyargs or [] 

2914 if expects_training_arg is None: 

2915 self._expects_training_arg = ('training' in call_fn_args or 

2916 self._call_accepts_kwargs) 

2917 else: 

2918 # Use value encoded into the metadata when loading from the SavedModel. 

2919 self._expects_training_arg = expects_training_arg 

2920 # The default training arg will be any (non-None) default specified in the 

2921 # method signature, or None if no value is specified. 

2922 call_fn_arg_defaults = self._call_fn_arg_defaults.copy() 

2923 call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {}) 

2924 self._default_training_arg = call_fn_arg_defaults.get('training') 

2925 

2926 self._expects_mask_arg = ('mask' in call_fn_args or 

2927 self._call_accepts_kwargs) 

2928 

2929 @property 

2930 @layer_utils.cached_per_instance 

2931 def _call_full_argspec(self): 

2932 # Argspec inspection is expensive and the call spec is used often, so it 

2933 # makes sense to cache the result. 

2934 return tf_inspect.getfullargspec(self.call) 

2935 

2936 @property 

2937 @layer_utils.cached_per_instance 

2938 def _call_fn_args(self): 

2939 all_args = self._call_full_argspec.args 

2940 # Scrub `self` that appears if a decorator was applied. 

2941 if all_args and all_args[0] == 'self': 

2942 return all_args[1:] 

2943 return all_args 

2944 

2945 @property 

2946 @layer_utils.cached_per_instance 

2947 def _call_fn_arg_defaults(self): 

2948 call_fn_args = self._call_fn_args 

2949 call_fn_defaults = self._call_full_argspec.defaults or [] 

2950 defaults = dict() 

2951 

2952 # The call arg defaults are an n-tuple of the last n elements of the args 

2953 # list. (n = # of elements that have a default argument) 

2954 for i in range(-1 * len(call_fn_defaults), 0): 

2955 defaults[call_fn_args[i]] = call_fn_defaults[i] 

2956 return defaults 

2957 

2958 @property 

2959 @layer_utils.cached_per_instance 

2960 def _call_fn_arg_positions(self): 

2961 call_fn_arg_positions = dict() 

2962 for pos, arg in enumerate(self._call_fn_args): 

2963 call_fn_arg_positions[arg] = pos 

2964 return call_fn_arg_positions 

2965 

2966 @property 

2967 @layer_utils.cached_per_instance 

2968 def _call_accepts_kwargs(self): 

2969 return self._call_full_argspec.varkw is not None 

2970 

2971 @property 

2972 def _eager_losses(self): 

2973 # A list of loss values containing activity regularizers and losses 

2974 # manually added through `add_loss` during eager execution. It is cleared 

2975 # after every batch. 

2976 # Because we plan on eventually allowing a same model instance to be trained 

2977 # in eager mode or graph mode alternatively, we need to keep track of 

2978 # eager losses and symbolic losses via separate attributes. 

2979 if not hasattr(self._thread_local, '_eager_losses'): 

2980 self._thread_local._eager_losses = [] 

2981 return self._thread_local._eager_losses 

2982 

2983 @_eager_losses.setter 

2984 def _eager_losses(self, losses): 

2985 self._thread_local._eager_losses = losses 

2986 

2987 def _dedup_weights(self, weights): 

2988 """Dedupe weights while maintaining order as much as possible.""" 

2989 output, seen_ids = [], set() 

2990 for w in weights: 

2991 if id(w) not in seen_ids: 

2992 output.append(w) 

2993 # Track the Variable's identity to avoid __eq__ issues. 

2994 seen_ids.add(id(w)) 

2995 

2996 return output 

2997 

2998 def _split_out_first_arg(self, args, kwargs): 

2999 # Grab the argument corresponding to the first argument in the 

3000 # layer's `call` method spec. This will either be the first positional 

3001 # argument, or it will be provided as a keyword argument. 

3002 if args: 

3003 inputs = args[0] 

3004 args = args[1:] 

3005 elif self._call_fn_args[0] in kwargs: 

3006 kwargs = copy.copy(kwargs) 

3007 inputs = kwargs.pop(self._call_fn_args[0]) 

3008 else: 

3009 raise ValueError( 

3010 'The first argument to `Layer.call` must always be passed.') 

3011 return inputs, args, kwargs 

3012 

3013 # SavedModel properties. Please see keras/saving/saved_model for details. 

3014 

3015 @trackable.no_automatic_dependency_tracking 

3016 def _set_save_spec(self, inputs): 

3017 if self._saved_model_inputs_spec is not None: 

3018 return # Already set. 

3019 

3020 self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec, 

3021 inputs) 

3022 

3023 def _get_save_spec(self, dynamic_batch=True): 

3024 if self._saved_model_inputs_spec is None: 

3025 return None 

3026 

3027 return nest.map_structure( 

3028 lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), 

3029 self._saved_model_inputs_spec) 

3030 

3031 @property 

3032 def _trackable_saved_model_saver(self): 

3033 return layer_serialization.LayerSavedModelSaver(self) 

3034 

3035 @property 

3036 def _object_identifier(self): 

3037 return self._trackable_saved_model_saver.object_identifier 

3038 

3039 @property 

3040 def _tracking_metadata(self): 

3041 """Info about this layer to be saved into the SavedModel.""" 

3042 return self._trackable_saved_model_saver.tracking_metadata 

3043 

3044 def _trackable_children(self, save_type='checkpoint', **kwargs): 

3045 if save_type == 'savedmodel': 

3046 cache = kwargs['cache'] 

3047 # TODO(b/213628533): This must be called before super() to ensure 

3048 # that any input shape changes are applied before getting the config of 

3049 # the model. 

3050 children = self._trackable_saved_model_saver.trackable_children(cache) 

3051 else: 

3052 children = {} 

3053 children.update(super()._trackable_children(save_type, **kwargs)) 

3054 return children 

3055 

3056 @property 

3057 def _use_input_spec_as_call_signature(self): 

3058 # Whether input spec can be used as the call signature when tracing the 

3059 # Layer for SavedModel. By default, this is set to `True` for layers 

3060 # exported from the Keras library, because the layers more rigidly define 

3061 # the `input_specs` property (many custom layers only set the `ndims`) 

3062 return get_canonical_name_for_symbol(type(self), 

3063 api_name='keras') is not None 

3064 

3065 def __getstate__(self): 

3066 # Override to support `copy.deepcopy` and pickling. 

3067 # Thread-local objects cannot be copied in Python 3, so pop these. 

3068 # Thread-local objects are used to cache losses in MirroredStrategy, and 

3069 # so shouldn't be copied. 

3070 state = self.__dict__.copy() 

3071 state.pop('_thread_local', None) 

3072 state.pop('_metrics_lock', None) 

3073 return state 

3074 

3075 def __setstate__(self, state): 

3076 state['_thread_local'] = threading.local() 

3077 state['_metrics_lock'] = threading.Lock() 

3078 # Bypass Trackable logic as `__dict__` already contains this info. 

3079 object.__setattr__(self, '__dict__', state) 

3080 

3081 

3082class TensorFlowOpLayer(Layer): 

3083 """Wraps a TensorFlow Operation in a Layer. 

3084 

3085 This class is used internally by the Functional API. When a user 

3086 uses a raw TensorFlow Operation on symbolic tensors originating 

3087 from an `Input` Layer, the resultant operation will be wrapped 

3088 with this Layer object in order to make the operation compatible 

3089 with the Keras API. 

3090 

3091 This Layer will create a new, identical operation (except for inputs 

3092 and outputs) every time it is called. If `run_eagerly` is `True`, 

3093 the op creation and calculation will happen inside an Eager function. 

3094 

3095 Instances of this Layer are created when `autolambda` is called, which 

3096 is whenever a Layer's `__call__` encounters symbolic inputs that do 

3097 not have Keras metadata, or when a Network's `__init__` encounters 

3098 outputs that do not have Keras metadata. 

3099 

3100 Attributes: 

3101 node_def: String, the serialized NodeDef of the Op this layer will wrap. 

3102 name: String, the name of the Layer. 

3103 constants: Dict of NumPy arrays, the values of any Tensors needed for this 

3104 Operation that do not originate from a Keras `Input` Layer. Since all 

3105 placeholders must come from Keras `Input` Layers, these Tensors must be 

3106 treated as constant in the Functional API. 

3107 trainable: Bool, whether this Layer is trainable. Currently Variables are 

3108 not supported, and so this parameter has no effect. 

3109 dtype: The default dtype of this Layer. Inherited from `Layer` and has no 

3110 effect on this class, however is used in `get_config`. 

3111 """ 

3112 

3113 @trackable.no_automatic_dependency_tracking 

3114 def __init__(self, 

3115 node_def, 

3116 name, 

3117 constants=None, 

3118 trainable=True, 

3119 dtype=None): 

3120 # Pass autocast=False, as if inputs are cast, input types might not match 

3121 # Operation type. 

3122 super(TensorFlowOpLayer, self).__init__( 

3123 name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype, 

3124 autocast=False) 

3125 if isinstance(node_def, dict): 

3126 self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef()) 

3127 else: 

3128 if not isinstance(node_def, bytes): 

3129 node_def = node_def.encode('utf-8') 

3130 self.node_def = node_def_pb2.NodeDef.FromString(node_def) 

3131 # JSON serialization stringifies keys which are integer input indices. 

3132 self.constants = ({ 

3133 int(index): constant for index, constant in constants.items() 

3134 } if constants is not None else {}) 

3135 # Layer uses original op unless it is called on new inputs. 

3136 # This means `built` is not set in `__call__`. 

3137 self.built = True 

3138 

3139 # Do not individually trace TensorflowOpLayers in the SavedModel. 

3140 self._must_restore_from_config = True 

3141 

3142 def call(self, inputs): 

3143 if context.executing_eagerly(): 

3144 return self._defun_call(inputs) 

3145 return self._make_op(inputs) 

3146 

3147 def _make_node_def(self, graph): 

3148 node_def = node_def_pb2.NodeDef() 

3149 node_def.CopyFrom(self.node_def) 

3150 # Used in TPUReplicateContext to indicate whether this node has been cloned 

3151 # and to not add TPU attributes. 

3152 node_def.attr['_cloned'].b = True 

3153 node_def.name = graph.unique_name(node_def.name) 

3154 return node_def 

3155 

3156 def _make_op(self, inputs): 

3157 inputs = nest.flatten(inputs) 

3158 graph = inputs[0].graph 

3159 node_def = self._make_node_def(graph) 

3160 with graph.as_default(): 

3161 for index, constant in self.constants.items(): 

3162 # Recreate constant in graph to add distribution context. 

3163 value = tensor_util.constant_value(constant) 

3164 if value is not None: 

3165 constant = constant_op.constant(value, name=node_def.input[index]) 

3166 inputs.insert(index, constant) 

3167 # TODO(b/183990973): We should drop or consolidate these private api calls 

3168 # for adding an op to the graph and recording its gradient. 

3169 c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[]) 

3170 op = graph._create_op_from_tf_operation(c_op) 

3171 op._control_flow_post_processing() 

3172 

3173 # Record the gradient because custom-made ops don't go through the 

3174 # code-gen'd eager call path 

3175 op_type = compat.as_str(op.op_def.name) 

3176 attr_names = [compat.as_str(attr.name) for attr in op.op_def.attr] 

3177 attrs = [] 

3178 for attr_name in attr_names: 

3179 attrs.append(attr_name) 

3180 attrs.append(op.get_attr(attr_name)) 

3181 attrs = tuple(attrs) 

3182 backprop.record_gradient(op_type, op.inputs, attrs, op.outputs) 

3183 

3184 if len(op.outputs) == 1: 

3185 return op.outputs[0] 

3186 return op.outputs 

3187 

3188 @def_function.function 

3189 def _defun_call(self, inputs): 

3190 """Wraps the op creation method in an Eager function for `run_eagerly`.""" 

3191 return self._make_op(inputs) 

3192 

3193 def get_config(self): 

3194 config = super(TensorFlowOpLayer, self).get_config() 

3195 config.update({ 

3196 # `__init__` prefixes the name. Revert to the constructor argument. 

3197 'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):], 

3198 'node_def': json_format.MessageToDict(self.node_def), 

3199 'constants': { 

3200 i: backend.get_value(c) for i, c in self.constants.items() 

3201 } 

3202 }) 

3203 return config 

3204 

3205 

3206class AddLoss(Layer): 

3207 """Adds its inputs as a loss. 

3208 

3209 Attributes: 

3210 unconditional: Whether or not the loss should be conditioned on the inputs. 

3211 """ 

3212 

3213 def __init__(self, unconditional, **kwargs): 

3214 # Pass autocast=False, as there is no reason to cast loss to a different 

3215 # dtype. 

3216 kwargs['autocast'] = False 

3217 super(AddLoss, self).__init__(**kwargs) 

3218 self.unconditional = unconditional 

3219 

3220 def call(self, inputs): 

3221 self.add_loss(inputs, inputs=(not self.unconditional)) 

3222 return inputs 

3223 

3224 def get_config(self): 

3225 config = super(AddLoss, self).get_config() 

3226 config.update({'unconditional': self.unconditional}) 

3227 return config 

3228 

3229 

3230class AddMetric(Layer): 

3231 """Adds its inputs as a metric. 

3232 

3233 Attributes: 

3234 aggregation: 'mean' or None. How the inputs should be aggregated. 

3235 metric_name: The name to use for this metric. 

3236 """ 

3237 

3238 def __init__(self, aggregation=None, metric_name=None, **kwargs): 

3239 super(AddMetric, self).__init__(**kwargs) 

3240 self.aggregation = aggregation 

3241 self.metric_name = metric_name 

3242 

3243 def call(self, inputs): 

3244 self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name) 

3245 return inputs 

3246 

3247 def get_config(self): 

3248 config = super(AddMetric, self).get_config() 

3249 config.update({ 

3250 'aggregation': self.aggregation, 

3251 'metric_name': self.metric_name 

3252 }) 

3253 return config 

3254 

3255 

3256def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument 

3257 """Check the arguments to see if we are constructing a functional model.""" 

3258 # We are constructing a functional model if any of the inputs 

3259 # are KerasTensors 

3260 return any( 

3261 isinstance(tensor, keras_tensor.KerasTensor) 

3262 for tensor in nest.flatten([inputs, args, kwargs])) 

3263 

3264 

3265def _convert_numpy_or_python_types(x): 

3266 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): 

3267 return tensor_conversion.convert_to_tensor_v2_with_dispatch(x) 

3268 return x 

3269 

3270 

3271# Avoid breaking users who directly import this symbol from this file. 

3272# TODO(fchollet): remove this. 

3273InputSpec = input_spec.InputSpec # pylint:disable=invalid-name