Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/data_utils.py: 22%

449 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities for file download and caching.""" 

17 

18import functools 

19import hashlib 

20import multiprocessing.dummy 

21import os 

22import pathlib 

23import queue 

24import random 

25import shutil 

26import tarfile 

27import threading 

28import time 

29import typing 

30import urllib 

31import warnings 

32import weakref 

33import zipfile 

34from abc import abstractmethod 

35from contextlib import closing 

36 

37import numpy as np 

38import tensorflow.compat.v2 as tf 

39from six.moves.urllib.parse import urlsplit 

40 

41from keras.src.utils import io_utils 

42from keras.src.utils import tf_inspect 

43from keras.src.utils.generic_utils import Progbar 

44 

45# isort: off 

46from tensorflow.python.util.tf_export import keras_export 

47from six.moves.urllib.request import urlopen 

48 

49# Required to support google internal urlretrieve 

50if True: # This gets transformed to `if sys.version_info[0] == 2:` in OSS. 

51 

52 def urlretrieve(url, filename, reporthook=None, data=None): 

53 """Replacement for `urlretrieve` for Python 2. 

54 

55 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 

56 `urllib` module, known to have issues with proxy management. 

57 

58 Args: 

59 url: url to retrieve. 

60 filename: where to store the retrieved data locally. 

61 reporthook: a hook function that will be called once on 

62 establishment of the network connection and once after each block 

63 read thereafter. The hook will be passed three arguments; a count 

64 of blocks transferred so far, a block size in bytes, and the total 

65 size of the file. 

66 data: `data` argument passed to `urlopen`. 

67 """ 

68 

69 def chunk_read(response, chunk_size=8192, reporthook=None): 

70 content_type = response.info().get("Content-Length") 

71 total_size = -1 

72 if content_type is not None: 

73 total_size = int(content_type.strip()) 

74 count = 0 

75 while True: 

76 chunk = response.read(chunk_size) 

77 count += 1 

78 if reporthook is not None: 

79 reporthook(count, chunk_size, total_size) 

80 if chunk: 

81 yield chunk 

82 else: 

83 break 

84 

85 response = urlopen(url, data) 

86 with open(filename, "wb") as fd: 

87 for chunk in chunk_read(response, reporthook=reporthook): 

88 fd.write(chunk) 

89 

90else: 

91 from urllib.request import urlretrieve 

92 

93 

94def is_generator_or_sequence(x): 

95 """Check if `x` is a Keras generator type.""" 

96 builtin_iterators = (str, list, tuple, dict, set, frozenset) 

97 if isinstance(x, (tf.Tensor, np.ndarray) + builtin_iterators): 

98 return False 

99 return ( 

100 tf_inspect.isgenerator(x) 

101 or isinstance(x, Sequence) 

102 or isinstance(x, typing.Iterator) 

103 ) 

104 

105 

106def _resolve_path(path): 

107 return os.path.realpath(os.path.abspath(path)) 

108 

109 

110def _is_path_in_dir(path, base_dir): 

111 return _resolve_path(os.path.join(base_dir, path)).startswith(base_dir) 

112 

113 

114def _is_link_in_dir(info, base): 

115 tip = _resolve_path(os.path.join(base, os.path.dirname(info.name))) 

116 return _is_path_in_dir(info.linkname, base_dir=tip) 

117 

118 

119def _filter_safe_paths(members): 

120 base_dir = _resolve_path(".") 

121 for finfo in members: 

122 valid_path = False 

123 if _is_path_in_dir(finfo.name, base_dir): 

124 valid_path = True 

125 yield finfo 

126 elif finfo.issym() or finfo.islnk(): 

127 if _is_link_in_dir(finfo, base_dir): 

128 valid_path = True 

129 yield finfo 

130 if not valid_path: 

131 warnings.warn( 

132 "Skipping invalid path during archive extraction: " 

133 f"'{finfo.name}'." 

134 ) 

135 

136 

137def _extract_archive(file_path, path=".", archive_format="auto"): 

138 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 

139 

140 Args: 

141 file_path: Path to the archive file. 

142 path: Where to extract the archive file. 

143 archive_format: Archive format to try for extracting the file. 

144 Options are `'auto'`, `'tar'`, `'zip'`, and `None`. 

145 `'tar'` includes tar, tar.gz, and tar.bz files. 

146 The default 'auto' is `['tar', 'zip']`. 

147 `None` or an empty list will return no matches found. 

148 

149 Returns: 

150 True if a match was found and an archive extraction was completed, 

151 False otherwise. 

152 """ 

153 if archive_format is None: 

154 return False 

155 if archive_format == "auto": 

156 archive_format = ["tar", "zip"] 

157 if isinstance(archive_format, str): 

158 archive_format = [archive_format] 

159 

160 file_path = io_utils.path_to_string(file_path) 

