Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_autograph.py: 46%
35 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Autograph specifc overrides for tf.data.ops."""
16import functools
18import numpy as np
20from tensorflow.python.autograph.operators import control_flow
21from tensorflow.python.autograph.operators import py_builtins
22from tensorflow.python.data.ops import iterator_ops
23from tensorflow.python.framework import tensor_conversion
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.ops import cond
26from tensorflow.python.util import nest
29# TODO(mdan): These checks should be easier. Fix the nest API.
30def _verify_spec_compatible(input_name, spec_name, input_, spec):
31 """Verifies that a symbol has a type compatible vith a given spec.
33 Here, compatibility is viewed in the general TensorFlow sense: that the dtypes
34 are the same after implicit conversion, if both are tensors.
36 This verifier ensures consistent treatment of types across AutoGraph.
38 Args:
39 input_name: A name to use for `input_` in error messages.
40 spec_name: A name to use for `spec` in error messages.
41 input_: Any, value to verify.
42 spec: TypeSpec that `input_` must be compatible with.
44 Raises:
45 ValueError if the two types have been determined not to be compatible.
46 """
47 assert isinstance(spec, tensor_spec.TensorSpec)
48 if input is None:
49 # TODO(mdan): raise from None when switching to Py3.
50 raise ValueError("{} cannot be None".format(input_name))
52 # TODO(mdan): Use TensorCompatible when ready.
53 if isinstance(input_, (bool, int, float, str, np.ndarray)):
54 input_ = tensor_conversion.convert_to_tensor_v2(input_)
56 input_dtype = getattr(input_, "dtype", None)
58 if input_dtype != spec.dtype:
59 input_dtype_str = "no dtype" if input_dtype is None else str(input_dtype)
61 raise TypeError(
62 "{} must have the same dtype as {}. Expected {}, got {}".format(
63 input_name, spec_name, spec.dtype, input_dtype_str
64 )
65 )
68def _verify_structure_compatible(input_name, spec_name, input_, spec):
69 """Verifies that possibly-structured symbol has types compatible vith another.
71 See _verify_spec_compatible for a more concrete meaning of "compatible".
72 Unspec _verify_spec_compatible, which handles singular Tensor-spec objects,
73 verify_structures_compatible can process structures recognized by tf.nest.
75 Args:
76 input_name: A name to use for `input_` in error messages.
77 spec_name: A name to use for `spec` in error messages.
78 input_: Any, value to verify. May, but doesn't need to, be a structure.
79 spec: Any, value that `input_` must be compatible with. May, but doesn't
80 need to, be a structure.
82 Raises:
83 ValueError if the two types have been determined not to be compatible.
84 """
85 try:
86 nest.assert_same_structure(input_, spec, expand_composites=True)
87 except (ValueError, TypeError) as e:
88 raise TypeError(
89 "{} must have the same element structure as {}.\n\n{}".format(
90 input_name, spec_name, str(e)
91 )
92 ) from e
94 nest.map_structure(
95 functools.partial(_verify_spec_compatible, input_name, spec_name), input_,
96 spec)
99def _next_tf_iterator(iterator, default=py_builtins.UNSPECIFIED):
100 if default is py_builtins.UNSPECIFIED:
101 # Without a default, fall back to the "normal" behavior which raises
102 # a runtime exception.
103 return next(iterator)
104 opt_iterate = iterator.get_next_as_optional()
105 _verify_structure_compatible(
106 "the default argument", "the iterate", default, iterator.element_spec
107 )
108 return cond.cond(
109 opt_iterate.has_value(), opt_iterate.get_value, lambda: default
110 )
113def register_overrides():
114 py_builtins.next_registry.register(
115 iterator_ops.OwnedIterator, _next_tf_iterator
116 )
117 control_flow.for_loop_registry.register(
118 iterator_ops.OwnedIterator, control_flow._tf_iterator_for_stmt # pylint: disable=protected-access
119 )