Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/base_layer.py: 22%

1285 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Contains the base Layer class, from which all layers inherit.""" 

18 

19import collections 

20import contextlib 

21import functools 

22import itertools 

23import textwrap 

24import threading 

25import warnings 

26import weakref 

27 

28import numpy as np 

29import tensorflow.compat.v2 as tf 

30 

31from keras.src import backend 

32from keras.src import constraints 

33from keras.src import initializers 

34from keras.src import regularizers 

35from keras.src.dtensor import lazy_variable 

36from keras.src.engine import base_layer_utils 

37from keras.src.engine import input_spec 

38from keras.src.engine import keras_tensor 

39from keras.src.engine import node as node_module 

40from keras.src.mixed_precision import autocast_variable 

41from keras.src.mixed_precision import policy 

42from keras.src.saving import serialization_lib 

43from keras.src.saving.legacy.saved_model import layer_serialization 

44from keras.src.utils import generic_utils 

45from keras.src.utils import layer_utils 

46from keras.src.utils import object_identity 

47from keras.src.utils import tf_inspect 

48from keras.src.utils import tf_utils 

49from keras.src.utils import traceback_utils 

50from keras.src.utils import version_utils 

51 

52# A module that only depends on `keras.layers` import these from here. 

53from keras.src.utils.generic_utils import to_snake_case # noqa: F401 

54from keras.src.utils.tf_utils import is_tensor_or_tensor_list # noqa: F401 

55 

56# isort: off 

57from google.protobuf import json_format 

58from tensorflow.python.platform import tf_logging 

59from tensorflow.python.util.tf_export import ( 

60 get_canonical_name_for_symbol, 

61) 

62from tensorflow.python.util.tf_export import keras_export 

63from tensorflow.tools.docs import doc_controls 

64 

65 

66metrics_mod = generic_utils.LazyLoader( 

67 "metrics_mod", globals(), "keras.src.metrics" 

68) 

69 

70 

71# Prefix that is added to the TF op layer names. 

72_TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_" 

73 

74# TODO(mdan): Should we have a single generic type for types that can be passed 

75# to tf.cast? 

76_AUTOCAST_TYPES = (tf.Tensor, tf.SparseTensor, tf.RaggedTensor) 

77 

78keras_layers_gauge = tf.__internal__.monitoring.BoolGauge( 

79 "/tensorflow/api/keras/layers", "keras layers usage", "method" 

80) 

81keras_models_gauge = tf.__internal__.monitoring.BoolGauge( 

82 "/tensorflow/api/keras/models", "keras model usage", "method" 

83) 

84keras_api_gauge = tf.__internal__.monitoring.BoolGauge( 

85 "/tensorflow/api/keras", "keras api usage", "method" 

86) 

87keras_premade_model_gauge = tf.__internal__.monitoring.BoolGauge( 

88 "/tensorflow/api/keras/premade_models", "premade keras model usage", "type" 

89) 

90 

91_is_name_scope_on_model_declaration_enabled = False 

92 

93_name_scope_unnester_stack = threading.local() 

94 

95 

96@contextlib.contextmanager 

97def _name_scope_unnester(full_name_scope): 

98 """Helper to get relative name scope from fully-speced nested name scopes. 

99 

100 Args: 

101 full_name_scope: full(absolute) name scope path. 

102 

103 Yields: 

104 Relative name scope path from the parent `_name_scope_unnester` context 

105 manager. 

106 

107 Example: 

108 ``` 

109 with _name_scope_unnester('a') as name1: # name1 == 'a' 

110 with _name_scope_unnester('a/b') as name2: # name2 == 'b' 

111 with _name_scope_unnester('a/b/c') as name3: # name3 == 'c' 

112 pass 

113 ``` 

114 """ 

115 if not getattr(_name_scope_unnester_stack, "value", None): 

116 _name_scope_unnester_stack.value = [""] 

117 

118 _name_scope_unnester_stack.value.append(full_name_scope) 

119 

120 try: 

121 full_name_scope = _name_scope_unnester_stack.value[-1] 

122 outer_name_scope = _name_scope_unnester_stack.value[-2] 

123 relative_name_scope = full_name_scope.lstrip(outer_name_scope) 

124 relative_name_scope = relative_name_scope.lstrip("/") 

125 yield relative_name_scope 

126 finally: 

127 _name_scope_unnester_stack.value.pop() 

128 

129 

130@keras_export("keras.layers.Layer") 

131class Layer(tf.Module, version_utils.LayerVersionSelector): 

132 """This is the class from which all layers inherit. 

133 

134 A layer is a callable object that takes as input one or more tensors and 

135 that outputs one or more tensors. It involves *computation*, defined 

136 in the `call()` method, and a *state* (weight variables). State can be 

137 created in various places, at the convenience of the subclass implementer: 

138 

139 * in `__init__()`; 

140 * in the optional `build()` method, which is invoked by the first 

141 `__call__()` to the layer, and supplies the shape(s) of the input(s), 

142 which may not have been known at initialization time; 

143 * in the first invocation of `call()`, with some caveats discussed 

144 below. 

145 

146 Layers are recursively composable: If you assign a Layer instance as an 

147 attribute of another Layer, the outer layer will start tracking the weights 

148 created by the inner layer. Nested layers should be instantiated in the 

149 `__init__()` method. 

150 

151 Users will just instantiate a layer and then treat it as a callable. 

152 

153 Args: 

154 trainable: Boolean, whether the layer's variables should be trainable. 

155 name: String name of the layer. 

156 dtype: The dtype of the layer's computations and weights. Can also be a 

157 `tf.keras.mixed_precision.Policy`, which allows the computation and 

158 weight dtype to differ. Default of `None` means to use 

159 `tf.keras.mixed_precision.global_policy()`, which is a float32 policy 

160 unless set to different value. 

161 dynamic: Set this to `True` if your layer should only be run eagerly, and 

162 should not be used to generate a static computation graph. 

163 This would be the case for a Tree-RNN or a recursive network, 

164 for example, or generally for any layer that manipulates tensors 

165 using Python control flow. If `False`, we assume that the layer can 

166 safely be used to generate a static computation graph. 

167 

168 Attributes: 

169 name: The name of the layer (string). 

170 dtype: The dtype of the layer's weights. 

171 variable_dtype: Alias of `dtype`. 

172 compute_dtype: The dtype of the layer's computations. Layers automatically 

173 cast inputs to this dtype which causes the computations and output to 

174 also be in this dtype. When mixed precision is used with a 

175 `tf.keras.mixed_precision.Policy`, this will be different than 

176 `variable_dtype`. 

177 dtype_policy: The layer's dtype policy. See the 

178 `tf.keras.mixed_precision.Policy` documentation for details. 

179 trainable_weights: List of variables to be included in backprop. 

180 non_trainable_weights: List of variables that should not be 

181 included in backprop. 

182 weights: The concatenation of the lists trainable_weights and 

183 non_trainable_weights (in this order). 

184 trainable: Whether the layer should be trained (boolean), i.e. whether 

185 its potentially-trainable weights should be returned as part of 

186 `layer.trainable_weights`. 

187 input_spec: Optional (list of) `InputSpec` object(s) specifying the 

188 constraints on inputs that can be accepted by the layer. 

189 

190 We recommend that descendants of `Layer` implement the following methods: 

191 

192 * `__init__()`: Defines custom layer attributes, and creates layer weights 

193 that do not depend on input shapes, using `add_weight()`, or other state. 

194 * `build(self, input_shape)`: This method can be used to create weights that 

195 depend on the shape(s) of the input(s), using `add_weight()`, or other 

196 state. `__call__()` will automatically build the layer (if it has not been 

197 built yet) by calling `build()`. 

198 * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making 

199 sure `build()` has been called. `call()` performs the logic of applying 

200 the layer to the `inputs`. The first invocation may additionally create 

201 state that could not be conveniently created in `build()`; see its 

202 docstring for details. 

203 Two reserved keyword arguments you can optionally use in `call()` are: 

204 - `training` (boolean, whether the call is in inference mode or training 

205 mode). See more details in [the layer/model subclassing guide]( 

206 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method) 

207 - `mask` (boolean tensor encoding masked timesteps in the input, used 

208 in RNN layers). See more details in 

209 [the layer/model subclassing guide]( 

210 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method) 

211 A typical signature for this method is `call(self, inputs)`, and user 

212 could optionally add `training` and `mask` if the layer need them. `*args` 

213 and `**kwargs` is only useful for future extension when more input 

214 parameters are planned to be added. 

215 * `get_config(self)`: Returns a dictionary containing the configuration used 

216 to initialize this layer. If the keys differ from the arguments 

217 in `__init__`, then override `from_config(self)` as well. 

218 This method is used when saving 

219 the layer or a model that contains this layer. 

220 

221 Examples: 

222 

223 Here's a basic example: a layer with two variables, `w` and `b`, 

224 that returns `y = w . x + b`. 

225 It shows how to implement `build()` and `call()`. 

226 Variables set as attributes of a layer are tracked as weights 

227 of the layers (in `layer.weights`). 

228 

229 ```python 

230 class SimpleDense(Layer): 

231 

232 def __init__(self, units=32): 

233 super(SimpleDense, self).__init__() 

234 self.units = units 

235 

236 def build(self, input_shape): # Create the state of the layer (weights) 

237 w_init = tf.random_normal_initializer() 

238 self.w = tf.Variable( 

239 initial_value=w_init(shape=(input_shape[-1], self.units), 

240 dtype='float32'), 

241 trainable=True) 

242 b_init = tf.zeros_initializer() 

243 self.b = tf.Variable( 

244 initial_value=b_init(shape=(self.units,), dtype='float32'), 

245 trainable=True) 

246 

247 def call(self, inputs): # Defines the computation from inputs to outputs 

248 return tf.matmul(inputs, self.w) + self.b 

249 

250 # Instantiates the layer. 

251 linear_layer = SimpleDense(4) 

252 

253 # This will also call `build(input_shape)` and create the weights. 

254 y = linear_layer(tf.ones((2, 2))) 

255 assert len(linear_layer.weights) == 2 

256 

257 # These weights are trainable, so they're listed in `trainable_weights`: 

258 assert len(linear_layer.trainable_weights) == 2 

259 ``` 

260 

261 Note that the method `add_weight()` offers a shortcut to create weights: 

262 

263 ```python 

264 class SimpleDense(Layer): 

265 

266 def __init__(self, units=32): 

267 super(SimpleDense, self).__init__() 

268 self.units = units 

269 

270 def build(self, input_shape): 

271 self.w = self.add_weight(shape=(input_shape[-1], self.units), 

272 initializer='random_normal', 

273 trainable=True) 

274 self.b = self.add_weight(shape=(self.units,), 

275 initializer='random_normal', 

276 trainable=True) 

277 

278 def call(self, inputs): 

279 return tf.matmul(inputs, self.w) + self.b 

280 ``` 

281 

282 Besides trainable weights, updated via backpropagation during training, 

283 layers can also have non-trainable weights. These weights are meant to 

284 be updated manually during `call()`. Here's a example layer that computes 

285 the running sum of its inputs: 

286 

287 ```python 

288 class ComputeSum(Layer): 

289 

290 def __init__(self, input_dim): 

291 super(ComputeSum, self).__init__() 

292 # Create a non-trainable weight. 

293 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), 

294 trainable=False) 

295 

296 def call(self, inputs): 

297 self.total.assign_add(tf.reduce_sum(inputs, axis=0)) 

298 return self.total 

299 

300 my_sum = ComputeSum(2) 

301 x = tf.ones((2, 2)) 

302 

303 y = my_sum(x) 

304 print(y.numpy()) # [2. 2.] 

305 

306 y = my_sum(x) 

307 print(y.numpy()) # [4. 4.] 

308 

309 assert my_sum.weights == [my_sum.total] 

310 assert my_sum.non_trainable_weights == [my_sum.total] 

311 assert my_sum.trainable_weights == [] 

312 ``` 

313 

314 For more information about creating layers, see the guide 