161 path = io_utils.path_to_string(path) 

162 

163 for archive_type in archive_format: 

164 if archive_type == "tar": 

165 open_fn = tarfile.open 

166 is_match_fn = tarfile.is_tarfile 

167 if archive_type == "zip": 

168 open_fn = zipfile.ZipFile 

169 is_match_fn = zipfile.is_zipfile 

170 

171 if is_match_fn(file_path): 

172 with open_fn(file_path) as archive: 

173 try: 

174 if zipfile.is_zipfile(file_path): 

175 # Zip archive. 

176 archive.extractall(path) 

177 else: 

178 # Tar archive, perhaps unsafe. Filter paths. 

179 archive.extractall( 

180 path, members=_filter_safe_paths(archive) 

181 ) 

182 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 

183 if os.path.exists(path): 

184 if os.path.isfile(path): 

185 os.remove(path) 

186 else: 

187 shutil.rmtree(path) 

188 raise 

189 return True 

190 return False 

191 

192 

193@keras_export("keras.utils.get_file") 

194def get_file( 

195 fname=None, 

196 origin=None, 

197 untar=False, 

198 md5_hash=None, 

199 file_hash=None, 

200 cache_subdir="datasets", 

201 hash_algorithm="auto", 

202 extract=False, 

203 archive_format="auto", 

204 cache_dir=None, 

205): 

206 """Downloads a file from a URL if it not already in the cache. 

207 

208 By default the file at the url `origin` is downloaded to the 

209 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 

210 and given the filename `fname`. The final location of a file 

211 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 

212 

213 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 

214 Passing a hash will verify the file after download. The command line 

215 programs `shasum` and `sha256sum` can compute the hash. 

216 

217 Example: 

218 

219 ```python 

220 path_to_downloaded_file = tf.keras.utils.get_file( 

221 origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", 

222 extract=True, 

223 ) 

224 ``` 

225 

226 Args: 

227 fname: Name of the file. If an absolute path `/path/to/file.txt` is 

228 specified the file will be saved at that location. If `None`, the 

229 name of the file at `origin` will be used. 

230 origin: Original URL of the file. 

231 untar: Deprecated in favor of `extract` argument. 

232 boolean, whether the file should be decompressed 

233 md5_hash: Deprecated in favor of `file_hash` argument. 

234 md5 hash of the file for verification 

235 file_hash: The expected hash string of the file after download. 

236 The sha256 and md5 hash algorithms are both supported. 

237 cache_subdir: Subdirectory under the Keras cache dir where the file is 

238 saved. If an absolute path `/path/to/folder` is 

239 specified the file will be saved at that location. 

240 hash_algorithm: Select the hash algorithm to verify the file. 

241 options are `'md5'`, `'sha256'`, and `'auto'`. 

242 The default 'auto' detects the hash algorithm in use. 

243 extract: True tries extracting the file as an Archive, like tar or zip. 

244 archive_format: Archive format to try for extracting the file. 

245 Options are `'auto'`, `'tar'`, `'zip'`, and `None`. 

246 `'tar'` includes tar, tar.gz, and tar.bz files. 

247 The default `'auto'` corresponds to `['tar', 'zip']`. 

248 None or an empty list will return no matches found. 

249 cache_dir: Location to store cached files, when None it 

250 defaults to the default directory `~/.keras/`. 

251 

252 Returns: 

253 Path to the downloaded file. 

254 

255 ⚠️ **Warning on malicious downloads** ⚠️ 

256 

257 Downloading something from the Internet carries a risk. 

258 NEVER download a file/archive if you do not trust the source. 

259 We recommend that you specify the `file_hash` argument 

260 (if the hash of the source file is known) to make sure that the file you 

261 are getting is the one you expect. 

262 """ 

263 if origin is None: 

264 raise ValueError( 

265 'Please specify the "origin" argument (URL of the file ' 

266 "to download)." 

267 ) 

268 

269 if cache_dir is None: 

270 cache_dir = os.path.join(os.path.expanduser("~"), ".keras") 

271 if md5_hash is not None and file_hash is None: 

272 file_hash = md5_hash 

273 hash_algorithm = "md5" 

274 datadir_base = os.path.expanduser(cache_dir) 

275 if not os.access(datadir_base, os.W_OK): 

276 datadir_base = os.path.join("/tmp", ".keras") 

277 datadir = os.path.join(datadir_base, cache_subdir) 

278 _makedirs_exist_ok(datadir) 

279 

280 fname = io_utils.path_to_string(fname) 

281 if not fname: 

282 fname = os.path.basename(urlsplit(origin).path) 

283 if not fname: 

284 raise ValueError( 

285 "Can't parse the file name from the origin provided: " 

286 f"'{origin}'." 

287 "Please specify the `fname` as the input param." 

288 ) 

289 

290 if untar: 

