Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/legacy_tf_layers/variable_scope_shim.py: 22%

176 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15# pylint: disable=g-classes-have-attributes 

16"""Contains a shim to allow using TF1 get_variable code in TF2.""" 

17import functools 

18 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.keras.engine import base_layer 

24from tensorflow.python.keras.utils import tf_contextlib 

25from tensorflow.python.keras.utils import tf_inspect 

26from tensorflow.python.module import module 

27from tensorflow.python.ops import init_ops 

28from tensorflow.python.ops import variable_scope as vs 

29from tensorflow.python.ops import variables 

30from tensorflow.python.platform import tf_logging as logging 

31from tensorflow.python.util import tf_decorator 

32 

33 

34def as_shape(shape): 

35 """Converts the given object to a TensorShape.""" 

36 if isinstance(shape, tensor_shape.TensorShape): 

37 return shape 

38 else: 

39 return tensor_shape.TensorShape(shape) 

40 

41 

42def _is_callable_object(obj): 

43 return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__) 

44 

45 

46def _has_kwargs(fn): 

47 """Returns whether the passed callable has **kwargs in its signature. 

48 

49 Args: 

50 fn: Function, or function-like object (e.g., result of `functools.partial`). 

51 

52 Returns: 

53 `bool`: if `fn` has **kwargs in its signature. 

54 

55 Raises: 

56 `TypeError`: If fn is not a Function, or function-like object. 

57 """ 

58 if isinstance(fn, functools.partial): 

59 fn = fn.func 

60 elif _is_callable_object(fn): 

61 fn = fn.__call__ 

62 elif not callable(fn): 

63 raise TypeError( 

64 "fn should be a function-like object, but is of type {}.".format( 

65 type(fn))) 

66 return tf_inspect.getfullargspec(fn).varkw is not None 

67 

68 

69def fn_args(fn): 

70 """Get argument names for function-like object. 

71 

72 Args: 

73 fn: Function, or function-like object (e.g., result of `functools.partial`). 

74 

75 Returns: 

76 `tuple` of string argument names. 

77 

78 Raises: 

79 ValueError: if partial function has positionally bound arguments 

80 """ 

81 if isinstance(fn, functools.partial): 

82 args = fn_args(fn.func) 

83 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] 

84 else: 

85 if hasattr(fn, "__call__") and tf_inspect.ismethod(fn.__call__): 

86 fn = fn.__call__ 

87 args = tf_inspect.getfullargspec(fn).args 

88 if _is_bound_method(fn) and args: 

89 # If it's a bound method, it may or may not have a self/cls first 

90 # argument; for example, self could be captured in *args. 

91 # If it does have a positional argument, it is self/cls. 

92 args.pop(0) 

93 return tuple(args) 

94 

95 

96def _is_bound_method(fn): 

97 _, fn = tf_decorator.unwrap(fn) 

98 return tf_inspect.ismethod(fn) and (fn.__self__ is not None) 

99 

100 

101def validate_synchronization_aggregation_trainable( 

102 synchronization, aggregation, trainable, name): 

103 """Given user-provided variable properties, sets defaults and validates.""" 

104 if aggregation is None: 

105 aggregation = variables.VariableAggregation.NONE 

106 else: 

107 if not isinstance(aggregation, 

108 (variables.VariableAggregation, 

109 variables.VariableAggregationV2)): 

110 try: 

111 aggregation = variables.VariableAggregationV2(aggregation) 

112 except ValueError: 

113 raise ValueError( 

114 "Invalid variable aggregation mode: {} for variable: {}".format( 

115 aggregation, name)) 

116 if synchronization is None: 

117 synchronization = variables.VariableSynchronization.AUTO 

118 else: 

119 try: 

120 synchronization = variables.VariableSynchronization(synchronization) 

121 except ValueError: 

122 raise ValueError( 

123 "Invalid variable synchronization mode: {} for variable: {}".format( 

124 synchronization, name)) 

125 if trainable is None: 

126 trainable = synchronization != variables.VariableSynchronization.ON_READ 

127 return synchronization, aggregation, trainable 

128 

129 

130class _EagerVariableStore(object): 

