Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/from_generator_op.py: 18%

107 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The implementation of `tf.data.Dataset.from_generator`.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.data.ops import dataset_ops 

20from tensorflow.python.data.ops import structured_function 

21from tensorflow.python.data.util import nest 

22from tensorflow.python.data.util import structure 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_shape 

26from tensorflow.python.framework import tensor_spec 

27from tensorflow.python.framework import type_spec 

28from tensorflow.python.ops import gen_dataset_ops 

29from tensorflow.python.ops import script_ops 

30 

31 

32def _from_generator(generator, output_types, output_shapes, args, 

33 output_signature, name): 

34 """Creates a `Dataset` whose elements are generated by `generator`. 

35 

36 Note: The current implementation of `Dataset.from_generator()` uses 

37 `tf.numpy_function` and inherits the same constraints. In particular, it 

38 requires the dataset and iterator related operations to be placed 

39 on a device in the same process as the Python program that called 

40 `Dataset.from_generator()`. In particular, using `from_generator` will 

41 preclude the use of tf.data service for scaling out dataset processing. 

42 The body of `generator` will not be serialized in a `GraphDef`, and you 

43 should not use this method if you need to serialize your model and restore 

44 it in a different environment. 

45 

46 The `generator` argument must be a callable object that returns 

47 an object that supports the `iter()` protocol (e.g. a generator function). 

48 

49 The elements generated by `generator` must be compatible with either the 

50 given `output_signature` argument or with the given `output_types` and 

51 (optionally) `output_shapes` arguments, whichever was specified. 

52 

53 The recommended way to call `from_generator` is to use the 

54 `output_signature` argument. In this case the output will be assumed to 

55 consist of objects with the classes, shapes and types defined by 

56 `tf.TypeSpec` objects from `output_signature` argument: 

57 

58 >>> def gen(): 

59 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]]) 

60 ... yield 42, ragged_tensor 

61 >>> 

62 >>> dataset = tf.data.Dataset.from_generator( 

63 ... gen, 

64 ... output_signature=( 

65 ... tf.TensorSpec(shape=(), dtype=tf.int32), 

66 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32))) 

67 >>> 

68 >>> list(dataset.take(1)) 

69 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>, 

70 <tf.RaggedTensor [[1, 2], [3]]>)] 

71 

72 There is also a deprecated way to call `from_generator` by either with 

73 `output_types` argument alone or together with `output_shapes` argument. 

74 In this case the output of the function will be assumed to consist of 

75 `tf.Tensor` objects with the types defined by `output_types` and with the 

76 shapes which are either unknown or defined by `output_shapes`. 

77 

78 Note: If `generator` depends on mutable global variables or other external 

79 state, be aware that the runtime may invoke `generator` multiple times 

80 (in order to support repeating the `Dataset`) and at any time 

81 between the call to `Dataset.from_generator()` and the production of the 

82 first element from the generator. Mutating global variables or external 

83 state can cause undefined behavior, and we recommend that you explicitly 

84 cache any external state in `generator` before calling 

85 `Dataset.from_generator()`. 

86 

87 Note: While the `output_signature` parameter makes it possible to yield 

88 `Dataset` elements, the scope of `Dataset.from_generator()` should be 

89 limited to logic that cannot be expressed through tf.data operations. Using 

90 tf.data operations within the generator function is an anti-pattern and may 

91 result in incremental memory growth. 

92 

93 Args: 

94 generator: A callable object that returns an object that supports the 

95 `iter()` protocol. If `args` is not specified, `generator` must take no 

96 arguments; otherwise it must take as many arguments as there are values in 

97 `args`. 

98 output_types: (Optional.) A (nested) structure of `tf.DType` objects 

99 corresponding to each component of an element yielded by `generator`. 

100 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` objects 

101 corresponding to each component of an element yielded by `generator`. 

102 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and 

103 passed to `generator` as NumPy-array arguments. 

104 output_signature: (Optional.) A (nested) structure of `tf.TypeSpec` objects 

105 corresponding to each component of an element yielded by `generator`. 

106 name: (Optional.) A name for the tf.data operations used by 

107 `from_generator`. 

108 

109 Returns: 

110 Dataset: A `Dataset`. 

111 """ 

