Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/prefetching_ops.py: 31%
98 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from tensorflow.python.data.ops import dataset_ops
17from tensorflow.python.data.ops import iterator_ops
18from tensorflow.python.data.ops import structured_function
19from tensorflow.python.data.util import structure
20from tensorflow.python.eager import def_function
21from tensorflow.python.framework import device as framework_device
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import functional_ops
27from tensorflow.python.ops import gen_dataset_ops
28from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.util.tf_export import tf_export
33@tf_export("data.experimental.prefetch_to_device")
34def prefetch_to_device(device, buffer_size=None):
35 """A transformation that prefetches dataset values to the given `device`.
37 NOTE: Although the transformation creates a `tf.data.Dataset`, the
38 transformation must be the final `Dataset` in the input pipeline.
40 For example,
41 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
42 >>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0"))
43 >>> for element in dataset:
44 ... print(f'Tensor {element} is on device {element.device}')
45 Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0
46 Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0
47 Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0
49 Args:
50 device: A string. The name of a device to which elements will be prefetched.
51 buffer_size: (Optional.) The number of elements to buffer on `device`.
52 Defaults to an automatically chosen value.
54 Returns:
55 A `Dataset` transformation function, which can be passed to
56 `tf.data.Dataset.apply`.
57 """
58 def _apply_fn(dataset):
59 return dataset.apply(
60 copy_to_device(target_device=device)).prefetch(buffer_size)
62 return _apply_fn
65@tf_export("data.experimental.copy_to_device")
66def copy_to_device(target_device, source_device="/cpu:0"):
67 """A transformation that copies dataset elements to the given `target_device`.
69 Args:
70 target_device: The name of a device to which elements will be copied.
71 source_device: The original device on which `input_dataset` will be placed.
73 Returns:
74 A `Dataset` transformation function, which can be passed to
75 `tf.data.Dataset.apply`.
76 """
78 def _apply_fn(dataset):
79 return _CopyToDeviceDataset(
80 dataset, target_device=target_device, source_device=source_device)
82 return _apply_fn
85# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
86# all inputs to the Op are in host memory, thereby avoiding some unnecessary
87# Sends and Recvs.
88class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
89 """A `Dataset` that copies elements to another device."""
91 def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
92 """Constructs a _CopyToDeviceDataset.
94 Args:
95 input_dataset: `Dataset` to be copied
96 target_device: The name of the device to which elements would be copied.
97 source_device: Device where input_dataset would be placed.
98 """
99 self._input_dataset = input_dataset._apply_debug_options() # pylint: disable=protected-access
100 self._target_device = target_device
101 spec = framework_device.DeviceSpec().from_string(self._target_device)
102 self._is_gpu_target = (spec.device_type == "GPU")
103 self._source_device_string = source_device
104 self._source_device = ops.convert_to_tensor(source_device)
106 wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
107 self._input_dataset._variant_tensor) # pylint: disable=protected-access
109 @def_function.function()
110 def _init_func():
111 """Creates an iterator for the input dataset.
113 Returns:
114 A `string` tensor that encapsulates the iterator created.
115 """
116 ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
117 resource = gen_dataset_ops.anonymous_iterator(
118 **self._input_dataset._flat_structure) # pylint: disable=protected-access
119 with ops.control_dependencies(
120 [gen_dataset_ops.make_iterator(ds_variant, resource)]):
121 return gen_dataset_ops.iterator_to_string_handle(resource)
123 init_func_concrete = _init_func.get_concrete_function() # pylint: disable=protected-access
125 @def_function.function()
126 def _remote_init_func():
127 return functional_ops.remote_call(
128 target=self._source_device,
129 args=init_func_concrete.captured_inputs,
130 Tout=[dtypes.string],
131 f=init_func_concrete)
133 self._init_func = _remote_init_func.get_concrete_function() # pylint: disable=protected-access
134 self._init_captured_args = self._init_func.captured_inputs
136 @def_function.function(
137 input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
138 def _next_func(string_handle):
139 """Calls get_next for created iterator.
141 Args:
142 string_handle: An iterator string handle created by _init_func
143 Returns:
144 The elements generated from `input_dataset`
145 """
146 with ops.device(self._source_device_string):
147 iterator = iterator_ops.Iterator.from_string_handle(
148 string_handle,
149 dataset_ops.get_legacy_output_types(self),
150 dataset_ops.get_legacy_output_shapes(self),
151 dataset_ops.get_legacy_output_classes(self))
152 return structure.to_tensor_list(self.element_spec, iterator.get_next())
154 next_func_concrete = _next_func.get_concrete_function() # pylint: disable=protected-access
156 @def_function.function(
157 input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
158 experimental_attributes={"experimental_ints_on_device": True})
159 def _remote_next_func(string_handle):
160 return functional_ops.remote_call(
161 target=self._source_device,
162 args=[string_handle] + next_func_concrete.captured_inputs,
163 Tout=self._input_dataset._flat_types, # pylint: disable=protected-access
164 f=next_func_concrete)
166 self._next_func = _remote_next_func.get_concrete_function()
167 self._next_captured_args = self._next_func.captured_inputs
169 @def_function.function(
170 input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
171 def _finalize_func(string_handle):
172 """Destroys the iterator resource created.
174 Args:
175 string_handle: An iterator string handle created by _init_func
176 Returns:
177 Tensor constant 0
178 """
179 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
180 string_handle,
181 **self._input_dataset._flat_structure) # pylint: disable=protected-access
182 with ops.control_dependencies([
183 resource_variable_ops.destroy_resource_op(
184 iterator_resource, ignore_lookup_error=True)]):
185 return array_ops.constant(0, dtypes.int64)
187 finalize_func_concrete = _finalize_func.get_concrete_function() # pylint: disable=protected-access
189 @def_function.function(
190 input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
191 def _remote_finalize_func(string_handle):
192 return functional_ops.remote_call(
193 target=self._source_device,
194 args=[string_handle] + finalize_func_concrete.captured_inputs,
195 Tout=[dtypes.int64],
196 f=finalize_func_concrete)
198 self._finalize_func = _remote_finalize_func.get_concrete_function( # pylint: disable=protected-access
199 )
200 self._finalize_captured_args = self._finalize_func.captured_inputs
202 g = ops.get_default_graph()
203 self._init_func.add_to_graph(g)
204 self._next_func.add_to_graph(g)
205 self._finalize_func.add_to_graph(g)
206 # pylint: enable=protected-scope
208 with ops.device(self._target_device):
209 variant_tensor = gen_dataset_ops.generator_dataset(
210 self._init_captured_args,
211 self._next_captured_args,
212 self._finalize_captured_args,
213 init_func=self._init_func,
214 next_func=self._next_func,
215 finalize_func=self._finalize_func,
216 **self._input_dataset._flat_structure) # pylint: disable=protected-access
217 super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
219 # The one_shot_iterator implementation needs a 0 arg _make_dataset function
220 # that thereby captures all the inputs required to create the dataset. Since
221 # there are strings that are inputs to the GeneratorDataset which can't be
222 # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
223 # GPU
224 def make_one_shot_iterator(self):
225 if self._is_gpu_target:
226 raise ValueError(
227 "`make_one_shot_iterator` is not compatible with GPU execution. "
228 "Please use `Dataset.make_initializable_iterator()` instead."
229 )
230 else:
231 return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
234class _MapOnGpuDataset(dataset_ops.UnaryDataset):
235 """A `Dataset` that maps a function over elements in its using a GPU."""
237 def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
238 """See `Dataset.map()` for details."""
239 self._input_dataset = input_dataset
240 self._use_inter_op_parallelism = use_inter_op_parallelism
242 self._map_func = structured_function.StructuredFunctionWrapper(
243 map_func,
244 self._transformation_name(),
245 dataset=input_dataset,
246 defun_kwargs={"experimental_ints_on_device": True})
247 variant_tensor = ged_ops.experimental_map_dataset(
248 self._input_dataset._variant_tensor, # pylint: disable=protected-access
249 self._map_func.function.captured_inputs,
250 f=self._map_func.function,
251 use_inter_op_parallelism=self._use_inter_op_parallelism,
252 **self._flat_structure)
253 super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
255 def _functions(self):
256 return [self._map_func]
258 @property
259 def element_spec(self):
260 return self._map_func.output_structure
262 def _transformation_name(self):
263 return "map_on_gpu()"
266def map_on_gpu(map_func):
267 """Maps `map_func` across the elements of this dataset.
269 NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
270 `map_func` on GPU. It must be used after applying the
271 `tf.data.experimental.copy_to_device` transformation with a GPU device
272 argument.
274 Args:
275 map_func: A function mapping a nested structure of tensors (having shapes
276 and types defined by `self.output_shapes` and `self.output_types`) to
277 another nested structure of tensors.
279 Returns:
280 A `Dataset` transformation function, which can be passed to
281 `tf.data.Dataset.apply`.
282 """
284 def _apply_fn(dataset):
285 return _MapOnGpuDataset(dataset, map_func)
287 return _apply_fn