Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/save_impl.py: 23%

294 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras SavedModel serialization. 

16 

17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should 

18go to model_serialization.py. 

19""" 

20 

21import functools 

22import threading 

23import weakref 

24 

25import tensorflow.compat.v1.logging as logging 

26import tensorflow.compat.v2 as tf 

27 

28from keras.src import backend 

29from keras.src.engine import base_layer_utils 

30from keras.src.engine import input_spec 

31from keras.src.mixed_precision import autocast_variable 

32from keras.src.saving.legacy import saving_utils 

33from keras.src.saving.legacy.saved_model import constants 

34from keras.src.saving.legacy.saved_model import load as keras_load 

35from keras.src.saving.legacy.saved_model import serialized_attributes 

36from keras.src.saving.legacy.saved_model import utils 

37from keras.src.utils import layer_utils 

38from keras.src.utils import tf_contextlib 

39from keras.src.utils import tf_utils 

40from keras.src.utils import version_utils 

41from keras.src.utils.generic_utils import LazyLoader 

42 

43# To avoid circular dependencies between keras/engine and keras/saving, 

44# code in keras/saving must delay imports. 

45 

46# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 

47# once the issue with copybara is fixed. 

48 

49base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer") 

50metrics = LazyLoader("metrics", globals(), "keras.src.metrics") 

51input_layer = LazyLoader("input_layer", globals(), "keras.src.engine.input_layer") 

52training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training") 

53sequential_lib = LazyLoader( 

54 "sequential_lib", globals(), "keras.src.engine.sequential" 

55) 

56 

57 

58def should_skip_serialization(layer): 

59 """Skip serializing extra objects and functions if layer inputs aren't 

60 set.""" 

61 saved_model_input_spec_set = ( 

62 isinstance(layer, training_lib.Model) 

63 and layer._saved_model_inputs_spec is not None 

64 ) 

65 if not layer.built and not saved_model_input_spec_set: 

66 logging.warning( 

67 "Skipping full serialization of Keras layer {}, because " 

68 "it is not built.".format(layer) 

69 ) 

70 return True 

71 return False 

72 

73 

74def _filter_shards(variables): 

75 return [var for var in variables if not hasattr(var, "_sharded_container")] 

76 

77 

78def wrap_layer_objects(layer, serialization_cache): 

79 """Returns extra trackable objects to attach to the serialized layer. 

80 

81 Args: 

82 layer: Keras Layer object. 

83 serialization_cache: Dictionary shared between all objects during 

84 serialization. 

85 

86 Returns: 

87 A dictionary containing all checkpointable objects from a 

88 SerializedAttributes object. See LayerAttributes and ModelAttributes for 

89 entire list of objects 

90 """ 

91 # Wrap all regularization losses as tf.functions. 

92 # First, generate list of all regularization losses in this layer and 

93 # sublayers. 

94 all_losses = layer._callable_losses[:] 

95 for child_layer in utils.list_all_layers(layer): 

96 all_losses.extend(child_layer._callable_losses) 

97 # Next, wrap all loss functions as tf.functions. Use the serialization cache 

98 # to store already-wrapped functions. 

99 keras_loss_cache = serialization_cache.setdefault("keras_losses", {}) 

100 wrapped_loss_functions = [] 

101 for loss_fn in all_losses: 

102 if loss_fn in keras_loss_cache: 

103 wrapped_loss_functions.append(keras_loss_cache[loss_fn]) 

104 else: 

105 wrapped_loss = _wrap_unconditional_loss( 

106 loss_fn, len(keras_loss_cache) 

107 ) 

108 keras_loss_cache[loss_fn] = wrapped_loss 

109 wrapped_loss_functions.append(wrapped_loss) 

110 wrapped_layer_losses = [ 

111 keras_loss_cache[fn] for fn in layer._callable_losses[:] 

112 ] 

113 

114 layer_metrics = tf.__internal__.tracking.wrap( 

115 {m.name: m for m in layer._metrics} 

116 ) 

117 

118 # Avoid duplicate creation of shard Variables on loading. 

119 # `layer.variables` will return the shard Variables rather than the 

120 # ShardedVariables (b/224541446), but Keras loading will create new 

121 # ShardedVariables (and thus shard Variables) from Keras metadata if needed. 

122 # There's no need to also save the shard Variables here, so filter them out. 

123 variables = _filter_shards(layer.variables) 

124 trainable_variables = _filter_shards(layer.trainable_variables) 

125 non_trainable_variables = _filter_shards(layer.non_trainable_variables) 

126 return dict( 

127 variables=tf.__internal__.tracking.wrap(variables), 

128 trainable_variables=tf.__internal__.tracking.wrap(trainable_variables), 

129 non_trainable_variables=tf.__internal__.tracking.wrap( 

130 non_trainable_variables 

131 ), 

132 layers=tf.__internal__.tracking.wrap(utils.list_all_layers(layer)), 

133 metrics=tf.__internal__.tracking.wrap(layer.metrics), 

134 regularization_losses=tf.__internal__.tracking.wrap( 

135 wrapped_loss_functions 

136 ), 

137 layer_regularization_losses=tf.__internal__.tracking.wrap( 

138 wrapped_layer_losses 

139 ), 

140 layer_metrics=layer_metrics, 

141 ) 

142 

143 

144def wrap_layer_functions(layer, serialization_cache): 

145 """Returns dict of wrapped layer call function and losses in tf.functions. 

