Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/data_utils.py: 25%

368 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=g-import-not-at-top 

16"""Utilities for file download and caching.""" 

17 

18from abc import abstractmethod 

19from contextlib import closing 

20import functools 

21import hashlib 

22import multiprocessing 

23import multiprocessing.dummy 

24import os 

25import queue 

26import random 

27import shutil 

28import sys # pylint: disable=unused-import 

29import tarfile 

30import threading 

31import time 

32import typing 

33import urllib 

34import weakref 

35import zipfile 

36 

37import numpy as np 

38 

39from tensorflow.python.framework import ops 

40from six.moves.urllib.request import urlopen 

41from tensorflow.python.keras.utils import tf_inspect 

42from tensorflow.python.keras.utils.generic_utils import Progbar 

43from tensorflow.python.keras.utils.io_utils import path_to_string 

44from tensorflow.python.util.tf_export import keras_export 

45 

46# Required to support google internal urlretrieve 

47if sys.version_info[0] == 2: 

48 

49 def urlretrieve(url, filename, reporthook=None, data=None): 

50 """Replacement for `urlretrieve` for Python 2. 

51 

52 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 

53 `urllib` module, known to have issues with proxy management. 

54 

55 Args: 

56 url: url to retrieve. 

57 filename: where to store the retrieved data locally. 

58 reporthook: a hook function that will be called once on establishment of 

59 the network connection and once after each block read thereafter. The 

60 hook will be passed three arguments; a count of blocks transferred so 

61 far, a block size in bytes, and the total size of the file. 

62 data: `data` argument passed to `urlopen`. 

63 """ 

64 

65 def chunk_read(response, chunk_size=8192, reporthook=None): 

66 content_type = response.info().get('Content-Length') 

67 total_size = -1 

68 if content_type is not None: 

69 total_size = int(content_type.strip()) 

70 count = 0 

71 while True: 

72 chunk = response.read(chunk_size) 

73 count += 1 

74 if reporthook is not None: 

75 reporthook(count, chunk_size, total_size) 

76 if chunk: 

77 yield chunk 

78 else: 

79 break 

80 

81 response = urlopen(url, data) 

82 with open(filename, 'wb') as fd: 

83 for chunk in chunk_read(response, reporthook=reporthook): 

84 fd.write(chunk) 

85else: 

86 from urllib.request import urlretrieve # pylint: disable=g-importing-member 

87 

88 

89def is_generator_or_sequence(x): 

90 """Check if `x` is a Keras generator type.""" 

91 builtin_iterators = (str, list, tuple, dict, set, frozenset) 

92 if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): 

93 return False 

94 return (tf_inspect.isgenerator(x) or 

95 isinstance(x, Sequence) or 

96 isinstance(x, typing.Iterator)) 

97 

98 

99def _extract_archive(file_path, path='.', archive_format='auto'): 

100 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 

101 

102 Args: 

103 file_path: path to the archive file 

104 path: path to extract the archive file 

105 archive_format: Archive format to try for extracting the file. 

106 Options are 'auto', 'tar', 'zip', and None. 

107 'tar' includes tar, tar.gz, and tar.bz files. 

108 The default 'auto' is ['tar', 'zip']. 

109 None or an empty list will return no matches found. 

110 

111 Returns: 

112 True if a match was found and an archive extraction was completed, 

113 False otherwise. 

114 """ 

115 if archive_format is None: 

116 return False 

117 if archive_format == 'auto': 

118 archive_format = ['tar', 'zip'] 

119 if isinstance(archive_format, str): 

120 archive_format = [archive_format] 

121 

122 file_path = path_to_string(file_path) 

123 path = path_to_string(path) 

124 

125 for archive_type in archive_format: 

126 if archive_type == 'tar': 

127 open_fn = tarfile.open 

128 is_match_fn = tarfile.is_tarfile 

129 if archive_type == 'zip': 

130 open_fn = zipfile.ZipFile 

131 is_match_fn = zipfile.is_zipfile 

132 

133 if is_match_fn(file_path): 

134 with open_fn(file_path) as archive: 

135 try: 

136 archive.extractall(path) 

137 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 

138 if os.path.exists(path): 

139 if os.path.isfile(path): 

140 os.remove(path) 

141 else: 

142 shutil.rmtree(path) 

143 raise 

144 return True 

145 return False 

146 

147 

148@keras_export('keras.utils.get_file') 

149def get_file(fname, 

150 origin, 

151 untar=False, 

152 md5_hash=None, 

153 file_hash=None, 

154 cache_subdir='datasets', 

155 hash_algorithm='auto', 

156 extract=False, 

157 archive_format='auto', 

158 cache_dir=None): 

159 """Downloads a file from a URL if it not already in the cache. 

