Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py: 24%

310 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras SavedModel serialization. 

16 

17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should 

18go to model_serialization.py. 

19""" 

20 

21import functools 

22import threading 

23import weakref 

24 

25from tensorflow.python.eager import def_function 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_spec 

28from tensorflow.python.keras import backend as K 

29from tensorflow.python.keras.engine import base_layer_utils 

30from tensorflow.python.keras.engine import input_spec 

31from tensorflow.python.keras.mixed_precision import autocast_variable 

32from tensorflow.python.keras.saving import saving_utils 

33from tensorflow.python.keras.saving.saved_model import constants 

34from tensorflow.python.keras.saving.saved_model import load as keras_load 

35from tensorflow.python.keras.saving.saved_model import serialized_attributes 

36from tensorflow.python.keras.saving.saved_model import utils 

37from tensorflow.python.keras.utils import tf_contextlib 

38from tensorflow.python.keras.utils import tf_inspect 

39from tensorflow.python.keras.utils import tf_utils 

40from tensorflow.python.keras.utils import version_utils 

41from tensorflow.python.keras.utils.generic_utils import LazyLoader 

42from tensorflow.python.platform import tf_logging as logging 

43from tensorflow.python.trackable import data_structures 

44from tensorflow.python.util import nest 

45from tensorflow.python.util import tf_decorator 

46 

47 

48# To avoid circular dependencies between keras/engine and keras/saving, 

49# code in keras/saving must delay imports. 

50 

51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 

52# once the issue with copybara is fixed. 

53# pylint:disable=g-inconsistent-quotes 

54base_layer = LazyLoader( 

55 "base_layer", globals(), 

56 "tensorflow.python.keras.engine.base_layer") 

57metrics = LazyLoader("metrics", globals(), 

58 "tensorflow.python.keras.metrics") 

59input_layer = LazyLoader( 

60 "input_layer", globals(), 

61 "tensorflow.python.keras.engine.input_layer") 

62training_lib = LazyLoader( 

63 "training_lib", globals(), 

64 "tensorflow.python.keras.engine.training") 

65sequential_lib = LazyLoader( 

66 "sequential_lib", globals(), 

67 "tensorflow.python.keras.engine.sequential") 

68# pylint:enable=g-inconsistent-quotes 

69 

70 

71def should_skip_serialization(layer): 

72 """Skip serializing extra objects and functions if layer inputs aren't set.""" 

73 saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and 

74 layer._saved_model_inputs_spec is not None) # pylint: disable=protected-access 

75 if not layer.built and not saved_model_input_spec_set: 

76 logging.warning('Skipping full serialization of Keras layer {}, because ' 

77 'it is not built.'.format(layer)) 

78 return True 

79 return False 

80 

81 

82def wrap_layer_objects(layer, serialization_cache): 

83 """Returns extra trackable objects to attach to the serialized layer. 

84 

85 Args: 

86 layer: Keras Layer object. 

87 serialization_cache: Dictionary shared between all objects during 

88 serialization. 

89 

90 Returns: 

91 A dictionary containing all checkpointable objects from a 

92 SerializedAttributes object. See LayerAttributes and ModelAttributes for 

93 entire list of objects 

94 """ 

95 # Wrap all regularization losses as tf.functions. 

96 # First, generate list of all regularization losses in this layer and 

97 # sublayers. 

98 all_losses = layer._callable_losses[:] # pylint: disable=protected-access 

99 for child_layer in utils.list_all_layers(layer): 

100 all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access 

101 # Next, wrap all loss functions as tf.functions. Use the serialization cache 

102 # to store already-wrapped functions. 

103 keras_loss_cache = serialization_cache.setdefault('keras_losses', {}) 

104 wrapped_loss_functions = [] 

105 for loss_fn in all_losses: 

106 if loss_fn in keras_loss_cache: 

107 wrapped_loss_functions.append(keras_loss_cache[loss_fn]) 

108 else: 

109 wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache)) 

110 keras_loss_cache[loss_fn] = wrapped_loss 

111 wrapped_loss_functions.append(wrapped_loss) 

112 wrapped_layer_losses = [keras_loss_cache[fn] 

113 for fn in layer._callable_losses[:]] # pylint: disable=protected-access 

