Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py: 27%

314 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains private utilities used mainly by the base Layer class.""" 

16 

17import functools 

18import threading 

19 

20from tensorflow.python import tf2 

21from tensorflow.python.distribute import distribute_lib 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import sparse_tensor 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_util 

28from tensorflow.python.keras import backend 

29from tensorflow.python.keras.utils import control_flow_util 

30from tensorflow.python.keras.utils import tf_inspect 

31from tensorflow.python.keras.utils import tf_utils 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import variable_v1 

34from tensorflow.python.ops import variables as tf_variables 

35from tensorflow.python.ops.ragged import ragged_tensor 

36from tensorflow.python.trackable import base as tracking 

37from tensorflow.python.training.saving import saveable_object_util 

38from tensorflow.python.util import nest 

39from tensorflow.python.util.tf_export import keras_export 

40 

41_call_context = threading.local() 

42 

43 

44def create_mean_metric(value, name=None): 

45 # import keras will import base_layer and then this module, and metric relies 

46 # on base_layer, which result into a cyclic dependency. 

47 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 

48 metric_obj = metrics_module.Mean(name=name, dtype=value.dtype) 

49 return metric_obj, metric_obj(value) 

50 

51 

52def make_variable(name, 

53 shape=None, 

54 dtype=dtypes.float32, 

55 initializer=None, 

56 trainable=None, 

57 caching_device=None, 

58 validate_shape=True, 

59 constraint=None, 

60 use_resource=None, 

61 collections=None, 

62 synchronization=tf_variables.VariableSynchronization.AUTO, 

63 aggregation=tf_variables.VariableAggregation.NONE, 

64 partitioner=None): # pylint: disable=unused-argument 

65 """Temporary util to create a variable (relies on `variable_scope.variable`). 

66 

67 Some reuse-related technicalities prevent us from using 

68 `variable_scope.get_variable()` directly, so we use a subcomponent 

69 that has fewer constraints (`variable_scope.variable()`). 

70 

71 In the longer term, it seems like a similar "default variable creator" method 

72 should exist in `Trackable` instead. When this happens, we can get 

73 rid of this temporary solution. 

74 

75 TODO(fchollet): remove this method when no longer needed. 

76 

77 Args: 

78 name: Variable name. 

79 shape: Variable shape. 

80 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 

81 initializer: Initializer instance (callable). 

82 trainable: Whether the variable should be part of the layer's 

83 "trainable_variables" (e.g. variables, biases) 

84 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 

85 Note, if the current variable scope is marked as non-trainable 

86 then this parameter is ignored and any added variables are also 

87 marked as non-trainable. `trainable` defaults to `True` unless 

88 `synchronization` is set to `ON_READ`. 

89 caching_device: Passed to `tf.Variable`. 

90 validate_shape: Passed to `tf.Variable`. 

91 constraint: Constraint instance (callable). 

92 use_resource: Whether to use a `ResourceVariable`. 

93 collections: List of graph collections keys. The new variable is added to 

94 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

95 synchronization: Indicates when a distributed a variable will be 

96 aggregated. Accepted values are constants defined in the class 

97 `tf.VariableSynchronization`. By default the synchronization is set to 

98 `AUTO` and the current `DistributionStrategy` chooses 

99 when to synchronize. If `synchronization` is set to `ON_READ`, 

100 `trainable` must not be set to `True`. 

101 aggregation: Indicates how a distributed variable will be aggregated. 

102 Accepted values are constants defined in the class 

103 `tf.VariableAggregation`. 

104 partitioner: Not handled at this time. 

105 

106 Returns: 

107 Variable instance. 

