Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py: 27%
314 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains private utilities used mainly by the base Layer class."""
17import functools
18import threading
20from tensorflow.python import tf2
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.keras import backend
29from tensorflow.python.keras.utils import control_flow_util
30from tensorflow.python.keras.utils import tf_inspect
31from tensorflow.python.keras.utils import tf_utils
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import variable_v1
34from tensorflow.python.ops import variables as tf_variables
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.trackable import base as tracking
37from tensorflow.python.training.saving import saveable_object_util
38from tensorflow.python.util import nest
39from tensorflow.python.util.tf_export import keras_export
41_call_context = threading.local()
44def create_mean_metric(value, name=None):
45 # import keras will import base_layer and then this module, and metric relies
46 # on base_layer, which result into a cyclic dependency.
47 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
48 metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
49 return metric_obj, metric_obj(value)
52def make_variable(name,
53 shape=None,
54 dtype=dtypes.float32,
55 initializer=None,
56 trainable=None,
57 caching_device=None,
58 validate_shape=True,
59 constraint=None,
60 use_resource=None,
61 collections=None,
62 synchronization=tf_variables.VariableSynchronization.AUTO,
63 aggregation=tf_variables.VariableAggregation.NONE,
64 partitioner=None): # pylint: disable=unused-argument
65 """Temporary util to create a variable (relies on `variable_scope.variable`).
67 Some reuse-related technicalities prevent us from using
68 `variable_scope.get_variable()` directly, so we use a subcomponent
69 that has fewer constraints (`variable_scope.variable()`).
71 In the longer term, it seems like a similar "default variable creator" method
72 should exist in `Trackable` instead. When this happens, we can get
73 rid of this temporary solution.
75 TODO(fchollet): remove this method when no longer needed.
77 Args:
78 name: Variable name.
79 shape: Variable shape.
80 dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
81 initializer: Initializer instance (callable).
82 trainable: Whether the variable should be part of the layer's
83 "trainable_variables" (e.g. variables, biases)
84 or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
85 Note, if the current variable scope is marked as non-trainable
86 then this parameter is ignored and any added variables are also
87 marked as non-trainable. `trainable` defaults to `True` unless
88 `synchronization` is set to `ON_READ`.
89 caching_device: Passed to `tf.Variable`.
90 validate_shape: Passed to `tf.Variable`.
91 constraint: Constraint instance (callable).
92 use_resource: Whether to use a `ResourceVariable`.
93 collections: List of graph collections keys. The new variable is added to
94 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
95 synchronization: Indicates when a distributed a variable will be
96 aggregated. Accepted values are constants defined in the class
97 `tf.VariableSynchronization`. By default the synchronization is set to
98 `AUTO` and the current `DistributionStrategy` chooses
99 when to synchronize. If `synchronization` is set to `ON_READ`,
100 `trainable` must not be set to `True`.
101 aggregation: Indicates how a distributed variable will be aggregated.
102 Accepted values are constants defined in the class
103 `tf.VariableAggregation`.
104 partitioner: Not handled at this time.
106 Returns:
107 Variable instance.
108 """
109 initializing_from_value = False
110 if initializer is not None and not callable(initializer):
111 initializing_from_value = True
113 if initializing_from_value:
114 init_val = initializer
115 variable_dtype = None
116 else:
117 # Instantiate initializer if provided initializer is a type object.
118 if tf_inspect.isclass(initializer):
119 initializer = initializer()
120 init_val = functools.partial(initializer, shape, dtype=dtype)
121 variable_dtype = dtype.base_dtype
122 if use_resource is None:
123 use_resource = True
125 # TODO(apassos,rohanj) figure out how to remove collections from here so we
126 # can remove the V1.
127 variable_shape = tensor_shape.TensorShape(shape)
128 return variable_v1.VariableV1(
129 initial_value=init_val,
130 name=name,
131 trainable=trainable,
132 caching_device=caching_device,
133 dtype=variable_dtype,
134 validate_shape=validate_shape,
135 constraint=constraint,
136 use_resource=use_resource,
137 collections=collections,
138 synchronization=synchronization,
139 aggregation=aggregation,
140 shape=variable_shape if variable_shape else None)
143def collect_previous_mask(input_tensors):
144 """Retrieves the output mask(s) of the previous node.
146 Args:
147 input_tensors: An arbitrary structure of Tensors.
149 Returns:
150 A mask tensor or list of mask tensors.
151 """
153 def _collect_previous_mask(x):
154 return getattr(x, '_keras_mask', None)
156 return nest.map_structure(_collect_previous_mask, input_tensors)
159def have_all_keras_metadata(tensors):
160 return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
163def generate_placeholders_from_shape(shape):
164 return array_ops.placeholder(shape=shape, dtype=backend.floatx())
167def create_keras_history(tensors):
168 """Wraps TensorFlow Operations for compatibility with the Functional API.
170 This method checks to see if a Tensor in `tensors` is missing Keras metadata
171 and has its origin in a Keras `Input` Layer. If so, this method will replace
172 the raw TensorFlow Operations that created this tensor with
173 `TensorFlowOpLayer` instances that create identical operations.
175 Any Tensors not originating from a Keras `Input` Layer will be treated as
176 constants when constructing `TensorFlowOpLayer` instances.
178 Args:
179 tensors: A structure of Tensors, some of which come from raw TensorFlow
180 operations and need to have Keras metadata assigned to them.
182 Returns:
183 created_layers: List. The `TensorFlowOpLayer` instances created to wrap
184 the raw Tensorflow operations.
185 """
186 _, created_layers = _create_keras_history_helper(tensors, set(), [])
187 return created_layers
190# Unsafe Internal attribute.
191# If True, Keras will not evaluate the constant-foldable inputs to tf op
192# layers in TF1 graphs. This *might* speed up model construction time in
193# certain settings, but it means
194# the models will not be serializable/deserializable via get_config
195# (Only via Savedmodels). It may also change the semantics of whether
196# generated random numbers are generated once and re-used, or recomputed
197# each time.
198# Note: This path triggers for TPUEstimators / xla compiled graphs regardless
199# of this setting.
200_UNSAFE_GRAPH_OP_LAYER_CREATION = False
203def _create_keras_history_helper(tensors, processed_ops, created_layers):
204 """Helper method for `create_keras_history`.
206 Args:
207 tensors: A structure of Tensors for which to create Keras metadata.
208 processed_ops: Set. TensorFlow operations that have already been wrapped in
209 `TensorFlowOpLayer` instances.
210 created_layers: List. The `TensorFlowOpLayer` instances created.
212 Returns:
213 Tuple. First element is the updated set of TensorFlow Operations that
214 have been wrapped in `TensorFlowOpLayer` instances. Second element is
215 a list of the `TensorFlowOpLayer` instances created.
216 """
217 if ops.executing_eagerly_outside_functions():
218 raise ValueError(
219 '`create_keras_history` should only be called if eager is disabled!')
220 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
221 # Cannot be imported at top because of circular dependencies.
222 # TODO(omalleyt): Resolve circular dependency.
223 from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top
224 tensor_list = nest.flatten(tensors)
225 sparse_ops = []
226 ragged_tensors = []
227 for tensor in tensor_list:
228 if getattr(tensor, '_keras_history', None) is not None:
229 continue
230 if isinstance(
231 tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
232 sparse_ops.append(tensor.op)
233 continue
234 if tf_utils.is_ragged(tensor):
235 # Ragged tensors don't have an op property
236 ragged_tensors.append(tensor)
237 continue
238 op = tensor.op # The Op that created this Tensor.
239 if op not in processed_ops:
240 # Recursively set `_keras_history`.
241 op_inputs = list(op.inputs)
242 constants = {}
243 layer_inputs = []
244 for i, op_input in enumerate(op_inputs):
245 if uses_keras_history(op_input):
246 layer_inputs.append(op_input)
247 else:
248 # Treat any value not originating from a `keras.Input` as
249 # a constant. Variables cannot be supported.
250 ds_with_session = (
251 distribute_lib.in_cross_replica_context() and
252 not ops.executing_eagerly_outside_functions())
253 using_xla = control_flow_util.GraphOrParentsInXlaContext(
254 ops.get_default_graph())
255 if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION:
256 # In Legacy Graph mode, evaluating here makes Session be
257 # configured improperly. The downside of this is that saving
258 # via `get_config` breaks, but SavedModel still works.
259 constants[i] = op_input
260 else:
261 with ops.init_scope():
262 constants[i] = backend.function([], op_input)([])
263 layer_inputs = unnest_if_single_tensor(layer_inputs)
264 processed_ops, created_layers = _create_keras_history_helper(
265 layer_inputs, processed_ops, created_layers)
266 name = op.name
267 node_def = op.node_def.SerializeToString()
268 op_layer = base_layer.TensorFlowOpLayer(
269 node_def, constants=constants, name=name)
270 created_layers.append(op_layer)
271 op_layer._set_connectivity_metadata( # pylint: disable=protected-access
272 args=(layer_inputs,),
273 kwargs={},
274 outputs=op.outputs)
275 processed_ops.update([op])
276 if sparse_ops or ragged_tensors:
277 lambda_example = """
278 weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
279 output = tf.keras.layers.Lambda(weights_mult)(input)
280 """
281 raise ValueError(
282 'Tensorflow ops that generate ragged or sparse tensor '
283 'outputs are currently not supported by Keras automatic '
284 'op wrapping. Please wrap these ops in a Lambda layer: '
285 '\n\n```\n{example}\n```\n'
286 'Sparse ops encountered: {sparse_ops}\n'
287 'Ragged tensors encountered: {ragged_tensors}\n'.format(
288 example=lambda_example,
289 sparse_ops=str(sparse_ops),
290 ragged_tensors=str(ragged_tensors)))
291 return processed_ops, created_layers
294def unnest_if_single_tensor(input_tensors):
295 # Preserve compatibility with older configs
296 flat_input_tensors = nest.flatten(input_tensors)
297 # If this is a single element but not a dict, unwrap. If this is a dict,
298 # assume the first layer expects a dict (as is the case with a
299 # DenseFeatures layer); pass through.
300 if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
301 input_tensors = flat_input_tensors[0]
302 return input_tensors
305def needs_keras_history(tensors, ignore_call_context=False):
306 """Check if any Tensors need to be wrapped in TensorFlowOpLayers.
308 This will never return True inside a sublayer, because sublayers
309 do not need to create Keras History. Otherwise, this returns True
310 if one or more of `tensors` originates from a `keras.Input` and
311 does not have `_keras_history` set.
313 Args:
314 tensors: An arbitrary nested structure of Tensors.
315 ignore_call_context: Whether to ignore the check of if currently
316 outside of a `call` context. This is `True` when creating
317 KerasHistory inside `Node`, where we always know that Tensors
318 are being used with the Functional API.
320 Returns:
321 Bool, whether at least one Tensor needs to be wrapped.
322 """
323 input_tensors = nest.flatten(tensors)
324 if call_context().in_call and not ignore_call_context:
325 return False
326 if all(
327 getattr(tensor, '_keras_history', None) is not None
328 for tensor in input_tensors):
329 # KerasHistory already set.
330 return False
331 return uses_keras_history(tensors)
334def is_in_keras_graph():
335 """Returns if currently executing inside of a Keras graph."""
336 return call_context().in_keras_graph
339def is_in_eager_or_tf_function():
340 """Returns if in eager mode or inside of a tf.function."""
341 return context.executing_eagerly() or is_in_tf_function()
344def is_in_tf_function():
345 """Returns if inside of a tf.function."""
346 # Check if running in V1 graph mode.
347 if not ops.executing_eagerly_outside_functions():
348 return False
349 if not ops.inside_function():
350 return False
351 # Check if inside Keras FuncGraph.
352 if is_in_keras_graph():
353 return False
354 # Check for a v1 `wrap_function` FuncGraph.
355 graph = ops.get_default_graph()
356 if (getattr(graph, 'name', False) and
357 graph.name.startswith('wrapped_function')):
358 return False
359 return True
362def uses_keras_history(tensors):
363 """Check if at least one Tensor originates from a `keras.Input`.
365 This is `True` if at least one Tensor has its origin in a `keras.Input`.
366 Any Tensor that originates from a `keras.Input` will have a dependency
367 Tensor with a `_keras_history` attribute attached. Tensors that have
368 already been checked to not originate from a `keras.Input`
369 are marked as `_keras_history_checked`.
371 Args:
372 tensors: An arbitrary nested structure of Tensors.
374 Returns:
375 Bool, whether at least one Tensor originates from a `keras.Input`.
376 """
377 checked_tensors = set()
378 tensors_to_check = nest.flatten(tensors)
380 while tensors_to_check:
381 new_tensors_to_check = []
382 for tensor in tensors_to_check:
383 if id(tensor) in checked_tensors:
384 continue
386 checked_tensors.add(id(tensor))
388 if getattr(tensor, '_keras_history_checked', None) is not None:
389 continue
390 if getattr(tensor, '_keras_history', None) is not None:
391 return True
393 try:
394 new_tensors_to_check.extend(tensor.op.inputs)
395 except AttributeError:
396 # In case `tensor` is a Variable created in an Eager context.
397 pass
399 tensors_to_check = new_tensors_to_check
401 # Mark that these Tensors have been checked once for `_keras_history`,
402 # and should not be checked again for performance reasons.
403 mark_checked(tensors)
404 return False
407def mark_checked(tensors):
408 """Marks that these Tensors should not be tracked.
410 This prevents Layers from attempting to create TensorFlowOpLayers
411 for these Tensors.
413 Args:
414 tensors: An arbitrary structure of Tensors.
415 """
417 def _mark_checked(tensor):
418 tensor._keras_history_checked = True # pylint: disable=protected-access
420 nest.map_structure(_mark_checked, tensors)
423def call_context():
424 """Returns currently active `CallContext`."""
425 call_ctx = getattr(_call_context, 'call_context', None)
426 if call_ctx is None:
427 call_ctx = CallContext()
428 _call_context.call_context = call_ctx
429 return call_ctx
432class CallContext(object):
433 """Keeps track of properties currently inside a Layer/Model's `call`.
435 Attributes:
436 in_call: Whether currently inside the `call` of a Layer.
437 layer: The `Layer` whose `call` is currently active.
438 inputs: The inputs to the currently active `Layer`.
439 build_graph: Whether currently inside a Graph or FuncGraph.
440 training: Whether currently executing in training or inference mode.
441 saving: Whether currently saving to SavedModel.
442 frozen: Whether currently executing inside a `Layer` with `trainable` set to
443 `False`.
444 in_keras_graph: Whether executing inside the Keras Graph.
445 """
447 def __init__(self):
448 # Handle `in_call` separately as it is the most-read attr and reading it is
449 # on the hot path.
450 self.in_call = False
451 self._state = {
452 'layer': None,
453 'inputs': None,
454 'build_graph': False,
455 'training': None,
456 'saving': None
457 }
458 # TODO(b/150169018): This logic can be replaced after the Functional API
459 # refactor.
460 self._in_keras_graph = False
462 def enter(self, layer, inputs, build_graph, training, saving=None):
463 """Push a Layer and its inputs and state onto the current call context.
465 Args:
466 layer: The `Layer` whose `call` is currently active.
467 inputs: The inputs to the currently active `Layer`.
468 build_graph: Whether currently inside a Graph or FuncGraph.
469 training: Whether currently executing in training or inference mode.
470 saving: Whether currently saving to SavedModel.
472 Returns:
473 Context manager.
474 """
475 state = {
476 'layer': layer,
477 'inputs': inputs,
478 'build_graph': build_graph,
479 'training': training,
480 'saving': saving
481 }
482 return CallContextManager(self, state)
484 @property
485 def layer(self):
486 return self._state['layer']
488 @property
489 def inputs(self):
490 return self._state['inputs']
492 @property
493 def build_graph(self):
494 return self._state['build_graph']
496 @property
497 def training(self):
498 return self._state['training']
500 @property
501 def saving(self):
502 return self._state['saving']
504 @property
505 def frozen(self):
506 layer = self._state['layer']
507 if not layer:
508 return False
509 return not layer.trainable
511 @property
512 def in_keras_graph(self):
513 # Returns True even if in a subgraph of the Keras graph, such as those
514 # created by control flow ops.
515 if context.executing_eagerly():
516 return False
517 return (self._in_keras_graph or
518 getattr(backend.get_graph(), 'name', None) == 'keras_graph')
521class CallContextManager(object):
522 """Context manager for `CallContext`."""
524 def __init__(self, call_ctx, state):
525 self._call_ctx = call_ctx
526 self._state = state
527 self._build_graph = state['build_graph']
529 def __enter__(self):
530 call_ctx = self._call_ctx
531 self._prev_in_call = call_ctx.in_call
532 self._prev_state = call_ctx._state
534 call_ctx.in_call = True
535 call_ctx._state = self._state
537 # TODO(b/150169018): This logic can be removed after the Functional API
538 # refactor.
539 if self._build_graph:
540 self._prev_in_keras_graph = call_ctx._in_keras_graph
541 call_ctx._in_keras_graph = (
542 call_ctx._in_keras_graph or
543 getattr(backend.get_graph(), 'name', None) == 'keras_graph')
545 def __exit__(self, *exc_info):
546 call_ctx = self._call_ctx
547 call_ctx.in_call = self._prev_in_call
548 call_ctx._state = self._prev_state
550 if self._build_graph:
551 call_ctx._in_keras_graph = self._prev_in_keras_graph
554def training_arg_passed_to_call(argspec, args, kwargs):
555 """Returns whether a user passed the `training` argument in `__call__`."""
556 # `argspec.args` starts with ['self', 'inputs']
557 full_args = dict(zip(argspec.args[2:], args))
558 full_args.update(kwargs)
559 return 'training' in full_args and full_args['training'] is not None
562def is_subclassed(layer):
563 """Returns True if the object is a subclassed layer or subclassed model."""
564 return (layer.__module__.find('keras.engine') == -1 and
565 layer.__module__.find('keras.layers') == -1)
568def from_saved_model(layer):
569 """Returns whether the layer is loaded from a SavedModel."""
570 return layer.__module__.find('keras.saving.saved_model') != -1
573def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
574 """Checks that tensors passed to `add_*` method match the Keras graph.
576 When one of the `add_*` method is called inside a V2 conditional branch,
577 the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
578 We need to raise clear error messages in such cases.
580 Args:
581 tensor: Tensor to check, or `False` if it is known that an error
582 should be raised.
583 method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
584 force_raise: If an error should be raised regardless of `tensor`.
586 Raises:
587 RuntimeError: In case of an out-of-graph tensor.
588 """
589 if (force_raise or
590 (ops.executing_eagerly_outside_functions() and
591 hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)):
592 if method == 'activity_regularizer':
593 bad_example = """
594 class TestModel(tf.keras.Model):
596 def __init__(self):
597 super(TestModel, self).__init__(name='test_model')
598 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
600 def call(self, x, training=None):
601 if training:
602 return self.dense(x)
603 else:
604 return self.dense(x)
605 """
606 correct_example = """
607 class TestModel(tf.keras.Model):
609 def __init__(self):
610 super(TestModel, self).__init__(name='test_model')
611 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
613 def call(self, x, training=None):
614 return self.dense(x)
615 """
616 raise RuntimeError(
617 'You are using a layer with `activity_regularizer` in a control flow '
618 'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
619 'Please move your call to the layer with `activity_regularizer` out '
620 'of the control flow branch, e.g.:\n{correct_example}\n'
621 'You can also resolve this by marking your outer model/layer dynamic'
622 ' (eager-only) by passing `dynamic=True` to the layer constructor. '
623 'Any kind of control flow is supported with dynamic layers. '
624 'Note that using `dynamic=True` requires you to implement static '
625 'shape inference in the `compute_output_shape(input_shape)` '
626 'method.'.format(
627 bad_example=bad_example, correct_example=correct_example))
629 if method == 'add_metric':
630 bad_example = """
631 def call(self, inputs, training=None):
632 if training:
633 metric = compute_metric(inputs)
634 self.add_metric(metric, name='my_metric', aggregation='mean')
635 return inputs
636 """
637 correct_example = """
638 def call(self, inputs, training=None):
639 if training:
640 metric = compute_metric(inputs)
641 else:
642 metric = 0.
643 self.add_metric(metric, name='my_metric', aggregation='mean')
644 return inputs
645 """
646 elif method == 'add_loss':
647 bad_example = """
648 def call(self, inputs, training=None):
649 if training:
650 loss = compute_loss(inputs)
651 self.add_loss(loss)
652 return inputs
653 """
654 correct_example = """
655 def call(self, inputs, training=None):
656 if training:
657 loss = compute_loss(inputs)
658 else:
659 loss = 0.
660 self.add_loss(loss)
661 return inputs
662 """
663 else:
664 bad_example = """
665 def call(self, inputs, training=None):
666 if training:
667 self.add_update(self.w.assign_add(1))
668 return inputs
669 """
670 correct_example = """
671 def call(self, inputs, training=None):
672 if training:
673 increment = 1
674 else:
675 increment = 0
676 self.add_update(self.w.assign_add(increment))
677 return inputs
678 """
679 raise RuntimeError(
680 'You are using the method `{method}` in a control flow branch '
681 'in your layer, e.g.:\n{bad_example}\n'
682 'This is not currently supported. '
683 'Please move your call to {method} out of the control flow branch, '
684 'e.g.:\n{correct_example}\n'
685 'You can also resolve this by marking your layer '
686 'as dynamic (eager-only) by passing '
687 '`dynamic=True` to the layer constructor. '
688 'Any kind of control flow is supported with dynamic layers. '
689 'Note that using `dynamic=True` requires you '
690 'to implement static shape inference '
691 'in the `compute_output_shape(input_shape)` method.'.format(
692 method=method,
693 bad_example=bad_example,
694 correct_example=correct_example))
697def mark_as_return(outputs, acd):
698 """Marks `outputs` as the return values for automatic control deps."""
700 def _mark_as_return(tensor):
701 """Marks `tensor` as the return value for automatic control deps."""
702 if not tensor_util.is_tf_type(tensor):
703 return tensor
705 # pylint: disable=protected-access
706 return_tensor = acd.mark_as_return(tensor)
707 if getattr(tensor, '_keras_mask', None) is not None:
708 return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
709 else:
710 return_tensor._keras_mask = None
712 # Handle TensorFlow Probability attached metadata.
713 # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
714 if getattr(tensor, '_tfp_distribution', None) is not None:
715 return_tensor._tfp_distribution = tensor._tfp_distribution
717 return return_tensor
718 # pylint: enable=protected-access
720 return nest.map_structure(_mark_as_return, outputs)
723V2_DTYPE_BEHAVIOR = None
726@keras_export(v1=['keras.layers.enable_v2_dtype_behavior'])
727def enable_v2_dtype_behavior():
728 """Enable the V2 dtype behavior for Keras layers.
730 By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function
731 is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since
732 mixed precision requires V2 dtype behavior to be enabled, this function allows
733 you to use mixed precision in Keras layers if `disable_v2_behavior` has been
734 called.
736 When enabled, the dtype of Keras layers defaults to floatx (which is typically
737 float32) instead of None. In addition, layers will automatically cast
738 floating-point inputs to the layer's dtype.
740 >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
741 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
742 >>> print(layer.dtype) # float32 since V2 dtype behavior is enabled
743 float32
744 >>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled
745 >>> print(y.dtype.name)
746 float32
748 A layer author can opt-out their layer from the automatic input casting by
749 passing `autocast=False` to the base Layer's constructor. This disables the
750 autocasting part of the V2 behavior for that layer, but not the defaulting to
751 floatx part of the V2 behavior.
753 When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype
754 will default to the global policy instead of floatx. Layers will automatically
755 cast inputs to the policy's compute_dtype.
756 """
757 global V2_DTYPE_BEHAVIOR
758 V2_DTYPE_BEHAVIOR = True
761@keras_export(v1=['keras.layers.disable_v2_dtype_behavior'])
762def disable_v2_dtype_behavior():
763 """Disables the V2 dtype behavior for Keras layers.
765 See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
766 """
767 global V2_DTYPE_BEHAVIOR
768 V2_DTYPE_BEHAVIOR = False
771def v2_dtype_behavior_enabled():
772 """Returns True if the V2 dtype behavior is enabled."""
773 if V2_DTYPE_BEHAVIOR is None:
774 return tf2.enabled()
775 return V2_DTYPE_BEHAVIOR
778class TrackableWeightHandler(object):
779 """Keras wrapper for handling tracking.Trackable object saving and restoring.
781 This class handles Trackables in both V1 and V2 modes, ensuring that they can
782 be saved and restored with the correct data and without adding additional ops
783 on every save.
785 Attributes:
786 trackable: The trackable to wrap.
787 num_tensors: The number of tensors that this trackable requires for saving.
788 """
790 def __init__(self, trackable):
791 if not isinstance(trackable, tracking.Trackable):
792 raise ValueError('%s is not a Trackable object.' % (trackable,))
793 self._trackable = trackable
794 self._distribute_strategy = distribute_lib.get_strategy()
796 saveables = saveable_object_util.saveable_objects_from_trackable(
797 trackable).values()
798 # 'Saveables' won't exist when we're passed a legacy TF1 table like
799 # a StaticHashTable.
800 if not saveables:
801 self._num_tensors = 0
802 self._setter = lambda weights: None
803 self._getter = lambda: []
805 elif len(saveables) == 1:
806 saveable = list(saveables)[0]
808 if ops.executing_eagerly_outside_functions():
809 # If we're in eager mode, we need to defer calling the Trackable's
810 # saveable() callable until data export time.
811 # However, it is safe to call the saveable as many times as we want, so
812 # we will call it now to figure out how many tensors this Trackable will
813 # produce.
814 self._saveable = saveable
815 self._num_tensors = len(self._saveable().specs)
816 self._setter = lambda weights: self._saveable().restore(weights, None)
817 self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
818 else:
819 # If we're in Graph mode, we need to evaluate the Saveable only once and
820 # cache the resulting restore graph. Failing to do this will result in
821 # new assignment ops being added to the graph each time set_weights() is
822 # called.
823 self._placeholder_tensors = []
824 self._saveable = saveable()
825 self._num_tensors = len(self._saveable.specs)
826 for spec in self._saveable.specs:
827 tensor = spec.tensor
828 self._placeholder_tensors.append(
829 array_ops.placeholder(tensor.dtype, tensor.shape))
830 self._assign_op = self._saveable.restore(self._placeholder_tensors,
831 None)
832 self._setter = self._set_weights_v1
833 self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
834 else:
835 raise ValueError('Only Trackables with one Saveable are supported. '
836 'The Trackable %s has %d Saveables.' %
837 (trackable, len(saveables)))
839 @property
840 def num_tensors(self):
841 return self._num_tensors
843 def set_weights(self, weights):
844 if len(weights) != self._num_tensors:
845 raise ValueError(
846 ('Weight handler for trackable %s received the wrong number of ' +
847 'weights: expected %s, got %s.') %
848 (self._trackable, self._num_tensors, len(weights)))
849 self._setter(weights)
851 def get_tensors(self):
852 return self._getter()
854 def _set_weights_v1(self, weights):
855 feed_dict = {}
856 for idx, tensor in enumerate(weights):
857 feed_dict[self._placeholder_tensors[idx]] = tensor
858 backend.get_session().run(self._assign_op, feed_dict)
861class StaticTableHandler(TrackableWeightHandler):
862 """Wrapper for handling weight collection for static hash tables."""
864 def __init__(self, getter_lambda): # pylint: disable=super-init-not-called
865 self._num_tensors = 2
866 self._getter = getter_lambda
867 self._distribute_strategy = distribute_lib.get_strategy()
869 def raise_error(_):
870 raise RuntimeError('This layer contains a static lookup table, which '
871 'cannot be changed via set_weights().')
873 self._setter = raise_error
876def no_ragged_support(inputs, layer_name):
877 input_list = nest.flatten(inputs)
878 if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
879 raise ValueError('Layer %s does not support RaggedTensors as input. '
880 'Inputs received: %s. You can try converting your '
881 'input to an uniform tensor.' % (layer_name, inputs))
884def is_split_variable(v):
885 """Returns True if `v` is either a PartionedVariable or a ShardedVariable."""
886 return hasattr(v, '_variable_list') or hasattr(v, '_variables')
889def has_weights(obj):
890 obj_type = type(obj)
891 return (hasattr(obj_type, 'trainable_weights') and
892 hasattr(obj_type, 'non_trainable_weights') and
893 not isinstance(obj, type))
896# TODO(kathywu): This is a temporary hack. When a network of layers is revived
897# from SavedModel, only the top-level layer will have losses. This causes issues
898# in eager mode because the child layers may have graph losses
899# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
900# whenever eager losses are added to one layer, add eager losses to all
901# child layers. This causes `.losses` to only return eager losses.
902REVIVED_LOSS_PLACEHOLDER = (
903 'This layer\'s losses have been added to the parent layer.')