Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/base_layer_utils.py: 24%

299 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains private utilities used mainly by the base Layer class.""" 

16 

17import functools 

18import threading 

19 

20import tensorflow.compat.v1 as tf1 

21import tensorflow.compat.v2 as tf 

22 

23from keras.src import backend 

24from keras.src.dtensor import dtensor_api as dtensor 

25from keras.src.utils import control_flow_util 

26from keras.src.utils import tf_inspect 

27from keras.src.utils import tf_utils 

28 

29# isort: off 

30from tensorflow.python.util.tf_export import keras_export 

31 

32_call_context = threading.local() 

33 

34 

35def create_mean_metric(value, name=None): 

36 # import keras will import base_layer and then this module, and metric 

37 # relies on base_layer, which result into a cyclic dependency. 

38 from keras.src import metrics as metrics_module 

39 

40 metric_obj = metrics_module.Mean(name=name, dtype=value.dtype) 

41 return metric_obj, metric_obj(value) 

42 

43 

44def infer_init_val_and_dtype(initializer, dtype, shape, layout=None): 

45 if initializer is not None and not callable(initializer): 

46 init_val = initializer 

47 variable_dtype = None 

48 else: 

49 # Instantiate initializer if provided initializer is a type object. 

50 if tf_inspect.isclass(initializer): 

51 initializer = initializer() 

52 if layout: 

53 init_val = functools.partial( 

54 initializer, shape, dtype=dtype, layout=layout 

55 ) 

56 else: 

57 init_val = functools.partial(initializer, shape, dtype=dtype) 

58 variable_dtype = dtype.base_dtype 

59 return init_val, variable_dtype 

60 

61 

62def make_variable( 

63 name, 

64 shape=None, 

65 dtype=tf.float32, 

66 initializer=None, 

67 trainable=None, 

68 caching_device=None, 

69 validate_shape=True, 

70 constraint=None, 

71 use_resource=None, 

72 collections=None, 

73 synchronization=tf.VariableSynchronization.AUTO, 

74 aggregation=tf.VariableAggregation.NONE, 

75 partitioner=None, 

76 layout=None, 

77 experimental_enable_variable_lifting=True, 

78): 

79 """Util to create a variable (relies on `variable_scope.variable`). 

80 

81 Some reuse-related technicalities prevent us from using 

82 `variable_scope.get_variable()` directly, so we use a subcomponent 

83 that has fewer constraints (`variable_scope.variable()`). 

84 

85 In the longer term, it seems like a similar "default variable creator" 

86 method should exist in `Trackable` instead. When this happens, we can get 

87 rid of this temporary solution. 

88 

89 TODO(fchollet): remove this method when no longer needed. 

90 

91 Args: 

92 name: Variable name. 

93 shape: Variable shape. 

94 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 

95 initializer: Initializer instance (callable). 

96 trainable: Whether the variable should be part of the layer's 

97 "trainable_variables" (e.g. variables, biases) 

98 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 

99 Note, if the current variable scope is marked as non-trainable 

100 then this parameter is ignored and any added variables are also 

101 marked as non-trainable. `trainable` becomes `True` unless 

102 `synchronization` is set to `ON_READ`. Defaults to `None`. 

103 caching_device: Passed to `tf.Variable`. 

104 validate_shape: Passed to `tf.Variable`. 

105 constraint: Constraint instance (callable). 

106 use_resource: Whether to use a `ResourceVariable`. 

107 collections: List of graph collections keys. The new variable is added to 

108 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

109 synchronization: Indicates when a distributed a variable will be 

110 aggregated. Accepted values are constants defined in the class 

111 `tf.VariableSynchronization`. By default the synchronization is set to 

112 `AUTO` and the current `DistributionStrategy` chooses 

113 when to synchronize. If `synchronization` is set to `ON_READ`, 

114 `trainable` must not be set to `True`. 

115 aggregation: Indicates how a distributed variable will be aggregated. 

116 Accepted values are constants defined in the class 

117 `tf.VariableAggregation`. 

118 partitioner: Not handled at this time. 

119 layout: the optional DTensor layout, used for creating DVariable. 

120 

121 Returns: 

122 Variable instance. 