291 if fname.endswith(".tar.gz"): 

292 fname = pathlib.Path(fname) 

293 # The 2 `.with_suffix()` are because of `.tar.gz` as pathlib 

294 # considers it as 2 suffixes. 

295 fname = fname.with_suffix("").with_suffix("") 

296 fname = str(fname) 

297 untar_fpath = os.path.join(datadir, fname) 

298 fpath = untar_fpath + ".tar.gz" 

299 else: 

300 fpath = os.path.join(datadir, fname) 

301 

302 download = False 

303 if os.path.exists(fpath): 

304 # File found; verify integrity if a hash was provided. 

305 if file_hash is not None: 

306 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 

307 io_utils.print_msg( 

308 "A local file was found, but it seems to be " 

309 f"incomplete or outdated because the {hash_algorithm} " 

310 "file hash does not match the original value of " 

311 f"{file_hash} " 

312 "so we will re-download the data." 

313 ) 

314 download = True 

315 else: 

316 download = True 

317 

318 if download: 

319 io_utils.print_msg(f"Downloading data from {origin}") 

320 

321 class DLProgbar: 

322 """Manage progress bar state for use in urlretrieve.""" 

323 

324 def __init__(self): 

325 self.progbar = None 

326 self.finished = False 

327 

328 def __call__(self, block_num, block_size, total_size): 

329 if not self.progbar: 

330 if total_size == -1: 

331 total_size = None 

332 self.progbar = Progbar(total_size) 

333 current = block_num * block_size 

334 

335 if total_size is None: 

336 self.progbar.update(current) 

337 else: 

338 if current < total_size: 

339 self.progbar.update(current) 

340 elif not self.finished: 

341 self.progbar.update(self.progbar.target) 

342 self.finished = True 

343 

344 error_msg = "URL fetch failure on {}: {} -- {}" 

345 try: 

346 try: 

347 urlretrieve(origin, fpath, DLProgbar()) 

348 except urllib.error.HTTPError as e: 

349 raise Exception(error_msg.format(origin, e.code, e.msg)) 

350 except urllib.error.URLError as e: 

351 raise Exception(error_msg.format(origin, e.errno, e.reason)) 

352 except (Exception, KeyboardInterrupt): 

353 if os.path.exists(fpath): 

354 os.remove(fpath) 

355 raise 

356 

357 # Validate download if succeeded and user provided an expected hash 

358 # Security conscious users would get the hash of the file from a 

359 # separate channel and pass it to this API to prevent MITM / corruption: 

360 if os.path.exists(fpath) and file_hash is not None: 

361 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 

362 raise ValueError( 

363 "Incomplete or corrupted file detected. " 

364 f"The {hash_algorithm} " 

365 "file hash does not match the provided value " 

366 f"of {file_hash}." 

367 ) 

368 

369 if untar: 

370 if not os.path.exists(untar_fpath): 

371 _extract_archive(fpath, datadir, archive_format="tar") 

372 return untar_fpath 

373 

374 if extract: 

375 _extract_archive(fpath, datadir, archive_format) 

376 

377 return fpath 

378 

379 

380def _makedirs_exist_ok(datadir): 

381 os.makedirs(datadir, exist_ok=True) 

382 

383 

384def _resolve_hasher(algorithm, file_hash=None): 

385 """Returns hash algorithm as hashlib function.""" 

386 if algorithm == "sha256": 

387 return hashlib.sha256() 

388 

389 if algorithm == "auto" and file_hash is not None and len(file_hash) == 64: 

390 return hashlib.sha256() 

391 

392 # This is used only for legacy purposes. 

393 return hashlib.md5() 

394 

395 

396def _hash_file(fpath, algorithm="sha256", chunk_size=65535): 

397 """Calculates a file sha256 or md5 hash. 

398 

399 Example: 

400 

401 ```python 

402 _hash_file('/path/to/file.zip') 

403 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 

404 ``` 

405 

406 Args: 

407 fpath: Path to the file being validated. 

408 algorithm: Hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`. 

409 The default `'auto'` detects the hash algorithm in use. 

410 chunk_size: Bytes to read at a time, important for large files. 

411 

412 Returns: 

413 The file hash. 

414 """ 

415 if isinstance(algorithm, str): 

416 hasher = _resolve_hasher(algorithm) 

417 else: 

418 hasher = algorithm 

419 

420 with open(fpath, "rb") as fpath_file: 

421 for chunk in iter(lambda: fpath_file.read(chunk_size), b""): 

422 hasher.update(chunk) 

423 

424 return hasher.hexdigest() 

425 

426 

427def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): 

428 """Validates a file against a sha256 or md5 hash. 

429 

430 Args: 

431 fpath: path to the file being validated 

432 file_hash: The expected hash string of the file. 

433 The sha256 and md5 hash algorithms are both supported. 

434 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 

435 The default 'auto' detects the hash algorithm in use. 

436 chunk_size: Bytes to read at a time, important for large files. 

437 

438 Returns: 

439 Whether the file is valid 

440 """ 

