Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/prefetching_ops.py: 31%

98 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrapper for prefetching_ops.""" 

16from tensorflow.python.data.ops import dataset_ops 

17from tensorflow.python.data.ops import iterator_ops 

18from tensorflow.python.data.ops import structured_function 

19from tensorflow.python.data.util import structure 

20from tensorflow.python.eager import def_function 

21from tensorflow.python.framework import device as framework_device 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.framework import tensor_spec 

25from tensorflow.python.ops import array_ops 

26from tensorflow.python.ops import functional_ops 

27from tensorflow.python.ops import gen_dataset_ops 

28from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

29from tensorflow.python.ops import resource_variable_ops 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33@tf_export("data.experimental.prefetch_to_device") 

34def prefetch_to_device(device, buffer_size=None): 

35 """A transformation that prefetches dataset values to the given `device`. 

36 

37 NOTE: Although the transformation creates a `tf.data.Dataset`, the 

38 transformation must be the final `Dataset` in the input pipeline. 

39 

40 For example, 

41 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

42 >>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0")) 

43 >>> for element in dataset: 

44 ... print(f'Tensor {element} is on device {element.device}') 

45 Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0 

46 Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0 

47 Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0 

48 

49 Args: 

50 device: A string. The name of a device to which elements will be prefetched. 

51 buffer_size: (Optional.) The number of elements to buffer on `device`. 

52 Defaults to an automatically chosen value. 

53 

54 Returns: 

55 A `Dataset` transformation function, which can be passed to 

56 `tf.data.Dataset.apply`. 

57 """ 

58 def _apply_fn(dataset): 

59 return dataset.apply( 

60 copy_to_device(target_device=device)).prefetch(buffer_size) 

61 

62 return _apply_fn 

63 

64 

65@tf_export("data.experimental.copy_to_device") 

66def copy_to_device(target_device, source_device="/cpu:0"): 

67 """A transformation that copies dataset elements to the given `target_device`. 

68 

69 Args: 

70 target_device: The name of a device to which elements will be copied. 

71 source_device: The original device on which `input_dataset` will be placed. 

72 

73 Returns: 

74 A `Dataset` transformation function, which can be passed to 

75 `tf.data.Dataset.apply`. 

76 """ 

77 

78 def _apply_fn(dataset): 

79 return _CopyToDeviceDataset( 

80 dataset, target_device=target_device, source_device=source_device) 

81 

82 return _apply_fn 

83 

84 

85# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate 

86# all inputs to the Op are in host memory, thereby avoiding some unnecessary 

87# Sends and Recvs. 

88class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): 

89 """A `Dataset` that copies elements to another device.""" 

90 

91 def __init__(self, input_dataset, target_device, source_device="/cpu:0"): 

92 """Constructs a _CopyToDeviceDataset. 

93 

94 Args: 

95 input_dataset: `Dataset` to be copied 

96 target_device: The name of the device to which elements would be copied. 

97 source_device: Device where input_dataset would be placed. 

98 """ 

99 self._input_dataset = input_dataset._apply_debug_options() # pylint: disable=protected-access 

100 self._target_device = target_device 

101 spec = framework_device.DeviceSpec().from_string(self._target_device) 

102 self._is_gpu_target = (spec.device_type == "GPU") 

103 self._source_device_string = source_device 

104 self._source_device = ops.convert_to_tensor(source_device) 

105 

106 wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( 

107 self._input_dataset._variant_tensor) # pylint: disable=protected-access 

108 

109 @def_function.function() 

110 def _init_func(): 

111 """Creates an iterator for the input dataset. 

112 

113 Returns: 

114 A `string` tensor that encapsulates the iterator created. 

115 """ 

116 ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) 

117 resource = gen_dataset_ops.anonymous_iterator( 

118 **self._input_dataset._flat_structure) # pylint: disable=protected-access 

119 with ops.control_dependencies( 

120 [gen_dataset_ops.make_iterator(ds_variant, resource)]): 

121 return gen_dataset_ops.iterator_to_string_handle(resource) 

122 

123 init_func_concrete = _init_func.get_concrete_function() # pylint: disable=protected-access 

124 

125 @def_function.function() 

126 def _remote_init_func(): 

127 return functional_ops.remote_call( 

128 target=self._source_device, 

129 args=init_func_concrete.captured_inputs, 

130 Tout=[dtypes.string], 

131 f=init_func_concrete) 

132 

133 self._init_func = _remote_init_func.get_concrete_function() # pylint: disable=protected-access 

134 self._init_captured_args = self._init_func.captured_inputs 

135 

136 @def_function.function( 

137 input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 

138 def _next_func(string_handle): 

139 """Calls get_next for created iterator. 

140 

141 Args: 

142 string_handle: An iterator string handle created by _init_func 

143 Returns: 

144 The elements generated from `input_dataset` 

