Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/readers.py: 23%

308 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrappers for reader Datasets.""" 

16import collections 

17import csv 

18import functools 

19import gzip 

20 

21import numpy as np 

22 

23from tensorflow.python import tf2 

24from tensorflow.python.data.experimental.ops import error_ops 

25from tensorflow.python.data.experimental.ops import parsing_ops 

26from tensorflow.python.data.ops import dataset_ops 

27from tensorflow.python.data.ops import map_op 

28from tensorflow.python.data.ops import options as options_lib 

29from tensorflow.python.data.ops import readers as core_readers 

30from tensorflow.python.data.util import convert 

31from tensorflow.python.data.util import nest 

32from tensorflow.python.framework import constant_op 

33from tensorflow.python.framework import dtypes 

34from tensorflow.python.framework import ops 

35from tensorflow.python.framework import tensor_spec 

36from tensorflow.python.framework import tensor_util 

37from tensorflow.python.lib.io import file_io 

38from tensorflow.python.ops import gen_experimental_dataset_ops 

39from tensorflow.python.ops import io_ops 

40from tensorflow.python.platform import gfile 

41from tensorflow.python.util.tf_export import tf_export 

42 

43_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, 

44 dtypes.int64, dtypes.string) 

45 

46 

47def _is_valid_int32(str_val): 

48 try: 

49 # Checks equality to prevent int32 overflow 

50 return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( 

51 str_val) 

52 except (ValueError, OverflowError): 

53 return False 

54 

55 

56def _is_valid_int64(str_val): 

57 try: 

58 dtypes.int64.as_numpy_dtype(str_val) 

59 return True 

60 except (ValueError, OverflowError): 

61 return False 

62 

63 

64def _is_valid_float(str_val, float_dtype): 

65 try: 

66 return float_dtype.as_numpy_dtype(str_val) < np.inf 

67 except ValueError: 

68 return False 

69 

70 

71def _infer_type(str_val, na_value, prev_type): 

72 """Given a string, infers its tensor type. 

73 

74 Infers the type of a value by picking the least 'permissive' type possible, 

75 while still allowing the previous type inference for this column to be valid. 

76 

77 Args: 

78 str_val: String value to infer the type of. 

79 na_value: Additional string to recognize as a NA/NaN CSV value. 

80 prev_type: Type previously inferred based on values of this column that 

81 we've seen up till now. 

82 Returns: 

83 Inferred dtype. 

84 """ 

85 if str_val in ("", na_value): 

86 # If the field is null, it gives no extra information about its type 

87 return prev_type 

88 

89 type_list = [ 

90 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string 

91 ] # list of types to try, ordered from least permissive to most 

92 

93 type_functions = [ 

94 _is_valid_int32, 

95 _is_valid_int64, 

96 lambda str_val: _is_valid_float(str_val, dtypes.float32), 

97 lambda str_val: _is_valid_float(str_val, dtypes.float64), 

98 lambda str_val: True, 

99 ] # Corresponding list of validation functions 

100 

101 for i in range(len(type_list)): 

102 validation_fn = type_functions[i] 

103 if validation_fn(str_val) and (prev_type is None or 

104 prev_type in type_list[:i + 1]): 

105 return type_list[i] 

106 

107 

108def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, 

109 file_io_fn): 

110 """Generator that yields rows of CSV file(s) in order.""" 

111 for fn in filenames: 

112 with file_io_fn(fn) as f: 

113 rdr = csv.reader( 

114 f, 

115 delimiter=field_delim, 

116 quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) 

117 row_num = 1 

118 if header: 

119 next(rdr) # Skip header lines 

120 row_num += 1 

121 

122 for csv_row in rdr: 

123 if len(csv_row) != num_cols: 

124 raise ValueError( 

125 f"Problem inferring types: CSV row {row_num} has {len(csv_row)} " 

126 f"number of fields. Expected: {num_cols}.") 

127 row_num += 1 

128 yield csv_row 

129 

130 

131def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, 

132 na_value, header, num_rows_for_inference, 

133 select_columns, file_io_fn): 

134 """Infers column types from the first N valid CSV records of files.""" 

135 if select_columns is None: 

136 select_columns = range(num_cols) 

137 inferred_types = [None] * len(select_columns) 

138 

139 for i, csv_row in enumerate( 

140 _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, 

141 file_io_fn)): 

142 if num_rows_for_inference is not None and i >= num_rows_for_inference: 

143 break 

144 

145 for j, col_index in enumerate(select_columns): 

146 inferred_types[j] = _infer_type(csv_row[col_index], na_value, 

147 inferred_types[j]) 

148 

149 # Replace None's with a default type 

150 inferred_types = [t or dtypes.string for t in inferred_types] 

151 # Default to 0 or '' for null values 

152 return [ 

153 constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) 

154 for t in inferred_types 

155 ] 

156 

157 

158def _infer_column_names(filenames, field_delim, use_quote_delim, file_io_fn): 

159 """Infers column names from first rows of files.""" 

160 csv_kwargs = { 

161 "delimiter": field_delim, 

162 "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE 

163 } 

164 with file_io_fn(filenames[0]) as f: 

165 try: 

166 column_names = next(csv.reader(f, **csv_kwargs)) 

167 except StopIteration: 

168 raise ValueError("Failed when reading the header line of " 

169 f"{filenames[0]}. Is it an empty file?") 

170 

171 for name in filenames[1:]: 

172 with file_io_fn(name) as f: 

173 try: 

174 if next(csv.reader(f, **csv_kwargs)) != column_names: 

175 raise ValueError( 

176 "All input CSV files should have the same column names in the " 

177 f"header row. File {name} has different column names.") 

178 except StopIteration: 

179 raise ValueError("Failed when reading the header line of " 

180 f"{name}. Is it an empty file?") 

181 return column_names 

182 

183 

184def _get_sorted_col_indices(select_columns, column_names): 

185 """Transforms select_columns argument into sorted column indices.""" 

186 names_to_indices = {n: i for i, n in enumerate(column_names)} 

187 num_cols = len(column_names) 

188 

189 results = [] 

190 for v in select_columns: 

191 # If value is already an int, check if it's valid. 

192 if isinstance(v, int): 

193 if v < 0 or v >= num_cols: 

194 raise ValueError( 

195 f"Column index {v} specified in `select_columns` should be > 0 " 

196 f" and <= {num_cols}, which is the number of columns.") 

197 results.append(v) 

198 # Otherwise, check that it's a valid column name and convert to the 

199 # the relevant column index. 

200 elif v not in names_to_indices: 

201 raise ValueError( 

202 f"Column {v} specified in `select_columns` must be of one of the " 

203 f"columns: {names_to_indices.keys()}.") 

204 else: 

205 results.append(names_to_indices[v]) 

206 

207 # Sort and ensure there are no duplicates 

208 results = sorted(set(results)) 

209 if len(results) != len(select_columns): 

210 sorted_names = sorted(results) 

211 duplicate_columns = set([a for a, b in zip( 

212 sorted_names[:-1], sorted_names[1:]) if a == b]) 

213 raise ValueError("The `select_columns` argument contains duplicate " 

214 f"columns: {duplicate_columns}.") 

215 return results 

216 

217 

218def _maybe_shuffle_and_repeat( 

219 dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): 

220 """Optionally shuffle and repeat dataset, as requested.""" 

221 if shuffle: 

222 dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) 

223 if num_epochs != 1: 

224 dataset = dataset.repeat(num_epochs) 

225 return dataset 

226 

227 

228def make_tf_record_dataset(file_pattern, 

229 batch_size, 

230 parser_fn=None, 

231 num_epochs=None, 

232 shuffle=True, 

233 shuffle_buffer_size=None, 

234 shuffle_seed=None, 

235 prefetch_buffer_size=None, 

236 num_parallel_reads=None, 

237 num_parallel_parser_calls=None, 

238 drop_final_batch=False): 

239 """Reads and optionally parses TFRecord files into a dataset. 