112 if not callable(generator): 

113 raise TypeError("`generator` must be a Python callable.") 

114 

115 if output_signature is not None: 

116 if output_types is not None: 

117 raise TypeError("The `output_types` argument can not be used together " 

118 "with the `output_signature` argument.") 

119 if output_shapes is not None: 

120 raise TypeError("The `output_shapes` argument can not be used together " 

121 "with the `output_signature` argument.") 

122 for spec in nest.flatten(output_signature): 

123 if not isinstance(spec, type_spec.TypeSpec): 

124 raise TypeError(f"`output_signature` must contain objects that are " 

125 f"subclass of `tf.TypeSpec` but found {type(spec)} " 

126 f"which is not.") 

127 else: 

128 if output_types is None: 

129 raise TypeError("To specify the output signature you need to provide " 

130 "either the `output_signature` argument or the " 

131 "`output_types` argument.") 

132 

133 if output_signature is None: 

134 if output_shapes is None: 

135 output_shapes = nest.map_structure( 

136 lambda _: tensor_shape.TensorShape(None), output_types) 

137 else: 

138 output_shapes = nest.map_structure_up_to(output_types, 

139 tensor_shape.as_shape, 

140 output_shapes) 

141 output_signature = nest.map_structure_up_to(output_types, 

142 tensor_spec.TensorSpec, 

143 output_shapes, output_types) 

144 if all( 

145 isinstance(x, tensor_spec.TensorSpec) 

146 for x in nest.flatten(output_signature)): 

147 output_types = nest.pack_sequence_as( 

148 output_signature, [x.dtype for x in nest.flatten(output_signature)]) 

149 output_shapes = nest.pack_sequence_as( 

150 output_signature, [x.shape for x in nest.flatten(output_signature)]) 

151 

152 if args is None: 

153 args = () 

154 else: 

155 args = tuple(ops.convert_n_to_tensor(args, name="args")) 

156 

157 generator_state = dataset_ops.DatasetV2._GeneratorState(generator) # pylint: disable=protected-access 

158 

159 def get_iterator_id_fn(unused_dummy): 

160 """Creates a unique `iterator_id` for each pass over the dataset. 

161 

162 The returned `iterator_id` disambiguates between multiple concurrently 

163 existing iterators. 

164 

165 Args: 

166 unused_dummy: Ignored value. 

167 

168 Returns: 

169 A `tf.int64` tensor whose value uniquely identifies an iterator in 

170 `generator_state`. 

171 """ 

172 return script_ops.numpy_function(generator_state.get_next_id, args, 

173 dtypes.int64) 

174 

175 def generator_next_fn(iterator_id_t): 

176 """Generates the next element from iterator with ID `iterator_id_t`. 

177 

178 We map this function across an infinite repetition of the 

179 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 

180 

181 Args: 

182 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the 

183 iterator in `generator_state` from which to generate an element. 

184 

185 Returns: 

186 The next element to generate from the iterator. 

187 """ 

188 if output_types and output_shapes: 

189 flattened_types = [ 

190 dtypes.as_dtype(dt) for dt in nest.flatten(output_types) 

191 ] 

192 flattened_shapes = nest.flatten(output_shapes) 

193 

194 def generator_py_func(iterator_id): 

195 """A `py_func` that will be called to invoke the iterator.""" 

196 # `next()` raises `StopIteration` when there are no more 

197 # elements remaining to be generated. 

198 values = next(generator_state.get_iterator(iterator_id)) 

199 

200 # Use the same _convert function from the py_func() implementation to 

201 # convert the returned values to arrays early, so that we can inspect 

202 # their values. 

203 try: 

204 flattened_values = nest.flatten_up_to(output_types, values) 

205 except (TypeError, ValueError) as e: 