131 """TF2-compatible VariableStore that avoids collections & tracks regularizers. 

132 

133 New variable names and new variables can be created; all stored 

134 variables are initialized with the initializer passed to __init__. 

135 

136 All variables get created in `tf.init_scope.` to avoid a bad 

137 interaction between `tf.function` `FuncGraph` internals, Keras 

138 Functional Models, and TPUStrategy variable initialization. 

139 

140 Attributes: 

141 vars: a dictionary with string names (same as passed in GetVar) as keys and 

142 the corresponding TensorFlow Variables as values. 

143 """ 

144 

145 __slots__ = ["_vars", "_regularizers", "_store_eager_variables"] 

146 

147 def __init__(self): 

148 """Create a variable store.""" 

149 self._vars = {} # A dictionary of the stored TensorFlow variables. 

150 self._regularizers = {} # A dict mapping var names to their regularizers. 

151 self._store_eager_variables = True 

152 

153 def get_variable( 

154 self, 

155 name, 

156 shape=None, 

157 dtype=dtypes.float32, 

158 initializer=None, 

159 regularizer=None, 

160 reuse=None, 

161 trainable=None, 

162 collections=None, 

163 caching_device=None, 

164 partitioner=None, 

165 validate_shape=True, 

166 use_resource=None, 

167 custom_getter=None, 

168 constraint=None, 

169 synchronization=vs.VariableSynchronization.AUTO, 

170 aggregation=vs.VariableAggregation.NONE): 

171 """Gets an existing variable with these parameters or create a new one. 

172 

173 If a variable with the given name is already stored, we return the stored 

174 variable. Otherwise, we create a new one. 

175 

176 Set `reuse` to `True` when you only want to reuse existing Variables. 

177 Set `reuse` to `False` when you only want to create new Variables. 

178 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 

179 variables to be created if they don't exist or returned if they do. 

180 

181 If initializer is `None` (the default), the default initializer passed in 

182 the constructor is used. If that one is `None` too, we use a new 

183 `glorot_uniform_initializer`. If initializer is a Tensor, we use 

184 it as a value and derive the shape from the initializer. 

185 

186 If a partitioner is provided, a `PartitionedVariable` is returned. 

187 Accessing this object as a `Tensor` returns the shards concatenated along 

188 the partition axis. 

189 

190 Some useful partitioners are available. See, e.g., 

191 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 

192 

193 Args: 

194 name: The name of the new or existing variable. 

195 shape: Shape of the new or existing variable. 

196 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 

197 initializer: Initializer for the variable. 

198 regularizer: A (Tensor -> Tensor or None) function; the result of applying 

199 it on a newly created variable will be added to the collection 

200 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 

201 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 

202 variables. When eager execution is enabled this argument is always 

203 forced to be False. 

204 trainable: If `True` also add the variable to the graph collection 

205 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 

206 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 

207 which case it defaults to `False`. 

208 collections: List of graph collections keys to add the `Variable` to. 

209 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 

210 caching_device: Optional device string or function describing where the 

211 Variable should be cached for reading. Defaults to the Variable's 

212 device. If not `None`, caches on another device. Typical use is to 

213 cache on the device where the Ops using the `Variable` reside, to 

214 deduplicate copying through `Switch` and other conditional statements. 

215 partitioner: Optional callable that accepts a fully defined `TensorShape` 

216 and dtype of the `Variable` to be created, and returns a list of 

217 partitions for each axis (currently only one axis can be partitioned). 

218 validate_shape: If False, allows the variable to be initialized with a 

219 value of unknown shape. If True, the default, the shape of initial_value 

220 must be known. 

221 use_resource: If False, creates a regular Variable. If True, creates 

222 instead an experimental ResourceVariable which has well-defined 

223 semantics. Defaults to False (will later change to True). When eager 

224 execution is enabled this argument is always forced to be true. 

225 custom_getter: Callable that takes as a first argument the true getter, 

226 and allows overwriting the internal get_variable method. The signature 

227 of `custom_getter` should match that of this method, 

228 but the most future-proof version will allow for changes: `def 

229 custom_getter(getter, *args, **kwargs)`. Direct access to 

230 all `get_variable` parameters is also allowed: `def 

231 custom_getter(getter, name, *args, **kwargs)`. A simple identity 

232 custom getter that simply creates variables with modified names is: 

233 ```python 

234 def custom_getter(getter, name, *args, **kwargs): return getter(name + 

235 '_suffix', *args, **kwargs) ``` 

236 constraint: An optional projection function to be applied to the variable 

237 after being updated by an `Optimizer` (e.g. used to implement norm 

238 constraints or value constraints for layer weights). The function must 

239 take as input the unprojected Tensor representing the value of the 

240 variable and return the Tensor for the projected value (which must have 

241 the same shape). Constraints are not safe to use when doing asynchronous 

242 distributed training. 

243 synchronization: Indicates when a distributed a variable will be 

244 aggregated. Accepted values are constants defined in the class 

245 `tf.VariableSynchronization`. By default the synchronization is set to 

246 `AUTO` and the current `DistributionStrategy` chooses when to 

247 synchronize. 

248 aggregation: Indicates how a distributed variable will be aggregated. 

249 Accepted values are constants defined in the class 

250 `tf.VariableAggregation`. 

251 

252 Returns: 

253 The created or existing `Variable` (or `PartitionedVariable`, if a 

254 partitioner was used). 

255 

256 Raises: 

257 ValueError: when creating a new variable and shape is not declared, 

258 when reusing a variable and specifying a conflicting shape, 

259 or when violating reuse during variable creation. 

260 RuntimeError: when eager execution is enabled and not called from an 

261 EagerVariableStore. 

262 """ 