146 

147 Args: 

148 layer: Keras Layer object. 

149 serialization_cache: Dictionary shared between all objects during 

150 serialization. 

151 

152 Returns: 

153 A dictionary containing all keras tf.functions to serialize. See 

154 LayerAttributes and ModelAttributes for the list of all attributes. 

155 """ 

156 # Since Sequential models may be modified in place using model.add() or 

157 # model.pop(), don't use saved functions. 

158 if isinstance(layer, keras_load.RevivedLayer) and not isinstance( 

159 layer, sequential_lib.Sequential 

160 ): 

161 return { 

162 fn_name: getattr(layer.keras_api, fn_name, None) 

163 for fn_name in serialized_attributes.LayerAttributes.all_functions 

164 } 

165 

166 # Reset the losses of the layer and its children. The call function in each 

167 # child layer is replaced with tf.functions. 

168 original_fns = _replace_child_layer_functions(layer, serialization_cache) 

169 original_losses = _reset_layer_losses(layer) 

170 

171 # Wrap all the layer call and activity regularizer functions. 

172 

173 # Use LayerCallCollection to ensure that all layer call functions (__call__, 

174 # call with losses) are traced with the same inputs. 

175 call_collection = LayerCallCollection(layer) 

176 call_fn_with_losses = call_collection.add_function( 

177 _wrap_call_and_conditional_losses(layer), 

178 f"{layer.name}_layer_call_and_return_conditional_losses", 

179 # If any of this layer's child layers use the training arg, the traced 

180 # call functions of this layer will have a training keyword argument. If 

181 # the original layer does not expect the training arg, then it will have 

182 # to be removed (by setting `match_layer_training_arg`). 

183 match_layer_training_arg=True, 

184 ) 

185 call_fn = call_collection.add_function( 

186 _extract_outputs_from_fn(layer, call_fn_with_losses), 

187 f"{layer.name}_layer_call_fn", 

188 # Since `call_fn` wraps call_fn_with_losses and not the original call 

189 # function, `match_layer_training_arg` should be set to False. 

190 match_layer_training_arg=False, 

191 ) 

192 

193 fns = { 

194 "call_and_return_conditional_losses": call_fn_with_losses, 

195 "__call__": call_fn, 

196 } 

197 

198 if layer._activity_regularizer is not None: 

199 fns["activity_regularizer_fn"] = _wrap_activity_regularizer(layer) 

200 fns[ 

201 "call_and_return_all_conditional_losses" 

202 ] = call_collection.add_function( 

203 _append_activity_regularizer_loss( 

204 layer, call_fn_with_losses, fns["activity_regularizer_fn"] 

205 ), 

206 f"{layer.name}_layer_call_and_return_all_conditional_losses", 

207 match_layer_training_arg=False, 

208 ) 

209 else: 

210 fns["activity_regularizer_fn"] = None 

211 fns["call_and_return_all_conditional_losses"] = call_fn_with_losses 

212 

213 # Manually trigger traces before restoring the overwritten functions. The 

214 # functions are traced within the layer call context to ensure that layer 

215 # functions (e.g. add_loss) behave as though running in graph mode. 

216 with tracing_scope(): 

217 call_collection.trace_with_input_signature() 

218 with base_layer_utils.call_context().enter( 

219 layer, inputs=None, build_graph=True, training=None, saving=True 

220 ): 

221 for fn in fns.values(): 

222 if fn is not None and not isinstance(fn, LayerCall): 

223 fn.get_concrete_function() 

224 

225 # Restore overwritten functions and losses 

226 _restore_child_layer_functions(original_fns) 

227 _restore_layer_losses(original_losses) 

228 

229 return fns 

230 

231 

232def default_save_signature(layer): 

233 original_losses = _reset_layer_losses(layer) 

234 fn = saving_utils.trace_model_call(layer) 

235 _restore_layer_losses(original_losses) 

236 return fn 

237 

238 

239def _replace_child_layer_functions(layer, serialization_cache): 

240 """Replaces functions in the children layers with wrapped tf.functions. 