240 

241 Provides common functionality such as batching, optional parsing, shuffling, 

242 and performant defaults. 

243 

244 Args: 

245 file_pattern: List of files or patterns of TFRecord file paths. 

246 See `tf.io.gfile.glob` for pattern rules. 

247 batch_size: An int representing the number of records to combine 

248 in a single batch. 

249 parser_fn: (Optional.) A function accepting string input to parse 

250 and process the record contents. This function must map records 

251 to components of a fixed shape, so they may be batched. By 

252 default, uses the record contents unmodified. 

253 num_epochs: (Optional.) An int specifying the number of times this 

254 dataset is repeated. If None (the default), cycles through the 

255 dataset forever. 

256 shuffle: (Optional.) A bool that indicates whether the input 

257 should be shuffled. Defaults to `True`. 

258 shuffle_buffer_size: (Optional.) Buffer size to use for 

259 shuffling. A large buffer size ensures better shuffling, but 

260 increases memory usage and startup time. 

261 shuffle_seed: (Optional.) Randomization seed to use for shuffling. 

262 prefetch_buffer_size: (Optional.) An int specifying the number of 

263 feature batches to prefetch for performance improvement. 

264 Defaults to auto-tune. Set to 0 to disable prefetching. 

265 num_parallel_reads: (Optional.) Number of threads used to read 

266 records from files. By default or if set to a value >1, the 

267 results will be interleaved. Defaults to `24`. 

268 num_parallel_parser_calls: (Optional.) Number of parallel 

269 records to parse in parallel. Defaults to `batch_size`. 

270 drop_final_batch: (Optional.) Whether the last batch should be 

271 dropped in case its size is smaller than `batch_size`; the 

272 default behavior is not to drop the smaller batch. 

273 

274 Returns: 

275 A dataset, where each element matches the output of `parser_fn` 

276 except it will have an additional leading `batch-size` dimension, 

277 or a `batch_size`-length 1-D tensor of strings if `parser_fn` is 

278 unspecified. 

279 """ 

280 if num_parallel_reads is None: 

281 # NOTE: We considered auto-tuning this value, but there is a concern 

282 # that this affects the mixing of records from different files, which 

283 # could affect training convergence/accuracy, so we are defaulting to 

284 # a constant for now. 

285 num_parallel_reads = 24 

286 

287 if num_parallel_parser_calls is None: 

288 # TODO(josh11b): if num_parallel_parser_calls is None, use some function 

289 # of num cores instead of `batch_size`. 

290 num_parallel_parser_calls = batch_size 

291 

292 if prefetch_buffer_size is None: 

293 prefetch_buffer_size = dataset_ops.AUTOTUNE 

294 

295 files = dataset_ops.Dataset.list_files( 

296 file_pattern, shuffle=shuffle, seed=shuffle_seed) 

297 

298 dataset = core_readers.TFRecordDataset( 

299 files, num_parallel_reads=num_parallel_reads) 

300 

301 if shuffle_buffer_size is None: 

302 # TODO(josh11b): Auto-tune this value when not specified 

303 shuffle_buffer_size = 10000 

304 dataset = _maybe_shuffle_and_repeat( 

305 dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) 

306 

307 # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to 

308 # improve the shape inference, because it makes the batch dimension static. 

309 # It is safe to do this because in that case we are repeating the input 

310 # indefinitely, and all batches will be full-sized. 

311 drop_final_batch = drop_final_batch or num_epochs is None 

312 

313 if parser_fn is None: 

314 dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) 

315 else: 

316 dataset = dataset.map( 

317 parser_fn, num_parallel_calls=num_parallel_parser_calls) 

318 dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) 

319 

320 if prefetch_buffer_size == 0: 

321 return dataset 

322 else: 

323 return dataset.prefetch(buffer_size=prefetch_buffer_size) 

324 

325 

326@tf_export("data.experimental.make_csv_dataset", v1=[]) 

327def make_csv_dataset_v2( 

328 file_pattern, 

329 batch_size, 

330 column_names=None, 

331 column_defaults=None, 

332 label_name=None, 

333 select_columns=None, 

334 field_delim=",", 

335 use_quote_delim=True, 

336 na_value="", 

337 header=True, 

338 num_epochs=None, # TODO(aaudibert): Change default to 1 when graduating. 

339 shuffle=True, 

340 shuffle_buffer_size=10000, 

341 shuffle_seed=None, 

342 prefetch_buffer_size=None, 

343 num_parallel_reads=None, 

344 sloppy=False, 

345 num_rows_for_inference=100, 

346 compression_type=None, 

347 ignore_errors=False, 

348 encoding="utf-8", 

349): 

350 """Reads CSV files into a dataset. 

