Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/base_layer_utils.py: 24%
299 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains private utilities used mainly by the base Layer class."""
17import functools
18import threading
20import tensorflow.compat.v1 as tf1
21import tensorflow.compat.v2 as tf
23from keras.src import backend
24from keras.src.dtensor import dtensor_api as dtensor
25from keras.src.utils import control_flow_util
26from keras.src.utils import tf_inspect
27from keras.src.utils import tf_utils
29# isort: off
30from tensorflow.python.util.tf_export import keras_export
32_call_context = threading.local()
35def create_mean_metric(value, name=None):
36 # import keras will import base_layer and then this module, and metric
37 # relies on base_layer, which result into a cyclic dependency.
38 from keras.src import metrics as metrics_module
40 metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
41 return metric_obj, metric_obj(value)
44def infer_init_val_and_dtype(initializer, dtype, shape, layout=None):
45 if initializer is not None and not callable(initializer):
46 init_val = initializer
47 variable_dtype = None
48 else:
49 # Instantiate initializer if provided initializer is a type object.
50 if tf_inspect.isclass(initializer):
51 initializer = initializer()
52 if layout:
53 init_val = functools.partial(
54 initializer, shape, dtype=dtype, layout=layout
55 )
56 else:
57 init_val = functools.partial(initializer, shape, dtype=dtype)
58 variable_dtype = dtype.base_dtype
59 return init_val, variable_dtype
62def make_variable(
63 name,
64 shape=None,
65 dtype=tf.float32,
66 initializer=None,
67 trainable=None,
68 caching_device=None,
69 validate_shape=True,
70 constraint=None,
71 use_resource=None,
72 collections=None,
73 synchronization=tf.VariableSynchronization.AUTO,
74 aggregation=tf.VariableAggregation.NONE,
75 partitioner=None,
76 layout=None,
77 experimental_enable_variable_lifting=True,
78):
79 """Util to create a variable (relies on `variable_scope.variable`).
81 Some reuse-related technicalities prevent us from using
82 `variable_scope.get_variable()` directly, so we use a subcomponent
83 that has fewer constraints (`variable_scope.variable()`).
85 In the longer term, it seems like a similar "default variable creator"
86 method should exist in `Trackable` instead. When this happens, we can get
87 rid of this temporary solution.
89 TODO(fchollet): remove this method when no longer needed.
91 Args:
92 name: Variable name.
93 shape: Variable shape.
94 dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
95 initializer: Initializer instance (callable).
96 trainable: Whether the variable should be part of the layer's
97 "trainable_variables" (e.g. variables, biases)
98 or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
99 Note, if the current variable scope is marked as non-trainable
100 then this parameter is ignored and any added variables are also
101 marked as non-trainable. `trainable` becomes `True` unless
102 `synchronization` is set to `ON_READ`. Defaults to `None`.
103 caching_device: Passed to `tf.Variable`.
104 validate_shape: Passed to `tf.Variable`.
105 constraint: Constraint instance (callable).
106 use_resource: Whether to use a `ResourceVariable`.
107 collections: List of graph collections keys. The new variable is added to
108 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
109 synchronization: Indicates when a distributed a variable will be
110 aggregated. Accepted values are constants defined in the class
111 `tf.VariableSynchronization`. By default the synchronization is set to
112 `AUTO` and the current `DistributionStrategy` chooses
113 when to synchronize. If `synchronization` is set to `ON_READ`,
114 `trainable` must not be set to `True`.
115 aggregation: Indicates how a distributed variable will be aggregated.
116 Accepted values are constants defined in the class
117 `tf.VariableAggregation`.
118 partitioner: Not handled at this time.
119 layout: the optional DTensor layout, used for creating DVariable.
121 Returns:
122 Variable instance.
123 """
124 init_val, variable_dtype = infer_init_val_and_dtype(
125 initializer, dtype, shape, layout
126 )
127 variable_shape = tf.TensorShape(shape)
129 if use_resource is None:
130 use_resource = True
132 if layout is None:
133 # In theory, in `use_resource` is True and `collections` is empty
134 # (that is to say, in TF2), we can use tf.Variable.
135 # However, this breaks legacy (Estimator) checkpoints because
136 # it changes variable names. Remove this when V1 is fully deprecated.
137 return tf1.Variable(
138 initial_value=init_val,
139 name=name,
140 trainable=trainable,
141 caching_device=caching_device,
142 dtype=variable_dtype,
143 validate_shape=validate_shape,
144 constraint=constraint,
145 use_resource=use_resource,
146 collections=collections,
147 synchronization=synchronization,
148 aggregation=aggregation,
149 shape=variable_shape if variable_shape else None,
150 experimental_enable_variable_lifting=experimental_enable_variable_lifting, # noqa: E501
151 )
152 else:
153 return dtensor.DVariable(
154 initial_value=init_val,
155 name=name,
156 trainable=trainable,
157 caching_device=caching_device,
158 dtype=variable_dtype,
159 validate_shape=validate_shape,
160 constraint=constraint,
161 collections=collections,
162 synchronization=synchronization,
163 aggregation=aggregation,
164 shape=variable_shape if variable_shape else None,
165 )
168def collect_previous_mask(input_tensors):
169 """Retrieves the output mask(s) of the previous node.
171 Args:
172 input_tensors: An arbitrary structure of Tensors.
174 Returns:
175 A mask tensor or list of mask tensors.
176 """
178 def _collect_previous_mask(x):
179 return getattr(x, "_keras_mask", None)
181 return tf.nest.map_structure(_collect_previous_mask, input_tensors)
184def have_all_keras_metadata(tensors):
185 return all(hasattr(x, "_keras_history") for x in tf.nest.flatten(tensors))
188def generate_placeholders_from_shape(shape):
189 return tf1.placeholder(shape=shape, dtype=backend.floatx())
192def create_keras_history(tensors):
193 """Wraps TensorFlow Operations for compatibility with the Functional API.
195 This method checks to see if a Tensor in `tensors` is missing Keras metadata
196 and has its origin in a Keras `Input` Layer. If so, this method will replace
197 the raw TensorFlow Operations that created this tensor with
198 `TensorFlowOpLayer` instances that create identical operations.
200 Any Tensors not originating from a Keras `Input` Layer will be treated as
201 constants when constructing `TensorFlowOpLayer` instances.
203 Args:
204 tensors: A structure of Tensors, some of which come from raw TensorFlow
205 operations and need to have Keras metadata assigned to them.
207 Returns:
208 created_layers: List. The `TensorFlowOpLayer` instances created to wrap
209 the raw Tensorflow operations.
210 """
211 _, created_layers = _create_keras_history_helper(tensors, set(), [])
212 return created_layers
215# Unsafe Internal attribute.
216# If True, Keras will not evaluate the constant-foldable inputs to tf op
217# layers in TF1 graphs. This *might* speed up model construction time in
218# certain settings, but it means
219# the models will not be serializable/deserializable via get_config
220# (Only via Savedmodels). It may also change the semantics of whether
221# generated random numbers are generated once and re-used, or recomputed
222# each time.
223# Note: This path triggers for TPUEstimators / xla compiled graphs regardless
224# of this setting.
225_UNSAFE_GRAPH_OP_LAYER_CREATION = False
228def _create_keras_history_helper(tensors, processed_ops, created_layers):
229 """Helper method for `create_keras_history`.
231 Args:
232 tensors: A structure of Tensors for which to create Keras metadata.
233 processed_ops: Set. TensorFlow operations that have already been wrapped
234 in `TensorFlowOpLayer` instances.
235 created_layers: List. The `TensorFlowOpLayer` instances created.
237 Returns:
238 Tuple. First element is the updated set of TensorFlow Operations that
239 have been wrapped in `TensorFlowOpLayer` instances. Second element is
240 a list of the `TensorFlowOpLayer` instances created.
241 """
242 if tf1.executing_eagerly_outside_functions():
243 raise ValueError(
244 "`create_keras_history` should only be called if eager is disabled!"
245 )
246 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
247 # Cannot be imported at top because of circular dependencies.
248 # TODO(omalleyt): Resolve circular dependency.
249 from keras.src.engine import base_layer
251 tensor_list = tf.nest.flatten(tensors)
252 sparse_ops = []
253 ragged_tensors = []
254 for tensor in tensor_list:
255 if getattr(tensor, "_keras_history", None) is not None:
256 continue
257 if isinstance(tensor, (tf.SparseTensor, tf1.SparseTensorValue)):
258 sparse_ops.append(tensor.op)
259 continue
260 if tf_utils.is_ragged(tensor):
261 # Ragged tensors don't have an op property
262 ragged_tensors.append(tensor)
263 continue
264 op = tensor.op # The Op that created this Tensor.
265 if op not in processed_ops:
266 # Recursively set `_keras_history`.
267 op_inputs = list(op.inputs)
268 constants = {}
269 layer_inputs = []
270 for i, op_input in enumerate(op_inputs):
271 if uses_keras_history(op_input):
272 layer_inputs.append(op_input)
273 else:
274 # Treat any value not originating from a `keras.Input` as
275 # a constant. Variables cannot be supported.
276 ds_with_session = (
277 tf.distribute.in_cross_replica_context()
278 and not tf1.executing_eagerly_outside_functions()
279 )
280 using_xla = control_flow_util.GraphOrParentsInXlaContext(
281 tf1.get_default_graph()
282 )
283 if (
284 ds_with_session
285 or using_xla
286 or _UNSAFE_GRAPH_OP_LAYER_CREATION
287 ):
288 # In Legacy Graph mode, evaluating here makes Session be
289 # configured improperly. The downside of this is that
290 # saving via `get_config` breaks, but SavedModel still
291 # works.
292 constants[i] = op_input
293 else:
294 with tf.init_scope():
295 constants[i] = backend.function([], op_input)([])
296 layer_inputs = unnest_if_single_tensor(layer_inputs)
297 processed_ops, created_layers = _create_keras_history_helper(
298 layer_inputs, processed_ops, created_layers
299 )
300 name = op.name
301 node_def = op.node_def.SerializeToString()
302 op_layer = base_layer.TensorFlowOpLayer(
303 node_def, constants=constants, name=name
304 )
305 created_layers.append(op_layer)
306 op_layer._set_connectivity_metadata(
307 args=(layer_inputs,), kwargs={}, outputs=op.outputs
308 )
309 processed_ops.update([op])
310 if sparse_ops or ragged_tensors:
311 lambda_example = """
312 weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
313 output = tf.keras.layers.Lambda(weights_mult)(input)
314 """
315 raise ValueError(
316 "Tensorflow ops that generate ragged or sparse tensor "
317 "outputs are currently not supported by Keras automatic "
318 "op wrapping. Please wrap these ops in a Lambda layer: "
319 "\n\n```\n{example}\n```\n"
320 "Sparse ops encountered: {sparse_ops}\n"
321 "Ragged tensors encountered: {ragged_tensors}\n".format(
322 example=lambda_example,
323 sparse_ops=str(sparse_ops),
324 ragged_tensors=str(ragged_tensors),
325 )
326 )
327 return processed_ops, created_layers
330def unnest_if_single_tensor(input_tensors):
331 # Preserve compatibility with older configs
332 flat_input_tensors = tf.nest.flatten(input_tensors)
333 # If this is a single element but not a dict, unwrap. If this is a dict,
334 # assume the first layer expects a dict (as is the case with a
335 # DenseFeatures layer); pass through.
336 if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
337 input_tensors = flat_input_tensors[0]
338 return input_tensors
341def needs_keras_history(tensors, ignore_call_context=False):
342 """Check if any Tensors need to be wrapped in TensorFlowOpLayers.
344 This will never return True inside a sublayer, because sublayers
345 do not need to create Keras History. Otherwise, this returns True
346 if one or more of `tensors` originates from a `keras.Input` and
347 does not have `_keras_history` set.
349 Args:
350 tensors: An arbitrary nested structure of Tensors.
351 ignore_call_context: Whether to ignore the check of if currently
352 outside of a `call` context. This is `True` when creating
353 KerasHistory inside `Node`, where we always know that Tensors
354 are being used with the Functional API.
356 Returns:
357 Bool, whether at least one Tensor needs to be wrapped.
358 """
359 input_tensors = tf.nest.flatten(tensors)
360 if call_context().in_call and not ignore_call_context:
361 return False
362 if all(
363 getattr(tensor, "_keras_history", None) is not None
364 for tensor in input_tensors
365 ):
366 # KerasHistory already set.
367 return False
368 return uses_keras_history(tensors)
371def is_in_keras_graph():
372 """Returns if currently executing inside of a Keras graph."""
373 return call_context().in_keras_graph
376def is_in_eager_or_tf_function():
377 """Returns if in eager mode or inside of a tf.function."""
378 return tf.executing_eagerly() or is_in_tf_function()
381def is_in_tf_function():
382 """Returns if inside of a tf.function."""
383 # Check if running in V1 graph mode.
384 if not tf1.executing_eagerly_outside_functions():
385 return False
386 if not tf.inside_function():
387 return False
388 # Check if inside Keras FuncGraph.
389 if is_in_keras_graph():
390 return False
391 # Check for a v1 `wrap_function` FuncGraph.
392 graph = tf1.get_default_graph()
393 if getattr(graph, "name", False) and graph.name.startswith(
394 "wrapped_function"
395 ):
396 return False
397 return True
400def uses_keras_history(tensors):
401 """Check if at least one Tensor originates from a `keras.Input`.
403 This is `True` if at least one Tensor has its origin in a `keras.Input`.
404 Any Tensor that originates from a `keras.Input` will have a dependency
405 Tensor with a `_keras_history` attribute attached. Tensors that have
406 already been checked to not originate from a `keras.Input`
407 are marked as `_keras_history_checked`.
409 Args:
410 tensors: An arbitrary nested structure of Tensors.
412 Returns:
413 Bool, whether at least one Tensor originates from a `keras.Input`.
414 """
415 checked_tensors = set()
416 tensors_to_check = tf.nest.flatten(tensors)
418 while tensors_to_check:
419 new_tensors_to_check = []
420 for tensor in tensors_to_check:
421 if id(tensor) in checked_tensors:
422 continue
424 checked_tensors.add(id(tensor))
426 if getattr(tensor, "_keras_history_checked", None) is not None:
427 continue
428 if getattr(tensor, "_keras_history", None) is not None:
429 return True
431 try:
432 new_tensors_to_check.extend(tensor.op.inputs)
433 except AttributeError:
434 # In case `tensor` is a Variable created in an Eager context.
435 pass
437 tensors_to_check = new_tensors_to_check
439 # Mark that these Tensors have been checked once for `_keras_history`,
440 # and should not be checked again for performance reasons.
441 mark_checked(tensors)
442 return False
445def mark_checked(tensors):
446 """Marks that these Tensors should not be tracked.
448 This prevents Layers from attempting to create TensorFlowOpLayers
449 for these Tensors.
451 Args:
452 tensors: An arbitrary structure of Tensors.
453 """
455 def _mark_checked(tensor):
456 tensor._keras_history_checked = True
458 tf.nest.map_structure(_mark_checked, tensors)
461def call_context():
462 """Returns currently active `CallContext`."""
463 call_ctx = getattr(_call_context, "call_context", None)
464 if call_ctx is None:
465 call_ctx = CallContext()
466 _call_context.call_context = call_ctx
467 return call_ctx
470# Inject the call_context function to keras_deps to remove the dependency
471# from TFLite to Keras.
472tf.__internal__.register_call_context_function(call_context)
475class CallContext:
476 """Keeps track of properties currently inside a Layer/Model's `call`.
478 Attributes:
479 in_call: Whether currently inside the `call` of a Layer.
480 layer: The `Layer` whose `call` is currently active.
481 inputs: The inputs to the currently active `Layer`.
482 build_graph: Whether currently inside a Graph or FuncGraph.
483 training: Whether currently executing in training or inference mode.
484 saving: Whether currently saving to SavedModel.
485 frozen: Whether currently executing inside a `Layer` with `trainable` set
486 to `False`.
487 in_keras_graph: Whether executing inside the Keras Graph.
488 """
490 def __init__(self):
491 # Handle `in_call` separately as it is the most-read attr and reading it
492 # is on the hot path.
493 self.in_call = False
494 self._state = {
495 "layer": None,
496 "inputs": None,
497 "build_graph": False,
498 "training": None,
499 "saving": None,
500 }
501 # TODO(b/150169018): This logic can be replaced after the Functional API
502 # refactor.
503 self._in_keras_graph = False
505 def enter(self, layer, inputs, build_graph, training, saving=None):
506 """Push a Layer and its inputs and state onto the current call context.
508 Args:
509 layer: The `Layer` whose `call` is currently active.
510 inputs: The inputs to the currently active `Layer`.
511 build_graph: Whether currently inside a Graph or FuncGraph.
512 training: Whether currently executing in training or inference mode.
513 saving: Whether currently saving to SavedModel.
515 Returns:
516 Context manager.
517 """
518 state = {
519 "layer": layer,
520 "inputs": inputs,
521 "build_graph": build_graph,
522 "training": training,
523 "saving": saving,
524 }
525 return CallContextManager(self, state)
527 @property
528 def layer(self):
529 return self._state["layer"]
531 @property
532 def inputs(self):
533 return self._state["inputs"]
535 @property
536 def build_graph(self):
537 return self._state["build_graph"]
539 @property
540 def training(self):
541 return self._state["training"]
543 @property
544 def saving(self):
545 return self._state["saving"]
547 @property
548 def frozen(self):
549 layer = self._state["layer"]
550 if not layer:
551 return False
552 return not layer.trainable
554 @property
555 def in_keras_graph(self):
556 # Returns True even if in a subgraph of the Keras graph, such as those
557 # created by control flow ops.
558 if tf.executing_eagerly():
559 return False
560 return (
561 self._in_keras_graph
562 or getattr(backend.get_graph(), "name", None) == "keras_graph"
563 )
566class CallContextManager:
567 """Context manager for `CallContext`."""
569 def __init__(self, call_ctx, state):
570 self._call_ctx = call_ctx
571 self._state = state
572 self._build_graph = state["build_graph"]
574 def __enter__(self):
575 call_ctx = self._call_ctx
576 self._prev_in_call = call_ctx.in_call
577 self._prev_state = call_ctx._state
579 call_ctx.in_call = True
580 call_ctx._state = self._state
582 # TODO(b/150169018): This logic can be removed after the Functional API
583 # refactor.
584 if self._build_graph:
585 self._prev_in_keras_graph = call_ctx._in_keras_graph
586 call_ctx._in_keras_graph = (
587 call_ctx._in_keras_graph
588 or getattr(backend.get_graph(), "name", None) == "keras_graph"
589 )
591 def __exit__(self, *exc_info):
592 call_ctx = self._call_ctx
593 call_ctx.in_call = self._prev_in_call
594 call_ctx._state = self._prev_state
596 if self._build_graph:
597 call_ctx._in_keras_graph = self._prev_in_keras_graph
600def training_arg_passed_to_call(argspec, args, kwargs):
601 """Returns whether a user passed the `training` argument in `__call__`."""
602 # `argspec.args` starts with ['self', 'inputs']
603 full_args = dict(zip(argspec.args[2:], args))
604 full_args.update(kwargs)
605 return "training" in full_args and full_args["training"] is not None
608def is_subclassed(layer):
609 """Returns True if the object is a subclassed layer or subclassed model."""
610 return (
611 layer.__module__.find("keras.engine") == -1
612 and layer.__module__.find("keras.layers") == -1
613 )
616def from_saved_model(layer):
617 """Returns whether the layer is loaded from a SavedModel."""
618 return layer.__module__.find("keras.saving.legacy.saved_model") != -1
621def check_graph_consistency(tensor=None, method="add_loss", force_raise=False):
622 """Checks that tensors passed to `add_*` method match the Keras graph.
624 When one of the `add_*` method is called inside a V2 conditional branch, the
625 underlying tensor gets created in a FuncGraph managed by control_flow_v2.
626 We need to raise clear error messages in such cases.
628 Args:
629 tensor: Tensor to check, or `False` if it is known that an error
630 should be raised.
631 method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
632 force_raise: If an error should be raised regardless of `tensor`.
634 Raises:
635 RuntimeError: In case of an out-of-graph tensor.
636 """
637 if force_raise or (
638 tf1.executing_eagerly_outside_functions()
639 and hasattr(tensor, "graph")
640 and tensor.graph.is_control_flow_graph
641 ):
642 if method == "activity_regularizer":
643 bad_example = """
644 class TestModel(tf.keras.Model):
646 def __init__(self):
647 super(TestModel, self).__init__(name='test_model')
648 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
650 def call(self, x, training=None):
651 if training:
652 return self.dense(x)
653 else:
654 return self.dense(x)
655 """
656 correct_example = """
657 class TestModel(tf.keras.Model):
659 def __init__(self):
660 super(TestModel, self).__init__(name='test_model')
661 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
663 def call(self, x, training=None):
664 return self.dense(x)
665 """
666 raise RuntimeError(
667 "You are using a layer with `activity_regularizer` in a "
668 f"control flow branch, e.g.:\n{bad_example}\nThis is currently "
669 "not supported. Please move your call to the layer with "
670 "`activity_regularizer` out of the control flow branch, "
671 f"e.g.:\n{correct_example}\nYou can also resolve this by "
672 "marking your outer model/layer dynamic (eager-only) by "
673 "passing `dynamic=True` to the layer constructor. Any kind of "
674 "control flow is supported with dynamic layers. Note that "
675 "using `dynamic=True` requires you to implement static shape "
676 "inference in the `compute_output_shape(input_shape)` "
677 "method."
678 )
680 if method == "add_metric":
681 bad_example = """
682 def call(self, inputs, training=None):
683 if training:
684 metric = compute_metric(inputs)
685 self.add_metric(metric, name='my_metric', aggregation='mean')
686 return inputs
687 """
688 correct_example = """
689 def call(self, inputs, training=None):
690 if training:
691 metric = compute_metric(inputs)
692 else:
693 metric = 0.
694 self.add_metric(metric, name='my_metric', aggregation='mean')
695 return inputs
696 """
697 elif method == "add_loss":
698 bad_example = """
699 def call(self, inputs, training=None):
700 if training:
701 loss = compute_loss(inputs)
702 self.add_loss(loss)
703 return inputs
704 """
705 correct_example = """
706 def call(self, inputs, training=None):
707 if training:
708 loss = compute_loss(inputs)
709 else:
710 loss = 0.
711 self.add_loss(loss)
712 return inputs
713 """
714 else:
715 bad_example = """
716 def call(self, inputs, training=None):
717 if training:
718 self.add_update(self.w.assign_add(1))
719 return inputs
720 """
721 correct_example = """
722 def call(self, inputs, training=None):
723 if training:
724 increment = 1
725 else:
726 increment = 0
727 self.add_update(self.w.assign_add(increment))
728 return inputs
729 """
730 raise RuntimeError(
731 "You are using the method `{method}` in a control flow branch "
732 "in your layer, e.g.:\n{bad_example}\n"
733 "This is not currently supported. "
734 "Please move your call to {method} out of the control flow branch, "
735 "e.g.:\n{correct_example}\n"
736 "You can also resolve this by marking your layer "
737 "as dynamic (eager-only) by passing "
738 "`dynamic=True` to the layer constructor. "
739 "Any kind of control flow is supported with dynamic layers. "
740 "Note that using `dynamic=True` requires you "
741 "to implement static shape inference "
742 "in the `compute_output_shape(input_shape)` method.".format(
743 method=method,
744 bad_example=bad_example,
745 correct_example=correct_example,
746 )
747 )
750def mark_as_return(outputs, acd):
751 """Marks `outputs` as the return values for automatic control deps."""
753 def _mark_as_return(tensor):
754 """Marks `tensor` as the return value for automatic control deps."""
755 if not tf.is_tensor(tensor):
756 return tensor
758 return_tensor = acd.mark_as_return(tensor)
759 if getattr(tensor, "_keras_mask", None) is not None:
760 return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
761 else:
762 return_tensor._keras_mask = None
764 # Handle TensorFlow Probability attached metadata.
765 # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
766 if getattr(tensor, "_tfp_distribution", None) is not None:
767 return_tensor._tfp_distribution = tensor._tfp_distribution
769 return return_tensor
771 return tf.nest.map_structure(_mark_as_return, outputs)
774V2_DTYPE_BEHAVIOR = None
777@keras_export(v1=["keras.layers.enable_v2_dtype_behavior"])
778def enable_v2_dtype_behavior():
779 """Enable the V2 dtype behavior for Keras layers.
781 By default, the V2 dtype behavior is enabled in TensorFlow 2, so this
782 function is only useful if `tf.compat.v1.disable_v2_behavior` has been
783 called. Since mixed precision requires V2 dtype behavior to be enabled, this
784 function allows you to use mixed precision in Keras layers if
785 `disable_v2_behavior` has been called.
787 When enabled, the dtype of Keras layers defaults to floatx (which is
788 typically float32) instead of None. In addition, layers will automatically
789 cast floating-point inputs to the layer's dtype.
791 >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
792 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
793 >>> print(layer.dtype) # float32 since V2 dtype behavior is enabled
794 float32
795 >>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled
796 >>> print(y.dtype.name)
797 float32
799 A layer author can opt-out their layer from the automatic input casting by
800 passing `autocast=False` to the base Layer's constructor. This disables the
801 autocasting part of the V2 behavior for that layer, but not the defaulting
802 to floatx part of the V2 behavior.
804 When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's
805 dtype will default to the global policy instead of floatx. Layers will
806 automatically cast inputs to the policy's compute_dtype.
807 """
808 global V2_DTYPE_BEHAVIOR
809 V2_DTYPE_BEHAVIOR = True
812@keras_export(v1=["keras.layers.disable_v2_dtype_behavior"])
813def disable_v2_dtype_behavior():
814 """Disables the V2 dtype behavior for Keras layers.
816 See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
817 """
818 global V2_DTYPE_BEHAVIOR
819 V2_DTYPE_BEHAVIOR = False
822def v2_dtype_behavior_enabled():
823 """Returns True if the V2 dtype behavior is enabled."""
824 if V2_DTYPE_BEHAVIOR is None:
825 return tf.__internal__.tf2.enabled()
826 return V2_DTYPE_BEHAVIOR
829class TrackableWeightHandler:
830 """Keras wrapper for handling Trackable object saving and restoring.
832 This class handles Trackables in both V1 and V2 modes, ensuring that they
833 can be saved and restored with the correct data and without adding
834 additional ops on every save.
836 Attributes:
837 trackable: The trackable to wrap.
838 num_tensors: The number of tensors that this trackable requires for
839 saving.
840 """
842 def __init__(self, trackable):
843 if not isinstance(trackable, tf.__internal__.tracking.Trackable):
844 raise ValueError(f"{trackable} is not a Trackable object.")
845 self._trackable = trackable
846 self._distribute_strategy = tf.distribute.get_strategy()
848 saveables = tf.__internal__.tracking.saveable_objects_from_trackable(
849 trackable
850 ).values()
851 # 'Saveables' won't exist when we're passed a legacy TF1 table like
852 # a StaticHashTable.
853 if not saveables:
854 self._num_tensors = 0
855 self._setter = lambda weights: None
856 self._getter = lambda: []
858 elif len(saveables) == 1:
859 saveable = list(saveables)[0]
861 if tf1.executing_eagerly_outside_functions():
862 # If we're in eager mode, we need to defer calling the
863 # Trackable's saveable() callable until data export time.
864 # However, it is safe to call the saveable as many times as we
865 # want, so we will call it now to figure out how many tensors
866 # this Trackable will produce.
867 self._saveable = saveable
868 self._num_tensors = len(self._saveable().specs)
869 self._setter = lambda weights: self._saveable().restore(
870 weights, None
871 )
872 self._getter = lambda: [
873 spec.tensor for spec in self._saveable().specs
874 ]
875 else:
876 # If we're in Graph mode, we need to evaluate the Saveable only
877 # once and cache the resulting restore graph. Failing to do this
878 # will result in new assignment ops being added to the graph
879 # each time set_weights() is called.
880 self._placeholder_tensors = []
881 self._saveable = saveable()
882 self._num_tensors = len(self._saveable.specs)
883 for spec in self._saveable.specs:
884 tensor = spec.tensor
885 self._placeholder_tensors.append(
886 tf1.placeholder(tensor.dtype, tensor.shape)
887 )
888 self._assign_op = self._saveable.restore(
889 self._placeholder_tensors, None
890 )
891 self._setter = self._set_weights_v1
892 self._getter = lambda: [
893 spec.tensor for spec in self._saveable.specs
894 ]
895 else:
896 raise ValueError(
897 "Only Trackables with one Saveable are supported. "
898 f"The Trackable {trackable} has {len(saveables)} Saveables."
899 )
901 @property
902 def num_tensors(self):
903 return self._num_tensors
905 def set_weights(self, weights):
906 if len(weights) != self._num_tensors:
907 raise ValueError(
908 f"Weight handler for trackable {self._trackable} received "
909 "an incorrect number of weights: "
910 f"expected {self._num_tensors} weights, "
911 f"got {len(weights)} weights."
912 )
913 self._setter(weights)
915 def get_tensors(self):
916 return self._getter()
918 def _set_weights_v1(self, weights):
919 feed_dict = {}
920 for idx, tensor in enumerate(weights):
921 feed_dict[self._placeholder_tensors[idx]] = tensor
922 backend.get_session().run(self._assign_op, feed_dict)
925def no_ragged_support(inputs, layer_name):
926 input_list = tf.nest.flatten(inputs)
927 if any(isinstance(x, tf.RaggedTensor) for x in input_list):
928 raise ValueError(
929 f"Layer {layer_name} does not support RaggedTensors as input. "
930 f"Inputs received: {inputs}. You can try converting your "
931 "input to a dense (uniform) tensor."
932 )
935def is_split_variable(v):
936 """Returns True if `v` is a PartitionedVariable or a ShardedVariable."""
937 return not {clz.__name__ for clz in v.__class__.__mro__}.isdisjoint(
938 {"PartitionedVariable", "ShardedVariable"}
939 )
942def has_weights(obj):
943 obj_type = type(obj)
944 return (
945 hasattr(obj_type, "trainable_weights")
946 and hasattr(obj_type, "non_trainable_weights")
947 and not isinstance(obj, type)
948 )
951# TODO(kathywu): This is a temporary hack. When a network of layers is revived
952# from SavedModel, only the top-level layer will have losses. This causes issues
953# in eager mode because the child layers may have graph losses
954# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
955# whenever eager losses are added to one layer, add eager losses to all
956# child layers. This causes `.losses` to only return eager losses.
957REVIVED_LOSS_PLACEHOLDER = (
958 "This layer's losses have been added to the parent layer."
959)