263 if custom_getter is not None and not callable(custom_getter): 

264 raise ValueError("Passed a custom_getter which is not callable: %s" % 

265 custom_getter) 

266 

267 with ops.init_scope(): 

268 if context.executing_eagerly(): 

269 # Variable creation and initialization takes place in `init_scope`s; 

270 # as such, if an `init_scope` lifts us into the eager context, then we 

271 # need to use `ResourceVariable`s. 

272 use_resource = True 

273 

274 # Note that it's fine to reuse eager variables whose initialization was 

275 # lifted from a function-building graph into the eager context (that's why 

276 # the following clause is not wrapped in an `init_scope`); lifted variables 

277 # are tracked by the graph's `VariableStore`. 

278 if context.executing_eagerly(): 

279 reuse = vs.AUTO_REUSE 

280 

281 # If a *_ref type is passed in an error would be triggered further down the 

282 # stack. We prevent this using base_dtype to get a non-ref version of the 

283 # type, before doing anything else. When _ref types are removed in favor of 

284 # resources, this line can be removed. 

285 try: 

286 dtype = dtype.base_dtype 

287 except AttributeError: 

288 # .base_dtype not existing means that we will try and use the raw dtype 

289 # which was passed in - this might be a NumPy type which is valid. 

290 pass 

291 

292 # This is the main logic of get_variable. However, custom_getter 

293 # may override this logic. So we save it as a callable and pass 

294 # it to custom_getter. 

295 # Note: the parameters of _true_getter, and their documentation, match 

296 # *exactly* item-for-item with the docstring of this method. 

297 def _true_getter( # pylint: disable=missing-docstring 

298 name, 

299 shape=None, 

300 dtype=dtypes.float32, 

301 initializer=None, 

302 regularizer=None, 

303 reuse=None, 

304 trainable=None, 

305 collections=None, # pylint: disable=unused-argument 

306 caching_device=None, 

307 partitioner=None, 

308 validate_shape=True, 

309 use_resource=None, # pylint: disable=unused-argument 

310 constraint=None, 

311 synchronization=vs.VariableSynchronization.AUTO, 

312 aggregation=vs.VariableAggregation.NONE): 

313 # Partitioned variable currently unsupported w/ the shim 

314 if partitioner is not None: 

315 raise ValueError( 

316 "`partitioner` arg for `get_variable` is unsupported in TF2." 

317 "File a bug if you need help. You passed %s" % partitioner) 

318 

319 # Single variable case 

320 if "%s/part_0" % name in self._vars: 

