Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/readers.py: 41%

205 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrappers for reader Datasets.""" 

16import os 

17 

18from tensorflow.python import tf2 

19from tensorflow.python.data.ops import dataset_ops 

20from tensorflow.python.data.ops import from_tensor_slices_op 

21from tensorflow.python.data.ops import structured_function 

22from tensorflow.python.data.util import convert 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_shape 

26from tensorflow.python.framework import tensor_spec 

27from tensorflow.python.framework import type_spec 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import gen_dataset_ops 

30from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

31from tensorflow.python.types import data as data_types 

32from tensorflow.python.util import nest 

33from tensorflow.python.util.tf_export import tf_export 

34 

35_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB 

36 

37 

38def _normalise_fspath(path): 

39 """Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519).""" 

40 return os.fspath(path) if isinstance(path, os.PathLike) else path 

41 

42 

43def _create_or_validate_filenames_dataset(filenames, name=None): 

44 """Creates (or validates) a dataset of filenames. 

45 

46 Args: 

47 filenames: Either a list or dataset of filenames. If it is a list, it is 

48 convert to a dataset. If it is a dataset, its type and shape is validated. 

49 name: (Optional.) A name for the tf.data operation. 

50 

51 Returns: 

52 A dataset of filenames. 

53 """ 

54 if isinstance(filenames, data_types.DatasetV2): 

55 element_type = dataset_ops.get_legacy_output_types(filenames) 

56 if element_type != dtypes.string: 

57 raise TypeError( 

58 "The `filenames` argument must contain `tf.string` elements. Got a " 

59 f"dataset of `{element_type!r}` elements.") 

60 element_shape = dataset_ops.get_legacy_output_shapes(filenames) 

61 if not element_shape.is_compatible_with(tensor_shape.TensorShape([])): 

62 raise TypeError( 

63 "The `filenames` argument must contain `tf.string` elements of shape " 

64 "[] (i.e. scalars). Got a dataset of element shape " 

65 f"{element_shape!r}.") 

66 else: 

67 filenames = nest.map_structure(_normalise_fspath, filenames) 

68 filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string) 

69 if filenames.dtype != dtypes.string: 

70 raise TypeError( 

71 "The `filenames` argument must contain `tf.string` elements. Got " 

72 f"`{filenames.dtype!r}` elements.") 

73 filenames = array_ops.reshape(filenames, [-1], name="flat_filenames") 

74 filenames = from_tensor_slices_op._TensorSliceDataset( # pylint: disable=protected-access 

75 filenames, 

76 is_files=True, 

77 name=name) 

78 return filenames 

79 

80 

81def _create_dataset_reader(dataset_creator, 

82 filenames, 

83 num_parallel_reads=None, 

84 name=None): 

85 """Creates a dataset that reads the given files using the given reader. 

86 

87 Args: 

88 dataset_creator: A function that takes in a single file name and returns a 

89 dataset. 

90 filenames: A `tf.data.Dataset` containing one or more filenames. 

91 num_parallel_reads: The number of parallel reads we should do. 

92 name: (Optional.) A name for the tf.data operation. 

93 

94 Returns: 

95 A `Dataset` that reads data from `filenames`. 

96 """ 

97 

98 def read_one_file(filename): 

99 filename = ops.convert_to_tensor(filename, dtypes.string, name="filename") 

100 return dataset_creator(filename) 

101 

102 if num_parallel_reads is None: 

103 return filenames.flat_map(read_one_file, name=name) 

104 elif num_parallel_reads == dataset_ops.AUTOTUNE: 

105 return filenames.interleave( 

106 read_one_file, num_parallel_calls=num_parallel_reads, name=name) 

107 else: 

108 return ParallelInterleaveDataset( 

109 filenames, 

110 read_one_file, 

111 cycle_length=num_parallel_reads, 

112 block_length=1, 

113 sloppy=False, 

114 buffer_output_elements=None, 

115 prefetch_input_elements=None, 

116 name=name) 

117 

118 

119def _get_type(value): 

120 """Returns the type of `value` if it is a TypeSpec.""" 

121 

122 if isinstance(value, type_spec.TypeSpec): 

123 return value.value_type() 

124 else: 

125 return type(value) 

126 

127 

128class _TextLineDataset(dataset_ops.DatasetSource): 

129 """A `Dataset` comprising records from one or more text files.""" 

130 

131 def __init__(self, 

132 filenames, 

133 compression_type=None, 

134 buffer_size=None, 

135 name=None): 

136 """Creates a `TextLineDataset`. 