315 [Making new Layers and Models via subclassing]( 

316 https://www.tensorflow.org/guide/keras/custom_layers_and_models) 

317 """ 

318 

319 @tf.__internal__.tracking.no_automatic_dependency_tracking 

320 def __init__( 

321 self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs 

322 ): 

323 self._instrument_layer_creation() 

324 

325 # These properties should be set by the user via keyword arguments. 

326 # note that 'dtype', 'input_shape' and 'batch_input_shape' 

327 # are only applicable to input layers: do not pass these keywords 

328 # to non-input layers. 

329 allowed_kwargs = { 

330 "input_dim", 

331 "input_shape", 

332 "batch_input_shape", 

333 "batch_size", 

334 "weights", 

335 "activity_regularizer", 

336 "autocast", 

337 "implementation", 

338 } 

339 # Validate optional keyword arguments. 

340 generic_utils.validate_kwargs(kwargs, allowed_kwargs) 

341 

342 # Mutable properties 

343 # Indicates whether the layer's weights are updated during training 

344 # and whether the layer's updates are run during training. 

345 if not ( 

346 isinstance(trainable, bool) 

347 or ( 

348 isinstance(trainable, (tf.Tensor, tf.Variable)) 

349 and trainable.dtype is tf.bool 

350 ) 

351 ): 

352 raise TypeError( 

353 "Expected `trainable` argument to be a boolean, " 

354 f"but got: {trainable}" 

355 ) 

356 self._trainable = trainable 

357 # A stateful layer is a layer whose updates are run during inference 

358 # too, for instance stateful RNNs. 

359 self._stateful = False 

360 # Indicates whether `build` needs to be called upon layer call, to 

361 # create the layer's weights. (Note that the first call() may also 

362 # create weights, independent of build().) 

363 self.built = False 

364 # Provides information about which inputs are compatible with the layer. 

365 self._input_spec = None 

366 

367 # SavedModel-related attributes. 

368 # Record the build input shape for loading purposes. 

369 # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is 

370 # submitted. 

371 self._build_input_shape = None 

372 self._saved_model_inputs_spec = None 

373 self._saved_model_arg_spec = None 

374 

375 # `Layer.compute_mask` will be called at the end of `Layer.__call__` if 

376 # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets 

377 # `self.supports_masking=True`. 

378 self._supports_masking = not generic_utils.is_default(self.compute_mask) 

379 

380 self._init_set_name(name) 

381 self._activity_regularizer = regularizers.get( 

382 kwargs.pop("activity_regularizer", None) 

383 ) 

384 self._maybe_create_attribute("_trainable_weights", []) 

385 self._maybe_create_attribute("_non_trainable_weights", []) 

386 self._updates = [] 

387 # Object to store all thread local layer properties. 

388 self._thread_local = threading.local() 

389 # A list of zero-argument lambdas which return Tensors, used for 

390 # variable regularizers. 

391 self._callable_losses = [] 

392 # A list of symbolic Tensors containing activity regularizers and losses 

393 # manually added through `add_loss` in graph-building mode. 

394 self._losses = [] 

395 # A list of metric instances corresponding to the symbolic metric 

396 # tensors added using the `add_metric` API. 

397 self._metrics = [] 

398 # Ensures the same metric is not added multiple times in 

399 # `MirroredStrategy`. 

400 self._metrics_lock = threading.Lock() 

401 

402 # Note that models also have a dtype policy, as they are layers. For 

403 # functional models, the policy is only used in Model.compile, which 

404 # wraps the optimizer with a LossScaleOptimizer if the policy name is 

405 # "mixed_float16". Subclassed models additionally use the policy's 

406 # compute and variable dtypes, as like any ordinary layer. 

407 self._set_dtype_policy(dtype) 

408 # Boolean indicating whether the layer automatically casts its inputs to 

409 # the layer's compute_dtype. 

410 self._autocast = kwargs.get( 

411 "autocast", base_layer_utils.v2_dtype_behavior_enabled() 

412 ) 

413 

414 # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s. 

415 # Ordered by when the object was assigned as an attr. 

416 # Entries are unique. 

417 self._maybe_create_attribute("_self_tracked_trackables", []) 

418 

419 # These lists will be filled via successive calls 

420 # to self._add_inbound_node(). 

421 # Used in symbolic mode only, only in conjunction with graph-networks 

422 self._inbound_nodes_value = [] 

423 self._outbound_nodes_value = [] 

424 

425 self._init_call_fn_args() 

426 

427 # Whether the `call` method can be used to build a TF graph without 

428 # issues. This attribute has no effect if the model is created using 

429 # the Functional API. Instead, `model.dynamic` is determined based on 

430 # the internal layers. 

431 if not isinstance(dynamic, bool): 

432 raise TypeError( 

433 "Expected `dynamic` argument to be a boolean, " 

434 f"but got: {dynamic}" 

435 ) 

436 self._dynamic = dynamic 

437 

438 # Manage input shape information if passed. 

439 if "input_dim" in kwargs and "input_shape" not in kwargs: 

440 # Backwards compatibility: alias 'input_dim' to 'input_shape'. 

441 kwargs["input_shape"] = (kwargs["input_dim"],) 

442 if "input_shape" in kwargs or "batch_input_shape" in kwargs: 

443 # In this case we will later create an input layer 

444 # to insert before the current layer 

445 if "batch_input_shape" in kwargs: 

446 batch_input_shape = tuple(kwargs["batch_input_shape"]) 

447 elif "input_shape" in kwargs: 

448 if "batch_size" in kwargs: 

449 batch_size = kwargs["batch_size"] 

450 else: 

451 batch_size = None 

452 batch_input_shape = (batch_size,) + tuple(kwargs["input_shape"]) 

453 self._batch_input_shape = batch_input_shape 

454 

455 # Manage initial weight values if passed. 

456 self._initial_weights = kwargs.get("weights", None) 

457 

458 # Whether the layer will track any layers that is set as attribute on 

459 # itself as sub-layers, the weights from the sub-layers will be included 

460 # in the parent layer's variables() as well. Defaults to `True`, which 

461 # means auto tracking is turned on. Certain subclass might want to turn 

462 # it off, like Sequential model. 

463 self._auto_track_sub_layers = True 

464 

465 # For backwards compat reasons, most built-in layers do not guarantee 

466 # That they will 100% preserve the structure of input args when saving 

467 # / loading configs. E.g. they may un-nest an arg that is 

468 # a list with one element. 

469 self._preserve_input_structure_in_config = False 

470 

471 # Save outer name scope at layer declaration so that it is preserved at 

472 # the actual layer construction. 

473 self._name_scope_on_declaration = tf.get_current_name_scope() 

474 

475 # Save the temp regularization losses created in the DTensor use case. 

476 # When DTensor is enable, we will first create LazyInitVariable and then 

477 # DVariable with proper layout afterward. For the weights regularization 

478 # loss, we have to create against the DVariable as well. 

479 self._captured_weight_regularizer = [] 

480 

481 @tf.__internal__.tracking.no_automatic_dependency_tracking 

482 @generic_utils.default 

483 def build(self, input_shape): 

484 """Creates the variables of the layer (for subclass implementers). 

485 

486 This is a method that implementers of subclasses of `Layer` or `Model` 

487 can override if they need a state-creation step in-between 

488 layer instantiation and layer call. It is invoked automatically before 

489 the first execution of `call()`. 

490 

491 This is typically used to create the weights of `Layer` subclasses 

492 (at the discretion of the subclass implementer). 

493 

494 Args: 

495 input_shape: Instance of `TensorShape`, or list of instances of 

496 `TensorShape` if the layer expects a list of inputs 

497 (one instance per input). 

498 """ 

499 self._build_input_shape = input_shape 

500 self.built = True 

501 

502 @doc_controls.for_subclass_implementers 

503 def call(self, inputs, *args, **kwargs): 

504 """This is where the layer's logic lives. 

505 

506 The `call()` method may not create state (except in its first 

507 invocation, wrapping the creation of variables or other resources in 

508 `tf.init_scope()`). It is recommended to create state, including 

509 `tf.Variable` instances and nested `Layer` instances, 

510 in `__init__()`, or in the `build()` method that is 

511 called automatically before `call()` executes for the first time. 

512 

513 Args: 

514 inputs: Input tensor, or dict/list/tuple of input tensors. 

515 The first positional `inputs` argument is subject to special rules: 

516 - `inputs` must be explicitly passed. A layer cannot have zero 

517 arguments, and `inputs` cannot be provided via the default value 

518 of a keyword argument. 

519 - NumPy array or Python scalar values in `inputs` get cast as 

520 tensors. 

521 - Keras mask metadata is only collected from `inputs`. 

522 - Layers are built (`build(input_shape)` method) 

523 using shape info from `inputs` only. 

524 - `input_spec` compatibility is only checked against `inputs`. 

525 - Mixed precision input casting is only applied to `inputs`. 

526 If a layer has tensor arguments in `*args` or `**kwargs`, their 

527 casting behavior in mixed precision should be handled manually. 

528 - The SavedModel input specification is generated using `inputs` 

529 only. 

530 - Integration with various ecosystem packages like TFMOT, TFLite, 

531 TF.js, etc is only supported for `inputs` and not for tensors in 

532 positional and keyword arguments. 

533 *args: Additional positional arguments. May contain tensors, although 

534 this is not recommended, for the reasons above. 

535 **kwargs: Additional keyword arguments. May contain tensors, although 

536 this is not recommended, for the reasons above. 

537 The following optional keyword arguments are reserved: 

538 - `training`: Boolean scalar tensor of Python boolean indicating 

539 whether the `call` is meant for training or inference. 

540 - `mask`: Boolean input mask. If the layer's `call()` method takes a 

541 `mask` argument, its default value will be set to the mask 

542 generated for `inputs` by the previous layer (if `input` did come 

543 from a layer that generated a corresponding mask, i.e. if it came 

544 from a Keras layer with masking support). 

545 

546 Returns: 

547 A tensor or list/tuple of tensors. 

548 """ 

549 return inputs 

550 

551 @doc_controls.for_subclass_implementers 

552 def add_weight( 

553 self, 

554 name=None, 

555 shape=None, 

556 dtype=None, 

557 initializer=None, 

558 regularizer=None, 

559 trainable=None, 

560 constraint=None, 

561 use_resource=None, 

562 synchronization=tf.VariableSynchronization.AUTO, 

563 aggregation=tf.VariableAggregation.NONE, 

564 **kwargs, 

565 ): 

566 """Adds a new variable to the layer. 

567 

568 Args: 

569 name: Variable name. 

570 shape: Variable shape. Defaults to scalar if unspecified. 

571 dtype: The type of the variable. Defaults to `self.dtype`. 

572 initializer: Initializer instance (callable). 

573 regularizer: Regularizer instance (callable). 

574 trainable: Boolean, whether the variable should be part of the layer's 

575 "trainable_variables" (e.g. variables, biases) 

576 or "non_trainable_variables" (e.g. BatchNorm mean and variance). 

577 Note that `trainable` cannot be `True` if `synchronization` 

578 is set to `ON_READ`. 

579 constraint: Constraint instance (callable). 

580 use_resource: Whether to use a `ResourceVariable` or not. 

581 See [this guide]( 

582 https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables) 

583 for more information. 

584 synchronization: Indicates when a distributed a variable will be 

585 aggregated. Accepted values are constants defined in the class 

586 `tf.VariableSynchronization`. By default the synchronization is set 

587 to `AUTO` and the current `DistributionStrategy` chooses when to 

588 synchronize. If `synchronization` is set to `ON_READ`, `trainable` 

589 must not be set to `True`. 

590 aggregation: Indicates how a distributed variable will be aggregated. 

591 Accepted values are constants defined in the class 

592 `tf.VariableAggregation`. 

593 **kwargs: Additional keyword arguments. Accepted values are `getter`, 

594 `collections`, `experimental_autocast` and `caching_device`. 

595 

596 Returns: 

597 The variable created. 

598 

599 Raises: 

600 ValueError: When giving unsupported dtype and no initializer or when 

601 trainable has been set to True with synchronization set as 

602 `ON_READ`. 

603 """ 

604 if shape is None: 

605 shape = () 

606 kwargs.pop("partitioner", None) # Ignored. 

607 # Validate optional keyword arguments. 

608 for kwarg in kwargs: 

609 if kwarg not in [ 

610 "collections", 

611 "experimental_autocast", 

612 "caching_device", 

613 "getter", 

614 "layout", 

615 "experimental_enable_variable_lifting", 

616 ]: 

617 raise TypeError("Unknown keyword argument:", kwarg) 

618 collections_arg = kwargs.pop("collections", None) 

619 # 'experimental_autocast' can be set to False by the caller to indicate 

620 # an AutoCastVariable should never be created. 

621 autocast = kwargs.pop("experimental_autocast", True) 

622 # See the docstring for tf.Variable about the details for 

623 # caching_device. 

624 caching_device = kwargs.pop("caching_device", None) 

625 

626 layout = kwargs.pop("layout", None) 

627 # Specially handling of auto layout fetch, based on the variable name 

628 # and attribute name. For built-in keras layers, usually the variable 

629 # name, eg 'kernel', will match with a 'kernel_layout' attribute name on 

630 # the instance. We will try to do this auto fetch if layout is not 

631 # explicitly specified. This is mainly a quick workaround for not 

632 # applying too many interface change to built-in layers, until DTensor 

633 # is a public API. Also see dtensor.utils.allow_initializer_layout for 

634 # more details. 

635 # TODO(scottzhu): Remove this once dtensor is public to end user. 

636 if not layout and name: 

637 layout = getattr(self, name + "_layout", None) 

638 

639 if dtype is None: 

640 dtype = self.dtype or backend.floatx() 

641 dtype = tf.as_dtype(dtype) 

642 if self._dtype_policy.variable_dtype is None: 

643 # The policy is "_infer", so we infer the policy from the variable 

644 # dtype. 

645 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) 

646 initializer = initializers.get(initializer) 

647 regularizer = regularizers.get(regularizer) 

648 constraint = constraints.get(constraint) 

649 

650 if synchronization == tf.VariableSynchronization.ON_READ: 

651 if trainable: 

652 raise ValueError( 

653 "Synchronization value can be set to " 

654 "VariableSynchronization.ON_READ only for non-trainable " 

655 "variables. You have specified trainable=True and " 

656 "synchronization=VariableSynchronization.ON_READ." 

657 ) 

658 else: 

659 # Set trainable to be false when variable is to be synced on 

660 # read. 

661 trainable = False 

662 elif trainable is None: 

663 trainable = True 

664 

665 # Initialize variable when no initializer provided 

666 if initializer is None: 

667 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 

668 if dtype.is_floating: 

669 initializer = initializers.get("glorot_uniform") 

670 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 

671 # If dtype is DT_BOOL, provide a default value `FALSE` 

672 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 

673 initializer = initializers.get("zeros") 

674 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX 

675 # here? 

676 elif "getter" not in kwargs: 

677 # When `getter` is specified, it's possibly fine for 

678 # `initializer` to be None since it's up to the custom `getter` 

679 # to raise error in case it indeed needs `initializer`. 

680 raise ValueError( 

681 f"An initializer for variable {name} of type " 

682 f"{dtype.base_dtype} is required for layer " 

683 f"{self.name}. Received: {initializer}." 

684 ) 

685 

686 getter = kwargs.pop("getter", base_layer_utils.make_variable) 

687 if ( 

688 autocast 

689 and self._dtype_policy.compute_dtype 

690 != self._dtype_policy.variable_dtype 

691 and dtype.is_floating 

692 ): 

693 old_getter = getter 

694 

695 # Wrap variable constructor to return an AutoCastVariable. 

696 def getter(*args, **kwargs): 

697 variable = old_getter(*args, **kwargs) 

698 return autocast_variable.create_autocast_variable(variable) 

699 

700 # Also the caching_device does not work with the mixed precision 

701 # API, disable it if it is specified. 

702 # TODO(b/142020079): Re-enable it once the bug is fixed. 

703 if caching_device is not None: 

704 tf_logging.warning( 

705 "`caching_device` does not work with mixed precision API. " 

706 "Ignoring user specified `caching_device`." 

707 ) 

708 caching_device = None 

709 if layout: 

710 getter = functools.partial(getter, layout=layout) 

711 

712 variable = self._add_variable_with_custom_getter( 

713 name=name, 

714 shape=shape, 

715 # TODO(allenl): a `make_variable` equivalent should be added as a 

716 # `Trackable` method. 

717 getter=getter, 

718 # Manage errors in Layer rather than Trackable. 

719 overwrite=True, 

720 initializer=initializer, 

721 dtype=dtype, 

722 constraint=constraint, 

723 trainable=trainable, 

724 use_resource=use_resource, 

725 collections=collections_arg, 

726 synchronization=synchronization, 

727 aggregation=aggregation, 

728 caching_device=caching_device, 

729 ) 

730 if regularizer is not None: 

731 # TODO(fchollet): in the future, this should be handled at the 

732 # level of variable creation, and weight regularization losses 

733 # should be variable attributes. 

734 name_in_scope = variable.name[: variable.name.find(":")] 

735 self._handle_weight_regularization( 

736 name_in_scope, variable, regularizer 

737 ) 

738 if base_layer_utils.is_split_variable(variable): 

739 for v in variable: 

740 backend.track_variable(v) 

741 if trainable: 

742 self._trainable_weights.append(v) 

743 else: 

744 self._non_trainable_weights.append(v) 

745 else: 

746 backend.track_variable(variable) 

747 if trainable: 

748 self._trainable_weights.append(variable) 

749 else: 

750 self._non_trainable_weights.append(variable) 

751 return variable 

752 

753 def __new__(cls, *args, **kwargs): 

754 # Generate a config to be returned by default by `get_config()`. 

755 arg_names = tf_inspect.getfullargspec(cls.__init__).args 

756 kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) 

757 instance = super(Layer, cls).__new__(cls, *args, **kwargs) 

758 # For safety, we only rely on auto-configs for a small set of 

759 # serializable types. 

760 supported_types = (str, int, float, bool, type(None)) 

761 try: 

762 flat_arg_values = tf.nest.flatten(kwargs) 

763 auto_get_config = True 

764 for value in flat_arg_values: 

765 if not isinstance(value, supported_types): 

766 auto_get_config = False 

767 break 

768 except TypeError: 

769 auto_get_config = False 

770 try: 

771 instance._auto_get_config = auto_get_config 

772 if auto_get_config: 

773 instance._auto_config = serialization_lib.Config(**kwargs) 

774 except RecursionError: 

775 # Setting an instance attribute in __new__ has the potential 

776 # to trigger an infinite recursion if a subclass overrides 

777 # setattr in an unsafe way. 

778 pass 

779 return instance 

780 

781 @generic_utils.default 

782 def get_config(self): 

783 """Returns the config of the layer. 

784 

785 A layer config is a Python dictionary (serializable) 

786 containing the configuration of a layer. 

787 The same layer can be reinstantiated later 

788 (without its trained weights) from this configuration. 

789 

790 The config of a layer does not include connectivity 

791 information, nor the layer class name. These are handled 

792 by `Network` (one layer of abstraction above). 

793 

794 Note that `get_config()` does not guarantee to return a fresh copy of 

795 dict every time it is called. The callers should make a copy of the 

796 returned dict if they want to modify it. 

797 

798 Returns: 

799 Python dictionary. 

800 """ 

801 config = { 

802 "name": self.name, 

803 "trainable": self.trainable, 

804 } 

805 config["dtype"] = policy.serialize(self._dtype_policy) 

806 if hasattr(self, "_batch_input_shape"): 

807 config["batch_input_shape"] = self._batch_input_shape 

808 

809 if not generic_utils.is_default(self.get_config): 

810 # In this case the subclass implements get_config() 

811 return config 

812 

813 # In this case the subclass doesn't implement get_config(): 

814 # Let's see if we can autogenerate it. 

815 if getattr(self, "_auto_get_config", False): 

816 xtra_args = set(config.keys()) 

817 config.update(self._auto_config.config) 

818 # Remove args non explicitly supported 

819 argspec = tf_inspect.getfullargspec(self.__init__) 

820 if argspec.varkw != "kwargs": 

821 for key in xtra_args - xtra_args.intersection(argspec.args[1:]): 

822 config.pop(key, None) 

823 return config 

824 else: 

825 raise NotImplementedError( 

826 textwrap.dedent( 

827 f""" 

828 Layer {self.__class__.__name__} was created by passing 

829 non-serializable argument values in `__init__()`, 

830 and therefore the layer must override `get_config()` in 

831 order to be serializable. Please implement `get_config()`. 

832 

833 Example: 

834 

835 class CustomLayer(keras.layers.Layer): 

836 def __init__(self, arg1, arg2, **kwargs): 

837 super().__init__(**kwargs) 

838 self.arg1 = arg1 

839 self.arg2 = arg2 

840 

841 def get_config(self): 

842 config = super().get_config() 

843 config.update({{ 

844 "arg1": self.arg1, 

845 "arg2": self.arg2, 

846 }}) 

847 return config""" 

848 ) 

849 ) 

850 

851 @classmethod 

852 def from_config(cls, config): 

853 """Creates a layer from its config. 

854 

855 This method is the reverse of `get_config`, 

856 capable of instantiating the same layer from the config 

857 dictionary. It does not handle layer connectivity 

858 (handled by Network), nor weights (handled by `set_weights`). 

859 

860 Args: 

861 config: A Python dictionary, typically the 

862 output of get_config. 

863 

864 Returns: 

865 A layer instance. 

866 """ 

867 try: 

868 return cls(**config) 

869 except Exception as e: 

870 raise TypeError( 

871 f"Error when deserializing class '{cls.__name__}' using " 

872 f"config={config}.\n\nException encountered: {e}" 

873 ) 

874 

875 def compute_output_shape(self, input_shape): 

876 """Computes the output shape of the layer. 

877 

878 This method will cause the layer's state to be built, if that has not 

879 happened before. This requires that the layer will later be used with 

880 inputs that match the input shape provided here. 

881 

882 Args: 

883 input_shape: Shape tuple (tuple of integers) or `tf.TensorShape`, 

884 or structure of shape tuples / `tf.TensorShape` instances 

885 (one per output tensor of the layer). 

886 Shape tuples can include None for free dimensions, 

887 instead of an integer. 

888 

889 Returns: 

890 A `tf.TensorShape` instance 

891 or structure of `tf.TensorShape` instances. 

892 """ 

893 if tf.executing_eagerly(): 

894 # In this case we build the model first in order to do shape 

895 # inference. This is acceptable because the framework only calls 

896 # `compute_output_shape` on shape values that the layer would later 

897 # be built for. It would however cause issues in case a user 

898 # attempts to use `compute_output_shape` manually with shapes that 

899 # are incompatible with the shape the Layer will be called on (these 

900 # users will have to implement `compute_output_shape` themselves). 

901 self._maybe_build(input_shape) 

902 graph_name = str(self.name) + "_scratch_graph" 

903 with tf.__internal__.FuncGraph(graph_name).as_default(): 

904 input_shape = tf_utils.convert_shapes( 

905 input_shape, to_tuples=False 

906 ) 

907 

908 def _make_placeholder_like(shape): 

909 ph = backend.placeholder(shape=shape, dtype=self.dtype) 

910 ph._keras_mask = None 

911 return ph 

912 

913 inputs = tf.nest.map_structure( 

914 _make_placeholder_like, input_shape 

915 ) 

916 try: 

917 outputs = self(inputs, training=False) 

918 except TypeError as e: 

919 raise NotImplementedError( 

920 "We could not automatically infer the static shape of " 

921 "the layer's output. Please implement the " 

922 "`compute_output_shape` method on your layer (%s)." 

923 % self.__class__.__name__ 

924 ) from e 

925 return tf.nest.map_structure(lambda t: t.shape, outputs) 

926 raise NotImplementedError( 

927 "Please run in eager mode or implement the `compute_output_shape` " 

928 "method on your layer (%s)." % self.__class__.__name__ 

929 ) 

930 

931 @doc_controls.for_subclass_implementers 

932 def compute_output_signature(self, input_signature): 

933 """Compute the output tensor signature of the layer based on the inputs. 

934 

935 Unlike a TensorShape object, a TensorSpec object contains both shape 

936 and dtype information for a tensor. This method allows layers to provide 

937 output dtype information if it is different from the input dtype. 

938 For any layer that doesn't implement this function, 

939 the framework will fall back to use `compute_output_shape`, and will 

940 assume that the output dtype matches the input dtype. 

941 

942 Args: 

943 input_signature: Single TensorSpec or nested structure of TensorSpec 

944 objects, describing a candidate input for the layer. 

945 

946 Returns: 

947 Single TensorSpec or nested structure of TensorSpec objects, 

948 describing how the layer would transform the provided input. 

949 

950 Raises: 

951 TypeError: If input_signature contains a non-TensorSpec object. 

952 """ 

953 

954 def check_type_return_shape(s): 

955 if not isinstance(s, tf.TensorSpec): 

956 raise TypeError( 

957 "Only TensorSpec signature types are supported. " 

958 f"Received: {s}." 

959 ) 

960 return s.shape 

961 

962 input_shape = tf.nest.map_structure( 

963 check_type_return_shape, input_signature 

964 ) 

965 output_shape = self.compute_output_shape(input_shape) 

966 

967 try: 

968 dtype = self.output.dtype 

969 except AttributeError: 

970 dtype = self._compute_dtype 

971 

972 if dtype is None: 

973 input_dtypes = [s.dtype for s in tf.nest.flatten(input_signature)] 

974 # Default behavior when self.dtype is None, is to use the first 

975 # input's dtype. 

976 dtype = input_dtypes[0] 

977 return tf.nest.map_structure( 

978 lambda s: tf.TensorSpec(dtype=dtype, shape=s), output_shape 

979 ) 

980 

981 @generic_utils.default 

982 def compute_mask(self, inputs, mask=None): 

983 """Computes an output mask tensor. 

984 

985 Args: 

986 inputs: Tensor or list of tensors. 

987 mask: Tensor or list of tensors. 

988 

989 Returns: 

990 None or a tensor (or list of tensors, 

991 one per output tensor of the layer). 

992 """ 

993 if not self._supports_masking: 

994 if any(m is not None for m in tf.nest.flatten(mask)): 

995 raise TypeError( 

996 "Layer " + self.name + " does not support masking, " 

997 "but was passed an input_mask: " + str(mask) 

998 ) 

999 # masking not explicitly supported: return None as mask. 

1000 return None 

1001 # if masking is explicitly supported, by default 

1002 # carry over the input mask 

1003 return mask 

1004 

1005 @traceback_utils.filter_traceback 

1006 def __call__(self, *args, **kwargs): 

1007 """Wraps `call`, applying pre- and post-processing steps. 

1008 

1009 Args: 

1010 *args: Positional arguments to be passed to `self.call`. 

1011 **kwargs: Keyword arguments to be passed to `self.call`. 

1012 

1013 Returns: 

1014 Output tensor(s). 

1015 

1016 Note: 

1017 - The following optional keyword arguments are reserved for specific 

1018 uses: 

1019 * `training`: Boolean scalar tensor of Python boolean indicating 

1020 whether the `call` is meant for training or inference. 

1021 * `mask`: Boolean input mask. 

1022 - If the layer's `call` method takes a `mask` argument (as some Keras 

1023 layers do), its default value will be set to the mask generated 

