Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py: 25%
219 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow-related utilities."""
17import collections
18import copy
19import numpy as np
21from tensorflow.python.data.experimental.ops import cardinality
22from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
23from tensorflow.python.eager import context
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.framework import type_spec
31from tensorflow.python.keras import backend as K
32from tensorflow.python.keras.engine import keras_tensor
33from tensorflow.python.keras.utils import object_identity
34from tensorflow.python.keras.utils import tf_contextlib
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.ops.ragged import ragged_tensor_value
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import keras_export
43def is_tensor_or_tensor_list(v):
44 v = nest.flatten(v)
45 if v and isinstance(v[0], ops.Tensor):
46 return True
47 else:
48 return False
51def get_reachable_from_inputs(inputs, targets=None):
52 """Returns the set of tensors/ops reachable from `inputs`.
54 Stops if all targets have been found (target is optional).
56 Only valid in Symbolic mode, not Eager mode.
58 Args:
59 inputs: List of tensors.
60 targets: List of tensors.
62 Returns:
63 A set of tensors reachable from the inputs (includes the inputs themselves).
64 """
65 inputs = nest.flatten(inputs, expand_composites=True)
66 reachable = object_identity.ObjectIdentitySet(inputs)
67 if targets:
68 remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets))
69 queue = collections.deque(inputs)
71 while queue:
72 x = queue.pop()
73 if isinstance(x, tuple(_user_convertible_tensor_types)):
74 # Can't find consumers of user-specific types.
75 continue
77 if isinstance(x, ops.Operation):
78 outputs = x.outputs[:] or []
79 outputs += x._control_outputs # pylint: disable=protected-access
80 elif isinstance(x, variables.Variable):
81 try:
82 outputs = [x.op]
83 except AttributeError:
84 # Variables can be created in an Eager context.
85 outputs = []
86 elif tensor_util.is_tf_type(x):
87 outputs = x.consumers()
88 else:
89 raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
91 for y in outputs:
92 if y not in reachable:
93 reachable.add(y)
94 if targets:
95 remaining_targets.discard(y)
96 queue.appendleft(y)
98 if targets and not remaining_targets:
99 return reachable
101 return reachable
104# This function needs access to private functions of `nest`.
105# pylint: disable=protected-access
106def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
107 """Maps the atomic elements of a nested structure.
109 Args:
110 is_atomic_fn: A function that determines if an element of `nested` is
111 atomic.
112 map_fn: The function to apply to atomic elements of `nested`.
113 nested: A nested structure.
115 Returns:
116 The nested structure, with atomic elements mapped according to `map_fn`.
118 Raises:
119 ValueError: If an element that is neither atomic nor a sequence is
120 encountered.
121 """
122 if is_atomic_fn(nested):
123 return map_fn(nested)
125 # Recursively convert.
126 if not nest.is_nested(nested):
127 raise ValueError(
128 'Received non-atomic and non-sequence element: {}'.format(nested))
129 if nest.is_mapping(nested):
130 values = [nested[k] for k in sorted(nested.keys())]
131 elif nest.is_attrs(nested):
132 values = _astuple(nested)
133 else:
134 values = nested
135 mapped_values = [
136 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
137 ]
138 return nest._sequence_like(nested, mapped_values)
141def get_shapes(tensors):
142 """Gets shapes from tensors."""
143 return nest.map_structure(lambda x: x.shape, tensors)
146# pylint: enable=protected-access
149def convert_shapes(input_shape, to_tuples=True):
150 """Converts nested shape representations to desired format.
152 Performs:
154 TensorShapes -> tuples if `to_tuples=True`.
155 tuples of int or None -> TensorShapes if `to_tuples=False`.
157 Valid objects to be converted are:
158 - TensorShapes
159 - tuples with elements of type int or None.
160 - ints
161 - None
163 Args:
164 input_shape: A nested structure of objects to be converted to TensorShapes.
165 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts
166 all tuples representing shapes to TensorShapes.
168 Returns:
169 Nested structure of shapes in desired format.
171 Raises:
172 ValueError: when the input tensor shape can't be converted to tuples, eg
173 unknown tensor shape.
174 """
176 def _is_shape_component(value):
177 return value is None or isinstance(value, (int, tensor_shape.Dimension))
179 def _is_atomic_shape(input_shape):
180 # Ex: TensorShape or (None, 10, 32) or 5 or `None`
181 if _is_shape_component(input_shape):
182 return True
183 if isinstance(input_shape, tensor_shape.TensorShape):
184 return True
185 if (isinstance(input_shape, (tuple, list)) and
186 all(_is_shape_component(ele) for ele in input_shape)):
187 return True
188 return False
190 def _convert_shape(input_shape):
191 input_shape = tensor_shape.TensorShape(input_shape)
192 if to_tuples:
193 input_shape = tuple(input_shape.as_list())
194 return input_shape
196 return map_structure_with_atomic(_is_atomic_shape, _convert_shape,
197 input_shape)
200class ListWrapper(object):
201 """A wrapper for lists to be treated as elements for `nest`."""
203 def __init__(self, list_to_wrap):
204 self._list = list_to_wrap
206 def as_list(self):
207 return self._list
210def convert_inner_node_data(nested, wrap=False):
211 """Either wraps or unwraps innermost node data lists in `ListWrapper` objects.
213 Args:
214 nested: A nested data structure.
215 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`,
216 unwraps `ListWrapper` objects into lists.
218 Returns:
219 Structure of same type as nested, with lists wrapped/unwrapped.
220 """
222 def _is_serialized_node_data(nested):
223 # Node data can be of form `[layer_name, node_id, tensor_id]` or
224 # `[layer_name, node_id, tensor_id, kwargs]`.
225 if (isinstance(nested, list) and (len(nested) in [3, 4]) and
226 isinstance(nested[0], str)):
227 return True
228 return False
230 def _is_atomic_nested(nested):
231 """Returns `True` if `nested` is a list representing node data."""
232 if isinstance(nested, ListWrapper):
233 return True
234 if _is_serialized_node_data(nested):
235 return True
236 return not nest.is_nested(nested)
238 def _convert_object_or_list(nested):
239 """Convert b/t `ListWrapper` object and list representations."""
240 if wrap:
241 if isinstance(nested, ListWrapper):
242 return nested
243 if _is_serialized_node_data(nested):
244 return ListWrapper(nested)
245 return nested
246 else:
247 if isinstance(nested, ListWrapper):
248 return nested.as_list()
249 return nested
251 return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list,
252 nested)
255def shape_type_conversion(fn):
256 """Decorator that handles tuple/TensorShape conversion.
258 Used in `compute_output_shape` and `build`.
260 Args:
261 fn: function to wrap.
263 Returns:
264 Wrapped function.
265 """
267 def wrapper(instance, input_shape):
268 # Pass shapes as tuples to `fn`
269 # This preserves compatibility with external Keras.
270 if input_shape is not None:
271 input_shape = convert_shapes(input_shape, to_tuples=True)
272 output_shape = fn(instance, input_shape)
273 # Return shapes from `fn` as TensorShapes.
274 if output_shape is not None:
275 output_shape = convert_shapes(output_shape, to_tuples=False)
276 return output_shape
278 return wrapper
281def are_all_symbolic_tensors(tensors):
282 return all(map(is_symbolic_tensor, tensors))
285_user_convertible_tensor_types = set()
288def is_extension_type(tensor):
289 """Returns whether a tensor is of an ExtensionType.
291 github.com/tensorflow/community/pull/269
292 Currently it works by checking if `tensor` is a `CompositeTensor` instance,
293 but this will be changed to use an appropriate extensiontype protocol
294 check once ExtensionType is made public.
296 Args:
297 tensor: An object to test
299 Returns:
300 True if the tensor is an extension type object, false if not.
301 """
302 return isinstance(tensor, composite_tensor.CompositeTensor)
305def is_symbolic_tensor(tensor):
306 """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
308 A Variable can be seen as either: it is considered symbolic
309 when we are in a graph scope, and eager when we are in an eager scope.
311 Args:
312 tensor: A tensor instance to test.
314 Returns:
315 True for symbolic tensors, False for eager tensors.
316 """
317 if isinstance(tensor, ops.Tensor):
318 return hasattr(tensor, 'graph')
319 elif is_extension_type(tensor):
320 component_tensors = nest.flatten(tensor, expand_composites=True)
321 return any(hasattr(t, 'graph') for t in component_tensors)
322 elif isinstance(tensor, variables.Variable):
323 # Variables that are output of a Keras Layer in Functional API mode
324 # should be considered symbolic.
325 # TODO(omalleyt): We need a better way to check this in order to
326 # enable `run_eagerly=True` for Models containing Layers that
327 # return Variables as outputs.
328 return (getattr(tensor, '_keras_history', False) or
329 not context.executing_eagerly())
330 elif isinstance(tensor, tuple(_user_convertible_tensor_types)):
331 tensor = ops.convert_to_tensor_or_composite(tensor)
332 return is_symbolic_tensor(tensor)
333 else:
334 return False
337@keras_export('keras.__internal__.utils.register_symbolic_tensor_type', v1=[])
338def register_symbolic_tensor_type(cls):
339 """Allows users to specify types regarded as symbolic `Tensor`s.
341 Used in conjunction with `tf.register_tensor_conversion_function`, calling
342 `tf.keras.__internal__.utils.register_symbolic_tensor_type(cls)`
343 allows non-`Tensor` objects to be plumbed through Keras layers.
345 Example:
347 ```python
348 # One-time setup.
349 class Foo(object):
350 def __init__(self, input_):
351 self._input = input_
352 def value(self):
353 return tf.constant(42.)
355 tf.register_tensor_conversion_function(
356 Foo, lambda x, *args, **kwargs: x.value())
358 tf.keras.__internal__.utils.register_symbolic_tensor_type(Foo)
360 # User-land.
361 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
362 ```
364 Args:
365 cls: A `class` type which shall be regarded as a symbolic `Tensor`.
366 """
367 global _user_convertible_tensor_types
368 if cls not in _user_convertible_tensor_types:
369 keras_tensor.register_keras_tensor_specialization(
370 cls, keras_tensor.UserRegisteredTypeKerasTensor)
371 _user_convertible_tensor_types.add(cls)
374def type_spec_from_value(value):
375 """Grab type_spec without converting array-likes to tensors."""
376 if is_extension_type(value):
377 return value._type_spec # pylint: disable=protected-access
378 # Get a TensorSpec for array-like data without
379 # converting the data to a Tensor
380 if hasattr(value, 'shape') and hasattr(value, 'dtype'):
381 return tensor_spec.TensorSpec(value.shape, value.dtype)
382 else:
383 return type_spec.type_spec_from_value(value)
386def is_ragged(tensor):
387 """Returns true if `tensor` is a ragged tensor or ragged tensor value."""
388 return isinstance(
389 tensor,
390 (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue))
393def is_sparse(tensor):
394 """Returns true if `tensor` is a sparse tensor or sparse tensor value."""
395 return isinstance(
396 tensor,
397 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))
400def is_tensor_or_variable(x):
401 return tensor_util.is_tf_type(x) or isinstance(x, variables.Variable)
404def assert_no_legacy_layers(layers):
405 """Prevent tf.layers.Layers from being used with Keras.
407 Certain legacy layers inherit from their keras analogs; however they are
408 not supported with keras and can lead to subtle and hard to diagnose bugs.
410 Args:
411 layers: A list of layers to check
413 Raises:
414 TypeError: If any elements of layers are tf.layers.Layers
415 """
417 # isinstance check for tf.layers.Layer introduces a circular dependency.
418 legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)]
419 if legacy_layers:
420 layer_str = '\n'.join(' ' + str(l) for l in legacy_layers)
421 raise TypeError(
422 'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a '
423 'framework (for instance using the Network, Model, or Sequential '
424 'classes), please use the tf.keras.layers implementation instead. '
425 '(Or, if writing custom layers, subclass from tf.keras.layers rather '
426 'than tf.layers)'.format(layer_str))
429@tf_contextlib.contextmanager
430def maybe_init_scope(layer):
431 """Open an `init_scope` if in V2 mode and using the keras graph.
433 Args:
434 layer: The Layer/Model that is currently active.
436 Yields:
437 None
438 """
439 # Don't open an init_scope in V1 mode or when using legacy tf.layers.
440 if (ops.executing_eagerly_outside_functions() and
441 getattr(layer, '_keras_style', True)):
442 with ops.init_scope():
443 yield
444 else:
445 yield
448@tf_contextlib.contextmanager
449def graph_context_for_symbolic_tensors(*args, **kwargs):
450 """Returns graph context manager if any of the inputs is a symbolic tensor."""
451 if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
452 with K.get_graph().as_default():
453 yield
454 else:
455 yield
458def dataset_is_infinite(dataset):
459 """True if the passed dataset is infinite."""
460 if ops.executing_eagerly_outside_functions():
461 return math_ops.equal(
462 cardinality.cardinality(dataset), cardinality.INFINITE)
463 else:
464 dataset_size = K.get_session().run(cardinality.cardinality(dataset))
465 return dataset_size == cardinality.INFINITE
468def get_tensor_spec(t, dynamic_batch=False, name=None):
469 """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
470 # pylint: disable=protected-access
471 if isinstance(t, type_spec.TypeSpec):
472 spec = t
473 elif is_extension_type(t):
474 # TODO(b/148821952): Should these specs have a name attr?
475 spec = t._type_spec
476 elif (hasattr(t, '_keras_history') and
477 hasattr(t._keras_history[0], '_type_spec')):
478 return t._keras_history[0]._type_spec
479 elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
480 spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
481 else:
482 return None # Allow non-Tensors to pass through.
484 if not dynamic_batch:
485 return spec
487 dynamic_batch_spec = copy.deepcopy(spec)
488 # RaggedTensorSpec only has a private _shape.
489 shape = dynamic_batch_spec._shape
490 if shape.rank is not None and shape.rank > 0:
491 shape_list = shape.as_list()
492 shape_list[0] = None
493 dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list)
494 return dynamic_batch_spec
495 # pylint: enable=protected-access
498def sync_to_numpy_or_python_type(tensors):
499 """Syncs and converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types.
501 For each tensor, it calls `tensor.numpy()`. If the result is a scalar value,
502 it converts it to a Python type, such as a float or int, by calling
503 `result.item()`.
505 Numpy scalars are converted, as Python types are often more convenient to deal
506 with. This is especially useful for bfloat16 Numpy scalars, which don't
507 support as many operations as other Numpy values.
509 Async strategies (such as `TPUStrategy` and `ParameterServerStrategy`) are
510 forced to
511 sync during this process.
513 Args:
514 tensors: A structure of tensors.
516 Returns:
517 `tensors`, but scalar tensors are converted to Python types and non-scalar
518 tensors are converted to Numpy arrays.
519 """
520 if isinstance(tensors, coordinator_lib.RemoteValue):
521 return tensors.fetch()
523 def _to_single_numpy_or_python_type(t):
524 if isinstance(t, ops.Tensor):
525 x = t.numpy()
526 return x.item() if np.ndim(x) == 0 else x
527 return t # Don't turn ragged or sparse tensors to NumPy.
529 return nest.map_structure(_to_single_numpy_or_python_type, tensors)
532def _astuple(attrs):
533 """Converts the given attrs to tuple non-recursively."""
534 cls = type(attrs)
535 fields = getattr(cls, '__attrs_attrs__', None)
536 if fields is None:
537 raise ValueError('%r is not an attrs-decorated class.' % cls)
538 values = []
539 for field in fields:
540 values.append(getattr(attrs, field.name))
541 return tuple(values)