123 """ 

124 init_val, variable_dtype = infer_init_val_and_dtype( 

125 initializer, dtype, shape, layout 

126 ) 

127 variable_shape = tf.TensorShape(shape) 

128 

129 if use_resource is None: 

130 use_resource = True 

131 

132 if layout is None: 

133 # In theory, in `use_resource` is True and `collections` is empty 

134 # (that is to say, in TF2), we can use tf.Variable. 

135 # However, this breaks legacy (Estimator) checkpoints because 

136 # it changes variable names. Remove this when V1 is fully deprecated. 

137 return tf1.Variable( 

138 initial_value=init_val, 

139 name=name, 

140 trainable=trainable, 

141 caching_device=caching_device, 

142 dtype=variable_dtype, 

143 validate_shape=validate_shape, 

144 constraint=constraint, 

145 use_resource=use_resource, 

146 collections=collections, 

147 synchronization=synchronization, 

148 aggregation=aggregation, 

149 shape=variable_shape if variable_shape else None, 

150 experimental_enable_variable_lifting=experimental_enable_variable_lifting, # noqa: E501 

151 ) 

152 else: 

153 return dtensor.DVariable( 

154 initial_value=init_val, 

155 name=name, 

156 trainable=trainable, 

157 caching_device=caching_device, 

158 dtype=variable_dtype, 

159 validate_shape=validate_shape, 

160 constraint=constraint, 

161 collections=collections, 

162 synchronization=synchronization, 

163 aggregation=aggregation, 

164 shape=variable_shape if variable_shape else None, 

165 ) 

166 

167 

168def collect_previous_mask(input_tensors): 

169 """Retrieves the output mask(s) of the previous node. 

170 

171 Args: 

172 input_tensors: An arbitrary structure of Tensors. 

173 

174 Returns: 

175 A mask tensor or list of mask tensors. 

176 """ 

177 

178 def _collect_previous_mask(x): 

179 return getattr(x, "_keras_mask", None) 

180 

181 return tf.nest.map_structure(_collect_previous_mask, input_tensors) 

182 

183 

184def have_all_keras_metadata(tensors): 

185 return all(hasattr(x, "_keras_history") for x in tf.nest.flatten(tensors)) 

186 

187 

188def generate_placeholders_from_shape(shape): 

189 return tf1.placeholder(shape=shape, dtype=backend.floatx()) 

190 

191 

192def create_keras_history(tensors): 

193 """Wraps TensorFlow Operations for compatibility with the Functional API. 

194 

195 This method checks to see if a Tensor in `tensors` is missing Keras metadata 

196 and has its origin in a Keras `Input` Layer. If so, this method will replace 

197 the raw TensorFlow Operations that created this tensor with 

198 `TensorFlowOpLayer` instances that create identical operations. 

199 

200 Any Tensors not originating from a Keras `Input` Layer will be treated as 

201 constants when constructing `TensorFlowOpLayer` instances. 

202 

203 Args: 

204 tensors: A structure of Tensors, some of which come from raw TensorFlow 

205 operations and need to have Keras metadata assigned to them. 

206 

207 Returns: 

208 created_layers: List. The `TensorFlowOpLayer` instances created to wrap 

209 the raw Tensorflow operations. 

210 """ 

211 _, created_layers = _create_keras_history_helper(tensors, set(), []) 

212 return created_layers 

213 

214 

215# Unsafe Internal attribute. 

216# If True, Keras will not evaluate the constant-foldable inputs to tf op 

217# layers in TF1 graphs. This *might* speed up model construction time in 

218# certain settings, but it means 

219# the models will not be serializable/deserializable via get_config 

220# (Only via Savedmodels). It may also change the semantics of whether 

221# generated random numbers are generated once and re-used, or recomputed 

222# each time. 

223# Note: This path triggers for TPUEstimators / xla compiled graphs regardless 

224# of this setting. 

225_UNSAFE_GRAPH_OP_LAYER_CREATION = False 

226 

227 

228def _create_keras_history_helper(tensors, processed_ops, created_layers): 

229 """Helper method for `create_keras_history`. 

230 

231 Args: 

232 tensors: A structure of Tensors for which to create Keras metadata. 

233 processed_ops: Set. TensorFlow operations that have already been wrapped 

234 in `TensorFlowOpLayer` instances. 

235 created_layers: List. The `TensorFlowOpLayer` instances created. 

236 

237 Returns: 

238 Tuple. First element is the updated set of TensorFlow Operations that 

239 have been wrapped in `TensorFlowOpLayer` instances. Second element is 

240 a list of the `TensorFlowOpLayer` instances created. 