321 raise ValueError( 

322 "No partitioner was provided, but a partitioned version of the " 

323 "variable was found: %s/part_0. Perhaps a variable of the same " 

324 "name was already created with partitioning?" % name) 

325 

326 return self._get_single_variable( 

327 name=name, 

328 shape=shape, 

329 dtype=dtype, 

330 initializer=initializer, 

331 regularizer=regularizer, 

332 reuse=reuse, 

333 trainable=trainable, 

334 caching_device=caching_device, 

335 validate_shape=validate_shape, 

336 constraint=constraint, 

337 synchronization=synchronization, 

338 aggregation=aggregation) 

339 

340 synchronization, aggregation, trainable = ( 

341 validate_synchronization_aggregation_trainable( 

342 synchronization, aggregation, trainable, name)) 

343 

344 if custom_getter is not None: 

345 # Handle backwards compatibility with getter arguments that were added 

346 # to the API after users started writing custom getters. 

347 custom_getter_kwargs = { 

348 "getter": _true_getter, 

349 "name": name, 

350 "shape": shape, 

351 "dtype": dtype, 

352 "initializer": initializer, 

353 "regularizer": regularizer, 

354 "reuse": reuse, 

355 "trainable": trainable, 

356 "collections": collections, 

357 "caching_device": caching_device, 

358 "partitioner": partitioner, 

359 "validate_shape": validate_shape, 

360 "use_resource": use_resource, 

361 "synchronization": synchronization, 

362 "aggregation": aggregation, 

363 } 

364 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 

365 # `lambda`. 

366 if ("constraint" in fn_args(custom_getter) or 

367 _has_kwargs(custom_getter)): 

368 custom_getter_kwargs["constraint"] = constraint 

369 return custom_getter(**custom_getter_kwargs) 

370 else: 

371 return _true_getter( 

372 name, 

373 shape=shape, 

374 dtype=dtype, 

375 initializer=initializer, 

376 regularizer=regularizer, 

377 reuse=reuse, 

378 trainable=trainable, 

379 collections=collections, 

380 caching_device=caching_device, 

381 partitioner=partitioner, 

382 validate_shape=validate_shape, 

383 use_resource=use_resource, 

384 constraint=constraint, 

385 synchronization=synchronization, 

386 aggregation=aggregation) 

387 

388 def _get_single_variable( 

389 self, 

390 name, 

391 shape=None, 

392 dtype=dtypes.float32, 

393 initializer=None, 

394 regularizer=None, 

395 partition_info=None, 

396 reuse=None, 

397 trainable=None, 

398 caching_device=None, 

399 validate_shape=True, 

400 constraint=None, 

401 synchronization=vs.VariableSynchronization.AUTO, 

402 aggregation=vs.VariableAggregation.NONE): 

403 """Get or create a single Variable (e.g. 

404 

405 a shard or entire variable). 

406 

407 See the documentation of get_variable above (ignore partitioning components) 

408 for details. 

409 

410 Args: 

411 name: see get_variable. 

412 shape: see get_variable. 

413 dtype: see get_variable. 

414 initializer: see get_variable. 

415 regularizer: see get_variable. 

416 partition_info: _PartitionInfo object. 

417 reuse: see get_variable. 

418 trainable: see get_variable. 

419 caching_device: see get_variable. 

420 validate_shape: see get_variable. 

421 constraint: see get_variable. 

422 synchronization: see get_variable. 

423 aggregation: see get_variable. 

424 

425 Returns: 

426 A Variable. See documentation of get_variable above. 

427 

428 Raises: 

429 ValueError: See documentation of get_variable above. 

430 """ 

431 # Set to true if initializer is a constant. 

432 initializing_from_value = False 

433 if initializer is not None and not callable(initializer): 

434 initializing_from_value = True 

435 if shape is not None and initializing_from_value: 

436 raise ValueError("If initializer is a constant, do not specify shape.") 

437 

438 dtype = dtypes.as_dtype(dtype) 

439 shape = as_shape(shape) 

440 

441 if name in self._vars: 

442 # Here we handle the case when returning an existing variable. 