351 

352 Reads CSV files into a dataset, where each element of the dataset is a 

353 (features, labels) tuple that corresponds to a batch of CSV rows. The features 

354 dictionary maps feature column names to `Tensor`s containing the corresponding 

355 feature data, and labels is a `Tensor` containing the batch's label data. 

356 

357 By default, the first rows of the CSV files are expected to be headers listing 

358 the column names. If the first rows are not headers, set `header=False` and 

359 provide the column names with the `column_names` argument. 

360 

361 By default, the dataset is repeated indefinitely, reshuffling the order each 

362 time. This behavior can be modified by setting the `num_epochs` and `shuffle` 

363 arguments. 

364 

365 For example, suppose you have a CSV file containing 

366 

367 | Feature_A | Feature_B | 

368 | --------- | --------- | 

369 | 1 | "a" | 

370 | 2 | "b" | 

371 | 3 | "c" | 

372 | 4 | "d" | 

373 

374 ``` 

375 # No label column specified 

376 dataset = tf.data.experimental.make_csv_dataset(filename, batch_size=2) 

377 iterator = dataset.as_numpy_iterator() 

378 print(dict(next(iterator))) 

379 # prints a dictionary of batched features: 

380 # OrderedDict([('Feature_A', array([1, 4], dtype=int32)), 

381 # ('Feature_B', array([b'a', b'd'], dtype=object))]) 

382 ``` 

383 

384 ``` 

385 # Set Feature_B as label column 

386 dataset = tf.data.experimental.make_csv_dataset( 

387 filename, batch_size=2, label_name="Feature_B") 

388 iterator = dataset.as_numpy_iterator() 

389 print(next(iterator)) 

390 # prints (features, labels) tuple: 

391 # (OrderedDict([('Feature_A', array([1, 2], dtype=int32))]), 

392 # array([b'a', b'b'], dtype=object)) 

393 ``` 

394 

395 See the 

