Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/data_utils.py: 25%
368 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Utilities for file download and caching."""
18from abc import abstractmethod
19from contextlib import closing
20import functools
21import hashlib
22import multiprocessing
23import multiprocessing.dummy
24import os
25import queue
26import random
27import shutil
28import sys # pylint: disable=unused-import
29import tarfile
30import threading
31import time
32import typing
33import urllib
34import weakref
35import zipfile
37import numpy as np
39from tensorflow.python.framework import ops
40from six.moves.urllib.request import urlopen
41from tensorflow.python.keras.utils import tf_inspect
42from tensorflow.python.keras.utils.generic_utils import Progbar
43from tensorflow.python.keras.utils.io_utils import path_to_string
44from tensorflow.python.util.tf_export import keras_export
46# Required to support google internal urlretrieve
47if sys.version_info[0] == 2:
49 def urlretrieve(url, filename, reporthook=None, data=None):
50 """Replacement for `urlretrieve` for Python 2.
52 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
53 `urllib` module, known to have issues with proxy management.
55 Args:
56 url: url to retrieve.
57 filename: where to store the retrieved data locally.
58 reporthook: a hook function that will be called once on establishment of
59 the network connection and once after each block read thereafter. The
60 hook will be passed three arguments; a count of blocks transferred so
61 far, a block size in bytes, and the total size of the file.
62 data: `data` argument passed to `urlopen`.
63 """
65 def chunk_read(response, chunk_size=8192, reporthook=None):
66 content_type = response.info().get('Content-Length')
67 total_size = -1
68 if content_type is not None:
69 total_size = int(content_type.strip())
70 count = 0
71 while True:
72 chunk = response.read(chunk_size)
73 count += 1
74 if reporthook is not None:
75 reporthook(count, chunk_size, total_size)
76 if chunk:
77 yield chunk
78 else:
79 break
81 response = urlopen(url, data)
82 with open(filename, 'wb') as fd:
83 for chunk in chunk_read(response, reporthook=reporthook):
84 fd.write(chunk)
85else:
86 from urllib.request import urlretrieve # pylint: disable=g-importing-member
89def is_generator_or_sequence(x):
90 """Check if `x` is a Keras generator type."""
91 builtin_iterators = (str, list, tuple, dict, set, frozenset)
92 if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators):
93 return False
94 return (tf_inspect.isgenerator(x) or
95 isinstance(x, Sequence) or
96 isinstance(x, typing.Iterator))
99def _extract_archive(file_path, path='.', archive_format='auto'):
100 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
102 Args:
103 file_path: path to the archive file
104 path: path to extract the archive file
105 archive_format: Archive format to try for extracting the file.
106 Options are 'auto', 'tar', 'zip', and None.
107 'tar' includes tar, tar.gz, and tar.bz files.
108 The default 'auto' is ['tar', 'zip'].
109 None or an empty list will return no matches found.
111 Returns:
112 True if a match was found and an archive extraction was completed,
113 False otherwise.
114 """
115 if archive_format is None:
116 return False
117 if archive_format == 'auto':
118 archive_format = ['tar', 'zip']
119 if isinstance(archive_format, str):
120 archive_format = [archive_format]
122 file_path = path_to_string(file_path)
123 path = path_to_string(path)
125 for archive_type in archive_format:
126 if archive_type == 'tar':
127 open_fn = tarfile.open
128 is_match_fn = tarfile.is_tarfile
129 if archive_type == 'zip':
130 open_fn = zipfile.ZipFile
131 is_match_fn = zipfile.is_zipfile
133 if is_match_fn(file_path):
134 with open_fn(file_path) as archive:
135 try:
136 archive.extractall(path)
137 except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
138 if os.path.exists(path):
139 if os.path.isfile(path):
140 os.remove(path)
141 else:
142 shutil.rmtree(path)
143 raise
144 return True
145 return False
148@keras_export('keras.utils.get_file')
149def get_file(fname,
150 origin,
151 untar=False,
152 md5_hash=None,
153 file_hash=None,
154 cache_subdir='datasets',
155 hash_algorithm='auto',
156 extract=False,
157 archive_format='auto',
158 cache_dir=None):
159 """Downloads a file from a URL if it not already in the cache.
161 By default the file at the url `origin` is downloaded to the
162 cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
163 and given the filename `fname`. The final location of a file
164 `example.txt` would therefore be `~/.keras/datasets/example.txt`.
166 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
167 Passing a hash will verify the file after download. The command line
168 programs `shasum` and `sha256sum` can compute the hash.
170 Example:
172 ```python
173 path_to_downloaded_file = tf.keras.utils.get_file(
174 "flower_photos",
175 "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
176 untar=True)
177 ```
179 Args:
180 fname: Name of the file. If an absolute path `/path/to/file.txt` is
181 specified the file will be saved at that location.
182 origin: Original URL of the file.
183 untar: Deprecated in favor of `extract` argument.
184 boolean, whether the file should be decompressed
185 md5_hash: Deprecated in favor of `file_hash` argument.
186 md5 hash of the file for verification
187 file_hash: The expected hash string of the file after download.
188 The sha256 and md5 hash algorithms are both supported.
189 cache_subdir: Subdirectory under the Keras cache dir where the file is
190 saved. If an absolute path `/path/to/folder` is
191 specified the file will be saved at that location.
192 hash_algorithm: Select the hash algorithm to verify the file.
193 options are `'md5'`, `'sha256'`, and `'auto'`.
194 The default 'auto' detects the hash algorithm in use.
195 extract: True tries extracting the file as an Archive, like tar or zip.
196 archive_format: Archive format to try for extracting the file.
197 Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
198 `'tar'` includes tar, tar.gz, and tar.bz files.
199 The default `'auto'` corresponds to `['tar', 'zip']`.
200 None or an empty list will return no matches found.
201 cache_dir: Location to store cached files, when None it
202 defaults to the default directory `~/.keras/`.
204 Returns:
205 Path to the downloaded file
206 """
207 if cache_dir is None:
208 cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
209 if md5_hash is not None and file_hash is None:
210 file_hash = md5_hash
211 hash_algorithm = 'md5'
212 datadir_base = os.path.expanduser(cache_dir)
213 if not os.access(datadir_base, os.W_OK):
214 datadir_base = os.path.join('/tmp', '.keras')
215 datadir = os.path.join(datadir_base, cache_subdir)
216 _makedirs_exist_ok(datadir)
218 fname = path_to_string(fname)
220 if untar:
221 untar_fpath = os.path.join(datadir, fname)
222 fpath = untar_fpath + '.tar.gz'
223 else:
224 fpath = os.path.join(datadir, fname)
226 download = False
227 if os.path.exists(fpath):
228 # File found; verify integrity if a hash was provided.
229 if file_hash is not None:
230 if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
231 print('A local file was found, but it seems to be '
232 'incomplete or outdated because the ' + hash_algorithm +
233 ' file hash does not match the original value of ' + file_hash +
234 ' so we will re-download the data.')
235 download = True
236 else:
237 download = True
239 if download:
240 print('Downloading data from', origin)
242 class ProgressTracker(object):
243 # Maintain progbar for the lifetime of download.
244 # This design was chosen for Python 2.7 compatibility.
245 progbar = None
247 def dl_progress(count, block_size, total_size):
248 if ProgressTracker.progbar is None:
249 if total_size == -1:
250 total_size = None
251 ProgressTracker.progbar = Progbar(total_size)
252 else:
253 ProgressTracker.progbar.update(count * block_size)
255 error_msg = 'URL fetch failure on {}: {} -- {}'
256 try:
257 try:
258 urlretrieve(origin, fpath, dl_progress)
259 except urllib.error.HTTPError as e:
260 raise Exception(error_msg.format(origin, e.code, e.msg))
261 except urllib.error.URLError as e:
262 raise Exception(error_msg.format(origin, e.errno, e.reason))
263 except (Exception, KeyboardInterrupt) as e:
264 if os.path.exists(fpath):
265 os.remove(fpath)
266 raise
267 ProgressTracker.progbar = None
269 if untar:
270 if not os.path.exists(untar_fpath):
271 _extract_archive(fpath, datadir, archive_format='tar')
272 return untar_fpath
274 if extract:
275 _extract_archive(fpath, datadir, archive_format)
277 return fpath
280def _makedirs_exist_ok(datadir):
281 os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
284def _resolve_hasher(algorithm, file_hash=None):
285 """Returns hash algorithm as hashlib function."""
286 if algorithm == 'sha256':
287 return hashlib.sha256()
289 if algorithm == 'auto' and file_hash is not None and len(file_hash) == 64:
290 return hashlib.sha256()
292 # This is used only for legacy purposes.
293 return hashlib.md5()
296def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
297 """Calculates a file sha256 or md5 hash.
299 Example:
301 ```python
302 _hash_file('/path/to/file.zip')
303 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
304 ```
306 Args:
307 fpath: path to the file being validated
308 algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
309 The default `'auto'` detects the hash algorithm in use.
310 chunk_size: Bytes to read at a time, important for large files.
312 Returns:
313 The file hash
314 """
315 if isinstance(algorithm, str):
316 hasher = _resolve_hasher(algorithm)
317 else:
318 hasher = algorithm
320 with open(fpath, 'rb') as fpath_file:
321 for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
322 hasher.update(chunk)
324 return hasher.hexdigest()
327def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
328 """Validates a file against a sha256 or md5 hash.
330 Args:
331 fpath: path to the file being validated
332 file_hash: The expected hash string of the file.
333 The sha256 and md5 hash algorithms are both supported.
334 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
335 The default 'auto' detects the hash algorithm in use.
336 chunk_size: Bytes to read at a time, important for large files.
338 Returns:
339 Whether the file is valid
340 """
341 hasher = _resolve_hasher(algorithm, file_hash)
343 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
344 return True
345 else:
346 return False
349class ThreadsafeIter(object):
350 """Wrap an iterator with a lock and propagate exceptions to all threads."""
352 def __init__(self, it):
353 self.it = it
354 self.lock = threading.Lock()
356 # After a generator throws an exception all subsequent next() calls raise a
357 # StopIteration Exception. This, however, presents an issue when mixing
358 # generators and threading because it means the order of retrieval need not
359 # match the order in which the generator was called. This can make it appear
360 # that a generator exited normally when in fact the terminating exception is
361 # just in a different thread. In order to provide thread safety, once
362 # self.it has thrown an exception we continue to throw the same exception.
363 self._exception = None
365 def __iter__(self):
366 return self
368 def next(self):
369 return self.__next__()
371 def __next__(self):
372 with self.lock:
373 if self._exception:
374 raise self._exception # pylint: disable=raising-bad-type
376 try:
377 return next(self.it)
378 except Exception as e:
379 self._exception = e
380 raise
383def threadsafe_generator(f):
385 @functools.wraps(f)
386 def g(*a, **kw):
387 return ThreadsafeIter(f(*a, **kw))
389 return g
392@keras_export('keras.utils.Sequence')
393class Sequence(object):
394 """Base object for fitting to a sequence of data, such as a dataset.
396 Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
397 If you want to modify your dataset between epochs you may implement
398 `on_epoch_end`.
399 The method `__getitem__` should return a complete batch.
401 Notes:
403 `Sequence` are a safer way to do multiprocessing. This structure guarantees
404 that the network will only train once
405 on each sample per epoch which is not the case with generators.
407 Examples:
409 ```python
410 from skimage.io import imread
411 from skimage.transform import resize
412 import numpy as np
413 import math
415 # Here, `x_set` is list of path to the images
416 # and `y_set` are the associated classes.
418 class CIFAR10Sequence(Sequence):
420 def __init__(self, x_set, y_set, batch_size):
421 self.x, self.y = x_set, y_set
422 self.batch_size = batch_size
424 def __len__(self):
425 return math.ceil(len(self.x) / self.batch_size)
427 def __getitem__(self, idx):
428 batch_x = self.x[idx * self.batch_size:(idx + 1) *
429 self.batch_size]
430 batch_y = self.y[idx * self.batch_size:(idx + 1) *
431 self.batch_size]
433 return np.array([
434 resize(imread(file_name), (200, 200))
435 for file_name in batch_x]), np.array(batch_y)
436 ```
437 """
439 @abstractmethod
440 def __getitem__(self, index):
441 """Gets batch at position `index`.
443 Args:
444 index: position of the batch in the Sequence.
446 Returns:
447 A batch
448 """
449 raise NotImplementedError
451 @abstractmethod
452 def __len__(self):
453 """Number of batch in the Sequence.
455 Returns:
456 The number of batches in the Sequence.
457 """
458 raise NotImplementedError
460 def on_epoch_end(self):
461 """Method called at the end of every epoch.
462 """
463 pass
465 def __iter__(self):
466 """Create a generator that iterate over the Sequence."""
467 for item in (self[i] for i in range(len(self))):
468 yield item
471def iter_sequence_infinite(seq):
472 """Iterates indefinitely over a Sequence.
474 Args:
475 seq: `Sequence` instance.
477 Yields:
478 Batches of data from the `Sequence`.
479 """
480 while True:
481 for item in seq:
482 yield item
485# Global variables to be shared across processes
486_SHARED_SEQUENCES = {}
487# We use a Value to provide unique id to different processes.
488_SEQUENCE_COUNTER = None
491# Because multiprocessing pools are inherently unsafe, starting from a clean
492# state can be essential to avoiding deadlocks. In order to accomplish this, we
493# need to be able to check on the status of Pools that we create.
494_DATA_POOLS = weakref.WeakSet()
495_WORKER_ID_QUEUE = None # Only created if needed.
496_WORKER_IDS = set()
497_FORCE_THREADPOOL = False
498_FORCE_THREADPOOL_LOCK = threading.RLock()
501def dont_use_multiprocessing_pool(f):
502 @functools.wraps(f)
503 def wrapped(*args, **kwargs):
504 with _FORCE_THREADPOOL_LOCK:
505 global _FORCE_THREADPOOL
506 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True
507 out = f(*args, **kwargs)
508 _FORCE_THREADPOOL = old_force_threadpool
509 return out
510 return wrapped
513def get_pool_class(use_multiprocessing):
514 global _FORCE_THREADPOOL
515 if not use_multiprocessing or _FORCE_THREADPOOL:
516 return multiprocessing.dummy.Pool # ThreadPool
517 return multiprocessing.Pool
520def get_worker_id_queue():
521 """Lazily create the queue to track worker ids."""
522 global _WORKER_ID_QUEUE
523 if _WORKER_ID_QUEUE is None:
524 _WORKER_ID_QUEUE = multiprocessing.Queue()
525 return _WORKER_ID_QUEUE
528def init_pool(seqs):
529 global _SHARED_SEQUENCES
530 _SHARED_SEQUENCES = seqs
533def get_index(uid, i):
534 """Get the value from the Sequence `uid` at index `i`.
536 To allow multiple Sequences to be used at the same time, we use `uid` to
537 get a specific one. A single Sequence would cause the validation to
538 overwrite the training Sequence.
540 Args:
541 uid: int, Sequence identifier
542 i: index
544 Returns:
545 The value at index `i`.
546 """
547 return _SHARED_SEQUENCES[uid][i]
550@keras_export('keras.utils.SequenceEnqueuer')
551class SequenceEnqueuer(object):
552 """Base class to enqueue inputs.
554 The task of an Enqueuer is to use parallelism to speed up preprocessing.
555 This is done with processes or threads.
557 Example:
559 ```python
560 enqueuer = SequenceEnqueuer(...)
561 enqueuer.start()
562 datas = enqueuer.get()
563 for data in datas:
564 # Use the inputs; training, evaluating, predicting.
565 # ... stop sometime.
566 enqueuer.stop()
567 ```
569 The `enqueuer.get()` should be an infinite stream of datas.
570 """
572 def __init__(self, sequence,
573 use_multiprocessing=False):
574 self.sequence = sequence
575 self.use_multiprocessing = use_multiprocessing
577 global _SEQUENCE_COUNTER
578 if _SEQUENCE_COUNTER is None:
579 try:
580 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
581 except OSError:
582 # In this case the OS does not allow us to use
583 # multiprocessing. We resort to an int
584 # for enqueuer indexing.
585 _SEQUENCE_COUNTER = 0
587 if isinstance(_SEQUENCE_COUNTER, int):
588 self.uid = _SEQUENCE_COUNTER
589 _SEQUENCE_COUNTER += 1
590 else:
591 # Doing Multiprocessing.Value += x is not process-safe.
592 with _SEQUENCE_COUNTER.get_lock():
593 self.uid = _SEQUENCE_COUNTER.value
594 _SEQUENCE_COUNTER.value += 1
596 self.workers = 0
597 self.executor_fn = None
598 self.queue = None
599 self.run_thread = None
600 self.stop_signal = None
602 def is_running(self):
603 return self.stop_signal is not None and not self.stop_signal.is_set()
605 def start(self, workers=1, max_queue_size=10):
606 """Starts the handler's workers.
608 Args:
609 workers: Number of workers.
610 max_queue_size: queue size
611 (when full, workers could block on `put()`)
612 """
613 if self.use_multiprocessing:
614 self.executor_fn = self._get_executor_init(workers)
615 else:
616 # We do not need the init since it's threads.
617 self.executor_fn = lambda _: get_pool_class(False)(workers)
618 self.workers = workers
619 self.queue = queue.Queue(max_queue_size)
620 self.stop_signal = threading.Event()
621 self.run_thread = threading.Thread(target=self._run)
622 self.run_thread.daemon = True
623 self.run_thread.start()
625 def _send_sequence(self):
626 """Sends current Iterable to all workers."""
627 # For new processes that may spawn
628 _SHARED_SEQUENCES[self.uid] = self.sequence
630 def stop(self, timeout=None):
631 """Stops running threads and wait for them to exit, if necessary.
633 Should be called by the same thread which called `start()`.
635 Args:
636 timeout: maximum time to wait on `thread.join()`
637 """
638 self.stop_signal.set()
639 with self.queue.mutex:
640 self.queue.queue.clear()
641 self.queue.unfinished_tasks = 0
642 self.queue.not_full.notify()
643 self.run_thread.join(timeout)
644 _SHARED_SEQUENCES[self.uid] = None
646 def __del__(self):
647 if self.is_running():
648 self.stop()
650 @abstractmethod
651 def _run(self):
652 """Submits request to the executor and queue the `Future` objects."""
653 raise NotImplementedError
655 @abstractmethod
656 def _get_executor_init(self, workers):
657 """Gets the Pool initializer for multiprocessing.
659 Args:
660 workers: Number of workers.
662 Returns:
663 Function, a Function to initialize the pool
664 """
665 raise NotImplementedError
667 @abstractmethod
668 def get(self):
669 """Creates a generator to extract data from the queue.
671 Skip the data if it is `None`.
672 # Returns
673 Generator yielding tuples `(inputs, targets)`
674 or `(inputs, targets, sample_weights)`.
675 """
676 raise NotImplementedError
679@keras_export('keras.utils.OrderedEnqueuer')
680class OrderedEnqueuer(SequenceEnqueuer):
681 """Builds a Enqueuer from a Sequence.
683 Args:
684 sequence: A `tf.keras.utils.data_utils.Sequence` object.
685 use_multiprocessing: use multiprocessing if True, otherwise threading
686 shuffle: whether to shuffle the data at the beginning of each epoch
687 """
689 def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
690 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
691 self.shuffle = shuffle
693 def _get_executor_init(self, workers):
694 """Gets the Pool initializer for multiprocessing.
696 Args:
697 workers: Number of workers.
699 Returns:
700 Function, a Function to initialize the pool
701 """
702 def pool_fn(seqs):
703 pool = get_pool_class(True)(
704 workers, initializer=init_pool_generator,
705 initargs=(seqs, None, get_worker_id_queue()))
706 _DATA_POOLS.add(pool)
707 return pool
709 return pool_fn
711 def _wait_queue(self):
712 """Wait for the queue to be empty."""
713 while True:
714 time.sleep(0.1)
715 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
716 return
718 def _run(self):
719 """Submits request to the executor and queue the `Future` objects."""
720 sequence = list(range(len(self.sequence)))
721 self._send_sequence() # Share the initial sequence
722 while True:
723 if self.shuffle:
724 random.shuffle(sequence)
726 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
727 for i in sequence:
728 if self.stop_signal.is_set():
729 return
731 self.queue.put(
732 executor.apply_async(get_index, (self.uid, i)), block=True)
734 # Done with the current epoch, waiting for the final batches
735 self._wait_queue()
737 if self.stop_signal.is_set():
738 # We're done
739 return
741 # Call the internal on epoch end.
742 self.sequence.on_epoch_end()
743 self._send_sequence() # Update the pool
745 def get(self):
746 """Creates a generator to extract data from the queue.
748 Skip the data if it is `None`.
750 Yields:
751 The next element in the queue, i.e. a tuple
752 `(inputs, targets)` or
753 `(inputs, targets, sample_weights)`.
754 """
755 while self.is_running():
756 try:
757 inputs = self.queue.get(block=True, timeout=5).get()
758 if self.is_running():
759 self.queue.task_done()
760 if inputs is not None:
761 yield inputs
762 except queue.Empty:
763 pass
764 except Exception as e: # pylint: disable=broad-except
765 self.stop()
766 raise e
769def init_pool_generator(gens, random_seed=None, id_queue=None):
770 """Initializer function for pool workers.
772 Args:
773 gens: State which should be made available to worker processes.
774 random_seed: An optional value with which to seed child processes.
775 id_queue: A multiprocessing Queue of worker ids. This is used to indicate
776 that a worker process was created by Keras and can be terminated using
777 the cleanup_all_keras_forkpools utility.
778 """
779 global _SHARED_SEQUENCES
780 _SHARED_SEQUENCES = gens
782 worker_proc = multiprocessing.current_process()
784 # name isn't used for anything, but setting a more descriptive name is helpful
785 # when diagnosing orphaned processes.
786 worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
788 if random_seed is not None:
789 np.random.seed(random_seed + worker_proc.ident)
791 if id_queue is not None:
792 # If a worker dies during init, the pool will just create a replacement.
793 id_queue.put(worker_proc.ident, block=True, timeout=0.1)
796def next_sample(uid):
797 """Gets the next value from the generator `uid`.
799 To allow multiple generators to be used at the same time, we use `uid` to
800 get a specific one. A single generator would cause the validation to
801 overwrite the training generator.
803 Args:
804 uid: int, generator identifier
806 Returns:
807 The next value of generator `uid`.
808 """
809 return next(_SHARED_SEQUENCES[uid])
812@keras_export('keras.utils.GeneratorEnqueuer')
813class GeneratorEnqueuer(SequenceEnqueuer):
814 """Builds a queue out of a data generator.
816 The provided generator can be finite in which case the class will throw
817 a `StopIteration` exception.
819 Args:
820 generator: a generator function which yields data
821 use_multiprocessing: use multiprocessing if True, otherwise threading
822 random_seed: Initial seed for workers,
823 will be incremented by one for each worker.
824 """
826 def __init__(self, generator,
827 use_multiprocessing=False,
828 random_seed=None):
829 super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing)
830 self.random_seed = random_seed
832 def _get_executor_init(self, workers):
833 """Gets the Pool initializer for multiprocessing.
835 Args:
836 workers: Number of works.
838 Returns:
839 A Function to initialize the pool
840 """
841 def pool_fn(seqs):
842 pool = get_pool_class(True)(
843 workers, initializer=init_pool_generator,
844 initargs=(seqs, self.random_seed, get_worker_id_queue()))
845 _DATA_POOLS.add(pool)
846 return pool
847 return pool_fn
849 def _run(self):
850 """Submits request to the executor and queue the `Future` objects."""
851 self._send_sequence() # Share the initial generator
852 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
853 while True:
854 if self.stop_signal.is_set():
855 return
857 self.queue.put(
858 executor.apply_async(next_sample, (self.uid,)), block=True)
860 def get(self):
861 """Creates a generator to extract data from the queue.
863 Skip the data if it is `None`.
865 Yields:
866 The next element in the queue, i.e. a tuple
867 `(inputs, targets)` or
868 `(inputs, targets, sample_weights)`.
869 """
870 try:
871 while self.is_running():
872 inputs = self.queue.get(block=True).get()
873 self.queue.task_done()
874 if inputs is not None:
875 yield inputs
876 except StopIteration:
877 # Special case for finite generators
878 last_ones = []
879 while self.queue.qsize() > 0:
880 last_ones.append(self.queue.get(block=True))
881 # Wait for them to complete
882 for f in last_ones:
883 f.wait()
884 # Keep the good ones
885 last_ones = [future.get() for future in last_ones if future.successful()]
886 for inputs in last_ones:
887 if inputs is not None:
888 yield inputs
889 except Exception as e: # pylint: disable=broad-except
890 self.stop()
891 if 'generator already executing' in str(e):
892 raise RuntimeError(
893 'Your generator is NOT thread-safe. '
894 'Keras requires a thread-safe generator when '
895 '`use_multiprocessing=False, workers > 1`. ')
896 raise e