160 

161 By default the file at the url `origin` is downloaded to the 

162 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 

163 and given the filename `fname`. The final location of a file 

164 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 

165 

166 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 

167 Passing a hash will verify the file after download. The command line 

168 programs `shasum` and `sha256sum` can compute the hash. 

169 

170 Example: 

171 

172 ```python 

173 path_to_downloaded_file = tf.keras.utils.get_file( 

174 "flower_photos", 

175 "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", 

176 untar=True) 

177 ``` 

178 

179 Args: 

180 fname: Name of the file. If an absolute path `/path/to/file.txt` is 

181 specified the file will be saved at that location. 

182 origin: Original URL of the file. 

183 untar: Deprecated in favor of `extract` argument. 

184 boolean, whether the file should be decompressed 

185 md5_hash: Deprecated in favor of `file_hash` argument. 

186 md5 hash of the file for verification 

187 file_hash: The expected hash string of the file after download. 

188 The sha256 and md5 hash algorithms are both supported. 

189 cache_subdir: Subdirectory under the Keras cache dir where the file is 

190 saved. If an absolute path `/path/to/folder` is 

191 specified the file will be saved at that location. 

192 hash_algorithm: Select the hash algorithm to verify the file. 

193 options are `'md5'`, `'sha256'`, and `'auto'`. 

194 The default 'auto' detects the hash algorithm in use. 

195 extract: True tries extracting the file as an Archive, like tar or zip. 

196 archive_format: Archive format to try for extracting the file. 

197 Options are `'auto'`, `'tar'`, `'zip'`, and `None`. 

198 `'tar'` includes tar, tar.gz, and tar.bz files. 

199 The default `'auto'` corresponds to `['tar', 'zip']`. 

200 None or an empty list will return no matches found. 

201 cache_dir: Location to store cached files, when None it 

202 defaults to the default directory `~/.keras/`. 

203 

204 Returns: 

205 Path to the downloaded file 

206 """ 

207 if cache_dir is None: 

208 cache_dir = os.path.join(os.path.expanduser('~'), '.keras') 

209 if md5_hash is not None and file_hash is None: 

210 file_hash = md5_hash 

211 hash_algorithm = 'md5' 

212 datadir_base = os.path.expanduser(cache_dir) 

213 if not os.access(datadir_base, os.W_OK): 

214 datadir_base = os.path.join('/tmp', '.keras') 

215 datadir = os.path.join(datadir_base, cache_subdir) 

216 _makedirs_exist_ok(datadir) 

217 

218 fname = path_to_string(fname) 

219 

220 if untar: 

221 untar_fpath = os.path.join(datadir, fname) 

222 fpath = untar_fpath + '.tar.gz' 

223 else: 

224 fpath = os.path.join(datadir, fname) 

225 

226 download = False 

227 if os.path.exists(fpath): 

228 # File found; verify integrity if a hash was provided. 

229 if file_hash is not None: 

230 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 

231 print('A local file was found, but it seems to be ' 

232 'incomplete or outdated because the ' + hash_algorithm + 

233 ' file hash does not match the original value of ' + file_hash + 

234 ' so we will re-download the data.') 

235 download = True 

236 else: 

237 download = True 

238 

239 if download: 

240 print('Downloading data from', origin) 

241 

242 class ProgressTracker(object): 

243 # Maintain progbar for the lifetime of download. 

244 # This design was chosen for Python 2.7 compatibility. 

245 progbar = None 

246 

247 def dl_progress(count, block_size, total_size): 

248 if ProgressTracker.progbar is None: 

249 if total_size == -1: 

250 total_size = None 

251 ProgressTracker.progbar = Progbar(total_size) 

252 else: 

253 ProgressTracker.progbar.update(count * block_size) 

254 

255 error_msg = 'URL fetch failure on {}: {} -- {}' 

256 try: 

257 try: 

258 urlretrieve(origin, fpath, dl_progress) 

259 except urllib.error.HTTPError as e: 

260 raise Exception(error_msg.format(origin, e.code, e.msg)) 

261 except urllib.error.URLError as e: 

262 raise Exception(error_msg.format(origin, e.errno, e.reason)) 

263 except (Exception, KeyboardInterrupt) as e: 

264 if os.path.exists(fpath): 

265 os.remove(fpath) 

266 raise 

267 ProgressTracker.progbar = None 

268 

269 if untar: 

270 if not os.path.exists(untar_fpath): 

271 _extract_archive(fpath, datadir, archive_format='tar') 

272 return untar_fpath 

273 

274 if extract: 

275 _extract_archive(fpath, datadir, archive_format) 

276 

277 return fpath 

278 

279 

280def _makedirs_exist_ok(datadir): 

281 os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg 

282 

283 

284def _resolve_hasher(algorithm, file_hash=None): 

285 """Returns hash algorithm as hashlib function.""" 

286 if algorithm == 'sha256': 

287 return hashlib.sha256() 

288 

289 if algorithm == 'auto' and file_hash is not None and len(file_hash) == 64: 

290 return hashlib.sha256() 

291 

292 # This is used only for legacy purposes. 

293 return hashlib.md5() 

294 

295 

296def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 

297 """Calculates a file sha256 or md5 hash. 