137 

138 Args: 

139 filenames: A `tf.string` tensor containing one or more filenames. 

140 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

141 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 

142 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 

143 to buffer. A value of 0 results in the default buffering values chosen 

144 based on the compression type. 

145 name: (Optional.) A name for the tf.data operation. 

146 """ 

147 self._filenames = filenames 

148 self._compression_type = convert.optional_param_to_tensor( 

149 "compression_type", 

150 compression_type, 

151 argument_default="", 

152 argument_dtype=dtypes.string) 

153 self._buffer_size = convert.optional_param_to_tensor( 

154 "buffer_size", 

155 buffer_size, 

156 argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) 

157 self._name = name 

158 

159 variant_tensor = gen_dataset_ops.text_line_dataset( 

160 self._filenames, 

161 self._compression_type, 

162 self._buffer_size, 

163 metadata=self._metadata.SerializeToString()) 

164 super(_TextLineDataset, self).__init__(variant_tensor) 

165 

166 @property 

167 def element_spec(self): 

168 return tensor_spec.TensorSpec([], dtypes.string) 

169 

170 

171@tf_export("data.TextLineDataset", v1=[]) 

172class TextLineDatasetV2(dataset_ops.DatasetSource): 

173 r"""Creates a `Dataset` comprising lines from one or more text files. 

174 

175 The `tf.data.TextLineDataset` loads text from text files and creates a dataset 

176 where each line of the files becomes an element of the dataset. 

177 

178 For example, suppose we have 2 files "text_lines0.txt" and "text_lines1.txt" 

179 with the following lines: 

180 

181 >>> with open('/tmp/text_lines0.txt', 'w') as f: 

182 ... f.write('the cow\n') 

183 ... f.write('jumped over\n') 

184 ... f.write('the moon\n') 

185 >>> with open('/tmp/text_lines1.txt', 'w') as f: 

186 ... f.write('jack and jill\n') 

187 ... f.write('went up\n') 

188 ... f.write('the hill\n') 

189 

190 We can construct a TextLineDataset from them as follows: 

191 

192 >>> dataset = tf.data.TextLineDataset(['/tmp/text_lines0.txt', 

193 ... '/tmp/text_lines1.txt']) 

194 

195 The elements of the dataset are expected to be: 

196 

197 >>> for element in dataset.as_numpy_iterator(): 

198 ... print(element) 

199 b'the cow' 

200 b'jumped over' 

201 b'the moon' 

202 b'jack and jill' 

203 b'went up' 

204 b'the hill' 

205 """ 

206 

207 def __init__(self, 

208 filenames, 

209 compression_type=None, 

210 buffer_size=None, 

211 num_parallel_reads=None, 

212 name=None): 

213 r"""Creates a `TextLineDataset`. 

214 

215 The elements of the dataset will be the lines of the input files, using 

216 the newline character '\n' to denote line splits. The newline characters 

217 will be stripped off of each element. 

218 

219 Args: 

220 filenames: A `tf.data.Dataset` whose elements are `tf.string` scalars, a 

221 `tf.string` tensor, or a value that can be converted to a `tf.string` 

222 tensor (such as a list of Python strings). 

223 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

224 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 

225 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 

226 to buffer. A value of 0 results in the default buffering values chosen 

227 based on the compression type. 

228 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 

229 number of files to read in parallel. If greater than one, the records of 

230 files read in parallel are outputted in an interleaved order. If your 

231 input pipeline is I/O bottlenecked, consider setting this parameter to a 

232 value greater than one to parallelize the I/O. If `None`, files will be 

233 read sequentially. 

234 name: (Optional.) A name for the tf.data operation. 

