Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_autograph.py: 33%
102 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Autograph specifc overrides for dataset_ops."""
16from tensorflow.python.autograph.operators import control_flow
17from tensorflow.python.autograph.operators import py_builtins
18from tensorflow.python.data.experimental.ops import take_while_ops
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import control_flow_assert
26from tensorflow.python.ops import gen_string_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.util import nest
31def _general_purpose_scan(ds, init_state, body):
32 """Variant of Dataset.scan with semantics of general-purpose computation."""
33 # Datasets are typically intended for data preprocessing. However, in
34 # autograph loops they usually appear as general-purpose computations (for
35 # example, a custom training loop). These two use cases require significantly
36 # different optimization policies, the most important of which is the device
37 # placement. The flag override for use_default_device below instructs the
38 # runtime to treat the computation as general-purpose, rather than data
39 # preprocessing.
41 # Loaded lazily due to a circular dependency (dataset_ops ->
42 # scan_op -> dataset_ops).
43 # pylint: disable=g-import-not-at-top,protected-access
44 from tensorflow.python.data.ops import scan_op
45 return scan_op._ScanDataset(ds, init_state, body, use_default_device=False)
46 # pylint: enable=g-import-not-at-top,protected-access
49def _tf_ag_dataset_for_stmt(
50 ds, extra_test, body, get_state, set_state, symbol_names, opts
51):
52 """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
53 # Note: This is easier to follow with the insight that the computations in
54 # a dataset pipeline are transposed (aka fused).
55 # For example, given a pipeline input -> scan -> take_while -> reduce,
56 # and a dataset with input [1, 2, 3], the computations occur in the following
57 # order:
58 # reduce(take_while(scan(1)))
59 # reduce(take_while(scan(2)))
60 # reduce(take_while(scan(3)))
62 init_vars = get_state()
63 control_flow.verify_loop_init_vars(init_vars, symbol_names)
65 # Workaround for Dataset.reduce not allowing empty state tensors - create
66 # a dummy state variable that remains unused.
67 # TODO(mdan): reduce should allow and match empty structures.
68 if not init_vars:
69 init_vars = (constant_op.constant(0),)
70 symbol_names = ("<internal dummy>",)
72 def dummy_set_state(unused_dummy):
73 pass
75 def dummy_get_state():
76 return (constant_op.constant(0),)
78 get_state, set_state = dummy_get_state, dummy_set_state
80 def scan_body(scan_state, scan_inputs):
81 """Main body of the Dataset.scan."""
82 loop_vars, iterate = scan_state, scan_inputs
83 set_state(loop_vars)
85 def main_path():
86 body(iterate)
87 new_loop_vars = get_state()
88 control_flow.verify_tf_loop_vars(
89 init_vars,
90 loop_vars,
91 new_loop_vars,
92 symbol_names,
93 opts,
94 check_shapes=False)
95 return new_loop_vars
97 if extra_test is not None:
98 extra_cond = extra_test()
99 new_loop_vars = cond.cond(extra_cond, main_path,
100 lambda: loop_vars)
101 else:
102 # TODO(mdan): the optimizer should be able to remove an invariant cond?
103 extra_cond = (constant_op.constant(True),) # dummy value, unused
104 new_loop_vars = main_path()
106 scan_outputs = new_loop_vars, extra_cond
107 new_scan_state = new_loop_vars
108 return new_scan_state, scan_outputs
110 def take_while_predicate(unused_loop_vars, extra_cond):
111 return extra_cond
113 def reduce_body(unused_reduce_state, scan_outputs):
114 output_loop_vars, unused_extra_cond = scan_outputs
115 new_reduce_state = output_loop_vars
116 return new_reduce_state
118 ds = _general_purpose_scan(ds, init_vars, scan_body)
119 if extra_test is not None:
120 ds = ds.apply(take_while_ops.take_while(take_while_predicate))
121 final_loop_vars = ds.reduce(init_vars, reduce_body)
122 set_state(final_loop_vars)
125def _tf_ag_dataset_abs(ds):
126 specs = nest.flatten(ds.element_spec)
127 if len(specs) == 1:
128 return ds.map(math_ops.abs, num_parallel_calls=dataset_ops.AUTOTUNE)
129 return ds.map(
130 lambda *e: nest.map_structure(math_ops.abs, e),
131 num_parallel_calls=dataset_ops.AUTOTUNE)
134def _tf_ag_dataset_len(s):
135 """Autograph override of the builtin len for dataset_ops.DataSetV2."""
136 l = s.cardinality()
137 msg = gen_string_ops.string_join([
138 "len requires dataset with definitive cardinality, got ",
139 gen_string_ops.as_string(l),
140 ])
141 # TODO(yongtang): UNKNOWN is treated as an error.
142 # In case there are more UNKNOWN cases for dataset, we could
143 # use dataset.reduce() to find out the length (in an expensive way).
144 with ops.control_dependencies([
145 control_flow_assert.Assert(
146 math_ops.logical_and(
147 math_ops.not_equal(l, dataset_ops.INFINITE),
148 math_ops.not_equal(l, dataset_ops.UNKNOWN)), [msg])
149 ]):
150 l = array_ops.identity(l)
152 return l
155def _tf_ag_dataset_enumerate(ds, start=0):
156 return ds.enumerate(start)
159def _tf_ag_dataset_zip(*iterables, strict=False):
160 if strict:
161 raise ValueError("strict zip not supported by Dataset")
162 return dataset_ops.DatasetV2.zip(iterables)
165def _tf_ag_dataset_map(fn, *iterables):
166 return dataset_ops.DatasetV2.zip(iterables).map(fn)
169def _tf_ag_dataset_filter(fn, iterable):
170 return iterable.filter(fn)
173# any() operation is essentially a "if first True element exist".
174# For that it could be translated to `filter(True)` to filter out
175# only `True` element, and then `take(1)`. This works in tf.data
176# as tf.data's filter+take is done in pipeline so it will stop
177# as soon as `take(1)` returns.
178def _tf_ag_dataset_any(iterable):
179 # check and make sure iterable.element_spec only consists of one
180 # element of tf.bool.
181 specs = nest.flatten(iterable.element_spec)
182 if len(specs) != 1 or specs[0].dtype != dtypes.bool:
183 raise ValueError('in graph mode, the "any" builtin only supports datasets '
184 'that return bool scalars; got: {}'.format(
185 iterable.element_spec))
186 ds = iterable.filter(lambda x: x)
187 ds = ds.take(1)
188 ds = ds.reduce(constant_op.constant(False, dtype=dtypes.bool), lambda _, y: y)
189 return ds
192# all() operation is similar to any() and could be translated
193# to `filter(False)` then `take(1)`, and check if `False` exists.
194def _tf_ag_dataset_all(iterable):
195 # check and make sure iterable.element_spec only consists of one
196 # element of tf.bool.
197 specs = nest.flatten(iterable.element_spec)
198 if len(specs) != 1 or specs[0].dtype != dtypes.bool:
199 raise ValueError('in graph mode, the "all" builtin only supports datasets '
200 'that return bool scalars; got: {}'.format(
201 iterable.element_spec))
202 ds = iterable.filter(math_ops.logical_not)
203 ds = ds.take(1)
204 ds = ds.reduce(constant_op.constant(True, dtype=dtypes.bool), lambda _, y: y)
205 return ds
208def register_overrides():
209 """Registers the autograph specific overrides for dataset_ops."""
210 control_flow.for_loop_registry.register(
211 dataset_ops.DatasetV2, _tf_ag_dataset_for_stmt
212 )
213 py_builtins.abs_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_abs)
214 py_builtins.len_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_len)
215 py_builtins.enumerate_registry.register(
216 dataset_ops.DatasetV2, _tf_ag_dataset_enumerate
217 )
218 py_builtins.zip_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_zip)
219 py_builtins.map_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_map)
220 py_builtins.filter_registry.register(
221 dataset_ops.DatasetV2, _tf_ag_dataset_filter
222 )
223 py_builtins.any_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_any)
224 py_builtins.all_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_all)