298 

299 Example: 

300 

301 ```python 

302 _hash_file('/path/to/file.zip') 

303 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 

304 ``` 

305 

306 Args: 

307 fpath: path to the file being validated 

308 algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`. 

309 The default `'auto'` detects the hash algorithm in use. 

310 chunk_size: Bytes to read at a time, important for large files. 

311 

312 Returns: 

313 The file hash 

314 """ 

315 if isinstance(algorithm, str): 

316 hasher = _resolve_hasher(algorithm) 

317 else: 

318 hasher = algorithm 

319 

320 with open(fpath, 'rb') as fpath_file: 

321 for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 

322 hasher.update(chunk) 

323 

324 return hasher.hexdigest() 

325 

326 

327def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 

328 """Validates a file against a sha256 or md5 hash. 

329 

330 Args: 

331 fpath: path to the file being validated 

332 file_hash: The expected hash string of the file. 

333 The sha256 and md5 hash algorithms are both supported. 

334 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 

335 The default 'auto' detects the hash algorithm in use. 

336 chunk_size: Bytes to read at a time, important for large files. 

337 

338 Returns: 

339 Whether the file is valid 

340 """ 

341 hasher = _resolve_hasher(algorithm, file_hash) 

342 

343 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 

344 return True 

345 else: 

346 return False 

347 

348 

349class ThreadsafeIter(object): 

350 """Wrap an iterator with a lock and propagate exceptions to all threads.""" 

351 

352 def __init__(self, it): 

353 self.it = it 

354 self.lock = threading.Lock() 

355 

356 # After a generator throws an exception all subsequent next() calls raise a 

357 # StopIteration Exception. This, however, presents an issue when mixing 

358 # generators and threading because it means the order of retrieval need not 

359 # match the order in which the generator was called. This can make it appear 

360 # that a generator exited normally when in fact the terminating exception is 

361 # just in a different thread. In order to provide thread safety, once 

362 # self.it has thrown an exception we continue to throw the same exception. 

363 self._exception = None 

364 

365 def __iter__(self): 

366 return self 

367 

368 def next(self): 

369 return self.__next__() 

370 

371 def __next__(self): 

372 with self.lock: 

373 if self._exception: 

374 raise self._exception # pylint: disable=raising-bad-type 

375 

376 try: 

377 return next(self.it) 

378 except Exception as e: 

379 self._exception = e 

380 raise 

381 

382 

383def threadsafe_generator(f): 

384 

385 @functools.wraps(f) 

386 def g(*a, **kw): 

387 return ThreadsafeIter(f(*a, **kw)) 

388 

389 return g 

390 

391 

392@keras_export('keras.utils.Sequence') 

393class Sequence(object): 

394 """Base object for fitting to a sequence of data, such as a dataset. 

395 

396 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 

397 If you want to modify your dataset between epochs you may implement 

398 `on_epoch_end`. 

399 The method `__getitem__` should return a complete batch. 

400 

401 Notes: 

402 

403 `Sequence` are a safer way to do multiprocessing. This structure guarantees 

404 that the network will only train once 

405 on each sample per epoch which is not the case with generators. 

406 

407 Examples: 

408 

409 ```python 