114 

115 layer_metrics = data_structures.wrap_or_unwrap( 

116 {m.name: m for m in layer._metrics}) # pylint: disable=protected-access 

117 return dict( 

118 variables=data_structures.wrap_or_unwrap(layer.variables), 

119 trainable_variables=data_structures.wrap_or_unwrap( 

120 layer.trainable_variables), 

121 non_trainable_variables=data_structures.wrap_or_unwrap( 

122 layer.non_trainable_variables), 

123 layers=data_structures.wrap_or_unwrap(utils.list_all_layers(layer)), 

124 metrics=data_structures.wrap_or_unwrap(layer.metrics), 

125 regularization_losses=data_structures.wrap_or_unwrap( 

126 wrapped_loss_functions), 

127 layer_regularization_losses=data_structures.wrap_or_unwrap( 

128 wrapped_layer_losses), 

129 layer_metrics=layer_metrics) 

130 # pylint: disable=protected-access 

131 

132 

133def wrap_layer_functions(layer, serialization_cache): 

134 """Returns dict of wrapped layer call function and losses in tf.functions. 

135 

136 Args: 

137 layer: Keras Layer object. 

138 serialization_cache: Dictionary shared between all objects during 

139 serialization. 

140 

141 Returns: 

142 A dictionary containing all keras tf.functions to serialize. See 

143 LayerAttributes and ModelAttributes for the list of all attributes. 

144 """ 

145 # Since Sequential models may be modified in place using model.add() or 

146 # model.pop(), don't use saved functions. 

147 if (isinstance(layer, keras_load.RevivedLayer) and 

148 not isinstance(layer, sequential_lib.Sequential)): 

149 return {fn_name: getattr(layer.keras_api, fn_name, None) 

150 for fn_name in serialized_attributes.LayerAttributes.all_functions} 

151 

152 # Reset the losses of the layer and its children. The call function in each 

153 # child layer is replaced with tf.functions. 

154 original_fns = _replace_child_layer_functions(layer, serialization_cache) 

155 original_losses = _reset_layer_losses(layer) 

156 

157 # Wrap all the layer call and activity regularizer functions. 

158 

159 # Use LayerCallCollection to ensure that all layer call functions (__call__, 

160 # call with losses) are traced with the same inputs. 

161 call_collection = LayerCallCollection(layer) 

162 call_fn_with_losses = call_collection.add_function( 

163 _wrap_call_and_conditional_losses(layer), 

164 '{}_layer_call_and_return_conditional_losses'.format(layer.name), 

165 # If any of this layer's child layers use the training arg, the traced 

166 # call functions of this layer will have a training keyword argument. If 

167 # the original layer does not expect the training arg, then it will have 

168 # to be removed (by setting `match_layer_training_arg`). 

169 match_layer_training_arg=True) 

170 call_fn = call_collection.add_function( 

171 _extract_outputs_from_fn(layer, call_fn_with_losses), 

172 '{}_layer_call_fn'.format(layer.name), 

173 # Since `call_fn` wraps call_fn_with_losses and not the original call 

174 # function, `match_layer_training_arg` should be set to False. 

175 match_layer_training_arg=False) 

176 

177 fns = {'call_and_return_conditional_losses': call_fn_with_losses, 

178 '__call__': call_fn} 

179 

180 if layer._activity_regularizer is not None: # pylint: disable=protected-access 

181 fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) 

182 fns['call_and_return_all_conditional_losses'] = ( 

183 call_collection.add_function( 

184 _append_activity_regularizer_loss( 

185 layer, call_fn_with_losses, fns['activity_regularizer_fn']), 

186 '{}_layer_call_and_return_all_conditional_losses'.format( 

187 layer.name), 

188 match_layer_training_arg=False)) 

189 else: 

190 fns['activity_regularizer_fn'] = None 

191 fns['call_and_return_all_conditional_losses'] = call_fn_with_losses 

192 

193 # Manually trigger traces before restoring the overwritten functions. The 

194 # functions are traced within the layer call context to ensure that layer 

195 # functions (e.g. add_loss) behave as though running in graph mode. 

196 with tracing_scope(): 

197 call_collection.trace_with_input_signature() 