108 """ 

109 initializing_from_value = False 

110 if initializer is not None and not callable(initializer): 

111 initializing_from_value = True 

112 

113 if initializing_from_value: 

114 init_val = initializer 

115 variable_dtype = None 

116 else: 

117 # Instantiate initializer if provided initializer is a type object. 

118 if tf_inspect.isclass(initializer): 

119 initializer = initializer() 

120 init_val = functools.partial(initializer, shape, dtype=dtype) 

121 variable_dtype = dtype.base_dtype 

122 if use_resource is None: 

123 use_resource = True 

124 

125 # TODO(apassos,rohanj) figure out how to remove collections from here so we 

126 # can remove the V1. 

127 variable_shape = tensor_shape.TensorShape(shape) 

128 return variable_v1.VariableV1( 

129 initial_value=init_val, 

130 name=name, 

131 trainable=trainable, 

132 caching_device=caching_device, 

133 dtype=variable_dtype, 

134 validate_shape=validate_shape, 

135 constraint=constraint, 

136 use_resource=use_resource, 

137 collections=collections, 

138 synchronization=synchronization, 

139 aggregation=aggregation, 

140 shape=variable_shape if variable_shape else None) 

141 

142 

143def collect_previous_mask(input_tensors): 

144 """Retrieves the output mask(s) of the previous node. 

145 

146 Args: 

147 input_tensors: An arbitrary structure of Tensors. 

148 

149 Returns: 

150 A mask tensor or list of mask tensors. 

151 """ 

152 

153 def _collect_previous_mask(x): 

154 return getattr(x, '_keras_mask', None) 

155 

156 return nest.map_structure(_collect_previous_mask, input_tensors) 

157 

158 

159def have_all_keras_metadata(tensors): 

160 return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors)) 

161 

162 

163def generate_placeholders_from_shape(shape): 

164 return array_ops.placeholder(shape=shape, dtype=backend.floatx()) 

165 

166 

167def create_keras_history(tensors): 

168 """Wraps TensorFlow Operations for compatibility with the Functional API. 

169 

170 This method checks to see if a Tensor in `tensors` is missing Keras metadata 

171 and has its origin in a Keras `Input` Layer. If so, this method will replace 

172 the raw TensorFlow Operations that created this tensor with 

173 `TensorFlowOpLayer` instances that create identical operations. 

174 

175 Any Tensors not originating from a Keras `Input` Layer will be treated as 

176 constants when constructing `TensorFlowOpLayer` instances. 

177 

178 Args: 

179 tensors: A structure of Tensors, some of which come from raw TensorFlow 

180 operations and need to have Keras metadata assigned to them. 

181 

182 Returns: 

183 created_layers: List. The `TensorFlowOpLayer` instances created to wrap 

184 the raw Tensorflow operations. 

185 """ 

186 _, created_layers = _create_keras_history_helper(tensors, set(), []) 

187 return created_layers 

188 

189 

190# Unsafe Internal attribute. 

191# If True, Keras will not evaluate the constant-foldable inputs to tf op 

192# layers in TF1 graphs. This *might* speed up model construction time in 

193# certain settings, but it means 

194# the models will not be serializable/deserializable via get_config 

195# (Only via Savedmodels). It may also change the semantics of whether 

196# generated random numbers are generated once and re-used, or recomputed 

197# each time. 

198# Note: This path triggers for TPUEstimators / xla compiled graphs regardless 

199# of this setting. 

200_UNSAFE_GRAPH_OP_LAYER_CREATION = False 

201 

202 

203def _create_keras_history_helper(tensors, processed_ops, created_layers): 

204 """Helper method for `create_keras_history`. 

205 

206 Args: 

207 tensors: A structure of Tensors for which to create Keras metadata. 

208 processed_ops: Set. TensorFlow operations that have already been wrapped in 

209 `TensorFlowOpLayer` instances. 

210 created_layers: List. The `TensorFlowOpLayer` instances created. 

211 

212 Returns: 

213 Tuple. First element is the updated set of TensorFlow Operations that 

214 have been wrapped in `TensorFlowOpLayer` instances. Second element is 

215 a list of the `TensorFlowOpLayer` instances created. 