410 from skimage.io import imread 

411 from skimage.transform import resize 

412 import numpy as np 

413 import math 

414 

415 # Here, `x_set` is list of path to the images 

416 # and `y_set` are the associated classes. 

417 

418 class CIFAR10Sequence(Sequence): 

419 

420 def __init__(self, x_set, y_set, batch_size): 

421 self.x, self.y = x_set, y_set 

422 self.batch_size = batch_size 

423 

424 def __len__(self): 

425 return math.ceil(len(self.x) / self.batch_size) 

426 

427 def __getitem__(self, idx): 

428 batch_x = self.x[idx * self.batch_size:(idx + 1) * 

429 self.batch_size] 

430 batch_y = self.y[idx * self.batch_size:(idx + 1) * 

431 self.batch_size] 

432 

433 return np.array([ 

434 resize(imread(file_name), (200, 200)) 

435 for file_name in batch_x]), np.array(batch_y) 

436 ``` 

437 """ 

438 

439 @abstractmethod 

440 def __getitem__(self, index): 

441 """Gets batch at position `index`. 

442 

443 Args: 

444 index: position of the batch in the Sequence. 

445 

446 Returns: 

447 A batch 

448 """ 

449 raise NotImplementedError 

450 

451 @abstractmethod 

452 def __len__(self): 

453 """Number of batch in the Sequence. 

454 

455 Returns: 

456 The number of batches in the Sequence. 

457 """ 

458 raise NotImplementedError 

459 

460 def on_epoch_end(self): 

461 """Method called at the end of every epoch. 

462 """ 

463 pass 

464 

465 def __iter__(self): 

466 """Create a generator that iterate over the Sequence.""" 

467 for item in (self[i] for i in range(len(self))): 

468 yield item 

469 

470 

471def iter_sequence_infinite(seq): 

472 """Iterates indefinitely over a Sequence. 

473 

474 Args: 

475 seq: `Sequence` instance. 

476 

477 Yields: 

478 Batches of data from the `Sequence`. 

479 """ 

480 while True: 

481 for item in seq: 

482 yield item 

483 

484 

485# Global variables to be shared across processes 

486_SHARED_SEQUENCES = {} 

487# We use a Value to provide unique id to different processes. 

488_SEQUENCE_COUNTER = None 

489 

490 

491# Because multiprocessing pools are inherently unsafe, starting from a clean 

492# state can be essential to avoiding deadlocks. In order to accomplish this, we 

493# need to be able to check on the status of Pools that we create. 

494_DATA_POOLS = weakref.WeakSet() 

495_WORKER_ID_QUEUE = None # Only created if needed. 

496_WORKER_IDS = set() 

497_FORCE_THREADPOOL = False 

498_FORCE_THREADPOOL_LOCK = threading.RLock() 

499 

500 

501def dont_use_multiprocessing_pool(f): 

502 @functools.wraps(f) 

503 def wrapped(*args, **kwargs): 

504 with _FORCE_THREADPOOL_LOCK: 

505 global _FORCE_THREADPOOL 

506 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True 

507 out = f(*args, **kwargs) 

508 _FORCE_THREADPOOL = old_force_threadpool 

509 return out 

510 return wrapped 

511 

512 

513def get_pool_class(use_multiprocessing): 

514 global _FORCE_THREADPOOL 

515 if not use_multiprocessing or _FORCE_THREADPOOL: 

516 return multiprocessing.dummy.Pool # ThreadPool 

517 return multiprocessing.Pool 

518 

519 

520def get_worker_id_queue(): 

521 """Lazily create the queue to track worker ids.""" 

522 global _WORKER_ID_QUEUE 

523 if _WORKER_ID_QUEUE is None: 

524 _WORKER_ID_QUEUE = multiprocessing.Queue() 

525 return _WORKER_ID_QUEUE 

526 

527 

528def init_pool(seqs): 

529 global _SHARED_SEQUENCES 

530 _SHARED_SEQUENCES = seqs 

531 

532 

533def get_index(uid, i): 

534 """Get the value from the Sequence `uid` at index `i`. 

535 

536 To allow multiple Sequences to be used at the same time, we use `uid` to 

537 get a specific one. A single Sequence would cause the validation to 

538 overwrite the training Sequence. 

539 

540 Args: 

541 uid: int, Sequence identifier 

542 i: index 

543 

544 Returns: 

545 The value at index `i`. 

546 """ 

547 return _SHARED_SEQUENCES[uid][i] 

548 

549 

550@keras_export('keras.utils.SequenceEnqueuer') 

551class SequenceEnqueuer(object): 

552 """Base class to enqueue inputs. 

553 

554 The task of an Enqueuer is to use parallelism to speed up preprocessing. 

555 This is done with processes or threads. 

556 

557 Example: 

558 

559 ```python 