198 with base_layer_utils.call_context().enter( 

199 layer, inputs=None, build_graph=True, training=None, saving=True): 

200 for fn in fns.values(): 

201 if fn is not None and fn.input_signature is not None: 

202 if isinstance(fn, LayerCall): 

203 fn = fn.wrapped_call 

204 fn.get_concrete_function() 

205 

206 # Restore overwritten functions and losses 

207 _restore_child_layer_functions(original_fns) 

208 _restore_layer_losses(original_losses) 

209 

210 return fns 

211 

212 

213def default_save_signature(layer): 

214 original_losses = _reset_layer_losses(layer) 

215 fn = saving_utils.trace_model_call(layer) 

216 fn.get_concrete_function() 

217 _restore_layer_losses(original_losses) 

218 return fn 

219 

220 

221def _replace_child_layer_functions(layer, serialization_cache): 

222 """Replaces functions in the children layers with wrapped tf.functions. 

223 

224 This step allows functions from parent layers to reference the wrapped 

225 functions from their children layers instead of retracing the ops. 

226 

227 This function also resets all losses stored in the layer. These are stored in 

228 the returned dictionary. Use `_restore_child_layer_functions` to restore 

229 the original attributes. 

230 

231 Args: 

232 layer: Keras Layer object. 

233 serialization_cache: Dictionary shared between all objects during 

234 serialization. 

235 

236 Returns: 

237 Dictionary mapping layer objects -> original functions and losses: 

238 { Child layer 1: { 

239 'losses': Original losses, 

240 'call': Original call function 

241 '_activity_regularizer': Original activity regularizer}, 

242 Child layer 2: ... 

243 } 

244 """ 

245 # pylint: disable=protected-access 

246 original_fns = {} 

247 

248 def replace_layer_functions(child_layer, serialized_fns): 

249 """Replaces layer call and activity regularizer with wrapped functions.""" 

250 original_fns[child_layer] = { 

251 'call': child_layer.call, 

252 '_activity_regularizer': child_layer._activity_regularizer 

253 } 

254 with utils.no_automatic_dependency_tracking_scope(child_layer): 

255 try: 

256 child_layer._activity_regularizer = serialized_fns.get( 

257 'activity_regularizer_fn') 

258 except AttributeError: 

259 # Some layers have an unsettable activity regularizer. 

260 pass 

261 child_layer.call = utils.use_wrapped_call( 

262 child_layer, 

263 serialized_fns['call_and_return_conditional_losses'], 

264 default_training_value=False) 

265 

266 def replace_metric_functions(child_layer, serialized_fns): 

267 """Replaces metric functions with wrapped functions.""" 

268 original_fns[child_layer] = { 

269 '__call__': child_layer.__call__, 

270 'result': child_layer.result, 

271 'update_state': child_layer.update_state 

272 } 

273 with utils.no_automatic_dependency_tracking_scope(child_layer): 

274 child_layer.__call__ = serialized_fns['__call__'] 

275 child_layer.result = serialized_fns['result'] 

276 child_layer.update_state = serialized_fns['update_state'] 

277 

278 for child_layer in utils.list_all_layers(layer): 

279 if isinstance(child_layer, input_layer.InputLayer): 

280 continue 

281 

282 if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]: 

283 serialized_functions = ( 

284 child_layer._trackable_saved_model_saver._get_serialized_attributes( 

285 serialization_cache).functions) 

286 else: 

287 serialized_functions = ( 

288 serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions) 

289 if not serialized_functions: 

290 # This indicates either: 

291 # - circular dependency, which means the current layer's functions 

292 # should be wrapped first. 

293 # - Child layer's inputs are not defined, so its functions have not been 

294 # wrapped. In this case, no replacement is necessary so move on to the 

295 # next child. 

296 continue 

297 

298 if isinstance(child_layer, metrics.Metric): 

299 replace_metric_functions(child_layer, serialized_functions) 

300 else: 

301 replace_layer_functions(child_layer, serialized_functions) 

302 

303 return original_fns 

304 # pylint: enable=protected-access 

305 

306 

307def _restore_child_layer_functions(original_fns): 

308 """Restores attributes replaced with `_replace_child_layer_functions`.""" 

309 for child_layer, fns in original_fns.items(): 