1024 for `inputs` by the previous layer (if `input` did come from 

1025 a layer that generated a corresponding mask, i.e. if it came from 

1026 a Keras layer with masking support. 

1027 - If the layer is not built, the method will call `build`. 

1028 

1029 Raises: 

1030 ValueError: if the layer's `call` method returns None (an invalid 

1031 value). 

1032 RuntimeError: if `super().__init__()` was not called in the 

1033 constructor. 

1034 """ 

1035 if not hasattr(self, "_thread_local"): 

1036 raise RuntimeError( 

1037 "You must call `super().__init__()` in the layer constructor." 

1038 ) 

1039 

1040 # `inputs` (the first arg in the method spec) is special cased in 

1041 # layer call due to historical reasons. 

1042 # This special casing currently takes the form of: 

1043 # - 'inputs' must be explicitly passed. A layer cannot have zero 

1044 # arguments, and inputs cannot have been provided via the default 

1045 # value of a kwarg. 

1046 # - numpy/scalar values in `inputs` get converted to tensors 

1047 # - implicit masks / mask metadata are only collected from 'inputs` 

1048 # - Layers are built using shape info from 'inputs' only 

1049 # - input_spec compatibility is only checked against `inputs` 

1050 # - mixed precision casting (autocast) is only applied to `inputs`, 

1051 # not to any other argument. 

1052 inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs) 

1053 input_list = tf.nest.flatten(inputs) 

1054 

1055 # Functional Model construction mode is invoked when `Layer`s are called 

1056 # on symbolic `KerasTensor`s, i.e.: 

1057 # >> inputs = tf.keras.Input(10) 

1058 # >> outputs = MyLayer()(inputs) # Functional construction mode. 

1059 # >> model = tf.keras.Model(inputs, outputs) 

1060 if _in_functional_construction_mode( 

1061 self, inputs, args, kwargs, input_list 

1062 ): 

1063 return self._functional_construction_call( 

1064 inputs, args, kwargs, input_list 

1065 ) 

1066 

1067 # Maintains info about the `Layer.call` stack. 

1068 call_context = base_layer_utils.call_context() 

1069 

1070 # Accept NumPy and scalar inputs by converting to Tensors. 

1071 if any( 

1072 isinstance(x, (tf.Tensor, np.ndarray, float, int)) 

1073 for x in input_list 

1074 ): 

1075 inputs = tf.nest.map_structure( 

1076 _convert_numpy_or_python_types, inputs 

1077 ) 

1078 input_list = tf.nest.flatten(inputs) 

1079 

1080 # Handle `mask` propagation from previous layer to current layer. Masks 

1081 # can be propagated explicitly via the `mask` argument, or implicitly 

1082 # via setting the `_keras_mask` attribute on the inputs to a Layer. 

1083 # Masks passed explicitly take priority. 

1084 input_masks, mask_is_implicit = self._get_input_masks( 

1085 inputs, input_list, args, kwargs 

1086 ) 

1087 if self._expects_mask_arg and mask_is_implicit: 

1088 kwargs["mask"] = input_masks 

1089 

1090 # Training mode for `Layer.call` is set via (in order of priority): 

1091 # (1) The `training` argument passed to this `Layer.call`, if it is not 

1092 # None 

1093 # (2) The training mode of an outer `Layer.call`. 

1094 # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if 

1095 # set) 

1096 # (4) Any non-None default value for `training` specified in the call 

1097 # signature 

1098 # (5) False (treating the layer as if it's in inference) 

1099 args, kwargs, training_mode = self._set_training_mode( 

1100 args, kwargs, call_context 

1101 ) 

1102 

1103 # Losses are cleared for all sublayers on the outermost `Layer.call`. 

1104 # Losses are not cleared on inner `Layer.call`s, because sublayers can 

1105 # be called multiple times. 

1106 if not call_context.in_call: 

1107 self._clear_losses() 

1108 

1109 eager = tf.executing_eagerly() 

1110 with call_context.enter( 

1111 layer=self, 

1112 inputs=inputs, 

1113 build_graph=not eager, 

1114 training=training_mode, 

1115 ): 

1116 

1117 input_spec.assert_input_compatibility( 

1118 self.input_spec, inputs, self.name 

1119 ) 

1120 

1121 if eager: 

1122 call_fn = self.call 

1123 name_scope = self._name 

1124 else: 

1125 name_scope = self._get_unnested_name_scope() 

1126 call_fn = self._autographed_call() 

1127 

1128 call_fn = traceback_utils.inject_argument_info_in_traceback( 

1129 call_fn, 

1130 object_name=( 

1131 f"layer '{self.name}' (type {self.__class__.__name__})" 

1132 ), 

1133 ) 

1134 with contextlib.ExitStack() as namescope_stack: 

1135 if _is_name_scope_on_model_declaration_enabled: 

1136 namescope_stack.enter_context( 

1137 _name_scope_unnester(self._name_scope_on_declaration) 

1138 ) 

1139 namescope_stack.enter_context(tf.name_scope(name_scope)) 

1140 

1141 if not self.built: 

1142 self._maybe_build(inputs) 

1143 

1144 if self._autocast: 

1145 inputs = self._maybe_cast_inputs(inputs, input_list) 

1146 

1147 with autocast_variable.enable_auto_cast_variables( 

1148 self._compute_dtype_object 

1149 ): 

1150 outputs = call_fn(inputs, *args, **kwargs) 

1151 

1152 if self._activity_regularizer: 

1153 self._handle_activity_regularization(inputs, outputs) 

1154 if self._supports_masking: 

1155 self._set_mask_metadata( 

1156 inputs, outputs, input_masks, not eager 

1157 ) 

1158 if self._saved_model_inputs_spec is None: 

1159 self._set_save_spec(inputs, args, kwargs) 

1160 

1161 return outputs 

1162 

1163 def _get_unnested_name_scope(self): 

1164 if _is_name_scope_on_model_declaration_enabled: 

1165 with _name_scope_unnester( 

1166 self._name_scope_on_declaration 

1167 ) as relative_name_scope_on_declaration: 

1168 # To avoid `tf.name_scope` autoincrement, use absolute path. 

1169 relative_name_scope = filter( 

1170 None, 

1171 [ 

1172 tf.get_current_name_scope(), 

1173 relative_name_scope_on_declaration, 

1174 ], 

1175 ) 

1176 current_name_scope = "/".join(relative_name_scope) + "/" 

1177 if current_name_scope == "/": 

1178 current_name_scope = self._name_scope_on_declaration 

1179 with tf.name_scope(current_name_scope): 

1180 name_scope = self._name_scope() # Avoid autoincrementing. 

1181 else: 

1182 name_scope = self._name_scope() 

1183 

1184 return name_scope 

1185 

1186 @property 

1187 def dtype(self): 

1188 """The dtype of the layer weights. 

1189 

1190 This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless 

1191 mixed precision is used, this is the same as `Layer.compute_dtype`, the 

1192 dtype of the layer's computations. 

1193 """ 

1194 return self._dtype_policy.variable_dtype 

1195 

1196 @property 

1197 def name(self): 

1198 """Name of the layer (string), set in the constructor.""" 

1199 return self._name 

1200 

1201 @property 

1202 def supports_masking(self): 

1203 """Whether this layer supports computing a mask using `compute_mask`.""" 

1204 return self._supports_masking 

1205 

1206 @supports_masking.setter 

1207 def supports_masking(self, value): 

1208 self._supports_masking = value 

1209 

1210 @property 

1211 def dynamic(self): 

1212 """Whether the layer is dynamic (eager-only); set in the constructor.""" 

1213 return any(layer._dynamic for layer in self._flatten_layers()) 

1214 

1215 @property 

1216 @doc_controls.do_not_doc_inheritable 

1217 def stateful(self): 

1218 return any(layer._stateful for layer in self._flatten_layers()) 

1219 

1220 @stateful.setter 

1221 def stateful(self, value): 

1222 self._stateful = value 

1223 

1224 @property 

1225 def trainable(self): 

1226 return self._trainable 

1227 

1228 @trainable.setter 

1229 def trainable(self, value): 

1230 """Sets trainable attribute for the layer and its sublayers. 

1231 

1232 When this value is changed during training (e.g. with a 

1233 `tf.keras.callbacks.Callback`) you need to call the parent 

1234 `tf.keras.Model.make_train_function` with `force=True` in order to 

1235 recompile the training graph. 

1236 

1237 Args: 

1238 value: Boolean with the desired state for the layer's trainable 

1239 attribute. 

1240 """ 

1241 for layer in self._flatten_layers(): 

1242 layer._trainable = value 

1243 

1244 @property 

1245 def activity_regularizer(self): 

1246 """Optional regularizer function for the output of this layer.""" 

1247 return self._activity_regularizer 

1248 

1249 @activity_regularizer.setter 

1250 def activity_regularizer(self, regularizer): 

1251 """Optional regularizer function for the output of this layer.""" 

1252 self._activity_regularizer = regularizer 

1253 

1254 @property 

1255 def input_spec(self): 

1256 """`InputSpec` instance(s) describing the input format for this layer. 

1257 

1258 When you create a layer subclass, you can set `self.input_spec` to 

1259 enable the layer to run input compatibility checks when it is called. 

1260 Consider a `Conv2D` layer: it can only be called on a single input 

1261 tensor of rank 4. As such, you can set, in `__init__()`: 

1262 

1263 ```python 

1264 self.input_spec = tf.keras.layers.InputSpec(ndim=4) 

1265 ``` 

1266 

1267 Now, if you try to call the layer on an input that isn't rank 4 

1268 (for instance, an input of shape `(2,)`, it will raise a 

1269 nicely-formatted error: 

1270 

1271 ``` 

1272 ValueError: Input 0 of layer conv2d is incompatible with the layer: 

1273 expected ndim=4, found ndim=1. Full shape received: [2] 

1274 ``` 

1275 

1276 Input checks that can be specified via `input_spec` include: 

1277 - Structure (e.g. a single input, a list of 2 inputs, etc) 

1278 - Shape 

1279 - Rank (ndim) 

1280 - Dtype 

1281 

1282 For more information, see `tf.keras.layers.InputSpec`. 

1283 

1284 Returns: 

1285 A `tf.keras.layers.InputSpec` instance, or nested structure thereof. 

1286 """ 

1287 return self._input_spec 

1288 

1289 @input_spec.setter 

1290 # Must be decorated to prevent tracking, since the input_spec can be nested 

1291 # InputSpec objects. 

1292 @tf.__internal__.tracking.no_automatic_dependency_tracking 

1293 def input_spec(self, value): 

1294 for v in tf.nest.flatten(value): 

1295 if v is not None and not isinstance(v, input_spec.InputSpec): 

1296 raise TypeError( 

1297 "Layer input_spec must be an instance of InputSpec. " 

1298 "Got: {}".format(v) 

1299 ) 

1300 self._input_spec = value 

1301 

1302 @property 

1303 def trainable_weights(self): 

1304 """List of all trainable weights tracked by this layer. 

1305 

1306 Trainable weights are updated via gradient descent during training. 

1307 

1308 Returns: 

1309 A list of trainable variables. 

1310 """ 

1311 self._update_trackables() 

1312 if self.trainable: 

1313 children_weights = self._gather_children_attribute( 

1314 "trainable_variables" 

1315 ) 

1316 return self._dedup_weights( 

1317 self._trainable_weights + children_weights 

1318 ) 

1319 else: 

1320 return [] 

1321 

1322 @property 

1323 def non_trainable_weights(self): 

1324 """List of all non-trainable weights tracked by this layer. 

1325 

1326 Non-trainable weights are *not* updated during training. They are 

1327 expected to be updated manually in `call()`. 

1328 

1329 Returns: 

1330 A list of non-trainable variables. 

1331 """ 

1332 self._update_trackables() 

1333 if self.trainable: 

1334 children_weights = self._gather_children_attribute( 

1335 "non_trainable_variables" 

1336 ) 

1337 non_trainable_weights = ( 

1338 self._non_trainable_weights + children_weights 

1339 ) 

1340 else: 

1341 children_weights = self._gather_children_attribute("variables") 

1342 non_trainable_weights = ( 

1343 self._trainable_weights 

1344 + self._non_trainable_weights 

1345 + children_weights 

1346 ) 

1347 return self._dedup_weights(non_trainable_weights) 

1348 

1349 @property 

1350 def weights(self): 

1351 """Returns the list of all layer variables/weights. 

1352 

1353 Returns: 

1354 A list of variables. 

1355 """ 

1356 return self.trainable_weights + self.non_trainable_weights 

1357 

1358 @property 

1359 @doc_controls.do_not_generate_docs 

1360 def updates(self): 

1361 warnings.warn( 

1362 "`layer.updates` will be removed in a future version. " 

1363 "This property should not be used in TensorFlow 2.0, " 

1364 "as `updates` are applied automatically.", 

1365 stacklevel=2, 

1366 ) 

1367 return [] 

1368 

1369 @property 

1370 def losses(self): 

1371 """List of losses added using the `add_loss()` API. 

1372 

1373 Variable regularization tensors are created when this property is 

1374 accessed, so it is eager safe: accessing `losses` under a 

1375 `tf.GradientTape` will propagate gradients back to the corresponding 

1376 variables. 

1377 

1378 Examples: 

1379 

1380 >>> class MyLayer(tf.keras.layers.Layer): 

1381 ... def call(self, inputs): 

1382 ... self.add_loss(tf.abs(tf.reduce_mean(inputs))) 

1383 ... return inputs 

1384 >>> l = MyLayer() 

1385 >>> l(np.ones((10, 1))) 

1386 >>> l.losses 

1387 [1.0] 

1388 

1389 >>> inputs = tf.keras.Input(shape=(10,)) 

1390 >>> x = tf.keras.layers.Dense(10)(inputs) 

1391 >>> outputs = tf.keras.layers.Dense(1)(x) 

1392 >>> model = tf.keras.Model(inputs, outputs) 

1393 >>> # Activity regularization. 

1394 >>> len(model.losses) 

1395 0 

1396 >>> model.add_loss(tf.abs(tf.reduce_mean(x))) 

1397 >>> len(model.losses) 

1398 1 

1399 

1400 >>> inputs = tf.keras.Input(shape=(10,)) 

1401 >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones') 

1402 >>> x = d(inputs) 

1403 >>> outputs = tf.keras.layers.Dense(1)(x) 

1404 >>> model = tf.keras.Model(inputs, outputs) 

1405 >>> # Weight regularization. 

1406 >>> model.add_loss(lambda: tf.reduce_mean(d.kernel)) 

1407 >>> model.losses 

1408 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>] 

1409 

1410 Returns: 

1411 A list of tensors. 

1412 """ 

1413 collected_losses = [] 

1414 for layer in self._flatten_layers(): 

1415 # If any eager losses are present, we assume the model to be part of 

1416 # an eager training loop (either a custom one or the one used when 

1417 # `run_eagerly=True`) and so we always return just the eager losses. 

1418 if layer._eager_losses: 

1419 # Filter placeholder losses that may have been added by revived 

1420 # layers. (see base_layer_utils for details). 

1421 if ( 

1422 layer._eager_losses[0] 

1423 is not base_layer_utils.REVIVED_LOSS_PLACEHOLDER 

1424 ): 

1425 collected_losses.extend(layer._eager_losses) 

1426 else: 

1427 collected_losses.extend(layer._losses) 

1428 for regularizer in layer._callable_losses: 

1429 loss_tensor = regularizer() 

1430 if loss_tensor is not None: 

1431 collected_losses.append(loss_tensor) 

1432 return collected_losses 

1433 

1434 def add_loss(self, losses, **kwargs): 

1435 """Add loss tensor(s), potentially dependent on layer inputs. 

1436 

1437 Some losses (for instance, activity regularization losses) may be 

1438 dependent on the inputs passed when calling a layer. Hence, when reusing 

1439 the same layer on different inputs `a` and `b`, some entries in 

1440 `layer.losses` may be dependent on `a` and some on `b`. This method 

1441 automatically keeps track of dependencies. 

1442 

1443 This method can be used inside a subclassed layer or model's `call` 

1444 function, in which case `losses` should be a Tensor or list of Tensors. 

1445 

1446 Example: 

1447 

1448 ```python 

1449 class MyLayer(tf.keras.layers.Layer): 

1450 def call(self, inputs): 

1451 self.add_loss(tf.abs(tf.reduce_mean(inputs))) 

1452 return inputs 

1453 ``` 

1454 

1455 The same code works in distributed training: the input to `add_loss()` 

1456 is treated like a regularization loss and averaged across replicas 

1457 by the training loop (both built-in `Model.fit()` and compliant custom 

1458 training loops). 

1459 

1460 The `add_loss` method can also be called directly on a Functional Model 

1461 during construction. In this case, any loss Tensors passed to this Model 

1462 must be symbolic and be able to be traced back to the model's `Input`s. 

1463 These losses become part of the model's topology and are tracked in 

1464 `get_config`. 

1465 

1466 Example: 

1467 

1468 ```python 

1469 inputs = tf.keras.Input(shape=(10,)) 

1470 x = tf.keras.layers.Dense(10)(inputs) 

1471 outputs = tf.keras.layers.Dense(1)(x) 

1472 model = tf.keras.Model(inputs, outputs) 

1473 # Activity regularization. 

1474 model.add_loss(tf.abs(tf.reduce_mean(x))) 

1475 ``` 