216 """ 

217 if ops.executing_eagerly_outside_functions(): 

218 raise ValueError( 

219 '`create_keras_history` should only be called if eager is disabled!') 

220 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`. 

221 # Cannot be imported at top because of circular dependencies. 

222 # TODO(omalleyt): Resolve circular dependency. 

223 from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top 

224 tensor_list = nest.flatten(tensors) 

225 sparse_ops = [] 

226 ragged_tensors = [] 

227 for tensor in tensor_list: 

228 if getattr(tensor, '_keras_history', None) is not None: 

229 continue 

230 if isinstance( 

231 tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 

232 sparse_ops.append(tensor.op) 

233 continue 

234 if tf_utils.is_ragged(tensor): 

235 # Ragged tensors don't have an op property 

236 ragged_tensors.append(tensor) 

237 continue 

238 op = tensor.op # The Op that created this Tensor. 

239 if op not in processed_ops: 

240 # Recursively set `_keras_history`. 

241 op_inputs = list(op.inputs) 

242 constants = {} 

243 layer_inputs = [] 

244 for i, op_input in enumerate(op_inputs): 

245 if uses_keras_history(op_input): 

246 layer_inputs.append(op_input) 

247 else: 

248 # Treat any value not originating from a `keras.Input` as 

249 # a constant. Variables cannot be supported. 

250 ds_with_session = ( 

251 distribute_lib.in_cross_replica_context() and 

252 not ops.executing_eagerly_outside_functions()) 

253 using_xla = control_flow_util.GraphOrParentsInXlaContext( 

254 ops.get_default_graph()) 

255 if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION: 

256 # In Legacy Graph mode, evaluating here makes Session be 

257 # configured improperly. The downside of this is that saving 

258 # via `get_config` breaks, but SavedModel still works. 

259 constants[i] = op_input 

260 else: 

261 with ops.init_scope(): 

262 constants[i] = backend.function([], op_input)([]) 

263 layer_inputs = unnest_if_single_tensor(layer_inputs) 

264 processed_ops, created_layers = _create_keras_history_helper( 

265 layer_inputs, processed_ops, created_layers) 

266 name = op.name 

267 node_def = op.node_def.SerializeToString() 

268 op_layer = base_layer.TensorFlowOpLayer( 

269 node_def, constants=constants, name=name) 

270 created_layers.append(op_layer) 

271 op_layer._set_connectivity_metadata( # pylint: disable=protected-access 

272 args=(layer_inputs,), 

273 kwargs={}, 

274 outputs=op.outputs) 

275 processed_ops.update([op]) 

276 if sparse_ops or ragged_tensors: 

277 lambda_example = """ 

278 weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) 

279 output = tf.keras.layers.Lambda(weights_mult)(input) 

280 """ 

281 raise ValueError( 

282 'Tensorflow ops that generate ragged or sparse tensor ' 

283 'outputs are currently not supported by Keras automatic ' 

284 'op wrapping. Please wrap these ops in a Lambda layer: ' 

285 '\n\n```\n{example}\n```\n' 

286 'Sparse ops encountered: {sparse_ops}\n' 

287 'Ragged tensors encountered: {ragged_tensors}\n'.format( 

288 example=lambda_example, 

289 sparse_ops=str(sparse_ops), 

290 ragged_tensors=str(ragged_tensors))) 

291 return processed_ops, created_layers 

292 

293 

294def unnest_if_single_tensor(input_tensors): 

295 # Preserve compatibility with older configs 

296 flat_input_tensors = nest.flatten(input_tensors) 

297 # If this is a single element but not a dict, unwrap. If this is a dict, 

298 # assume the first layer expects a dict (as is the case with a 

299 # DenseFeatures layer); pass through. 

300 if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1: 

301 input_tensors = flat_input_tensors[0] 

302 return input_tensors 

303 

304 

305def needs_keras_history(tensors, ignore_call_context=False): 

306 """Check if any Tensors need to be wrapped in TensorFlowOpLayers. 

307 

308 This will never return True inside a sublayer, because sublayers 

309 do not need to create Keras History. Otherwise, this returns True 

310 if one or more of `tensors` originates from a `keras.Input` and 

311 does not have `_keras_history` set. 

312 

313 Args: 

314 tensors: An arbitrary nested structure of Tensors. 

315 ignore_call_context: Whether to ignore the check of if currently 

316 outside of a `call` context. This is `True` when creating 

317 KerasHistory inside `Node`, where we always know that Tensors 

318 are being used with the Functional API. 

319 

320 Returns: 

321 Bool, whether at least one Tensor needs to be wrapped. 