310 with utils.no_automatic_dependency_tracking_scope(child_layer): 

311 for fn_name, fn in fns.items(): 

312 try: 

313 setattr(child_layer, fn_name, fn) # pylint: disable=protected-access 

314 except AttributeError: 

315 pass # In the case of _activity_regularizer, setting the attribute 

316 # may be disallowed. 

317 

318 

319# pylint: disable=protected-access 

320def _reset_layer_losses(parent_layer): 

321 """Resets losses of layer and its sublayers, and returns original losses.""" 

322 losses_dict = {} 

323 for layer in utils.list_all_layers_and_sublayers(parent_layer): 

324 losses_dict[layer] = {'losses': layer._losses[:], 

325 'eager_losses': layer._eager_losses[:]} 

326 with utils.no_automatic_dependency_tracking_scope(layer): 

327 layer._losses = [] 

328 layer._eager_losses = [] 

329 return losses_dict 

330 

331 

332def _restore_layer_losses(losses_dict): 

333 for layer in losses_dict: 

334 with utils.no_automatic_dependency_tracking_scope(layer): 

335 layer._losses = losses_dict[layer]['losses'] 

336 layer._eager_losses = losses_dict[layer]['eager_losses'] 

337# pylint: enable=protected-access 

338 

339 

340class LayerTracingContext(threading.local): 

341 

342 def __init__(self): 

343 super(LayerTracingContext, self).__init__() 

344 self.enable_call_tracing = False 

345 self.trace_queue = [] 

346 

347_thread_local_data = LayerTracingContext() 

348 

349 

350@tf_contextlib.contextmanager 

351def tracing_scope(): 

352 """Enables tracing scope.""" 

353 # This enables the LayerCallCollection's tracing mechanism to trace all call 

354 # functions in the collection. 

355 previous_value = _thread_local_data.enable_call_tracing 

356 previous_queue = _thread_local_data.trace_queue 

357 try: 

358 _thread_local_data.enable_call_tracing = True 

359 _thread_local_data.trace_queue = [] 

360 yield 

361 finally: 

362 # Run traces from the queue. 

363 while _thread_local_data.trace_queue: 

364 fn, args, kwargs, training = _thread_local_data.trace_queue.pop() 

365 if training is not None: 

366 with K.deprecated_internal_learning_phase_scope(training): 

367 fn.get_concrete_function(*args, **kwargs) 

368 else: 

369 fn.get_concrete_function(*args, **kwargs) 

370 _thread_local_data.trace_queue = previous_queue 

371 _thread_local_data.enable_call_tracing = previous_value 

372 

373 

374def add_trace_to_queue(fn, args, kwargs, training=None): 

375 if tracing_enabled(): 

376 _thread_local_data.trace_queue.append( 

377 (fn, args[:], kwargs.copy(), training)) 

378 

379 

380def tracing_enabled(): 

381 """Whether to add extra traces to the queue.""" 

382 return _thread_local_data.enable_call_tracing 

383 

384 

385class LayerCallCollection(object): 

386 """Groups wrapped layer call functions. 

387 

388 This is used to ensure that all layer call functions are traced with the same 

389 inputs- 

390 - call 

391 - call_and_return_conditional_losses 

392 - call_and_return_all_conditional_losses 

393 """ 

394 

395 def __init__(self, layer): 

396 self.layer = layer 

397 

398 self.layer_call_method = _get_layer_call_method(layer) 

399 self._expects_training_arg = utils.layer_uses_training_bool(layer) 

400 self._training_arg_index = utils.get_training_arg_index( 

401 self.layer_call_method) 

402 

403 # If the layer call function has kwargs, then the traced function cannot 

404 # have an input signature. 

405 arg_spec = tf_inspect.getfullargspec(self.layer_call_method) 

406 self._has_kwargs = bool(self._expects_training_arg or 

407 arg_spec.defaults or 

408 arg_spec.kwonlyargs or 

409 arg_spec.varkw) 

410 

411 self._input_signature = self._generate_input_signature(layer) 

412 self._functions = weakref.WeakValueDictionary() 

413 

414 # Get the input argument name from the args. 

415 args = arg_spec.args 

416 if tf_inspect.ismethod(self.layer_call_method): 

417 args = args[1:] 