560 enqueuer = SequenceEnqueuer(...) 

561 enqueuer.start() 

562 datas = enqueuer.get() 

563 for data in datas: 

564 # Use the inputs; training, evaluating, predicting. 

565 # ... stop sometime. 

566 enqueuer.stop() 

567 ``` 

568 

569 The `enqueuer.get()` should be an infinite stream of datas. 

570 """ 

571 

572 def __init__(self, sequence, 

573 use_multiprocessing=False): 

574 self.sequence = sequence 

575 self.use_multiprocessing = use_multiprocessing 

576 

577 global _SEQUENCE_COUNTER 

578 if _SEQUENCE_COUNTER is None: 

579 try: 

580 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) 

581 except OSError: 

582 # In this case the OS does not allow us to use 

583 # multiprocessing. We resort to an int 

584 # for enqueuer indexing. 

585 _SEQUENCE_COUNTER = 0 

586 

587 if isinstance(_SEQUENCE_COUNTER, int): 

588 self.uid = _SEQUENCE_COUNTER 

589 _SEQUENCE_COUNTER += 1 

590 else: 

591 # Doing Multiprocessing.Value += x is not process-safe. 

592 with _SEQUENCE_COUNTER.get_lock(): 

593 self.uid = _SEQUENCE_COUNTER.value 

594 _SEQUENCE_COUNTER.value += 1 

595 

596 self.workers = 0 

597 self.executor_fn = None 

598 self.queue = None 

599 self.run_thread = None 

600 self.stop_signal = None 

601 

602 def is_running(self): 

603 return self.stop_signal is not None and not self.stop_signal.is_set() 

604 

605 def start(self, workers=1, max_queue_size=10): 

606 """Starts the handler's workers. 

607 

608 Args: 

609 workers: Number of workers. 

610 max_queue_size: queue size 

611 (when full, workers could block on `put()`) 

612 """ 

613 if self.use_multiprocessing: 

614 self.executor_fn = self._get_executor_init(workers) 

615 else: 

616 # We do not need the init since it's threads. 

617 self.executor_fn = lambda _: get_pool_class(False)(workers) 

618 self.workers = workers 

619 self.queue = queue.Queue(max_queue_size) 

620 self.stop_signal = threading.Event() 

621 self.run_thread = threading.Thread(target=self._run) 

622 self.run_thread.daemon = True 

623 self.run_thread.start() 

624 

625 def _send_sequence(self): 

626 """Sends current Iterable to all workers.""" 

627 # For new processes that may spawn 

628 _SHARED_SEQUENCES[self.uid] = self.sequence 

629 

630 def stop(self, timeout=None): 

631 """Stops running threads and wait for them to exit, if necessary. 

632 

633 Should be called by the same thread which called `start()`. 

634 

635 Args: 

636 timeout: maximum time to wait on `thread.join()` 

637 """ 

638 self.stop_signal.set() 

639 with self.queue.mutex: 

640 self.queue.queue.clear() 

641 self.queue.unfinished_tasks = 0 

642 self.queue.not_full.notify() 

643 self.run_thread.join(timeout) 

644 _SHARED_SEQUENCES[self.uid] = None 

645 

646 def __del__(self): 

647 if self.is_running(): 

648 self.stop() 

649 

650 @abstractmethod 

651 def _run(self): 

652 """Submits request to the executor and queue the `Future` objects.""" 

653 raise NotImplementedError 

654 

655 @abstractmethod 

656 def _get_executor_init(self, workers): 

657 """Gets the Pool initializer for multiprocessing. 

658 

659 Args: 

660 workers: Number of workers. 

661 

662 Returns: 

663 Function, a Function to initialize the pool 

664 """ 

665 raise NotImplementedError 

666 

667 @abstractmethod 

668 def get(self): 

669 """Creates a generator to extract data from the queue. 

670 

671 Skip the data if it is `None`. 

672 # Returns 

673 Generator yielding tuples `(inputs, targets)` 

674 or `(inputs, targets, sample_weights)`. 

675 """ 

676 raise NotImplementedError 

677 

678 

679@keras_export('keras.utils.OrderedEnqueuer') 

680class OrderedEnqueuer(SequenceEnqueuer): 

681 """Builds a Enqueuer from a Sequence. 

682 

683 Args: 

684 sequence: A `tf.keras.utils.data_utils.Sequence` object. 

685 use_multiprocessing: use multiprocessing if True, otherwise threading 

686 shuffle: whether to shuffle the data at the beginning of each epoch 

687 """ 

688 

689 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 

690 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) 

691 self.shuffle = shuffle 

692 

693 def _get_executor_init(self, workers): 

694 """Gets the Pool initializer for multiprocessing. 

695 

696 Args: 

697 workers: Number of workers. 

698 

699 Returns: 

700 Function, a Function to initialize the pool 

701 """ 

702 def pool_fn(seqs): 

703 pool = get_pool_class(True)( 

704 workers, initializer=init_pool_generator, 

705 initargs=(seqs, None, get_worker_id_queue())) 

706 _DATA_POOLS.add(pool) 

707 return pool 

708 

709 return pool_fn 

710 

711 def _wait_queue(self): 

712 """Wait for the queue to be empty.""" 

713 while True: 

714 time.sleep(0.1) 

715 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 

716 return 

717 

718 def _run(self): 

719 """Submits request to the executor and queue the `Future` objects.""" 

720 sequence = list(range(len(self.sequence))) 

721 self._send_sequence() # Share the initial sequence 

722 while True: 

723 if self.shuffle: 

724 random.shuffle(sequence) 

725 

726 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 

727 for i in sequence: 

728 if self.stop_signal.is_set(): 

729 return 

730 

731 self.queue.put( 

732 executor.apply_async(get_index, (self.uid, i)), block=True) 

733 

734 # Done with the current epoch, waiting for the final batches 

735 self._wait_queue() 

736 

737 if self.stop_signal.is_set(): 

738 # We're done 

739 return 

740 

741 # Call the internal on epoch end. 

742 self.sequence.on_epoch_end() 

743 self._send_sequence() # Update the pool 

744 

745 def get(self): 

746 """Creates a generator to extract data from the queue. 

747 

748 Skip the data if it is `None`. 

749 

750 Yields: 

751 The next element in the queue, i.e. a tuple 

752 `(inputs, targets)` or 

753 `(inputs, targets, sample_weights)`. 

754 """ 

755 while self.is_running(): 

756 try: 

757 inputs = self.queue.get(block=True, timeout=5).get() 

758 if self.is_running(): 

759 self.queue.task_done() 

760 if inputs is not None: 

761 yield inputs 

762 except queue.Empty: 

763 pass 

764 except Exception as e: # pylint: disable=broad-except 

765 self.stop() 

766 raise e 

767 

768 

769def init_pool_generator(gens, random_seed=None, id_queue=None): 

770 """Initializer function for pool workers. 

771 

772 Args: 

773 gens: State which should be made available to worker processes. 

774 random_seed: An optional value with which to seed child processes. 

775 id_queue: A multiprocessing Queue of worker ids. This is used to indicate 

776 that a worker process was created by Keras and can be terminated using 

777 the cleanup_all_keras_forkpools utility. 

778 """ 

779 global _SHARED_SEQUENCES 

780 _SHARED_SEQUENCES = gens 

781 

782 worker_proc = multiprocessing.current_process() 

783 

784 # name isn't used for anything, but setting a more descriptive name is helpful 

785 # when diagnosing orphaned processes. 

786 worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name) 

787 

788 if random_seed is not None: 

789 np.random.seed(random_seed + worker_proc.ident) 

790 

791 if id_queue is not None: 

792 # If a worker dies during init, the pool will just create a replacement. 

793 id_queue.put(worker_proc.ident, block=True, timeout=0.1) 

794 

795 

796def next_sample(uid): 

797 """Gets the next value from the generator `uid`. 

798 

799 To allow multiple generators to be used at the same time, we use `uid` to 

800 get a specific one. A single generator would cause the validation to 

801 overwrite the training generator. 

802 

803 Args: 

804 uid: int, generator identifier 

805 

806 Returns: 

807 The next value of generator `uid`. 

808 """ 

809 return next(_SHARED_SEQUENCES[uid]) 

810 

811 

812@keras_export('keras.utils.GeneratorEnqueuer') 

813class GeneratorEnqueuer(SequenceEnqueuer): 

814 """Builds a queue out of a data generator. 

815 

816 The provided generator can be finite in which case the class will throw 

817 a `StopIteration` exception. 

818 

819 Args: 

820 generator: a generator function which yields data 

821 use_multiprocessing: use multiprocessing if True, otherwise threading 

822 random_seed: Initial seed for workers, 

823 will be incremented by one for each worker. 

824 """ 

825 

826 def __init__(self, generator, 

827 use_multiprocessing=False, 

828 random_seed=None): 

829 super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing) 

830 self.random_seed = random_seed 

831 

832 def _get_executor_init(self, workers): 

833 """Gets the Pool initializer for multiprocessing. 

834 

835 Args: 

836 workers: Number of works. 

837 

838 Returns: 

839 A Function to initialize the pool 

840 """ 

841 def pool_fn(seqs): 

842 pool = get_pool_class(True)( 

843 workers, initializer=init_pool_generator, 

844 initargs=(seqs, self.random_seed, get_worker_id_queue())) 

845 _DATA_POOLS.add(pool) 

846 return pool 

847 return pool_fn 

848 

849 def _run(self): 

850 """Submits request to the executor and queue the `Future` objects.""" 

851 self._send_sequence() # Share the initial generator 

852 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 

853 while True: 

854 if self.stop_signal.is_set(): 

855 return 

856 

857 self.queue.put( 

858 executor.apply_async(next_sample, (self.uid,)), block=True) 

859 

860 def get(self): 

861 """Creates a generator to extract data from the queue. 

862 

863 Skip the data if it is `None`. 

864 

865 Yields: 

866 The next element in the queue, i.e. a tuple 

867 `(inputs, targets)` or 

868 `(inputs, targets, sample_weights)`. 

869 """ 

870 try: 

871 while self.is_running(): 

872 inputs = self.queue.get(block=True).get() 

873 self.queue.task_done() 

874 if inputs is not None: 

875 yield inputs 

876 except StopIteration: 

877 # Special case for finite generators 

878 last_ones = [] 

879 while self.queue.qsize() > 0: 

880 last_ones.append(self.queue.get(block=True)) 

881 # Wait for them to complete 

882 for f in last_ones: 

883 f.wait() 

884 # Keep the good ones 

885 last_ones = [future.get() for future in last_ones if future.successful()] 

886 for inputs in last_ones: 

887 if inputs is not None: 

888 yield inputs 

889 except Exception as e: # pylint: disable=broad-except 

890 self.stop() 

891 if 'generator already executing' in str(e): 

892 raise RuntimeError( 

893 'Your generator is NOT thread-safe. ' 

894 'Keras requires a thread-safe generator when ' 

895 '`use_multiprocessing=False, workers > 1`. ') 

896 raise e