Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/dataset_utils.py: 10%

266 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras image dataset loading utilities.""" 

16 

17import multiprocessing 

18import os 

19import random 

20import time 

21import warnings 

22 

23import numpy as np 

24import tensorflow.compat.v2 as tf 

25 

26from keras.src.utils import io_utils 

27 

28# isort: off 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export("keras.utils.split_dataset", v1=[]) 

33def split_dataset( 

34 dataset, left_size=None, right_size=None, shuffle=False, seed=None 

35): 

36 """Split a dataset into a left half and a right half (e.g. train / test). 

37 

38 Args: 

39 dataset: A `tf.data.Dataset` object, or a list/tuple of arrays with the 

40 same length. 

41 left_size: If float (in the range `[0, 1]`), it signifies 

42 the fraction of the data to pack in the left dataset. If integer, it 

43 signifies the number of samples to pack in the left dataset. If 

44 `None`, it uses the complement to `right_size`. Defaults to `None`. 

45 right_size: If float (in the range `[0, 1]`), it signifies 

46 the fraction of the data to pack in the right dataset. If integer, it 

47 signifies the number of samples to pack in the right dataset. If 

48 `None`, it uses the complement to `left_size`. Defaults to `None`. 

49 shuffle: Boolean, whether to shuffle the data before splitting it. 

50 seed: A random seed for shuffling. 

51 

52 Returns: 

53 A tuple of two `tf.data.Dataset` objects: the left and right splits. 

54 

55 Example: 

56 

57 >>> data = np.random.random(size=(1000, 4)) 

58 >>> left_ds, right_ds = tf.keras.utils.split_dataset(data, left_size=0.8) 

59 >>> int(left_ds.cardinality()) 

60 800 

61 >>> int(right_ds.cardinality()) 

62 200 

63 

64 """ 

65 dataset_type_spec = _get_type_spec(dataset) 

66 

67 if dataset_type_spec not in [tf.data.Dataset, list, tuple, np.ndarray]: 

68 raise TypeError( 

69 "The `dataset` argument must be either a `tf.data.Dataset` " 

70 "object or a list/tuple of arrays. " 

71 f"Received: dataset={dataset} of type {type(dataset)}" 

72 ) 

73 

74 if right_size is None and left_size is None: 

75 raise ValueError( 

76 "At least one of the `left_size` or `right_size` " 

77 "must be specified. Received: left_size=None and " 

78 "right_size=None" 

79 ) 

80 

81 dataset_as_list = _convert_dataset_to_list(dataset, dataset_type_spec) 

82 

83 if shuffle: 

84 if seed is None: 

85 seed = random.randint(0, int(1e6)) 

86 random.seed(seed) 

87 random.shuffle(dataset_as_list) 

88 

89 total_length = len(dataset_as_list) 

90 

91 left_size, right_size = _rescale_dataset_split_sizes( 

92 left_size, right_size, total_length 

93 ) 

94 left_split = list(dataset_as_list[:left_size]) 

95 right_split = list(dataset_as_list[-right_size:]) 

96 

97 left_split = _restore_dataset_from_list( 

98 left_split, dataset_type_spec, dataset 

99 ) 

100 right_split = _restore_dataset_from_list( 

101 right_split, dataset_type_spec, dataset 

102 ) 

103 

104 left_split = tf.data.Dataset.from_tensor_slices(left_split) 

105 right_split = tf.data.Dataset.from_tensor_slices(right_split) 

106 

107 # apply batching to the splits if the dataset is batched 

108 if dataset_type_spec is tf.data.Dataset and is_batched(dataset): 

109 batch_size = get_batch_size(dataset) 

110 if batch_size is not None: 

111 left_split = left_split.batch(batch_size) 

112 right_split = right_split.batch(batch_size) 

113 

114 left_split = left_split.prefetch(tf.data.AUTOTUNE) 

115 right_split = right_split.prefetch(tf.data.AUTOTUNE) 

116 

117 return left_split, right_split 

118 

119 

120def _convert_dataset_to_list( 

121 dataset, 

122 dataset_type_spec, 

123 data_size_warning_flag=True, 

124 ensure_shape_similarity=True, 

125): 

126 """Convert `tf.data.Dataset` object or list/tuple of NumPy arrays to a list. 