396 [Load CSV data guide](https://www.tensorflow.org/tutorials/load_data/csv) for 

397 more examples of using `make_csv_dataset` to read CSV data. 

398 

399 Args: 

400 file_pattern: List of files or patterns of file paths containing CSV 

401 records. See `tf.io.gfile.glob` for pattern rules. 

402 batch_size: An int representing the number of records to combine 

403 in a single batch. 

404 column_names: An optional list of strings that corresponds to the CSV 

405 columns, in order. One per column of the input record. If this is not 

406 provided, infers the column names from the first row of the records. 

407 These names will be the keys of the features dict of each dataset element. 

408 column_defaults: A optional list of default values for the CSV fields. One 

409 item per selected column of the input record. Each item in the list is 

410 either a valid CSV dtype (float32, float64, int32, int64, or string), or a 

411 `Tensor` with one of the aforementioned types. The tensor can either be 

412 a scalar default value (if the column is optional), or an empty tensor (if 

413 the column is required). If a dtype is provided instead of a tensor, the 

414 column is also treated as required. If this list is not provided, tries 

415 to infer types based on reading the first num_rows_for_inference rows of 

416 files specified, and assumes all columns are optional, defaulting to `0` 

417 for numeric values and `""` for string values. If both this and 

418 `select_columns` are specified, these must have the same lengths, and 

419 `column_defaults` is assumed to be sorted in order of increasing column 

420 index. 

421 label_name: A optional string corresponding to the label column. If 

422 provided, the data for this column is returned as a separate `Tensor` from 

423 the features dictionary, so that the dataset complies with the format 

424 expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input 

425 function. 

426 select_columns: An optional list of integer indices or string column 

427 names, that specifies a subset of columns of CSV data to select. If 

428 column names are provided, these must correspond to names provided in 

429 `column_names` or inferred from the file header lines. When this argument 

430 is specified, only a subset of CSV columns will be parsed and returned, 

431 corresponding to the columns specified. Using this results in faster 

432 parsing and lower memory usage. If both this and `column_defaults` are 

433 specified, these must have the same lengths, and `column_defaults` is 

434 assumed to be sorted in order of increasing column index. 

435 field_delim: An optional `string`. Defaults to `","`. Char delimiter to 

436 separate fields in a record. 

437 use_quote_delim: An optional bool. Defaults to `True`. If false, treats 

438 double quotation marks as regular characters inside of the string fields. 

439 na_value: Additional string to recognize as NA/NaN. 

440 header: A bool that indicates whether the first rows of provided CSV files 

441 correspond to header lines with column names, and should not be included 

442 in the data. 

443 num_epochs: An int specifying the number of times this dataset is repeated. 

444 If None, cycles through the dataset forever. 

445 shuffle: A bool that indicates whether the input should be shuffled. 

446 shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size 

447 ensures better shuffling, but increases memory usage and startup time. 

448 shuffle_seed: Randomization seed to use for shuffling. 

449 prefetch_buffer_size: An int specifying the number of feature 

450 batches to prefetch for performance improvement. Recommended value is the 

451 number of batches consumed per training step. Defaults to auto-tune. 

452 num_parallel_reads: Number of threads used to read CSV records from files. 

453 If >1, the results will be interleaved. Defaults to `1`. 

454 sloppy: If `True`, reading performance will be improved at 

455 the cost of non-deterministic ordering. If `False`, the order of elements 

456 produced is deterministic prior to shuffling (elements are still 

457 randomized if `shuffle=True`. Note that if the seed is set, then order 

458 of elements after shuffling is deterministic). Defaults to `False`. 

459 num_rows_for_inference: Number of rows of a file to use for type inference 

460 if record_defaults is not provided. If None, reads all the rows of all 

461 the files. Defaults to 100. 

462 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

463 `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression. 

464 ignore_errors: (Optional.) If `True`, ignores errors with CSV file parsing, 

465 such as malformed data or empty lines, and moves on to the next valid 

466 CSV record. Otherwise, the dataset raises an error and stops processing 

467 when encountering any invalid records. Defaults to `False`. 

468 encoding: Encoding to use when reading. Defaults to `UTF-8`. 

469 

470 Returns: 

471 A dataset, where each element is a (features, labels) tuple that corresponds 

472 to a batch of `batch_size` CSV rows. The features dictionary maps feature 

473 column names to `Tensor`s containing the corresponding column data, and 

474 labels is a `Tensor` containing the column data for the label column 

475 specified by `label_name`. 

476 

477 Raises: 

478 ValueError: If any of the arguments is malformed. 

479 """ 

480 if num_parallel_reads is None: 

481 num_parallel_reads = 1 

482 

483 if prefetch_buffer_size is None: 

484 prefetch_buffer_size = dataset_ops.AUTOTUNE 

485 

486 # Create dataset of all matching filenames 

487 filenames = _get_file_names(file_pattern, False) 

488 dataset = dataset_ops.Dataset.from_tensor_slices(filenames) 

489 if shuffle: 

490 dataset = dataset.shuffle(len(filenames), shuffle_seed) 

491 

492 # Clean arguments; figure out column names and defaults 

493 if column_names is None or column_defaults is None: 

494 # Find out which io function to open the file 

495 file_io_fn = lambda filename: file_io.FileIO( # pylint: disable=g-long-lambda 

496 filename, "r", encoding=encoding) 

497 if compression_type is not None: 

498 compression_type_value = tensor_util.constant_value(compression_type) 

499 if compression_type_value is None: 

500 raise ValueError( 

501 f"Received unknown `compression_type` {compression_type}. " 

502 "Expected: GZIP, ZLIB or "" (empty string).") 

503 if compression_type_value == "GZIP": 

504 file_io_fn = lambda filename: gzip.open( # pylint: disable=g-long-lambda 

505 filename, "rt", encoding=encoding) 

506 elif compression_type_value == "ZLIB": 

507 raise ValueError( 

508 f"`compression_type` {compression_type} is not supported for " 

509 "probing columns.") 

510 elif compression_type_value != "": 

511 raise ValueError( 

512 f"Received unknown `compression_type` {compression_type}. " 

513 "Expected: GZIP, ZLIB or " 

514 " (empty string).") 

515 if column_names is None: 

516 if not header: 

517 raise ValueError("Expected `column_names` or `header` arguments. Neither " 

518 "is provided.") 

519 # If column names are not provided, infer from the header lines 

520 column_names = _infer_column_names(filenames, field_delim, use_quote_delim, 

521 file_io_fn) 

522 if len(column_names) != len(set(column_names)): 

523 sorted_names = sorted(column_names) 

524 duplicate_columns = set([a for a, b in zip( 

525 sorted_names[:-1], sorted_names[1:]) if a == b]) 

526 raise ValueError( 

527 "Either `column_names` argument or CSV header row contains duplicate " 

528 f"column names: {duplicate_columns}.") 

529 

530 if select_columns is not None: 

531 select_columns = _get_sorted_col_indices(select_columns, column_names) 

532 

533 if column_defaults is not None: 

534 column_defaults = [ 

535 constant_op.constant([], dtype=x) 

536 if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x 

537 for x in column_defaults 

538 ] 

539 else: 

540 # If column defaults are not provided, infer from records at graph 

541 # construction time 

542 column_defaults = _infer_column_defaults(filenames, len(column_names), 

543 field_delim, use_quote_delim, 

544 na_value, header, 

545 num_rows_for_inference, 

546 select_columns, file_io_fn) 

547 

548 if select_columns is not None and len(column_defaults) != len(select_columns): 

549 raise ValueError( 

550 "If specified, `column_defaults` and `select_columns` must have the " 

551 f"same length: `column_defaults` has length {len(column_defaults)}, " 

552 f"`select_columns` has length {len(select_columns)}.") 

553 if select_columns is not None and len(column_names) > len(select_columns): 

554 # Pick the relevant subset of column names 

555 column_names = [column_names[i] for i in select_columns] 

556 

557 if label_name is not None and label_name not in column_names: 

558 raise ValueError("`label_name` provided must be one of the columns: " 

559 f"{column_names}. Received: {label_name}.") 

560 

561 def filename_to_dataset(filename): 

562 dataset = CsvDataset( 

563 filename, 

564 record_defaults=column_defaults, 

565 field_delim=field_delim, 

566 use_quote_delim=use_quote_delim, 

567 na_value=na_value, 

568 select_cols=select_columns, 

569 header=header, 

570 compression_type=compression_type 

571 ) 

572 if ignore_errors: 

573 dataset = dataset.apply(error_ops.ignore_errors()) 

574 return dataset 

575 

576 def map_fn(*columns): 

577 """Organizes columns into a features dictionary. 

578 

579 Args: 

580 *columns: list of `Tensor`s corresponding to one csv record. 

581 Returns: 

582 An OrderedDict of feature names to values for that particular record. If 

583 label_name is provided, extracts the label feature to be returned as the 

584 second element of the tuple. 

585 """ 

586 features = collections.OrderedDict(zip(column_names, columns)) 

587 if label_name is not None: 

588 label = features.pop(label_name) 

589 return features, label 

590 return features 

591 

592 if num_parallel_reads == dataset_ops.AUTOTUNE: 

593 dataset = dataset.interleave( 

594 filename_to_dataset, num_parallel_calls=num_parallel_reads) 

595 options = options_lib.Options() 

596 options.deterministic = not sloppy 

597 dataset = dataset.with_options(options) 

598 else: 

599 # Read files sequentially (if num_parallel_reads=1) or in parallel 

600 def apply_fn(dataset): 

601 return core_readers.ParallelInterleaveDataset( 

602 dataset, 

603 filename_to_dataset, 

604 cycle_length=num_parallel_reads, 

605 block_length=1, 

606 sloppy=sloppy, 

607 buffer_output_elements=None, 

608 prefetch_input_elements=None) 

609 

610 dataset = dataset.apply(apply_fn) 

611 

612 dataset = _maybe_shuffle_and_repeat( 

613 dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) 

614 

615 # Apply batch before map for perf, because map has high overhead relative 

616 # to the size of the computation in each map. 

617 # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to 

618 # improve the shape inference, because it makes the batch dimension static. 

619 # It is safe to do this because in that case we are repeating the input 

620 # indefinitely, and all batches will be full-sized. 

621 dataset = dataset.batch(batch_size=batch_size, 

622 drop_remainder=num_epochs is None) 

623 dataset = map_op._MapDataset( # pylint: disable=protected-access 

624 dataset, map_fn, use_inter_op_parallelism=False) 

625 dataset = dataset.prefetch(prefetch_buffer_size) 

626 

627 return dataset 

628 

629 

630@tf_export(v1=["data.experimental.make_csv_dataset"]) 

631def make_csv_dataset_v1( 

632 file_pattern, 

633 batch_size, 

634 column_names=None, 

635 column_defaults=None, 

636 label_name=None, 

637 select_columns=None, 

638 field_delim=",", 

639 use_quote_delim=True, 

640 na_value="", 

641 header=True, 

642 num_epochs=None, 

643 shuffle=True, 

644 shuffle_buffer_size=10000, 

645 shuffle_seed=None, 

646 prefetch_buffer_size=None, 

647 num_parallel_reads=None, 

648 sloppy=False, 

649 num_rows_for_inference=100, 

650 compression_type=None, 

651 ignore_errors=False, 

652 encoding="utf-8", 

653): # pylint: disable=missing-docstring 

654 return dataset_ops.DatasetV1Adapter( 

655 make_csv_dataset_v2(file_pattern, batch_size, column_names, 

656 column_defaults, label_name, select_columns, 

657 field_delim, use_quote_delim, na_value, header, 

658 num_epochs, shuffle, shuffle_buffer_size, 

659 shuffle_seed, prefetch_buffer_size, 

660 num_parallel_reads, sloppy, num_rows_for_inference, 

661 compression_type, ignore_errors, encoding)) 

662make_csv_dataset_v1.__doc__ = make_csv_dataset_v2.__doc__ 

663 

664 

665_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB 

666 

667 

668@tf_export("data.experimental.CsvDataset", v1=[]) 

669class CsvDatasetV2(dataset_ops.DatasetSource): 

670 r"""A Dataset comprising lines from one or more CSV files. 

671 

672 The `tf.data.experimental.CsvDataset` class provides a minimal CSV Dataset 

673 interface. There is also a richer `tf.data.experimental.make_csv_dataset` 

674 function which provides additional convenience features such as column header 

675 parsing, column type-inference, automatic shuffling, and file interleaving. 

676 

677 The elements of this dataset correspond to records from the file(s). 

678 RFC 4180 format is expected for CSV files 

679 (https://tools.ietf.org/html/rfc4180) 

680 Note that we allow leading and trailing spaces for int or float fields. 

681 

682 For example, suppose we have a file 'my_file0.csv' with four CSV columns of 

683 different data types: 

684 

685 >>> with open('/tmp/my_file0.csv', 'w') as f: 

686 ... f.write('abcdefg,4.28E10,5.55E6,12\n') 

687 ... f.write('hijklmn,-5.3E14,,2\n') 

688 

689 We can construct a CsvDataset from it as follows: 

690 

691 >>> dataset = tf.data.experimental.CsvDataset( 

692 ... "/tmp/my_file0.csv", 

693 ... [tf.float32, # Required field, use dtype or empty tensor 

694 ... tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 

695 ... tf.int32, # Required field, use dtype or empty tensor 

696 ... ], 

697 ... select_cols=[1,2,3] # Only parse last three columns 

698 ... ) 

699 

700 The expected output of its iterations is: 

701 

702 >>> for element in dataset.as_numpy_iterator(): 

703 ... print(element) 

704 (4.28e10, 5.55e6, 12) 

705 (-5.3e14, 0.0, 2) 

706 

707 See 

708 https://www.tensorflow.org/tutorials/load_data/csv#tfdataexperimentalcsvdataset 

709 for more in-depth example usage. 

710 """ 

711 

712 def __init__(self, 

713 filenames, 

714 record_defaults, 

715 compression_type=None, 

716 buffer_size=None, 

717 header=False, 

718 field_delim=",", 

719 use_quote_delim=True, 

720 na_value="", 

721 select_cols=None, 

722 exclude_cols=None): 

723 """Creates a `CsvDataset` by reading and decoding CSV files. 

724 

725 Args: 

726 filenames: A `tf.string` tensor containing one or more filenames. 

727 record_defaults: A list of default values for the CSV fields. Each item in 

728 the list is either a valid CSV `DType` (float32, float64, int32, int64, 

729 string), or a `Tensor` object with one of the above types. One per 

730 column of CSV data, with either a scalar `Tensor` default value for the 

731 column if it is optional, or `DType` or empty `Tensor` if required. If 

732 both this and `select_columns` are specified, these must have the same 

733 lengths, and `column_defaults` is assumed to be sorted in order of 

734 increasing column index. If both this and 'exclude_cols' are specified, 

735 the sum of lengths of record_defaults and exclude_cols should equal 

736 the total number of columns in the CSV file. 

737 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

738 `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no 

739 compression. 

740 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 

741 to buffer while reading files. Defaults to 4MB. 

742 header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) 