418 self._input_arg_name = args[0] if args else 'inputs' 

419 

420 def _generate_input_signature(self, layer): 

421 """Inspects layer object and returns the inferred input signature. 

422 

423 Args: 

424 layer: Layer object. 

425 

426 Returns: 

427 List of possibly nested TensorSpecs of the layer call function inputs. 

428 The list does not contain the `training` argument. 

429 """ 

430 if (isinstance(layer.call, def_function.Function) and 

431 layer.call.input_signature is not None): 

432 return layer.call.input_signature 

433 elif isinstance(layer, training_lib.Model): 

434 return saving_utils.model_input_signature(layer) 

435 elif (layer.input_spec is not None and 

436 layer._use_input_spec_as_call_signature): # pylint: disable=protected-access 

437 

438 def to_tensor_spec_or_none(x): 

439 spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access 

440 # If the shape is too general (e.g. multiple dimensions are allowed), 

441 # return None so that separate functions can be generated for each 

442 # inferred input signature. 

443 # TODO(b/134962016): currently partial signatures are not supported. 

444 if spec.shape == tensor_shape.TensorShape(None): 

445 return None 

446 return spec 

447 input_signature = [nest.map_structure( 

448 to_tensor_spec_or_none, layer.input_spec)] 

449 

450 return input_signature 

451 else: 

452 return None 

453 

454 def add_trace(self, *args, **kwargs): 

455 """Traces all functions with the same args and kwargs. 

456 

457 Args: 

458 *args: Positional args passed to the original function. 

459 **kwargs: Keyword args passed to the original function. 

460 """ 

461 args = list(args) 

462 kwargs = kwargs.copy() 

463 

464 for fn in self._functions.values(): 

465 # TODO(kathywu): Replace arguments with broader shapes defined in the 

466 # input signature. 

467 if self._expects_training_arg: 

468 def trace_with_training(value, fn=fn): 

469 utils.set_training_arg(value, self._training_arg_index, args, kwargs) 

470 add_trace_to_queue(fn, args, kwargs, value) 

471 

472 trace_with_training(True) 

473 trace_with_training(False) 

474 else: 

475 add_trace_to_queue(fn, args, kwargs) 

476 

477 @property 

478 def fn_input_signature(self): 

479 """Returns input signature for the wrapped layer call function.""" 

480 if self._has_kwargs: 

481 # Input signatures may only describe tensor arguments and kwargs are not 

482 # supported. 

483 return None 

484 if None in nest.flatten(self._input_signature): 

485 # TODO(b/134962016): If input signature cannot be partially defined. 

486 return None 

487 return self._input_signature 

488 

489 def training_arg_was_passed(self, args, kwargs): 

490 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 

491 return (utils.get_training_arg(self._training_arg_index, args, kwargs) 

492 is not None) 

493 else: 

494 return self.layer._call_arg_was_passed( # pylint: disable=protected-access 

495 'training', args, kwargs, inputs_in_args=True) 

496 

497 def get_training_arg_value(self, args, kwargs): 

498 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 

499 return utils.get_training_arg(self._training_arg_index, args, kwargs) 

500 else: 

501 return self.layer._get_call_arg_value( # pylint: disable=protected-access 

502 'training', args, kwargs, inputs_in_args=True) 

503 

504 def get_input_arg_value(self, args, kwargs): 

505 return self.layer._get_call_arg_value( # pylint: disable=protected-access 

506 self._input_arg_name, args, kwargs, inputs_in_args=True) 

507 

508 def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg): 

509 """Wraps call function with added training argument if necessary.""" 

510 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 

511 # Add training arg to wrapper function. 

512 arg_spec = tf_inspect.getfullargspec(call_fn) 

513 args = arg_spec.args + ['training'] 

514 defaults = list(arg_spec.defaults or []) 

515 defaults.append(False) 

516 new_arg_spec = tf_inspect.FullArgSpec( 

517 args=args, 

518 varargs=arg_spec.varargs, 

519 varkw=arg_spec.varkw, 

520 defaults=defaults, 

521 kwonlyargs=arg_spec.kwonlyargs, 

522 kwonlydefaults=arg_spec.kwonlydefaults, 

523 annotations=arg_spec.annotations) 

524 

525 # Set new training arg index 