241 """ 

242 if tf1.executing_eagerly_outside_functions(): 

243 raise ValueError( 

244 "`create_keras_history` should only be called if eager is disabled!" 

245 ) 

246 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`. 

247 # Cannot be imported at top because of circular dependencies. 

248 # TODO(omalleyt): Resolve circular dependency. 

249 from keras.src.engine import base_layer 

250 

251 tensor_list = tf.nest.flatten(tensors) 

252 sparse_ops = [] 

253 ragged_tensors = [] 

254 for tensor in tensor_list: 

255 if getattr(tensor, "_keras_history", None) is not None: 

256 continue 

257 if isinstance(tensor, (tf.SparseTensor, tf1.SparseTensorValue)): 

258 sparse_ops.append(tensor.op) 

259 continue 

260 if tf_utils.is_ragged(tensor): 

261 # Ragged tensors don't have an op property 

262 ragged_tensors.append(tensor) 

263 continue 

264 op = tensor.op # The Op that created this Tensor. 

265 if op not in processed_ops: 

266 # Recursively set `_keras_history`. 

267 op_inputs = list(op.inputs) 

268 constants = {} 

269 layer_inputs = [] 

270 for i, op_input in enumerate(op_inputs): 

271 if uses_keras_history(op_input): 

272 layer_inputs.append(op_input) 

273 else: 

274 # Treat any value not originating from a `keras.Input` as 

275 # a constant. Variables cannot be supported. 

276 ds_with_session = ( 

277 tf.distribute.in_cross_replica_context() 

278 and not tf1.executing_eagerly_outside_functions() 

279 ) 

280 using_xla = control_flow_util.GraphOrParentsInXlaContext( 

281 tf1.get_default_graph() 

282 ) 

283 if ( 

284 ds_with_session 

285 or using_xla 

286 or _UNSAFE_GRAPH_OP_LAYER_CREATION 

287 ): 

288 # In Legacy Graph mode, evaluating here makes Session be 

289 # configured improperly. The downside of this is that 

290 # saving via `get_config` breaks, but SavedModel still 

291 # works. 

292 constants[i] = op_input 

293 else: 

294 with tf.init_scope(): 

295 constants[i] = backend.function([], op_input)([]) 

296 layer_inputs = unnest_if_single_tensor(layer_inputs) 

297 processed_ops, created_layers = _create_keras_history_helper( 

298 layer_inputs, processed_ops, created_layers 

299 ) 

300 name = op.name 

301 node_def = op.node_def.SerializeToString() 

302 op_layer = base_layer.TensorFlowOpLayer( 

303 node_def, constants=constants, name=name 

304 ) 

305 created_layers.append(op_layer) 

306 op_layer._set_connectivity_metadata( 

307 args=(layer_inputs,), kwargs={}, outputs=op.outputs 

308 ) 

309 processed_ops.update([op]) 

310 if sparse_ops or ragged_tensors: 

311 lambda_example = """ 

312 weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) 

313 output = tf.keras.layers.Lambda(weights_mult)(input) 

314 """ 

315 raise ValueError( 

316 "Tensorflow ops that generate ragged or sparse tensor " 

317 "outputs are currently not supported by Keras automatic " 

318 "op wrapping. Please wrap these ops in a Lambda layer: " 

319 "\n\n```\n{example}\n```\n" 

320 "Sparse ops encountered: {sparse_ops}\n" 

321 "Ragged tensors encountered: {ragged_tensors}\n".format( 

322 example=lambda_example, 

323 sparse_ops=str(sparse_ops), 

324 ragged_tensors=str(ragged_tensors), 

325 ) 

326 ) 

327 return processed_ops, created_layers 

328 

329 

330def unnest_if_single_tensor(input_tensors): 

331 # Preserve compatibility with older configs 

332 flat_input_tensors = tf.nest.flatten(input_tensors) 

333 # If this is a single element but not a dict, unwrap. If this is a dict, 

334 # assume the first layer expects a dict (as is the case with a 

335 # DenseFeatures layer); pass through. 

336 if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1: 

337 input_tensors = flat_input_tensors[0] 

338 return input_tensors 

339 

340 

341def needs_keras_history(tensors, ignore_call_context=False): 

342 """Check if any Tensors need to be wrapped in TensorFlowOpLayers. 

343 

344 This will never return True inside a sublayer, because sublayers 

345 do not need to create Keras History. Otherwise, this returns True 

346 if one or more of `tensors` originates from a `keras.Input` and 