235 """ 

236 filenames = _create_or_validate_filenames_dataset(filenames, name=name) 

237 self._filenames = filenames 

238 self._compression_type = compression_type 

239 self._buffer_size = buffer_size 

240 

241 def creator_fn(filename): 

242 return _TextLineDataset( 

243 filename, compression_type, buffer_size, name=name) 

244 

245 self._impl = _create_dataset_reader( 

246 creator_fn, filenames, num_parallel_reads, name=name) 

247 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 

248 

249 super(TextLineDatasetV2, self).__init__(variant_tensor) 

250 

251 @property 

252 def element_spec(self): 

253 return tensor_spec.TensorSpec([], dtypes.string) 

254 

255 

256@tf_export(v1=["data.TextLineDataset"]) 

257class TextLineDatasetV1(dataset_ops.DatasetV1Adapter): 

258 """A `Dataset` comprising lines from one or more text files.""" 

259 

260 def __init__(self, 

261 filenames, 

262 compression_type=None, 

263 buffer_size=None, 

264 num_parallel_reads=None, 

265 name=None): 

266 wrapped = TextLineDatasetV2(filenames, compression_type, buffer_size, 

267 num_parallel_reads, name) 

268 super(TextLineDatasetV1, self).__init__(wrapped) 

269 

270 __init__.__doc__ = TextLineDatasetV2.__init__.__doc__ 

271 

272 @property 

273 def _filenames(self): 

274 return self._dataset._filenames # pylint: disable=protected-access 

275 

276 @_filenames.setter 

277 def _filenames(self, value): 

278 self._dataset._filenames = value # pylint: disable=protected-access 

279 

280 

281class _TFRecordDataset(dataset_ops.DatasetSource): 

282 """A `Dataset` comprising records from one or more TFRecord files.""" 

283 

284 def __init__(self, 

285 filenames, 

286 compression_type=None, 

287 buffer_size=None, 

288 name=None): 

289 """Creates a `TFRecordDataset`. 

290 

291 Args: 

292 filenames: A `tf.string` tensor containing one or more filenames. 

293 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

294 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 

295 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 

296 bytes in the read buffer. 0 means no buffering. 

297 name: (Optional.) A name for the tf.data operation. 

298 """ 

299 self._filenames = filenames 

300 self._compression_type = convert.optional_param_to_tensor( 

301 "compression_type", 

302 compression_type, 

303 argument_default="", 

304 argument_dtype=dtypes.string) 

305 self._buffer_size = convert.optional_param_to_tensor( 

306 "buffer_size", 

307 buffer_size, 

308 argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) 

309 self._name = name 

310 

311 variant_tensor = gen_dataset_ops.tf_record_dataset( 

312 self._filenames, self._compression_type, self._buffer_size, 

313 metadata=self._metadata.SerializeToString()) 

314 super(_TFRecordDataset, self).__init__(variant_tensor) 

315 

316 @property 

317 def element_spec(self): 

318 return tensor_spec.TensorSpec([], dtypes.string) 

319 

320 

321class ParallelInterleaveDataset(dataset_ops.UnaryDataset): 

322 """A `Dataset` that maps a function over its input and flattens the result.""" 

323 

324 def __init__(self, 

325 input_dataset, 

326 map_func, 

327 cycle_length, 

328 block_length, 

329 sloppy, 

330 buffer_output_elements, 

331 prefetch_input_elements, 

332 name=None): 

333 """See `tf.data.experimental.parallel_interleave()` for details.""" 

334 self._input_dataset = input_dataset 

335 self._map_func = structured_function.StructuredFunctionWrapper( 

336 map_func, self._transformation_name(), dataset=input_dataset) 

337 if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec): 

338 raise TypeError( 

339 "The `map_func` argument must return a `Dataset` object. Got " 

340 f"{_get_type(self._map_func.output_structure)!r}.") 

341 self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access 

342 self._cycle_length = ops.convert_to_tensor( 

343 cycle_length, dtype=dtypes.int64, name="cycle_length") 

344 self._block_length = ops.convert_to_tensor( 

345 block_length, dtype=dtypes.int64, name="block_length") 

346 self._buffer_output_elements = convert.optional_param_to_tensor( 

347 "buffer_output_elements", 

348 buffer_output_elements, 

349 argument_default=2 * block_length) 

350 self._prefetch_input_elements = convert.optional_param_to_tensor( 

351 "prefetch_input_elements", 

352 prefetch_input_elements, 

353 argument_default=2 * cycle_length) 

354 if sloppy is None: 

355 self._deterministic = "default" 

356 elif sloppy: 

357 self._deterministic = "false" 

358 else: 

359 self._deterministic = "true" 

360 self._name = name 

361 

362 variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2( 

363 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

364 self._map_func.function.captured_inputs, 

365 self._cycle_length, 

366 self._block_length, 

367 self._buffer_output_elements, 

368 self._prefetch_input_elements, 

369 f=self._map_func.function, 

370 deterministic=self._deterministic, 

371 **self._common_args) 

372 super(ParallelInterleaveDataset, self).__init__(input_dataset, 

373 variant_tensor) 

374 

375 def _functions(self): 

376 return [self._map_func] 

377 

378 @property 

379 def element_spec(self): 

380 return self._element_spec 

381 

382 def _transformation_name(self): 

383 return "tf.data.experimental.parallel_interleave()" 

384 

385 

386@tf_export("data.TFRecordDataset", v1=[]) 

387class TFRecordDatasetV2(dataset_ops.DatasetV2): 

388 """A `Dataset` comprising records from one or more TFRecord files. 