241 

242 This step allows functions from parent layers to reference the wrapped 

243 functions from their children layers instead of retracing the ops. 

244 

245 This function also resets all losses stored in the layer. These are stored 

246 in the returned dictionary. Use `_restore_child_layer_functions` to restore 

247 the original attributes. 

248 

249 Args: 

250 layer: Keras Layer object. 

251 serialization_cache: Dictionary shared between all objects during 

252 serialization. 

253 

254 Returns: 

255 Dictionary mapping layer objects -> original functions and losses: 

256 { Child layer 1: { 

257 'losses': Original losses, 

258 'call': Original call function 

259 '_activity_regularizer': Original activity regularizer}, 

260 Child layer 2: ... 

261 } 

262 """ 

263 

264 original_fns = {} 

265 

266 def replace_layer_functions(child_layer, serialized_fns): 

267 """Replaces layer call and activity regularizer with wrapped 

268 functions.""" 

269 original_fns[child_layer] = { 

270 "call": child_layer.call, 

271 "_activity_regularizer": child_layer._activity_regularizer, 

272 } 

273 with utils.no_automatic_dependency_tracking_scope(child_layer): 

274 try: 

275 child_layer._activity_regularizer = serialized_fns.get( 

276 "activity_regularizer_fn" 

277 ) 

278 except AttributeError: 

279 # Some layers have an unsettable activity regularizer. 

280 pass 

281 child_layer.call = utils.use_wrapped_call( 

282 child_layer, 

283 serialized_fns["call_and_return_conditional_losses"], 

284 child_layer._call_spec, 

285 default_training_value=False, 

286 ) 

287 

288 def replace_metric_functions(child_layer, serialized_fns): 

289 """Replaces metric functions with wrapped functions.""" 

290 original_fns[child_layer] = { 

291 "__call__": child_layer.__call__, 

292 "result": child_layer.result, 

293 "update_state": child_layer.update_state, 

294 } 

295 with utils.no_automatic_dependency_tracking_scope(child_layer): 

296 child_layer.__call__ = serialized_fns["__call__"] 

297 child_layer.result = serialized_fns["result"] 

298 child_layer.update_state = serialized_fns["update_state"] 

299 

300 for child_layer in utils.list_all_layers(layer): 

301 if isinstance(child_layer, input_layer.InputLayer): 

302 continue 

303 

304 if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]: 

305 serialized_functions = child_layer._trackable_saved_model_saver._get_serialized_attributes( # noqa: E501 

306 serialization_cache 

307 ).functions 

308 else: 

309 serialized_functions = serialization_cache[ 

310 constants.KERAS_CACHE_KEY 

311 ][child_layer].functions 

312 if not serialized_functions: 

313 # This indicates either: 

314 # - circular dependency, which means the current layer's functions 

315 # should be wrapped first. 

316 # - Child layer's inputs are not defined, so its functions have 

317 # not been wrapped. In this case, no replacement is necessary so 

318 # move on to the next child. 

319 continue 

320 

321 if isinstance(child_layer, metrics.Metric): 

322 replace_metric_functions(child_layer, serialized_functions) 

323 else: 

324 replace_layer_functions(child_layer, serialized_functions) 

325 

326 return original_fns 

327 

328 

329def _restore_child_layer_functions(original_fns): 

330 """Restores attributes replaced with `_replace_child_layer_functions`.""" 

331 for child_layer, fns in original_fns.items(): 

332 with utils.no_automatic_dependency_tracking_scope(child_layer): 

333 for fn_name, fn in fns.items(): 

334 try: 

335 setattr(child_layer, fn_name, fn) 

336 except AttributeError: 

337 # In the case of _activity_regularizer, setting the 

338 # attribute may be disallowed. 

339 pass 

340 

341 

342def _reset_layer_losses(parent_layer): 

343 """Resets losses of layer and its sublayers, and returns original losses.""" 

344 losses_dict = {} 

345 for layer in utils.list_all_layers_and_sublayers(parent_layer): 

346 losses_dict[layer] = { 

347 "losses": layer._losses[:], 

348 "eager_losses": layer._eager_losses[:], 

349 } 

350 with utils.no_automatic_dependency_tracking_scope(layer): 

351 layer._losses = [] 

352 layer._eager_losses = [] 

353 return losses_dict 

354 

355 

356def _restore_layer_losses(losses_dict): 

357 for layer in losses_dict: 

358 with utils.no_automatic_dependency_tracking_scope(layer): 

359 layer._losses = losses_dict[layer]["losses"] 

360 layer._eager_losses = losses_dict[layer]["eager_losses"] 

361 

362 

363class LayerTracingContext(threading.local): 

364 def __init__(self): 

365 super().__init__() 

366 self.enable_call_tracing = False 

367 self.trace_queue = [] 

368 

369 

370_thread_local_data = LayerTracingContext() 

371 

372 

373@tf_contextlib.contextmanager 

374def tracing_scope(): 

375 """Enables tracing scope.""" 

376 # This enables the LayerCallCollection's tracing mechanism to trace all call 

377 # functions in the collection. 

378 previous_value = _thread_local_data.enable_call_tracing 

379 previous_queue = _thread_local_data.trace_queue 

380 try: 

381 _thread_local_data.enable_call_tracing = True 

382 _thread_local_data.trace_queue = [] 

383 yield 

384 finally: 

385 # Run traces from the queue. 

386 while _thread_local_data.trace_queue: 

387 fn, args, kwargs, training = _thread_local_data.trace_queue.pop(0) 

388 if training is not None: 

389 with backend.deprecated_internal_learning_phase_scope(training): 

390 fn.get_concrete_function(*args, **kwargs) 

391 else: 

392 fn.get_concrete_function(*args, **kwargs) 

393 _thread_local_data.trace_queue = previous_queue 

394 _thread_local_data.enable_call_tracing = previous_value 

395 

396 

397def add_trace_to_queue(fn, args, kwargs, training=None): 

398 if tracing_enabled(): 

399 _thread_local_data.trace_queue.append( 

400 (fn, args[:], kwargs.copy(), training) 

401 ) 

402 

403 

404def tracing_enabled(): 

405 """Whether to add extra traces to the queue.""" 

406 return _thread_local_data.enable_call_tracing 

407 

408 

409class LayerCallCollection: 

410 """Groups wrapped layer call functions. 