347 does not have `_keras_history` set. 

348 

349 Args: 

350 tensors: An arbitrary nested structure of Tensors. 

351 ignore_call_context: Whether to ignore the check of if currently 

352 outside of a `call` context. This is `True` when creating 

353 KerasHistory inside `Node`, where we always know that Tensors 

354 are being used with the Functional API. 

355 

356 Returns: 

357 Bool, whether at least one Tensor needs to be wrapped. 

358 """ 

359 input_tensors = tf.nest.flatten(tensors) 

360 if call_context().in_call and not ignore_call_context: 

361 return False 

362 if all( 

363 getattr(tensor, "_keras_history", None) is not None 

364 for tensor in input_tensors 

365 ): 

366 # KerasHistory already set. 

367 return False 

368 return uses_keras_history(tensors) 

369 

370 

371def is_in_keras_graph(): 

372 """Returns if currently executing inside of a Keras graph.""" 

373 return call_context().in_keras_graph 

374 

375 

376def is_in_eager_or_tf_function(): 

377 """Returns if in eager mode or inside of a tf.function.""" 

378 return tf.executing_eagerly() or is_in_tf_function() 

379 

380 

381def is_in_tf_function(): 

382 """Returns if inside of a tf.function.""" 

383 # Check if running in V1 graph mode. 

384 if not tf1.executing_eagerly_outside_functions(): 

385 return False 

386 if not tf.inside_function(): 

387 return False 

388 # Check if inside Keras FuncGraph. 

389 if is_in_keras_graph(): 

390 return False 

391 # Check for a v1 `wrap_function` FuncGraph. 

392 graph = tf1.get_default_graph() 

393 if getattr(graph, "name", False) and graph.name.startswith( 

394 "wrapped_function" 

395 ): 

396 return False 

397 return True 

398 

399 

400def uses_keras_history(tensors): 

401 """Check if at least one Tensor originates from a `keras.Input`. 

402 

403 This is `True` if at least one Tensor has its origin in a `keras.Input`. 

404 Any Tensor that originates from a `keras.Input` will have a dependency 

405 Tensor with a `_keras_history` attribute attached. Tensors that have 

406 already been checked to not originate from a `keras.Input` 

407 are marked as `_keras_history_checked`. 

408 

409 Args: 

410 tensors: An arbitrary nested structure of Tensors. 

411 

412 Returns: 

413 Bool, whether at least one Tensor originates from a `keras.Input`. 

414 """ 

415 checked_tensors = set() 

416 tensors_to_check = tf.nest.flatten(tensors) 

417 

418 while tensors_to_check: 

419 new_tensors_to_check = [] 

420 for tensor in tensors_to_check: 

421 if id(tensor) in checked_tensors: 

422 continue 

423 

424 checked_tensors.add(id(tensor)) 

425 

426 if getattr(tensor, "_keras_history_checked", None) is not None: 

427 continue 

428 if getattr(tensor, "_keras_history", None) is not None: 

429 return True 

430 

431 try: 

432 new_tensors_to_check.extend(tensor.op.inputs) 

433 except AttributeError: 

434 # In case `tensor` is a Variable created in an Eager context. 

435 pass 

436 

437 tensors_to_check = new_tensors_to_check 

438 

439 # Mark that these Tensors have been checked once for `_keras_history`, 

440 # and should not be checked again for performance reasons. 

441 mark_checked(tensors) 

442 return False 

443 

444 

445def mark_checked(tensors): 

446 """Marks that these Tensors should not be tracked. 

447 

448 This prevents Layers from attempting to create TensorFlowOpLayers 

449 for these Tensors. 

450 

451 Args: 

452 tensors: An arbitrary structure of Tensors. 

453 """ 

454 

455 def _mark_checked(tensor): 

456 tensor._keras_history_checked = True 

457 

458 tf.nest.map_structure(_mark_checked, tensors) 

459 

460 

461def call_context(): 

462 """Returns currently active `CallContext`.""" 

463 call_ctx = getattr(_call_context, "call_context", None) 

464 if call_ctx is None: 

465 call_ctx = CallContext() 

466 _call_context.call_context = call_ctx 

467 return call_ctx 

468 

469 

470# Inject the call_context function to keras_deps to remove the dependency 

471# from TFLite to Keras. 

472tf.__internal__.register_call_context_function(call_context) 

473 

474 

475class CallContext: 

476 """Keeps track of properties currently inside a Layer/Model's `call`. 