526 self._training_arg_index = len(args) - 1 

527 if tf_inspect.ismethod(call_fn): 

528 self._training_arg_index -= 1 

529 

530 def wrap_with_training_arg(*args, **kwargs): 

531 if match_layer_training_arg: 

532 # Remove the training value, since the original call_fn does not 

533 # expect a training arg. Instead, the training value will be 

534 # propagated using the call context created in LayerCall. 

535 args = list(args) 

536 kwargs = kwargs.copy() 

537 utils.remove_training_arg(self._training_arg_index, args, kwargs) 

538 return call_fn(*args, **kwargs) 

539 

540 return tf_decorator.make_decorator( 

541 target=call_fn, 

542 decorator_func=wrap_with_training_arg, 

543 decorator_argspec=new_arg_spec) 

544 

545 return call_fn 

546 

547 def add_function(self, call_fn, name, match_layer_training_arg): 

548 """Adds a layer call function to the collection. 

549 

550 Args: 

551 call_fn: a python function 

552 name: Name of call function 

553 match_layer_training_arg: If True, removes the `training` from the 

554 function arguments when calling `call_fn`. 

555 

556 Returns: 

557 LayerCall (tf.function) 

558 """ 

559 fn = LayerCall( 

560 self, 

561 self._maybe_wrap_with_training_arg(call_fn, match_layer_training_arg), 

562 name, 

563 input_signature=self.fn_input_signature) 

564 self._functions[name] = fn.wrapped_call 

565 return fn 

566 

567 def trace_with_input_signature(self): 

568 """Trace with the layer/models inferred input signature if possible.""" 

569 if (None not in nest.flatten(self._input_signature) and self._has_kwargs): 

570 # Manually add traces for layers that have keyword arguments and have 

571 # a fully defined input signature. 

572 self.add_trace(*self._input_signature) 

573 

574 

575def _filtered_inputs(inputs): 

576 return list(filter(tf_utils.is_tensor_or_variable, nest.flatten(inputs))) 

577 

578 

579def layer_call_wrapper(call_collection, method, name): 

580 """Ensures layer losses are kept the same, and runs method in call context.""" 

581 

582 # Create wrapper that deals with losses and call context. 

583 def wrapper(*args, **kwargs): 

584 """Calls method within call context.""" 

585 layer = call_collection.layer 

586 training = None 

587 inputs = _filtered_inputs([args, kwargs]) 

588 # pylint: disable=protected-access 

589 if (args or kwargs) and call_collection.training_arg_was_passed( 

590 args, kwargs): 

591 training = call_collection.get_training_arg_value(args, kwargs) 

592 # pylint: enable=protected-access 

593 original_losses = _reset_layer_losses(layer) 

594 with base_layer_utils.call_context().enter( 

595 layer, inputs=inputs, build_graph=False, training=training, 

596 saving=True): 

597 with autocast_variable.enable_auto_cast_variables( 

598 layer._compute_dtype_object): # pylint: disable=protected-access 

599 ret = method(*args, **kwargs) 

600 _restore_layer_losses(original_losses) 

601 return ret 

602 

603 # Rename to `name`, since tf.function doesn't have a name argument. Without 

604 # this, all functions returned by this method will be named "call", which 

605 # would be a nightmare to debug. 

606 fn = tf_decorator.make_decorator(target=method, decorator_func=wrapper) 

607 fn.__name__ = name 

608 return fn 

609 

610 

611class LayerCall(object): 

612 """Function that triggers traces of other functions in the same collection.""" 

613 

614 def __init__(self, call_collection, call_fn, name, input_signature): 

615 """Initializes a LayerCall object. 

616 

617 Args: 

618 call_collection: a LayerCallCollection, which contains the other layer 

619 call functions (e.g. call_with_conditional_losses, call). These 

620 functions should be traced with the same arguments. 

621 call_fn: A call function. 

622 name: Name of the call function. 

623 input_signature: Input signature of call_fn (can be None). 

624 """ 

625 self.call_collection = call_collection 

626 self.input_signature = input_signature 

627 self.wrapped_call = def_function.function( 

628 layer_call_wrapper(call_collection, call_fn, name), 

629 input_signature=input_signature) 

630 self.original_layer_call = call_collection.layer_call_method 