1476 

1477 If this is not the case for your loss (if, for example, your loss 

1478 references a `Variable` of one of the model's layers), you can wrap your 

1479 loss in a zero-argument lambda. These losses are not tracked as part of 

1480 the model's topology since they can't be serialized. 

1481 

1482 Example: 

1483 

1484 ```python 

1485 inputs = tf.keras.Input(shape=(10,)) 

1486 d = tf.keras.layers.Dense(10) 

1487 x = d(inputs) 

1488 outputs = tf.keras.layers.Dense(1)(x) 

1489 model = tf.keras.Model(inputs, outputs) 

1490 # Weight regularization. 

1491 model.add_loss(lambda: tf.reduce_mean(d.kernel)) 

1492 ``` 

1493 

1494 Args: 

1495 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, 

1496 losses may also be zero-argument callables which create a loss 

1497 tensor. 

1498 **kwargs: Used for backwards compatibility only. 

1499 """ 

1500 kwargs.pop("inputs", None) 

1501 if kwargs: 

1502 raise TypeError(f"Unknown keyword arguments: {kwargs.keys()}") 

1503 

1504 def _tag_callable(loss): 

1505 """Tags callable loss tensor as `_unconditional_loss`.""" 

1506 if callable(loss): 

1507 # We run the loss without autocasting, as regularizers are often 

1508 # numerically unstable in float16. 

1509 with autocast_variable.enable_auto_cast_variables(None): 

1510 loss = loss() 

1511 if loss is None: 

1512 # Will be filtered out when computing the .losses property 

1513 return None 

1514 if not tf.is_tensor(loss): 

1515 loss = tf.convert_to_tensor(loss, dtype=backend.floatx()) 

1516 loss._unconditional_loss = True 

1517 return loss 

1518 

1519 losses = tf.nest.flatten(losses) 

1520 

1521 callable_losses = [] 

1522 eager_losses = [] 

1523 symbolic_losses = [] 

1524 for loss in losses: 

1525 if callable(loss): 

1526 callable_losses.append(functools.partial(_tag_callable, loss)) 

1527 continue 

1528 if loss is None: 

1529 continue 

1530 if not tf.is_tensor(loss) and not isinstance( 

1531 loss, keras_tensor.KerasTensor 

1532 ): 

1533 loss = tf.convert_to_tensor(loss, dtype=backend.floatx()) 

1534 # TF Functions should take the eager path. 

1535 if ( 

1536 tf_utils.is_symbolic_tensor(loss) 

1537 or isinstance(loss, keras_tensor.KerasTensor) 

1538 ) and not base_layer_utils.is_in_tf_function(): 

1539 symbolic_losses.append(loss) 

1540 elif tf.is_tensor(loss): 

1541 eager_losses.append(loss) 

1542 

1543 self._callable_losses.extend(callable_losses) 

1544 

1545 in_call_context = base_layer_utils.call_context().in_call 

1546 if eager_losses and not in_call_context: 

1547 raise ValueError( 

1548 "Expected a symbolic Tensors or a callable for the loss value. " 

1549 "Please wrap your loss computation in a zero argument `lambda`." 

1550 ) 

1551 

1552 self._eager_losses.extend(eager_losses) 

1553 

1554 for symbolic_loss in symbolic_losses: 

1555 if getattr(self, "_is_graph_network", False): 

1556 self._graph_network_add_loss(symbolic_loss) 

1557 else: 

1558 # Possible a loss was added in a Layer's `build`. 

1559 self._losses.append(symbolic_loss) 

1560 

1561 @property 

1562 def metrics(self): 

1563 """List of metrics added using the `add_metric()` API. 

1564 

1565 Example: 

1566 

1567 >>> input = tf.keras.layers.Input(shape=(3,)) 

1568 >>> d = tf.keras.layers.Dense(2) 

1569 >>> output = d(input) 

1570 >>> d.add_metric(tf.reduce_max(output), name='max') 

1571 >>> d.add_metric(tf.reduce_min(output), name='min') 

1572 >>> [m.name for m in d.metrics] 

1573 ['max', 'min'] 

1574 

1575 Returns: 

1576 A list of `Metric` objects. 

1577 """ 

1578 collected_metrics = [] 

1579 for layer in self._flatten_layers(): 

1580 if not hasattr(layer, "_metrics_lock"): 

1581 continue 

1582 with layer._metrics_lock: 

1583 collected_metrics.extend(layer._metrics) 

1584 return collected_metrics 

1585 

1586 def add_metric(self, value, name=None, **kwargs): 

1587 """Adds metric tensor to the layer. 

1588 

1589 This method can be used inside the `call()` method of a subclassed layer 

1590 or model. 

1591 

1592 ```python 

1593 class MyMetricLayer(tf.keras.layers.Layer): 

1594 def __init__(self): 

1595 super(MyMetricLayer, self).__init__(name='my_metric_layer') 

1596 self.mean = tf.keras.metrics.Mean(name='metric_1') 

1597 

1598 def call(self, inputs): 

1599 self.add_metric(self.mean(inputs)) 

1600 self.add_metric(tf.reduce_sum(inputs), name='metric_2') 

1601 return inputs 

1602 ``` 

1603 

1604 This method can also be called directly on a Functional Model during 

1605 construction. In this case, any tensor passed to this Model must 

1606 be symbolic and be able to be traced back to the model's `Input`s. These 

1607 metrics become part of the model's topology and are tracked when you 

1608 save the model via `save()`. 

1609 

1610 ```python 

1611 inputs = tf.keras.Input(shape=(10,)) 

1612 x = tf.keras.layers.Dense(10)(inputs) 

1613 outputs = tf.keras.layers.Dense(1)(x) 

1614 model = tf.keras.Model(inputs, outputs) 

1615 model.add_metric(math_ops.reduce_sum(x), name='metric_1') 

1616 ``` 

1617 

1618 Note: Calling `add_metric()` with the result of a metric object on a 

1619 Functional Model, as shown in the example below, is not supported. This 

1620 is because we cannot trace the metric result tensor back to the model's 

1621 inputs. 

1622 

1623 ```python 

1624 inputs = tf.keras.Input(shape=(10,)) 

1625 x = tf.keras.layers.Dense(10)(inputs) 

1626 outputs = tf.keras.layers.Dense(1)(x) 

1627 model = tf.keras.Model(inputs, outputs) 

1628 model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1') 

1629 ``` 

1630 

1631 Args: 

1632 value: Metric tensor. 

1633 name: String metric name. 

1634 **kwargs: Additional keyword arguments for backward compatibility. 

1635 Accepted values: 

1636 `aggregation` - When the `value` tensor provided is not the result 

1637 of calling a `keras.Metric` instance, it will be aggregated by 

1638 default using a `keras.Metric.Mean`. 

1639 """ 

1640 kwargs_keys = list(kwargs.keys()) 

1641 if len(kwargs_keys) > 1 or ( 

1642 len(kwargs_keys) == 1 and kwargs_keys[0] != "aggregation" 

1643 ): 

1644 raise TypeError( 

1645 f"Unknown keyword arguments: {kwargs.keys()}. " 

1646 "Expected `aggregation`." 

1647 ) 

1648 

1649 from_metric_obj = hasattr(value, "_metric_obj") 

1650 is_symbolic = isinstance(value, keras_tensor.KerasTensor) 

1651 in_call_context = base_layer_utils.call_context().in_call 

1652 

1653 if name is None and not from_metric_obj: 

1654 # Eg. `self.add_metric(math_ops.reduce_sum(x))` In eager mode, we 

1655 # use metric name to lookup a metric. Without a name, a new Mean 

1656 # metric wrapper will be created on every model/layer call. So, we 

1657 # raise an error when no name is provided. We will do the same for 

1658 # symbolic mode for consistency although a name will be generated if 

1659 # no name is provided. 

1660 

1661 # We will not raise this error in the foll use case for the sake of 

1662 # consistency as name in provided in the metric constructor. 

1663 # mean = metrics.Mean(name='my_metric') 

1664 # model.add_metric(mean(outputs)) 

1665 raise ValueError( 

1666 "Please provide a name for your metric like " 

1667 "`self.add_metric(tf.reduce_sum(inputs), " 

1668 "name='mean_activation')`" 

1669 ) 

1670 elif from_metric_obj: 

1671 name = value._metric_obj.name 

1672 

1673 if not in_call_context and not is_symbolic: 

1674 raise ValueError( 

1675 "Expected a symbolic Tensor for the metric value, received: " 

1676 + str(value) 

1677 ) 

1678 

1679 # If a metric was added in a Layer's `call` or `build`. 

1680 if in_call_context or not getattr(self, "_is_graph_network", False): 

1681 # TF Function path should take the eager path. 

1682 

1683 # If the given metric is available in `metrics` list we just update 

1684 # state on it, otherwise we create a new metric instance and 

1685 # add it to the `metrics` list. 

1686 metric_obj = getattr(value, "_metric_obj", None) 

1687 # Tensors that come from a Metric object already updated the Metric 

1688 # state. 

1689 should_update_state = not metric_obj 

1690 name = metric_obj.name if metric_obj else name 

1691 

1692 with self._metrics_lock: 

1693 match = self._get_existing_metric(name) 

1694 if match: 

1695 metric_obj = match 

1696 elif metric_obj: 

1697 self._metrics.append(metric_obj) 

1698 else: 

1699 # Build the metric object with the value's dtype if it 

1700 # defines one 

1701 metric_obj = metrics_mod.Mean( 

1702 name=name, dtype=getattr(value, "dtype", None) 

1703 ) 

1704 self._metrics.append(metric_obj) 

1705 

1706 if should_update_state: 

1707 metric_obj(value) 

1708 else: 

1709 if from_metric_obj: 

1710 raise ValueError( 

1711 "Using the result of calling a `Metric` object " 

1712 "when calling `add_metric` on a Functional " 

1713 "Model is not supported. Please pass the " 

1714 "Tensor to monitor directly." 

1715 ) 

1716 

1717 # Insert layers into the Keras Graph Network. 

1718 aggregation = None if from_metric_obj else "mean" 

1719 self._graph_network_add_metric(value, aggregation, name) 

1720 

1721 @doc_controls.do_not_doc_inheritable 

1722 def add_update(self, updates): 

1723 """Add update op(s), potentially dependent on layer inputs. 

1724 

1725 Weight updates (for instance, the updates of the moving mean and 

1726 variance in a BatchNormalization layer) may be dependent on the inputs 

1727 passed when calling a layer. Hence, when reusing the same layer on 

1728 different inputs `a` and `b`, some entries in `layer.updates` may be 

1729 dependent on `a` and some on `b`. This method automatically keeps track 

1730 of dependencies. 

1731 

1732 This call is ignored when eager execution is enabled (in that case, 

1733 variable updates are run on the fly and thus do not need to be tracked 

1734 for later execution). 

1735 

1736 Args: 

1737 updates: Update op, or list/tuple of update ops, or zero-arg callable 

1738 that returns an update op. A zero-arg callable should be passed in 

1739 order to disable running the updates by setting `trainable=False` 

1740 on this Layer, when executing in Eager mode. 

1741 """ 

1742 call_context = base_layer_utils.call_context() 

1743 # No need to run updates during Functional API construction. 

1744 if call_context.in_keras_graph: 

1745 return 

1746 

1747 # Callable updates are disabled by setting `trainable=False`. 

1748 if not call_context.frozen: 

1749 for update in tf.nest.flatten(updates): 

1750 if callable(update): 

1751 update() 

1752 

1753 def set_weights(self, weights): 

1754 """Sets the weights of the layer, from NumPy arrays. 

1755 

1756 The weights of a layer represent the state of the layer. This function 

1757 sets the weight values from numpy arrays. The weight values should be 

1758 passed in the order they are created by the layer. Note that the layer's 

1759 weights must be instantiated before calling this function, by calling 

1760 the layer. 

1761 

1762 For example, a `Dense` layer returns a list of two values: the kernel 

1763 matrix and the bias vector. These can be used to set the weights of 

1764 another `Dense` layer: 

1765 

1766 >>> layer_a = tf.keras.layers.Dense(1, 

1767 ... kernel_initializer=tf.constant_initializer(1.)) 

1768 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]])) 

1769 >>> layer_a.get_weights() 

1770 [array([[1.], 

1771 [1.], 

1772 [1.]], dtype=float32), array([0.], dtype=float32)] 

1773 >>> layer_b = tf.keras.layers.Dense(1, 

1774 ... kernel_initializer=tf.constant_initializer(2.)) 

1775 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]])) 

1776 >>> layer_b.get_weights() 

1777 [array([[2.], 

1778 [2.], 

1779 [2.]], dtype=float32), array([0.], dtype=float32)] 

1780 >>> layer_b.set_weights(layer_a.get_weights()) 

1781 >>> layer_b.get_weights() 

1782 [array([[1.], 

1783 [1.], 

1784 [1.]], dtype=float32), array([0.], dtype=float32)] 

1785 

1786 Args: 

1787 weights: a list of NumPy arrays. The number 

1788 of arrays and their shape must match 

1789 number of the dimensions of the weights 

1790 of the layer (i.e. it should match the 

1791 output of `get_weights`). 

1792 

1793 Raises: 

1794 ValueError: If the provided weights list does not match the 

1795 layer's specifications. 

1796 """ 

1797 params = self.weights 

1798 

1799 expected_num_weights = 0 

1800 for param in params: 

1801 if isinstance(param, base_layer_utils.TrackableWeightHandler): 

1802 expected_num_weights += param.num_tensors 

1803 else: 

1804 expected_num_weights += 1 

1805 

1806 if expected_num_weights != len(weights): 

1807 raise ValueError( 

1808 'You called `set_weights(weights)` on layer "%s" ' 

1809 "with a weight list of length %s, but the layer was " 

1810 "expecting %s weights. Provided weights: %s..." 

1811 % ( 

1812 self.name, 

1813 len(weights), 

1814 expected_num_weights, 

1815 str(weights)[:50], 

1816 ) 

1817 ) 

1818 

1819 weight_index = 0 

1820 weight_value_tuples = [] 

1821 for param in params: 

1822 if isinstance(param, base_layer_utils.TrackableWeightHandler): 

1823 num_tensors = param.num_tensors 

1824 tensors = weights[weight_index : weight_index + num_tensors] 

1825 param.set_weights(tensors) 

1826 weight_index += num_tensors 

1827 else: 

1828 weight = weights[weight_index] 

1829 weight_shape = weight.shape if hasattr(weight, "shape") else () 

1830 ref_shape = param.shape 

1831 if not ref_shape.is_compatible_with(weight_shape): 

1832 raise ValueError( 

1833 f"Layer {self.name} weight shape {ref_shape} " 

1834 "is not compatible with provided weight " 

1835 f"shape {weight_shape}." 

1836 ) 

1837 weight_value_tuples.append((param, weight)) 

1838 weight_index += 1 

1839 

1840 backend.batch_set_value(weight_value_tuples) 

1841 

1842 # Perform any layer defined finalization of the layer state. 

1843 for layer in self._flatten_layers(): 

1844 layer.finalize_state() 

1845 

1846 def get_weights(self): 

1847 """Returns the current weights of the layer, as NumPy arrays. 

1848 

1849 The weights of a layer represent the state of the layer. This function 

1850 returns both trainable and non-trainable weight values associated with 

1851 this layer as a list of NumPy arrays, which can in turn be used to load 

1852 state into similarly parameterized layers. 

1853 

1854 For example, a `Dense` layer returns a list of two values: the kernel 

1855 matrix and the bias vector. These can be used to set the weights of 

1856 another `Dense` layer: 

1857 

1858 >>> layer_a = tf.keras.layers.Dense(1, 

1859 ... kernel_initializer=tf.constant_initializer(1.)) 

1860 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]])) 

1861 >>> layer_a.get_weights() 

1862 [array([[1.], 

1863 [1.], 

1864 [1.]], dtype=float32), array([0.], dtype=float32)] 

1865 >>> layer_b = tf.keras.layers.Dense(1, 

1866 ... kernel_initializer=tf.constant_initializer(2.)) 

1867 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]])) 

1868 >>> layer_b.get_weights() 

1869 [array([[2.], 

1870 [2.], 

1871 [2.]], dtype=float32), array([0.], dtype=float32)] 

1872 >>> layer_b.set_weights(layer_a.get_weights()) 

1873 >>> layer_b.get_weights() 

1874 [array([[1.], 

1875 [1.], 

1876 [1.]], dtype=float32), array([0.], dtype=float32)] 

1877 

1878 Returns: 

1879 Weights values as a list of NumPy arrays. 

1880 """ 

1881 weights = self.weights 

1882 output_weights = [] 

1883 for weight in weights: 

1884 if isinstance(weight, base_layer_utils.TrackableWeightHandler): 

1885 output_weights.extend(weight.get_tensors()) 

1886 else: 

1887 output_weights.append(weight) 

1888 return backend.batch_get_value(output_weights) 

1889 

1890 @doc_controls.do_not_generate_docs 

1891 def finalize_state(self): 

1892 """Finalizes the layers state after updating layer weights. 

1893 

1894 This function can be subclassed in a layer and will be called after 

1895 updating a layer weights. It can be overridden to finalize any 

1896 additional layer state after a weight update. 

1897 

1898 This function will be called after weights of a layer have been restored 

1899 from a loaded model. 

1900 """ 

1901 pass 

1902 

1903 @doc_controls.do_not_doc_inheritable 

1904 def get_input_mask_at(self, node_index): 

1905 """Retrieves the input mask tensor(s) of a layer at a given node. 

1906 

1907 Args: 

1908 node_index: Integer, index of the node 

1909 from which to retrieve the attribute. 

1910 E.g. `node_index=0` will correspond to the 

1911 first time the layer was called. 

1912 

1913 Returns: 

1914 A mask tensor 

1915 (or list of tensors if the layer has multiple inputs). 

1916 """ 

1917 inputs = self.get_input_at(node_index) 

1918 if isinstance(inputs, list): 

1919 return [getattr(x, "_keras_mask", None) for x in inputs] 

1920 else: 

1921 return getattr(inputs, "_keras_mask", None) 

1922 

1923 @doc_controls.do_not_doc_inheritable 

1924 def get_output_mask_at(self, node_index): 

1925 """Retrieves the output mask tensor(s) of a layer at a given node. 

1926 

1927 Args: 

1928 node_index: Integer, index of the node 

1929 from which to retrieve the attribute. 

1930 E.g. `node_index=0` will correspond to the 

1931 first time the layer was called. 

1932 

1933 Returns: 

1934 A mask tensor 

1935 (or list of tensors if the layer has multiple outputs). 

1936 """ 

1937 output = self.get_output_at(node_index) 

1938 if isinstance(output, list): 

1939 return [getattr(x, "_keras_mask", None) for x in output] 

1940 else: 

1941 return getattr(output, "_keras_mask", None) 

1942 

1943 @property 

1944 @doc_controls.do_not_doc_inheritable 

1945 def input_mask(self): 

1946 """Retrieves the input mask tensor(s) of a layer. 