477 

478 Attributes: 

479 in_call: Whether currently inside the `call` of a Layer. 

480 layer: The `Layer` whose `call` is currently active. 

481 inputs: The inputs to the currently active `Layer`. 

482 build_graph: Whether currently inside a Graph or FuncGraph. 

483 training: Whether currently executing in training or inference mode. 

484 saving: Whether currently saving to SavedModel. 

485 frozen: Whether currently executing inside a `Layer` with `trainable` set 

486 to `False`. 

487 in_keras_graph: Whether executing inside the Keras Graph. 

488 """ 

489 

490 def __init__(self): 

491 # Handle `in_call` separately as it is the most-read attr and reading it 

492 # is on the hot path. 

493 self.in_call = False 

494 self._state = { 

495 "layer": None, 

496 "inputs": None, 

497 "build_graph": False, 

498 "training": None, 

499 "saving": None, 

500 } 

501 # TODO(b/150169018): This logic can be replaced after the Functional API 

502 # refactor. 

503 self._in_keras_graph = False 

504 

505 def enter(self, layer, inputs, build_graph, training, saving=None): 

506 """Push a Layer and its inputs and state onto the current call context. 

507 

508 Args: 

509 layer: The `Layer` whose `call` is currently active. 

510 inputs: The inputs to the currently active `Layer`. 

511 build_graph: Whether currently inside a Graph or FuncGraph. 

512 training: Whether currently executing in training or inference mode. 

513 saving: Whether currently saving to SavedModel. 

514 

515 Returns: 

516 Context manager. 

517 """ 

518 state = { 

519 "layer": layer, 

520 "inputs": inputs, 

521 "build_graph": build_graph, 

522 "training": training, 

523 "saving": saving, 

524 } 

525 return CallContextManager(self, state) 

526 

527 @property 

528 def layer(self): 

529 return self._state["layer"] 

530 

531 @property 

532 def inputs(self): 

533 return self._state["inputs"] 

534 

535 @property 

536 def build_graph(self): 

537 return self._state["build_graph"] 

538 

539 @property 

540 def training(self): 

541 return self._state["training"] 

542 

543 @property 

544 def saving(self): 

545 return self._state["saving"] 

546 

547 @property 

548 def frozen(self): 

549 layer = self._state["layer"] 

550 if not layer: 

551 return False 

552 return not layer.trainable 

553 

554 @property 

555 def in_keras_graph(self): 

556 # Returns True even if in a subgraph of the Keras graph, such as those 

557 # created by control flow ops. 

558 if tf.executing_eagerly(): 

559 return False 

560 return ( 

561 self._in_keras_graph 

562 or getattr(backend.get_graph(), "name", None) == "keras_graph" 

563 ) 

564 

565 

566class CallContextManager: 

567 """Context manager for `CallContext`.""" 

568 

569 def __init__(self, call_ctx, state): 

570 self._call_ctx = call_ctx 

571 self._state = state 

572 self._build_graph = state["build_graph"] 

573 

574 def __enter__(self): 

575 call_ctx = self._call_ctx 

576 self._prev_in_call = call_ctx.in_call 

577 self._prev_state = call_ctx._state 

578 

579 call_ctx.in_call = True 

580 call_ctx._state = self._state 

581 

582 # TODO(b/150169018): This logic can be removed after the Functional API 

583 # refactor. 

584 if self._build_graph: 

585 self._prev_in_keras_graph = call_ctx._in_keras_graph 

586 call_ctx._in_keras_graph = ( 

587 call_ctx._in_keras_graph 

588 or getattr(backend.get_graph(), "name", None) == "keras_graph" 

589 ) 

590 

591 def __exit__(self, *exc_info): 

592 call_ctx = self._call_ctx 

593 call_ctx.in_call = self._prev_in_call 

594 call_ctx._state = self._prev_state 

595 

596 if self._build_graph: 

597 call_ctx._in_keras_graph = self._prev_in_keras_graph 

598 

599 

600def training_arg_passed_to_call(argspec, args, kwargs): 

601 """Returns whether a user passed the `training` argument in `__call__`.""" 

602 # `argspec.args` starts with ['self', 'inputs'] 

603 full_args = dict(zip(argspec.args[2:], args)) 

604 full_args.update(kwargs) 

605 return "training" in full_args and full_args["training"] is not None 

606 

607 

608def is_subclassed(layer): 

609 """Returns True if the object is a subclassed layer or subclassed model.""" 

610 return ( 

611 layer.__module__.find("keras.engine") == -1 

612 and layer.__module__.find("keras.layers") == -1 

613 ) 

614 

615 

616def from_saved_model(layer): 

617 """Returns whether the layer is loaded from a SavedModel.""" 

618 return layer.__module__.find("keras.saving.legacy.saved_model") != -1 

619 

620 

621def check_graph_consistency(tensor=None, method="add_loss", force_raise=False): 

622 """Checks that tensors passed to `add_*` method match the Keras graph. 