441 hasher = _resolve_hasher(algorithm, file_hash) 

442 

443 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 

444 return True 

445 else: 

446 return False 

447 

448 

449class ThreadsafeIter: 

450 """Wrap an iterator with a lock and propagate exceptions to all threads.""" 

451 

452 def __init__(self, it): 

453 self.it = it 

454 self.lock = threading.Lock() 

455 

456 # After a generator throws an exception all subsequent next() calls 

457 # raise a StopIteration Exception. This, however, presents an issue when 

458 # mixing generators and threading because it means the order of 

459 # retrieval need not match the order in which the generator was called. 

460 # This can make it appear that a generator exited normally when in fact 

461 # the terminating exception is just in a different thread. In order to 

462 # provide thread safety, once self.it has thrown an exception we 

463 # continue to throw the same exception. 

464 self._exception = None 

465 

466 def __iter__(self): 

467 return self 

468 

469 def next(self): 

470 return self.__next__() 

471 

472 def __next__(self): 

473 with self.lock: 

474 if self._exception: 

475 raise self._exception 

476 

477 try: 

478 return next(self.it) 

479 except Exception as e: 

480 self._exception = e 

481 raise 

482 

483 

484def threadsafe_generator(f): 

485 @functools.wraps(f) 

486 def g(*a, **kw): 

487 return ThreadsafeIter(f(*a, **kw)) 

488 

489 return g 

490 

491 

492@keras_export("keras.utils.Sequence") 

493class Sequence: 

494 """Base object for fitting to a sequence of data, such as a dataset. 

495 

496 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 

497 If you want to modify your dataset between epochs, you may implement 

498 `on_epoch_end`. The method `__getitem__` should return a complete batch. 

499 

500 Notes: 

501 

502 `Sequence` is a safer way to do multiprocessing. This structure guarantees 

503 that the network will only train once on each sample per epoch, which is not 

504 the case with generators. 

505 

506 Examples: 

507 

508 ```python 

509 from skimage.io import imread 

510 from skimage.transform import resize 

511 import numpy as np 

512 import math 

513 

514 # Here, `x_set` is list of path to the images 

515 # and `y_set` are the associated classes. 

516 

517 class CIFAR10Sequence(tf.keras.utils.Sequence): 

518 

519 def __init__(self, x_set, y_set, batch_size): 

520 self.x, self.y = x_set, y_set 

521 self.batch_size = batch_size 

522 

523 def __len__(self): 

524 return math.ceil(len(self.x) / self.batch_size) 

525 

526 def __getitem__(self, idx): 

527 low = idx * self.batch_size 

528 # Cap upper bound at array length; the last batch may be smaller 

529 # if the total number of items is not a multiple of batch size. 

530 high = min(low + self.batch_size, len(self.x)) 

531 batch_x = self.x[low:high] 

532 batch_y = self.y[low:high] 

533 

534 return np.array([ 

535 resize(imread(file_name), (200, 200)) 

536 for file_name in batch_x]), np.array(batch_y) 

537 ``` 

538 """ 

539 

540 @abstractmethod 

541 def __getitem__(self, index): 

542 """Gets batch at position `index`. 

543 

544 Args: 

545 index: position of the batch in the Sequence. 

546 

547 Returns: 

548 A batch 

549 """ 

550 raise NotImplementedError 

551 

552 @abstractmethod 

553 def __len__(self): 

554 """Number of batch in the Sequence. 

555 

556 Returns: 

557 The number of batches in the Sequence. 

558 """ 

559 raise NotImplementedError 

560 

561 def on_epoch_end(self): 

562 """Method called at the end of every epoch.""" 

563 pass 

564 

565 def __iter__(self): 

566 """Create a generator that iterate over the Sequence.""" 

567 for item in (self[i] for i in range(len(self))): 

568 yield item 

569 

570 

571def iter_sequence_infinite(seq): 

572 """Iterates indefinitely over a Sequence. 

573 

574 Args: 

575 seq: `Sequence` instance. 

576 

577 Yields: 

578 Batches of data from the `Sequence`. 

579 """ 

580 while True: 

581 for item in seq: 

582 yield item 

583 

584 

585# Global variables to be shared across processes 

586_SHARED_SEQUENCES = {} 

587# We use a Value to provide unique id to different processes. 

588_SEQUENCE_COUNTER = None 

589 

590 

591# Because multiprocessing pools are inherently unsafe, starting from a clean 

592# state can be essential to avoiding deadlocks. In order to accomplish this, we 

593# need to be able to check on the status of Pools that we create. 

594_DATA_POOLS = weakref.WeakSet() 