389 

390 This dataset loads TFRecords from the files as bytes, exactly as they were 

391 written.`TFRecordDataset` does not do any parsing or decoding on its own. 

392 Parsing and decoding can be done by applying `Dataset.map` transformations 

393 after the `TFRecordDataset`. 

394 

395 A minimal example is given below: 

396 

397 >>> import tempfile 

398 >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords") 

399 >>> np.random.seed(0) 

400 

401 >>> # Write the records to a file. 

402 ... with tf.io.TFRecordWriter(example_path) as file_writer: 

403 ... for _ in range(4): 

404 ... x, y = np.random.random(), np.random.random() 

405 ... 

406 ... record_bytes = tf.train.Example(features=tf.train.Features(feature={ 

407 ... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])), 

408 ... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])), 

409 ... })).SerializeToString() 

410 ... file_writer.write(record_bytes) 

411 

412 >>> # Read the data back out. 

413 >>> def decode_fn(record_bytes): 

414 ... return tf.io.parse_single_example( 

415 ... # Data 

416 ... record_bytes, 

417 ... 

418 ... # Schema 

419 ... {"x": tf.io.FixedLenFeature([], dtype=tf.float32), 

420 ... "y": tf.io.FixedLenFeature([], dtype=tf.float32)} 

421 ... ) 

422 

423 >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn): 

424 ... print("x = {x:.4f}, y = {y:.4f}".format(**batch)) 

425 x = 0.5488, y = 0.7152 

426 x = 0.6028, y = 0.5449 

427 x = 0.4237, y = 0.6459 

428 x = 0.4376, y = 0.8918 

429 """ 

430 

431 def __init__(self, 

432 filenames, 

433 compression_type=None, 

434 buffer_size=None, 

435 num_parallel_reads=None, 

436 name=None): 

437 """Creates a `TFRecordDataset` to read one or more TFRecord files. 

438 

439 Each element of the dataset will contain a single TFRecord. 

440 

441 Args: 

442 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or 

443 more filenames. 

444 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

445 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 

446 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 

447 bytes in the read buffer. If your input pipeline is I/O bottlenecked, 

448 consider setting this parameter to a value 1-100 MBs. If `None`, a 

449 sensible default for both local and remote file systems is used. 

450 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 

451 number of files to read in parallel. If greater than one, the records of 

452 files read in parallel are outputted in an interleaved order. If your 

453 input pipeline is I/O bottlenecked, consider setting this parameter to a 

454 value greater than one to parallelize the I/O. If `None`, files will be 

455 read sequentially. 

456 name: (Optional.) A name for the tf.data operation. 

457 

458 Raises: 

459 TypeError: If any argument does not have the expected type. 

460 ValueError: If any argument does not have the expected shape. 