127 

128 Args: 

129 dataset : A `tf.data.Dataset` object or a list/tuple of arrays. 

130 dataset_type_spec : the type of the dataset 

131 data_size_warning_flag (bool, optional): If set to True, a warning will 

132 be issued if the dataset takes longer than 10 seconds to iterate. 

133 Defaults to `True`. 

134 ensure_shape_similarity (bool, optional): If set to True, the shape of 

135 the first sample will be used to validate the shape of rest of the 

136 samples. Defaults to `True`. 

137 

138 Returns: 

139 List: A list of tuples/NumPy arrays. 

140 """ 

141 dataset_iterator = _get_data_iterator_from_dataset( 

142 dataset, dataset_type_spec 

143 ) 

144 dataset_as_list = [] 

145 

146 start_time = time.time() 

147 for sample in _get_next_sample( 

148 dataset_iterator, 

149 ensure_shape_similarity, 

150 data_size_warning_flag, 

151 start_time, 

152 ): 

153 if dataset_type_spec in [tuple, list]: 

154 # The try-except here is for NumPy 1.24 compatibility, see: 

155 # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html 

156 try: 

157 arr = np.array(sample) 

158 except ValueError: 

159 arr = np.array(sample, dtype=object) 

160 dataset_as_list.append(arr) 

161 else: 

162 dataset_as_list.append(sample) 

163 

164 return dataset_as_list 

165 

166 

167def _get_data_iterator_from_dataset(dataset, dataset_type_spec): 

168 """Get the iterator from a dataset. 

169 

170 Args: 

171 dataset : A `tf.data.Dataset` object or a list/tuple of arrays. 

172 dataset_type_spec : the type of the dataset 

173 

174 Raises: 

175 ValueError: 

176 - If the dataset is empty. 

177 - If the dataset is not a `tf.data.Dataset` object 

178 or a list/tuple of arrays. 

179 - If the dataset is a list/tuple of arrays and the 

180 length of the list/tuple is not equal to the number 

181 

182 Returns: 

183 iterator: An `iterator` object. 

184 """ 

185 if dataset_type_spec == list: 

186 if len(dataset) == 0: 

187 raise ValueError( 

188 "Received an empty list dataset. " 

189 "Please provide a non-empty list of arrays." 

190 ) 

191 

192 if _get_type_spec(dataset[0]) is np.ndarray: 

193 expected_shape = dataset[0].shape 

194 for i, element in enumerate(dataset): 

195 if np.array(element).shape[0] != expected_shape[0]: 

196 raise ValueError( 

197 "Received a list of NumPy arrays with different " 

198 f"lengths. Mismatch found at index {i}, " 

199 f"Expected shape={expected_shape} " 

200 f"Received shape={np.array(element).shape}." 

201 "Please provide a list of NumPy arrays with " 

202 "the same length." 

203 ) 

204 else: 

205 raise ValueError( 

206 "Expected a list of `numpy.ndarray` objects," 

207 f"Received: {type(dataset[0])}" 

208 ) 

209 

210 return iter(zip(*dataset)) 

211 elif dataset_type_spec == tuple: 

212 if len(dataset) == 0: 

213 raise ValueError( 

214 "Received an empty list dataset." 

215 "Please provide a non-empty tuple of arrays." 

216 ) 

217 

218 if _get_type_spec(dataset[0]) is np.ndarray: 

219 expected_shape = dataset[0].shape 

220 for i, element in enumerate(dataset): 

221 if np.array(element).shape[0] != expected_shape[0]: 

222 raise ValueError( 

223 "Received a tuple of NumPy arrays with different " 

224 f"lengths. Mismatch found at index {i}, " 

225 f"Expected shape={expected_shape} " 

226 f"Received shape={np.array(element).shape}." 

227 "Please provide a tuple of NumPy arrays with " 

228 "the same length." 

229 ) 

230 else: 

231 raise ValueError( 

232 "Expected a tuple of `numpy.ndarray` objects, " 

233 f"Received: {type(dataset[0])}" 

234 ) 

235 

236 return iter(zip(*dataset)) 

237 elif dataset_type_spec == tf.data.Dataset: 

238 if is_batched(dataset): 

239 dataset = dataset.unbatch() 

240 return iter(dataset) 

241 elif dataset_type_spec == np.ndarray: 

242 return iter(dataset) 

243 

244 

245def _get_next_sample( 

246 dataset_iterator, 

247 ensure_shape_similarity, 

248 data_size_warning_flag, 

249 start_time, 

250): 

251 """ "Yield data samples from the `dataset_iterator`. 