595_WORKER_ID_QUEUE = None # Only created if needed. 

596_WORKER_IDS = set() 

597_FORCE_THREADPOOL = False 

598_FORCE_THREADPOOL_LOCK = threading.RLock() 

599 

600 

601def dont_use_multiprocessing_pool(f): 

602 @functools.wraps(f) 

603 def wrapped(*args, **kwargs): 

604 with _FORCE_THREADPOOL_LOCK: 

605 global _FORCE_THREADPOOL 

606 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True 

607 out = f(*args, **kwargs) 

608 _FORCE_THREADPOOL = old_force_threadpool 

609 return out 

610 

611 return wrapped 

612 

613 

614def get_pool_class(use_multiprocessing): 

615 global _FORCE_THREADPOOL 

616 if not use_multiprocessing or _FORCE_THREADPOOL: 

617 return multiprocessing.dummy.Pool # ThreadPool 

618 return multiprocessing.Pool 

619 

620 

621def get_worker_id_queue(): 

622 """Lazily create the queue to track worker ids.""" 

623 global _WORKER_ID_QUEUE 

624 if _WORKER_ID_QUEUE is None: 

625 _WORKER_ID_QUEUE = multiprocessing.Queue() 

626 return _WORKER_ID_QUEUE 

627 

628 

629def init_pool(seqs): 

630 global _SHARED_SEQUENCES 

631 _SHARED_SEQUENCES = seqs 

632 

633 

634def get_index(uid, i): 

635 """Get the value from the Sequence `uid` at index `i`. 

636 

637 To allow multiple Sequences to be used at the same time, we use `uid` to 

638 get a specific one. A single Sequence would cause the validation to 

639 overwrite the training Sequence. 

640 

641 Args: 

642 uid: int, Sequence identifier 

643 i: index 

644 

645 Returns: 

646 The value at index `i`. 

647 """ 

648 return _SHARED_SEQUENCES[uid][i] 

649 

650 

651@keras_export("keras.utils.SequenceEnqueuer") 

652class SequenceEnqueuer: 

653 """Base class to enqueue inputs. 

654 

655 The task of an Enqueuer is to use parallelism to speed up preprocessing. 

656 This is done with processes or threads. 

657 

658 Example: 

659 

660 ```python 

661 enqueuer = SequenceEnqueuer(...) 

662 enqueuer.start() 

663 datas = enqueuer.get() 

664 for data in datas: 

665 # Use the inputs; training, evaluating, predicting. 

666 # ... stop sometime. 

667 enqueuer.stop() 

668 ``` 

669 

670 The `enqueuer.get()` should be an infinite stream of data. 

671 """ 

672 

673 def __init__(self, sequence, use_multiprocessing=False): 

674 self.sequence = sequence 

675 self.use_multiprocessing = use_multiprocessing 

676 

677 global _SEQUENCE_COUNTER 

678 if _SEQUENCE_COUNTER is None: 

679 try: 

680 _SEQUENCE_COUNTER = multiprocessing.Value("i", 0) 

681 except OSError: 

682 # In this case the OS does not allow us to use 

683 # multiprocessing. We resort to an int 

684 # for enqueuer indexing. 

685 _SEQUENCE_COUNTER = 0 

686 

687 if isinstance(_SEQUENCE_COUNTER, int): 

688 self.uid = _SEQUENCE_COUNTER 

689 _SEQUENCE_COUNTER += 1 

690 else: 

691 # Doing Multiprocessing.Value += x is not process-safe. 

692 with _SEQUENCE_COUNTER.get_lock(): 

693 self.uid = _SEQUENCE_COUNTER.value 

694 _SEQUENCE_COUNTER.value += 1 

695 

696 self.workers = 0 

697 self.executor_fn = None 

698 self.queue = None 

699 self.run_thread = None 

700 self.stop_signal = None 

701 

702 def is_running(self): 

703 return self.stop_signal is not None and not self.stop_signal.is_set() 

704 

705 def start(self, workers=1, max_queue_size=10): 

706 """Starts the handler's workers. 

707 

708 Args: 

709 workers: Number of workers. 

710 max_queue_size: queue size 

711 (when full, workers could block on `put()`) 

712 """ 

713 if self.use_multiprocessing: 

714 self.executor_fn = self._get_executor_init(workers) 

715 else: 

716 # We do not need the init since it's threads. 

717 self.executor_fn = lambda _: get_pool_class(False)(workers) 

718 self.workers = workers 

719 self.queue = queue.Queue(max_queue_size) 

720 self.stop_signal = threading.Event() 

721 self.run_thread = threading.Thread(target=self._run) 

722 self.run_thread.daemon = True 

723 self.run_thread.start() 

724 

725 def _send_sequence(self): 

726 """Sends current Iterable to all workers.""" 

727 # For new processes that may spawn 

