Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/tf_utils.py: 22%
282 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow-related utilities."""
17import collections
18import contextlib
19import copy
20import platform
21import random
22import threading
24import numpy as np
25import tensorflow.compat.v2 as tf
26from absl import logging
28from keras.src import backend
29from keras.src.engine import keras_tensor
30from keras.src.utils import object_identity
31from keras.src.utils import tf_contextlib
33# isort: off
34from tensorflow.python.framework import ops
35from tensorflow.python.util.tf_export import keras_export
36from tensorflow.python import pywrap_tfe
39@keras_export("keras.utils.set_random_seed", v1=[])
40def set_random_seed(seed):
41 """Sets all random seeds for the program (Python, NumPy, and TensorFlow).
43 You can use this utility to make almost any Keras program fully
44 deterministic. Some limitations apply in cases where network communications
45 are involved (e.g. parameter server distribution), which creates additional
46 sources of randomness, or when certain non-deterministic cuDNN ops are
47 involved.
49 Calling this utility is equivalent to the following:
51 ```python
52 import random
53 import numpy as np
54 import tensorflow as tf
55 random.seed(seed)
56 np.random.seed(seed)
57 tf.random.set_seed(seed)
58 ```
60 Arguments:
61 seed: Integer, the random seed to use.
62 """
63 if not isinstance(seed, int):
64 raise ValueError(
65 "Expected `seed` argument to be an integer. "
66 f"Received: seed={seed} (of type {type(seed)})"
67 )
68 random.seed(seed)
69 np.random.seed(seed)
70 tf.random.set_seed(seed)
71 backend._SEED_GENERATOR.generator = random.Random(seed)
74def get_random_seed():
75 """Retrieve a seed value to seed a random generator.
77 Returns:
78 the random seed as an integer.
79 """
80 if getattr(backend._SEED_GENERATOR, "generator", None):
81 return backend._SEED_GENERATOR.generator.randint(1, 1e9)
82 else:
83 return random.randint(1, 1e9)
86def is_tensor_or_tensor_list(v):
87 v = tf.nest.flatten(v)
88 if v and isinstance(v[0], tf.Tensor):
89 return True
90 else:
91 return False
94def get_reachable_from_inputs(inputs, targets=None):
95 """Returns the set of tensors/ops reachable from `inputs`.
97 Stops if all targets have been found (target is optional).
99 Only valid in Symbolic mode, not Eager mode.
101 Args:
102 inputs: List of tensors.
103 targets: List of tensors.
105 Returns:
106 A set of tensors reachable from the inputs (includes the inputs
107 themselves).
108 """
109 inputs = tf.nest.flatten(inputs, expand_composites=True)
110 reachable = object_identity.ObjectIdentitySet(inputs)
111 if targets:
112 remaining_targets = object_identity.ObjectIdentitySet(
113 tf.nest.flatten(targets)
114 )
115 queue = collections.deque(inputs)
117 while queue:
118 x = queue.pop()
119 if isinstance(x, tuple(_user_convertible_tensor_types)):
120 # Can't find consumers of user-specific types.
121 continue
123 if isinstance(x, tf.Operation):
124 outputs = x.outputs[:] or []
125 outputs += x._control_outputs
126 elif isinstance(x, tf.Variable):
127 try:
128 outputs = [x.op]
129 except AttributeError:
130 # Variables can be created in an Eager context.
131 outputs = []
132 elif tf.is_tensor(x):
133 outputs = x.consumers()
134 else:
135 raise TypeError(
136 "Expected tf.Operation, tf.Variable, or tf.Tensor. "
137 f"Received: {x}"
138 )
140 for y in outputs:
141 if y not in reachable:
142 reachable.add(y)
143 if targets:
144 remaining_targets.discard(y)
145 queue.appendleft(y)
147 if targets and not remaining_targets:
148 return reachable
150 return reachable
153# This function needs access to private functions of `nest`.
156def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
157 """Maps the atomic elements of a nested structure.
159 Args:
160 is_atomic_fn: A function that determines if an element of `nested` is
161 atomic.
162 map_fn: The function to apply to atomic elements of `nested`.
163 nested: A nested structure.
165 Returns:
166 The nested structure, with atomic elements mapped according to `map_fn`.
168 Raises:
169 ValueError: If an element that is neither atomic nor a sequence is
170 encountered.
171 """
172 if is_atomic_fn(nested):
173 return map_fn(nested)
175 # Recursively convert.
176 if not tf.nest.is_nested(nested):
177 raise ValueError(
178 f"Received non-atomic and non-sequence element: {nested} "
179 f"of type {type(nested)}"
180 )
181 if tf.__internal__.nest.is_mapping(nested):
182 values = [nested[k] for k in sorted(nested.keys())]
183 elif tf.__internal__.nest.is_attrs(nested):
184 values = _astuple(nested)
185 else:
186 values = nested
187 mapped_values = [
188 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
189 ]
190 return tf.__internal__.nest.sequence_like(nested, mapped_values)
193def get_shapes(tensors):
194 """Gets shapes from tensors."""
195 return tf.nest.map_structure(
196 lambda x: x.shape if hasattr(x, "shape") else None, tensors
197 )
200def convert_shapes(input_shape, to_tuples=True):
201 """Converts nested shape representations to desired format.
203 Performs:
205 TensorShapes -> tuples if `to_tuples=True`.
206 tuples of int or None -> TensorShapes if `to_tuples=False`.
208 Valid objects to be converted are:
209 - TensorShapes
210 - tuples with elements of type int or None.
211 - ints
212 - None
214 Args:
215 input_shape: A nested structure of objects to be converted to
216 TensorShapes.
217 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise
218 converts all tuples representing shapes to TensorShapes.
220 Returns:
221 Nested structure of shapes in desired format.
223 Raises:
224 ValueError: when the input tensor shape can't be converted to tuples, eg
225 unknown tensor shape.
226 """
228 def _is_shape_component(value):
229 return value is None or isinstance(value, (int, tf.compat.v1.Dimension))
231 def _is_atomic_shape(input_shape):
232 # Ex: TensorShape or (None, 10, 32) or 5 or `None`
233 if _is_shape_component(input_shape):
234 return True
235 if isinstance(input_shape, tf.TensorShape):
236 return True
237 if isinstance(input_shape, (tuple, list)) and all(
238 _is_shape_component(ele) for ele in input_shape
239 ):
240 return True
241 return False
243 def _convert_shape(input_shape):
244 input_shape = tf.TensorShape(input_shape)
245 if to_tuples:
246 input_shape = tuple(input_shape.as_list())
247 return input_shape
249 return map_structure_with_atomic(
250 _is_atomic_shape, _convert_shape, input_shape
251 )
254def validate_axis(axis, input_shape):
255 """Validate an axis value and returns its standardized form.
257 Args:
258 axis: Value to validate. Can be an integer or a list/tuple of integers.
259 Integers may be negative.
260 input_shape: Reference input shape that the axis/axes refer to.
262 Returns:
263 Normalized form of `axis`, i.e. a list with all-positive values.
264 """
265 input_shape = tf.TensorShape(input_shape)
266 rank = input_shape.rank
267 if not rank:
268 raise ValueError(
269 f"Input has undefined rank. Received: input_shape={input_shape}"
270 )
272 # Convert axis to list and resolve negatives
273 if isinstance(axis, int):
274 axis = [axis]
275 else:
276 axis = list(axis)
277 for idx, x in enumerate(axis):
278 if x < 0:
279 axis[idx] = rank + x
281 # Validate axes
282 for x in axis:
283 if x < 0 or x >= rank:
284 raise ValueError(
285 "Invalid value for `axis` argument. "
286 "Expected 0 <= axis < inputs.rank (with "
287 f"inputs.rank={rank}). Received: axis={tuple(axis)}"
288 )
289 if len(axis) != len(set(axis)):
290 raise ValueError(f"Duplicate axis: {tuple(axis)}")
291 return axis
294class ListWrapper:
295 """A wrapper for lists to be treated as elements for `nest`."""
297 def __init__(self, list_to_wrap):
298 self._list = list_to_wrap
300 def as_list(self):
301 return self._list
304def convert_inner_node_data(nested, wrap=False):
305 """Either wraps or unwraps innermost node data lists in `ListWrapper`
306 objects.
308 Args:
309 nested: A nested data structure.
310 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If
311 `False`, unwraps `ListWrapper` objects into lists.
313 Returns:
314 Structure of same type as nested, with lists wrapped/unwrapped.
315 """
317 def _is_serialized_node_data(nested):
318 # Node data can be of form `[layer_name, node_id, tensor_id]` or
319 # `[layer_name, node_id, tensor_id, kwargs]`.
320 if (
321 isinstance(nested, list)
322 and (len(nested) in [3, 4])
323 and isinstance(nested[0], str)
324 ):
325 return True
326 return False
328 def _is_atomic_nested(nested):
329 """Returns `True` if `nested` is a list representing node data."""
330 if isinstance(nested, ListWrapper):
331 return True
332 if _is_serialized_node_data(nested):
333 return True
334 return not tf.nest.is_nested(nested)
336 def _convert_object_or_list(nested):
337 """Convert b/t `ListWrapper` object and list representations."""
338 if wrap:
339 if isinstance(nested, ListWrapper):
340 return nested
341 if _is_serialized_node_data(nested):
342 return ListWrapper(nested)
343 return nested
344 else:
345 if isinstance(nested, ListWrapper):
346 return nested.as_list()
347 return nested
349 return map_structure_with_atomic(
350 _is_atomic_nested, _convert_object_or_list, nested
351 )
354def shape_type_conversion(fn):
355 """Decorator that handles tuple/TensorShape conversion.
357 Used in `compute_output_shape` and `build`.
359 Args:
360 fn: function to wrap.
362 Returns:
363 Wrapped function.
364 """
366 def wrapper(instance, input_shape):
367 # Pass shapes as tuples to `fn`
368 # This preserves compatibility with external Keras.
369 if input_shape is not None:
370 input_shape = convert_shapes(input_shape, to_tuples=True)
371 output_shape = fn(instance, input_shape)
372 # Return shapes from `fn` as TensorShapes.
373 if output_shape is not None:
374 output_shape = convert_shapes(output_shape, to_tuples=False)
375 return output_shape
377 return wrapper
380def are_all_symbolic_tensors(tensors):
381 return all(map(is_symbolic_tensor, tensors))
384_user_convertible_tensor_types = set()
387def is_extension_type(tensor):
388 """Returns whether a tensor is of an ExtensionType.
390 github.com/tensorflow/community/pull/269
391 Currently it works by checking if `tensor` is a `CompositeTensor` instance,
392 but this will be changed to use an appropriate extensiontype protocol
393 check once ExtensionType is made public.
395 Args:
396 tensor: An object to test
398 Returns:
399 True if the tensor is an extension type object, false if not.
400 """
401 return isinstance(tensor, tf.__internal__.CompositeTensor)
404def is_symbolic_tensor(tensor):
405 """Returns whether a tensor is symbolic (from a TF graph) or an eager
406 tensor.
408 A Variable can be seen as either: it is considered symbolic
409 when we are in a graph scope, and eager when we are in an eager scope.
411 Args:
412 tensor: A tensor instance to test.
414 Returns:
415 True for symbolic tensors, False for eager tensors.
416 """
417 if isinstance(tensor, tf.Tensor):
418 return hasattr(tensor, "graph")
419 elif is_extension_type(tensor):
420 component_tensors = tf.nest.flatten(tensor, expand_composites=True)
421 return any(hasattr(t, "graph") for t in component_tensors)
422 elif isinstance(tensor, tf.Variable):
423 # Variables that are output of a Keras Layer in Functional API mode
424 # should be considered symbolic.
425 # TODO(omalleyt): We need a better way to check this in order to
426 # enable `run_eagerly=True` for Models containing Layers that
427 # return Variables as outputs.
428 return (
429 getattr(tensor, "_keras_history", False)
430 or not tf.executing_eagerly()
431 )
432 elif isinstance(tensor, tuple(_user_convertible_tensor_types)):
433 tensor = ops.convert_to_tensor_or_composite(tensor)
434 return is_symbolic_tensor(tensor)
435 else:
436 return False
439@keras_export("keras.__internal__.utils.register_symbolic_tensor_type", v1=[])
440def register_symbolic_tensor_type(cls):
441 """Allows users to specify types regarded as symbolic `Tensor`s.
443 Used in conjunction with `tf.register_tensor_conversion_function`, calling
444 `tf.keras.__internal__.utils.register_symbolic_tensor_type(cls)`
445 allows non-`Tensor` objects to be plumbed through Keras layers.
447 Example:
449 ```python
450 # One-time setup.
451 class Foo:
452 def __init__(self, input_):
453 self._input = input_
454 def value(self):
455 return tf.constant(42.)
457 tf.register_tensor_conversion_function(
458 Foo, lambda x, *args, **kwargs: x.value())
460 tf.keras.__internal__.utils.register_symbolic_tensor_type(Foo)
462 # User-land.
463 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
464 ```
466 Args:
467 cls: A `class` type which shall be regarded as a symbolic `Tensor`.
468 """
469 global _user_convertible_tensor_types
470 if cls not in _user_convertible_tensor_types:
471 keras_tensor.register_keras_tensor_specialization(
472 cls, keras_tensor.UserRegisteredTypeKerasTensor
473 )
474 _user_convertible_tensor_types.add(cls)
477def type_spec_from_value(value):
478 """Grab type_spec without converting array-likes to tensors."""
479 if is_extension_type(value):
480 return value._type_spec
481 # Get a TensorSpec for array-like data without
482 # converting the data to a Tensor
483 if hasattr(value, "shape") and hasattr(value, "dtype"):
484 return tf.TensorSpec(value.shape, value.dtype)
485 else:
486 return tf.type_spec_from_value(value)
489def is_ragged(tensor):
490 """Returns true if `tensor` is a ragged tensor or ragged tensor value."""
491 return isinstance(
492 tensor, (tf.RaggedTensor, tf.compat.v1.ragged.RaggedTensorValue)
493 )
496def is_sparse(tensor):
497 """Returns true if `tensor` is a sparse tensor or sparse tensor value."""
498 return isinstance(tensor, (tf.SparseTensor, tf.compat.v1.SparseTensorValue))
501def is_tensor_or_variable(x):
502 return tf.is_tensor(x) or isinstance(x, tf.Variable)
505def is_tensor_or_extension_type(x):
506 """Returns true if 'x' is a TF-native type or an ExtensionType."""
507 return tf.is_tensor(x) or is_extension_type(x)
510def convert_variables_to_tensors(values):
511 """Converts `Variable`s in `values` to `Tensor`s.
513 This is a Keras version of `convert_variables_to_tensors` in TensorFlow
514 variable_utils.py.
516 If an object in `values` is an `ExtensionType` and it overrides its
517 `_convert_variables_to_tensors` method, its `ResourceVariable` components
518 will also be converted to `Tensor`s. Objects other than `ResourceVariable`s
519 in `values` will be returned unchanged.
521 Args:
522 values: A nested structure of `ResourceVariable`s, or any other objects.
524 Returns:
525 A new structure with `ResourceVariable`s in `values` converted to
526 `Tensor`s.
527 """
529 def _convert_resource_variable_to_tensor(x):
530 if isinstance(x, tf.Variable):
531 return tf.convert_to_tensor(x)
532 elif is_extension_type(x):
533 return x._convert_variables_to_tensors()
534 else:
535 return x
537 return tf.nest.map_structure(_convert_resource_variable_to_tensor, values)
540def assert_no_legacy_layers(layers):
541 """Prevent tf.layers.Layers from being used with Keras.
543 Certain legacy layers inherit from their keras analogs; however they are
544 not supported with keras and can lead to subtle and hard to diagnose bugs.
546 Args:
547 layers: A list of layers to check
549 Raises:
550 TypeError: If any elements of layers are tf.layers.Layers
551 """
553 # isinstance check for tf.layers.Layer introduces a circular dependency.
554 legacy_layers = [l for l in layers if getattr(l, "_is_legacy_layer", None)]
555 if legacy_layers:
556 layer_str = "\n".join(" " + str(l) for l in legacy_layers)
557 raise TypeError(
558 f"The following are legacy tf.layers.Layers:\n{layer_str}\n"
559 "To use keras as a "
560 "framework (for instance using the Network, Model, or Sequential "
561 "classes), please use the tf.keras.layers implementation instead. "
562 "(Or, if writing custom layers, subclass from tf.keras.layers "
563 "rather than tf.layers)"
564 )
567@tf_contextlib.contextmanager
568def maybe_init_scope(layer):
569 """Open an `init_scope` if in V2 mode and using the keras graph.
571 Args:
572 layer: The Layer/Model that is currently active.
574 Yields:
575 None
576 """
577 # Don't open an init_scope in V1 mode, when using legacy tf.layers, or in a
578 # local-variable scope.
579 # The local-variable scope should ensure that created variables are local to
580 # the function being executed, rather than lifted out of the graph by
581 # `init_scope`. This way the variables are freely usable and mutable within
582 # the function, which enables a visitation guarantee for model evaluation,
583 # when the scope is applied to metric variable creation.
584 if (
585 tf.compat.v1.executing_eagerly_outside_functions()
586 and getattr(layer, "_keras_style", True)
587 and not in_local_vars_context()
588 ):
589 with tf.init_scope():
590 yield
591 else:
592 yield
595@tf_contextlib.contextmanager
596def graph_context_for_symbolic_tensors(*args, **kwargs):
597 """Returns graph context manager if any of the inputs is a symbolic
598 tensor."""
599 if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
600 with backend.get_graph().as_default():
601 yield
602 else:
603 yield
606def dataset_is_infinite(dataset):
607 """True if the passed dataset is infinite."""
608 if tf.compat.v1.executing_eagerly_outside_functions():
609 return tf.equal(
610 tf.data.experimental.cardinality(dataset),
611 tf.data.experimental.INFINITE_CARDINALITY,
612 )
613 else:
614 dataset_size = backend.get_session().run(
615 tf.data.experimental.cardinality(dataset)
616 )
617 return dataset_size == tf.data.experimental.INFINITE_CARDINALITY
620def get_tensor_spec(t, dynamic_batch=False, name=None):
621 """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
623 if isinstance(t, tf.TypeSpec):
624 spec = t
625 elif is_extension_type(t):
626 # TODO(b/148821952): Should these specs have a name attr?
627 spec = t._type_spec
628 elif hasattr(t, "_keras_history") and hasattr(
629 t._keras_history[0], "_type_spec"
630 ):
631 return t._keras_history[0]._type_spec
632 elif isinstance(t, keras_tensor.KerasTensor):
633 spec = t.type_spec
634 elif hasattr(t, "shape") and hasattr(t, "dtype"):
635 spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
636 else:
637 return None # Allow non-Tensors to pass through.
639 if not dynamic_batch:
640 return spec
642 shape = spec.shape
643 if shape.rank is None or shape.rank == 0:
644 return spec
646 shape_list = shape.as_list()
647 shape_list[0] = None
648 # TODO(b/203201161) Remove this deepcopy one type_spec_with_shape has been
649 # updated to not mutate spec.
650 spec = copy.deepcopy(spec)
651 return keras_tensor.type_spec_with_shape(spec, tf.TensorShape(shape_list))
654def sync_to_numpy_or_python_type(tensors):
655 """Syncs and converts a structure of `Tensor`s to `NumPy` arrays or Python
656 scalar types.
658 For each tensor, it calls `tensor.numpy()`. If the result is a scalar value,
659 it converts it to a Python type, such as a float or int, by calling
660 `result.item()`.
662 Numpy scalars are converted, as Python types are often more convenient to
663 deal with. This is especially useful for bfloat16 Numpy scalars, which don't
664 support as many operations as other Numpy values.
666 Async strategies (such as `TPUStrategy` and `ParameterServerStrategy`) are
667 forced to
668 sync during this process.
670 Args:
671 tensors: A structure of tensors.
673 Returns:
674 `tensors`, but scalar tensors are converted to Python types and non-scalar
675 tensors are converted to Numpy arrays.
676 """
677 if isinstance(tensors, tf.distribute.experimental.coordinator.RemoteValue):
678 tensors = tensors.fetch()
679 if isinstance(tensors, list) and isinstance(
680 tensors[0], tf.distribute.experimental.coordinator.RemoteValue
681 ):
682 tensors = tf.nest.map_structure(lambda t: t.fetch(), tensors)
684 def _to_single_numpy_or_python_type(t):
685 # Don't turn ragged or sparse tensors to NumPy.
686 if isinstance(t, tf.Tensor):
687 t = t.numpy()
688 # Strings, ragged and sparse tensors don't have .item(). Return them
689 # as-is.
690 if not isinstance(t, (np.ndarray, np.generic)):
691 return t
692 return t.item() if np.ndim(t) == 0 else t
694 return tf.nest.map_structure(_to_single_numpy_or_python_type, tensors)
697def _astuple(attrs):
698 """Converts the given attrs to tuple non-recursively."""
699 cls = type(attrs)
700 fields = getattr(cls, "__attrs_attrs__", None)
701 if fields is None:
702 raise ValueError(f"{cls} is not an attrs-decorated class.")
703 values = []
704 for field in fields:
705 values.append(getattr(attrs, field.name))
706 return tuple(values)
709def can_jit_compile(warn=False):
710 """Returns True if TensorFlow XLA is available for the platform."""
711 if platform.system() == "Darwin" and "arm" in platform.processor().lower():
712 if warn:
713 logging.warning(
714 "XLA (`jit_compile`) is not yet supported on Apple M1/M2 ARM "
715 "processors. Falling back to `jit_compile=False`."
716 )
717 return False
718 if pywrap_tfe.TF_ListPluggablePhysicalDevices():
719 if warn:
720 logging.warning(
721 "XLA (`jit_compile`) is not supported on your system. "
722 "Falling back to `jit_compile=False`."
723 )
724 return False
725 return True
728_metric_local_vars_scope = threading.local()
731def get_metric_local_vars_scope():
732 try:
733 return _metric_local_vars_scope.current
734 except AttributeError:
735 return None
738def in_local_vars_context():
739 ctx = get_metric_local_vars_scope()
740 return ctx is not None
743@contextlib.contextmanager
744def with_metric_local_vars_scope():
745 previous_scope = getattr(_metric_local_vars_scope, "current", None)
746 _metric_local_vars_scope.current = MetricLocalVarsScope()
747 yield
748 _metric_local_vars_scope.current = previous_scope
751class MetricLocalVarsScope:
752 """Turn on local variable creation for Metrics.
754 No functionality is needed here, it just exists to modulate Metric's
755 variable creation."""