631 

632 def _maybe_trace(self, args, kwargs): 

633 # Trigger traces of other call functions + extra training-arg traces. 

634 if tracing_enabled(): 

635 self.call_collection.add_trace(*args, **kwargs) 

636 

637 def __call__(self, *args, **kwargs): 

638 self._maybe_trace(args, kwargs) 

639 return self.wrapped_call(*args, **kwargs) 

640 

641 def get_concrete_function(self, *args, **kwargs): 

642 self._maybe_trace(args, kwargs) 

643 return self.wrapped_call.get_concrete_function(*args, **kwargs) 

644 

645 

646def _wrap_call_and_conditional_losses(layer): 

647 """Wraps call function that returns a tuple of (outputs, losses). 

648 

649 The losses returned are conditional on the inputs passed to the call function. 

650 Unconditional losses (e.g. weight regularizeration) are wrapped separately. 

651 

652 Args: 

653 layer: a Keras layer object 

654 

655 Returns: 

656 python call function that returns outputs and conditional losses -- excludes 

657 activity regularizer 

658 """ 

659 # Create function that generates both outputs and losses 

660 layer_call = _get_layer_call_method(layer) 

661 def call_and_return_conditional_losses(*args, **kwargs): 

662 """Returns layer (call_output, conditional losses) tuple.""" 

663 call_output = layer_call(*args, **kwargs) 

664 if version_utils.is_v1_layer_or_model(layer): 

665 conditional_losses = layer.get_losses_for( 

666 _filtered_inputs([args, kwargs])) 

667 else: 

668 conditional_losses = [ 

669 l for l in layer.losses if not hasattr(l, '_unconditional_loss') 

670 ] 

671 return call_output, conditional_losses 

672 

673 return _create_call_fn_decorator(layer, call_and_return_conditional_losses) 

674 

675 

676def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): 

677 """Returns a function that returns only call function outputs.""" 

678 if isinstance(layer, keras_load.RevivedLayer): 

679 return layer.keras_api.__call__ # pylint: disable=protected-access 

680 def call(inputs, *args, **kwargs): 

681 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0] 

682 return _create_call_fn_decorator(layer, call) 

683 

684 

685def _append_activity_regularizer_loss( 

686 layer, call_fn_with_losses, activity_regularizer_fn): 

687 """Appends activity regularizer loss to losses returned by the wrapped fn.""" 

688 def fn(inputs, *args, **kwargs): 

689 outputs, losses = call_fn_with_losses(inputs, *args, **kwargs) 

690 losses.append(activity_regularizer_fn(outputs)) 

691 return outputs, losses 

692 return _create_call_fn_decorator(layer, fn) 

693 

694 

695def _create_call_fn_decorator(layer, wrapped_call): 

696 call_fn = _get_layer_call_method(layer) 

697 fn, arg_spec = utils.maybe_add_training_arg( 

698 call_fn, wrapped_call, layer._expects_training_arg, # pylint: disable=protected-access 

699 default_training_value=False) 

700 return tf_decorator.make_decorator( 

701 target=call_fn, 

702 decorator_func=fn, 

703 decorator_argspec=arg_spec) 

704 

705 

706def _wrap_unconditional_loss(loss_fn, index): 

707 """Wraps callable/unconditional loss, returning a serializable function.""" 

708 # Extract original loss function from partial function 

709 fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn 

710 if isinstance(fn, def_function.Function): 

711 return fn 

712 else: 

713 return def_function.Function( 

714 fn, 'loss_fn_{}'.format(index), input_signature=[]) 

715 

716 

717def _wrap_activity_regularizer(layer): 

718 """Wraps the activity regularizer.""" 

719 # pylint: disable=protected-access 

720 if isinstance(layer._activity_regularizer, def_function.Function): 

721 return layer._activity_regularizer 

722 return def_function.Function( 

723 layer._activity_regularizer, 

724 '{}_activity_regularizer'.format(layer.name), 

725 input_signature=[ 

726 tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx()) 

727 ]) 

728 # pylint: enable=protected-access 

729 

730 

731def _get_layer_call_method(layer): 

732 if isinstance(layer.call, (def_function.Function)): 

733 return layer.call.python_function 

734 return layer.call