411 

412 This is used to ensure that all layer call functions are traced with the 

413 same inputs- 

414 - call 

415 - call_and_return_conditional_losses 

416 - call_and_return_all_conditional_losses 

417 """ 

418 

419 def __init__(self, layer): 

420 self.layer = layer 

421 

422 self.layer_call_method = _get_layer_call_method(layer) 

423 self._expects_training_arg = utils.layer_uses_training_bool(layer) 

424 self._call_spec = layer._call_spec 

425 

426 # Create new call spec if the layer itself does not accept a training 

427 # arg, but one of its child layers does. When this layer's call 

428 # functions are traced, they will be traced with an added `training` 

429 # keyword argument. 

430 if not self.layer._expects_training_arg and self._expects_training_arg: 

431 arg_spec = utils.set_training_arg_spec( 

432 self._call_spec.full_argspec, False 

433 ) 

434 self._call_spec = layer_utils.CallFunctionSpec(arg_spec) 

435 

436 self._layer_inputs = self._get_layer_inputs(layer) 

437 self._functions = weakref.WeakValueDictionary() 

438 

439 # Get the input argument name from the args. 

440 if self._call_spec.arg_names: 

441 self._input_arg_name = self._call_spec.arg_names[0] 

442 else: 

443 # Layer could be defined with only varargs, in which case use a 

444 # default name. 

445 self._input_arg_name = "inputs" 

446 

447 def _get_layer_inputs(self, layer): 

448 """Inspects layer object and returns the inferred input signature. 