728 _SHARED_SEQUENCES[self.uid] = self.sequence 

729 

730 def stop(self, timeout=None): 

731 """Stops running threads and wait for them to exit, if necessary. 

732 

733 Should be called by the same thread which called `start()`. 

734 

735 Args: 

736 timeout: maximum time to wait on `thread.join()` 

737 """ 

738 self.stop_signal.set() 

739 with self.queue.mutex: 

740 self.queue.queue.clear() 

741 self.queue.unfinished_tasks = 0 

742 self.queue.not_full.notify() 

743 self.run_thread.join(timeout) 

744 _SHARED_SEQUENCES[self.uid] = None 

745 

746 def __del__(self): 

747 if self.is_running(): 

748 self.stop() 

749 

750 @abstractmethod 

751 def _run(self): 

752 """Submits request to the executor and queue the `Future` objects.""" 

753 raise NotImplementedError 

754 

755 @abstractmethod 

756 def _get_executor_init(self, workers): 

757 """Gets the Pool initializer for multiprocessing. 

758 

759 Args: 

760 workers: Number of workers. 

761 

762 Returns: 

763 Function, a Function to initialize the pool 

764 """ 

765 raise NotImplementedError 

766 

767 @abstractmethod 

768 def get(self): 

769 """Creates a generator to extract data from the queue. 

770 

771 Skip the data if it is `None`. 

772 # Returns 

773 Generator yielding tuples `(inputs, targets)` 

774 or `(inputs, targets, sample_weights)`. 

775 """ 

776 raise NotImplementedError 

777 

778 

779@keras_export("keras.utils.OrderedEnqueuer") 

780class OrderedEnqueuer(SequenceEnqueuer): 

781 """Builds a Enqueuer from a Sequence. 

782 

783 Args: 

784 sequence: A `tf.keras.utils.data_utils.Sequence` object. 

785 use_multiprocessing: use multiprocessing if True, otherwise threading 

786 shuffle: whether to shuffle the data at the beginning of each epoch 

787 """ 

788 

789 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 

790 super().__init__(sequence, use_multiprocessing) 

791 self.shuffle = shuffle 

792 

793 def _get_executor_init(self, workers): 

794 """Gets the Pool initializer for multiprocessing. 

795 

796 Args: 

797 workers: Number of workers. 

798 

799 Returns: 

800 Function, a Function to initialize the pool 

801 """ 

802 

803 def pool_fn(seqs): 

804 pool = get_pool_class(True)( 

805 workers, 

806 initializer=init_pool_generator, 

807 initargs=(seqs, None, get_worker_id_queue()), 

808 ) 

809 _DATA_POOLS.add(pool) 

810 return pool 

811 

812 return pool_fn 

813 

814 def _wait_queue(self): 

815 """Wait for the queue to be empty.""" 

816 while True: 

817 time.sleep(0.1) 

818 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 

819 return 

820 

821 def _run(self): 

822 """Submits request to the executor and queue the `Future` objects.""" 

823 sequence = list(range(len(self.sequence))) 

824 self._send_sequence() # Share the initial sequence 

825 while True: 

826 if self.shuffle: 

827 random.shuffle(sequence) 

828 

829 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 

830 for i in sequence: 

831 if self.stop_signal.is_set(): 

832 return 

833 

834 self.queue.put( 

835 executor.apply_async(get_index, (self.uid, i)), 

836 block=True, 

837 ) 

838 

839 # Done with the current epoch, waiting for the final batches 

840 self._wait_queue() 

841 

842 if self.stop_signal.is_set(): 

843 # We're done 

844 return 

845 

846 # Call the internal on epoch end. 

847 self.sequence.on_epoch_end() 

848 self._send_sequence() # Update the pool 

849 

850 def get(self): 

851 """Creates a generator to extract data from the queue. 

852 

853 Skip the data if it is `None`. 

854 

855 Yields: 

856 The next element in the queue, i.e. a tuple 

857 `(inputs, targets)` or 

858 `(inputs, targets, sample_weights)`. 

859 """ 

860 while self.is_running(): 

861 try: 

862 inputs = self.queue.get(block=True, timeout=5).get() 

863 if self.is_running(): 

864 self.queue.task_done() 

865 if inputs is not None: 

866 yield inputs 

867 except queue.Empty: 

868 pass 

869 except Exception as e: 

870 self.stop() 

871 raise e 

872 

873 

874def init_pool_generator(gens, random_seed=None, id_queue=None): 

875 """Initializer function for pool workers. 

876 

877 Args: 

878 gens: State which should be made available to worker processes. 

879 random_seed: An optional value with which to seed child processes. 

880 id_queue: A multiprocessing Queue of worker ids. This is used to indicate 

881 that a worker process was created by Keras and can be terminated using 

882 the cleanup_all_keras_forkpools utility. 

883 """ 