322 """ 

323 input_tensors = nest.flatten(tensors) 

324 if call_context().in_call and not ignore_call_context: 

325 return False 

326 if all( 

327 getattr(tensor, '_keras_history', None) is not None 

328 for tensor in input_tensors): 

329 # KerasHistory already set. 

330 return False 

331 return uses_keras_history(tensors) 

332 

333 

334def is_in_keras_graph(): 

335 """Returns if currently executing inside of a Keras graph.""" 

336 return call_context().in_keras_graph 

337 

338 

339def is_in_eager_or_tf_function(): 

340 """Returns if in eager mode or inside of a tf.function.""" 

341 return context.executing_eagerly() or is_in_tf_function() 

342 

343 

344def is_in_tf_function(): 

345 """Returns if inside of a tf.function.""" 

346 # Check if running in V1 graph mode. 

347 if not ops.executing_eagerly_outside_functions(): 

348 return False 

349 if not ops.inside_function(): 

350 return False 

351 # Check if inside Keras FuncGraph. 

352 if is_in_keras_graph(): 

353 return False 

354 # Check for a v1 `wrap_function` FuncGraph. 

355 graph = ops.get_default_graph() 

356 if (getattr(graph, 'name', False) and 

357 graph.name.startswith('wrapped_function')): 

358 return False 

359 return True 

360 

361 

362def uses_keras_history(tensors): 

363 """Check if at least one Tensor originates from a `keras.Input`. 

364 

365 This is `True` if at least one Tensor has its origin in a `keras.Input`. 

366 Any Tensor that originates from a `keras.Input` will have a dependency 

367 Tensor with a `_keras_history` attribute attached. Tensors that have 

368 already been checked to not originate from a `keras.Input` 

369 are marked as `_keras_history_checked`. 

370 

371 Args: 

372 tensors: An arbitrary nested structure of Tensors. 

373 

374 Returns: 

375 Bool, whether at least one Tensor originates from a `keras.Input`. 

376 """ 

377 checked_tensors = set() 

378 tensors_to_check = nest.flatten(tensors) 

379 

380 while tensors_to_check: 

381 new_tensors_to_check = [] 

382 for tensor in tensors_to_check: 

383 if id(tensor) in checked_tensors: 

384 continue 

385 

386 checked_tensors.add(id(tensor)) 

387 

388 if getattr(tensor, '_keras_history_checked', None) is not None: 

389 continue 

390 if getattr(tensor, '_keras_history', None) is not None: 

391 return True 

392 

393 try: 

394 new_tensors_to_check.extend(tensor.op.inputs) 

395 except AttributeError: 

396 # In case `tensor` is a Variable created in an Eager context. 

397 pass 

398 

399 tensors_to_check = new_tensors_to_check 

400 

401 # Mark that these Tensors have been checked once for `_keras_history`, 

402 # and should not be checked again for performance reasons. 

403 mark_checked(tensors) 

404 return False 

405 

406 

407def mark_checked(tensors): 

408 """Marks that these Tensors should not be tracked. 

409 

410 This prevents Layers from attempting to create TensorFlowOpLayers 

411 for these Tensors. 

412 

413 Args: 

414 tensors: An arbitrary structure of Tensors. 

415 """ 

416 

417 def _mark_checked(tensor): 

418 tensor._keras_history_checked = True # pylint: disable=protected-access 

419 

420 nest.map_structure(_mark_checked, tensors) 

421 

422 

423def call_context(): 

424 """Returns currently active `CallContext`.""" 

425 call_ctx = getattr(_call_context, 'call_context', None) 

426 if call_ctx is None: 

427 call_ctx = CallContext() 

428 _call_context.call_context = call_ctx 

429 return call_ctx 

430 

431 

432class CallContext(object): 

433 """Keeps track of properties currently inside a Layer/Model's `call`. 

434 

435 Attributes: 

436 in_call: Whether currently inside the `call` of a Layer. 

437 layer: The `Layer` whose `call` is currently active. 

438 inputs: The inputs to the currently active `Layer`. 

439 build_graph: Whether currently inside a Graph or FuncGraph. 

440 training: Whether currently executing in training or inference mode. 

441 saving: Whether currently saving to SavedModel. 

442 frozen: Whether currently executing inside a `Layer` with `trainable` set to 

443 `False`. 

444 in_keras_graph: Whether executing inside the Keras Graph. 