743 have header line(s) that should be skipped when parsing. Defaults to 

744 `False`. 

745 field_delim: (Optional.) A `tf.string` scalar containing the delimiter 

746 character that separates fields in a record. Defaults to `","`. 

747 use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats 

748 double quotation marks as regular characters inside of string fields 

749 (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. 

750 na_value: (Optional.) A `tf.string` scalar indicating a value that will 

751 be treated as NA/NaN. 

752 select_cols: (Optional.) A sorted list of column indices to select from 

753 the input data. If specified, only this subset of columns will be 

754 parsed. Defaults to parsing all columns. At most one of `select_cols` 

755 and `exclude_cols` can be specified. 

756 exclude_cols: (Optional.) A sorted list of column indices to exclude from 

757 the input data. If specified, only the complement of this set of column 

758 will be parsed. Defaults to parsing all columns. At most one of 

759 `select_cols` and `exclude_cols` can be specified. 

760 

761 Raises: 

762 InvalidArgumentError: If exclude_cols is not None and 

763 len(exclude_cols) + len(record_defaults) does not match the total 

764 number of columns in the file(s) 

765 

766 

767 """ 

768 self._filenames = ops.convert_to_tensor( 

769 filenames, dtype=dtypes.string, name="filenames") 

770 self._compression_type = convert.optional_param_to_tensor( 

771 "compression_type", 

772 compression_type, 

773 argument_default="", 

774 argument_dtype=dtypes.string) 

775 record_defaults = [ 

776 constant_op.constant([], dtype=x) 

777 if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x 

778 for x in record_defaults 

779 ] 

780 self._record_defaults = ops.convert_n_to_tensor( 

781 record_defaults, name="record_defaults") 

782 self._buffer_size = convert.optional_param_to_tensor( 

783 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 

784 self._header = ops.convert_to_tensor( 

785 header, dtype=dtypes.bool, name="header") 

786 self._field_delim = ops.convert_to_tensor( 

787 field_delim, dtype=dtypes.string, name="field_delim") 

788 self._use_quote_delim = ops.convert_to_tensor( 

789 use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") 

790 self._na_value = ops.convert_to_tensor( 

791 na_value, dtype=dtypes.string, name="na_value") 

792 self._select_cols = convert.optional_param_to_tensor( 

793 "select_cols", 

794 select_cols, 

795 argument_default=[], 

796 argument_dtype=dtypes.int64, 

797 ) 

798 self._exclude_cols = convert.optional_param_to_tensor( 

799 "exclude_cols", 

800 exclude_cols, 

801 argument_default=[], 

802 argument_dtype=dtypes.int64, 

803 ) 

804 self._element_spec = tuple( 

805 tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults) 

806 variant_tensor = gen_experimental_dataset_ops.csv_dataset_v2( 

807 filenames=self._filenames, 

808 record_defaults=self._record_defaults, 

809 buffer_size=self._buffer_size, 

810 header=self._header, 

811 output_shapes=self._flat_shapes, 

812 field_delim=self._field_delim, 

813 use_quote_delim=self._use_quote_delim, 

814 na_value=self._na_value, 

815 select_cols=self._select_cols, 

816 exclude_cols=self._exclude_cols, 

817 compression_type=self._compression_type) 

818 super(CsvDatasetV2, self).__init__(variant_tensor) 

819 

820 @property 

821 def element_spec(self): 

822 return self._element_spec 

823 

824 

825@tf_export(v1=["data.experimental.CsvDataset"]) 

826class CsvDatasetV1(dataset_ops.DatasetV1Adapter): 

827 """A Dataset comprising lines from one or more CSV files.""" 

828 

829 @functools.wraps(CsvDatasetV2.__init__, ("__module__", "__name__")) 

830 def __init__(self, 

831 filenames, 

832 record_defaults, 

833 compression_type=None, 

834 buffer_size=None, 

835 header=False, 

836 field_delim=",", 

837 use_quote_delim=True, 

838 na_value="", 

839 select_cols=None): 

840 """Creates a `CsvDataset` by reading and decoding CSV files. 

841 

842 The elements of this dataset correspond to records from the file(s). 

843 RFC 4180 format is expected for CSV files 

844 (https://tools.ietf.org/html/rfc4180) 

845 Note that we allow leading and trailing spaces with int or float field. 

846 

847 

848 For example, suppose we have a file 'my_file0.csv' with four CSV columns of 

849 different data types: 

850 ``` 

851 abcdefg,4.28E10,5.55E6,12 

852 hijklmn,-5.3E14,,2 

853 ``` 

854 

855 We can construct a CsvDataset from it as follows: 

856 

857 ```python 

858 dataset = tf.data.experimental.CsvDataset( 

859 "my_file*.csv", 

860 [tf.float32, # Required field, use dtype or empty tensor 

861 tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 

862 tf.int32, # Required field, use dtype or empty tensor 

863 ], 

864 select_cols=[1,2,3] # Only parse last three columns 

865 ) 

866 ``` 

867 

868 The expected output of its iterations is: 

869 

870 ```python 

871 for element in dataset: 

872 print(element) 

873 

874 >> (4.28e10, 5.55e6, 12) 

875 >> (-5.3e14, 0.0, 2) 

876 ``` 

877 

878 Args: 

879 filenames: A `tf.string` tensor containing one or more filenames. 

880 record_defaults: A list of default values for the CSV fields. Each item in 

881 the list is either a valid CSV `DType` (float32, float64, int32, int64, 

882 string), or a `Tensor` object with one of the above types. One per 

883 column of CSV data, with either a scalar `Tensor` default value for the 

884 column if it is optional, or `DType` or empty `Tensor` if required. If 

885 both this and `select_columns` are specified, these must have the same 

886 lengths, and `column_defaults` is assumed to be sorted in order of 

887 increasing column index. If both this and 'exclude_cols' are specified, 

888 the sum of lengths of record_defaults and exclude_cols should equal the 

889 total number of columns in the CSV file. 

890 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 

891 `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no 

892 compression. 

893 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 

894 to buffer while reading files. Defaults to 4MB. 

895 header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) 

896 have header line(s) that should be skipped when parsing. Defaults to 

897 `False`. 

898 field_delim: (Optional.) A `tf.string` scalar containing the delimiter 

899 character that separates fields in a record. Defaults to `","`. 

900 use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats double 

901 quotation marks as regular characters inside of string fields (ignoring 

902 RFC 4180, Section 2, Bullet 5). Defaults to `True`. 

903 na_value: (Optional.) A `tf.string` scalar indicating a value that will be 

904 treated as NA/NaN. 

905 select_cols: (Optional.) A sorted list of column indices to select from 

906 the input data. If specified, only this subset of columns will be 

907 parsed. Defaults to parsing all columns. At most one of `select_cols` 

908 and `exclude_cols` can be specified. 

909 """ 

910 wrapped = CsvDatasetV2(filenames, record_defaults, compression_type, 

911 buffer_size, header, field_delim, use_quote_delim, 

912 na_value, select_cols) 

913 super(CsvDatasetV1, self).__init__(wrapped) 

914 

915 

916@tf_export("data.experimental.make_batched_features_dataset", v1=[]) 

917def make_batched_features_dataset_v2(file_pattern, 

918 batch_size, 

919 features, 

920 reader=None, 

921 label_key=None, 

922 reader_args=None, 

923 num_epochs=None, 

924 shuffle=True, 

925 shuffle_buffer_size=10000, 

926 shuffle_seed=None, 

927 prefetch_buffer_size=None, 

928 reader_num_threads=None, 

929 parser_num_threads=None, 

930 sloppy_ordering=False, 

931 drop_final_batch=False): 

932 """Returns a `Dataset` of feature dictionaries from `Example` protos. 

933 

934 If label_key argument is provided, returns a `Dataset` of tuple 

935 comprising of feature dictionaries and label. 

936 

937 Example: 

938 

939 ``` 

940 serialized_examples = [ 

941 features { 

942 feature { key: "age" value { int64_list { value: [ 0 ] } } } 

943 feature { key: "gender" value { bytes_list { value: [ "f" ] } } } 

944 feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } 

945 }, 

946 features { 

947 feature { key: "age" value { int64_list { value: [] } } } 

948 feature { key: "gender" value { bytes_list { value: [ "f" ] } } } 

949 feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } 

950 } 

951 ] 