443 if reuse is False: # pylint: disable=g-bool-id-comparison 

444 err_msg = ("Variable %s already exists, disallowed." 

445 " Did you mean to set reuse=True or " 

446 "reuse=tf.AUTO_REUSE in VarScope?" % name) 

447 # ResourceVariables don't have an op associated with so no traceback 

448 raise ValueError(err_msg) 

449 found_var = self._vars[name] 

450 if not shape.is_compatible_with(found_var.get_shape()): 

451 raise ValueError("Trying to share variable %s, but specified shape %s" 

452 " and found shape %s." % 

453 (name, shape, found_var.get_shape())) 

454 if not dtype.is_compatible_with(found_var.dtype): 

455 dtype_str = dtype.name 

456 found_type_str = found_var.dtype.name 

457 raise ValueError("Trying to share variable %s, but specified dtype %s" 

458 " and found dtype %s." % 

459 (name, dtype_str, found_type_str)) 

460 return found_var 

461 

462 # The code below handles only the case of creating a new variable. 

463 if reuse is True: # pylint: disable=g-bool-id-comparison 

464 raise ValueError("Variable %s does not exist, or was not created with " 

465 "tf.get_variable(). Did you mean to set " 

466 "reuse=tf.AUTO_REUSE in VarScope?" % name) 

467 

468 # Create the tensor to initialize the variable with default value. 

469 if initializer is None: 

470 initializer, initializing_from_value = self._get_default_initializer( 

471 name=name, shape=shape, dtype=dtype) 

472 # Enter an init scope when creating the initializer. 

473 with ops.init_scope(): 

474 if initializing_from_value: 

475 init_val = initializer 

476 variable_dtype = None 

477 else: 

478 # Instantiate initializer if provided initializer is a type object. 

479 if tf_inspect.isclass(initializer): 

480 initializer = initializer() 

481 if shape.is_fully_defined(): 

482 if "partition_info" in tf_inspect.getargspec(initializer).args: 

483 init_val = functools.partial(initializer, 

484 shape.as_list(), 

485 dtype=dtype, 

486 partition_info=partition_info) 

487 else: 

488 init_val = functools.partial(initializer, 

489 shape.as_list(), dtype=dtype) 

490 variable_dtype = dtype.base_dtype 

491 else: 

492 init_val = initializer 

493 variable_dtype = None 

494 

495 # Create the variable (Always eagerly as a workaround for a strange 

496 # tpu / funcgraph / keras functional model interaction ) 

497 with ops.init_scope(): 

498 v = variables.Variable( 

499 initial_value=init_val, 

500 name=name, 

501 trainable=trainable, 

502 caching_device=caching_device, 

503 dtype=variable_dtype, 

504 validate_shape=validate_shape, 

505 constraint=constraint, 

506 synchronization=synchronization, 

507 aggregation=aggregation) 

508 

509 self._vars[name] = v 

510 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 

511 format(shape), initializer) 

512 

513 # Run the regularizer if requested and save the resulting loss. 

514 if regularizer: 

515 self.add_regularizer(v, regularizer) 

516 

517 return v 

518 

519 def add_regularizer(self, var, regularizer): 

520 self._regularizers[var.name] = functools.partial(regularizer, var) 

521 

522 # Initialize variable when no initializer provided 

523 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 

524 """Provide a default initializer and a corresponding value. 

525 

526 Args: 

527 name: see get_variable. 

528 shape: see get_variable. 

529 dtype: see get_variable. 

530 

531 Returns: 

532 initializer and initializing_from_value. See get_variable above. 

533 

534 Raises: 

535 ValueError: When giving unsupported dtype. 

536 """ 

537 del shape 

538 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 

539 if dtype.is_floating: 

540 initializer = init_ops.glorot_uniform_initializer() 

541 initializing_from_value = False 

542 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 

543 # If dtype is DT_BOOL, provide a default value `FALSE` 

544 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 

545 dtype == dtypes.string): 

546 initializer = init_ops.zeros_initializer() 

547 initializing_from_value = False 

548 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 

549 else: 

550 raise ValueError("An initializer for variable %s of %s is required" % 

551 (name, dtype.base_dtype)) 

