Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/from_generator_op.py: 18%
107 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The implementation of `tf.data.Dataset.from_generator`."""
17import numpy as np
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.data.ops import structured_function
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import structure
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import type_spec
28from tensorflow.python.ops import gen_dataset_ops
29from tensorflow.python.ops import script_ops
32def _from_generator(generator, output_types, output_shapes, args,
33 output_signature, name):
34 """Creates a `Dataset` whose elements are generated by `generator`.
36 Note: The current implementation of `Dataset.from_generator()` uses
37 `tf.numpy_function` and inherits the same constraints. In particular, it
38 requires the dataset and iterator related operations to be placed
39 on a device in the same process as the Python program that called
40 `Dataset.from_generator()`. In particular, using `from_generator` will
41 preclude the use of tf.data service for scaling out dataset processing.
42 The body of `generator` will not be serialized in a `GraphDef`, and you
43 should not use this method if you need to serialize your model and restore
44 it in a different environment.
46 The `generator` argument must be a callable object that returns
47 an object that supports the `iter()` protocol (e.g. a generator function).
49 The elements generated by `generator` must be compatible with either the
50 given `output_signature` argument or with the given `output_types` and
51 (optionally) `output_shapes` arguments, whichever was specified.
53 The recommended way to call `from_generator` is to use the
54 `output_signature` argument. In this case the output will be assumed to
55 consist of objects with the classes, shapes and types defined by
56 `tf.TypeSpec` objects from `output_signature` argument:
58 >>> def gen():
59 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]])
60 ... yield 42, ragged_tensor
61 >>>
62 >>> dataset = tf.data.Dataset.from_generator(
63 ... gen,
64 ... output_signature=(
65 ... tf.TensorSpec(shape=(), dtype=tf.int32),
66 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
67 >>>
68 >>> list(dataset.take(1))
69 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
70 <tf.RaggedTensor [[1, 2], [3]]>)]
72 There is also a deprecated way to call `from_generator` by either with
73 `output_types` argument alone or together with `output_shapes` argument.
74 In this case the output of the function will be assumed to consist of
75 `tf.Tensor` objects with the types defined by `output_types` and with the
76 shapes which are either unknown or defined by `output_shapes`.
78 Note: If `generator` depends on mutable global variables or other external
79 state, be aware that the runtime may invoke `generator` multiple times
80 (in order to support repeating the `Dataset`) and at any time
81 between the call to `Dataset.from_generator()` and the production of the
82 first element from the generator. Mutating global variables or external
83 state can cause undefined behavior, and we recommend that you explicitly
84 cache any external state in `generator` before calling
85 `Dataset.from_generator()`.
87 Note: While the `output_signature` parameter makes it possible to yield
88 `Dataset` elements, the scope of `Dataset.from_generator()` should be
89 limited to logic that cannot be expressed through tf.data operations. Using
90 tf.data operations within the generator function is an anti-pattern and may
91 result in incremental memory growth.
93 Args:
94 generator: A callable object that returns an object that supports the
95 `iter()` protocol. If `args` is not specified, `generator` must take no
96 arguments; otherwise it must take as many arguments as there are values in
97 `args`.
98 output_types: (Optional.) A (nested) structure of `tf.DType` objects
99 corresponding to each component of an element yielded by `generator`.
100 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` objects
101 corresponding to each component of an element yielded by `generator`.
102 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and
103 passed to `generator` as NumPy-array arguments.
104 output_signature: (Optional.) A (nested) structure of `tf.TypeSpec` objects
105 corresponding to each component of an element yielded by `generator`.
106 name: (Optional.) A name for the tf.data operations used by
107 `from_generator`.
109 Returns:
110 Dataset: A `Dataset`.
111 """
112 if not callable(generator):
113 raise TypeError("`generator` must be a Python callable.")
115 if output_signature is not None:
116 if output_types is not None:
117 raise TypeError("The `output_types` argument can not be used together "
118 "with the `output_signature` argument.")
119 if output_shapes is not None:
120 raise TypeError("The `output_shapes` argument can not be used together "
121 "with the `output_signature` argument.")
122 for spec in nest.flatten(output_signature):
123 if not isinstance(spec, type_spec.TypeSpec):
124 raise TypeError(f"`output_signature` must contain objects that are "
125 f"subclass of `tf.TypeSpec` but found {type(spec)} "
126 f"which is not.")
127 else:
128 if output_types is None:
129 raise TypeError("To specify the output signature you need to provide "
130 "either the `output_signature` argument or the "
131 "`output_types` argument.")
133 if output_signature is None:
134 if output_shapes is None:
135 output_shapes = nest.map_structure(
136 lambda _: tensor_shape.TensorShape(None), output_types)
137 else:
138 output_shapes = nest.map_structure_up_to(output_types,
139 tensor_shape.as_shape,
140 output_shapes)
141 output_signature = nest.map_structure_up_to(output_types,
142 tensor_spec.TensorSpec,
143 output_shapes, output_types)
144 if all(
145 isinstance(x, tensor_spec.TensorSpec)
146 for x in nest.flatten(output_signature)):
147 output_types = nest.pack_sequence_as(
148 output_signature, [x.dtype for x in nest.flatten(output_signature)])
149 output_shapes = nest.pack_sequence_as(
150 output_signature, [x.shape for x in nest.flatten(output_signature)])
152 if args is None:
153 args = ()
154 else:
155 args = tuple(ops.convert_n_to_tensor(args, name="args"))
157 generator_state = dataset_ops.DatasetV2._GeneratorState(generator) # pylint: disable=protected-access
159 def get_iterator_id_fn(unused_dummy):
160 """Creates a unique `iterator_id` for each pass over the dataset.
162 The returned `iterator_id` disambiguates between multiple concurrently
163 existing iterators.
165 Args:
166 unused_dummy: Ignored value.
168 Returns:
169 A `tf.int64` tensor whose value uniquely identifies an iterator in
170 `generator_state`.
171 """
172 return script_ops.numpy_function(generator_state.get_next_id, args,
173 dtypes.int64)
175 def generator_next_fn(iterator_id_t):
176 """Generates the next element from iterator with ID `iterator_id_t`.
178 We map this function across an infinite repetition of the
179 `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
181 Args:
182 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the
183 iterator in `generator_state` from which to generate an element.
185 Returns:
186 The next element to generate from the iterator.
187 """
188 if output_types and output_shapes:
189 flattened_types = [
190 dtypes.as_dtype(dt) for dt in nest.flatten(output_types)
191 ]
192 flattened_shapes = nest.flatten(output_shapes)
194 def generator_py_func(iterator_id):
195 """A `py_func` that will be called to invoke the iterator."""
196 # `next()` raises `StopIteration` when there are no more
197 # elements remaining to be generated.
198 values = next(generator_state.get_iterator(iterator_id))
200 # Use the same _convert function from the py_func() implementation to
201 # convert the returned values to arrays early, so that we can inspect
202 # their values.
203 try:
204 flattened_values = nest.flatten_up_to(output_types, values)
205 except (TypeError, ValueError) as e:
206 raise TypeError(
207 f"`generator` yielded an element that did not match the "
208 f"expected structure. The expected structure was "
209 f"{output_types}, but the yielded element was {values}.") from e
210 ret_arrays = []
211 for ret, dtype in zip(flattened_values, flattened_types):
212 try:
213 ret_arrays.append(
214 script_ops.FuncRegistry._convert( # pylint: disable=protected-access
215 ret,
216 dtype=dtype.as_numpy_dtype))
217 except (TypeError, ValueError) as e:
218 raise TypeError(
219 f"`generator` yielded an element that could not be "
220 f"converted to the expected type. The expected type was "
221 f"{dtype.name}, but the yielded element was {ret}.") from e
223 # Additional type and shape checking to ensure that the components of
224 # the generated element match the `output_types` and `output_shapes`
225 # arguments.
226 for (ret_array, expected_dtype,
227 expected_shape) in zip(ret_arrays, flattened_types,
228 flattened_shapes):
229 if ret_array.dtype != expected_dtype.as_numpy_dtype:
230 raise TypeError(
231 f"`generator` yielded an element of type {ret_array.dtype} "
232 f"where an element of type {expected_dtype.as_numpy_dtype} "
233 f"was expected.")
234 if not expected_shape.is_compatible_with(ret_array.shape):
235 raise TypeError(
236 f"`generator` yielded an element of shape {ret_array.shape} "
237 f"where an element of shape {expected_shape} was expected.")
239 return ret_arrays
241 flat_values = script_ops.numpy_function(generator_py_func,
242 [iterator_id_t], flattened_types)
244 # In debug mode the numpy_function will return a scalar if
245 # generator_py_func produces only a single value.
246 if not isinstance(flat_values, (list, tuple)):
247 flat_values = [flat_values]
249 # The `py_func()` op drops the inferred shapes, so we add them back in
250 # here.
251 if output_shapes is not None:
252 for ret_t, shape in zip(flat_values, flattened_shapes):
253 ret_t.set_shape(shape)
255 return nest.pack_sequence_as(output_types, flat_values)
256 else:
257 flat_output_types = structure.get_flat_tensor_types(output_signature)
259 def generator_py_func(iterator_id):
260 """A `py_func` that will be called to invoke the iterator."""
261 # `next()` raises `StopIteration` when there are no more
262 # elements remaining to be generated.
263 values = next(generator_state.get_iterator(iterator_id.numpy()))
265 try:
266 values = structure.normalize_element(values, output_signature)
267 except (TypeError, ValueError) as e:
268 raise TypeError(
269 f"`generator` yielded an element that did not match the "
270 f"expected structure. The expected structure was "
271 f"{output_signature}, but the yielded element was "
272 f"{values}.") from e
274 values_spec = structure.type_spec_from_value(values)
276 if not structure.are_compatible(values_spec, output_signature):
277 raise TypeError(
278 f"`generator` yielded an element of {values_spec} where an "
279 f"element of {output_signature} was expected.")
281 return structure.to_tensor_list(output_signature, values)
283 return script_ops.eager_py_func(
284 generator_py_func, inp=[iterator_id_t], Tout=flat_output_types)
286 def finalize_fn(iterator_id_t):
287 """Releases host-side state for the iterator with ID `iterator_id_t`."""
289 def finalize_py_func(iterator_id):
290 generator_state.iterator_completed(iterator_id)
291 # We return a dummy value so that the `finalize_fn` has a valid
292 # signature.
293 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
294 # casting in `py_func()` will create an array of `np.int32` on Windows,
295 # leading to a runtime error.
296 return np.array(0, dtype=np.int64)
298 return script_ops.numpy_function(finalize_py_func, [iterator_id_t],
299 dtypes.int64)
301 # This function associates each traversal of `generator` with a unique
302 # iterator ID.
303 def flat_map_fn(dummy_arg):
304 # The `get_iterator_id_fn` gets a unique ID for the current instance of
305 # of the generator.
306 # The `generator_next_fn` gets the next element from the iterator with the
307 # given ID, and raises StopIteration when that iterator contains no
308 # more elements.
309 return _GeneratorDataset(
310 dummy_arg,
311 get_iterator_id_fn,
312 generator_next_fn,
313 finalize_fn,
314 output_signature,
315 name=name)
317 # A single-element dataset that, each time it is evaluated, contains a
318 # freshly-generated and unique (for the returned dataset) int64
319 # ID that will be used to identify the appropriate Python state, which
320 # is encapsulated in `generator_state`, and captured in
321 # `get_iterator_id_map_fn`.
322 dummy = 0
323 id_dataset = dataset_ops.Dataset.from_tensors(dummy, name=name)
325 # A dataset that contains all of the elements generated by a
326 # single iterator created from `generator`, identified by the
327 # iterator ID contained in `id_dataset`. Lifting the iteration
328 # into a flat_map here enables multiple repetitions and/or nested
329 # versions of the returned dataset to be created, because it forces
330 # the generation of a new ID for each version.
331 return id_dataset.flat_map(flat_map_fn, name=name)
334class _GeneratorDataset(dataset_ops.DatasetSource):
335 """A `Dataset` that generates elements by invoking a function."""
337 def __init__(self,
338 init_args,
339 init_func,
340 next_func,
341 finalize_func,
342 output_signature,
343 name=None):
344 """Constructs a `_GeneratorDataset`.
346 Args:
347 init_args: A (nested) structure representing the arguments to `init_func`.
348 init_func: A TensorFlow function that will be called on `init_args` each
349 time a C++ iterator over this dataset is constructed. Returns a (nested)
350 structure representing the "state" of the dataset.
351 next_func: A TensorFlow function that will be called on the result of
352 `init_func` to produce each element, and that raises `OutOfRangeError`
353 to terminate iteration.
354 finalize_func: A TensorFlow function that will be called on the result of
355 `init_func` immediately before a C++ iterator over this dataset is
356 destroyed. The return value is ignored.
357 output_signature: A (nested) structure of `tf.TypeSpec` objects describing
358 the output of `next_func`.
359 name: Optional. A name for the tf.data transformation.
360 """
361 self._init_args = init_args
363 self._init_structure = structure.type_spec_from_value(init_args)
365 self._init_func = structured_function.StructuredFunctionWrapper(
366 init_func,
367 self._transformation_name(),
368 input_structure=self._init_structure)
370 self._next_func = structured_function.StructuredFunctionWrapper(
371 next_func,
372 self._transformation_name(),
373 input_structure=self._init_func.output_structure)
375 self._finalize_func = structured_function.StructuredFunctionWrapper(
376 finalize_func,
377 self._transformation_name(),
378 input_structure=self._init_func.output_structure)
380 self._output_signature = output_signature
382 self._name = name
384 variant_tensor = gen_dataset_ops.generator_dataset(
385 structure.to_tensor_list(self._init_structure, self._init_args) +
386 self._init_func.function.captured_inputs,
387 self._next_func.function.captured_inputs,
388 self._finalize_func.function.captured_inputs,
389 init_func=self._init_func.function,
390 next_func=self._next_func.function,
391 finalize_func=self._finalize_func.function,
392 **self._common_args)
393 super().__init__(variant_tensor)
395 @property
396 def element_spec(self):
397 return self._output_signature
399 def _transformation_name(self):
400 return "Dataset.from_generator()"