252 

253 Args: 

254 dataset_iterator : An `iterator` object. 

255 ensure_shape_similarity (bool, optional): If set to True, the shape of 

256 the first sample will be used to validate the shape of rest of the 

257 samples. Defaults to `True`. 

258 data_size_warning_flag (bool, optional): If set to True, a warning will 

259 be issued if the dataset takes longer than 10 seconds to iterate. 

260 Defaults to `True`. 

261 start_time (float): the start time of the dataset iteration. this is 

262 used only if `data_size_warning_flag` is set to true. 

263 

264 Raises: 

265 ValueError: - If the dataset is empty. 

266 - If `ensure_shape_similarity` is set to True and the 

267 shape of the first sample is not equal to the shape of 

268 atleast one of the rest of the samples. 

269 

270 Yields: 

271 data_sample: A tuple/list of numpy arrays. 

272 """ 

273 try: 

274 dataset_iterator = iter(dataset_iterator) 

275 first_sample = next(dataset_iterator) 

276 if isinstance(first_sample, (tf.Tensor, np.ndarray)): 

277 first_sample_shape = np.array(first_sample).shape 

278 else: 

279 first_sample_shape = None 

280 ensure_shape_similarity = False 

281 yield first_sample 

282 except StopIteration: 

283 raise ValueError( 

284 "Received an empty Dataset. `dataset` must " 

285 "be a non-empty list/tuple of `numpy.ndarray` objects " 

286 "or `tf.data.Dataset` objects." 

287 ) 

288 

289 for i, sample in enumerate(dataset_iterator): 

290 if ensure_shape_similarity: 

291 if first_sample_shape != np.array(sample).shape: 

292 raise ValueError( 

293 "All `dataset` samples must have same shape, " 

294 f"Expected shape: {np.array(first_sample).shape} " 

295 f"Received shape: {np.array(sample).shape} at index " 

296 f"{i}." 

297 ) 

298 if data_size_warning_flag: 

299 if i % 10 == 0: 

300 cur_time = time.time() 

301 # warns user if the dataset is too large to iterate within 10s 

302 if int(cur_time - start_time) > 10 and data_size_warning_flag: 

303 warnings.warn( 

304 "The dataset is taking longer than 10 seconds to " 

305 "iterate over. This may be due to the size of the " 

306 "dataset. Keep in mind that the `split_dataset` " 

307 "utility is only for small in-memory dataset " 

308 "(e.g. < 10,000 samples).", 

309 category=ResourceWarning, 

310 source="split_dataset", 

311 ) 

312 data_size_warning_flag = False 

313 yield sample 

314 

315 

316def _restore_dataset_from_list( 

317 dataset_as_list, dataset_type_spec, original_dataset 

318): 

319 """Restore the dataset from the list of arrays.""" 

320 if dataset_type_spec in [tuple, list]: 

321 return tuple(np.array(sample) for sample in zip(*dataset_as_list)) 

322 elif dataset_type_spec == tf.data.Dataset: 

323 if isinstance(original_dataset.element_spec, dict): 

324 restored_dataset = {} 

325 for d in dataset_as_list: 

326 for k, v in d.items(): 

327 if k not in restored_dataset: 

328 restored_dataset[k] = [v] 

329 else: 

330 restored_dataset[k].append(v) 

331 return restored_dataset 

332 else: 

333 return tuple(np.array(sample) for sample in zip(*dataset_as_list)) 

334 return dataset_as_list 

335 

336 

337def _rescale_dataset_split_sizes(left_size, right_size, total_length): 

338 """Rescale the dataset split sizes. 