461 """ 

462 filenames = _create_or_validate_filenames_dataset(filenames, name=name) 

463 

464 self._filenames = filenames 

465 self._compression_type = compression_type 

466 self._buffer_size = buffer_size 

467 self._num_parallel_reads = num_parallel_reads 

468 

469 def creator_fn(filename): 

470 return _TFRecordDataset( 

471 filename, compression_type, buffer_size, name=name) 

472 

473 self._impl = _create_dataset_reader( 

474 creator_fn, filenames, num_parallel_reads, name=name) 

475 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 

476 super(TFRecordDatasetV2, self).__init__(variant_tensor) 

477 

478 def _inputs(self): 

479 return self._impl._inputs() # pylint: disable=protected-access 

480 

481 @property 

482 def element_spec(self): 

483 return tensor_spec.TensorSpec([], dtypes.string) 

484 

485 

486@tf_export(v1=["data.TFRecordDataset"]) 

487class TFRecordDatasetV1(dataset_ops.DatasetV1Adapter): 

488 """A `Dataset` comprising records from one or more TFRecord files.""" 

489 

490 def __init__(self, 

491 filenames, 

492 compression_type=None, 

493 buffer_size=None, 

494 num_parallel_reads=None, 

495 name=None): 

496 wrapped = TFRecordDatasetV2( 

497 filenames, compression_type, buffer_size, num_parallel_reads, name=name) 

498 super(TFRecordDatasetV1, self).__init__(wrapped) 

499 

500 __init__.__doc__ = TFRecordDatasetV2.__init__.__doc__ 

501 

502 @property 

503 def _filenames(self): 

504 return self._dataset._filenames # pylint: disable=protected-access 

505 

506 @_filenames.setter 

507 def _filenames(self, value): 

508 self._dataset._filenames = value # pylint: disable=protected-access 

509 

510 

511class _FixedLengthRecordDataset(dataset_ops.DatasetSource): 

512 """A `Dataset` of fixed-length records from one or more binary files.""" 

513 

514 def __init__(self, 

515 filenames, 

516 record_bytes, 

517 header_bytes=None, 

518 footer_bytes=None, 

519 buffer_size=None, 

520 compression_type=None, 

521 name=None): 

522 """Creates a `FixedLengthRecordDataset`. 

523 

524 Args: 

525 filenames: A `tf.string` tensor containing one or more filenames. 

526 record_bytes: A `tf.int64` scalar representing the number of bytes in each 

527 record. 

528 header_bytes: (Optional.) A `tf.int64` scalar representing the number of 

529 bytes to skip at the start of a file. 

530 footer_bytes: (Optional.) A `tf.int64` scalar representing the number of 

531 bytes to ignore at the end of a file. 

532 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 

533 bytes to buffer when reading. 

534 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

535 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 

536 name: (Optional.) A name for the tf.data operation. 

537 """ 

538 self._filenames = filenames 

539 self._record_bytes = ops.convert_to_tensor( 

540 record_bytes, dtype=dtypes.int64, name="record_bytes") 

541 self._header_bytes = convert.optional_param_to_tensor( 

542 "header_bytes", header_bytes) 

543 self._footer_bytes = convert.optional_param_to_tensor( 

544 "footer_bytes", footer_bytes) 

545 self._buffer_size = convert.optional_param_to_tensor( 

546 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 

547 self._compression_type = convert.optional_param_to_tensor( 

548 "compression_type", 

549 compression_type, 

550 argument_default="", 

551 argument_dtype=dtypes.string) 

552 self._name = name 

553 

554 variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2( 

555 self._filenames, 

556 self._header_bytes, 

557 self._record_bytes, 

558 self._footer_bytes, 

559 self._buffer_size, 

560 self._compression_type, 

561 metadata=self._metadata.SerializeToString()) 

562 super(_FixedLengthRecordDataset, self).__init__(variant_tensor) 

563 

564 @property 

565 def element_spec(self): 

566 return tensor_spec.TensorSpec([], dtypes.string) 

567 

568 

569@tf_export("data.FixedLengthRecordDataset", v1=[]) 

570class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource): 

571 """A `Dataset` of fixed-length records from one or more binary files. 

572 

573 The `tf.data.FixedLengthRecordDataset` reads fixed length records from binary 

574 files and creates a dataset where each record becomes an element of the 

575 dataset. The binary files can have a fixed length header and a fixed length 

576 footer, which will both be skipped. 

577 

578 For example, suppose we have 2 files "fixed_length0.bin" and 

579 "fixed_length1.bin" with the following content: 

580 

581 >>> with open('/tmp/fixed_length0.bin', 'wb') as f: 

582 ... f.write(b'HEADER012345FOOTER') 

583 >>> with open('/tmp/fixed_length1.bin', 'wb') as f: 

584 ... f.write(b'HEADER6789abFOOTER') 