1947 

1948 Only applicable if the layer has exactly one inbound node, 

1949 i.e. if it is connected to one incoming layer. 

1950 

1951 Returns: 

1952 Input mask tensor (potentially None) or list of input 

1953 mask tensors. 

1954 

1955 Raises: 

1956 AttributeError: if the layer is connected to 

1957 more than one incoming layers. 

1958 """ 

1959 inputs = self.input 

1960 if isinstance(inputs, list): 

1961 return [getattr(x, "_keras_mask", None) for x in inputs] 

1962 else: 

1963 return getattr(inputs, "_keras_mask", None) 

1964 

1965 @property 

1966 @doc_controls.do_not_doc_inheritable 

1967 def output_mask(self): 

1968 """Retrieves the output mask tensor(s) of a layer. 

1969 

1970 Only applicable if the layer has exactly one inbound node, 

1971 i.e. if it is connected to one incoming layer. 

1972 

1973 Returns: 

1974 Output mask tensor (potentially None) or list of output 

1975 mask tensors. 

1976 

1977 Raises: 

1978 AttributeError: if the layer is connected to 

1979 more than one incoming layers. 

1980 """ 

1981 output = self.output 

1982 if isinstance(output, list): 

1983 return [getattr(x, "_keras_mask", None) for x in output] 

1984 else: 

1985 return getattr(output, "_keras_mask", None) 

1986 

1987 @doc_controls.do_not_doc_inheritable 

1988 def get_input_shape_at(self, node_index): 

1989 """Retrieves the input shape(s) of a layer at a given node. 

1990 

1991 Args: 

1992 node_index: Integer, index of the node 

1993 from which to retrieve the attribute. 

1994 E.g. `node_index=0` will correspond to the 

1995 first time the layer was called. 

1996 

1997 Returns: 

1998 A shape tuple 

1999 (or list of shape tuples if the layer has multiple inputs). 

2000 

2001 Raises: 

2002 RuntimeError: If called in Eager mode. 

2003 """ 

2004 return self._get_node_attribute_at_index( 

2005 node_index, "input_shapes", "input shape" 

2006 ) 

2007 

2008 @doc_controls.do_not_doc_inheritable 

2009 def get_output_shape_at(self, node_index): 

2010 """Retrieves the output shape(s) of a layer at a given node. 

2011 

2012 Args: 

2013 node_index: Integer, index of the node 

2014 from which to retrieve the attribute. 

2015 E.g. `node_index=0` will correspond to the 

2016 first time the layer was called. 

2017 

2018 Returns: 

2019 A shape tuple 

2020 (or list of shape tuples if the layer has multiple outputs). 

2021 

2022 Raises: 

2023 RuntimeError: If called in Eager mode. 

2024 """ 

2025 return self._get_node_attribute_at_index( 

2026 node_index, "output_shapes", "output shape" 

2027 ) 

2028 

2029 @doc_controls.do_not_doc_inheritable 

2030 def get_input_at(self, node_index): 

2031 """Retrieves the input tensor(s) of a layer at a given node. 

2032 

2033 Args: 

2034 node_index: Integer, index of the node 

2035 from which to retrieve the attribute. 

2036 E.g. `node_index=0` will correspond to the 

2037 first input node of the layer. 

2038 

2039 Returns: 

2040 A tensor (or list of tensors if the layer has multiple inputs). 

2041 

2042 Raises: 

2043 RuntimeError: If called in Eager mode. 

2044 """ 

2045 return self._get_node_attribute_at_index( 

2046 node_index, "input_tensors", "input" 

2047 ) 

2048 

2049 @doc_controls.do_not_doc_inheritable 

2050 def get_output_at(self, node_index): 

2051 """Retrieves the output tensor(s) of a layer at a given node. 

2052 

2053 Args: 

2054 node_index: Integer, index of the node 

2055 from which to retrieve the attribute. 

2056 E.g. `node_index=0` will correspond to the 

2057 first output node of the layer. 

2058 

2059 Returns: 

2060 A tensor (or list of tensors if the layer has multiple outputs). 

2061 

2062 Raises: 

2063 RuntimeError: If called in Eager mode. 

2064 """ 

2065 return self._get_node_attribute_at_index( 

2066 node_index, "output_tensors", "output" 

2067 ) 

2068 

2069 @property 

2070 def input(self): 

2071 """Retrieves the input tensor(s) of a layer. 

2072 

2073 Only applicable if the layer has exactly one input, 

2074 i.e. if it is connected to one incoming layer. 

2075 

2076 Returns: 

2077 Input tensor or list of input tensors. 

2078 

2079 Raises: 

2080 RuntimeError: If called in Eager mode. 

2081 AttributeError: If no inbound nodes are found. 

2082 """ 

2083 if not self._inbound_nodes: 

2084 raise AttributeError( 

2085 "Layer " + self.name + " is not connected, no input to return." 

2086 ) 

2087 return self._get_node_attribute_at_index(0, "input_tensors", "input") 

2088 

2089 @property 

2090 def output(self): 

2091 """Retrieves the output tensor(s) of a layer. 

2092 

2093 Only applicable if the layer has exactly one output, 

2094 i.e. if it is connected to one incoming layer. 

2095 

2096 Returns: 

2097 Output tensor or list of output tensors. 

2098 

2099 Raises: 

2100 AttributeError: if the layer is connected to more than one incoming 

2101 layers. 

2102 RuntimeError: if called in Eager mode. 

2103 """ 

2104 if not self._inbound_nodes: 

2105 raise AttributeError( 

2106 "Layer " + self.name + " has no inbound nodes." 

2107 ) 

2108 return self._get_node_attribute_at_index(0, "output_tensors", "output") 

2109 

2110 @property 

2111 @doc_controls.do_not_doc_inheritable 

2112 def input_shape(self): 

2113 """Retrieves the input shape(s) of a layer. 

2114 

2115 Only applicable if the layer has exactly one input, 

2116 i.e. if it is connected to one incoming layer, or if all inputs 

2117 have the same shape. 

2118 

2119 Returns: 

2120 Input shape, as an integer shape tuple 

2121 (or list of shape tuples, one tuple per input tensor). 

2122 

2123 Raises: 

2124 AttributeError: if the layer has no defined input_shape. 

2125 RuntimeError: if called in Eager mode. 

2126 """ 

2127 if not self._inbound_nodes: 

2128 raise AttributeError( 

2129 f'The layer "{self.name}" has never been called ' 

2130 "and thus has no defined input shape. Note that the " 

2131 "`input_shape` property is only available for " 

2132 "Functional and Sequential models." 

2133 ) 

2134 all_input_shapes = set( 

2135 [str(node.input_shapes) for node in self._inbound_nodes] 

2136 ) 

2137 if len(all_input_shapes) == 1: 

2138 return self._inbound_nodes[0].input_shapes 

2139 else: 

2140 raise AttributeError( 

2141 'The layer "' 

2142 + str(self.name) 

2143 + '" has multiple inbound nodes, ' 

2144 "with different input shapes. Hence " 

2145 'the notion of "input shape" is ' 

2146 "ill-defined for the layer. " 

2147 "Use `get_input_shape_at(node_index)` " 

2148 "instead." 

2149 ) 

2150 

2151 def count_params(self): 

2152 """Count the total number of scalars composing the weights. 

2153 

2154 Returns: 

2155 An integer count. 

2156 

2157 Raises: 

2158 ValueError: if the layer isn't yet built 