206 raise TypeError( 

207 f"`generator` yielded an element that did not match the " 

208 f"expected structure. The expected structure was " 

209 f"{output_types}, but the yielded element was {values}.") from e 

210 ret_arrays = [] 

211 for ret, dtype in zip(flattened_values, flattened_types): 

212 try: 

213 ret_arrays.append( 

214 script_ops.FuncRegistry._convert( # pylint: disable=protected-access 

215 ret, 

216 dtype=dtype.as_numpy_dtype)) 

217 except (TypeError, ValueError) as e: 

218 raise TypeError( 

219 f"`generator` yielded an element that could not be " 

220 f"converted to the expected type. The expected type was " 

221 f"{dtype.name}, but the yielded element was {ret}.") from e 

222 

223 # Additional type and shape checking to ensure that the components of 

224 # the generated element match the `output_types` and `output_shapes` 

225 # arguments. 

226 for (ret_array, expected_dtype, 

227 expected_shape) in zip(ret_arrays, flattened_types, 

228 flattened_shapes): 

229 if ret_array.dtype != expected_dtype.as_numpy_dtype: 

230 raise TypeError( 

231 f"`generator` yielded an element of type {ret_array.dtype} " 

232 f"where an element of type {expected_dtype.as_numpy_dtype} " 

233 f"was expected.") 

234 if not expected_shape.is_compatible_with(ret_array.shape): 

235 raise TypeError( 

236 f"`generator` yielded an element of shape {ret_array.shape} " 

237 f"where an element of shape {expected_shape} was expected.") 

238 

239 return ret_arrays 

240 

241 flat_values = script_ops.numpy_function(generator_py_func, 

242 [iterator_id_t], flattened_types) 

243 

244 # In debug mode the numpy_function will return a scalar if 

245 # generator_py_func produces only a single value. 

246 if not isinstance(flat_values, (list, tuple)): 

247 flat_values = [flat_values] 

248 

249 # The `py_func()` op drops the inferred shapes, so we add them back in 

250 # here. 

251 if output_shapes is not None: 

252 for ret_t, shape in zip(flat_values, flattened_shapes): 

253 ret_t.set_shape(shape) 

254 

255 return nest.pack_sequence_as(output_types, flat_values) 

256 else: 

257 flat_output_types = structure.get_flat_tensor_types(output_signature) 

258 

259 def generator_py_func(iterator_id): 

260 """A `py_func` that will be called to invoke the iterator.""" 

261 # `next()` raises `StopIteration` when there are no more 

262 # elements remaining to be generated. 

263 values = next(generator_state.get_iterator(iterator_id.numpy())) 

264 

265 try: 

266 values = structure.normalize_element(values, output_signature) 

267 except (TypeError, ValueError) as e: 

268 raise TypeError( 

269 f"`generator` yielded an element that did not match the " 

270 f"expected structure. The expected structure was " 

271 f"{output_signature}, but the yielded element was " 

272 f"{values}.") from e 

273 

274 values_spec = structure.type_spec_from_value(values) 

275 

276 if not structure.are_compatible(values_spec, output_signature): 

277 raise TypeError( 

278 f"`generator` yielded an element of {values_spec} where an " 

279 f"element of {output_signature} was expected.") 

280 

281 return structure.to_tensor_list(output_signature, values) 

282 

283 return script_ops.eager_py_func( 

284 generator_py_func, inp=[iterator_id_t], Tout=flat_output_types) 

285 

286 def finalize_fn(iterator_id_t): 

287 """Releases host-side state for the iterator with ID `iterator_id_t`.""" 

288 

289 def finalize_py_func(iterator_id): 

290 generator_state.iterator_completed(iterator_id) 

291 # We return a dummy value so that the `finalize_fn` has a valid 

292 # signature. 

293 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 

294 # casting in `py_func()` will create an array of `np.int32` on Windows, 

295 # leading to a runtime error. 

296 return np.array(0, dtype=np.int64) 

297 

298 return script_ops.numpy_function(finalize_py_func, [iterator_id_t], 

299 dtypes.int64) 

300 