585 

586 We can construct a `FixedLengthRecordDataset` from them as follows: 

587 

588 >>> dataset1 = tf.data.FixedLengthRecordDataset( 

589 ... filenames=['/tmp/fixed_length0.bin', '/tmp/fixed_length1.bin'], 

590 ... record_bytes=2, header_bytes=6, footer_bytes=6) 

591 

592 The elements of the dataset are: 

593 

594 >>> for element in dataset1.as_numpy_iterator(): 

595 ... print(element) 

596 b'01' 

597 b'23' 

598 b'45' 

599 b'67' 

600 b'89' 

601 b'ab' 

602 """ 

603 

604 def __init__(self, 

605 filenames, 

606 record_bytes, 

607 header_bytes=None, 

608 footer_bytes=None, 

609 buffer_size=None, 

610 compression_type=None, 

611 num_parallel_reads=None, 

612 name=None): 

613 """Creates a `FixedLengthRecordDataset`. 

614 

615 Args: 

616 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or 

617 more filenames. 

618 record_bytes: A `tf.int64` scalar representing the number of bytes in each 

619 record. 

620 header_bytes: (Optional.) A `tf.int64` scalar representing the number of 

621 bytes to skip at the start of a file. 

622 footer_bytes: (Optional.) A `tf.int64` scalar representing the number of 

623 bytes to ignore at the end of a file. 

624 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 

625 bytes to buffer when reading. 

626 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

627 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 

628 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 

629 number of files to read in parallel. If greater than one, the records of 

630 files read in parallel are outputted in an interleaved order. If your 

631 input pipeline is I/O bottlenecked, consider setting this parameter to a 

632 value greater than one to parallelize the I/O. If `None`, files will be 

633 read sequentially. 

634 name: (Optional.) A name for the tf.data operation. 

635 """ 

636 filenames = _create_or_validate_filenames_dataset(filenames, name=name) 

637 

638 self._filenames = filenames 

639 self._record_bytes = record_bytes 

640 self._header_bytes = header_bytes 

641 self._footer_bytes = footer_bytes 

642 self._buffer_size = buffer_size 

643 self._compression_type = compression_type 

644 

645 def creator_fn(filename): 

646 return _FixedLengthRecordDataset( 

647 filename, 

648 record_bytes, 

649 header_bytes, 

650 footer_bytes, 

651 buffer_size, 

652 compression_type, 

653 name=name) 

654 

655 self._impl = _create_dataset_reader( 

656 creator_fn, filenames, num_parallel_reads, name=name) 

657 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 

658 super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor) 

659 

660 @property 

661 def element_spec(self): 

662 return tensor_spec.TensorSpec([], dtypes.string) 

663 

664 

665@tf_export(v1=["data.FixedLengthRecordDataset"]) 

666class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter): 

667 """A `Dataset` of fixed-length records from one or more binary files.""" 

668 

669 def __init__(self, 

670 filenames, 

671 record_bytes, 

672 header_bytes=None, 

673 footer_bytes=None, 

674 buffer_size=None, 

675 compression_type=None, 

676 num_parallel_reads=None, 

677 name=None): 

678 wrapped = FixedLengthRecordDatasetV2( 

679 filenames, 

680 record_bytes, 

681 header_bytes, 

682 footer_bytes, 

683 buffer_size, 

684 compression_type, 

685 num_parallel_reads, 

686 name=name) 

687 super(FixedLengthRecordDatasetV1, self).__init__(wrapped) 

688 

689 __init__.__doc__ = FixedLengthRecordDatasetV2.__init__.__doc__ 

690 

691 @property 

692 def _filenames(self): 

693 return self._dataset._filenames # pylint: disable=protected-access 

694 

695 @_filenames.setter 

696 def _filenames(self, value): 

697 self._dataset._filenames = value # pylint: disable=protected-access 

698 

699 

700if tf2.enabled(): 

701 FixedLengthRecordDataset = FixedLengthRecordDatasetV2 

702 TFRecordDataset = TFRecordDatasetV2 

703 TextLineDataset = TextLineDatasetV2 

704else: 

705 FixedLengthRecordDataset = FixedLengthRecordDatasetV1 

706 TFRecordDataset = TFRecordDatasetV1 

707 TextLineDataset = TextLineDatasetV1