2159 (in which case its weights aren't yet defined). 

2160 """ 

2161 if not self.built: 

2162 if getattr(self, "_is_graph_network", False): 

2163 with tf_utils.maybe_init_scope(self): 

2164 self._maybe_build(self.inputs) 

2165 else: 

2166 raise ValueError( 

2167 "You tried to call `count_params` " 

2168 f"on layer {self.name}" 

2169 ", but the layer isn't built. " 

2170 "You can build it manually via: " 

2171 f"`{self.name}.build(batch_input_shape)`." 

2172 ) 

2173 return layer_utils.count_params(self.weights) 

2174 

2175 @property 

2176 @doc_controls.do_not_doc_inheritable 

2177 def output_shape(self): 

2178 """Retrieves the output shape(s) of a layer. 

2179 

2180 Only applicable if the layer has one output, 

2181 or if all outputs have the same shape. 

2182 

2183 Returns: 

2184 Output shape, as an integer shape tuple 

2185 (or list of shape tuples, one tuple per output tensor). 

2186 

2187 Raises: 

2188 AttributeError: if the layer has no defined output shape. 

2189 RuntimeError: if called in Eager mode. 

2190 """ 

2191 if not self._inbound_nodes: 

2192 raise AttributeError( 

2193 f'The layer "{self.name}" has never been called ' 

2194 "and thus has no defined output shape." 

2195 ) 

2196 all_output_shapes = set( 

2197 [str(node.output_shapes) for node in self._inbound_nodes] 

2198 ) 

2199 if len(all_output_shapes) == 1: 

2200 return self._inbound_nodes[0].output_shapes 

2201 else: 

2202 raise AttributeError( 

2203 'The layer "%s"' 

2204 " has multiple inbound nodes, " 

2205 "with different output shapes. Hence " 

2206 'the notion of "output shape" is ' 

2207 "ill-defined for the layer. " 

2208 "Use `get_output_shape_at(node_index)` " 

2209 "instead." % self.name 

2210 ) 

2211 

2212 @property 

2213 def dtype_policy(self): 

2214 """The dtype policy associated with this layer. 

2215 

2216 This is an instance of a `tf.keras.mixed_precision.Policy`. 

2217 """ 

2218 return self._dtype_policy 

2219 

2220 @property 

2221 def compute_dtype(self): 

2222 """The dtype of the layer's computations. 

2223 

2224 This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless 

2225 mixed precision is used, this is the same as `Layer.dtype`, the dtype of 

2226 the weights. 

2227 

2228 Layers automatically cast their inputs to the compute dtype, which 

2229 causes computations and the output to be in the compute dtype as well. 

2230 This is done by the base Layer class in `Layer.__call__`, so you do not 

2231 have to insert these casts if implementing your own layer. 

2232 

2233 Layers often perform certain internal computations in higher precision 

2234 when `compute_dtype` is float16 or bfloat16 for numeric stability. The 

2235 output will still typically be float16 or bfloat16 in such cases. 

2236 

2237 Returns: 

2238 The layer's compute dtype. 

2239 """ 

2240 return self._dtype_policy.compute_dtype 

2241 

2242 @property 

2243 def variable_dtype(self): 

2244 """Alias of `Layer.dtype`, the dtype of the weights.""" 

2245 return self.dtype 

2246 

2247 @property 

2248 @doc_controls.do_not_doc_inheritable 

2249 def inbound_nodes(self): 

2250 """Return Functional API nodes upstream of this layer.""" 

2251 return self._inbound_nodes 

2252 

2253 @property 

2254 @doc_controls.do_not_doc_inheritable 

2255 def outbound_nodes(self): 

2256 """Return Functional API nodes downstream of this layer.""" 

2257 return self._outbound_nodes 

2258 

2259 ############################################################################ 

2260 # Methods & attributes below are public aliases of other methods. # 

2261 ############################################################################ 

2262 

2263 @property 

2264 @doc_controls.do_not_generate_docs 

2265 def variables(self): 

2266 """Returns the list of all layer variables/weights. 

2267 

2268 Alias of `self.weights`. 

2269 

2270 Note: This will not track the weights of nested `tf.Modules` that are 

2271 not themselves Keras layers. 

2272 

2273 Returns: 

2274 A list of variables. 

2275 """ 

2276 return self.weights 

2277 

2278 @property 

2279 @doc_controls.do_not_generate_docs 

2280 def trainable_variables(self): 

2281 return self.trainable_weights 

2282 

2283 @property 

2284 @doc_controls.do_not_generate_docs 

2285 def non_trainable_variables(self): 

2286 return self.non_trainable_weights 

2287 

2288 @doc_controls.do_not_doc_inheritable 

2289 def add_variable(self, *args, **kwargs): 

2290 """Deprecated, do NOT use! Alias for `add_weight`.""" 

2291 warnings.warn( 

2292 "`layer.add_variable` is deprecated and " 

2293 "will be removed in a future version. " 

2294 "Please use the `layer.add_weight()` method instead.", 

2295 stacklevel=2, 

2296 ) 

2297 return self.add_weight(*args, **kwargs) 

2298 

2299 def get_build_config(self): 

2300 """Returns a dictionary with the layer's input shape. 

2301 

2302 This method returns a config dict that can be used by 

2303 `build_from_config(config)` to create all states (e.g. Variables and 

2304 Lookup tables) needed by the layer. 

2305 

2306 By default, the config only contains the input shape that the layer 

2307 was built with. If you're writing a custom layer that creates state in 

2308 an unusual way, you should override this method to make sure this state 

2309 is already created when Keras attempts to load its value upon model 

2310 loading. 

2311 

2312 Returns: 

2313 A dict containing the input shape associated with the layer. 

2314 """ 

2315 if self._build_input_shape is not None: 

2316 

2317 def convert_tensorshapes(x): 

2318 if isinstance(x, tf.TensorShape) and x._dims: 

2319 return tuple(x.as_list()) 

2320 return x 

2321 

2322 return { 

2323 "input_shape": tf.nest.map_structure( 

2324 convert_tensorshapes, self._build_input_shape 

2325 ) 

2326 } 

2327 

2328 def build_from_config(self, config): 

2329 """Builds the layer's states with the supplied config dict. 

2330 

2331 By default, this method calls the `build(config["input_shape"])` method, 

2332 which creates weights based on the layer's input shape in the supplied 

2333 config. If your config contains other information needed to load the 

2334 layer's state, you should override this method. 

2335 

2336 Args: 

2337 config: Dict containing the input shape associated with this layer. 

2338 """ 

2339 input_shape = config["input_shape"] 

2340 if input_shape is not None: 

2341 self.build(input_shape) 

2342 

2343 ############################################################################ 

2344 # Methods & attributes below are all private and only used by the framework. 

2345 ############################################################################ 

2346 

2347 # See tf.Module for the usage of this property. 

2348 # The key for _obj_reference_counts_dict is a Trackable, which could be a 

2349 # variable or layer etc. tf.Module._flatten will fail to flatten the key 

2350 # since it is trying to convert Trackable to a string. This attribute can be 

2351 # ignored even after the fix of nest lib, since the trackable object should 

2352 # already been available as individual attributes. 

2353 # _obj_reference_counts_dict just contains a copy of them. 

2354 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 

2355 itertools.chain( 

2356 ("_obj_reference_counts_dict",), 

2357 tf.Module._TF_MODULE_IGNORED_PROPERTIES, 

2358 ) 

2359 ) 

2360 

2361 # When loading from a SavedModel, Layers typically can be revived into a 

2362 # generic Layer wrapper. Sometimes, however, layers may implement methods 

2363 # that go beyond this wrapper, as in the case of PreprocessingLayers' 

2364 # `adapt` method. When this is the case, layer implementers can override 

2365 # must_restore_from_config to return True; layers with this property must 

2366 # be restored into their actual objects (and will fail if the object is 

2367 # not available to the restoration code). 

2368 _must_restore_from_config = False 

2369 

2370 def _get_cell_name(self): 

2371 canonical_name = get_canonical_name_for_symbol( 

2372 self.__class__, api_name="keras", add_prefix_to_v1_names=True 

2373 ) 

2374 if canonical_name is not None: 

2375 return f"tf.{canonical_name}" 

2376 return self.__class__.__module__ + "." + self.__class__.__name__ 

2377 

2378 def _instrument_layer_creation(self): 

2379 self._instrumented_keras_api = False 

2380 self._instrumented_keras_layer_class = False 

2381 self._instrumented_keras_model_class = False 

2382 if not getattr(self, "_disable_keras_instrumentation", False): 

2383 keras_api_gauge.get_cell("layer").set(True) 

2384 self._instrumented_keras_api = True 

2385 if getattr(self, "_is_model_for_instrumentation", False): 

2386 keras_models_gauge.get_cell(self._get_cell_name()).set(True) 

2387 self._instrumented_keras_model_class = True 

2388 else: 

2389 keras_layers_gauge.get_cell(self._get_cell_name()).set(True) 

2390 self._instrumented_keras_layer_class = True 

2391 else: 

2392 # This is a legacy layer that has disabled instrumentation 

2393 # as a native keras object. We still instrument this as 

2394 # legacy usage. 

2395 keras_api_gauge.get_cell("legacy_layer").set(True) 

2396 

2397 @doc_controls.for_subclass_implementers 

2398 def _add_trackable(self, trackable_object, trainable): 

2399 """Adds a Trackable object to this layer's state. 

2400 

2401 Args: 

2402 trackable_object: The tf.tracking.Trackable object to add. 

2403 trainable: Boolean, whether the variable should be part of the layer's 

2404 "trainable_variables" (e.g. variables, biases) or 

2405 "non_trainable_variables" (e.g. BatchNorm mean and variance). 

2406 

2407 Returns: 

2408 The TrackableWeightHandler used to track this object. 

2409 """ 

2410 if isinstance( 

2411 trackable_object, base_layer_utils.TrackableWeightHandler 

2412 ): 

2413 handler = trackable_object 

2414 else: 

2415 handler = base_layer_utils.TrackableWeightHandler(trackable_object) 

2416 if trainable: 

2417 self._trainable_weights.append(handler) 

2418 else: 

2419 self._non_trainable_weights.append(handler) 

2420 return handler 

2421 

2422 def _clear_losses(self): 

2423 """Used every step in eager to reset losses.""" 

2424 # Set to thread local directly to avoid Layer.__setattr__ overhead. 

2425 if not getattr( 

2426 self, "_self_tracked_trackables", None 

2427 ): # Fast path for single Layer. 

2428 self._thread_local._eager_losses = [] 

2429 else: 

2430 for layer in self._flatten_layers(): 

2431 layer._thread_local._eager_losses = [] 

2432 

2433 def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): 

2434 if self.dynamic: 

2435 # We will use static shape inference to return symbolic tensors 

2436 # matching the specifications of the layer outputs. 

2437 # Since `self.dynamic` is True, we will never attempt to 

2438 # run the underlying TF graph (which is disconnected). 

2439 # TODO(fchollet): consider py_func as an alternative, which 

2440 # would enable us to run the underlying graph if needed. 

2441 input_signature = tf.nest.map_structure( 

2442 lambda x: tf.TensorSpec(shape=x.shape, dtype=x.dtype), inputs 

2443 ) 

2444 output_signature = self.compute_output_signature(input_signature) 

2445 return tf.nest.map_structure( 

2446 keras_tensor.KerasTensor, output_signature 

2447 ) 

2448 else: 

2449 return self._infer_output_signature( 

2450 inputs, args, kwargs, input_masks 

2451 ) 

2452 

2453 def _infer_output_signature(self, inputs, args, kwargs, input_masks): 

2454 """Call the layer on input KerasTensors, returns output KerasTensors.""" 

2455 

2456 keras_tensor_inputs = inputs 

2457 call_fn = self.call 

2458 # Wrapping `call` function in autograph to allow for dynamic control 

2459 # flow and control dependencies in call. We are limiting this to 

2460 # subclassed layers as autograph is strictly needed only for 

2461 # subclassed layers and models. 

2462 # tf_convert will respect the value of autograph setting in the 

2463 # enclosing tf.function, if any. 

2464 if base_layer_utils.is_subclassed( 

2465 self 

2466 ) and not base_layer_utils.from_saved_model(self): 

2467 call_fn = tf.__internal__.autograph.tf_convert( 

2468 self.call, tf.__internal__.autograph.control_status_ctx() 

2469 ) 

2470 

2471 call_fn = traceback_utils.inject_argument_info_in_traceback( 

2472 call_fn, 

2473 object_name=f'layer "{self.name}" (type {self.__class__.__name__})', 

2474 ) 

2475 

2476 # We enter a scratch graph and build placeholder inputs inside of it 

2477 # that match the input args. 

2478 # We then call the layer inside of the scratch graph to identify the 

2479 # output signatures, then we build KerasTensors corresponding to those 

2480 # outputs. 

2481 scratch_graph = tf.__internal__.FuncGraph( 

2482 str(self.name) + "_scratch_graph" 

2483 ) 

2484 with scratch_graph.as_default(): 

2485 inputs = tf.nest.map_structure( 

2486 keras_tensor.keras_tensor_to_placeholder, inputs 

2487 ) 

2488 args = tf.nest.map_structure( 

2489 keras_tensor.keras_tensor_to_placeholder, args 

2490 ) 

2491 kwargs = tf.nest.map_structure( 

2492 keras_tensor.keras_tensor_to_placeholder, kwargs 

2493 ) 

2494 input_masks = tf.nest.map_structure( 

2495 keras_tensor.keras_tensor_to_placeholder, input_masks 

2496 ) 

2497 

2498 with backend.name_scope(self._name_scope()): 

2499 with autocast_variable.enable_auto_cast_variables( 

2500 self._compute_dtype_object 

2501 ): 

2502 # Build layer if applicable (if the `build` method has been 

2503 # overridden). 

2504 # TODO(kaftan): do we maybe_build here, or have we already 

2505 # done it? 

2506 self._maybe_build(inputs) 

2507 inputs = self._maybe_cast_inputs(inputs) 

2508 outputs = call_fn(inputs, *args, **kwargs) 

2509 

2510 self._handle_activity_regularization(inputs, outputs) 

2511 self._set_mask_metadata( 

2512 inputs, outputs, input_masks, build_graph=False 

2513 ) 

2514 outputs = tf.nest.map_structure( 

2515 keras_tensor.keras_tensor_from_tensor, outputs 

2516 ) 

2517 

2518 self._set_save_spec(keras_tensor_inputs, args, kwargs) 

2519 if hasattr(self, "_set_inputs") and not self.inputs: 

2520 # TODO(kaftan): figure out if we need to do this at all 

2521 # Subclassed network: explicitly set metadata normally set by 

2522 # a call to self._set_inputs(). 

2523 self._set_inputs(inputs, outputs) 

2524 del scratch_graph 

2525 return outputs 

2526 

2527 def _functional_construction_call(self, inputs, args, kwargs, input_list): 

2528 call_context = base_layer_utils.call_context() 

2529 

2530 # Accept NumPy and scalar inputs by converting to Tensors. 

2531 if any( 

2532 isinstance(x, (tf.Tensor, np.ndarray, float, int)) 

2533 for x in input_list 

2534 ): 

2535 

2536 def _convert_non_tensor(x): 

2537 # Don't call `ops.convert_to_tensor` on all `inputs` because 

2538 # `SparseTensors` can't be converted to `Tensor`. 

2539 if isinstance(x, (tf.Tensor, np.ndarray, float, int)): 

2540 return tf.convert_to_tensor(x) 

2541 return x 

2542 

2543 inputs = tf.nest.map_structure(_convert_non_tensor, inputs) 

2544 input_list = tf.nest.flatten(inputs) 

2545 

2546 # Handle `mask` propagation from previous layer to current layer. Masks 

2547 # can be propagated explicitly via the `mask` argument, or implicitly 

2548 # via setting the `_keras_mask` attribute on the inputs to a Layer. 

2549 # Masks passed explicitly take priority. 

2550 mask_arg_passed_by_framework = False 

2551 input_masks, mask_is_implicit = self._get_input_masks( 

2552 inputs, input_list, args, kwargs 

2553 ) 

2554 if self._expects_mask_arg and mask_is_implicit: 

2555 kwargs["mask"] = input_masks 

2556 mask_arg_passed_by_framework = True 

2557 

2558 # If `training` argument is None or not explicitly passed, 

2559 # propagate `training` value from this layer's calling layer. 

2560 training_value = None 

2561 training_arg_passed_by_framework = False 

2562 # Priority 1: `training` was explicitly passed a non-None value. 

2563 if self._call_spec.arg_was_passed("training", args, kwargs): 

2564 training_value = self._call_spec.get_arg_value( 

2565 "training", args, kwargs 

2566 ) 

2567 if not self._expects_training_arg: 

2568 kwargs.pop("training") 

2569 

2570 if training_value is None: 

2571 # Priority 2: `training` was passed to a parent layer. 

2572 if call_context.training is not None: 

2573 training_value = call_context.training 

2574 # Priority 3: `learning_phase()` has been set. 

2575 elif backend.global_learning_phase_is_set(): 

2576 training_value = backend.learning_phase() 

2577 # Force the training_value to be bool type which matches to the 

2578 # contract for layer/model call args. 

2579 if tf.is_tensor(training_value): 

2580 training_value = tf.cast(training_value, tf.bool) 

2581 else: 

2582 training_value = bool(training_value) 

2583 # Priority 4: trace layer with the default training argument 

2584 # specified in the `call` signature (or in inference mode if the 

2585 # `call` signature specifies no non-None default). 

2586 else: 

2587 training_value = self._call_spec.default_training_arg 

2588 # In cases (2), (3), (4) the training argument is passed 

2589 # automatically by the framework, and will not be hard-coded into 

2590 # the model. 

2591 if self._expects_training_arg: 

2592 args, kwargs = self._call_spec.set_arg_value( 

2593 "training", training_value, args, kwargs 

2594 ) 

2595 training_arg_passed_by_framework = True 

2596 

2597 with call_context.enter( 

2598 layer=self, inputs=inputs, build_graph=True, training=training_value 

2599 ): 

2600 # Check input assumptions set after layer building, e.g. input 

2601 # shape. 

2602 try: 

2603 outputs = self._keras_tensor_symbolic_call( 

2604 inputs, input_masks, args, kwargs 

2605 ) 

2606 except TypeError as e: 

2607 if "DictWrapper" in str(e): 

2608 raise TypeError( 

2609 f"{self} could not be deserialized properly. Please" 

2610 " ensure that components that are Python object" 

2611 " instances (layers, models, etc.) returned by" 

2612 " `get_config()` are explicitly deserialized in the" 

2613 " model's `from_config()` method." 

2614 ) from e 

2615 else: 

2616 raise e 

2617 

2618 if outputs is None: 

2619 raise ValueError( 

2620 "A layer's `call` method should return a " 

2621 "Tensor or a list of Tensors, not None " 

2622 "(layer: " + self.name + ")." 

2623 ) 

2624 if training_arg_passed_by_framework: 

2625 args, kwargs = self._call_spec.set_arg_value( 

2626 "training", None, args, kwargs, pop_kwarg_if_none=True 

2627 ) 

2628 if mask_arg_passed_by_framework: 

2629 kwargs.pop("mask") 

2630 # Node connectivity does not special-case the first argument. 

2631 outputs = self._set_connectivity_metadata( 

2632 (inputs,) + args, kwargs, outputs 

2633 ) 

2634 return outputs 

2635 

2636 def _set_training_mode(self, args, kwargs, call_context): 

2637 training_mode = None 

2638 if self._expects_training_arg: 

2639 # (1) `training` was passed to this `Layer.call`. 

2640 if self._call_spec.arg_was_passed("training", args, kwargs): 

2641 training_mode = self._call_spec.get_arg_value( 

2642 "training", args, kwargs 

2643 ) 

2644 # If no `training` arg was passed, or `None` was explicitly passed, 

2645 # the framework will make a decision about the training mode is. 

2646 if training_mode is None: 

2647 call_ctx_training = call_context.training 

2648 # (2) `training` mode is inferred from an outer `Layer.call`. 

2649 if call_ctx_training is not None: 

2650 training_mode = call_ctx_training 

2651 # (3) User set `tf.keras.backend.set_learning_phase`. 

2652 elif backend.global_learning_phase_is_set(): 

2653 training_mode = backend.learning_phase() 

2654 # Ensure value is a `bool` or `tf.bool`. 

2655 if isinstance(training_mode, bool): 

2656 pass 

2657 elif tf.is_tensor(training_mode): 

2658 training_mode = tf.cast(training_mode, tf.bool) 

2659 else: 

2660 training_mode = bool(training_mode) 

2661 # (4) We default to using `call`'s default value for `training`, 

2662 # or treating the layer as if it is in inference if no non-None 

2663 # default is specified in the `call` signature. 

2664 else: 

2665 training_mode = self._call_spec.default_training_arg 

2666 

2667 # For case (2), (3), (4) `training` arg is passed by framework. 

2668 args, kwargs = self._call_spec.set_arg_value( 

2669 "training", training_mode, args, kwargs 

2670 ) 

2671 else: 

2672 if "training" in kwargs: 

2673 # `training` was passed to this `Layer` but is not needed for 

2674 # `Layer.call`. It will set the default mode for inner 

2675 # `Layer.call`s. 

2676 training_mode = kwargs.pop("training") 

2677 else: 

2678 # Grab the current `training` mode from any outer `Layer.call`. 

2679 training_mode = call_context.training 

2680 

2681 return args, kwargs, training_mode 

2682 

2683 def _autographed_call(self): 

2684 # Wrapping `call` function in autograph to allow for dynamic control 

2685 # flow and control dependencies in call. We are limiting this to 

2686 # subclassed layers as autograph is strictly needed only for 

2687 # subclassed layers and models. 

2688 # tf_convert will respect the value of autograph setting in the 

2689 # enclosing tf.function, if any. 

2690 if base_layer_utils.is_subclassed( 

2691 self 

2692 ) and not base_layer_utils.from_saved_model(self): 

2693 return tf.__internal__.autograph.tf_convert( 

2694 self.call, tf.__internal__.autograph.control_status_ctx() 

2695 ) 

2696 else: 

2697 return self.call 

2698 

2699 @property 

2700 def _inbound_nodes(self): 

2701 return self._inbound_nodes_value 

2702 

2703 @_inbound_nodes.setter 

2704 @tf.__internal__.tracking.no_automatic_dependency_tracking 

2705 def _inbound_nodes(self, value): 

2706 self._inbound_nodes_value = value 

2707 

2708 @property 

2709 def _outbound_nodes(self): 

2710 return self._outbound_nodes_value 

2711 

2712 @_outbound_nodes.setter 

2713 @tf.__internal__.tracking.no_automatic_dependency_tracking 

2714 def _outbound_nodes(self, value): 

2715 self._outbound_nodes_value = value 

2716 

2717 def _set_dtype_policy(self, dtype): 

2718 """Sets self._dtype_policy.""" 

2719 self._dtype_policy = policy.get_policy(dtype) 

2720 

2721 # Performance optimization: cache the compute dtype as a Dtype object or 

2722 # None, so that str to Dtype conversion doesn't happen in 

2723 # Layer.__call__. 

2724 # TODO(b/157486353): Investigate returning DTypes in Policy. 

2725 if self._dtype_policy.compute_dtype: 

2726 self._compute_dtype_object = tf.as_dtype( 

2727 self._dtype_policy.compute_dtype 

2728 ) 

2729 else: 

2730 self._compute_dtype_object = None 

2731 

2732 @property 

2733 def _compute_dtype(self): 

2734 """Deprecated alias of `compute_dtype`.""" 

2735 return self._dtype_policy.compute_dtype 

2736 

2737 def _maybe_cast_inputs(self, inputs, input_list=None): 

2738 """Maybe casts the inputs to the compute dtype. 

2739 

2740 If self._compute_dtype is floating-point, and self_autocast is True, 

2741 floating-point inputs are casted to self._compute_dtype. 

2742 

2743 Args: 

2744 inputs: Input tensor, or structure of input tensors. 

2745 input_list: Flat list of input tensors. 

2746 

2747 Returns: 

2748 `inputs`, but tensors may have been casted to self._compute_dtype 

2749 """ 

2750 if not input_list: 

2751 input_list = tf.nest.flatten(inputs) 

2752 

2753 compute_dtype_object = self._compute_dtype_object 

2754 should_autocast = ( 

2755 self._autocast 

2756 and compute_dtype_object 

2757 and compute_dtype_object.is_floating 

2758 ) 

2759 

2760 if should_autocast and any( 

2761 map(self._should_cast_single_input, input_list) 

2762 ): 

2763 # Only perform expensive `nest` operation when needed. 

2764 return tf.nest.map_structure(self._cast_single_input, inputs) 

2765 else: 

2766 return inputs 

2767 

2768 def _should_cast_single_input(self, x): 

2769 if isinstance(x, _AUTOCAST_TYPES): 

2770 return ( 

2771 self._compute_dtype_object 

2772 and x.dtype != self._compute_dtype_object 

2773 and x.dtype.is_floating 

2774 ) 

2775 return False 

2776 

2777 def _cast_single_input(self, x): 

2778 """Cast a single Tensor or TensorSpec to the compute dtype.""" 

2779 if self._should_cast_single_input(x): 

2780 return tf.cast(x, self._compute_dtype_object) 

2781 else: 

2782 return x 

2783 

2784 # _dtype used to be an attribute set in the constructor. We still expose it 

2785 # because some clients still use it. 

2786 # TODO(reedwm): Deprecate, then remove the _dtype property. 

2787 @property 

2788 def _dtype(self): 

2789 # This is equivalent to returning self.dtype . We do not return 

2790 # self.dtype as it would cause infinite recursion in a few subclasses, 

2791 # which override "dtype" to return self._dtype. 

2792 return self._dtype_policy.variable_dtype 

2793 

2794 @_dtype.setter 

2795 def _dtype(self, value): 

2796 value = tf.as_dtype(value).name 

2797 self._set_dtype_policy(policy.Policy(value)) 

2798 

2799 def _name_scope(self): 

2800 if not tf.__internal__.tf2.enabled(): 

2801 return self.name 

2802 name_scope = self.name 

2803 current_name_scope = tf.__internal__.get_name_scope() 

2804 if current_name_scope: 

2805 name_scope = current_name_scope + "/" + name_scope 

2806 if name_scope: 

2807 # Note that the trailing `/` prevents autogenerated 

2808 # numerical suffixes to get appended. It will also fully reset 

2809 # nested name scope (i.e. the outer name scope has no effect). 

2810 name_scope += "/" 

2811 return name_scope 

2812 

2813 def _init_set_name(self, name, zero_based=True): 

2814 if name is None: 

2815 self._name = backend.unique_object_name( 

2816 generic_utils.to_snake_case(self.__class__.__name__), 

2817 zero_based=zero_based, 

2818 ) 

2819 elif isinstance(name, str): 

2820 backend.observe_object_name(name) 

2821 self._name = name 

2822 else: 

2823 raise TypeError( 

2824 f"Expected `name` argument to be a string, but got: {name}" 

2825 ) 

2826 

2827 def _get_existing_metric(self, name=None): 

2828 match = [m for m in self._metrics if m.name == name] 

2829 if not match: 

2830 return 

2831 if len(match) > 1: 

2832 raise ValueError( 

2833 "Please provide different names for the metrics you have " 

2834 'added. We found {} metrics with the name: "{}"'.format( 

2835 len(match), name 

2836 ) 

2837 ) 

2838 return match[0] 

2839 

2840 def _handle_weight_regularization(self, name, variable, regularizer): 

2841 """Create lambdas which compute regularization losses.""" 

2842 

2843 def _loss_for_variable(v): 

2844 """Creates a regularization loss `Tensor` for variable `v`.""" 

2845 with backend.name_scope(name + "/Regularizer"): 

2846 regularization = regularizer(v) 

2847 return regularization 

2848 

2849 if base_layer_utils.is_split_variable(variable): 

2850 for v in variable: 

2851 self.add_loss(functools.partial(_loss_for_variable, v)) 

2852 elif isinstance(variable, lazy_variable.LazyInitVariable): 

2853 self._captured_weight_regularizer.append( 

2854 (name, variable, regularizer) 

2855 ) 

2856 else: 

2857 self.add_loss(functools.partial(_loss_for_variable, variable)) 

2858 

2859 def _handle_activity_regularization(self, inputs, outputs): 

2860 # Apply activity regularization. 

2861 # Note that it should be applied every time the layer creates a new 

2862 # output, since it is output-specific. 

2863 if self._activity_regularizer: 

2864 output_list = tf.nest.flatten(outputs) 

2865 with backend.name_scope("ActivityRegularizer"): 

2866 for output in output_list: 

2867 activity_loss = tf.convert_to_tensor( 

2868 self._activity_regularizer(output) 

2869 ) 

2870 batch_size = tf.cast( 

2871 tf.shape(output)[0], activity_loss.dtype 

2872 ) 

2873 # Make activity regularization strength batch-agnostic. 

2874 mean_activity_loss = activity_loss / batch_size 

2875 self.add_loss(mean_activity_loss) 

2876 

2877 def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph): 

2878 # Many `Layer`s don't need to call `compute_mask`. 

2879 # This method is optimized to do as little work as needed for the common 

2880 # case. 

2881 if not self._supports_masking: 

2882 return 

2883 

2884 flat_outputs = tf.nest.flatten(outputs) 

2885 

2886 mask_already_computed = getattr( 

2887 self, "_compute_output_and_mask_jointly", False 

2888 ) or all( 

2889 getattr(x, "_keras_mask", None) is not None for x in flat_outputs 

2890 ) 

2891 if mask_already_computed: 

2892 if build_graph: 

2893 self._set_mask_keras_history_checked(flat_outputs) 

2894 return 

2895 

2896 output_masks = self.compute_mask(inputs, previous_mask) 

2897 if output_masks is None: 

2898 return 

2899 

2900 flat_masks = tf.nest.flatten(output_masks) 

2901 for tensor, mask in zip(flat_outputs, flat_masks): 

2902 try: 

2903 tensor._keras_mask = mask 

2904 except AttributeError: 

2905 # C Type such as np.ndarray. 

2906 pass 

2907 

2908 if build_graph: 

2909 self._set_mask_keras_history_checked(flat_outputs) 

2910 

2911 def _set_mask_keras_history_checked(self, flat_outputs): 

2912 for output in flat_outputs: 

2913 if getattr(output, "_keras_mask", None) is not None: 

2914 # Do not track masks for `TensorFlowOpLayer` construction. 

2915 output._keras_mask._keras_history_checked = True 

2916 

2917 def _get_input_masks(self, inputs, input_list, args, kwargs): 

2918 if not self._supports_masking and not self._expects_mask_arg: 

2919 # Input masks only need to be retrieved if they are needed for 

2920 # `call` or `compute_mask`. 

2921 input_masks = None 

2922 implicit_mask = False 

2923 elif self._call_spec.arg_was_passed("mask", args, kwargs): 

2924 input_masks = self._call_spec.get_arg_value("mask", args, kwargs) 

2925 implicit_mask = False 

2926 else: 

2927 input_masks = [getattr(t, "_keras_mask", None) for t in input_list] 

2928 if all(mask is None for mask in input_masks): 

2929 input_masks = None 

2930 implicit_mask = False 

2931 else: 

2932 # Only do expensive `nest` op when masking is actually being 

2933 # used. 

2934 input_masks = tf.nest.pack_sequence_as(inputs, input_masks) 

2935 implicit_mask = True 

2936 return input_masks, implicit_mask 

2937 

2938 def _set_connectivity_metadata(self, args, kwargs, outputs): 

2939 # If the layer returns tensors from its inputs unmodified, 

2940 # we copy them to avoid loss of KerasHistory metadata. 

2941 flat_outputs = tf.nest.flatten(outputs) 

2942 flat_inputs = tf.nest.flatten((args, kwargs)) 

2943 input_ids_set = {id(i) for i in flat_inputs} 

2944 outputs_copy = [] 

2945 for x in flat_outputs: 

2946 if id(x) in input_ids_set: 

2947 with backend.name_scope(self.name): 

2948 x = tf.identity(x) 

2949 outputs_copy.append(x) 

2950 outputs = tf.nest.pack_sequence_as(outputs, outputs_copy) 

2951 

2952 # Create node, Node wires itself to inbound and outbound layers. The 

2953 # Node constructor actually updates this layer's self._inbound_nodes, 

2954 # sets _keras_history on the outputs, and adds itself to the 

2955 # `_outbound_nodes` of the layers that produced the inputs to this layer 

2956 # call. 

2957 node_module.Node( 

2958 self, call_args=args, call_kwargs=kwargs, outputs=outputs 

2959 ) 

2960 return outputs 

2961 

2962 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 

2963 """Private utility to retrieves an attribute (e.g. inputs) from a node. 

2964 

2965 This is used to implement the methods: 

2966 - get_input_shape_at 

2967 - get_output_shape_at 

2968 - get_input_at 

2969 etc... 

2970 

2971 Args: 

2972 node_index: Integer index of the node from which 

2973 to retrieve the attribute. 

2974 attr: Exact node attribute name. 

2975 attr_name: Human-readable attribute name, for error messages. 

2976 

2977 Returns: 

2978 The layer's attribute `attr` at the node of index `node_index`. 

2979 

2980 Raises: 

2981 RuntimeError: If the layer has no inbound nodes, or if called in 

2982 Eager mode. 

2983 ValueError: If the index provided does not match any node. 

2984 """ 

2985 if not self._inbound_nodes: 

2986 raise RuntimeError( 

2987 f"The layer {self.name} has never been called " 

2988 f"and thus has no defined {attr_name}." 

2989 ) 

2990 if not len(self._inbound_nodes) > node_index: 

2991 raise ValueError( 

2992 f"Asked to get {attr_name} at node " 

2993 f"{node_index}, but the layer has only " 

2994 f"{len(self._inbound_nodes)} inbound nodes." 

2995 ) 

2996 values = getattr(self._inbound_nodes[node_index], attr) 

2997 if isinstance(values, list) and len(values) == 1: 

2998 return values[0] 

2999 else: 

3000 return values 

3001 

3002 def _maybe_build(self, inputs): 

3003 # Check input assumptions set before layer building, e.g. input rank. 

3004 if not self.built: 

3005 input_spec.assert_input_compatibility( 

3006 self.input_spec, inputs, self.name 

3007 ) 

3008 input_list = tf.nest.flatten(inputs) 

3009 if input_list and self._dtype_policy.compute_dtype is None: 

3010 try: 

3011 dtype = input_list[0].dtype.base_dtype.name 

3012 except AttributeError: 

3013 pass 

3014 else: 

3015 self._set_dtype_policy(policy.Policy(dtype)) 

3016 input_shapes = None 

3017 # Converts Tensors / CompositeTensors to TensorShapes. 

3018 if any(hasattr(x, "shape") for x in input_list): 

3019 input_shapes = tf_utils.get_shapes(inputs) 

3020 else: 

3021 # Converts input shape to TensorShapes. 

3022 try: 

3023 input_shapes = tf_utils.convert_shapes( 

3024 inputs, to_tuples=False 

3025 ) 

3026 except ValueError: 

3027 pass 

3028 # Only call `build` if the user has manually overridden the build 

3029 # method. 

3030 if not hasattr(self.build, "_is_default"): 

3031 # Any setup work performed only once should happen in an 

3032 # `init_scope` to avoid creating symbolic Tensors that will 

3033 # later pollute any eager operations. 

3034 with tf_utils.maybe_init_scope(self): 

3035 self.build(input_shapes) 

3036 # We must set also ensure that the layer is marked as built, and the 

3037 # build shape is stored since user defined build functions may not 

3038 # be calling `super.build()` 

3039 Layer.build(self, input_shapes) 

3040 

3041 # Optionally load weight values specified at layer instantiation. 

3042 if self._initial_weights is not None: 

3043 with tf.init_scope(): 

3044 # Using `init_scope` since we want variable assignment in 

3045 # `set_weights` to be treated like variable initialization. 

3046 self.set_weights(self._initial_weights) 

3047 self._initial_weights = None 

3048 

3049 def _get_trainable_state(self): 

3050 """Get the `trainable` state of each sublayer. 

3051 

3052 Returns: 

3053 A dict mapping all sublayers to their `trainable` value. 

3054 """ 

3055 trainable_state = weakref.WeakKeyDictionary() 

3056 for layer in self._flatten_layers(): 

3057 trainable_state[layer] = layer.trainable 

3058 return trainable_state 

3059 

3060 def _set_trainable_state(self, trainable_state): 

3061 """Set `trainable` state for each sublayer.""" 

3062 for layer in self._flatten_layers(): 

3063 if layer in trainable_state: 

3064 layer.trainable = trainable_state[layer] 

3065 

3066 @property 

3067 def _obj_reference_counts(self): 

3068 """A dict counting the number of attributes referencing an object.""" 

3069 self._maybe_create_attribute( 

3070 "_obj_reference_counts_dict", 

3071 object_identity.ObjectIdentityDictionary(), 

3072 ) 

3073 return self._obj_reference_counts_dict 

3074 

3075 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3076 def _maybe_create_attribute(self, name, default_value): 

3077 """Create attribute (with the default value) if it hasn't been created. 

3078 

3079 This is useful for fields that is used for tracking purpose, 

3080 _trainable_weights, or _layers. Note that user could create a layer 

3081 subclass and assign an internal field before invoking the 

3082 Layer.__init__(), the __setattr__() need to create the tracking fields 

3083 and __init__() need to not override them. 

3084 

3085 Args: 

3086 name: String, the name of the attribute. 

3087 default_value: Object, the default value of the attribute. 

3088 """ 

3089 if not hasattr(self, name): 

3090 self.__setattr__(name, default_value) 

3091 

3092 def __delattr__(self, name): 

3093 # For any super.__delattr__() call, we will directly use the 

3094 # implementation in Trackable and skip the behavior in AutoTrackable. 

3095 # The Layer was originally use Trackable as base class, the change of 

3096 # using Module as base class forced us to have AutoTrackable in the 

3097 # class hierarchy. 

3098 # 

3099 # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and 

3100 # __setattr__ in AutoTrackable may be unsustainable. 

3101 existing_value = getattr(self, name, None) 

3102 

3103 # If this value is replacing an existing object assigned to an 

3104 # attribute, we should clean it out to avoid leaking memory. First we 

3105 # check if there are other attributes referencing it. 

3106 reference_counts = self._obj_reference_counts 

3107 if existing_value not in reference_counts: 

3108 super(tf.__internal__.tracking.AutoTrackable, self).__delattr__( 

3109 name 

3110 ) 

3111 return 

3112 

3113 reference_count = reference_counts[existing_value] 

3114 if reference_count > 1: 

3115 # There are other remaining references. We can't remove this object 

3116 # from _layers etc. 

3117 reference_counts[existing_value] = reference_count - 1 

3118 super(tf.__internal__.tracking.AutoTrackable, self).__delattr__( 

3119 name 

3120 ) 

3121 return 

3122 else: 

3123 # This is the last remaining reference. 

3124 del reference_counts[existing_value] 

3125 

3126 super(tf.__internal__.tracking.AutoTrackable, self).__delattr__(name) 

3127 

3128 if isinstance(existing_value, Layer) or base_layer_utils.has_weights( 

3129 existing_value 

3130 ): 

3131 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__( 

3132 "_self_tracked_trackables", 

3133 [ 

3134 l 

3135 for l in self._self_tracked_trackables 

3136 if l is not existing_value 

3137 ], 

3138 ) 

3139 if isinstance(existing_value, tf.Variable): 

3140 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__( 

3141 "_trainable_weights", 

3142 [w for w in self._trainable_weights if w is not existing_value], 

3143 ) 

3144 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__( 

3145 "_non_trainable_weights", 

3146 [ 

3147 w 

3148 for w in self._non_trainable_weights 

3149 if w is not existing_value 

3150 ], 

3151 ) 

3152 

3153 def __setattr__(self, name, value): 

3154 if ( 

3155 name == "_self_setattr_tracking" 

3156 or not getattr(self, "_self_setattr_tracking", True) 

3157 # Exclude @property.setters from tracking 

3158 or hasattr(self.__class__, name) 

3159 ): 

3160 try: 

3161 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__( 

3162 name, value 

3163 ) 

3164 except AttributeError: 

3165 raise AttributeError( 

3166 ( 

3167 'Can\'t set the attribute "{}", likely because it ' 

3168 "conflicts with an existing read-only @property of the " 

3169 "object. Please choose a different name." 

3170 ).format(name) 

3171 ) 

3172 return 

3173 

3174 # Wraps data structures in `Trackable`, unwraps `NoDependency` objects. 

3175 value = tf.__internal__.tracking.sticky_attribute_assignment( 

3176 trackable=self, value=value, name=name 

3177 ) 

3178 

3179 reference_counts = self._obj_reference_counts 

3180 reference_counts[value] = reference_counts.get(value, 0) + 1 

3181 

3182 # When replacing an existing tf.Variable with a new one, we want to 

3183 # check its existing position in the 

3184 # self._trainable/non_trainable_variable, so that we can put it back to 

3185 # the original position. 

3186 if isinstance(value, tf.Variable) and isinstance( 

3187 getattr(self, name, None), tf.Variable 

3188 ): 

3189 existing_variable = getattr(self, name) 

3190 

3191 def _get_variable_from_list(var_list, var): 

3192 # helper function to get the tf.variable from the list 

3193 # the default list.index() use == for comparison, which will 

3194 # cause issue for eager tensor. 

3195 for i in range(len(var_list)): 

3196 if var_list[i] is var: 

3197 return i 

3198 return None 

3199 

3200 if existing_variable.trainable: 

3201 self._maybe_create_attribute("_trainable_weights", []) 

3202 position = _get_variable_from_list( 

3203 self._trainable_weights, existing_variable 

3204 ) 

3205 else: 

3206 self._maybe_create_attribute("_non_trainable_variable", []) 

3207 position = _get_variable_from_list( 

3208 self._non_trainable_variable, existing_variable 

3209 ) 

3210 else: 

3211 position = None 

3212 

3213 # Clean out the old attribute, which clears _layers and 

3214 # _trainable_weights if necessary. 

3215 try: 

3216 self.__delattr__(name) 

3217 except AttributeError: 

3218 pass 

3219 

3220 # Keep track of metric instance created in subclassed layer. 

3221 for val in tf.nest.flatten(value): 

3222 if isinstance(val, metrics_mod.Metric) and hasattr( 

3223 self, "_metrics" 

3224 ): 

3225 self._metrics.append(val) 

3226 

3227 # Append value to self._self_tracked_trackables if relevant 

3228 if getattr(self, "_auto_track_sub_layers", True) and ( 

3229 isinstance(value, tf.Module) or base_layer_utils.has_weights(value) 

3230 ): 

3231 self._maybe_create_attribute("_self_tracked_trackables", []) 

3232 # We need to check object identity to avoid de-duplicating empty 

3233 # container types which compare equal. 

3234 if not any( 

3235 (layer is value for layer in self._self_tracked_trackables) 

3236 ): 

3237 self._self_tracked_trackables.append(value) 

3238 if hasattr(value, "_use_resource_variables"): 

3239 # Legacy layers (V1 tf.layers) must always use 

3240 # resource variables. 

3241 value._use_resource_variables = True 

3242 

3243 # Append value to list of trainable / non-trainable weights if relevant 

3244 # TODO(b/125122625): This won't pick up on any variables added to a 

3245 # list/dict after creation. 

3246 self._track_variables(value, position=position) 

3247 

3248 # TODO(b/180760306) Skip the auto trackable from tf.Module to keep 

3249 # status quo. See the comment at __delattr__. 

3250 super(tf.__internal__.tracking.AutoTrackable, self).__setattr__( 

3251 name, value 

3252 ) 

3253 

3254 def _update_trackables(self): 

3255 """Track variables added to lists/dicts after creation""" 

3256 for trackable_obj in self._self_tracked_trackables: 

3257 if isinstance( 

3258 trackable_obj, tf.__internal__.tracking.TrackableDataStructure 

3259 ): 

3260 self._track_variables(trackable_obj) 

3261 

3262 def _track_variables(self, value, position=None): 

3263 """Tracks `Variable`s including `Variable`s in `CompositeTensor`s.""" 

3264 for val in tf.nest.flatten(value): 

3265 if isinstance(val, tf.Variable): 

3266 self._track_variable(val, position=position) 

3267 elif tf_utils.is_extension_type(val): 

3268 # Manually expand extension types to track resource variables. 

3269 nested_vals = tf_utils.type_spec_from_value(val)._to_components( 

3270 val 

3271 ) 

3272 self._track_variables(nested_vals, position=position) 

3273 

3274 def _track_variable(self, val, position=None): 

3275 """Tracks the given `tf.Variable`.""" 

3276 # Users may add extra weights/variables simply by assigning them to 

3277 # attributes (invalid for graph networks) 

3278 self._maybe_create_attribute("_trainable_weights", []) 

3279 self._maybe_create_attribute("_non_trainable_weights", []) 

3280 if val.trainable: 

3281 if any(val is w for w in self._trainable_weights): 

3282 return 

3283 if position is None: 

3284 self._trainable_weights.append(val) 

3285 else: 

3286 self._trainable_weights.insert(position, val) 

3287 else: 

3288 if any(val is w for w in self._non_trainable_weights): 

3289 return 

3290 if position is None: 

3291 self._non_trainable_weights.append(val) 

3292 else: 

3293 self._non_trainable_weights.insert(position, val) 

3294 backend.track_variable(val) 

3295 

3296 def _gather_children_attribute(self, attribute): 

3297 assert attribute in { 

3298 "variables", 

3299 "trainable_variables", 

3300 "non_trainable_variables", 

3301 } 

3302 if hasattr(self, "_self_tracked_trackables"): 

3303 nested_layers = self._flatten_modules( 

3304 include_self=False, recursive=False 

3305 ) 

3306 return list( 

3307 itertools.chain.from_iterable( 

3308 getattr(layer, attribute) for layer in nested_layers 

3309 ) 

3310 ) 

3311 return [] 

3312 

3313 def _flatten_layers(self, recursive=True, include_self=True): 

3314 for m in self._flatten_modules( 

3315 recursive=recursive, include_self=include_self 

3316 ): 

3317 if isinstance(m, Layer): 

3318 yield m 

3319 

3320 def _flatten_modules(self, recursive=True, include_self=True): 

3321 """Flattens `tf.Module` instances (excluding `Metrics`). 

3322 

3323 Args: 

3324 recursive: Whether to recursively flatten through submodules. 

3325 include_self: Whether to include this `Layer` instance. 

3326 

3327 Yields: 

3328 `tf.Module` instance tracked by this `Layer`. 

3329 """ 

3330 if include_self: 

3331 yield self 

3332 

3333 # Only instantiate set and deque if needed. 

3334 trackables = getattr(self, "_self_tracked_trackables", None) 

3335 if trackables: 

3336 seen_object_ids = set() 

3337 deque = collections.deque(trackables) 

3338 while deque: 

3339 trackable_obj = deque.popleft() 

3340 trackable_id = id(trackable_obj) 

3341 if trackable_id in seen_object_ids: 

3342 continue 

3343 seen_object_ids.add(trackable_id) 

3344 

3345 # Metrics are not considered part of the Layer's topology. 

3346 if isinstance(trackable_obj, tf.Module) and not isinstance( 

3347 trackable_obj, metrics_mod.Metric 

3348 ): 

3349 yield trackable_obj 

3350 # Introspect recursively through sublayers. 

3351 if recursive: 

3352 subtrackables = getattr( 

3353 trackable_obj, "_self_tracked_trackables", None 

3354 ) 

3355 if subtrackables: 

3356 deque.extendleft(reversed(subtrackables)) 

3357 elif isinstance( 

3358 trackable_obj, 

3359 tf.__internal__.tracking.TrackableDataStructure, 

3360 ): 

3361 # Data structures are introspected even with 

3362 # `recursive=False`. 

3363 tracked_values = trackable_obj._values 

3364 if tracked_values: 

3365 deque.extendleft(reversed(tracked_values)) 

3366 

3367 # This is a hack so that the is_layer (within 

3368 # training/trackable/layer_utils.py) check doesn't get the weights attr. 

3369 # TODO(b/110718070): Remove when fixed. 

3370 def _is_layer(self): 

3371 return True 

3372 

3373 def _init_call_fn_args(self, expects_training_arg=None): 

3374 self._call_spec = layer_utils.CallFunctionSpec( 

3375 tf_inspect.getfullargspec(self.call) 

3376 ) 

3377 if expects_training_arg is not None: 

3378 self._call_spec.expects_training_arg = expects_training_arg 

3379 

3380 @property 

3381 def _expects_training_arg(self): 

3382 """Whether the call function uses 'training' as a parameter.""" 

3383 return self._call_spec.expects_training_arg 

3384 

3385 @property 

3386 def _expects_mask_arg(self): 

3387 return self._call_spec.expects_mask_arg 

3388 

3389 @property 

3390 def _eager_losses(self): 

3391 # A list of loss values containing activity regularizers and losses 

3392 # manually added through `add_loss` during eager execution. It is 

3393 # cleared after every batch. Because we plan on eventually allowing a 

3394 # same model instance to be trained in eager mode or graph mode 

3395 # alternatively, we need to keep track of eager losses and symbolic 

3396 # losses via separate attributes. 

3397 if not hasattr(self._thread_local, "_eager_losses"): 

3398 self._thread_local._eager_losses = [] 

3399 return self._thread_local._eager_losses 

3400 

3401 @_eager_losses.setter 

3402 def _eager_losses(self, losses): 

3403 self._thread_local._eager_losses = losses 

3404 

3405 def _dedup_weights(self, weights): 

3406 """Dedupe weights while maintaining order as much as possible.""" 

3407 output, seen_ids = [], set() 

3408 for w in weights: 

3409 if id(w) not in seen_ids: 

3410 output.append(w) 

3411 # Track the Variable's identity to avoid __eq__ issues. 

3412 seen_ids.add(id(w)) 

3413 return output 

3414 

3415 # SavedModel properties. Please see keras/saving/saved_model for details. 

3416 

3417 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3418 def _set_save_spec(self, inputs, args=None, kwargs=None): 

3419 """Defines the save spec so that serialization can trace layer calls. 

3420 

3421 The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are 

3422 saved into a tuple of `([inputs] + args, kwargs)`. 

3423 

3424 Args: 

3425 inputs: possibly nested inputs passed into the call function. 

3426 args: a list of positional arguments passed into call. 

3427 kwargs: a dictionary of keyword arguments passed into call. 

3428 """ 

3429 if self._saved_model_inputs_spec is not None: 

3430 return # Already set. 

3431 

3432 inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs) 

3433 args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or []) 

3434 kwargs_spec = {} 

3435 # Filter out non-tensor arguments from kwargs. 

3436 for key, kwarg in kwargs.items(): 

3437 flat_kwarg = tf.nest.flatten(kwarg) 

3438 flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg] 

3439 if any(s is None for s in flat_specs): 

3440 continue 

3441 kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs) 

3442 

3443 self._saved_model_inputs_spec = inputs_spec 

3444 self._saved_model_arg_spec = ( 

3445 [inputs_spec] + list(args_spec), 

3446 kwargs_spec, 

3447 ) 

3448 

3449 def _get_save_spec(self, dynamic_batch=True, inputs_only=True): 

3450 if self._saved_model_inputs_spec is None: 

3451 return None 

3452 

3453 spec = tf.nest.map_structure( 

3454 lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), 

3455 self._saved_model_arg_spec, 

3456 ) 

3457 return spec[0][0] if inputs_only else spec 

3458 

3459 @property 

3460 def _trackable_saved_model_saver(self): 

3461 return layer_serialization.LayerSavedModelSaver(self) 

3462 

3463 @property 

3464 def _object_identifier(self): 

3465 return self._trackable_saved_model_saver.object_identifier 

3466 

3467 @property 

3468 def _tracking_metadata(self): 

3469 """Info about this layer to be saved into the SavedModel.""" 

3470 return self._trackable_saved_model_saver.tracking_metadata 

3471 

3472 def _trackable_children(self, save_type="checkpoint", **kwargs): 

3473 if save_type == "savedmodel": 

3474 cache = kwargs["cache"] 

3475 # TODO(b/213628533): This must be called before super() to ensure 

3476 # that any input shape changes are applied before getting the config 

3477 # of the model. 

3478 children = self._trackable_saved_model_saver.trackable_children( 

3479 cache 

3480 ) 

3481 else: 

3482 children = {} 

3483 children.update(super()._trackable_children(save_type, **kwargs)) 

3484 return children 

3485 

3486 @property 

3487 def _use_input_spec_as_call_signature(self): 

3488 # Whether input spec can be used as the call signature when tracing the 

3489 # Layer for SavedModel. By default, this is set to `True` for layers 

3490 # exported from the Keras library, because the layers more rigidly 

3491 # define the `input_specs` property (many custom layers only set the 

3492 # `ndims`) 

3493 return ( 

3494 get_canonical_name_for_symbol(type(self), api_name="keras") 

3495 is not None 

3496 ) 

3497 

3498 def __getstate__(self): 

3499 # Override to support `copy.deepcopy` and pickling. 

3500 # Thread-local objects cannot be copied in Python 3, so pop these. 

3501 # Thread-local objects are used to cache losses in MirroredStrategy, and 

3502 # so shouldn't be copied. 

3503 state = self.__dict__.copy() 

3504 state.pop("_thread_local", None) 

3505 state.pop("_metrics_lock", None) 

3506 return state 

3507 

3508 def __setstate__(self, state): 

3509 state["_thread_local"] = threading.local() 

3510 state["_metrics_lock"] = threading.Lock() 

3511 # Bypass Trackable logic as `__dict__` already contains this info. 

3512 object.__setattr__(self, "__dict__", state) 

3513 

3514 def save_own_variables(self, store): 

3515 """Saves the state of the layer. 

3516 

3517 You can override this method to take full control of how the state of 

3518 the layer is saved upon calling `model.save()`. 

3519 

3520 Args: 

3521 store: Dict where the state of the model will be saved. 

3522 """ 

3523 all_vars = self._trainable_weights + self._non_trainable_weights 

3524 for i, v in enumerate(all_vars): 

3525 store[f"{i}"] = v.numpy() 

3526 

3527 def load_own_variables(self, store): 

3528 """Loads the state of the layer. 

3529 

3530 You can override this method to take full control of how the state of 

3531 the layer is loaded upon calling `keras.models.load_model()`. 

3532 

3533 Args: 

3534 store: Dict from which the state of the model will be loaded. 

3535 """ 

3536 self._update_trackables() 

3537 all_vars = self._trainable_weights + self._non_trainable_weights 

3538 if len(store.keys()) != len(all_vars): 

3539 raise ValueError( 

3540 f"Layer '{self.name}' expected {len(all_vars)} variables, " 

3541 "but received " 

3542 f"{len(store.keys())} variables during loading. " 

3543 f"Expected: {[v.name for v in all_vars]}" 

3544 ) 

3545 for i, v in enumerate(all_vars): 

3546 # TODO(rchao): check shapes and raise errors. 

3547 v.assign(store[f"{i}"]) 

3548 

3549 

3550class TensorFlowOpLayer(Layer): 

3551 """Wraps a TensorFlow Operation in a Layer. 

3552 

3553 This class is used internally by the Functional API. When a user 

3554 uses a raw TensorFlow Operation on symbolic tensors originating 

3555 from an `Input` Layer, the resultant operation will be wrapped 

3556 with this Layer object in order to make the operation compatible 

3557 with the Keras API. 

3558 

3559 This Layer will create a new, identical operation (except for inputs 

3560 and outputs) every time it is called. If `run_eagerly` is `True`, 

3561 the op creation and calculation will happen inside an Eager function. 

3562 

3563 Instances of this Layer are created when `autolambda` is called, which 

3564 is whenever a Layer's `__call__` encounters symbolic inputs that do 

3565 not have Keras metadata, or when a Network's `__init__` encounters 

3566 outputs that do not have Keras metadata. 

3567 

3568 Attributes: 

3569 node_def: String, the serialized NodeDef of the Op this layer will wrap. 

3570 name: String, the name of the Layer. 

3571 constants: Dict of NumPy arrays, the values of any Tensors needed for this 

3572 Operation that do not originate from a Keras `Input` Layer. Since all 

3573 placeholders must come from Keras `Input` Layers, these Tensors must be 

3574 treated as constant in the Functional API. 

3575 trainable: Bool, whether this Layer is trainable. Currently Variables are 

3576 not supported, and so this parameter has no effect. 

3577 dtype: The default dtype of this Layer. Inherited from `Layer` and has no 

3578 effect on this class, however is used in `get_config`. 

3579 """ 

3580 

3581 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3582 def __init__( 

3583 self, node_def, name, constants=None, trainable=True, dtype=None 

3584 ): 

3585 # Pass autocast=False, as if inputs are cast, input types might not 

3586 # match Operation type. 

3587 super(TensorFlowOpLayer, self).__init__( 

3588 name=_TF_OP_LAYER_NAME_PREFIX + name, 

3589 trainable=trainable, 

3590 dtype=dtype, 

3591 autocast=False, 

3592 ) 

3593 if isinstance(node_def, dict): 

3594 self.node_def = json_format.ParseDict( 

3595 node_def, tf.compat.v1.NodeDef() 

3596 ) 

3597 else: 

3598 if not isinstance(node_def, bytes): 

3599 node_def = node_def.encode("utf-8") 

3600 self.node_def = tf.compat.v1.NodeDef.FromString(node_def) 

3601 # JSON serialization stringifies keys which are integer input indices. 

3602 self.constants = ( 

3603 {int(index): constant for index, constant in constants.items()} 

3604 if constants is not None 

3605 else {} 

3606 ) 

3607 # Layer uses original op unless it is called on new inputs. 

3608 # This means `built` is not set in `__call__`. 

3609 self.built = True 

3610 

3611 # Do not individually trace TensorflowOpLayers in the SavedModel. 

3612 self._must_restore_from_config = True 

3613 

3614 def call(self, inputs): 

3615 if tf.executing_eagerly(): 

3616 return self._defun_call(inputs) 

3617 return self._make_op(inputs) 

3618 

3619 def _make_node_def(self, graph): 

3620 node_def = tf.compat.v1.NodeDef() 

3621 node_def.CopyFrom(self.node_def) 

3622 # Used in TPUReplicateContext to indicate whether this node has been 

3623 # cloned and to not add TPU attributes. 

3624 node_def.attr["_cloned"].b = True 

3625 node_def.name = graph.unique_name(node_def.name) 

3626 return node_def 

3627 

3628 def _make_op(self, inputs): 

3629 inputs = tf.nest.flatten(inputs) 

3630 graph = inputs[0].graph 

3631 node_def = self._make_node_def(graph) 

3632 with graph.as_default(): 

3633 for index, constant in self.constants.items(): 

3634 # Recreate constant in graph to add distribution context. 

3635 value = tf.get_static_value(constant) 

3636 if value is not None: 

3637 if isinstance(value, dict): 

3638 value = serialization_lib.deserialize_keras_object( 

3639 value 

3640 ) 

3641 constant = tf.constant(value, name=node_def.input[index]) 

3642 inputs.insert(index, constant) 

3643 # TODO(b/183990973): We should drop or consolidate these private api 

3644 # calls for adding an op to the graph and recording its gradient. 

3645 c_op = tf.__internal__.create_c_op( 

3646 graph, node_def, inputs, control_inputs=[] 

3647 ) 

3648 op = graph._create_op_from_tf_operation(c_op) 

3649 op._control_flow_post_processing() 

3650 

3651 # Record the gradient because custom-made ops don't go through the 

3652 # code-gen'd eager call path 

3653 op_type = tf.compat.as_str(op.op_def.name) 

3654 attr_names = [ 

3655 tf.compat.as_str(attr.name) for attr in op.op_def.attr 

3656 ] 

3657 attrs = [] 

3658 for attr_name in attr_names: 

3659 attrs.append(attr_name) 

3660 attrs.append(op.get_attr(attr_name)) 

3661 attrs = tuple(attrs) 

3662 tf.__internal__.record_gradient( 

3663 op_type, op.inputs, attrs, op.outputs 

3664 ) 

3665 

3666 if len(op.outputs) == 1: 

3667 return op.outputs[0] 

3668 return op.outputs 

3669 

3670 @tf.function 

3671 def _defun_call(self, inputs): 

3672 """Wraps op creation method in an Eager function for `run_eagerly`.""" 

3673 return self._make_op(inputs) 

3674 

3675 def get_config(self): 

3676 config = super(TensorFlowOpLayer, self).get_config() 

3677 config.update( 

3678 { 

3679 # `__init__` prefixes the name. Revert to the constructor 

3680 # argument. 

3681 "name": config["name"][len(_TF_OP_LAYER_NAME_PREFIX) :], 

3682 "node_def": json_format.MessageToDict(self.node_def), 

3683 "constants": { 

3684 i: backend.get_value(c) for i, c in self.constants.items() 

3685 }, 

3686 } 

3687 ) 

3688 return config 

3689 

3690 

3691class AddLoss(Layer): 

3692 """Adds its inputs as a loss. 

3693 

3694 Attributes: 

3695 unconditional: Whether or not the loss should be conditioned on the 

3696 inputs. 

3697 """ 

3698 

3699 def __init__(self, unconditional, **kwargs): 

3700 # Pass autocast=False, as there is no reason to cast loss to a different 

3701 # dtype. 

3702 kwargs["autocast"] = False 

3703 super(AddLoss, self).__init__(**kwargs) 

3704 self.unconditional = unconditional 

3705 

3706 def call(self, inputs): 

3707 self.add_loss(inputs, inputs=(not self.unconditional)) 

3708 return inputs 

3709 

3710 def get_config(self): 

3711 config = super(AddLoss, self).get_config() 

3712 config.update({"unconditional": self.unconditional}) 

3713 return config 

3714 

3715 

3716class AddMetric(Layer): 

3717 """Adds its inputs as a metric. 

3718 

3719 Attributes: 

3720 aggregation: 'mean' or None. How the inputs should be aggregated. 

3721 metric_name: The name to use for this metric. 

3722 """ 

3723 

3724 def __init__(self, aggregation=None, metric_name=None, **kwargs): 

3725 super(AddMetric, self).__init__(**kwargs) 

3726 self.aggregation = aggregation 

3727 self.metric_name = metric_name 

3728 

3729 def call(self, inputs): 

3730 self.add_metric( 

3731 inputs, aggregation=self.aggregation, name=self.metric_name 

3732 ) 

3733 return inputs 

3734 

3735 def get_config(self): 

3736 config = super(AddMetric, self).get_config() 

3737 config.update( 

3738 {"aggregation": self.aggregation, "metric_name": self.metric_name} 

3739 ) 

3740 return config 

3741 

3742 

3743def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): 

3744 """Check the arguments to see if we are constructing a functional model.""" 

3745 # We are constructing a functional model if any of the inputs 

3746 # are KerasTensors 

3747 return any( 

3748 isinstance(tensor, keras_tensor.KerasTensor) 

3749 for tensor in tf.nest.flatten([inputs, args, kwargs]) 

3750 ) 

3751 

3752 

3753def _convert_numpy_or_python_types(x): 

3754 if isinstance(x, (tf.Tensor, np.ndarray, float, int)): 

3755 return tf.convert_to_tensor(x) 

3756 return x 

3757 

3758 

3759@keras_export("keras.__internal__.apply_name_scope_on_model_declaration", v1=[]) 

3760def _apply_name_scope_on_model_declaration(enable): 

3761 """Apply `with tf.name_scope(...)` on model declaration. 

3762 

3763 ```python 

3764 tf.keras.__internal__.apply_name_scope_on_model_declaration(True) 

3765 

3766 inputs = input_layer.Input((3,)) 

3767 with tf.name_scope('MyScope'): 

3768 outputs = layers.Dense(10, name='MyDense')(inputs) 

3769 model = tf.keras.Model(inputs, outputs) 

3770 

3771 # with `tf.keras.__internal__.apply_name_scope_on_model_declaration(True)`, 

3772 # The name of the dense layer is "model/MyScope/MyDense/*", and without, 

3773 # "model/MyDense/*" 

3774 ``` 

3775 

3776 Args: 

3777 enable: Enables if `True`, disables if `False`. 

3778 """ 

3779 if not isinstance(enable, bool): 

3780 raise TypeError( 

3781 f"`enable` argument must be `True` or `False`, got {enable}" 

3782 ) 

3783 

3784 global _is_name_scope_on_model_declaration_enabled 

3785 _is_name_scope_on_model_declaration_enabled = enable 

3786 

3787 

3788@keras_export("keras.__internal__.layers.BaseRandomLayer") 

3789class BaseRandomLayer(Layer): 

3790 """A layer handle the random number creation and savemodel behavior.""" 

3791 

3792 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3793 def __init__( 

3794 self, seed=None, force_generator=False, rng_type=None, **kwargs 

3795 ): 

3796 """Initialize the BaseRandomLayer. 

3797 

3798 Note that the constructor is annotated with 

3799 @no_automatic_dependency_tracking. This is to skip the auto 

3800 tracking of self._random_generator instance, which is an AutoTrackable. 

3801 The backend.RandomGenerator could contain a tf.random.Generator instance 

3802 which will have tf.Variable as the internal state. We want to avoid 

3803 saving that state into model.weights and checkpoints for backward 

3804 compatibility reason. In the meantime, we still need to make them 

3805 visible to SavedModel when it is tracing the tf.function for the 

3806 `call()`. 

3807 See _list_extra_dependencies_for_serialization below for more details. 

3808 

3809 Args: 

3810 seed: optional integer, used to create RandomGenerator. 

3811 force_generator: boolean, default to False, whether to force the 

3812 RandomGenerator to use the code branch of tf.random.Generator. 

3813 rng_type: string, the rng type that will be passed to backend 

3814 RandomGenerator. `None` will allow RandomGenerator to choose 

3815 types by itself. Valid values are "stateful", "stateless", 

3816 "legacy_stateful". Defaults to `None`. 

3817 **kwargs: other keyword arguments that will be passed to the parent 

3818 *class 

3819 """ 

3820 super().__init__(**kwargs) 

3821 self._random_generator = backend.RandomGenerator( 

3822 seed, force_generator=force_generator, rng_type=rng_type 

3823 ) 

3824 

3825 def build(self, input_shape): 

3826 super().build(input_shape) 

3827 self._random_generator._maybe_init() 

3828 

3829 def _trackable_children(self, save_type="checkpoint", **kwargs): 

3830 if save_type == "savedmodel": 

3831 cache = kwargs["cache"] 

3832 # TODO(b/213628533): This must be called before super() to ensure 

3833 # that any input shape changes are applied before getting the config 

3834 # of the model. 

3835 children = self._trackable_saved_model_saver.trackable_children( 

3836 cache 

3837 ) 

3838 # This method exposes the self._random_generator to SavedModel only 

3839 # (not layer.weights and checkpoint). 

3840 children["_random_generator"] = self._random_generator 

3841 else: 

3842 children = {} 

3843 children.update(super()._trackable_children(save_type, **kwargs)) 

3844 return children 

3845 

3846 def _lookup_dependency(self, name): 

3847 # When loading from a Keras SavedModel load, make sure that the loader 

3848 # can find the random generator, otherwise the loader will assume that 

3849 # it does not exist, and will try to create a new generator. 

3850 if name == "_random_generator": 

3851 return self._random_generator 

3852 else: 

3853 return super()._lookup_dependency(name) 

3854