301 # This function associates each traversal of `generator` with a unique 

302 # iterator ID. 

303 def flat_map_fn(dummy_arg): 

304 # The `get_iterator_id_fn` gets a unique ID for the current instance of 

305 # of the generator. 

306 # The `generator_next_fn` gets the next element from the iterator with the 

307 # given ID, and raises StopIteration when that iterator contains no 

308 # more elements. 

309 return _GeneratorDataset( 

310 dummy_arg, 

311 get_iterator_id_fn, 

312 generator_next_fn, 

313 finalize_fn, 

314 output_signature, 

315 name=name) 

316 

317 # A single-element dataset that, each time it is evaluated, contains a 

318 # freshly-generated and unique (for the returned dataset) int64 

319 # ID that will be used to identify the appropriate Python state, which 

320 # is encapsulated in `generator_state`, and captured in 

321 # `get_iterator_id_map_fn`. 

322 dummy = 0 

323 id_dataset = dataset_ops.Dataset.from_tensors(dummy, name=name) 

324 

325 # A dataset that contains all of the elements generated by a 

326 # single iterator created from `generator`, identified by the 

327 # iterator ID contained in `id_dataset`. Lifting the iteration 

328 # into a flat_map here enables multiple repetitions and/or nested 

329 # versions of the returned dataset to be created, because it forces 

330 # the generation of a new ID for each version. 

331 return id_dataset.flat_map(flat_map_fn, name=name) 

332 

333 

334class _GeneratorDataset(dataset_ops.DatasetSource): 

335 """A `Dataset` that generates elements by invoking a function.""" 

336 

337 def __init__(self, 

338 init_args, 

339 init_func, 

340 next_func, 

341 finalize_func, 

342 output_signature, 

343 name=None): 

344 """Constructs a `_GeneratorDataset`. 

345 

346 Args: 

347 init_args: A (nested) structure representing the arguments to `init_func`. 

348 init_func: A TensorFlow function that will be called on `init_args` each 

349 time a C++ iterator over this dataset is constructed. Returns a (nested) 

350 structure representing the "state" of the dataset. 

351 next_func: A TensorFlow function that will be called on the result of 

352 `init_func` to produce each element, and that raises `OutOfRangeError` 

353 to terminate iteration. 

354 finalize_func: A TensorFlow function that will be called on the result of 

355 `init_func` immediately before a C++ iterator over this dataset is 

356 destroyed. The return value is ignored. 

357 output_signature: A (nested) structure of `tf.TypeSpec` objects describing 

358 the output of `next_func`. 

359 name: Optional. A name for the tf.data transformation. 

360 """ 

361 self._init_args = init_args 

362 

363 self._init_structure = structure.type_spec_from_value(init_args) 

364 

365 self._init_func = structured_function.StructuredFunctionWrapper( 

366 init_func, 

367 self._transformation_name(), 

368 input_structure=self._init_structure) 

369 

370 self._next_func = structured_function.StructuredFunctionWrapper( 

371 next_func, 

372 self._transformation_name(), 

373 input_structure=self._init_func.output_structure) 

374 

375 self._finalize_func = structured_function.StructuredFunctionWrapper( 

376 finalize_func, 

377 self._transformation_name(), 

378 input_structure=self._init_func.output_structure) 

379 

380 self._output_signature = output_signature 

381 

382 self._name = name 

383 

384 variant_tensor = gen_dataset_ops.generator_dataset( 

385 structure.to_tensor_list(self._init_structure, self._init_args) + 

386 self._init_func.function.captured_inputs, 

387 self._next_func.function.captured_inputs, 

388 self._finalize_func.function.captured_inputs, 

389 init_func=self._init_func.function, 

390 next_func=self._next_func.function, 

391 finalize_func=self._finalize_func.function, 

392 **self._common_args) 

393 super().__init__(variant_tensor) 

394 

395 @property 

396 def element_spec(self): 

397 return self._output_signature 

398 

399 def _transformation_name(self): 

400 return "Dataset.from_generator()"