449 

450 Args: 

451 layer: Layer object. 

452 

453 Returns: 

454 List of possibly nested TensorSpecs of the layer call function inputs 

455 in the form of `(args, kwargs)` 

456 """ 

457 if ( 

458 isinstance(layer.call, tf.__internal__.function.Function) 

459 and layer.call.input_signature is not None 

460 ): 

461 return layer.call.input_signature, {} 

462 elif isinstance(layer, training_lib.Model): 

463 return saving_utils.model_call_inputs(layer) 

464 elif ( 

465 layer.input_spec is not None 

466 and layer._use_input_spec_as_call_signature 

467 ): 

468 

469 def to_tensor_spec_or_none(x): 

470 spec = input_spec.to_tensor_spec(x, layer._compute_dtype) 

471 # If the shape is too general (e.g. multiple dimensions are 

472 # allowed), return None so that separate functions can be 

473 # generated for each inferred input signature. 

474 # TODO(b/134962016): currently partial signatures are not 

475 # supported. 

476 if spec.shape == tf.TensorShape(None): 

477 return None, None 

478 return spec 

479 

480 input_signature = [ 

481 tf.nest.map_structure(to_tensor_spec_or_none, layer.input_spec) 

482 ] 

483 

484 return input_signature, {} 

485 else: 

486 return None, None 

487 

488 def add_trace(self, *args, **kwargs): 

489 """Traces all functions with the same args and kwargs. 

490 

491 Args: 

492 *args: Positional args passed to the original function. 

493 **kwargs: Keyword args passed to the original function. 

494 """ 

495 args = list(args) 

496 kwargs = kwargs.copy() 

497 

498 for fn in self._functions.values(): 

499 # TODO(kathywu): Replace arguments with broader shapes defined in 

500 # the input signature. 

501 if self._expects_training_arg: 

502 

503 def trace_with_training(value, fn=fn): 

504 nonlocal args, kwargs 

505 (args, kwargs,) = self._call_spec.set_arg_value( 

506 "training", value, args, kwargs, inputs_in_args=True 

507 ) 

508 add_trace_to_queue(fn, args, kwargs, value) 

509 

510 trace_with_training(True) 

511 trace_with_training(False) 

512 else: 

513 add_trace_to_queue(fn, args, kwargs) 

514 

515 def training_arg_was_passed(self, args, kwargs): 

516 return self._call_spec.arg_was_passed( 

517 "training", args, kwargs, inputs_in_args=True 

518 ) 

519 

520 def get_training_arg_value(self, args, kwargs): 

521 try: 

522 return self._call_spec.get_arg_value( 

523 "training", args, kwargs, inputs_in_args=True 

524 ) 

525 except KeyError: # Training is not in args or kwargs. 

526 return None 

527 

528 def get_input_arg_value(self, args, kwargs): 

529 return self._call_spec.get_arg_value( 

530 self._input_arg_name, args, kwargs, inputs_in_args=True 

531 ) 

532 

533 def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg): 

534 """Wraps call function with added training argument if necessary.""" 

535 if not self.layer._expects_training_arg and self._expects_training_arg: 

536 # Add training arg to wrapper function. 

537 def wrap_with_training_arg(*args, **kwargs): 

538 if match_layer_training_arg: 

539 # Remove the training value, since the original call_fn does 

540 # not expect a training arg. Instead, the training value 

541 # will be propagated using the call context created in 

542 # LayerCall. 

543 args = list(args) 

544 kwargs = kwargs.copy() 

545 (args, kwargs,) = self._call_spec.set_arg_value( 

546 "training", 

547 None, 

548 args, 

549 kwargs, 

550 inputs_in_args=True, 

551 pop_kwarg_if_none=True, 

552 ) 

553 return call_fn(*args, **kwargs) 

554 

555 return tf.__internal__.decorator.make_decorator( 

556 target=call_fn, 

557 decorator_func=wrap_with_training_arg, 

558 decorator_argspec=self._call_spec.full_argspec, 

559 ) 

560 

561 return call_fn 

562 

563 def add_function(self, call_fn, name, match_layer_training_arg): 

564 """Adds a layer call function to the collection. 