339 

340 We want to ensure that the sum of 

341 the split sizes is equal to the total length of the dataset. 

342 

343 Args: 

344 left_size : The size of the left dataset split. 

345 right_size : The size of the right dataset split. 

346 total_length : The total length of the dataset. 

347 

348 Raises: 

349 TypeError: - If `left_size` or `right_size` is not an integer or float. 

350 ValueError: - If `left_size` or `right_size` is negative or greater 

351 than 1 or greater than `total_length`. 

352 

353 Returns: 

354 tuple: A tuple of rescaled left_size and right_size 

355 """ 

356 left_size_type = type(left_size) 

357 right_size_type = type(right_size) 

358 

359 # check both left_size and right_size are integers or floats 

360 if (left_size is not None and left_size_type not in [int, float]) and ( 

361 right_size is not None and right_size_type not in [int, float] 

362 ): 

363 raise TypeError( 

364 "Invalid `left_size` and `right_size` Types. Expected: " 

365 "integer or float or None, Received: type(left_size)=" 

366 f"{left_size_type} and type(right_size)={right_size_type}" 

367 ) 

368 

369 # check left_size is a integer or float 

370 if left_size is not None and left_size_type not in [int, float]: 

371 raise TypeError( 

372 "Invalid `left_size` Type. Expected: int or float or None, " 

373 f"Received: type(left_size)={left_size_type}. " 

374 ) 

375 

376 # check right_size is a integer or float 

377 if right_size is not None and right_size_type not in [int, float]: 

378 raise TypeError( 

379 "Invalid `right_size` Type. " 

380 "Expected: int or float or None," 

381 f"Received: type(right_size)={right_size_type}." 

382 ) 

383 

384 # check left_size and right_size are non-zero 

385 if left_size == 0 and right_size == 0: 

386 raise ValueError( 

387 "Both `left_size` and `right_size` are zero. " 

388 "At least one of the split sizes must be non-zero." 

389 ) 

390 

391 # check left_size is non-negative and less than 1 and less than total_length 

392 if ( 

393 left_size_type == int 

394 and (left_size <= 0 or left_size >= total_length) 

395 or left_size_type == float 

396 and (left_size <= 0 or left_size >= 1) 

397 ): 

398 raise ValueError( 

399 "`left_size` should be either a positive integer " 

400 f"smaller than {total_length}, or a float " 

401 "within the range `[0, 1]`. Received: left_size=" 

402 f"{left_size}" 

403 ) 

404 

405 # check right_size is non-negative and less than 1 and less than 

406 # total_length 

407 if ( 

408 right_size_type == int 

409 and (right_size <= 0 or right_size >= total_length) 

410 or right_size_type == float 

411 and (right_size <= 0 or right_size >= 1) 

412 ): 

413 raise ValueError( 

414 "`right_size` should be either a positive integer " 

415 f"and smaller than {total_length} or a float " 

416 "within the range `[0, 1]`. Received: right_size=" 

417 f"{right_size}" 

418 ) 

419 

420 # check sum of left_size and right_size is less than or equal to 

421 # total_length 

422 if ( 

423 right_size_type == left_size_type == float 

424 and right_size + left_size > 1 

425 ): 

426 raise ValueError( 

427 "The sum of `left_size` and `right_size` is greater " 

428 "than 1. It must be less than or equal to 1." 

429 ) 

430 

431 if left_size_type == float: 

432 left_size = round(left_size * total_length) 

433 elif left_size_type == int: 

434 left_size = float(left_size) 

435 

436 if right_size_type == float: 

437 right_size = round(right_size * total_length) 

438 elif right_size_type == int: 

439 right_size = float(right_size) 

440 

441 if left_size is None: 

442 left_size = total_length - right_size 

443 elif right_size is None: 

444 right_size = total_length - left_size 

445 

446 if left_size + right_size > total_length: 

447 raise ValueError( 

448 "The sum of `left_size` and `right_size` should " 

449 "be smaller than the {total_length}. " 

450 f"Received: left_size + right_size = {left_size+right_size}" 

451 f"and total_length = {total_length}" 

452 ) 

453 

454 for split, side in [(left_size, "left"), (right_size, "right")]: 

455 if split == 0: 

456 raise ValueError( 

457 f"With `dataset` of length={total_length}, `left_size`=" 

458 f"{left_size} and `right_size`={right_size}." 

459 f"Resulting {side} side dataset split will be empty. " 

460 "Adjust any of the aforementioned parameters" 

461 ) 

462 

463 left_size, right_size = int(left_size), int(right_size) 

464 return left_size, right_size 

465 

466 

467def _get_type_spec(dataset): 

468 """Get the type spec of the dataset.""" 

469 if isinstance(dataset, tuple): 

470 return tuple 

471 elif isinstance(dataset, list): 

472 return list 

473 elif isinstance(dataset, np.ndarray): 

474 return np.ndarray 

475 elif isinstance(dataset, dict): 

476 return dict 

477 elif isinstance(dataset, tf.data.Dataset): 

478 return tf.data.Dataset 

479 else: 

480 return None 

481 

482 

483def is_batched(tf_dataset): 

484 """ "Check if the `tf.data.Dataset` is batched.""" 