445 """ 

446 

447 def __init__(self): 

448 # Handle `in_call` separately as it is the most-read attr and reading it is 

449 # on the hot path. 

450 self.in_call = False 

451 self._state = { 

452 'layer': None, 

453 'inputs': None, 

454 'build_graph': False, 

455 'training': None, 

456 'saving': None 

457 } 

458 # TODO(b/150169018): This logic can be replaced after the Functional API 

459 # refactor. 

460 self._in_keras_graph = False 

461 

462 def enter(self, layer, inputs, build_graph, training, saving=None): 

463 """Push a Layer and its inputs and state onto the current call context. 

464 

465 Args: 

466 layer: The `Layer` whose `call` is currently active. 

467 inputs: The inputs to the currently active `Layer`. 

468 build_graph: Whether currently inside a Graph or FuncGraph. 

469 training: Whether currently executing in training or inference mode. 

470 saving: Whether currently saving to SavedModel. 

471 

472 Returns: 

473 Context manager. 

474 """ 

475 state = { 

476 'layer': layer, 

477 'inputs': inputs, 

478 'build_graph': build_graph, 

479 'training': training, 

480 'saving': saving 

481 } 

482 return CallContextManager(self, state) 

483 

484 @property 

485 def layer(self): 

486 return self._state['layer'] 

487 

488 @property 

489 def inputs(self): 

490 return self._state['inputs'] 

491 

492 @property 

493 def build_graph(self): 

494 return self._state['build_graph'] 

495 

496 @property 

497 def training(self): 

498 return self._state['training'] 

499 

500 @property 

501 def saving(self): 

502 return self._state['saving'] 

503 

504 @property 

505 def frozen(self): 

506 layer = self._state['layer'] 

507 if not layer: 

508 return False 

509 return not layer.trainable 

510 

511 @property 

512 def in_keras_graph(self): 

513 # Returns True even if in a subgraph of the Keras graph, such as those 

514 # created by control flow ops. 

515 if context.executing_eagerly(): 

516 return False 

517 return (self._in_keras_graph or 

518 getattr(backend.get_graph(), 'name', None) == 'keras_graph') 

519 

520 

521class CallContextManager(object): 

522 """Context manager for `CallContext`.""" 

523 

524 def __init__(self, call_ctx, state): 

525 self._call_ctx = call_ctx 

526 self._state = state 

527 self._build_graph = state['build_graph'] 

528 

529 def __enter__(self): 

530 call_ctx = self._call_ctx 

531 self._prev_in_call = call_ctx.in_call 

532 self._prev_state = call_ctx._state 

533 

534 call_ctx.in_call = True 

535 call_ctx._state = self._state 

536 

537 # TODO(b/150169018): This logic can be removed after the Functional API 

538 # refactor. 

539 if self._build_graph: 

540 self._prev_in_keras_graph = call_ctx._in_keras_graph 

541 call_ctx._in_keras_graph = ( 

542 call_ctx._in_keras_graph or 

543 getattr(backend.get_graph(), 'name', None) == 'keras_graph') 

544 

545 def __exit__(self, *exc_info): 

546 call_ctx = self._call_ctx 

547 call_ctx.in_call = self._prev_in_call 

548 call_ctx._state = self._prev_state 

549 

550 if self._build_graph: 

551 call_ctx._in_keras_graph = self._prev_in_keras_graph 

552 

553 

554def training_arg_passed_to_call(argspec, args, kwargs): 

555 """Returns whether a user passed the `training` argument in `__call__`.""" 

556 # `argspec.args` starts with ['self', 'inputs'] 

557 full_args = dict(zip(argspec.args[2:], args)) 

558 full_args.update(kwargs) 

559 return 'training' in full_args and full_args['training'] is not None 

560 

561 

562def is_subclassed(layer): 

563 """Returns True if the object is a subclassed layer or subclassed model.""" 

564 return (layer.__module__.find('keras.engine') == -1 and 

565 layer.__module__.find('keras.layers') == -1) 

566 

567 

568def from_saved_model(layer): 

569 """Returns whether the layer is loaded from a SavedModel.""" 

570 return layer.__module__.find('keras.saving.saved_model') != -1 

571 

572 

573def check_graph_consistency(tensor=None, method='add_loss', force_raise=False): 

574 """Checks that tensors passed to `add_*` method match the Keras graph. 

575 

576 When one of the `add_*` method is called inside a V2 conditional branch, 

