1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Contains a shim to allow using TF1 get_variable code in TF2."""
17import functools
18
19from tensorflow.python.eager import context
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.keras.engine import base_layer
24from tensorflow.python.keras.utils import tf_contextlib
25from tensorflow.python.keras.utils import tf_inspect
26from tensorflow.python.module import module
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import variable_scope as vs
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import tf_decorator
32
33
34def as_shape(shape):
35 """Converts the given object to a TensorShape."""
36 if isinstance(shape, tensor_shape.TensorShape):
37 return shape
38 else:
39 return tensor_shape.TensorShape(shape)
40
41
42def _is_callable_object(obj):
43 return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__)
44
45
46def _has_kwargs(fn):
47 """Returns whether the passed callable has **kwargs in its signature.
48
49 Args:
50 fn: Function, or function-like object (e.g., result of `functools.partial`).
51
52 Returns:
53 `bool`: if `fn` has **kwargs in its signature.
54
55 Raises:
56 `TypeError`: If fn is not a Function, or function-like object.
57 """
58 if isinstance(fn, functools.partial):
59 fn = fn.func
60 elif _is_callable_object(fn):
61 fn = fn.__call__
62 elif not callable(fn):
63 raise TypeError(
64 "fn should be a function-like object, but is of type {}.".format(
65 type(fn)))
66 return tf_inspect.getfullargspec(fn).varkw is not None
67
68
69def fn_args(fn):
70 """Get argument names for function-like object.
71
72 Args:
73 fn: Function, or function-like object (e.g., result of `functools.partial`).
74
75 Returns:
76 `tuple` of string argument names.
77
78 Raises:
79 ValueError: if partial function has positionally bound arguments
80 """
81 if isinstance(fn, functools.partial):
82 args = fn_args(fn.func)
83 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
84 else:
85 if hasattr(fn, "__call__") and tf_inspect.ismethod(fn.__call__):
86 fn = fn.__call__
87 args = tf_inspect.getfullargspec(fn).args
88 if _is_bound_method(fn) and args:
89 # If it's a bound method, it may or may not have a self/cls first
90 # argument; for example, self could be captured in *args.
91 # If it does have a positional argument, it is self/cls.
92 args.pop(0)
93 return tuple(args)
94
95
96def _is_bound_method(fn):
97 _, fn = tf_decorator.unwrap(fn)
98 return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
99
100
101def validate_synchronization_aggregation_trainable(
102 synchronization, aggregation, trainable, name):
103 """Given user-provided variable properties, sets defaults and validates."""
104 if aggregation is None:
105 aggregation = variables.VariableAggregation.NONE
106 else:
107 if not isinstance(aggregation,
108 (variables.VariableAggregation,
109 variables.VariableAggregationV2)):
110 try:
111 aggregation = variables.VariableAggregationV2(aggregation)
112 except ValueError:
113 raise ValueError(
114 "Invalid variable aggregation mode: {} for variable: {}".format(
115 aggregation, name))
116 if synchronization is None:
117 synchronization = variables.VariableSynchronization.AUTO
118 else:
119 try:
120 synchronization = variables.VariableSynchronization(synchronization)
121 except ValueError:
122 raise ValueError(
123 "Invalid variable synchronization mode: {} for variable: {}".format(
124 synchronization, name))
125 if trainable is None:
126 trainable = synchronization != variables.VariableSynchronization.ON_READ
127 return synchronization, aggregation, trainable
128
129
130class _EagerVariableStore(object):
131 """TF2-compatible VariableStore that avoids collections & tracks regularizers.
132
133 New variable names and new variables can be created; all stored
134 variables are initialized with the initializer passed to __init__.
135
136 All variables get created in `tf.init_scope.` to avoid a bad
137 interaction between `tf.function` `FuncGraph` internals, Keras
138 Functional Models, and TPUStrategy variable initialization.
139
140 Attributes:
141 vars: a dictionary with string names (same as passed in GetVar) as keys and
142 the corresponding TensorFlow Variables as values.
143 """
144
145 __slots__ = ["_vars", "_regularizers", "_store_eager_variables"]
146
147 def __init__(self):
148 """Create a variable store."""
149 self._vars = {} # A dictionary of the stored TensorFlow variables.
150 self._regularizers = {} # A dict mapping var names to their regularizers.
151 self._store_eager_variables = True
152
153 def get_variable(
154 self,
155 name,
156 shape=None,
157 dtype=dtypes.float32,
158 initializer=None,
159 regularizer=None,
160 reuse=None,
161 trainable=None,
162 collections=None,
163 caching_device=None,
164 partitioner=None,
165 validate_shape=True,
166 use_resource=None,
167 custom_getter=None,
168 constraint=None,
169 synchronization=vs.VariableSynchronization.AUTO,
170 aggregation=vs.VariableAggregation.NONE):
171 """Gets an existing variable with these parameters or create a new one.
172
173 If a variable with the given name is already stored, we return the stored
174 variable. Otherwise, we create a new one.
175
176 Set `reuse` to `True` when you only want to reuse existing Variables.
177 Set `reuse` to `False` when you only want to create new Variables.
178 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
179 variables to be created if they don't exist or returned if they do.
180
181 If initializer is `None` (the default), the default initializer passed in
182 the constructor is used. If that one is `None` too, we use a new
183 `glorot_uniform_initializer`. If initializer is a Tensor, we use
184 it as a value and derive the shape from the initializer.
185
186 If a partitioner is provided, a `PartitionedVariable` is returned.
187 Accessing this object as a `Tensor` returns the shards concatenated along
188 the partition axis.
189
190 Some useful partitioners are available. See, e.g.,
191 `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
192
193 Args:
194 name: The name of the new or existing variable.
195 shape: Shape of the new or existing variable.
196 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
197 initializer: Initializer for the variable.
198 regularizer: A (Tensor -> Tensor or None) function; the result of applying
199 it on a newly created variable will be added to the collection
200 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
201 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
202 variables. When eager execution is enabled this argument is always
203 forced to be False.
204 trainable: If `True` also add the variable to the graph collection
205 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
206 defaults to `True`, unless `synchronization` is set to `ON_READ`, in
207 which case it defaults to `False`.
208 collections: List of graph collections keys to add the `Variable` to.
209 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
210 caching_device: Optional device string or function describing where the
211 Variable should be cached for reading. Defaults to the Variable's
212 device. If not `None`, caches on another device. Typical use is to
213 cache on the device where the Ops using the `Variable` reside, to
214 deduplicate copying through `Switch` and other conditional statements.
215 partitioner: Optional callable that accepts a fully defined `TensorShape`
216 and dtype of the `Variable` to be created, and returns a list of
217 partitions for each axis (currently only one axis can be partitioned).
218 validate_shape: If False, allows the variable to be initialized with a
219 value of unknown shape. If True, the default, the shape of initial_value
220 must be known.
221 use_resource: If False, creates a regular Variable. If True, creates
222 instead an experimental ResourceVariable which has well-defined
223 semantics. Defaults to False (will later change to True). When eager
224 execution is enabled this argument is always forced to be true.
225 custom_getter: Callable that takes as a first argument the true getter,
226 and allows overwriting the internal get_variable method. The signature
227 of `custom_getter` should match that of this method,
228 but the most future-proof version will allow for changes: `def
229 custom_getter(getter, *args, **kwargs)`. Direct access to
230 all `get_variable` parameters is also allowed: `def
231 custom_getter(getter, name, *args, **kwargs)`. A simple identity
232 custom getter that simply creates variables with modified names is:
233 ```python
234 def custom_getter(getter, name, *args, **kwargs): return getter(name +
235 '_suffix', *args, **kwargs) ```
236 constraint: An optional projection function to be applied to the variable
237 after being updated by an `Optimizer` (e.g. used to implement norm
238 constraints or value constraints for layer weights). The function must
239 take as input the unprojected Tensor representing the value of the
240 variable and return the Tensor for the projected value (which must have
241 the same shape). Constraints are not safe to use when doing asynchronous
242 distributed training.
243 synchronization: Indicates when a distributed a variable will be
244 aggregated. Accepted values are constants defined in the class
245 `tf.VariableSynchronization`. By default the synchronization is set to
246 `AUTO` and the current `DistributionStrategy` chooses when to
247 synchronize.
248 aggregation: Indicates how a distributed variable will be aggregated.
249 Accepted values are constants defined in the class
250 `tf.VariableAggregation`.
251
252 Returns:
253 The created or existing `Variable` (or `PartitionedVariable`, if a
254 partitioner was used).
255
256 Raises:
257 ValueError: when creating a new variable and shape is not declared,
258 when reusing a variable and specifying a conflicting shape,
259 or when violating reuse during variable creation.
260 RuntimeError: when eager execution is enabled and not called from an
261 EagerVariableStore.
262 """
263 if custom_getter is not None and not callable(custom_getter):
264 raise ValueError("Passed a custom_getter which is not callable: %s" %
265 custom_getter)
266
267 with ops.init_scope():
268 if context.executing_eagerly():
269 # Variable creation and initialization takes place in `init_scope`s;
270 # as such, if an `init_scope` lifts us into the eager context, then we
271 # need to use `ResourceVariable`s.
272 use_resource = True
273
274 # Note that it's fine to reuse eager variables whose initialization was
275 # lifted from a function-building graph into the eager context (that's why
276 # the following clause is not wrapped in an `init_scope`); lifted variables
277 # are tracked by the graph's `VariableStore`.
278 if context.executing_eagerly():
279 reuse = vs.AUTO_REUSE
280
281 # If a *_ref type is passed in an error would be triggered further down the
282 # stack. We prevent this using base_dtype to get a non-ref version of the
283 # type, before doing anything else. When _ref types are removed in favor of
284 # resources, this line can be removed.
285 try:
286 dtype = dtype.base_dtype
287 except AttributeError:
288 # .base_dtype not existing means that we will try and use the raw dtype
289 # which was passed in - this might be a NumPy type which is valid.
290 pass
291
292 # This is the main logic of get_variable. However, custom_getter
293 # may override this logic. So we save it as a callable and pass
294 # it to custom_getter.
295 # Note: the parameters of _true_getter, and their documentation, match
296 # *exactly* item-for-item with the docstring of this method.
297 def _true_getter( # pylint: disable=missing-docstring
298 name,
299 shape=None,
300 dtype=dtypes.float32,
301 initializer=None,
302 regularizer=None,
303 reuse=None,
304 trainable=None,
305 collections=None, # pylint: disable=unused-argument
306 caching_device=None,
307 partitioner=None,
308 validate_shape=True,
309 use_resource=None, # pylint: disable=unused-argument
310 constraint=None,
311 synchronization=vs.VariableSynchronization.AUTO,
312 aggregation=vs.VariableAggregation.NONE):
313 # Partitioned variable currently unsupported w/ the shim
314 if partitioner is not None:
315 raise ValueError(
316 "`partitioner` arg for `get_variable` is unsupported in TF2."
317 "File a bug if you need help. You passed %s" % partitioner)
318
319 # Single variable case
320 if "%s/part_0" % name in self._vars:
321 raise ValueError(
322 "No partitioner was provided, but a partitioned version of the "
323 "variable was found: %s/part_0. Perhaps a variable of the same "
324 "name was already created with partitioning?" % name)
325
326 return self._get_single_variable(
327 name=name,
328 shape=shape,
329 dtype=dtype,
330 initializer=initializer,
331 regularizer=regularizer,
332 reuse=reuse,
333 trainable=trainable,
334 caching_device=caching_device,
335 validate_shape=validate_shape,
336 constraint=constraint,
337 synchronization=synchronization,
338 aggregation=aggregation)
339
340 synchronization, aggregation, trainable = (
341 validate_synchronization_aggregation_trainable(
342 synchronization, aggregation, trainable, name))
343
344 if custom_getter is not None:
345 # Handle backwards compatibility with getter arguments that were added
346 # to the API after users started writing custom getters.
347 custom_getter_kwargs = {
348 "getter": _true_getter,
349 "name": name,
350 "shape": shape,
351 "dtype": dtype,
352 "initializer": initializer,
353 "regularizer": regularizer,
354 "reuse": reuse,
355 "trainable": trainable,
356 "collections": collections,
357 "caching_device": caching_device,
358 "partitioner": partitioner,
359 "validate_shape": validate_shape,
360 "use_resource": use_resource,
361 "synchronization": synchronization,
362 "aggregation": aggregation,
363 }
364 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
365 # `lambda`.
366 if ("constraint" in fn_args(custom_getter) or
367 _has_kwargs(custom_getter)):
368 custom_getter_kwargs["constraint"] = constraint
369 return custom_getter(**custom_getter_kwargs)
370 else:
371 return _true_getter(
372 name,
373 shape=shape,
374 dtype=dtype,
375 initializer=initializer,
376 regularizer=regularizer,
377 reuse=reuse,
378 trainable=trainable,
379 collections=collections,
380 caching_device=caching_device,
381 partitioner=partitioner,
382 validate_shape=validate_shape,
383 use_resource=use_resource,
384 constraint=constraint,
385 synchronization=synchronization,
386 aggregation=aggregation)
387
388 def _get_single_variable(
389 self,
390 name,
391 shape=None,
392 dtype=dtypes.float32,
393 initializer=None,
394 regularizer=None,
395 partition_info=None,
396 reuse=None,
397 trainable=None,
398 caching_device=None,
399 validate_shape=True,
400 constraint=None,
401 synchronization=vs.VariableSynchronization.AUTO,
402 aggregation=vs.VariableAggregation.NONE):
403 """Get or create a single Variable (e.g.
404
405 a shard or entire variable).
406
407 See the documentation of get_variable above (ignore partitioning components)
408 for details.
409
410 Args:
411 name: see get_variable.
412 shape: see get_variable.
413 dtype: see get_variable.
414 initializer: see get_variable.
415 regularizer: see get_variable.
416 partition_info: _PartitionInfo object.
417 reuse: see get_variable.
418 trainable: see get_variable.
419 caching_device: see get_variable.
420 validate_shape: see get_variable.
421 constraint: see get_variable.
422 synchronization: see get_variable.
423 aggregation: see get_variable.
424
425 Returns:
426 A Variable. See documentation of get_variable above.
427
428 Raises:
429 ValueError: See documentation of get_variable above.
430 """
431 # Set to true if initializer is a constant.
432 initializing_from_value = False
433 if initializer is not None and not callable(initializer):
434 initializing_from_value = True
435 if shape is not None and initializing_from_value:
436 raise ValueError("If initializer is a constant, do not specify shape.")
437
438 dtype = dtypes.as_dtype(dtype)
439 shape = as_shape(shape)
440
441 if name in self._vars:
442 # Here we handle the case when returning an existing variable.
443 if reuse is False: # pylint: disable=g-bool-id-comparison
444 err_msg = ("Variable %s already exists, disallowed."
445 " Did you mean to set reuse=True or "
446 "reuse=tf.AUTO_REUSE in VarScope?" % name)
447 # ResourceVariables don't have an op associated with so no traceback
448 raise ValueError(err_msg)
449 found_var = self._vars[name]
450 if not shape.is_compatible_with(found_var.get_shape()):
451 raise ValueError("Trying to share variable %s, but specified shape %s"
452 " and found shape %s." %
453 (name, shape, found_var.get_shape()))
454 if not dtype.is_compatible_with(found_var.dtype):
455 dtype_str = dtype.name
456 found_type_str = found_var.dtype.name
457 raise ValueError("Trying to share variable %s, but specified dtype %s"
458 " and found dtype %s." %
459 (name, dtype_str, found_type_str))
460 return found_var
461
462 # The code below handles only the case of creating a new variable.
463 if reuse is True: # pylint: disable=g-bool-id-comparison
464 raise ValueError("Variable %s does not exist, or was not created with "
465 "tf.get_variable(). Did you mean to set "
466 "reuse=tf.AUTO_REUSE in VarScope?" % name)
467
468 # Create the tensor to initialize the variable with default value.
469 if initializer is None:
470 initializer, initializing_from_value = self._get_default_initializer(
471 name=name, shape=shape, dtype=dtype)
472 # Enter an init scope when creating the initializer.
473 with ops.init_scope():
474 if initializing_from_value:
475 init_val = initializer
476 variable_dtype = None
477 else:
478 # Instantiate initializer if provided initializer is a type object.
479 if tf_inspect.isclass(initializer):
480 initializer = initializer()
481 if shape.is_fully_defined():
482 if "partition_info" in tf_inspect.getargspec(initializer).args:
483 init_val = functools.partial(initializer,
484 shape.as_list(),
485 dtype=dtype,
486 partition_info=partition_info)
487 else:
488 init_val = functools.partial(initializer,
489 shape.as_list(), dtype=dtype)
490 variable_dtype = dtype.base_dtype
491 else:
492 init_val = initializer
493 variable_dtype = None
494
495 # Create the variable (Always eagerly as a workaround for a strange
496 # tpu / funcgraph / keras functional model interaction )
497 with ops.init_scope():
498 v = variables.Variable(
499 initial_value=init_val,
500 name=name,
501 trainable=trainable,
502 caching_device=caching_device,
503 dtype=variable_dtype,
504 validate_shape=validate_shape,
505 constraint=constraint,
506 synchronization=synchronization,
507 aggregation=aggregation)
508
509 self._vars[name] = v
510 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
511 format(shape), initializer)
512
513 # Run the regularizer if requested and save the resulting loss.
514 if regularizer:
515 self.add_regularizer(v, regularizer)
516
517 return v
518
519 def add_regularizer(self, var, regularizer):
520 self._regularizers[var.name] = functools.partial(regularizer, var)
521
522 # Initialize variable when no initializer provided
523 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
524 """Provide a default initializer and a corresponding value.
525
526 Args:
527 name: see get_variable.
528 shape: see get_variable.
529 dtype: see get_variable.
530
531 Returns:
532 initializer and initializing_from_value. See get_variable above.
533
534 Raises:
535 ValueError: When giving unsupported dtype.
536 """
537 del shape
538 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
539 if dtype.is_floating:
540 initializer = init_ops.glorot_uniform_initializer()
541 initializing_from_value = False
542 # If dtype is DT_INT/DT_UINT, provide a default value `zero`
543 # If dtype is DT_BOOL, provide a default value `FALSE`
544 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
545 dtype == dtypes.string):
546 initializer = init_ops.zeros_initializer()
547 initializing_from_value = False
548 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
549 else:
550 raise ValueError("An initializer for variable %s of %s is required" %
551 (name, dtype.base_dtype))
552
553 return initializer, initializing_from_value
554
555
556class VariableAndLossTracker(module.Module):
557 """Module that has a scope to capture vars/losses made by `get_variable`."""
558
559 def __init__(self):
560 self._var_store = _EagerVariableStore() # pylint: disable=protected-access
561 self._variables = {}
562
563 def _variable_creator(self, next_creator, **kwargs):
564 var = next_creator(**kwargs)
565 self._variables[var.name] = var
566
567 return var
568
569 @tf_contextlib.contextmanager
570 def scope(self):
571 with vs.variable_creator_scope(
572 self._variable_creator), vs.with_variable_store(self._var_store):
573 yield
574
575 def get_regularization_losses(self):
576 # TODO(kaftan): Consider adding a regex scope like the collection access.
577 # But, < 40-50 usages of get_regularization_loss(es) with `scope`
578 # & possible to do manually?
579 losses = {}
580 for var_name, regularizer in self._var_store._regularizers.items(): # pylint: disable=protected-access
581 losses[var_name] = regularizer()
582 return losses
583
584
585class VariableScopeWrapperLayer(base_layer.Layer):
586 """Wrapper Layer to capture `compat.v1.get_variable` and `compat.v1.layers`.
587
588 See go/tf2-migration-model-bookkeeping for background.
589
590 This shim layer allows using large sets of TF1 model-forward-pass code as a
591 Keras layer that works in TF2 with TF2 behaviors enabled. To use it,
592 override this class and put your TF1 model's forward pass inside your
593 implementation for `forward_pass`.
594
595 Below are some examples, and then more details on the functionality of this
596 shhim layer to wrap TF1 model forward passes.
597
598 Example of capturing tf.compat.v1.layer-based modeling code as a Keras layer:
599
600 ```python
601 class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
602
603 def __init__(self, units, *args, **kwargs):
604 super().__init__(*args, **kwargs)
605 self.units = units
606
607 def forward_pass(self, inputs, training=None):
608 out = tf.compat.v1.layers.dense(
609 inputs, self.units, name="dense_one",
610 kernel_initializer=init_ops.ones_initializer(),
611 kernel_regularizer="l2")
612 with variable_scope.variable_scope("nested_scope"):
613 out = tf.compat.v1.layers.dense(
614 out, self.units, name="dense_two",
615 kernel_initializer=init_ops.ones_initializer(),
616 kernel_regularizer="l2")
617 return out
618
619 # Create a layer that can be used as a standard keras layer
620 layer = WrappedDoubleDenseLayer(10)
621
622 # call the layer on inputs
623 layer(...)
624
625 # Variables created/used within the scope will be tracked by the layer
626 layer.weights
627 layer.trainable_variables
628
629 # Regularization losses will be captured in layer.losses after a call,
630 # just like any other Keras layer
631 reg_losses = layer.losses
632 ```
633
634 The solution is to wrap the model construction and execution in a keras-style
635 scope:
636
637 ```python
638 class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
639
640 def __init__(self, units, *args, **kwargs):
641 super().__init__(*args, **kwargs)
642 self.units = units
643
644 def forward_pass(self, inputs, training=None):
645 out = inputs
646 with tf.compat.v1.variable_scope("dense_one"):
647 # The weights are created with a `regularizer`,
648 # so the layer should track their regularization losses
649 kernel = tf.compat.v1.get_variable(
650 shape=[out.shape[-1], self.units],
651 regularizer=regularizers.L2(),
652 initializer=init_ops.ones_initializer(),
653 name="kernel")
654 bias = tf.compat.v1.get_variable(
655 shape=[self.units,],
656 initializer=init_ops.zeros_initializer(),
657 name="bias")
658 out = tf.compat.v1.math.matmul(out, kernel)
659 out = tf.compat.v1.nn.bias_add(out, bias)
660 with tf.compat.v1.variable_scope("nested_scope"):
661 with tf.compat.v1.variable_scope("dense_two"):
662 kernel = tf.compat.v1.get_variable(
663 shape=[out.shape[-1], self.units],
664 regularizer=regularizers.L2(),
665 initializer=init_ops.ones_initializer(),
666 name="kernel")
667 bias = tf.compat.v1.get_variable(
668 shape=[self.units,],
669 initializer=init_ops.zeros_initializer(),
670 name="bias")
671 out = tf.compat.v1.math.matmul(out, kernel)
672 out = tf.compat.v1.nn.bias_add(out, bias)
673 return out
674
675 # Create a layer that can be used as a standard keras layer
676 layer = WrappedDoubleDenseLayer(10)
677
678 # call the layer on inputs
679 layer(...)
680
681 # Variables created/used within the scope will be tracked by the layer
682 layer.weights
683 layer.trainable_variables
684
685 # Regularization losses will be captured in layer.losses after a call,
686 # just like any other Keras layer
687 reg_losses = layer.losses
688 ```
689
690 Regularization losses:
691 Any regularizers specified in the `get_variable` calls or `compat.v1.layer`
692 creations will get captured by this wrapper layer. Regularization losses
693 are accessible in `layer.losses` after a call just like in a standard
694 Keras layer, and will be captured by any model that includes this layer.
695
696 Variable scope / variable reuse:
697 variable-scope based reuse in the `forward_pass` will be respected,
698 and work like variable-scope based reuse in TF1.
699
700 Variable Names/Pre-trained checkpoint loading:
701 variable naming from get_variable and `compat.v1.layer` layers will match
702 the TF1 names, so you should be able to re-use your old name-based
703 checkpoints.
704
705 Training Arg in `forward_pass`:
706 Keras will pass a `training` arg to this layer similarly to how it
707 passes `training` to other layers in TF2. See more details in the docs
708 on `tf.keras.layers.Layer` to understand what will be passed and when.
709 Note: tf.compat.v1.layers are usually not called with `training=None`,
710 so the training arg to `forward_pass` might not feed through to them
711 unless you pass it to their calls explicitly.
712
713 Call signature of the forward pass:
714 The semantics of the forward pass signature roughly match the standard
715 Keras layer `call` signature, except that a `training` arg will *always*
716 be passed, so your `forward_pass` must accept either.
717
718 Limitations:
719 * TF2 will not prune unused variable updates (or unused outputs). You may
720 need to adjust your forward pass code to avoid computations or variable
721 updates that you don't intend to use. (E.g. by adding a flag to the
722 `forward_pass` call signature and branching on it).
723 * Avoid Nesting variable creation in tf.function inside of `forward_pass`
724 While the layer may safetely be used from inside a `tf.function`, using
725 a function inside of `forward_pass` will break the variable scoping.
726 * TBD: Nesting keras layers/models or other `VariableScopeWrapperLayer`s
727 directly in `forward_pass` may not work correctly just yet.
728 Support for this/instructions for how to do this is sill being worked on.
729
730 Coming soon: A better guide, testing/verification guide.
731 """
732
733 def __init__(self, **kwargs):
734 super().__init__(**kwargs)
735 # Relies on keras layers tracking Modules
736 self.tracker = VariableAndLossTracker()
737 # May need to inspect func to see if it should pass a `training` arg or not
738
739 def forward_pass(self, *args, **kwargs):
740 raise NotImplementedError
741
742 def call(self, *args, **kwargs):
743 with self.tracker.scope():
744 out = self.forward_pass(*args, **kwargs)
745 if not self._eager_losses:
746 # We have to record regularization losses in the call as if they
747 # are activity losses.
748 # So, don't double-count regularization losses if the layer is used
749 # multiple times in a model
750 for loss in self.tracker.get_regularization_losses().values():
751 self.add_loss(loss)
752 return out