485 return hasattr(tf_dataset, "_batch_size") 

486 

487 

488def get_batch_size(tf_dataset): 

489 """Get the batch size of the dataset.""" 

490 if is_batched(tf_dataset): 

491 return tf_dataset._batch_size 

492 else: 

493 return None 

494 

495 

496def index_directory( 

497 directory, 

498 labels, 

499 formats, 

500 class_names=None, 

501 shuffle=True, 

502 seed=None, 

503 follow_links=False, 

504): 

505 """Make list of all files in `directory`, with their labels. 

506 

507 Args: 

508 directory: Directory where the data is located. 

509 If `labels` is "inferred", it should contain 

510 subdirectories, each containing files for a class. 

511 Otherwise, the directory structure is ignored. 

512 labels: Either "inferred" 

513 (labels are generated from the directory structure), 

514 None (no labels), 

515 or a list/tuple of integer labels of the same size as the number of 

516 valid files found in the directory. Labels should be sorted according 

517 to the alphanumeric order of the image file paths 

518 (obtained via `os.walk(directory)` in Python). 

519 formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). 

520 class_names: Only valid if "labels" is "inferred". This is the explicit 

521 list of class names (must match names of subdirectories). Used 

522 to control the order of the classes 

523 (otherwise alphanumerical order is used). 

524 shuffle: Whether to shuffle the data. Default: True. 

525 If set to False, sorts the data in alphanumeric order. 

526 seed: Optional random seed for shuffling. 

527 follow_links: Whether to visits subdirectories pointed to by symlinks. 

528 

529 Returns: 

530 tuple (file_paths, labels, class_names). 

531 file_paths: list of file paths (strings). 

532 labels: list of matching integer labels (same length as file_paths) 

533 class_names: names of the classes corresponding to these labels, in 

534 order. 

535 """ 

536 if labels != "inferred": 

537 # in the explicit/no-label cases, index from the parent directory down. 

538 subdirs = [""] 

539 class_names = subdirs 

540 else: 

541 subdirs = [] 

542 for subdir in sorted(tf.io.gfile.listdir(directory)): 