623 

624 When one of the `add_*` method is called inside a V2 conditional branch, the 

625 underlying tensor gets created in a FuncGraph managed by control_flow_v2. 

626 We need to raise clear error messages in such cases. 

627 

628 Args: 

629 tensor: Tensor to check, or `False` if it is known that an error 

630 should be raised. 

631 method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}. 

632 force_raise: If an error should be raised regardless of `tensor`. 

633 

634 Raises: 

635 RuntimeError: In case of an out-of-graph tensor. 

636 """ 

637 if force_raise or ( 

638 tf1.executing_eagerly_outside_functions() 

639 and hasattr(tensor, "graph") 

640 and tensor.graph.is_control_flow_graph 

641 ): 

642 if method == "activity_regularizer": 

643 bad_example = """ 

644 class TestModel(tf.keras.Model): 

645 

646 def __init__(self): 

647 super(TestModel, self).__init__(name='test_model') 

648 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 

649 

650 def call(self, x, training=None): 

651 if training: 

652 return self.dense(x) 

653 else: 

654 return self.dense(x) 

655 """ 

656 correct_example = """ 

657 class TestModel(tf.keras.Model): 

658 

659 def __init__(self): 

660 super(TestModel, self).__init__(name='test_model') 

661 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 

662 

663 def call(self, x, training=None): 

664 return self.dense(x) 

665 """ 

666 raise RuntimeError( 

667 "You are using a layer with `activity_regularizer` in a " 

668 f"control flow branch, e.g.:\n{bad_example}\nThis is currently " 

669 "not supported. Please move your call to the layer with " 

670 "`activity_regularizer` out of the control flow branch, " 

671 f"e.g.:\n{correct_example}\nYou can also resolve this by " 

672 "marking your outer model/layer dynamic (eager-only) by " 

673 "passing `dynamic=True` to the layer constructor. Any kind of " 

674 "control flow is supported with dynamic layers. Note that " 

675 "using `dynamic=True` requires you to implement static shape " 

676 "inference in the `compute_output_shape(input_shape)` " 

677 "method." 

678 ) 

679 

680 if method == "add_metric": 

681 bad_example = """ 

682 def call(self, inputs, training=None): 

683 if training: 

684 metric = compute_metric(inputs) 

685 self.add_metric(metric, name='my_metric', aggregation='mean') 

686 return inputs 

687 """ 

688 correct_example = """ 

689 def call(self, inputs, training=None): 

690 if training: 

691 metric = compute_metric(inputs) 

692 else: 

693 metric = 0. 

694 self.add_metric(metric, name='my_metric', aggregation='mean') 

695 return inputs 

696 """ 

697 elif method == "add_loss": 

698 bad_example = """ 

699 def call(self, inputs, training=None): 

700 if training: 

701 loss = compute_loss(inputs) 

702 self.add_loss(loss) 

703 return inputs 

704 """ 

705 correct_example = """ 

706 def call(self, inputs, training=None): 

707 if training: 

708 loss = compute_loss(inputs) 

709 else: 

710 loss = 0. 

711 self.add_loss(loss) 

712 return inputs 

713 """ 

714 else: 

715 bad_example = """ 

716 def call(self, inputs, training=None): 

717 if training: 

718 self.add_update(self.w.assign_add(1)) 

719 return inputs 

720 """ 

721 correct_example = """ 

722 def call(self, inputs, training=None): 

723 if training: 

724 increment = 1 

725 else: 

726 increment = 0 

727 self.add_update(self.w.assign_add(increment)) 

728 return inputs 