884 global _SHARED_SEQUENCES 

885 _SHARED_SEQUENCES = gens 

886 

887 worker_proc = multiprocessing.current_process() 

888 

889 # name isn't used for anything, but setting a more descriptive name is 

890 # helpful when diagnosing orphaned processes. 

891 worker_proc.name = f"Keras_worker_{worker_proc.name}" 

892 

893 if random_seed is not None: 

894 np.random.seed(random_seed + worker_proc.ident) 

895 

896 if id_queue is not None: 

897 # If a worker dies during init, the pool will just create a replacement. 

898 id_queue.put(worker_proc.ident, block=True, timeout=0.1) 

899 

900 

901def next_sample(uid): 

902 """Gets the next value from the generator `uid`. 

903 

904 To allow multiple generators to be used at the same time, we use `uid` to 

905 get a specific one. A single generator would cause the validation to 

906 overwrite the training generator. 

907 

908 Args: 

909 uid: int, generator identifier 

910 

911 Returns: 

912 The next value of generator `uid`. 

913 """ 

914 return next(_SHARED_SEQUENCES[uid]) 

915 

916 

917@keras_export("keras.utils.GeneratorEnqueuer") 

918class GeneratorEnqueuer(SequenceEnqueuer): 

919 """Builds a queue out of a data generator. 

920 

921 The provided generator can be finite in which case the class will throw 

922 a `StopIteration` exception. 

923 

924 Args: 

925 generator: a generator function which yields data 

926 use_multiprocessing: use multiprocessing if True, otherwise threading 

927 random_seed: Initial seed for workers, 

928 will be incremented by one for each worker. 

929 """ 

930 

931 def __init__(self, generator, use_multiprocessing=False, random_seed=None): 

932 super().__init__(generator, use_multiprocessing) 

933 self.random_seed = random_seed 

934 

935 def _get_executor_init(self, workers): 

936 """Gets the Pool initializer for multiprocessing. 

937 

938 Args: 

939 workers: Number of works. 

940 

941 Returns: 

942 A Function to initialize the pool 

943 """ 

944 

945 def pool_fn(seqs): 

946 pool = get_pool_class(True)( 

947 workers, 

948 initializer=init_pool_generator, 

949 initargs=(seqs, self.random_seed, get_worker_id_queue()), 

950 ) 

951 _DATA_POOLS.add(pool) 

952 return pool 

953 

954 return pool_fn 

955 

956 def _run(self): 

957 """Submits request to the executor and queue the `Future` objects.""" 

958 self._send_sequence() # Share the initial generator 

959 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 

960 while True: 

961 if self.stop_signal.is_set(): 

962 return 

963 

964 self.queue.put( 

965 executor.apply_async(next_sample, (self.uid,)), block=True 

966 ) 

967 

968 def get(self): 

969 """Creates a generator to extract data from the queue. 

970 

971 Skip the data if it is `None`. 

972 

973 Yields: 

974 The next element in the queue, i.e. a tuple 

975 `(inputs, targets)` or 

976 `(inputs, targets, sample_weights)`. 

977 """ 

978 try: 

979 while self.is_running(): 

980 inputs = self.queue.get(block=True).get() 

981 self.queue.task_done() 

982 if inputs is not None: 

983 yield inputs 

984 except StopIteration: 

985 # Special case for finite generators 

986 last_ones = [] 

987 while self.queue.qsize() > 0: 

988 last_ones.append(self.queue.get(block=True)) 

989 # Wait for them to complete 

990 for f in last_ones: 

991 f.wait() 

992 # Keep the good ones 

993 last_ones = [ 

994 future.get() for future in last_ones if future.successful() 

995 ] 

996 for inputs in last_ones: 

997 if inputs is not None: 

998 yield inputs 

999 except Exception as e: 

1000 self.stop() 

1001 if "generator already executing" in str(e): 

1002 raise RuntimeError( 

1003 "Your generator is NOT thread-safe. " 

1004 "Keras requires a thread-safe generator when " 

1005 "`use_multiprocessing=False, workers > 1`. " 

1006 ) 

1007 raise e 

1008 

1009 

1010@keras_export( 

1011 "keras.utils.pad_sequences", "keras.preprocessing.sequence.pad_sequences" 

1012) 

1013def pad_sequences( 

1014 sequences, 

1015 maxlen=None, 

1016 dtype="int32", 

1017 padding="pre", 

1018 truncating="pre", 

1019 value=0.0, 

1020): 

