Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/data_utils.py: 22%
449 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utilities for file download and caching."""
18import functools
19import hashlib
20import multiprocessing.dummy
21import os
22import pathlib
23import queue
24import random
25import shutil
26import tarfile
27import threading
28import time
29import typing
30import urllib
31import warnings
32import weakref
33import zipfile
34from abc import abstractmethod
35from contextlib import closing
37import numpy as np
38import tensorflow.compat.v2 as tf
39from six.moves.urllib.parse import urlsplit
41from keras.src.utils import io_utils
42from keras.src.utils import tf_inspect
43from keras.src.utils.generic_utils import Progbar
45# isort: off
46from tensorflow.python.util.tf_export import keras_export
47from six.moves.urllib.request import urlopen
49# Required to support google internal urlretrieve
50if True: # This gets transformed to `if sys.version_info[0] == 2:` in OSS.
52 def urlretrieve(url, filename, reporthook=None, data=None):
53 """Replacement for `urlretrieve` for Python 2.
55 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
56 `urllib` module, known to have issues with proxy management.
58 Args:
59 url: url to retrieve.
60 filename: where to store the retrieved data locally.
61 reporthook: a hook function that will be called once on
62 establishment of the network connection and once after each block
63 read thereafter. The hook will be passed three arguments; a count
64 of blocks transferred so far, a block size in bytes, and the total
65 size of the file.
66 data: `data` argument passed to `urlopen`.
67 """
69 def chunk_read(response, chunk_size=8192, reporthook=None):
70 content_type = response.info().get("Content-Length")
71 total_size = -1
72 if content_type is not None:
73 total_size = int(content_type.strip())
74 count = 0
75 while True:
76 chunk = response.read(chunk_size)
77 count += 1
78 if reporthook is not None:
79 reporthook(count, chunk_size, total_size)
80 if chunk:
81 yield chunk
82 else:
83 break
85 response = urlopen(url, data)
86 with open(filename, "wb") as fd:
87 for chunk in chunk_read(response, reporthook=reporthook):
88 fd.write(chunk)
90else:
91 from urllib.request import urlretrieve
94def is_generator_or_sequence(x):
95 """Check if `x` is a Keras generator type."""
96 builtin_iterators = (str, list, tuple, dict, set, frozenset)
97 if isinstance(x, (tf.Tensor, np.ndarray) + builtin_iterators):
98 return False
99 return (
100 tf_inspect.isgenerator(x)
101 or isinstance(x, Sequence)
102 or isinstance(x, typing.Iterator)
103 )
106def _resolve_path(path):
107 return os.path.realpath(os.path.abspath(path))
110def _is_path_in_dir(path, base_dir):
111 return _resolve_path(os.path.join(base_dir, path)).startswith(base_dir)
114def _is_link_in_dir(info, base):
115 tip = _resolve_path(os.path.join(base, os.path.dirname(info.name)))
116 return _is_path_in_dir(info.linkname, base_dir=tip)
119def _filter_safe_paths(members):
120 base_dir = _resolve_path(".")
121 for finfo in members:
122 valid_path = False
123 if _is_path_in_dir(finfo.name, base_dir):
124 valid_path = True
125 yield finfo
126 elif finfo.issym() or finfo.islnk():
127 if _is_link_in_dir(finfo, base_dir):
128 valid_path = True
129 yield finfo
130 if not valid_path:
131 warnings.warn(
132 "Skipping invalid path during archive extraction: "
133 f"'{finfo.name}'."
134 )
137def _extract_archive(file_path, path=".", archive_format="auto"):
138 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
140 Args:
141 file_path: Path to the archive file.
142 path: Where to extract the archive file.
143 archive_format: Archive format to try for extracting the file.
144 Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
145 `'tar'` includes tar, tar.gz, and tar.bz files.
146 The default 'auto' is `['tar', 'zip']`.
147 `None` or an empty list will return no matches found.
149 Returns:
150 True if a match was found and an archive extraction was completed,
151 False otherwise.
152 """
153 if archive_format is None:
154 return False
155 if archive_format == "auto":
156 archive_format = ["tar", "zip"]
157 if isinstance(archive_format, str):
158 archive_format = [archive_format]
160 file_path = io_utils.path_to_string(file_path)
161 path = io_utils.path_to_string(path)
163 for archive_type in archive_format:
164 if archive_type == "tar":
165 open_fn = tarfile.open
166 is_match_fn = tarfile.is_tarfile
167 if archive_type == "zip":
168 open_fn = zipfile.ZipFile
169 is_match_fn = zipfile.is_zipfile
171 if is_match_fn(file_path):
172 with open_fn(file_path) as archive:
173 try:
174 if zipfile.is_zipfile(file_path):
175 # Zip archive.
176 archive.extractall(path)
177 else:
178 # Tar archive, perhaps unsafe. Filter paths.
179 archive.extractall(
180 path, members=_filter_safe_paths(archive)
181 )
182 except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
183 if os.path.exists(path):
184 if os.path.isfile(path):
185 os.remove(path)
186 else:
187 shutil.rmtree(path)
188 raise
189 return True
190 return False
193@keras_export("keras.utils.get_file")
194def get_file(
195 fname=None,
196 origin=None,
197 untar=False,
198 md5_hash=None,
199 file_hash=None,
200 cache_subdir="datasets",
201 hash_algorithm="auto",
202 extract=False,
203 archive_format="auto",
204 cache_dir=None,
205):
206 """Downloads a file from a URL if it not already in the cache.
208 By default the file at the url `origin` is downloaded to the
209 cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
210 and given the filename `fname`. The final location of a file
211 `example.txt` would therefore be `~/.keras/datasets/example.txt`.
213 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
214 Passing a hash will verify the file after download. The command line
215 programs `shasum` and `sha256sum` can compute the hash.
217 Example:
219 ```python
220 path_to_downloaded_file = tf.keras.utils.get_file(
221 origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
222 extract=True,
223 )
224 ```
226 Args:
227 fname: Name of the file. If an absolute path `/path/to/file.txt` is
228 specified the file will be saved at that location. If `None`, the
229 name of the file at `origin` will be used.
230 origin: Original URL of the file.
231 untar: Deprecated in favor of `extract` argument.
232 boolean, whether the file should be decompressed
233 md5_hash: Deprecated in favor of `file_hash` argument.
234 md5 hash of the file for verification
235 file_hash: The expected hash string of the file after download.
236 The sha256 and md5 hash algorithms are both supported.
237 cache_subdir: Subdirectory under the Keras cache dir where the file is
238 saved. If an absolute path `/path/to/folder` is
239 specified the file will be saved at that location.
240 hash_algorithm: Select the hash algorithm to verify the file.
241 options are `'md5'`, `'sha256'`, and `'auto'`.
242 The default 'auto' detects the hash algorithm in use.
243 extract: True tries extracting the file as an Archive, like tar or zip.
244 archive_format: Archive format to try for extracting the file.
245 Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
246 `'tar'` includes tar, tar.gz, and tar.bz files.
247 The default `'auto'` corresponds to `['tar', 'zip']`.
248 None or an empty list will return no matches found.
249 cache_dir: Location to store cached files, when None it
250 defaults to the default directory `~/.keras/`.
252 Returns:
253 Path to the downloaded file.
255 ⚠️ **Warning on malicious downloads** ⚠️
257 Downloading something from the Internet carries a risk.
258 NEVER download a file/archive if you do not trust the source.
259 We recommend that you specify the `file_hash` argument
260 (if the hash of the source file is known) to make sure that the file you
261 are getting is the one you expect.
262 """
263 if origin is None:
264 raise ValueError(
265 'Please specify the "origin" argument (URL of the file '
266 "to download)."
267 )
269 if cache_dir is None:
270 cache_dir = os.path.join(os.path.expanduser("~"), ".keras")
271 if md5_hash is not None and file_hash is None:
272 file_hash = md5_hash
273 hash_algorithm = "md5"
274 datadir_base = os.path.expanduser(cache_dir)
275 if not os.access(datadir_base, os.W_OK):
276 datadir_base = os.path.join("/tmp", ".keras")
277 datadir = os.path.join(datadir_base, cache_subdir)
278 _makedirs_exist_ok(datadir)
280 fname = io_utils.path_to_string(fname)
281 if not fname:
282 fname = os.path.basename(urlsplit(origin).path)
283 if not fname:
284 raise ValueError(
285 "Can't parse the file name from the origin provided: "
286 f"'{origin}'."
287 "Please specify the `fname` as the input param."
288 )
290 if untar:
291 if fname.endswith(".tar.gz"):
292 fname = pathlib.Path(fname)
293 # The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
294 # considers it as 2 suffixes.
295 fname = fname.with_suffix("").with_suffix("")
296 fname = str(fname)
297 untar_fpath = os.path.join(datadir, fname)
298 fpath = untar_fpath + ".tar.gz"
299 else:
300 fpath = os.path.join(datadir, fname)
302 download = False
303 if os.path.exists(fpath):
304 # File found; verify integrity if a hash was provided.
305 if file_hash is not None:
306 if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
307 io_utils.print_msg(
308 "A local file was found, but it seems to be "
309 f"incomplete or outdated because the {hash_algorithm} "
310 "file hash does not match the original value of "
311 f"{file_hash} "
312 "so we will re-download the data."
313 )
314 download = True
315 else:
316 download = True
318 if download:
319 io_utils.print_msg(f"Downloading data from {origin}")
321 class DLProgbar:
322 """Manage progress bar state for use in urlretrieve."""
324 def __init__(self):
325 self.progbar = None
326 self.finished = False
328 def __call__(self, block_num, block_size, total_size):
329 if not self.progbar:
330 if total_size == -1:
331 total_size = None
332 self.progbar = Progbar(total_size)
333 current = block_num * block_size
335 if total_size is None:
336 self.progbar.update(current)
337 else:
338 if current < total_size:
339 self.progbar.update(current)
340 elif not self.finished:
341 self.progbar.update(self.progbar.target)
342 self.finished = True
344 error_msg = "URL fetch failure on {}: {} -- {}"
345 try:
346 try:
347 urlretrieve(origin, fpath, DLProgbar())
348 except urllib.error.HTTPError as e:
349 raise Exception(error_msg.format(origin, e.code, e.msg))
350 except urllib.error.URLError as e:
351 raise Exception(error_msg.format(origin, e.errno, e.reason))
352 except (Exception, KeyboardInterrupt):
353 if os.path.exists(fpath):
354 os.remove(fpath)
355 raise
357 # Validate download if succeeded and user provided an expected hash
358 # Security conscious users would get the hash of the file from a
359 # separate channel and pass it to this API to prevent MITM / corruption:
360 if os.path.exists(fpath) and file_hash is not None:
361 if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
362 raise ValueError(
363 "Incomplete or corrupted file detected. "
364 f"The {hash_algorithm} "
365 "file hash does not match the provided value "
366 f"of {file_hash}."
367 )
369 if untar:
370 if not os.path.exists(untar_fpath):
371 _extract_archive(fpath, datadir, archive_format="tar")
372 return untar_fpath
374 if extract:
375 _extract_archive(fpath, datadir, archive_format)
377 return fpath
380def _makedirs_exist_ok(datadir):
381 os.makedirs(datadir, exist_ok=True)
384def _resolve_hasher(algorithm, file_hash=None):
385 """Returns hash algorithm as hashlib function."""
386 if algorithm == "sha256":
387 return hashlib.sha256()
389 if algorithm == "auto" and file_hash is not None and len(file_hash) == 64:
390 return hashlib.sha256()
392 # This is used only for legacy purposes.
393 return hashlib.md5()
396def _hash_file(fpath, algorithm="sha256", chunk_size=65535):
397 """Calculates a file sha256 or md5 hash.
399 Example:
401 ```python
402 _hash_file('/path/to/file.zip')
403 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
404 ```
406 Args:
407 fpath: Path to the file being validated.
408 algorithm: Hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
409 The default `'auto'` detects the hash algorithm in use.
410 chunk_size: Bytes to read at a time, important for large files.
412 Returns:
413 The file hash.
414 """
415 if isinstance(algorithm, str):
416 hasher = _resolve_hasher(algorithm)
417 else:
418 hasher = algorithm
420 with open(fpath, "rb") as fpath_file:
421 for chunk in iter(lambda: fpath_file.read(chunk_size), b""):
422 hasher.update(chunk)
424 return hasher.hexdigest()
427def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535):
428 """Validates a file against a sha256 or md5 hash.
430 Args:
431 fpath: path to the file being validated
432 file_hash: The expected hash string of the file.
433 The sha256 and md5 hash algorithms are both supported.
434 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
435 The default 'auto' detects the hash algorithm in use.
436 chunk_size: Bytes to read at a time, important for large files.
438 Returns:
439 Whether the file is valid
440 """
441 hasher = _resolve_hasher(algorithm, file_hash)
443 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
444 return True
445 else:
446 return False
449class ThreadsafeIter:
450 """Wrap an iterator with a lock and propagate exceptions to all threads."""
452 def __init__(self, it):
453 self.it = it
454 self.lock = threading.Lock()
456 # After a generator throws an exception all subsequent next() calls
457 # raise a StopIteration Exception. This, however, presents an issue when
458 # mixing generators and threading because it means the order of
459 # retrieval need not match the order in which the generator was called.
460 # This can make it appear that a generator exited normally when in fact
461 # the terminating exception is just in a different thread. In order to
462 # provide thread safety, once self.it has thrown an exception we
463 # continue to throw the same exception.
464 self._exception = None
466 def __iter__(self):
467 return self
469 def next(self):
470 return self.__next__()
472 def __next__(self):
473 with self.lock:
474 if self._exception:
475 raise self._exception
477 try:
478 return next(self.it)
479 except Exception as e:
480 self._exception = e
481 raise
484def threadsafe_generator(f):
485 @functools.wraps(f)
486 def g(*a, **kw):
487 return ThreadsafeIter(f(*a, **kw))
489 return g
492@keras_export("keras.utils.Sequence")
493class Sequence:
494 """Base object for fitting to a sequence of data, such as a dataset.
496 Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
497 If you want to modify your dataset between epochs, you may implement
498 `on_epoch_end`. The method `__getitem__` should return a complete batch.
500 Notes:
502 `Sequence` is a safer way to do multiprocessing. This structure guarantees
503 that the network will only train once on each sample per epoch, which is not
504 the case with generators.
506 Examples:
508 ```python
509 from skimage.io import imread
510 from skimage.transform import resize
511 import numpy as np
512 import math
514 # Here, `x_set` is list of path to the images
515 # and `y_set` are the associated classes.
517 class CIFAR10Sequence(tf.keras.utils.Sequence):
519 def __init__(self, x_set, y_set, batch_size):
520 self.x, self.y = x_set, y_set
521 self.batch_size = batch_size
523 def __len__(self):
524 return math.ceil(len(self.x) / self.batch_size)
526 def __getitem__(self, idx):
527 low = idx * self.batch_size
528 # Cap upper bound at array length; the last batch may be smaller
529 # if the total number of items is not a multiple of batch size.
530 high = min(low + self.batch_size, len(self.x))
531 batch_x = self.x[low:high]
532 batch_y = self.y[low:high]
534 return np.array([
535 resize(imread(file_name), (200, 200))
536 for file_name in batch_x]), np.array(batch_y)
537 ```
538 """
540 @abstractmethod
541 def __getitem__(self, index):
542 """Gets batch at position `index`.
544 Args:
545 index: position of the batch in the Sequence.
547 Returns:
548 A batch
549 """
550 raise NotImplementedError
552 @abstractmethod
553 def __len__(self):
554 """Number of batch in the Sequence.
556 Returns:
557 The number of batches in the Sequence.
558 """
559 raise NotImplementedError
561 def on_epoch_end(self):
562 """Method called at the end of every epoch."""
563 pass
565 def __iter__(self):
566 """Create a generator that iterate over the Sequence."""
567 for item in (self[i] for i in range(len(self))):
568 yield item
571def iter_sequence_infinite(seq):
572 """Iterates indefinitely over a Sequence.
574 Args:
575 seq: `Sequence` instance.
577 Yields:
578 Batches of data from the `Sequence`.
579 """
580 while True:
581 for item in seq:
582 yield item
585# Global variables to be shared across processes
586_SHARED_SEQUENCES = {}
587# We use a Value to provide unique id to different processes.
588_SEQUENCE_COUNTER = None
591# Because multiprocessing pools are inherently unsafe, starting from a clean
592# state can be essential to avoiding deadlocks. In order to accomplish this, we
593# need to be able to check on the status of Pools that we create.
594_DATA_POOLS = weakref.WeakSet()
595_WORKER_ID_QUEUE = None # Only created if needed.
596_WORKER_IDS = set()
597_FORCE_THREADPOOL = False
598_FORCE_THREADPOOL_LOCK = threading.RLock()
601def dont_use_multiprocessing_pool(f):
602 @functools.wraps(f)
603 def wrapped(*args, **kwargs):
604 with _FORCE_THREADPOOL_LOCK:
605 global _FORCE_THREADPOOL
606 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True
607 out = f(*args, **kwargs)
608 _FORCE_THREADPOOL = old_force_threadpool
609 return out
611 return wrapped
614def get_pool_class(use_multiprocessing):
615 global _FORCE_THREADPOOL
616 if not use_multiprocessing or _FORCE_THREADPOOL:
617 return multiprocessing.dummy.Pool # ThreadPool
618 return multiprocessing.Pool
621def get_worker_id_queue():
622 """Lazily create the queue to track worker ids."""
623 global _WORKER_ID_QUEUE
624 if _WORKER_ID_QUEUE is None:
625 _WORKER_ID_QUEUE = multiprocessing.Queue()
626 return _WORKER_ID_QUEUE
629def init_pool(seqs):
630 global _SHARED_SEQUENCES
631 _SHARED_SEQUENCES = seqs
634def get_index(uid, i):
635 """Get the value from the Sequence `uid` at index `i`.
637 To allow multiple Sequences to be used at the same time, we use `uid` to
638 get a specific one. A single Sequence would cause the validation to
639 overwrite the training Sequence.
641 Args:
642 uid: int, Sequence identifier
643 i: index
645 Returns:
646 The value at index `i`.
647 """
648 return _SHARED_SEQUENCES[uid][i]
651@keras_export("keras.utils.SequenceEnqueuer")
652class SequenceEnqueuer:
653 """Base class to enqueue inputs.
655 The task of an Enqueuer is to use parallelism to speed up preprocessing.
656 This is done with processes or threads.
658 Example:
660 ```python
661 enqueuer = SequenceEnqueuer(...)
662 enqueuer.start()
663 datas = enqueuer.get()
664 for data in datas:
665 # Use the inputs; training, evaluating, predicting.
666 # ... stop sometime.
667 enqueuer.stop()
668 ```
670 The `enqueuer.get()` should be an infinite stream of data.
671 """
673 def __init__(self, sequence, use_multiprocessing=False):
674 self.sequence = sequence
675 self.use_multiprocessing = use_multiprocessing
677 global _SEQUENCE_COUNTER
678 if _SEQUENCE_COUNTER is None:
679 try:
680 _SEQUENCE_COUNTER = multiprocessing.Value("i", 0)
681 except OSError:
682 # In this case the OS does not allow us to use
683 # multiprocessing. We resort to an int
684 # for enqueuer indexing.
685 _SEQUENCE_COUNTER = 0
687 if isinstance(_SEQUENCE_COUNTER, int):
688 self.uid = _SEQUENCE_COUNTER
689 _SEQUENCE_COUNTER += 1
690 else:
691 # Doing Multiprocessing.Value += x is not process-safe.
692 with _SEQUENCE_COUNTER.get_lock():
693 self.uid = _SEQUENCE_COUNTER.value
694 _SEQUENCE_COUNTER.value += 1
696 self.workers = 0
697 self.executor_fn = None
698 self.queue = None
699 self.run_thread = None
700 self.stop_signal = None
702 def is_running(self):
703 return self.stop_signal is not None and not self.stop_signal.is_set()
705 def start(self, workers=1, max_queue_size=10):
706 """Starts the handler's workers.
708 Args:
709 workers: Number of workers.
710 max_queue_size: queue size
711 (when full, workers could block on `put()`)
712 """
713 if self.use_multiprocessing:
714 self.executor_fn = self._get_executor_init(workers)
715 else:
716 # We do not need the init since it's threads.
717 self.executor_fn = lambda _: get_pool_class(False)(workers)
718 self.workers = workers
719 self.queue = queue.Queue(max_queue_size)
720 self.stop_signal = threading.Event()
721 self.run_thread = threading.Thread(target=self._run)
722 self.run_thread.daemon = True
723 self.run_thread.start()
725 def _send_sequence(self):
726 """Sends current Iterable to all workers."""
727 # For new processes that may spawn
728 _SHARED_SEQUENCES[self.uid] = self.sequence
730 def stop(self, timeout=None):
731 """Stops running threads and wait for them to exit, if necessary.
733 Should be called by the same thread which called `start()`.
735 Args:
736 timeout: maximum time to wait on `thread.join()`
737 """
738 self.stop_signal.set()
739 with self.queue.mutex:
740 self.queue.queue.clear()
741 self.queue.unfinished_tasks = 0
742 self.queue.not_full.notify()
743 self.run_thread.join(timeout)
744 _SHARED_SEQUENCES[self.uid] = None
746 def __del__(self):
747 if self.is_running():
748 self.stop()
750 @abstractmethod
751 def _run(self):
752 """Submits request to the executor and queue the `Future` objects."""
753 raise NotImplementedError
755 @abstractmethod
756 def _get_executor_init(self, workers):
757 """Gets the Pool initializer for multiprocessing.
759 Args:
760 workers: Number of workers.
762 Returns:
763 Function, a Function to initialize the pool
764 """
765 raise NotImplementedError
767 @abstractmethod
768 def get(self):
769 """Creates a generator to extract data from the queue.
771 Skip the data if it is `None`.
772 # Returns
773 Generator yielding tuples `(inputs, targets)`
774 or `(inputs, targets, sample_weights)`.
775 """
776 raise NotImplementedError
779@keras_export("keras.utils.OrderedEnqueuer")
780class OrderedEnqueuer(SequenceEnqueuer):
781 """Builds a Enqueuer from a Sequence.
783 Args:
784 sequence: A `tf.keras.utils.data_utils.Sequence` object.
785 use_multiprocessing: use multiprocessing if True, otherwise threading
786 shuffle: whether to shuffle the data at the beginning of each epoch
787 """
789 def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
790 super().__init__(sequence, use_multiprocessing)
791 self.shuffle = shuffle
793 def _get_executor_init(self, workers):
794 """Gets the Pool initializer for multiprocessing.
796 Args:
797 workers: Number of workers.
799 Returns:
800 Function, a Function to initialize the pool
801 """
803 def pool_fn(seqs):
804 pool = get_pool_class(True)(
805 workers,
806 initializer=init_pool_generator,
807 initargs=(seqs, None, get_worker_id_queue()),
808 )
809 _DATA_POOLS.add(pool)
810 return pool
812 return pool_fn
814 def _wait_queue(self):
815 """Wait for the queue to be empty."""
816 while True:
817 time.sleep(0.1)
818 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
819 return
821 def _run(self):
822 """Submits request to the executor and queue the `Future` objects."""
823 sequence = list(range(len(self.sequence)))
824 self._send_sequence() # Share the initial sequence
825 while True:
826 if self.shuffle:
827 random.shuffle(sequence)
829 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
830 for i in sequence:
831 if self.stop_signal.is_set():
832 return
834 self.queue.put(
835 executor.apply_async(get_index, (self.uid, i)),
836 block=True,
837 )
839 # Done with the current epoch, waiting for the final batches
840 self._wait_queue()
842 if self.stop_signal.is_set():
843 # We're done
844 return
846 # Call the internal on epoch end.
847 self.sequence.on_epoch_end()
848 self._send_sequence() # Update the pool
850 def get(self):
851 """Creates a generator to extract data from the queue.
853 Skip the data if it is `None`.
855 Yields:
856 The next element in the queue, i.e. a tuple
857 `(inputs, targets)` or
858 `(inputs, targets, sample_weights)`.
859 """
860 while self.is_running():
861 try:
862 inputs = self.queue.get(block=True, timeout=5).get()
863 if self.is_running():
864 self.queue.task_done()
865 if inputs is not None:
866 yield inputs
867 except queue.Empty:
868 pass
869 except Exception as e:
870 self.stop()
871 raise e
874def init_pool_generator(gens, random_seed=None, id_queue=None):
875 """Initializer function for pool workers.
877 Args:
878 gens: State which should be made available to worker processes.
879 random_seed: An optional value with which to seed child processes.
880 id_queue: A multiprocessing Queue of worker ids. This is used to indicate
881 that a worker process was created by Keras and can be terminated using
882 the cleanup_all_keras_forkpools utility.
883 """
884 global _SHARED_SEQUENCES
885 _SHARED_SEQUENCES = gens
887 worker_proc = multiprocessing.current_process()
889 # name isn't used for anything, but setting a more descriptive name is
890 # helpful when diagnosing orphaned processes.
891 worker_proc.name = f"Keras_worker_{worker_proc.name}"
893 if random_seed is not None:
894 np.random.seed(random_seed + worker_proc.ident)
896 if id_queue is not None:
897 # If a worker dies during init, the pool will just create a replacement.
898 id_queue.put(worker_proc.ident, block=True, timeout=0.1)
901def next_sample(uid):
902 """Gets the next value from the generator `uid`.
904 To allow multiple generators to be used at the same time, we use `uid` to
905 get a specific one. A single generator would cause the validation to
906 overwrite the training generator.
908 Args:
909 uid: int, generator identifier
911 Returns:
912 The next value of generator `uid`.
913 """
914 return next(_SHARED_SEQUENCES[uid])
917@keras_export("keras.utils.GeneratorEnqueuer")
918class GeneratorEnqueuer(SequenceEnqueuer):
919 """Builds a queue out of a data generator.
921 The provided generator can be finite in which case the class will throw
922 a `StopIteration` exception.
924 Args:
925 generator: a generator function which yields data
926 use_multiprocessing: use multiprocessing if True, otherwise threading
927 random_seed: Initial seed for workers,
928 will be incremented by one for each worker.
929 """
931 def __init__(self, generator, use_multiprocessing=False, random_seed=None):
932 super().__init__(generator, use_multiprocessing)
933 self.random_seed = random_seed
935 def _get_executor_init(self, workers):
936 """Gets the Pool initializer for multiprocessing.
938 Args:
939 workers: Number of works.
941 Returns:
942 A Function to initialize the pool
943 """
945 def pool_fn(seqs):
946 pool = get_pool_class(True)(
947 workers,
948 initializer=init_pool_generator,
949 initargs=(seqs, self.random_seed, get_worker_id_queue()),
950 )
951 _DATA_POOLS.add(pool)
952 return pool
954 return pool_fn
956 def _run(self):
957 """Submits request to the executor and queue the `Future` objects."""
958 self._send_sequence() # Share the initial generator
959 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
960 while True:
961 if self.stop_signal.is_set():
962 return
964 self.queue.put(
965 executor.apply_async(next_sample, (self.uid,)), block=True
966 )
968 def get(self):
969 """Creates a generator to extract data from the queue.
971 Skip the data if it is `None`.
973 Yields:
974 The next element in the queue, i.e. a tuple
975 `(inputs, targets)` or
976 `(inputs, targets, sample_weights)`.
977 """
978 try:
979 while self.is_running():
980 inputs = self.queue.get(block=True).get()
981 self.queue.task_done()
982 if inputs is not None:
983 yield inputs
984 except StopIteration:
985 # Special case for finite generators
986 last_ones = []
987 while self.queue.qsize() > 0:
988 last_ones.append(self.queue.get(block=True))
989 # Wait for them to complete
990 for f in last_ones:
991 f.wait()
992 # Keep the good ones
993 last_ones = [
994 future.get() for future in last_ones if future.successful()
995 ]
996 for inputs in last_ones:
997 if inputs is not None:
998 yield inputs
999 except Exception as e:
1000 self.stop()
1001 if "generator already executing" in str(e):
1002 raise RuntimeError(
1003 "Your generator is NOT thread-safe. "
1004 "Keras requires a thread-safe generator when "
1005 "`use_multiprocessing=False, workers > 1`. "
1006 )
1007 raise e
1010@keras_export(
1011 "keras.utils.pad_sequences", "keras.preprocessing.sequence.pad_sequences"
1012)
1013def pad_sequences(
1014 sequences,
1015 maxlen=None,
1016 dtype="int32",
1017 padding="pre",
1018 truncating="pre",
1019 value=0.0,
1020):
1021 """Pads sequences to the same length.
1023 This function transforms a list (of length `num_samples`)
1024 of sequences (lists of integers)
1025 into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
1026 `num_timesteps` is either the `maxlen` argument if provided,
1027 or the length of the longest sequence in the list.
1029 Sequences that are shorter than `num_timesteps`
1030 are padded with `value` until they are `num_timesteps` long.
1032 Sequences longer than `num_timesteps` are truncated
1033 so that they fit the desired length.
1035 The position where padding or truncation happens is determined by
1036 the arguments `padding` and `truncating`, respectively.
1037 Pre-padding or removing values from the beginning of the sequence is the
1038 default.
1040 >>> sequence = [[1], [2, 3], [4, 5, 6]]
1041 >>> tf.keras.utils.pad_sequences(sequence)
1042 array([[0, 0, 1],
1043 [0, 2, 3],
1044 [4, 5, 6]], dtype=int32)
1046 >>> tf.keras.utils.pad_sequences(sequence, value=-1)
1047 array([[-1, -1, 1],
1048 [-1, 2, 3],
1049 [ 4, 5, 6]], dtype=int32)
1051 >>> tf.keras.utils.pad_sequences(sequence, padding='post')
1052 array([[1, 0, 0],
1053 [2, 3, 0],
1054 [4, 5, 6]], dtype=int32)
1056 >>> tf.keras.utils.pad_sequences(sequence, maxlen=2)
1057 array([[0, 1],
1058 [2, 3],
1059 [5, 6]], dtype=int32)
1061 Args:
1062 sequences: List of sequences (each sequence is a list of integers).
1063 maxlen: Optional Int, maximum length of all sequences. If not provided,
1064 sequences will be padded to the length of the longest individual
1065 sequence.
1066 dtype: (Optional, defaults to `"int32"`). Type of the output sequences.
1067 To pad sequences with variable length strings, you can use `object`.
1068 padding: String, "pre" or "post" (optional, defaults to `"pre"`):
1069 pad either before or after each sequence.
1070 truncating: String, "pre" or "post" (optional, defaults to `"pre"`):
1071 remove values from sequences larger than
1072 `maxlen`, either at the beginning or at the end of the sequences.
1073 value: Float or String, padding value. (Optional, defaults to 0.)
1075 Returns:
1076 Numpy array with shape `(len(sequences), maxlen)`
1078 Raises:
1079 ValueError: In case of invalid values for `truncating` or `padding`,
1080 or in case of invalid shape for a `sequences` entry.
1081 """
1082 if not hasattr(sequences, "__len__"):
1083 raise ValueError("`sequences` must be iterable.")
1084 num_samples = len(sequences)
1086 lengths = []
1087 sample_shape = ()
1088 flag = True
1090 # take the sample shape from the first non empty sequence
1091 # checking for consistency in the main loop below.
1093 for x in sequences:
1094 try:
1095 lengths.append(len(x))
1096 if flag and len(x):
1097 sample_shape = np.asarray(x).shape[1:]
1098 flag = False
1099 except TypeError as e:
1100 raise ValueError(
1101 "`sequences` must be a list of iterables. "
1102 f"Found non-iterable: {str(x)}"
1103 ) from e
1105 if maxlen is None:
1106 maxlen = np.max(lengths)
1108 is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
1109 dtype, np.unicode_
1110 )
1111 if isinstance(value, str) and dtype != object and not is_dtype_str:
1112 raise ValueError(
1113 f"`dtype` {dtype} is not compatible with `value`'s type: "
1114 f"{type(value)}\nYou should set `dtype=object` for variable length "
1115 "strings."
1116 )
1118 x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
1119 for idx, s in enumerate(sequences):
1120 if not len(s):
1121 continue # empty list/array was found
1122 if truncating == "pre":
1123 trunc = s[-maxlen:]
1124 elif truncating == "post":
1125 trunc = s[:maxlen]
1126 else:
1127 raise ValueError(f'Truncating type "{truncating}" not understood')
1129 # check `trunc` has expected shape
1130 trunc = np.asarray(trunc, dtype=dtype)
1131 if trunc.shape[1:] != sample_shape:
1132 raise ValueError(
1133 f"Shape of sample {trunc.shape[1:]} of sequence at "
1134 f"position {idx} is different from expected shape "
1135 f"{sample_shape}"
1136 )
1138 if padding == "post":
1139 x[idx, : len(trunc)] = trunc
1140 elif padding == "pre":
1141 x[idx, -len(trunc) :] = trunc
1142 else:
1143 raise ValueError(f'Padding type "{padding}" not understood')
1144 return x