577 the underlying tensor gets created in a FuncGraph managed by control_flow_v2. 

578 We need to raise clear error messages in such cases. 

579 

580 Args: 

581 tensor: Tensor to check, or `False` if it is known that an error 

582 should be raised. 

583 method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}. 

584 force_raise: If an error should be raised regardless of `tensor`. 

585 

586 Raises: 

587 RuntimeError: In case of an out-of-graph tensor. 

588 """ 

589 if (force_raise or 

590 (ops.executing_eagerly_outside_functions() and 

591 hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)): 

592 if method == 'activity_regularizer': 

593 bad_example = """ 

594 class TestModel(tf.keras.Model): 

595 

596 def __init__(self): 

597 super(TestModel, self).__init__(name='test_model') 

598 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 

599 

600 def call(self, x, training=None): 

601 if training: 

602 return self.dense(x) 

603 else: 

604 return self.dense(x) 

605 """ 

606 correct_example = """ 

607 class TestModel(tf.keras.Model): 

608 

609 def __init__(self): 

610 super(TestModel, self).__init__(name='test_model') 

611 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 

612 

613 def call(self, x, training=None): 

614 return self.dense(x) 

615 """ 

616 raise RuntimeError( 

617 'You are using a layer with `activity_regularizer` in a control flow ' 

618 'branch, e.g.:\n{bad_example}\nThis is currently not supported. ' 

619 'Please move your call to the layer with `activity_regularizer` out ' 

620 'of the control flow branch, e.g.:\n{correct_example}\n' 

621 'You can also resolve this by marking your outer model/layer dynamic' 

622 ' (eager-only) by passing `dynamic=True` to the layer constructor. ' 

623 'Any kind of control flow is supported with dynamic layers. ' 

624 'Note that using `dynamic=True` requires you to implement static ' 

625 'shape inference in the `compute_output_shape(input_shape)` ' 

626 'method.'.format( 

627 bad_example=bad_example, correct_example=correct_example)) 

628 

629 if method == 'add_metric': 

630 bad_example = """ 

631 def call(self, inputs, training=None): 

632 if training: 

633 metric = compute_metric(inputs) 

634 self.add_metric(metric, name='my_metric', aggregation='mean') 

635 return inputs 

636 """ 

637 correct_example = """ 

638 def call(self, inputs, training=None): 

639 if training: 

640 metric = compute_metric(inputs) 

641 else: 

642 metric = 0. 

643 self.add_metric(metric, name='my_metric', aggregation='mean') 

644 return inputs 

645 """ 

646 elif method == 'add_loss': 

647 bad_example = """ 

648 def call(self, inputs, training=None): 

649 if training: 

650 loss = compute_loss(inputs) 

651 self.add_loss(loss) 

652 return inputs 

653 """ 

654 correct_example = """ 

655 def call(self, inputs, training=None): 

656 if training: 

657 loss = compute_loss(inputs) 

658 else: 

659 loss = 0. 

660 self.add_loss(loss) 

661 return inputs 

662 """ 

663 else: 

664 bad_example = """ 

665 def call(self, inputs, training=None): 

666 if training: 

667 self.add_update(self.w.assign_add(1)) 

668 return inputs 

669 """ 

670 correct_example = """ 

671 def call(self, inputs, training=None): 

672 if training: 

673 increment = 1 

674 else: 

675 increment = 0 

676 self.add_update(self.w.assign_add(increment)) 

677 return inputs 