543 if tf.io.gfile.isdir(tf.io.gfile.join(directory, subdir)): 

544 if subdir.endswith("/"): 

545 subdir = subdir[:-1] 

546 subdirs.append(subdir) 

547 if not class_names: 

548 class_names = subdirs 

549 else: 

550 if set(class_names) != set(subdirs): 

551 raise ValueError( 

552 "The `class_names` passed did not match the " 

553 "names of the subdirectories of the target directory. " 

554 f"Expected: {subdirs}, but received: {class_names}" 

555 ) 

556 class_indices = dict(zip(class_names, range(len(class_names)))) 

557 

558 # Build an index of the files 

559 # in the different class subfolders. 

560 pool = multiprocessing.pool.ThreadPool() 

561 results = [] 

562 filenames = [] 

563 

564 for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs): 

565 results.append( 

566 pool.apply_async( 

567 index_subdirectory, 

568 (dirpath, class_indices, follow_links, formats), 

569 ) 

570 ) 

571 labels_list = [] 

572 for res in results: 

573 partial_filenames, partial_labels = res.get() 

574 labels_list.append(partial_labels) 

575 filenames += partial_filenames 

576 if labels not in ("inferred", None): 

577 if len(labels) != len(filenames): 

578 raise ValueError( 

579 "Expected the lengths of `labels` to match the number " 

580 "of files in the target directory. len(labels) is " 

581 f"{len(labels)} while we found {len(filenames)} files " 

582 f"in directory {directory}." 

583 ) 

584 class_names = sorted(set(labels)) 

585 else: 

586 i = 0 

587 labels = np.zeros((len(filenames),), dtype="int32") 

588 for partial_labels in labels_list: 

589 labels[i : i + len(partial_labels)] = partial_labels 

590 i += len(partial_labels) 

591 

592 if labels is None: 

593 io_utils.print_msg(f"Found {len(filenames)} files.") 

594 else: 

595 io_utils.print_msg( 

596 f"Found {len(filenames)} files belonging " 

597 f"to {len(class_names)} classes." 

598 ) 

599 pool.close() 

600 pool.join() 

601 file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames] 

602 

603 if shuffle: 

604 # Shuffle globally to erase macro-structure 

605 if seed is None: 

606 seed = np.random.randint(1e6) 

607 rng = np.random.RandomState(seed) 

608 rng.shuffle(file_paths) 

609 rng = np.random.RandomState(seed) 

610 rng.shuffle(labels) 

611 return file_paths, labels, class_names 

612 

613 

614def iter_valid_files(directory, follow_links, formats): 

615 if not follow_links: 

616 walk = tf.io.gfile.walk(directory) 

617 else: 

618 walk = os.walk(directory, followlinks=follow_links) 

619 for root, _, files in sorted(walk, key=lambda x: x[0]): 

620 for fname in sorted(files): 

621 if fname.lower().endswith(formats): 

622 yield root, fname 

623 

624 

625def index_subdirectory(directory, class_indices, follow_links, formats): 

626 """Recursively walks directory and list image paths and their class index. 

627 

628 Args: 

629 directory: string, target directory. 

630 class_indices: dict mapping class names to their index. 

631 follow_links: boolean, whether to recursively follow subdirectories 

632 (if False, we only list top-level images in `directory`). 

633 formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). 

634 

635 Returns: 

636 tuple `(filenames, labels)`. `filenames` is a list of relative file 

637 paths, and `labels` is a list of integer labels corresponding to these 

638 files. 

639 """ 

640 dirname = os.path.basename(directory) 

641 valid_files = iter_valid_files(directory, follow_links, formats) 

642 labels = [] 

643 filenames = [] 

644 for root, fname in valid_files: 

645 labels.append(class_indices[dirname]) 

646 absolute_path = tf.io.gfile.join(root, fname) 

647 relative_path = tf.io.gfile.join( 

648 dirname, os.path.relpath(absolute_path, directory) 

649 ) 