565 

566 Args: 

567 call_fn: a python function 

568 name: Name of call function 

569 match_layer_training_arg: If True, removes the `training` from the 

570 function arguments when calling `call_fn`. 

571 

572 Returns: 

573 LayerCall (tf.function) 

574 """ 

575 fn = LayerCall( 

576 self, 

577 self._maybe_wrap_with_training_arg( 

578 call_fn, match_layer_training_arg 

579 ), 

580 name, 

581 ) 

582 self._functions[name] = fn.wrapped_call 

583 return fn 

584 

585 def trace_with_input_signature(self): 

586 """Trace with the layer/models inferred input signature if possible.""" 

587 if self._layer_inputs[0] is None: 

588 return 

589 

590 args, kwargs = self._layer_inputs 

591 if self._expects_training_arg: 

592 args, kwargs = self._call_spec.set_arg_value( 

593 "training", False, args, kwargs, inputs_in_args=True 

594 ) 

595 if None not in tf.nest.flatten([args, kwargs]): 

596 # Manually add traces for layers that have keyword arguments and 

597 # have a fully defined input signature. 

598 self.add_trace(*args, **kwargs) 

599 

600 

601def _filtered_inputs(inputs): 

602 return list(filter(tf_utils.is_tensor_or_variable, tf.nest.flatten(inputs))) 

603 

604 

605def layer_call_wrapper(call_collection, method, name): 

606 """Ensures layer losses are kept the same, and runs method in call 

607 context.""" 

608 

609 # Create wrapper that deals with losses and call context. 

610 def wrapper(*args, **kwargs): 

611 """Calls method within call context.""" 

612 layer = call_collection.layer 

613 training = None 

614 inputs = _filtered_inputs([args, kwargs]) 

615 

616 if (args or kwargs) and call_collection.training_arg_was_passed( 

617 args, kwargs 

618 ): 

619 training = call_collection.get_training_arg_value(args, kwargs) 

620 

621 original_losses = _reset_layer_losses(layer) 

622 with base_layer_utils.call_context().enter( 

623 layer, 

624 inputs=inputs, 

625 build_graph=False, 

626 training=training, 

627 saving=True, 

628 ): 

629 with autocast_variable.enable_auto_cast_variables( 

630 layer._compute_dtype_object 

631 ): 

632 ret = method(*args, **kwargs) 

633 _restore_layer_losses(original_losses) 

634 return ret 

635 

636 # Rename to `name`, since tf.function doesn't have a name argument. Without 

637 # this, all functions returned by this method will be named "call", which 

638 # would be a nightmare to debug. 

639 fn = tf.__internal__.decorator.make_decorator( 

640 target=method, decorator_func=wrapper 

641 ) 

642 fn.__name__ = name 

643 return fn 

644 

645 

646class LayerCall: 

647 """Function that triggers traces of other functions in the same 

648 collection.""" 

649 

650 def __init__(self, call_collection, call_fn, name): 

651 """Initializes a LayerCall object. 

652 

653 Args: 

654 call_collection: a LayerCallCollection, which contains the other layer 

655 call functions (e.g. call_with_conditional_losses, call). These 

656 functions should be traced with the same arguments. 

657 call_fn: A call function. 

658 name: Name of the call function. 