1021 """Pads sequences to the same length. 

1022 

1023 This function transforms a list (of length `num_samples`) 

1024 of sequences (lists of integers) 

1025 into a 2D Numpy array of shape `(num_samples, num_timesteps)`. 

1026 `num_timesteps` is either the `maxlen` argument if provided, 

1027 or the length of the longest sequence in the list. 

1028 

1029 Sequences that are shorter than `num_timesteps` 

1030 are padded with `value` until they are `num_timesteps` long. 

1031 

1032 Sequences longer than `num_timesteps` are truncated 

1033 so that they fit the desired length. 

1034 

1035 The position where padding or truncation happens is determined by 

1036 the arguments `padding` and `truncating`, respectively. 

1037 Pre-padding or removing values from the beginning of the sequence is the 

1038 default. 

1039 

1040 >>> sequence = [[1], [2, 3], [4, 5, 6]] 

1041 >>> tf.keras.utils.pad_sequences(sequence) 

1042 array([[0, 0, 1], 

1043 [0, 2, 3], 

1044 [4, 5, 6]], dtype=int32) 

1045 

1046 >>> tf.keras.utils.pad_sequences(sequence, value=-1) 

1047 array([[-1, -1, 1], 

1048 [-1, 2, 3], 

1049 [ 4, 5, 6]], dtype=int32) 

1050 

1051 >>> tf.keras.utils.pad_sequences(sequence, padding='post') 

1052 array([[1, 0, 0], 

1053 [2, 3, 0], 

1054 [4, 5, 6]], dtype=int32) 

1055 

1056 >>> tf.keras.utils.pad_sequences(sequence, maxlen=2) 

1057 array([[0, 1], 

1058 [2, 3], 

1059 [5, 6]], dtype=int32) 

1060 

1061 Args: 

1062 sequences: List of sequences (each sequence is a list of integers). 

1063 maxlen: Optional Int, maximum length of all sequences. If not provided, 

1064 sequences will be padded to the length of the longest individual 

1065 sequence. 

1066 dtype: (Optional, defaults to `"int32"`). Type of the output sequences. 

1067 To pad sequences with variable length strings, you can use `object`. 

1068 padding: String, "pre" or "post" (optional, defaults to `"pre"`): 

1069 pad either before or after each sequence. 

1070 truncating: String, "pre" or "post" (optional, defaults to `"pre"`): 

1071 remove values from sequences larger than 

1072 `maxlen`, either at the beginning or at the end of the sequences. 

1073 value: Float or String, padding value. (Optional, defaults to 0.) 

1074 

1075 Returns: 

1076 Numpy array with shape `(len(sequences), maxlen)` 

1077 

1078 Raises: 

1079 ValueError: In case of invalid values for `truncating` or `padding`, 

1080 or in case of invalid shape for a `sequences` entry. 

1081 """ 

1082 if not hasattr(sequences, "__len__"): 

1083 raise ValueError("`sequences` must be iterable.") 

1084 num_samples = len(sequences) 

1085 

1086 lengths = [] 

1087 sample_shape = () 

1088 flag = True 

1089 

1090 # take the sample shape from the first non empty sequence 

1091 # checking for consistency in the main loop below. 

1092 

1093 for x in sequences: 

1094 try: 

1095 lengths.append(len(x)) 

1096 if flag and len(x): 

1097 sample_shape = np.asarray(x).shape[1:] 

1098 flag = False 

1099 except TypeError as e: 

1100 raise ValueError( 

1101 "`sequences` must be a list of iterables. " 

1102 f"Found non-iterable: {str(x)}" 

1103 ) from e 

1104 

1105 if maxlen is None: 

1106 maxlen = np.max(lengths) 

1107 

1108 is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype( 

1109 dtype, np.unicode_ 

1110 ) 

1111 if isinstance(value, str) and dtype != object and not is_dtype_str: 

1112 raise ValueError( 

1113 f"`dtype` {dtype} is not compatible with `value`'s type: " 

1114 f"{type(value)}\nYou should set `dtype=object` for variable length " 

1115 "strings." 

1116 ) 

1117 

1118 x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) 

1119 for idx, s in enumerate(sequences): 

1120 if not len(s): 

1121 continue # empty list/array was found 

1122 if truncating == "pre": 

1123 trunc = s[-maxlen:] 

1124 elif truncating == "post": 

1125 trunc = s[:maxlen] 

1126 else: 

1127 raise ValueError(f'Truncating type "{truncating}" not understood') 

1128 

1129 # check `trunc` has expected shape 

1130 trunc = np.asarray(trunc, dtype=dtype) 

1131 if trunc.shape[1:] != sample_shape: 

1132 raise ValueError( 

1133 f"Shape of sample {trunc.shape[1:]} of sequence at " 

1134 f"position {idx} is different from expected shape " 

1135 f"{sample_shape}" 

1136 ) 

1137 

1138 if padding == "post": 

1139 x[idx, : len(trunc)] = trunc 

1140 elif padding == "pre": 

1141 x[idx, -len(trunc) :] = trunc 

1142 else: 

1143 raise ValueError(f'Padding type "{padding}" not understood') 

1144 return x 

1145