145 """ 

146 with ops.device(self._source_device_string): 

147 iterator = iterator_ops.Iterator.from_string_handle( 

148 string_handle, 

149 dataset_ops.get_legacy_output_types(self), 

150 dataset_ops.get_legacy_output_shapes(self), 

151 dataset_ops.get_legacy_output_classes(self)) 

152 return structure.to_tensor_list(self.element_spec, iterator.get_next()) 

153 

154 next_func_concrete = _next_func.get_concrete_function() # pylint: disable=protected-access 

155 

156 @def_function.function( 

157 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 

158 experimental_attributes={"experimental_ints_on_device": True}) 

159 def _remote_next_func(string_handle): 

160 return functional_ops.remote_call( 

161 target=self._source_device, 

162 args=[string_handle] + next_func_concrete.captured_inputs, 

163 Tout=self._input_dataset._flat_types, # pylint: disable=protected-access 

164 f=next_func_concrete) 

165 

166 self._next_func = _remote_next_func.get_concrete_function() 

167 self._next_captured_args = self._next_func.captured_inputs 

168 

169 @def_function.function( 

170 input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 

171 def _finalize_func(string_handle): 

172 """Destroys the iterator resource created. 

173 

174 Args: 

175 string_handle: An iterator string handle created by _init_func 

176 Returns: 

177 Tensor constant 0 

178 """ 

179 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 

180 string_handle, 

181 **self._input_dataset._flat_structure) # pylint: disable=protected-access 

182 with ops.control_dependencies([ 

183 resource_variable_ops.destroy_resource_op( 

184 iterator_resource, ignore_lookup_error=True)]): 

185 return array_ops.constant(0, dtypes.int64) 

186 

187 finalize_func_concrete = _finalize_func.get_concrete_function() # pylint: disable=protected-access 

188 

189 @def_function.function( 

190 input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 

191 def _remote_finalize_func(string_handle): 

192 return functional_ops.remote_call( 

193 target=self._source_device, 

194 args=[string_handle] + finalize_func_concrete.captured_inputs, 

195 Tout=[dtypes.int64], 

196 f=finalize_func_concrete) 

197 

198 self._finalize_func = _remote_finalize_func.get_concrete_function( # pylint: disable=protected-access 

199 ) 

200 self._finalize_captured_args = self._finalize_func.captured_inputs 

201 

202 g = ops.get_default_graph() 

203 self._init_func.add_to_graph(g) 

204 self._next_func.add_to_graph(g) 

205 self._finalize_func.add_to_graph(g) 

206 # pylint: enable=protected-scope 

207 

208 with ops.device(self._target_device): 

209 variant_tensor = gen_dataset_ops.generator_dataset( 

210 self._init_captured_args, 

211 self._next_captured_args, 

212 self._finalize_captured_args, 

213 init_func=self._init_func, 

214 next_func=self._next_func, 

215 finalize_func=self._finalize_func, 

216 **self._input_dataset._flat_structure) # pylint: disable=protected-access 

217 super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor) 

218 

219 # The one_shot_iterator implementation needs a 0 arg _make_dataset function 

220 # that thereby captures all the inputs required to create the dataset. Since 

221 # there are strings that are inputs to the GeneratorDataset which can't be 

222 # placed on a GPU, this fails for the GPU case. Therefore, disabling it for 

223 # GPU 

224 def make_one_shot_iterator(self): 

225 if self._is_gpu_target: 

226 raise ValueError( 

227 "`make_one_shot_iterator` is not compatible with GPU execution. " 

228 "Please use `Dataset.make_initializable_iterator()` instead." 

229 ) 

230 else: 

231 return super(_CopyToDeviceDataset, self).make_one_shot_iterator() 

232 

233 

234class _MapOnGpuDataset(dataset_ops.UnaryDataset): 

235 """A `Dataset` that maps a function over elements in its using a GPU.""" 

236 

237 def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): 

238 """See `Dataset.map()` for details.""" 

239 self._input_dataset = input_dataset 

240 self._use_inter_op_parallelism = use_inter_op_parallelism 

241 

242 self._map_func = structured_function.StructuredFunctionWrapper( 

243 map_func, 

244 self._transformation_name(), 

245 dataset=input_dataset, 

246 defun_kwargs={"experimental_ints_on_device": True}) 

247 variant_tensor = ged_ops.experimental_map_dataset( 

248 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

249 self._map_func.function.captured_inputs, 

250 f=self._map_func.function, 

251 use_inter_op_parallelism=self._use_inter_op_parallelism, 

252 **self._flat_structure) 

253 super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor) 

254 

255 def _functions(self): 

256 return [self._map_func] 

257 

258 @property 

259 def element_spec(self): 

260 return self._map_func.output_structure 

261 

262 def _transformation_name(self): 

263 return "map_on_gpu()" 

264 

265 

266def map_on_gpu(map_func): 

267 """Maps `map_func` across the elements of this dataset. 

268 

269 NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs 

270 `map_func` on GPU. It must be used after applying the 

271 `tf.data.experimental.copy_to_device` transformation with a GPU device 

272 argument. 

273 

274 Args: 

275 map_func: A function mapping a nested structure of tensors (having shapes 

276 and types defined by `self.output_shapes` and `self.output_types`) to 

277 another nested structure of tensors. 

278 

279 Returns: 

280 A `Dataset` transformation function, which can be passed to 

281 `tf.data.Dataset.apply`. 

282 """ 

283 

284 def _apply_fn(dataset): 

285 return _MapOnGpuDataset(dataset, map_func) 

286 

287 return _apply_fn