659 """ 

660 self.call_collection = call_collection 

661 self.wrapped_call = tf.function( 

662 layer_call_wrapper(call_collection, call_fn, name) 

663 ) 

664 

665 def _maybe_trace(self, args, kwargs): 

666 # Trigger traces of other call functions + extra training-arg traces. 

667 if tracing_enabled(): 

668 self.call_collection.add_trace(*args, **kwargs) 

669 

670 def __call__(self, *args, **kwargs): 

671 self._maybe_trace(args, kwargs) 

672 return self.wrapped_call(*args, **kwargs) 

673 

674 def get_concrete_function(self, *args, **kwargs): 

675 self._maybe_trace(args, kwargs) 

676 return self.wrapped_call.get_concrete_function(*args, **kwargs) 

677 

678 

679def _wrap_call_and_conditional_losses(layer): 

680 """Wraps call function that returns a tuple of (outputs, losses). 

681 

682 The losses returned are conditional on the inputs passed to the call 

683 function. Unconditional losses (e.g. weight regularizeration) are wrapped 

684 separately. 

685 

686 Args: 

687 layer: a Keras layer object 

688 

689 Returns: 

690 python call function that returns outputs and conditional losses -- 

691 excludes activity regularizer 

692 """ 

693 # Create function that generates both outputs and losses 

694 layer_call = _get_layer_call_method(layer) 

695 

696 def call_and_return_conditional_losses(*args, **kwargs): 

697 """Returns layer (call_output, conditional losses) tuple.""" 

698 call_output = layer_call(*args, **kwargs) 

699 if version_utils.is_v1_layer_or_model(layer): 

700 conditional_losses = layer.get_losses_for( 

701 _filtered_inputs([args, kwargs]) 

702 ) 

703 else: 

704 conditional_losses = [ 

705 l for l in layer.losses if not hasattr(l, "_unconditional_loss") 

706 ] 

707 return call_output, conditional_losses 

708 

709 return _create_call_fn_decorator(layer, call_and_return_conditional_losses) 

710 

711 

712def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): 

713 """Returns a function that returns only call function outputs.""" 

714 if isinstance(layer, keras_load.RevivedLayer): 

715 return layer.keras_api.__call__ 

716 

717 def call(inputs, *args, **kwargs): 

718 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0] 

719 

720 return _create_call_fn_decorator(layer, call) 

721 

722 

723def _append_activity_regularizer_loss( 

724 layer, call_fn_with_losses, activity_regularizer_fn 

725): 

726 """Appends activity regularizer loss to losses returned by the wrapped 

727 fn.""" 

728 

729 def fn(inputs, *args, **kwargs): 

730 outputs, losses = call_fn_with_losses(inputs, *args, **kwargs) 

731 losses.append(activity_regularizer_fn(outputs)) 

732 return outputs, losses 

733 

734 return _create_call_fn_decorator(layer, fn) 

735 

736 

737def _create_call_fn_decorator(layer, wrapped_call): 

738 call_fn = _get_layer_call_method(layer) 

739 fn, arg_spec = utils.maybe_add_training_arg( 

740 layer._call_spec, 

741 wrapped_call, 

742 layer._expects_training_arg, 

743 default_training_value=False, 

744 ) 

745 return tf.__internal__.decorator.make_decorator( 

746 target=call_fn, decorator_func=fn, decorator_argspec=arg_spec 

747 ) 

748 

749 

750def _wrap_unconditional_loss(loss_fn, index): 

751 """Wraps callable/unconditional loss, returning a serializable function.""" 

752 # Extract original loss function from partial function 

753 fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn 

754 if isinstance(fn, tf.__internal__.function.Function): 

755 return fn 

756 else: 

757 return tf.__internal__.function.Function( 

758 fn, f"loss_fn_{index}", input_signature=[] 

759 ) 

760 

761 

762def _wrap_activity_regularizer(layer): 

763 """Wraps the activity regularizer.""" 

764 

765 if isinstance( 

766 layer._activity_regularizer, tf.__internal__.function.Function 

767 ): 

768 return layer._activity_regularizer 

769 return tf.__internal__.function.Function( 

770 layer._activity_regularizer, 

771 f"{layer.name}_activity_regularizer", 

772 input_signature=[ 

773 tf.TensorSpec(None, layer._compute_dtype or backend.floatx()) 

774 ], 

775 ) 

776 

777 

778def _get_layer_call_method(layer): 

779 if isinstance(layer.call, (tf.__internal__.function.Function)): 

780 return layer.call.python_function 

781 return layer.call 

782