729 """ 

730 raise RuntimeError( 

731 "You are using the method `{method}` in a control flow branch " 

732 "in your layer, e.g.:\n{bad_example}\n" 

733 "This is not currently supported. " 

734 "Please move your call to {method} out of the control flow branch, " 

735 "e.g.:\n{correct_example}\n" 

736 "You can also resolve this by marking your layer " 

737 "as dynamic (eager-only) by passing " 

738 "`dynamic=True` to the layer constructor. " 

739 "Any kind of control flow is supported with dynamic layers. " 

740 "Note that using `dynamic=True` requires you " 

741 "to implement static shape inference " 

742 "in the `compute_output_shape(input_shape)` method.".format( 

743 method=method, 

744 bad_example=bad_example, 

745 correct_example=correct_example, 

746 ) 

747 ) 

748 

749 

750def mark_as_return(outputs, acd): 

751 """Marks `outputs` as the return values for automatic control deps.""" 

752 

753 def _mark_as_return(tensor): 

754 """Marks `tensor` as the return value for automatic control deps.""" 

755 if not tf.is_tensor(tensor): 

756 return tensor 

757 

758 return_tensor = acd.mark_as_return(tensor) 

759 if getattr(tensor, "_keras_mask", None) is not None: 

760 return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask) 

761 else: 

762 return_tensor._keras_mask = None 

763 

764 # Handle TensorFlow Probability attached metadata. 

765 # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`. 

766 if getattr(tensor, "_tfp_distribution", None) is not None: 

767 return_tensor._tfp_distribution = tensor._tfp_distribution 

768 

769 return return_tensor 

770 

771 return tf.nest.map_structure(_mark_as_return, outputs) 

772 

773 

774V2_DTYPE_BEHAVIOR = None 

775 

776 

777@keras_export(v1=["keras.layers.enable_v2_dtype_behavior"]) 

778def enable_v2_dtype_behavior(): 

779 """Enable the V2 dtype behavior for Keras layers. 

780 

781 By default, the V2 dtype behavior is enabled in TensorFlow 2, so this 

782 function is only useful if `tf.compat.v1.disable_v2_behavior` has been 

783 called. Since mixed precision requires V2 dtype behavior to be enabled, this 

784 function allows you to use mixed precision in Keras layers if 

785 `disable_v2_behavior` has been called. 

786 

787 When enabled, the dtype of Keras layers defaults to floatx (which is 

788 typically float32) instead of None. In addition, layers will automatically 

789 cast floating-point inputs to the layer's dtype. 

790 

791 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 

792 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 

793 >>> print(layer.dtype) # float32 since V2 dtype behavior is enabled 

794 float32 

795 >>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled 

796 >>> print(y.dtype.name) 

797 float32 

798 

799 A layer author can opt-out their layer from the automatic input casting by 

800 passing `autocast=False` to the base Layer's constructor. This disables the 

801 autocasting part of the V2 behavior for that layer, but not the defaulting 

802 to floatx part of the V2 behavior. 

803 

804 When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's 

805 dtype will default to the global policy instead of floatx. Layers will 

806 automatically cast inputs to the policy's compute_dtype. 

807 """ 

808 global V2_DTYPE_BEHAVIOR 

809 V2_DTYPE_BEHAVIOR = True 

810 

811 

812@keras_export(v1=["keras.layers.disable_v2_dtype_behavior"]) 

813def disable_v2_dtype_behavior(): 

814 """Disables the V2 dtype behavior for Keras layers. 

815 