952 ``` 

953 

954 We can use arguments: 

955 

956 ``` 

957 features: { 

958 "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), 

959 "gender": FixedLenFeature([], dtype=tf.string), 

960 "kws": VarLenFeature(dtype=tf.string), 

961 } 

962 ``` 

963 

964 And the expected output is: 

965 

966 ```python 

967 { 

968 "age": [[0], [-1]], 

969 "gender": [["f"], ["f"]], 

970 "kws": SparseTensor( 

971 indices=[[0, 0], [0, 1], [1, 0]], 

972 values=["code", "art", "sports"] 

973 dense_shape=[2, 2]), 

974 } 

975 ``` 

976 

977 Args: 

978 file_pattern: List of files or patterns of file paths containing 

979 `Example` records. See `tf.io.gfile.glob` for pattern rules. 

980 batch_size: An int representing the number of records to combine 

981 in a single batch. 

982 features: A `dict` mapping feature keys to `FixedLenFeature` or 

983 `VarLenFeature` values. See `tf.io.parse_example`. 

984 reader: A function or class that can be 

985 called with a `filenames` tensor and (optional) `reader_args` and returns 

986 a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. 

987 label_key: (Optional) A string corresponding to the key labels are stored in 

988 `tf.Examples`. If provided, it must be one of the `features` key, 

989 otherwise results in `ValueError`. 

990 reader_args: Additional arguments to pass to the reader class. 

991 num_epochs: Integer specifying the number of times to read through the 

992 dataset. If None, cycles through the dataset forever. Defaults to `None`. 

993 shuffle: A boolean, indicates whether the input should be shuffled. Defaults 

994 to `True`. 

995 shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity 

996 ensures better shuffling but would increase memory usage and startup time. 

997 shuffle_seed: Randomization seed to use for shuffling. 

998 prefetch_buffer_size: Number of feature batches to prefetch in order to 

999 improve performance. Recommended value is the number of batches consumed 

1000 per training step. Defaults to auto-tune. 

1001 reader_num_threads: Number of threads used to read `Example` records. If >1, 

1002 the results will be interleaved. Defaults to `1`. 

1003 parser_num_threads: Number of threads to use for parsing `Example` tensors 

1004 into a dictionary of `Feature` tensors. Defaults to `2`. 

1005 sloppy_ordering: If `True`, reading performance will be improved at 

1006 the cost of non-deterministic ordering. If `False`, the order of elements 

1007 produced is deterministic prior to shuffling (elements are still 

1008 randomized if `shuffle=True`. Note that if the seed is set, then order 

1009 of elements after shuffling is deterministic). Defaults to `False`. 

1010 drop_final_batch: If `True`, and the batch size does not evenly divide the 

1011 input dataset size, the final smaller batch will be dropped. Defaults to 

1012 `False`. 

1013 

1014 Returns: 

1015 A dataset of `dict` elements, (or a tuple of `dict` elements and label). 

1016 Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects. 

1017 

1018 Raises: 

1019 TypeError: If `reader` is of the wrong type. 

1020 ValueError: If `label_key` is not one of the `features` keys. 

1021 """ 

1022 if reader is None: 

1023 reader = core_readers.TFRecordDataset 

1024 

1025 if reader_num_threads is None: 

1026 reader_num_threads = 1 

1027 if parser_num_threads is None: 

1028 parser_num_threads = 2 

1029 if prefetch_buffer_size is None: 

1030 prefetch_buffer_size = dataset_ops.AUTOTUNE 

1031 

1032 # Create dataset of all matching filenames 

1033 dataset = dataset_ops.Dataset.list_files( 

1034 file_pattern, shuffle=shuffle, seed=shuffle_seed) 

1035 

1036 if isinstance(reader, type) and issubclass(reader, io_ops.ReaderBase): 

1037 raise TypeError("The `reader` argument must return a `Dataset` object. " 

1038 "`tf.ReaderBase` subclasses are not supported. For " 

1039 "example, pass `tf.data.TFRecordDataset` instead of " 

1040 "`tf.TFRecordReader`.") 

1041 

1042 # Read `Example` records from files as tensor objects. 

1043 if reader_args is None: 

1044 reader_args = [] 

1045 

1046 if reader_num_threads == dataset_ops.AUTOTUNE: 

1047 dataset = dataset.interleave( 

1048 lambda filename: reader(filename, *reader_args), 

1049 num_parallel_calls=reader_num_threads) 

1050 options = options_lib.Options() 

1051 options.deterministic = not sloppy_ordering 

1052 dataset = dataset.with_options(options) 

1053 else: 

1054 # Read files sequentially (if reader_num_threads=1) or in parallel 

1055 def apply_fn(dataset): 

1056 return core_readers.ParallelInterleaveDataset( 

1057 dataset, 

1058 lambda filename: reader(filename, *reader_args), 

1059 cycle_length=reader_num_threads, 

1060 block_length=1, 

1061 sloppy=sloppy_ordering, 

1062 buffer_output_elements=None, 

1063 prefetch_input_elements=None) 

1064 

1065 dataset = dataset.apply(apply_fn) 

1066 

1067 # Extract values if the `Example` tensors are stored as key-value tuples. 

1068 if dataset_ops.get_legacy_output_types(dataset) == ( 

1069 dtypes.string, dtypes.string): 

1070 dataset = map_op._MapDataset( # pylint: disable=protected-access 

1071 dataset, lambda _, v: v, use_inter_op_parallelism=False) 

1072 

1073 # Apply dataset repeat and shuffle transformations. 

1074 dataset = _maybe_shuffle_and_repeat( 

1075 dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) 

1076 

1077 # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to 

1078 # improve the shape inference, because it makes the batch dimension static. 

1079 # It is safe to do this because in that case we are repeating the input 

1080 # indefinitely, and all batches will be full-sized. 

1081 dataset = dataset.batch( 

1082 batch_size, drop_remainder=drop_final_batch or num_epochs is None) 

1083 

1084 # Parse `Example` tensors to a dictionary of `Feature` tensors. 

1085 dataset = dataset.apply( 

1086 parsing_ops.parse_example_dataset( 

1087 features, num_parallel_calls=parser_num_threads)) 

1088 

1089 if label_key: 

1090 if label_key not in features: 

1091 raise ValueError( 

1092 f"The `label_key` provided ({label_key}) must be one of the " 

1093 f"`features` keys: {features.keys()}.") 

1094 dataset = dataset.map(lambda x: (x, x.pop(label_key))) 

1095 

1096 dataset = dataset.prefetch(prefetch_buffer_size) 

1097 return dataset 

1098 

1099 

1100@tf_export(v1=["data.experimental.make_batched_features_dataset"]) 

1101def make_batched_features_dataset_v1(file_pattern, # pylint: disable=missing-docstring 

1102 batch_size, 

1103 features, 

1104 reader=None, 

1105 label_key=None, 

1106 reader_args=None, 

1107 num_epochs=None, 

1108 shuffle=True, 

1109 shuffle_buffer_size=10000, 

1110 shuffle_seed=None, 

1111 prefetch_buffer_size=None, 

1112 reader_num_threads=None, 

1113 parser_num_threads=None, 

1114 sloppy_ordering=False, 

1115 drop_final_batch=False): 

1116 return dataset_ops.DatasetV1Adapter(make_batched_features_dataset_v2( 

1117 file_pattern, batch_size, features, reader, label_key, reader_args, 

1118 num_epochs, shuffle, shuffle_buffer_size, shuffle_seed, 

1119 prefetch_buffer_size, reader_num_threads, parser_num_threads, 

1120 sloppy_ordering, drop_final_batch)) 

1121make_batched_features_dataset_v1.__doc__ = ( 

1122 make_batched_features_dataset_v2.__doc__) 

1123 

1124 

1125def _get_file_names(file_pattern, shuffle): 

1126 """Parse list of file names from pattern, optionally shuffled. 

1127 

1128 Args: 

1129 file_pattern: File glob pattern, or list of glob patterns. 

1130 shuffle: Whether to shuffle the order of file names. 

1131 

1132 Returns: 

1133 List of file names matching `file_pattern`. 

1134 

1135 Raises: 

1136 ValueError: If `file_pattern` is empty, or pattern matches no files. 

1137 """ 

1138 if isinstance(file_pattern, list): 

1139 if not file_pattern: 

1140 raise ValueError("Argument `file_pattern` should not be empty.") 

1141 file_names = [] 

1142 for entry in file_pattern: 

1143 file_names.extend(gfile.Glob(entry)) 

1144 else: 

1145 file_names = list(gfile.Glob(file_pattern)) 

1146 

1147 if not file_names: 

1148 raise ValueError(f"No files match `file_pattern` {file_pattern}.") 

1149 

1150 # Sort files so it will be deterministic for unit tests. 

1151 if not shuffle: 

1152 file_names = sorted(file_names) 

1153 return file_names 

1154 

1155 

1156@tf_export("data.experimental.SqlDataset", v1=[]) 

1157class SqlDatasetV2(dataset_ops.DatasetSource): 

1158 """A `Dataset` consisting of the results from a SQL query. 

1159 

1160 `SqlDataset` allows a user to read data from the result set of a SQL query. 

1161 For example: 

1162 

1163 ```python 

