Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/legacy_tf_layers/base.py: 20%
223 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Contains the base Layer class, from which all layers inherit."""
17import copy
18import warnings
20from tensorflow.python.eager import context
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend
24from tensorflow.python.keras.engine import base_layer
25from tensorflow.python.keras.engine import base_layer_utils
26from tensorflow.python.keras.legacy_tf_layers import variable_scope_shim
27from tensorflow.python.keras.mixed_precision import policy
28from tensorflow.python.keras.utils import tf_contextlib
29from tensorflow.python.ops import variable_scope as vs
30from tensorflow.python.ops import variables as tf_variables
31from tensorflow.python.trackable import base as trackable
32from tensorflow.python.util import nest
33from tensorflow.python.util.tf_export import keras_export
35# Avoid breaking users who directly import this symbol from this file.
36# TODO(fchollet): remove this.
37InputSpec = base_layer.InputSpec # pylint: disable=invalid-name
39_KERAS_STYLE_SCOPE = False
42@keras_export(
43 v1=['keras.__internal__.legacy.layers.experimental.keras_style_scope'])
44@tf_contextlib.contextmanager
45def keras_style_scope():
46 """Use Keras-style variable management.
48 All tf.layers and tf RNN cells created in this scope use Keras-style
49 variable management. Creating such layers with a scope= argument is
50 disallowed, and reuse=True is disallowed.
52 The purpose of this scope is to allow users of existing layers to
53 slowly transition to a Keras layers API without breaking existing
54 functionality.
56 One example of this is when using TensorFlow's RNN classes with Keras
57 Models or Networks. Because Keras models do not properly set variable
58 scopes, users of RNNs may either accidentally share scopes between two
59 different models, or get errors about variables that already exist.
61 Example:
63 ```python
64 class RNNModel(tf.keras.Model):
66 def __init__(self, name):
67 super(RNNModel, self).__init__(name=name)
68 self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
69 [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])
71 def call(self, input, state):
72 return self.rnn(input, state)
74 model_1 = RNNModel("model_1")
75 model_2 = RNNModel("model_2")
77 # OK
78 output_1, next_state_1 = model_1(input, state)
79 # Raises an error about trying to create an already existing variable.
80 output_2, next_state_2 = model_2(input, state)
81 ```
83 The solution is to wrap the model construction and execution in a keras-style
84 scope:
86 ```python
87 with keras_style_scope():
88 model_1 = RNNModel("model_1")
89 model_2 = RNNModel("model_2")
91 # model_1 and model_2 are guaranteed to create their own variables.
92 output_1, next_state_1 = model_1(input, state)
93 output_2, next_state_2 = model_2(input, state)
95 assert len(model_1.weights) > 0
96 assert len(model_2.weights) > 0
97 assert(model_1.weights != model_2.weights)
98 ```
100 Yields:
101 A keras layer style scope.
102 """
103 global _KERAS_STYLE_SCOPE
104 stack = _KERAS_STYLE_SCOPE
105 _KERAS_STYLE_SCOPE = True
106 try:
107 yield
108 finally:
109 _KERAS_STYLE_SCOPE = stack
112@keras_export(
113 v1=['keras.__internal__.legacy.layers.experimental.set_keras_style'])
114def set_keras_style():
115 """Use Keras-style variable management.
117 All tf.layers and tf RNN cells created after keras style ha been enabled
118 use Keras-style variable management. Creating such layers with a
119 scope= argument is disallowed, and reuse=True is disallowed.
121 The purpose of this function is to allow users of existing layers to
122 slowly transition to Keras layers API without breaking existing
123 functionality.
125 For more details, see the documentation for `keras_style_scope`.
127 Note, once keras style has been set, it is set globally for the entire
128 program and cannot be unset.
130 Example:
132 ```python
133 set_keras_style()
135 model_1 = RNNModel(name="model_1")
136 model_2 = RNNModel(name="model_2")
138 # model_1 and model_2 are guaranteed to create their own variables.
139 output_1, next_state_1 = model_1(input, state)
140 output_2, next_state_2 = model_2(input, state)
142 assert len(model_1.weights) > 0
143 assert len(model_2.weights) > 0
144 assert(model_1.weights != model_2.weights)
145 ```
146 """
147 global _KERAS_STYLE_SCOPE
148 _KERAS_STYLE_SCOPE = True
151def _is_in_keras_style_scope():
152 global _KERAS_STYLE_SCOPE
153 return _KERAS_STYLE_SCOPE
156@keras_export(v1=['keras.__internal__.legacy.layers.Layer'])
157class Layer(base_layer.Layer):
158 """Base layer class.
160 It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
161 instead.
163 Args:
164 trainable: Boolean, whether the layer's variables should be trainable.
165 name: String name of the layer.
166 dtype: Default dtype of the layer's weights (default of `None` means use the
167 type of the first input).
169 Read-only properties:
170 name: The name of the layer (string).
171 dtype: Default dtype of the layer's weights (default of `None` means use the
172 type of the first input).
173 trainable_variables: List of trainable variables.
174 non_trainable_variables: List of non-trainable variables.
175 variables: List of all variables of this layer, trainable and
176 non-trainable.
177 updates: List of update ops of this layer.
178 losses: List of losses added by this layer.
179 trainable_weights: List of variables to be included in backprop.
180 non_trainable_weights: List of variables that should not be
181 included in backprop.
182 weights: The concatenation of the lists trainable_weights and
183 non_trainable_weights (in this order).
185 Mutable properties:
186 trainable: Whether the layer should be trained (boolean).
187 input_spec: Optional (list of) `InputSpec` object(s) specifying the
188 constraints on inputs that can be accepted by the layer.
189 """
191 def __init__(self, trainable=True, name=None, dtype=None,
192 **kwargs):
193 # For backwards compatibility, legacy layers do not use `ResourceVariable`
194 # by default.
195 self._use_resource_variables = False
196 scope = kwargs.pop('_scope', None)
197 self._reuse = kwargs.pop('_reuse', None)
199 # Avoid an incorrect lint error
200 self._trainable_weights = []
201 self.built = False
203 if dtype is None:
204 # Indicates to infer dtype from inputs. When the V2 dtype behavior is
205 # enabled, Keras layers default their dtype to floatx instead, so we pass
206 # an "_infer" policy to keep the old V1 behavior.
207 dtype = policy.Policy('_infer')
209 if 'autocast' not in kwargs:
210 kwargs['autocast'] = False
212 # Mark that legacy layers should not be instrumented as Keras usage
213 self._disable_keras_instrumentation = True
215 super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
216 **kwargs)
218 if _is_in_keras_style_scope():
219 if scope is not None:
220 raise ValueError(
221 'scope argument not allowed when keras style layers are enabled, '
222 'but saw: {}'.format(scope))
223 if self._reuse is not None:
224 raise ValueError(
225 'reuse argument not allowed when keras style layers are enabled, '
226 'but saw: {}'.format(self._reuse))
227 self._keras_style = True
228 else:
229 self._keras_style = False
231 self._call_has_scope_arg = 'scope' in self._call_fn_args
232 if scope:
233 with vs.variable_scope(scope) as captured_scope:
234 self._scope = captured_scope
235 else:
236 self._scope = None
237 self._current_scope = None
239 # We no longer track graph in tf.layers layers. This property is only kept to
240 # maintain API backward compatibility.
241 @property
242 def graph(self):
243 warnings.warn('`Layer.graph` is deprecated and '
244 'will be removed in a future version. '
245 'Please stop using this property because tf.layers layers no '
246 'longer track their graph.')
247 if context.executing_eagerly():
248 raise RuntimeError('Layer.graph not supported when executing eagerly.')
249 return None
251 def _init_set_name(self, name):
252 # Determine layer name (non-unique).
253 if isinstance(name, vs.VariableScope):
254 base_name = name.name
255 self._name, _ = self._make_unique_name()
256 else:
257 base_name = name
258 self._name = name
259 if not name:
260 self._name, base_name = self._make_unique_name()
261 self._base_name = base_name
263 def _make_unique_name(self, name_uid_map=None, avoid_names=None,
264 namespace='', zero_based=False):
265 base_name = base_layer.to_snake_case(self.__class__.__name__)
266 name = backend.unique_object_name(
267 base_name,
268 name_uid_map=name_uid_map,
269 avoid_names=avoid_names,
270 namespace=namespace,
271 zero_based=zero_based)
272 return (name, base_name)
274 @property
275 def scope_name(self):
276 if not self._scope:
277 raise ValueError('No name available for layer scope because the layer "' +
278 self._name + '" has not been used yet. The scope name ' +
279 ' is determined the first time the layer instance is ' +
280 'called. You must therefore call the layer before ' +
281 'querying `scope_name`.')
282 return self._scope.name
284 def add_loss(self, losses, inputs=None):
285 previous_losses_length = len(self._losses)
286 previous_callable_losses_length = len(self._callable_losses)
287 super(Layer, self).add_loss(losses, inputs=inputs)
288 if not context.executing_eagerly():
289 # TODO(fchollet): deprecate collection below.
290 new_losses = self._losses[previous_losses_length:]
291 new_callable_losses = self._callable_losses[
292 previous_callable_losses_length:]
293 for regularizer in new_callable_losses:
294 loss_tensor = regularizer()
295 if loss_tensor is not None:
296 new_losses.append(loss_tensor)
297 _add_elements_to_collection(
298 new_losses,
299 ops.GraphKeys.REGULARIZATION_LOSSES)
301 def _name_scope(self): # pylint: disable=method-hidden
302 """Determines op naming for the Layer."""
303 if self._keras_style:
304 return super(Layer, self)._name_scope()
305 return self._current_scope.original_name_scope
307 def _set_scope(self, scope=None):
308 if self._scope is None:
309 # If constructed with _scope=None, lazy setting of scope.
310 if self._reuse:
311 with vs.variable_scope(
312 scope if scope is not None else self._base_name) as captured_scope:
313 self._scope = captured_scope
314 else:
315 with vs.variable_scope(
316 scope, default_name=self._base_name) as captured_scope:
317 self._scope = captured_scope
319 def add_weight(self,
320 name,
321 shape,
322 dtype=None,
323 initializer=None,
324 regularizer=None,
325 trainable=None,
326 constraint=None,
327 use_resource=None,
328 synchronization=vs.VariableSynchronization.AUTO,
329 aggregation=vs.VariableAggregation.NONE,
330 partitioner=None,
331 **kwargs):
332 """Adds a new variable to the layer, or gets an existing one; returns it.
334 Args:
335 name: variable name.
336 shape: variable shape.
337 dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
338 initializer: initializer instance (callable).
339 regularizer: regularizer instance (callable).
340 trainable: whether the variable should be part of the layer's
341 "trainable_variables" (e.g. variables, biases)
342 or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
343 Note, if the current variable scope is marked as non-trainable
344 then this parameter is ignored and any added variables are also
345 marked as non-trainable. `trainable` defaults to `True` unless
346 `synchronization` is set to `ON_READ`.
347 constraint: constraint instance (callable).
348 use_resource: Whether to use `ResourceVariable`.
349 synchronization: Indicates when a distributed a variable will be
350 aggregated. Accepted values are constants defined in the class
351 `tf.VariableSynchronization`. By default the synchronization is set to
352 `AUTO` and the current `DistributionStrategy` chooses
353 when to synchronize. If `synchronization` is set to `ON_READ`,
354 `trainable` must not be set to `True`.
355 aggregation: Indicates how a distributed variable will be aggregated.
356 Accepted values are constants defined in the class
357 `tf.VariableAggregation`.
358 partitioner: (optional) partitioner instance (callable). If
359 provided, when the requested variable is created it will be split
360 into multiple partitions according to `partitioner`. In this case,
361 an instance of `PartitionedVariable` is returned. Available
362 partitioners include `tf.compat.v1.fixed_size_partitioner` and
363 `tf.compat.v1.variable_axis_size_partitioner`. For more details, see
364 the documentation of `tf.compat.v1.get_variable` and the "Variable
365 Partitioners and Sharding" section of the API guide.
366 **kwargs: Additional keyword arguments.
368 Returns:
369 The created variable. Usually either a `Variable` or `ResourceVariable`
370 instance. If `partitioner` is not `None`, a `PartitionedVariable`
371 instance is returned.
373 Raises:
374 RuntimeError: If called with partitioned variable regularization and
375 eager execution is enabled.
376 ValueError: When trainable has been set to True with synchronization
377 set as `ON_READ`.
378 """
379 for kwarg in kwargs:
380 if kwarg != 'experimental_autocast':
381 raise TypeError('Unknown keyword argument:', kwarg)
382 if self._keras_style:
383 return super(Layer, self).add_weight(
384 name=name,
385 shape=shape,
386 dtype=dtype,
387 initializer=initializer,
388 regularizer=regularizer,
389 trainable=trainable and self.trainable,
390 constraint=constraint,
391 use_resource=use_resource,
392 synchronization=vs.VariableSynchronization.AUTO,
393 aggregation=vs.VariableAggregation.NONE,
394 partitioner=partitioner,
395 **kwargs)
397 if synchronization == vs.VariableSynchronization.ON_READ:
398 if trainable:
399 raise ValueError(
400 'Synchronization value can be set to '
401 'VariableSynchronization.ON_READ only for non-trainable variables. '
402 'You have specified trainable=True and '
403 'synchronization=VariableSynchronization.ON_READ.')
404 else:
405 # Set trainable to be false when variable is to be synced on read.
406 trainable = False
407 elif trainable is None:
408 trainable = True
410 def _should_add_regularizer(variable, existing_variable_set):
411 if base_layer_utils.is_split_variable(variable):
412 for var in variable:
413 if var in existing_variable_set:
414 return False
415 return True
416 else:
417 return variable not in existing_variable_set
419 init_graph = None
420 if not context.executing_eagerly():
421 default_graph = ops.get_default_graph()
422 if default_graph.building_function:
423 with ops.init_scope():
424 # Retrieve the variables from the graph into which variables
425 # will be lifted; if initialization ops will be lifted into
426 # the eager context, then there is nothing to retrieve, since variable
427 # collections are not supported when eager execution is enabled.
428 if not context.executing_eagerly():
429 init_graph = ops.get_default_graph()
430 existing_variables = set(tf_variables.global_variables())
431 else:
432 # Initialization ops will not be lifted out of the default graph.
433 init_graph = default_graph
434 existing_variables = set(tf_variables.global_variables())
436 if dtype is None:
437 dtype = self.dtype or dtypes.float32
439 self._set_scope(None)
440 reuse = self.built or self._reuse
441 prev_len_trainable = len(self._trainable_weights)
442 with vs.variable_scope(
443 self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
444 self._current_scope = scope
445 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable
446 use_resource = (use_resource or
447 self._use_resource_variables or
448 scope.use_resource)
449 if initializer is None:
450 initializer = scope.initializer
451 variable = super(Layer, self).add_weight(
452 name,
453 shape,
454 dtype=dtypes.as_dtype(dtype),
455 initializer=initializer,
456 trainable=trainable and self.trainable,
457 constraint=constraint,
458 partitioner=partitioner,
459 use_resource=use_resource,
460 synchronization=synchronization,
461 aggregation=aggregation,
462 getter=vs.get_variable,
463 **kwargs)
465 if regularizer:
466 if (ops.executing_eagerly_outside_functions()
467 or _should_add_regularizer(variable, existing_variables)):
468 self._handle_weight_regularization(name, variable, regularizer)
469 var_store = vs._get_default_variable_store() # pylint: disable=protected-access
470 # When the shim to get variable scope working in TF2 is used,
471 # We need to explicitly make the shim track the regularization
472 # losses as the collections will not be accessible.
473 if hasattr(var_store, 'add_regularizer'):
474 var_store.add_regularizer(variable, regularizer)
476 if init_graph is not None:
477 # Handle edge case where a custom getter has overridden `trainable`.
478 # There is one known occurrence of this, in unit test
479 # testBasicRNNCellNotTrainable in
480 # contrib.rnn.python.kernel_tests.core_rnn_cell_test
481 with init_graph.as_default():
482 trainable_variables = tf_variables.trainable_variables()
483 if (trainable and self.trainable and
484 variable not in trainable_variables):
485 # A custom getter / variable scope overrode the trainable flag.
486 extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
487 self._trainable_weights = self._trainable_weights[
488 :prev_len_trainable]
489 self._non_trainable_weights += extra_trainable_vars
490 return variable
492 def __call__(self, inputs, *args, **kwargs):
493 """Wraps `call`, applying pre- and post-processing steps.
495 Args:
496 inputs: input tensor(s).
497 *args: additional positional arguments to be passed to `self.call`.
498 **kwargs: additional keyword arguments to be passed to `self.call`.
499 **Note**: kwarg `scope` is reserved for use by the layer.
501 Returns:
502 Output tensor(s).
504 Note:
505 - If the layer's `call` method takes a `scope` keyword argument,
506 this argument will be automatically set to the current variable scope.
507 - If the layer's `call` method takes a `mask` argument (as some Keras
508 layers do), its default value will be set to the mask generated
509 for `inputs` by the previous layer (if `input` did come from
510 a layer that generated a corresponding mask, i.e. if it came from
511 a Keras layer with masking support.
513 Raises:
514 ValueError: if the layer's `call` method returns None (an invalid value).
515 """
516 scope = kwargs.pop('scope', None)
518 if self._keras_style:
519 if scope is not None:
520 raise ValueError(
521 'scope argument not allowed when keras style layers are enabled, '
522 'but saw: {}'.format(scope))
523 return super(Layer, self).__call__(inputs, *args, **kwargs)
525 self._set_scope(scope)
527 if self.built:
528 try:
529 # Some classes which inherit from Layer do not use its constructor, so
530 # rather than initializing to None we check for an AttributeError.
531 scope_context_manager = self._always_reuse_variable_scope # pylint: disable=access-member-before-definition
532 except AttributeError:
533 scope_context_manager = None
535 if scope_context_manager is None:
536 # From this point we will always set reuse=True, so create a "final"
537 # variable scope with this setting. We avoid re-creating variable scopes
538 # after this point as an optimization.
539 scope_context_manager = vs.variable_scope(
540 self._scope, reuse=True, auxiliary_name_scope=False)
542 # Do not cache variable scopes if Eager mode is enabled. If Eager mode
543 # is enabled then we don't want to reuse scopes because the cached scope
544 # might be from a FuncGraph or Eager scope we are no longer in.
545 if not ops.executing_eagerly_outside_functions():
546 self._always_reuse_variable_scope = scope_context_manager
547 else:
548 scope_context_manager = vs.variable_scope(
549 self._scope, reuse=self._reuse, auxiliary_name_scope=False)
551 with scope_context_manager as scope:
552 self._current_scope = scope
554 try:
555 call_has_scope_arg = self._call_has_scope_arg
556 except AttributeError:
557 self._call_fn_args = variable_scope_shim.fn_args(self.call)
558 self._call_has_scope_arg = 'scope' in self._call_fn_args
559 call_has_scope_arg = self._call_has_scope_arg
560 if call_has_scope_arg:
561 kwargs['scope'] = scope
563 # Actually call layer
564 outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
566 if not context.executing_eagerly():
567 # Update global default collections.
568 _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
569 return outputs
571 def __deepcopy__(self, memo):
572 no_copy = set(['_graph', '_thread_local', '_metrics_lock'])
573 shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
574 cls = self.__class__
575 result = cls.__new__(cls)
576 memo[id(self)] = result
577 for k, v in self.__dict__.items():
578 if k in no_copy:
579 setattr(result, k, v)
580 elif k in shallow_copy:
581 setattr(result, k, copy.copy(v))
582 elif base_layer.is_tensor_or_tensor_list(v):
583 setattr(result, k, v)
584 else:
585 setattr(result, k, copy.deepcopy(v, memo))
586 return result
588 def __setattr__(self, value, name):
589 # By-pass the automatic dependency tracking performed by the parent Layer.
590 super(trackable.Trackable, self).__setattr__(value, name) # pylint: disable=bad-super-call
592 @property
593 def _is_legacy_layer(self):
594 """Used by keras to check compatibility. This should not be overridden."""
595 return True
598def _add_elements_to_collection(elements, collection_list):
599 if context.executing_eagerly():
600 raise RuntimeError('Using collections from Layers not supported in Eager '
601 'mode. Tried to add %s to %s' % (elements,
602 collection_list))
603 elements = nest.flatten(elements)
604 collection_list = nest.flatten(collection_list)
605 for name in collection_list:
606 collection = ops.get_collection_ref(name)
607 collection_set = {id(e) for e in collection}
608 for element in elements:
609 if id(element) not in collection_set:
610 collection.append(element)