816 See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`. 

817 """ 

818 global V2_DTYPE_BEHAVIOR 

819 V2_DTYPE_BEHAVIOR = False 

820 

821 

822def v2_dtype_behavior_enabled(): 

823 """Returns True if the V2 dtype behavior is enabled.""" 

824 if V2_DTYPE_BEHAVIOR is None: 

825 return tf.__internal__.tf2.enabled() 

826 return V2_DTYPE_BEHAVIOR 

827 

828 

829class TrackableWeightHandler: 

830 """Keras wrapper for handling Trackable object saving and restoring. 

831 

832 This class handles Trackables in both V1 and V2 modes, ensuring that they 

833 can be saved and restored with the correct data and without adding 

834 additional ops on every save. 

835 

836 Attributes: 

837 trackable: The trackable to wrap. 

838 num_tensors: The number of tensors that this trackable requires for 

839 saving. 

840 """ 

841 

842 def __init__(self, trackable): 

843 if not isinstance(trackable, tf.__internal__.tracking.Trackable): 

844 raise ValueError(f"{trackable} is not a Trackable object.") 

845 self._trackable = trackable 

846 self._distribute_strategy = tf.distribute.get_strategy() 

847 

848 saveables = tf.__internal__.tracking.saveable_objects_from_trackable( 

849 trackable 

850 ).values() 

851 # 'Saveables' won't exist when we're passed a legacy TF1 table like 

852 # a StaticHashTable. 

853 if not saveables: 

854 self._num_tensors = 0 

855 self._setter = lambda weights: None 

856 self._getter = lambda: [] 

857 

858 elif len(saveables) == 1: 

859 saveable = list(saveables)[0] 

860 

861 if tf1.executing_eagerly_outside_functions(): 

862 # If we're in eager mode, we need to defer calling the 

863 # Trackable's saveable() callable until data export time. 

864 # However, it is safe to call the saveable as many times as we 

865 # want, so we will call it now to figure out how many tensors 

866 # this Trackable will produce. 

867 self._saveable = saveable 

868 self._num_tensors = len(self._saveable().specs) 

869 self._setter = lambda weights: self._saveable().restore( 

870 weights, None 

871 ) 

872 self._getter = lambda: [ 

873 spec.tensor for spec in self._saveable().specs 

874 ] 

875 else: 

876 # If we're in Graph mode, we need to evaluate the Saveable only 

877 # once and cache the resulting restore graph. Failing to do this 

878 # will result in new assignment ops being added to the graph 

879 # each time set_weights() is called. 

880 self._placeholder_tensors = [] 

881 self._saveable = saveable() 

882 self._num_tensors = len(self._saveable.specs) 

883 for spec in self._saveable.specs: 

884 tensor = spec.tensor 

885 self._placeholder_tensors.append( 

886 tf1.placeholder(tensor.dtype, tensor.shape) 

887 ) 

888 self._assign_op = self._saveable.restore( 

889 self._placeholder_tensors, None 

890 ) 

891 self._setter = self._set_weights_v1 

892 self._getter = lambda: [ 

893 spec.tensor for spec in self._saveable.specs 

894 ] 

895 else: 

896 raise ValueError( 

897 "Only Trackables with one Saveable are supported. " 

898 f"The Trackable {trackable} has {len(saveables)} Saveables." 

899 ) 

900 

901 @property 

902 def num_tensors(self): 

903 return self._num_tensors 

904 

905 def set_weights(self, weights): 

906 if len(weights) != self._num_tensors: 

907 raise ValueError( 

908 f"Weight handler for trackable {self._trackable} received " 

909 "an incorrect number of weights: " 

910 f"expected {self._num_tensors} weights, " 

911 f"got {len(weights)} weights." 

912 ) 

913 self._setter(weights) 

914 

915 def get_tensors(self): 

916 return self._getter() 

917 

918 def _set_weights_v1(self, weights): 

919 feed_dict = {} 

920 for idx, tensor in enumerate(weights): 

921 feed_dict[self._placeholder_tensors[idx]] = tensor 

922 backend.get_session().run(self._assign_op, feed_dict) 

923 

924 

925def no_ragged_support(inputs, layer_name): 

926 input_list = tf.nest.flatten(inputs) 

927 if any(isinstance(x, tf.RaggedTensor) for x in input_list): 

928 raise ValueError( 

929 f"Layer {layer_name} does not support RaggedTensors as input. " 

930 f"Inputs received: {inputs}. You can try converting your " 

931 "input to a dense (uniform) tensor." 

932 ) 

933 

934 

935def is_split_variable(v): 

936 """Returns True if `v` is a PartitionedVariable or a ShardedVariable.""" 

937 return not {clz.__name__ for clz in v.__class__.__mro__}.isdisjoint( 

938 {"PartitionedVariable", "ShardedVariable"} 

939 ) 

940 

941 

942def has_weights(obj): 

943 obj_type = type(obj) 

944 return ( 

945 hasattr(obj_type, "trainable_weights") 

946 and hasattr(obj_type, "non_trainable_weights") 

947 and not isinstance(obj, type) 

948 ) 

949 

950 

951# TODO(kathywu): This is a temporary hack. When a network of layers is revived 

952# from SavedModel, only the top-level layer will have losses. This causes issues 

953# in eager mode because the child layers may have graph losses 

954# (thus model.losses returns a mix of Eager and graph tensors). To fix this, 

955# whenever eager losses are added to one layer, add eager losses to all 

956# child layers. This causes `.losses` to only return eager losses. 

957REVIVED_LOSS_PLACEHOLDER = ( 

958 "This layer's losses have been added to the parent layer." 

959) 

960