678 """ 

679 raise RuntimeError( 

680 'You are using the method `{method}` in a control flow branch ' 

681 'in your layer, e.g.:\n{bad_example}\n' 

682 'This is not currently supported. ' 

683 'Please move your call to {method} out of the control flow branch, ' 

684 'e.g.:\n{correct_example}\n' 

685 'You can also resolve this by marking your layer ' 

686 'as dynamic (eager-only) by passing ' 

687 '`dynamic=True` to the layer constructor. ' 

688 'Any kind of control flow is supported with dynamic layers. ' 

689 'Note that using `dynamic=True` requires you ' 

690 'to implement static shape inference ' 

691 'in the `compute_output_shape(input_shape)` method.'.format( 

692 method=method, 

693 bad_example=bad_example, 

694 correct_example=correct_example)) 

695 

696 

697def mark_as_return(outputs, acd): 

698 """Marks `outputs` as the return values for automatic control deps.""" 

699 

700 def _mark_as_return(tensor): 

701 """Marks `tensor` as the return value for automatic control deps.""" 

702 if not tensor_util.is_tf_type(tensor): 

703 return tensor 

704 

705 # pylint: disable=protected-access 

706 return_tensor = acd.mark_as_return(tensor) 

707 if getattr(tensor, '_keras_mask', None) is not None: 

708 return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask) 

709 else: 

710 return_tensor._keras_mask = None 

711 

712 # Handle TensorFlow Probability attached metadata. 

713 # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`. 

714 if getattr(tensor, '_tfp_distribution', None) is not None: 

715 return_tensor._tfp_distribution = tensor._tfp_distribution 

716 

717 return return_tensor 

718 # pylint: enable=protected-access 

719 

720 return nest.map_structure(_mark_as_return, outputs) 

721 

722 

723V2_DTYPE_BEHAVIOR = None 

724 

725 

726@keras_export(v1=['keras.layers.enable_v2_dtype_behavior']) 

727def enable_v2_dtype_behavior(): 

728 """Enable the V2 dtype behavior for Keras layers. 

729 

730 By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function 

731 is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since 

732 mixed precision requires V2 dtype behavior to be enabled, this function allows 

733 you to use mixed precision in Keras layers if `disable_v2_behavior` has been 

734 called. 

735 

736 When enabled, the dtype of Keras layers defaults to floatx (which is typically 

737 float32) instead of None. In addition, layers will automatically cast 

738 floating-point inputs to the layer's dtype. 

739 

740 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 

741 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 

742 >>> print(layer.dtype) # float32 since V2 dtype behavior is enabled 

743 float32 

744 >>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled 

745 >>> print(y.dtype.name) 

746 float32 

747 

748 A layer author can opt-out their layer from the automatic input casting by 

749 passing `autocast=False` to the base Layer's constructor. This disables the 

750 autocasting part of the V2 behavior for that layer, but not the defaulting to 

751 floatx part of the V2 behavior. 

752 

753 When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype 

754 will default to the global policy instead of floatx. Layers will automatically 

755 cast inputs to the policy's compute_dtype. 

756 """ 

757 global V2_DTYPE_BEHAVIOR 

758 V2_DTYPE_BEHAVIOR = True 

759 

760 

761@keras_export(v1=['keras.layers.disable_v2_dtype_behavior']) 

762def disable_v2_dtype_behavior(): 

763 """Disables the V2 dtype behavior for Keras layers. 

764 