552 

553 return initializer, initializing_from_value 

554 

555 

556class VariableAndLossTracker(module.Module): 

557 """Module that has a scope to capture vars/losses made by `get_variable`.""" 

558 

559 def __init__(self): 

560 self._var_store = _EagerVariableStore() # pylint: disable=protected-access 

561 self._variables = {} 

562 

563 def _variable_creator(self, next_creator, **kwargs): 

564 var = next_creator(**kwargs) 

565 self._variables[var.name] = var 

566 

567 return var 

568 

569 @tf_contextlib.contextmanager 

570 def scope(self): 

571 with vs.variable_creator_scope( 

572 self._variable_creator), vs.with_variable_store(self._var_store): 

573 yield 

574 

575 def get_regularization_losses(self): 

576 # TODO(kaftan): Consider adding a regex scope like the collection access. 

577 # But, < 40-50 usages of get_regularization_loss(es) with `scope` 

578 # & possible to do manually? 

579 losses = {} 

580 for var_name, regularizer in self._var_store._regularizers.items(): # pylint: disable=protected-access 

581 losses[var_name] = regularizer() 

582 return losses 

583 

584 

585class VariableScopeWrapperLayer(base_layer.Layer): 

586 """Wrapper Layer to capture `compat.v1.get_variable` and `compat.v1.layers`. 

587 

588 See go/tf2-migration-model-bookkeeping for background. 

589 

590 This shim layer allows using large sets of TF1 model-forward-pass code as a 

591 Keras layer that works in TF2 with TF2 behaviors enabled. To use it, 

592 override this class and put your TF1 model's forward pass inside your 

593 implementation for `forward_pass`. 

594 

595 Below are some examples, and then more details on the functionality of this 

596 shhim layer to wrap TF1 model forward passes. 

597 

598 Example of capturing tf.compat.v1.layer-based modeling code as a Keras layer: 

599 

600 ```python 

601 class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer): 

602 

603 def __init__(self, units, *args, **kwargs): 

604 super().__init__(*args, **kwargs) 

605 self.units = units 

606 

607 def forward_pass(self, inputs, training=None): 

608 out = tf.compat.v1.layers.dense( 

609 inputs, self.units, name="dense_one", 

610 kernel_initializer=init_ops.ones_initializer(), 

611 kernel_regularizer="l2") 

612 with variable_scope.variable_scope("nested_scope"): 

613 out = tf.compat.v1.layers.dense( 

614 out, self.units, name="dense_two", 

615 kernel_initializer=init_ops.ones_initializer(), 

616 kernel_regularizer="l2") 

617 return out 

618 

619 # Create a layer that can be used as a standard keras layer 

620 layer = WrappedDoubleDenseLayer(10) 

621 

622 # call the layer on inputs 

623 layer(...) 

624 

625 # Variables created/used within the scope will be tracked by the layer 

626 layer.weights 

627 layer.trainable_variables 

628 

629 # Regularization losses will be captured in layer.losses after a call, 

630 # just like any other Keras layer 

631 reg_losses = layer.losses 

632 ``` 

633 

634 The solution is to wrap the model construction and execution in a keras-style 

635 scope: 

636 

637 ```python 

638 class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer): 

639 

640 def __init__(self, units, *args, **kwargs): 

641 super().__init__(*args, **kwargs) 

642 self.units = units 

643 

644 def forward_pass(self, inputs, training=None): 

645 out = inputs 

646 with tf.compat.v1.variable_scope("dense_one"): 

647 # The weights are created with a `regularizer`, 

648 # so the layer should track their regularization losses 

649 kernel = tf.compat.v1.get_variable( 

650 shape=[out.shape[-1], self.units], 

651 regularizer=regularizers.L2(), 

652 initializer=init_ops.ones_initializer(), 

653 name="kernel") 

654 bias = tf.compat.v1.get_variable( 

655 shape=[self.units,], 

656 initializer=init_ops.zeros_initializer(), 

657 name="bias") 

658 out = tf.compat.v1.math.matmul(out, kernel) 

659 out = tf.compat.v1.nn.bias_add(out, bias) 

660 with tf.compat.v1.variable_scope("nested_scope"): 

661 with tf.compat.v1.variable_scope("dense_two"): 

662 kernel = tf.compat.v1.get_variable( 

663 shape=[out.shape[-1], self.units], 

664 regularizer=regularizers.L2(), 

665 initializer=init_ops.ones_initializer(), 

666 name="kernel") 

667 bias = tf.compat.v1.get_variable( 

668 shape=[self.units,], 

669 initializer=init_ops.zeros_initializer(), 

670 name="bias") 

671 out = tf.compat.v1.math.matmul(out, kernel) 

672 out = tf.compat.v1.nn.bias_add(out, bias) 

673 return out 

674 

675 # Create a layer that can be used as a standard keras layer 

676 layer = WrappedDoubleDenseLayer(10) 

677 

678 # call the layer on inputs 

679 layer(...) 

680 

681 # Variables created/used within the scope will be tracked by the layer 

682 layer.weights 

683 layer.trainable_variables 

684 

685 # Regularization losses will be captured in layer.losses after a call, 

686 # just like any other Keras layer 

687 reg_losses = layer.losses 

688 ``` 

689 

690 Regularization losses: 

691 Any regularizers specified in the `get_variable` calls or `compat.v1.layer` 

692 creations will get captured by this wrapper layer. Regularization losses 

693 are accessible in `layer.losses` after a call just like in a standard 

694 Keras layer, and will be captured by any model that includes this layer. 

695 

696 Variable scope / variable reuse: 

697 variable-scope based reuse in the `forward_pass` will be respected, 

698 and work like variable-scope based reuse in TF1. 

699 

700 Variable Names/Pre-trained checkpoint loading: 

701 variable naming from get_variable and `compat.v1.layer` layers will match 

702 the TF1 names, so you should be able to re-use your old name-based 

703 checkpoints. 

704 

705 Training Arg in `forward_pass`: 

706 Keras will pass a `training` arg to this layer similarly to how it 

707 passes `training` to other layers in TF2. See more details in the docs 

708 on `tf.keras.layers.Layer` to understand what will be passed and when. 

709 Note: tf.compat.v1.layers are usually not called with `training=None`, 

710 so the training arg to `forward_pass` might not feed through to them 

711 unless you pass it to their calls explicitly. 

712 

713 Call signature of the forward pass: 

714 The semantics of the forward pass signature roughly match the standard 

715 Keras layer `call` signature, except that a `training` arg will *always* 

716 be passed, so your `forward_pass` must accept either. 

717 

718 Limitations: 

719 * TF2 will not prune unused variable updates (or unused outputs). You may 

720 need to adjust your forward pass code to avoid computations or variable 

721 updates that you don't intend to use. (E.g. by adding a flag to the 

722 `forward_pass` call signature and branching on it). 

723 * Avoid Nesting variable creation in tf.function inside of `forward_pass` 

724 While the layer may safetely be used from inside a `tf.function`, using 

725 a function inside of `forward_pass` will break the variable scoping. 

726 * TBD: Nesting keras layers/models or other `VariableScopeWrapperLayer`s 

727 directly in `forward_pass` may not work correctly just yet. 

728 Support for this/instructions for how to do this is sill being worked on. 

729 

730 Coming soon: A better guide, testing/verification guide. 

731 """ 

732 

733 def __init__(self, **kwargs): 

734 super().__init__(**kwargs) 

735 # Relies on keras layers tracking Modules 

736 self.tracker = VariableAndLossTracker() 

737 # May need to inspect func to see if it should pass a `training` arg or not 

738 

739 def forward_pass(self, *args, **kwargs): 

740 raise NotImplementedError 

741 

742 def call(self, *args, **kwargs): 

743 with self.tracker.scope(): 

744 out = self.forward_pass(*args, **kwargs) 

745 if not self._eager_losses: 

746 # We have to record regularization losses in the call as if they 

747 # are activity losses. 

748 # So, don't double-count regularization losses if the layer is used 

749 # multiple times in a model 

750 for loss in self.tracker.get_regularization_losses().values(): 

751 self.add_loss(loss) 

752 return out