1164 dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3", 

1165 "SELECT name, age FROM people", 

1166 (tf.string, tf.int32)) 

1167 # Prints the rows of the result set of the above query. 

1168 for element in dataset: 

1169 print(element) 

1170 ``` 

1171 """ 

1172 

1173 def __init__(self, driver_name, data_source_name, query, output_types): 

1174 """Creates a `SqlDataset`. 

1175 

1176 Args: 

1177 driver_name: A 0-D `tf.string` tensor containing the database type. 

1178 Currently, the only supported value is 'sqlite'. 

1179 data_source_name: A 0-D `tf.string` tensor containing a connection string 

1180 to connect to the database. 

1181 query: A 0-D `tf.string` tensor containing the SQL query to execute. 

1182 output_types: A tuple of `tf.DType` objects representing the types of the 

1183 columns returned by `query`. 

1184 """ 

1185 self._driver_name = ops.convert_to_tensor( 

1186 driver_name, dtype=dtypes.string, name="driver_name") 

1187 self._data_source_name = ops.convert_to_tensor( 

1188 data_source_name, dtype=dtypes.string, name="data_source_name") 

1189 self._query = ops.convert_to_tensor( 

1190 query, dtype=dtypes.string, name="query") 

1191 self._element_spec = nest.map_structure( 

1192 lambda dtype: tensor_spec.TensorSpec([], dtype), output_types) 

1193 variant_tensor = gen_experimental_dataset_ops.sql_dataset( 

1194 self._driver_name, self._data_source_name, self._query, 

1195 **self._flat_structure) 

1196 super(SqlDatasetV2, self).__init__(variant_tensor) 

1197 

1198 @property 

1199 def element_spec(self): 

1200 return self._element_spec 

1201 

1202 

1203@tf_export(v1=["data.experimental.SqlDataset"]) 

1204class SqlDatasetV1(dataset_ops.DatasetV1Adapter): 

1205 """A `Dataset` consisting of the results from a SQL query.""" 

1206 

1207 @functools.wraps(SqlDatasetV2.__init__) 

1208 def __init__(self, driver_name, data_source_name, query, output_types): 

1209 wrapped = SqlDatasetV2(driver_name, data_source_name, query, output_types) 

1210 super(SqlDatasetV1, self).__init__(wrapped) 

1211 

1212 

1213if tf2.enabled(): 

1214 CsvDataset = CsvDatasetV2 

1215 SqlDataset = SqlDatasetV2 

1216 make_batched_features_dataset = make_batched_features_dataset_v2 

1217 make_csv_dataset = make_csv_dataset_v2 

1218else: 

1219 CsvDataset = CsvDatasetV1 

1220 SqlDataset = SqlDatasetV1 

1221 make_batched_features_dataset = make_batched_features_dataset_v1 

1222 make_csv_dataset = make_csv_dataset_v1