650 filenames.append(relative_path) 

651 return filenames, labels 

652 

653 

654def get_training_or_validation_split(samples, labels, validation_split, subset): 

655 """Potentially restict samples & labels to a training or validation split. 

656 

657 Args: 

658 samples: List of elements. 

659 labels: List of corresponding labels. 

660 validation_split: Float, fraction of data to reserve for validation. 

661 subset: Subset of the data to return. 

662 Either "training", "validation", or None. If None, we return all of the 

663 data. 

664 

665 Returns: 

666 tuple (samples, labels), potentially restricted to the specified subset. 

667 """ 

668 if not validation_split: 

669 return samples, labels 

670 

671 num_val_samples = int(validation_split * len(samples)) 

672 if subset == "training": 

673 print(f"Using {len(samples) - num_val_samples} files for training.") 

674 samples = samples[:-num_val_samples] 

675 labels = labels[:-num_val_samples] 

676 elif subset == "validation": 

677 print(f"Using {num_val_samples} files for validation.") 

678 samples = samples[-num_val_samples:] 

679 labels = labels[-num_val_samples:] 

680 else: 

681 raise ValueError( 

682 '`subset` must be either "training" ' 

683 f'or "validation", received: {subset}' 

684 ) 

685 return samples, labels 

686 

687 

688def labels_to_dataset(labels, label_mode, num_classes): 

689 """Create a tf.data.Dataset from the list/tuple of labels. 

690 

691 Args: 

692 labels: list/tuple of labels to be converted into a tf.data.Dataset. 

693 label_mode: String describing the encoding of `labels`. Options are: 

694 - 'binary' indicates that the labels (there can be only 2) are encoded as 

695 `float32` scalars with values 0 or 1 (e.g. for `binary_crossentropy`). 

696 - 'categorical' means that the labels are mapped into a categorical 

697 vector. (e.g. for `categorical_crossentropy` loss). 

698 num_classes: number of classes of labels. 

699 

700 Returns: 

701 A `Dataset` instance. 

702 """ 

703 label_ds = tf.data.Dataset.from_tensor_slices(labels) 

704 if label_mode == "binary": 

705 label_ds = label_ds.map( 

706 lambda x: tf.expand_dims(tf.cast(x, "float32"), axis=-1), 

707 num_parallel_calls=tf.data.AUTOTUNE, 

708 ) 

709 elif label_mode == "categorical": 

710 label_ds = label_ds.map( 

711 lambda x: tf.one_hot(x, num_classes), 

712 num_parallel_calls=tf.data.AUTOTUNE, 

713 ) 

714 return label_ds 

715 

716 

717def check_validation_split_arg(validation_split, subset, shuffle, seed): 

718 """Raise errors in case of invalid argument values. 

719 

720 Args: 

721 validation_split: float between 0 and 1, fraction of data to reserve for 

722 validation. 

723 subset: One of "training", "validation" or "both". Only used if 

724 `validation_split` is set. 

725 shuffle: Whether to shuffle the data. Either True or False. 

726 seed: random seed for shuffling and transformations. 

727 """ 

728 if validation_split and not 0 < validation_split < 1: 

729 raise ValueError( 

730 "`validation_split` must be between 0 and 1, " 

731 f"received: {validation_split}" 

732 ) 

733 if (validation_split or subset) and not (validation_split and subset): 

734 raise ValueError( 

735 "If `subset` is set, `validation_split` must be set, and inversely." 

736 ) 

737 if subset not in ("training", "validation", "both", None): 

738 raise ValueError( 

739 '`subset` must be either "training", ' 

740 f'"validation" or "both", received: {subset}' 

741 ) 

742 if validation_split and shuffle and seed is None: 

743 raise ValueError( 

744 "If using `validation_split` and shuffling the data, you must " 

745 "provide a `seed` argument, to make sure that there is no " 

746 "overlap between the training and validation subset." 

747 ) 

748