765 See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`. 

766 """ 

767 global V2_DTYPE_BEHAVIOR 

768 V2_DTYPE_BEHAVIOR = False 

769 

770 

771def v2_dtype_behavior_enabled(): 

772 """Returns True if the V2 dtype behavior is enabled.""" 

773 if V2_DTYPE_BEHAVIOR is None: 

774 return tf2.enabled() 

775 return V2_DTYPE_BEHAVIOR 

776 

777 

778class TrackableWeightHandler(object): 

779 """Keras wrapper for handling tracking.Trackable object saving and restoring. 

780 

781 This class handles Trackables in both V1 and V2 modes, ensuring that they can 

782 be saved and restored with the correct data and without adding additional ops 

783 on every save. 

784 

785 Attributes: 

786 trackable: The trackable to wrap. 

787 num_tensors: The number of tensors that this trackable requires for saving. 

788 """ 

789 

790 def __init__(self, trackable): 

791 if not isinstance(trackable, tracking.Trackable): 

792 raise ValueError('%s is not a Trackable object.' % (trackable,)) 

793 self._trackable = trackable 

794 self._distribute_strategy = distribute_lib.get_strategy() 

795 

796 saveables = saveable_object_util.saveable_objects_from_trackable( 

797 trackable).values() 

798 # 'Saveables' won't exist when we're passed a legacy TF1 table like 

799 # a StaticHashTable. 

800 if not saveables: 

801 self._num_tensors = 0 

802 self._setter = lambda weights: None 

803 self._getter = lambda: [] 

804 

805 elif len(saveables) == 1: 

806 saveable = list(saveables)[0] 

807 

808 if ops.executing_eagerly_outside_functions(): 

809 # If we're in eager mode, we need to defer calling the Trackable's 

810 # saveable() callable until data export time. 

811 # However, it is safe to call the saveable as many times as we want, so 

812 # we will call it now to figure out how many tensors this Trackable will 

813 # produce. 

814 self._saveable = saveable 

815 self._num_tensors = len(self._saveable().specs) 

816 self._setter = lambda weights: self._saveable().restore(weights, None) 

817 self._getter = lambda: [spec.tensor for spec in self._saveable().specs] 

818 else: 

819 # If we're in Graph mode, we need to evaluate the Saveable only once and 

820 # cache the resulting restore graph. Failing to do this will result in 

821 # new assignment ops being added to the graph each time set_weights() is 

822 # called. 

823 self._placeholder_tensors = [] 

824 self._saveable = saveable() 

825 self._num_tensors = len(self._saveable.specs) 

826 for spec in self._saveable.specs: 

827 tensor = spec.tensor 

828 self._placeholder_tensors.append( 

829 array_ops.placeholder(tensor.dtype, tensor.shape)) 

830 self._assign_op = self._saveable.restore(self._placeholder_tensors, 

831 None) 

832 self._setter = self._set_weights_v1 

833 self._getter = lambda: [spec.tensor for spec in self._saveable.specs] 

834 else: 

835 raise ValueError('Only Trackables with one Saveable are supported. ' 

836 'The Trackable %s has %d Saveables.' % 

837 (trackable, len(saveables))) 

838 

839 @property 

840 def num_tensors(self): 

841 return self._num_tensors 

842 

843 def set_weights(self, weights): 

844 if len(weights) != self._num_tensors: 

845 raise ValueError( 

846 ('Weight handler for trackable %s received the wrong number of ' + 

847 'weights: expected %s, got %s.') % 

848 (self._trackable, self._num_tensors, len(weights))) 

849 self._setter(weights) 

850 

851 def get_tensors(self): 

852 return self._getter() 

853 

854 def _set_weights_v1(self, weights): 

855 feed_dict = {} 

856 for idx, tensor in enumerate(weights): 

857 feed_dict[self._placeholder_tensors[idx]] = tensor 

858 backend.get_session().run(self._assign_op, feed_dict) 

859 

860 

861class StaticTableHandler(TrackableWeightHandler): 

862 """Wrapper for handling weight collection for static hash tables.""" 

863 

864 def __init__(self, getter_lambda): # pylint: disable=super-init-not-called 

865 self._num_tensors = 2 

866 self._getter = getter_lambda 

867 self._distribute_strategy = distribute_lib.get_strategy() 

868 

869 def raise_error(_): 

870 raise RuntimeError('This layer contains a static lookup table, which ' 

871 'cannot be changed via set_weights().') 

872 

873 self._setter = raise_error 

874 

875 

876def no_ragged_support(inputs, layer_name): 

877 input_list = nest.flatten(inputs) 

878 if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list): 

879 raise ValueError('Layer %s does not support RaggedTensors as input. ' 

880 'Inputs received: %s. You can try converting your ' 

881 'input to an uniform tensor.' % (layer_name, inputs)) 

882 

883 

884def is_split_variable(v): 

885 """Returns True if `v` is either a PartionedVariable or a ShardedVariable.""" 

886 return hasattr(v, '_variable_list') or hasattr(v, '_variables') 

887 

888 

889def has_weights(obj): 

890 obj_type = type(obj) 

891 return (hasattr(obj_type, 'trainable_weights') and 

892 hasattr(obj_type, 'non_trainable_weights') and 

893 not isinstance(obj, type)) 

894 

895 

896# TODO(kathywu): This is a temporary hack. When a network of layers is revived 

897# from SavedModel, only the top-level layer will have losses. This causes issues 

898# in eager mode because the child layers may have graph losses 

899# (thus model.losses returns a mix of Eager and graph tensors). To fix this, 

900# whenever eager losses are added to one layer, add eager losses to all 

901# child layers. This causes `.losses` to only return eager losses. 

902REVIVED_LOSS_PLACEHOLDER = ( 

903 'This layer\'s losses have been added to the parent layer.')