Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/preprocessing/image.py: 14%
737 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Utilies for image preprocessing and augmentation.
19Deprecated: `tf.keras.preprocessing.image` APIs do not operate on tensors and
20are not recommended for new code. Prefer loading data with
21`tf.keras.utils.image_dataset_from_directory`, and then transforming the output
22`tf.data.Dataset` with preprocessing layers. For more information, see the
23tutorials for [loading images](
24https://www.tensorflow.org/tutorials/load_data/images) and [augmenting images](
25https://www.tensorflow.org/tutorials/images/data_augmentation), as well as the
26[preprocessing layer guide](
27https://www.tensorflow.org/guide/keras/preprocessing_layers).
28"""
30import collections
31import multiprocessing
32import os
33import threading
34import warnings
36import numpy as np
38from keras.src import backend
39from keras.src.utils import data_utils
40from keras.src.utils import image_utils
41from keras.src.utils import io_utils
43# isort: off
44from tensorflow.python.util.tf_export import keras_export
46try:
47 import scipy
48 from scipy import linalg # noqa: F401
49 from scipy import ndimage # noqa: F401
50except ImportError:
51 pass
52try:
53 from PIL import ImageEnhance
54except ImportError:
55 ImageEnhance = None
58@keras_export("keras.preprocessing.image.Iterator")
59class Iterator(data_utils.Sequence):
60 """Base class for image data iterators.
62 Deprecated: `tf.keras.preprocessing.image.Iterator` is not recommended for
63 new code. Prefer loading images with
64 `tf.keras.utils.image_dataset_from_directory` and transforming the output
65 `tf.data.Dataset` with preprocessing layers. For more information, see the
66 tutorials for [loading images](
67 https://www.tensorflow.org/tutorials/load_data/images) and
68 [augmenting images](
69 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
70 the [preprocessing layer guide](
71 https://www.tensorflow.org/guide/keras/preprocessing_layers).
73 Every `Iterator` must implement the `_get_batches_of_transformed_samples`
74 method.
76 Args:
77 n: Integer, total number of samples in the dataset to loop over.
78 batch_size: Integer, size of a batch.
79 shuffle: Boolean, whether to shuffle the data between epochs.
80 seed: Random seeding for data shuffling.
81 """
83 white_list_formats = ("png", "jpg", "jpeg", "bmp", "ppm", "tif", "tiff")
85 def __init__(self, n, batch_size, shuffle, seed):
86 self.n = n
87 self.batch_size = batch_size
88 self.seed = seed
89 self.shuffle = shuffle
90 self.batch_index = 0
91 self.total_batches_seen = 0
92 self.lock = threading.Lock()
93 self.index_array = None
94 self.index_generator = self._flow_index()
96 def _set_index_array(self):
97 self.index_array = np.arange(self.n)
98 if self.shuffle:
99 self.index_array = np.random.permutation(self.n)
101 def __getitem__(self, idx):
102 if idx >= len(self):
103 raise ValueError(
104 "Asked to retrieve element {idx}, "
105 "but the Sequence "
106 "has length {length}".format(idx=idx, length=len(self))
107 )
108 if self.seed is not None:
109 np.random.seed(self.seed + self.total_batches_seen)
110 self.total_batches_seen += 1
111 if self.index_array is None:
112 self._set_index_array()
113 index_array = self.index_array[
114 self.batch_size * idx : self.batch_size * (idx + 1)
115 ]
116 return self._get_batches_of_transformed_samples(index_array)
118 def __len__(self):
119 return (self.n + self.batch_size - 1) // self.batch_size # round up
121 def on_epoch_end(self):
122 self._set_index_array()
124 def reset(self):
125 self.batch_index = 0
127 def _flow_index(self):
128 # Ensure self.batch_index is 0.
129 self.reset()
130 while 1:
131 if self.seed is not None:
132 np.random.seed(self.seed + self.total_batches_seen)
133 if self.batch_index == 0:
134 self._set_index_array()
136 if self.n == 0:
137 # Avoiding modulo by zero error
138 current_index = 0
139 else:
140 current_index = (self.batch_index * self.batch_size) % self.n
141 if self.n > current_index + self.batch_size:
142 self.batch_index += 1
143 else:
144 self.batch_index = 0
145 self.total_batches_seen += 1
146 yield self.index_array[
147 current_index : current_index + self.batch_size
148 ]
150 def __iter__(self):
151 # Needed if we want to do something like:
152 # for x, y in data_gen.flow(...):
153 return self
155 def __next__(self, *args, **kwargs):
156 return self.next(*args, **kwargs)
158 def next(self):
159 """For python 2.x.
161 Returns:
162 The next batch.
163 """
164 with self.lock:
165 index_array = next(self.index_generator)
166 # The transformation of images is not under thread lock
167 # so it can be done in parallel
168 return self._get_batches_of_transformed_samples(index_array)
170 def _get_batches_of_transformed_samples(self, index_array):
171 """Gets a batch of transformed samples.
173 Args:
174 index_array: Array of sample indices to include in batch.
175 Returns:
176 A batch of transformed samples.
177 """
178 raise NotImplementedError
181def _iter_valid_files(directory, white_list_formats, follow_links):
182 """Iterates on files with extension.
184 Args:
185 directory: Absolute path to the directory
186 containing files to be counted
187 white_list_formats: Set of strings containing allowed extensions for
188 the files to be counted.
189 follow_links: Boolean, follow symbolic links to subdirectories.
190 Yields:
191 Tuple of (root, filename) with extension in `white_list_formats`.
192 """
194 def _recursive_list(subpath):
195 return sorted(
196 os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]
197 )
199 for root, _, files in _recursive_list(directory):
200 for fname in sorted(files):
201 if fname.lower().endswith(".tiff"):
202 warnings.warn(
203 'Using ".tiff" files with multiple bands '
204 "will cause distortion. Please verify your output."
205 )
206 if fname.lower().endswith(white_list_formats):
207 yield root, fname
210def _list_valid_filenames_in_directory(
211 directory, white_list_formats, split, class_indices, follow_links
212):
213 """Lists paths of files in `subdir` with extensions in `white_list_formats`.
215 Args:
216 directory: absolute path to a directory containing the files to list.
217 The directory name is used as class label
218 and must be a key of `class_indices`.
219 white_list_formats: set of strings containing allowed extensions for
220 the files to be counted.
221 split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into
222 account a certain fraction of files in each directory.
223 E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent
224 of images in each directory.
225 class_indices: dictionary mapping a class name to its index.
226 follow_links: boolean, follow symbolic links to subdirectories.
228 Returns:
229 classes: a list of class indices
230 filenames: the path of valid files in `directory`, relative from
231 `directory`'s parent (e.g., if `directory` is "dataset/class1",
232 the filenames will be
233 `["class1/file1.jpg", "class1/file2.jpg", ...]`).
234 """
235 dirname = os.path.basename(directory)
236 if split:
237 all_files = list(
238 _iter_valid_files(directory, white_list_formats, follow_links)
239 )
240 num_files = len(all_files)
241 start, stop = int(split[0] * num_files), int(split[1] * num_files)
242 valid_files = all_files[start:stop]
243 else:
244 valid_files = _iter_valid_files(
245 directory, white_list_formats, follow_links
246 )
247 classes = []
248 filenames = []
249 for root, fname in valid_files:
250 classes.append(class_indices[dirname])
251 absolute_path = os.path.join(root, fname)
252 relative_path = os.path.join(
253 dirname, os.path.relpath(absolute_path, directory)
254 )
255 filenames.append(relative_path)
257 return classes, filenames
260class BatchFromFilesMixin:
261 """Adds methods related to getting batches from filenames.
263 It includes the logic to transform image files to batches.
264 """
266 def set_processing_attrs(
267 self,
268 image_data_generator,
269 target_size,
270 color_mode,
271 data_format,
272 save_to_dir,
273 save_prefix,
274 save_format,
275 subset,
276 interpolation,
277 keep_aspect_ratio,
278 ):
279 """Sets attributes to use later for processing files into a batch.
281 Args:
282 image_data_generator: Instance of `ImageDataGenerator`
283 to use for random transformations and normalization.
284 target_size: tuple of integers, dimensions to resize input images
285 to.
286 color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
287 Color mode to read images.
288 data_format: String, one of `channels_first`, `channels_last`.
289 save_to_dir: Optional directory where to save the pictures
290 being yielded, in a viewable format. This is useful
291 for visualizing the random transformations being
292 applied, for debugging purposes.
293 save_prefix: String prefix to use for saving sample
294 images (if `save_to_dir` is set).
295 save_format: Format to use for saving sample images
296 (if `save_to_dir` is set).
297 subset: Subset of data (`"training"` or `"validation"`) if
298 validation_split is set in ImageDataGenerator.
299 interpolation: Interpolation method used to resample the image if
300 the target size is different from that of the loaded image.
301 Supported methods are "nearest", "bilinear", and "bicubic". If
302 PIL version 1.1.3 or newer is installed, "lanczos" is also
303 supported. If PIL version 3.4.0 or newer is installed, "box" and
304 "hamming" are also supported. By default, "nearest" is used.
305 keep_aspect_ratio: Boolean, whether to resize images to a target
306 size without aspect ratio distortion. The image is cropped in
307 the center with target aspect ratio before resizing.
308 """
309 self.image_data_generator = image_data_generator
310 self.target_size = tuple(target_size)
311 self.keep_aspect_ratio = keep_aspect_ratio
312 if color_mode not in {"rgb", "rgba", "grayscale"}:
313 raise ValueError(
314 "Invalid color mode:",
315 color_mode,
316 '; expected "rgb", "rgba", or "grayscale".',
317 )
318 self.color_mode = color_mode
319 self.data_format = data_format
320 if self.color_mode == "rgba":
321 if self.data_format == "channels_last":
322 self.image_shape = self.target_size + (4,)
323 else:
324 self.image_shape = (4,) + self.target_size
325 elif self.color_mode == "rgb":
326 if self.data_format == "channels_last":
327 self.image_shape = self.target_size + (3,)
328 else:
329 self.image_shape = (3,) + self.target_size
330 else:
331 if self.data_format == "channels_last":
332 self.image_shape = self.target_size + (1,)
333 else:
334 self.image_shape = (1,) + self.target_size
335 self.save_to_dir = save_to_dir
336 self.save_prefix = save_prefix
337 self.save_format = save_format
338 self.interpolation = interpolation
339 if subset is not None:
340 validation_split = self.image_data_generator._validation_split
341 if subset == "validation":
342 split = (0, validation_split)
343 elif subset == "training":
344 split = (validation_split, 1)
345 else:
346 raise ValueError(
347 "Invalid subset name: %s;"
348 'expected "training" or "validation"' % (subset,)
349 )
350 else:
351 split = None
352 self.split = split
353 self.subset = subset
355 def _get_batches_of_transformed_samples(self, index_array):
356 """Gets a batch of transformed samples.
358 Args:
359 index_array: Array of sample indices to include in batch.
360 Returns:
361 A batch of transformed samples.
362 """
363 batch_x = np.zeros(
364 (len(index_array),) + self.image_shape, dtype=self.dtype
365 )
366 # build batch of image data
367 # self.filepaths is dynamic, is better to call it once outside the loop
368 filepaths = self.filepaths
369 for i, j in enumerate(index_array):
370 img = image_utils.load_img(
371 filepaths[j],
372 color_mode=self.color_mode,
373 target_size=self.target_size,
374 interpolation=self.interpolation,
375 keep_aspect_ratio=self.keep_aspect_ratio,
376 )
377 x = image_utils.img_to_array(img, data_format=self.data_format)
378 # Pillow images should be closed after `load_img`,
379 # but not PIL images.
380 if hasattr(img, "close"):
381 img.close()
382 if self.image_data_generator:
383 params = self.image_data_generator.get_random_transform(x.shape)
384 x = self.image_data_generator.apply_transform(x, params)
385 x = self.image_data_generator.standardize(x)
386 batch_x[i] = x
387 # optionally save augmented images to disk for debugging purposes
388 if self.save_to_dir:
389 for i, j in enumerate(index_array):
390 img = image_utils.array_to_img(
391 batch_x[i], self.data_format, scale=True
392 )
393 fname = "{prefix}_{index}_{hash}.{format}".format(
394 prefix=self.save_prefix,
395 index=j,
396 hash=np.random.randint(1e7),
397 format=self.save_format,
398 )
399 img.save(os.path.join(self.save_to_dir, fname))
400 # build batch of labels
401 if self.class_mode == "input":
402 batch_y = batch_x.copy()
403 elif self.class_mode in {"binary", "sparse"}:
404 batch_y = np.empty(len(batch_x), dtype=self.dtype)
405 for i, n_observation in enumerate(index_array):
406 batch_y[i] = self.classes[n_observation]
407 elif self.class_mode == "categorical":
408 batch_y = np.zeros(
409 (len(batch_x), len(self.class_indices)), dtype=self.dtype
410 )
411 for i, n_observation in enumerate(index_array):
412 batch_y[i, self.classes[n_observation]] = 1.0
413 elif self.class_mode == "multi_output":
414 batch_y = [output[index_array] for output in self.labels]
415 elif self.class_mode == "raw":
416 batch_y = self.labels[index_array]
417 else:
418 return batch_x
419 if self.sample_weight is None:
420 return batch_x, batch_y
421 else:
422 return batch_x, batch_y, self.sample_weight[index_array]
424 @property
425 def filepaths(self):
426 """List of absolute paths to image files."""
427 raise NotImplementedError(
428 "`filepaths` property method has not "
429 "been implemented in {}.".format(type(self).__name__)
430 )
432 @property
433 def labels(self):
434 """Class labels of every observation."""
435 raise NotImplementedError(
436 "`labels` property method has not been implemented in {}.".format(
437 type(self).__name__
438 )
439 )
441 @property
442 def sample_weight(self):
443 raise NotImplementedError(
444 "`sample_weight` property method has not "
445 "been implemented in {}.".format(type(self).__name__)
446 )
449@keras_export("keras.preprocessing.image.DirectoryIterator")
450class DirectoryIterator(BatchFromFilesMixin, Iterator):
451 """Iterator capable of reading images from a directory on disk.
453 Deprecated: `tf.keras.preprocessing.image.DirectoryIterator` is not
454 recommended for new code. Prefer loading images with
455 `tf.keras.utils.image_dataset_from_directory` and transforming the output
456 `tf.data.Dataset` with preprocessing layers. For more information, see the
457 tutorials for [loading images](
458 https://www.tensorflow.org/tutorials/load_data/images) and
459 [augmenting images](
460 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
461 the [preprocessing layer guide](
462 https://www.tensorflow.org/guide/keras/preprocessing_layers).
464 Args:
465 directory: Path to the directory to read images from. Each subdirectory
466 in this directory will be considered to contain images from one class,
467 or alternatively you could specify class subdirectories via the
468 `classes` argument.
469 image_data_generator: Instance of `ImageDataGenerator` to use for random
470 transformations and normalization.
471 target_size: tuple of integers, dimensions to resize input images to.
472 color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read
473 images.
474 classes: Optional list of strings, names of subdirectories containing
475 images from each class (e.g. `["dogs", "cats"]`). It will be computed
476 automatically if not set.
477 class_mode: Mode for yielding the targets:
478 - `"binary"`: binary targets (if there are only two classes),
479 - `"categorical"`: categorical targets,
480 - `"sparse"`: integer targets,
481 - `"input"`: targets are images identical to input images (mainly
482 used to work with autoencoders),
483 - `None`: no targets get yielded (only input images are yielded).
484 batch_size: Integer, size of a batch.
485 shuffle: Boolean, whether to shuffle the data between epochs.
486 seed: Random seed for data shuffling.
487 data_format: String, one of `channels_first`, `channels_last`.
488 save_to_dir: Optional directory where to save the pictures being
489 yielded, in a viewable format. This is useful for visualizing the
490 random transformations being applied, for debugging purposes.
491 save_prefix: String prefix to use for saving sample images (if
492 `save_to_dir` is set).
493 save_format: Format to use for saving sample images (if `save_to_dir` is
494 set).
495 subset: Subset of data (`"training"` or `"validation"`) if
496 validation_split is set in ImageDataGenerator.
497 interpolation: Interpolation method used to resample the image if the
498 target size is different from that of the loaded image. Supported
499 methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3
500 or newer is installed, "lanczos" is also supported. If PIL version
501 3.4.0 or newer is installed, "box" and "hamming" are also supported.
502 By default, "nearest" is used.
503 keep_aspect_ratio: Boolean, whether to resize images to a target size
504 without aspect ratio distortion. The image is cropped in the center
505 with target aspect ratio before resizing.
506 dtype: Dtype to use for generated arrays.
507 """
509 allowed_class_modes = {"categorical", "binary", "sparse", "input", None}
511 def __init__(
512 self,
513 directory,
514 image_data_generator,
515 target_size=(256, 256),
516 color_mode="rgb",
517 classes=None,
518 class_mode="categorical",
519 batch_size=32,
520 shuffle=True,
521 seed=None,
522 data_format=None,
523 save_to_dir=None,
524 save_prefix="",
525 save_format="png",
526 follow_links=False,
527 subset=None,
528 interpolation="nearest",
529 keep_aspect_ratio=False,
530 dtype=None,
531 ):
532 if data_format is None:
533 data_format = backend.image_data_format()
534 if dtype is None:
535 dtype = backend.floatx()
536 super().set_processing_attrs(
537 image_data_generator,
538 target_size,
539 color_mode,
540 data_format,
541 save_to_dir,
542 save_prefix,
543 save_format,
544 subset,
545 interpolation,
546 keep_aspect_ratio,
547 )
548 self.directory = directory
549 self.classes = classes
550 if class_mode not in self.allowed_class_modes:
551 raise ValueError(
552 "Invalid class_mode: {}; expected one of: {}".format(
553 class_mode, self.allowed_class_modes
554 )
555 )
556 self.class_mode = class_mode
557 self.dtype = dtype
558 # First, count the number of samples and classes.
559 self.samples = 0
561 if not classes:
562 classes = []
563 for subdir in sorted(os.listdir(directory)):
564 if os.path.isdir(os.path.join(directory, subdir)):
565 classes.append(subdir)
566 self.num_classes = len(classes)
567 self.class_indices = dict(zip(classes, range(len(classes))))
569 pool = multiprocessing.pool.ThreadPool()
571 # Second, build an index of the images
572 # in the different class subfolders.
573 results = []
574 self.filenames = []
575 i = 0
576 for dirpath in (os.path.join(directory, subdir) for subdir in classes):
577 results.append(
578 pool.apply_async(
579 _list_valid_filenames_in_directory,
580 (
581 dirpath,
582 self.white_list_formats,
583 self.split,
584 self.class_indices,
585 follow_links,
586 ),
587 )
588 )
589 classes_list = []
590 for res in results:
591 classes, filenames = res.get()
592 classes_list.append(classes)
593 self.filenames += filenames
594 self.samples = len(self.filenames)
595 self.classes = np.zeros((self.samples,), dtype="int32")
596 for classes in classes_list:
597 self.classes[i : i + len(classes)] = classes
598 i += len(classes)
600 io_utils.print_msg(
601 f"Found {self.samples} images belonging to "
602 f"{self.num_classes} classes."
603 )
604 pool.close()
605 pool.join()
606 self._filepaths = [
607 os.path.join(self.directory, fname) for fname in self.filenames
608 ]
609 super().__init__(self.samples, batch_size, shuffle, seed)
611 @property
612 def filepaths(self):
613 return self._filepaths
615 @property
616 def labels(self):
617 return self.classes
619 @property # mixin needs this property to work
620 def sample_weight(self):
621 # no sample weights will be returned
622 return None
625@keras_export("keras.preprocessing.image.NumpyArrayIterator")
626class NumpyArrayIterator(Iterator):
627 """Iterator yielding data from a Numpy array.
629 Deprecated: `tf.keras.preprocessing.image.NumpyArrayIterator` is not
630 recommended for new code. Prefer loading images with
631 `tf.keras.utils.image_dataset_from_directory` and transforming the output
632 `tf.data.Dataset` with preprocessing layers. For more information, see the
633 tutorials for [loading images](
634 https://www.tensorflow.org/tutorials/load_data/images) and
635 [augmenting images](
636 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
637 the [preprocessing layer guide](
638 https://www.tensorflow.org/guide/keras/preprocessing_layers).
640 Args:
641 x: Numpy array of input data or tuple. If tuple, the second elements is
642 either another numpy array or a list of numpy arrays, each of which
643 gets passed through as an output without any modifications.
644 y: Numpy array of targets data.
645 image_data_generator: Instance of `ImageDataGenerator` to use for random
646 transformations and normalization.
647 batch_size: Integer, size of a batch.
648 shuffle: Boolean, whether to shuffle the data between epochs.
649 sample_weight: Numpy array of sample weights.
650 seed: Random seed for data shuffling.
651 data_format: String, one of `channels_first`, `channels_last`.
652 save_to_dir: Optional directory where to save the pictures being
653 yielded, in a viewable format. This is useful for visualizing the
654 random transformations being applied, for debugging purposes.
655 save_prefix: String prefix to use for saving sample images (if
656 `save_to_dir` is set).
657 save_format: Format to use for saving sample images (if `save_to_dir` is
658 set).
659 subset: Subset of data (`"training"` or `"validation"`) if
660 validation_split is set in ImageDataGenerator.
661 ignore_class_split: Boolean (default: False), ignore difference
662 in number of classes in labels across train and validation
663 split (useful for non-classification tasks)
664 dtype: Dtype to use for the generated arrays.
665 """
667 def __init__(
668 self,
669 x,
670 y,
671 image_data_generator,
672 batch_size=32,
673 shuffle=False,
674 sample_weight=None,
675 seed=None,
676 data_format=None,
677 save_to_dir=None,
678 save_prefix="",
679 save_format="png",
680 subset=None,
681 ignore_class_split=False,
682 dtype=None,
683 ):
684 if data_format is None:
685 data_format = backend.image_data_format()
686 if dtype is None:
687 dtype = backend.floatx()
688 self.dtype = dtype
689 if isinstance(x, tuple) or isinstance(x, list):
690 if not isinstance(x[1], list):
691 x_misc = [np.asarray(x[1])]
692 else:
693 x_misc = [np.asarray(xx) for xx in x[1]]
694 x = x[0]
695 for xx in x_misc:
696 if len(x) != len(xx):
697 raise ValueError(
698 "All of the arrays in `x` "
699 "should have the same length. "
700 "Found a pair with: len(x[0]) = %s, len(x[?]) = %s"
701 % (len(x), len(xx))
702 )
703 else:
704 x_misc = []
706 if y is not None and len(x) != len(y):
707 raise ValueError(
708 "`x` (images tensor) and `y` (labels) "
709 "should have the same length. "
710 "Found: x.shape = %s, y.shape = %s"
711 % (np.asarray(x).shape, np.asarray(y).shape)
712 )
713 if sample_weight is not None and len(x) != len(sample_weight):
714 raise ValueError(
715 "`x` (images tensor) and `sample_weight` "
716 "should have the same length. "
717 "Found: x.shape = %s, sample_weight.shape = %s"
718 % (np.asarray(x).shape, np.asarray(sample_weight).shape)
719 )
720 if subset is not None:
721 if subset not in {"training", "validation"}:
722 raise ValueError(
723 "Invalid subset name:",
724 subset,
725 '; expected "training" or "validation".',
726 )
727 split_idx = int(len(x) * image_data_generator._validation_split)
729 if (
730 y is not None
731 and not ignore_class_split
732 and not np.array_equal(
733 np.unique(y[:split_idx]), np.unique(y[split_idx:])
734 )
735 ):
736 raise ValueError(
737 "Training and validation subsets "
738 "have different number of classes after "
739 "the split. If your numpy arrays are "
740 "sorted by the label, you might want "
741 "to shuffle them."
742 )
744 if subset == "validation":
745 x = x[:split_idx]
746 x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
747 if y is not None:
748 y = y[:split_idx]
749 else:
750 x = x[split_idx:]
751 x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
752 if y is not None:
753 y = y[split_idx:]
755 self.x = np.asarray(x, dtype=self.dtype)
756 self.x_misc = x_misc
757 if self.x.ndim != 4:
758 raise ValueError(
759 "Input data in `NumpyArrayIterator` "
760 "should have rank 4. You passed an array "
761 "with shape",
762 self.x.shape,
763 )
764 channels_axis = 3 if data_format == "channels_last" else 1
765 if self.x.shape[channels_axis] not in {1, 3, 4}:
766 warnings.warn(
767 'NumpyArrayIterator is set to use the data format convention "'
768 + data_format
769 + '" (channels on axis '
770 + str(channels_axis)
771 + "), i.e. expected either 1, 3, or 4 channels on axis "
772 + str(channels_axis)
773 + ". However, it was passed an array with shape "
774 + str(self.x.shape)
775 + " ("
776 + str(self.x.shape[channels_axis])
777 + " channels)."
778 )
779 if y is not None:
780 self.y = np.asarray(y)
781 else:
782 self.y = None
783 if sample_weight is not None:
784 self.sample_weight = np.asarray(sample_weight)
785 else:
786 self.sample_weight = None
787 self.image_data_generator = image_data_generator
788 self.data_format = data_format
789 self.save_to_dir = save_to_dir
790 self.save_prefix = save_prefix
791 self.save_format = save_format
792 super().__init__(x.shape[0], batch_size, shuffle, seed)
794 def _get_batches_of_transformed_samples(self, index_array):
795 batch_x = np.zeros(
796 tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=self.dtype
797 )
798 for i, j in enumerate(index_array):
799 x = self.x[j]
800 params = self.image_data_generator.get_random_transform(x.shape)
801 x = self.image_data_generator.apply_transform(
802 x.astype(self.dtype), params
803 )
804 x = self.image_data_generator.standardize(x)
805 batch_x[i] = x
807 if self.save_to_dir:
808 for i, j in enumerate(index_array):
809 img = image_utils.array_to_img(
810 batch_x[i], self.data_format, scale=True
811 )
812 fname = "{prefix}_{index}_{hash}.{format}".format(
813 prefix=self.save_prefix,
814 index=j,
815 hash=np.random.randint(1e4),
816 format=self.save_format,
817 )
818 img.save(os.path.join(self.save_to_dir, fname))
819 batch_x_miscs = [xx[index_array] for xx in self.x_misc]
820 output = (batch_x if not batch_x_miscs else [batch_x] + batch_x_miscs,)
821 if self.y is None:
822 return output[0]
823 output += (self.y[index_array],)
824 if self.sample_weight is not None:
825 output += (self.sample_weight[index_array],)
826 return output
829def validate_filename(filename, white_list_formats):
830 """Check if a filename refers to a valid file.
832 Args:
833 filename: String, absolute path to a file
834 white_list_formats: Set, allowed file extensions
835 Returns:
836 A boolean value indicating if the filename is valid or not
837 """
838 return filename.lower().endswith(white_list_formats) and os.path.isfile(
839 filename
840 )
843class DataFrameIterator(BatchFromFilesMixin, Iterator):
844 """Iterator capable of reading images from a directory as a dataframe.
846 Args:
847 dataframe: Pandas dataframe containing the filepaths relative to
848 `directory` (or absolute paths if `directory` is None) of the images
849 in a string column. It should include other column/s depending on the
850 `class_mode`: - if `class_mode` is `"categorical"` (default value) it
851 must include the `y_col` column with the class/es of each image.
852 Values in column can be string/list/tuple if a single class or
853 list/tuple if multiple classes.
854 - if `class_mode` is `"binary"` or `"sparse"` it must include the
855 given `y_col` column with class values as strings.
856 - if `class_mode` is `"raw"` or `"multi_output"` it should contain
857 the columns specified in `y_col`.
858 - if `class_mode` is `"input"` or `None` no extra column is needed.
859 directory: string, path to the directory to read images from. If `None`,
860 data in `x_col` column should be absolute paths.
861 image_data_generator: Instance of `ImageDataGenerator` to use for random
862 transformations and normalization. If None, no transformations and
863 normalizations are made.
864 x_col: string, column in `dataframe` that contains the filenames (or
865 absolute paths if `directory` is `None`).
866 y_col: string or list, column/s in `dataframe` that has the target data.
867 weight_col: string, column in `dataframe` that contains the sample
868 weights. Default: `None`.
869 target_size: tuple of integers, dimensions to resize input images to.
870 color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read
871 images.
872 classes: Optional list of strings, classes to use (e.g. `["dogs",
873 "cats"]`). If None, all classes in `y_col` will be used.
874 class_mode: one of "binary", "categorical", "input", "multi_output",
875 "raw", "sparse" or None. Default: "categorical".
876 Mode for yielding the targets:
877 - `"binary"`: 1D numpy array of binary labels,
878 - `"categorical"`: 2D numpy array of one-hot encoded labels.
879 Supports multi-label output.
880 - `"input"`: images identical to input images (mainly used to work
881 with autoencoders),
882 - `"multi_output"`: list with the values of the different columns,
883 - `"raw"`: numpy array of values in `y_col` column(s),
884 - `"sparse"`: 1D numpy array of integer labels, - `None`, no targets
885 are returned (the generator will only yield batches of image data,
886 which is useful to use in `model.predict()`).
887 batch_size: Integer, size of a batch.
888 shuffle: Boolean, whether to shuffle the data between epochs.
889 seed: Random seed for data shuffling.
890 data_format: String, one of `channels_first`, `channels_last`.
891 save_to_dir: Optional directory where to save the pictures being
892 yielded, in a viewable format. This is useful for visualizing the
893 random transformations being applied, for debugging purposes.
894 save_prefix: String prefix to use for saving sample images (if
895 `save_to_dir` is set).
896 save_format: Format to use for saving sample images (if `save_to_dir` is
897 set).
898 subset: Subset of data (`"training"` or `"validation"`) if
899 validation_split is set in ImageDataGenerator.
900 interpolation: Interpolation method used to resample the image if the
901 target size is different from that of the loaded image. Supported
902 methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3
903 or newer is installed, "lanczos" is also supported. If PIL version
904 3.4.0 or newer is installed, "box" and "hamming" are also supported.
905 By default, "nearest" is used.
906 keep_aspect_ratio: Boolean, whether to resize images to a target size
907 without aspect ratio distortion. The image is cropped in the center
908 with target aspect ratio before resizing.
909 dtype: Dtype to use for the generated arrays.
910 validate_filenames: Boolean, whether to validate image filenames in
911 `x_col`. If `True`, invalid images will be ignored. Disabling this
912 option can lead to speed-up in the instantiation of this class.
913 Default: `True`.
914 """
916 allowed_class_modes = {
917 "binary",
918 "categorical",
919 "input",
920 "multi_output",
921 "raw",
922 "sparse",
923 None,
924 }
926 def __init__(
927 self,
928 dataframe,
929 directory=None,
930 image_data_generator=None,
931 x_col="filename",
932 y_col="class",
933 weight_col=None,
934 target_size=(256, 256),
935 color_mode="rgb",
936 classes=None,
937 class_mode="categorical",
938 batch_size=32,
939 shuffle=True,
940 seed=None,
941 data_format="channels_last",
942 save_to_dir=None,
943 save_prefix="",
944 save_format="png",
945 subset=None,
946 interpolation="nearest",
947 keep_aspect_ratio=False,
948 dtype="float32",
949 validate_filenames=True,
950 ):
951 super().set_processing_attrs(
952 image_data_generator,
953 target_size,
954 color_mode,
955 data_format,
956 save_to_dir,
957 save_prefix,
958 save_format,
959 subset,
960 interpolation,
961 keep_aspect_ratio,
962 )
963 df = dataframe.copy()
964 self.directory = directory or ""
965 self.class_mode = class_mode
966 self.dtype = dtype
967 # check that inputs match the required class_mode
968 self._check_params(df, x_col, y_col, weight_col, classes)
969 if (
970 validate_filenames
971 ): # check which image files are valid and keep them
972 df = self._filter_valid_filepaths(df, x_col)
973 if class_mode not in ["input", "multi_output", "raw", None]:
974 df, classes = self._filter_classes(df, y_col, classes)
975 num_classes = len(classes)
976 # build an index of all the unique classes
977 self.class_indices = dict(zip(classes, range(len(classes))))
978 # retrieve only training or validation set
979 if self.split:
980 num_files = len(df)
981 start = int(self.split[0] * num_files)
982 stop = int(self.split[1] * num_files)
983 df = df.iloc[start:stop, :]
984 # get labels for each observation
985 if class_mode not in ["input", "multi_output", "raw", None]:
986 self.classes = self.get_classes(df, y_col)
987 self.filenames = df[x_col].tolist()
988 self._sample_weight = df[weight_col].values if weight_col else None
990 if class_mode == "multi_output":
991 self._targets = [np.array(df[col].tolist()) for col in y_col]
992 if class_mode == "raw":
993 self._targets = df[y_col].values
994 self.samples = len(self.filenames)
995 validated_string = (
996 "validated" if validate_filenames else "non-validated"
997 )
998 if class_mode in ["input", "multi_output", "raw", None]:
999 io_utils.print_msg(
1000 f"Found {self.samples} {validated_string} image filenames."
1001 )
1002 else:
1003 io_utils.print_msg(
1004 f"Found {self.samples} {validated_string} image filenames "
1005 f"belonging to {num_classes} classes."
1006 )
1007 self._filepaths = [
1008 os.path.join(self.directory, fname) for fname in self.filenames
1009 ]
1010 super().__init__(self.samples, batch_size, shuffle, seed)
1012 def _check_params(self, df, x_col, y_col, weight_col, classes):
1013 # check class mode is one of the currently supported
1014 if self.class_mode not in self.allowed_class_modes:
1015 raise ValueError(
1016 "Invalid class_mode: {}; expected one of: {}".format(
1017 self.class_mode, self.allowed_class_modes
1018 )
1019 )
1020 # check that y_col has several column names if class_mode is
1021 # multi_output
1022 if (self.class_mode == "multi_output") and not isinstance(y_col, list):
1023 raise TypeError(
1024 'If class_mode="{}", y_col must be a list. Received {}.'.format(
1025 self.class_mode, type(y_col).__name__
1026 )
1027 )
1028 # check that filenames/filepaths column values are all strings
1029 if not all(df[x_col].apply(lambda x: isinstance(x, str))):
1030 raise TypeError(
1031 f"All values in column x_col={x_col} must be strings."
1032 )
1033 # check labels are string if class_mode is binary or sparse
1034 if self.class_mode in {"binary", "sparse"}:
1035 if not all(df[y_col].apply(lambda x: isinstance(x, str))):
1036 raise TypeError(
1037 'If class_mode="{}", y_col="{}" column '
1038 "values must be strings.".format(self.class_mode, y_col)
1039 )
1040 # check that if binary there are only 2 different classes
1041 if self.class_mode == "binary":
1042 if classes:
1043 classes = set(classes)
1044 if len(classes) != 2:
1045 raise ValueError(
1046 'If class_mode="binary" there must be 2 '
1047 "classes. {} class/es were given.".format(len(classes))
1048 )
1049 elif df[y_col].nunique() != 2:
1050 raise ValueError(
1051 'If class_mode="binary" there must be 2 classes. '
1052 "Found {} classes.".format(df[y_col].nunique())
1053 )
1054 # check values are string, list or tuple if class_mode is categorical
1055 if self.class_mode == "categorical":
1056 types = (str, list, tuple)
1057 if not all(df[y_col].apply(lambda x: isinstance(x, types))):
1058 raise TypeError(
1059 'If class_mode="{}", y_col="{}" column '
1060 "values must be type string, list or tuple.".format(
1061 self.class_mode, y_col
1062 )
1063 )
1064 # raise warning if classes are given but will be unused
1065 if classes and self.class_mode in {
1066 "input",
1067 "multi_output",
1068 "raw",
1069 None,
1070 }:
1071 warnings.warn(
1072 '`classes` will be ignored given the class_mode="{}"'.format(
1073 self.class_mode
1074 )
1075 )
1076 # check that if weight column that the values are numerical
1077 if weight_col and not issubclass(df[weight_col].dtype.type, np.number):
1078 raise TypeError(f"Column weight_col={weight_col} must be numeric.")
1080 def get_classes(self, df, y_col):
1081 labels = []
1082 for label in df[y_col]:
1083 if isinstance(label, (list, tuple)):
1084 labels.append([self.class_indices[lbl] for lbl in label])
1085 else:
1086 labels.append(self.class_indices[label])
1087 return labels
1089 @staticmethod
1090 def _filter_classes(df, y_col, classes):
1091 df = df.copy()
1093 def remove_classes(labels, classes):
1094 if isinstance(labels, (list, tuple)):
1095 labels = [cls for cls in labels if cls in classes]
1096 return labels or None
1097 elif isinstance(labels, str):
1098 return labels if labels in classes else None
1099 else:
1100 raise TypeError(
1101 "Expect string, list or tuple "
1102 "but found {} in {} column ".format(type(labels), y_col)
1103 )
1105 if classes:
1106 # prepare for membership lookup
1107 classes = list(collections.OrderedDict.fromkeys(classes).keys())
1108 df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes))
1109 else:
1110 classes = set()
1111 for v in df[y_col]:
1112 if isinstance(v, (list, tuple)):
1113 classes.update(v)
1114 else:
1115 classes.add(v)
1116 classes = sorted(classes)
1117 return df.dropna(subset=[y_col]), classes
1119 def _filter_valid_filepaths(self, df, x_col):
1120 """Keep only dataframe rows with valid filenames.
1122 Args:
1123 df: Pandas dataframe containing filenames in a column
1124 x_col: string, column in `df` that contains the filenames or
1125 filepaths
1126 Returns:
1127 absolute paths to image files
1128 """
1129 filepaths = df[x_col].map(
1130 lambda fname: os.path.join(self.directory, fname)
1131 )
1132 mask = filepaths.apply(
1133 validate_filename, args=(self.white_list_formats,)
1134 )
1135 n_invalid = (~mask).sum()
1136 if n_invalid:
1137 warnings.warn(
1138 'Found {} invalid image filename(s) in x_col="{}". '
1139 "These filename(s) will be ignored.".format(n_invalid, x_col)
1140 )
1141 return df[mask]
1143 @property
1144 def filepaths(self):
1145 return self._filepaths
1147 @property
1148 def labels(self):
1149 if self.class_mode in {"multi_output", "raw"}:
1150 return self._targets
1151 else:
1152 return self.classes
1154 @property
1155 def sample_weight(self):
1156 return self._sample_weight
1159def flip_axis(x, axis):
1160 x = np.asarray(x).swapaxes(axis, 0)
1161 x = x[::-1, ...]
1162 x = x.swapaxes(0, axis)
1163 return x
1166@keras_export("keras.preprocessing.image.ImageDataGenerator")
1167class ImageDataGenerator:
1168 """Generate batches of tensor image data with real-time data augmentation.
1170 Deprecated: `tf.keras.preprocessing.image.ImageDataGenerator` is not
1171 recommended for new code. Prefer loading images with
1172 `tf.keras.utils.image_dataset_from_directory` and transforming the output
1173 `tf.data.Dataset` with preprocessing layers. For more information, see the
1174 tutorials for [loading images](
1175 https://www.tensorflow.org/tutorials/load_data/images) and
1176 [augmenting images](
1177 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
1178 the [preprocessing layer guide](
1179 https://www.tensorflow.org/guide/keras/preprocessing_layers).
1181 The data will be looped over (in batches).
1183 Args:
1184 featurewise_center: Boolean. Set input mean to 0 over the dataset,
1185 feature-wise.
1186 samplewise_center: Boolean. Set each sample mean to 0.
1187 featurewise_std_normalization: Boolean. Divide inputs by std of the
1188 dataset, feature-wise.
1189 samplewise_std_normalization: Boolean. Divide each input by its std.
1190 zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
1191 zca_whitening: Boolean. Apply ZCA whitening.
1192 rotation_range: Int. Degree range for random rotations.
1193 width_shift_range: Float, 1-D array-like or int
1194 - float: fraction of total width, if < 1, or pixels if >= 1.
1195 - 1-D array-like: random elements from the array.
1196 - int: integer number of pixels from interval `(-width_shift_range,
1197 +width_shift_range)` - With `width_shift_range=2` possible values
1198 are integers `[-1, 0, +1]`, same as with `width_shift_range=[-1,
1199 0, +1]`, while with `width_shift_range=1.0` possible values are
1200 floats in the interval [-1.0, +1.0).
1201 height_shift_range: Float, 1-D array-like or int
1202 - float: fraction of total height, if < 1, or pixels if >= 1.
1203 - 1-D array-like: random elements from the array.
1204 - int: integer number of pixels from interval `(-height_shift_range,
1205 +height_shift_range)` - With `height_shift_range=2` possible
1206 values are integers `[-1, 0, +1]`, same as with
1207 `height_shift_range=[-1, 0, +1]`, while with
1208 `height_shift_range=1.0` possible values are floats in the
1209 interval [-1.0, +1.0).
1210 brightness_range: Tuple or list of two floats. Range for picking a
1211 brightness shift value from.
1212 shear_range: Float. Shear Intensity (Shear angle in counter-clockwise
1213 direction in degrees)
1214 zoom_range: Float or [lower, upper]. Range for random zoom. If a float,
1215 `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
1216 channel_shift_range: Float. Range for random channel shifts.
1217 fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}. Default
1218 is 'nearest'. Points outside the boundaries of the input are filled
1219 according to the given mode:
1220 - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
1221 - 'nearest': aaaaaaaa|abcd|dddddddd
1222 - 'reflect': abcddcba|abcd|dcbaabcd
1223 - 'wrap': abcdabcd|abcd|abcdabcd
1224 cval: Float or Int. Value used for points outside the boundaries when
1225 `fill_mode = "constant"`.
1226 horizontal_flip: Boolean. Randomly flip inputs horizontally.
1227 vertical_flip: Boolean. Randomly flip inputs vertically.
1228 rescale: rescaling factor. Defaults to None. If None or 0, no rescaling
1229 is applied, otherwise we multiply the data by the value provided
1230 (after applying all other transformations).
1231 preprocessing_function: function that will be applied on each input. The
1232 function will run after the image is resized and augmented.
1233 The function should take one argument: one image (Numpy tensor with
1234 rank 3), and should output a Numpy tensor with the same shape.
1235 data_format: Image data format, either "channels_first" or
1236 "channels_last". "channels_last" mode means that the images should
1237 have shape `(samples, height, width, channels)`, "channels_first" mode
1238 means that the images should have shape `(samples, channels, height,
1239 width)`. It defaults to the `image_data_format` value found in your
1240 Keras config file at `~/.keras/keras.json`. If you never set it, then
1241 it will be "channels_last".
1242 validation_split: Float. Fraction of images reserved for validation
1243 (strictly between 0 and 1).
1244 dtype: Dtype to use for the generated arrays.
1246 Raises:
1247 ValueError: If the value of the argument, `data_format` is other than
1248 `"channels_last"` or `"channels_first"`.
1249 ValueError: If the value of the argument, `validation_split` > 1
1250 or `validation_split` < 0.
1252 Examples:
1254 Example of using `.flow(x, y)`:
1256 ```python
1257 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
1258 y_train = utils.to_categorical(y_train, num_classes)
1259 y_test = utils.to_categorical(y_test, num_classes)
1260 datagen = ImageDataGenerator(
1261 featurewise_center=True,
1262 featurewise_std_normalization=True,
1263 rotation_range=20,
1264 width_shift_range=0.2,
1265 height_shift_range=0.2,
1266 horizontal_flip=True,
1267 validation_split=0.2)
1268 # compute quantities required for featurewise normalization
1269 # (std, mean, and principal components if ZCA whitening is applied)
1270 datagen.fit(x_train)
1271 # fits the model on batches with real-time data augmentation:
1272 model.fit(datagen.flow(x_train, y_train, batch_size=32,
1273 subset='training'),
1274 validation_data=datagen.flow(x_train, y_train,
1275 batch_size=8, subset='validation'),
1276 steps_per_epoch=len(x_train) / 32, epochs=epochs)
1277 # here's a more "manual" example
1278 for e in range(epochs):
1279 print('Epoch', e)
1280 batches = 0
1281 for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
1282 model.fit(x_batch, y_batch)
1283 batches += 1
1284 if batches >= len(x_train) / 32:
1285 # we need to break the loop by hand because
1286 # the generator loops indefinitely
1287 break
1288 ```
1290 Example of using `.flow_from_directory(directory)`:
1292 ```python
1293 train_datagen = ImageDataGenerator(
1294 rescale=1./255,
1295 shear_range=0.2,
1296 zoom_range=0.2,
1297 horizontal_flip=True)
1298 test_datagen = ImageDataGenerator(rescale=1./255)
1299 train_generator = train_datagen.flow_from_directory(
1300 'data/train',
1301 target_size=(150, 150),
1302 batch_size=32,
1303 class_mode='binary')
1304 validation_generator = test_datagen.flow_from_directory(
1305 'data/validation',
1306 target_size=(150, 150),
1307 batch_size=32,
1308 class_mode='binary')
1309 model.fit(
1310 train_generator,
1311 steps_per_epoch=2000,
1312 epochs=50,
1313 validation_data=validation_generator,
1314 validation_steps=800)
1315 ```
1317 Example of transforming images and masks together.
1319 ```python
1320 # we create two instances with the same arguments
1321 data_gen_args = dict(featurewise_center=True,
1322 featurewise_std_normalization=True,
1323 rotation_range=90,
1324 width_shift_range=0.1,
1325 height_shift_range=0.1,
1326 zoom_range=0.2)
1327 image_datagen = ImageDataGenerator(**data_gen_args)
1328 mask_datagen = ImageDataGenerator(**data_gen_args)
1329 # Provide the same seed and keyword arguments to the fit and flow methods
1330 seed = 1
1331 image_datagen.fit(images, augment=True, seed=seed)
1332 mask_datagen.fit(masks, augment=True, seed=seed)
1333 image_generator = image_datagen.flow_from_directory(
1334 'data/images',
1335 class_mode=None,
1336 seed=seed)
1337 mask_generator = mask_datagen.flow_from_directory(
1338 'data/masks',
1339 class_mode=None,
1340 seed=seed)
1341 # combine generators into one which yields image and masks
1342 train_generator = zip(image_generator, mask_generator)
1343 model.fit(
1344 train_generator,
1345 steps_per_epoch=2000,
1346 epochs=50)
1347 ```
1348 """
1350 def __init__(
1351 self,
1352 featurewise_center=False,
1353 samplewise_center=False,
1354 featurewise_std_normalization=False,
1355 samplewise_std_normalization=False,
1356 zca_whitening=False,
1357 zca_epsilon=1e-6,
1358 rotation_range=0,
1359 width_shift_range=0.0,
1360 height_shift_range=0.0,
1361 brightness_range=None,
1362 shear_range=0.0,
1363 zoom_range=0.0,
1364 channel_shift_range=0.0,
1365 fill_mode="nearest",
1366 cval=0.0,
1367 horizontal_flip=False,
1368 vertical_flip=False,
1369 rescale=None,
1370 preprocessing_function=None,
1371 data_format=None,
1372 validation_split=0.0,
1373 interpolation_order=1,
1374 dtype=None,
1375 ):
1376 if data_format is None:
1377 data_format = backend.image_data_format()
1378 if dtype is None:
1379 dtype = backend.floatx()
1381 self.featurewise_center = featurewise_center
1382 self.samplewise_center = samplewise_center
1383 self.featurewise_std_normalization = featurewise_std_normalization
1384 self.samplewise_std_normalization = samplewise_std_normalization
1385 self.zca_whitening = zca_whitening
1386 self.zca_epsilon = zca_epsilon
1387 self.rotation_range = rotation_range
1388 self.width_shift_range = width_shift_range
1389 self.height_shift_range = height_shift_range
1390 self.shear_range = shear_range
1391 self.zoom_range = zoom_range
1392 self.channel_shift_range = channel_shift_range
1393 self.fill_mode = fill_mode
1394 self.cval = cval
1395 self.horizontal_flip = horizontal_flip
1396 self.vertical_flip = vertical_flip
1397 self.rescale = rescale
1398 self.preprocessing_function = preprocessing_function
1399 self.dtype = dtype
1400 self.interpolation_order = interpolation_order
1402 if data_format not in {"channels_last", "channels_first"}:
1403 raise ValueError(
1404 '`data_format` should be `"channels_last"` '
1405 "(channel after row and column) or "
1406 '`"channels_first"` (channel before row and column). '
1407 "Received: %s" % data_format
1408 )
1409 self.data_format = data_format
1410 if data_format == "channels_first":
1411 self.channel_axis = 1
1412 self.row_axis = 2
1413 self.col_axis = 3
1414 if data_format == "channels_last":
1415 self.channel_axis = 3
1416 self.row_axis = 1
1417 self.col_axis = 2
1418 if validation_split and not 0 < validation_split < 1:
1419 raise ValueError(
1420 "`validation_split` must be strictly between 0 and 1. "
1421 " Received: %s" % validation_split
1422 )
1423 self._validation_split = validation_split
1425 self.mean = None
1426 self.std = None
1427 self.zca_whitening_matrix = None
1429 if isinstance(zoom_range, (float, int)):
1430 self.zoom_range = [1 - zoom_range, 1 + zoom_range]
1431 elif len(zoom_range) == 2 and all(
1432 isinstance(val, (float, int)) for val in zoom_range
1433 ):
1434 self.zoom_range = [zoom_range[0], zoom_range[1]]
1435 else:
1436 raise ValueError(
1437 "`zoom_range` should be a float or "
1438 "a tuple or list of two floats. "
1439 "Received: %s" % (zoom_range,)
1440 )
1441 if zca_whitening:
1442 if not featurewise_center:
1443 self.featurewise_center = True
1444 warnings.warn(
1445 "This ImageDataGenerator specifies "
1446 "`zca_whitening`, which overrides "
1447 "setting of `featurewise_center`."
1448 )
1449 if featurewise_std_normalization:
1450 self.featurewise_std_normalization = False
1451 warnings.warn(
1452 "This ImageDataGenerator specifies "
1453 "`zca_whitening` "
1454 "which overrides setting of"
1455 "`featurewise_std_normalization`."
1456 )
1457 if featurewise_std_normalization:
1458 if not featurewise_center:
1459 self.featurewise_center = True
1460 warnings.warn(
1461 "This ImageDataGenerator specifies "
1462 "`featurewise_std_normalization`, "
1463 "which overrides setting of "
1464 "`featurewise_center`."
1465 )
1466 if samplewise_std_normalization:
1467 if not samplewise_center:
1468 self.samplewise_center = True
1469 warnings.warn(
1470 "This ImageDataGenerator specifies "
1471 "`samplewise_std_normalization`, "
1472 "which overrides setting of "
1473 "`samplewise_center`."
1474 )
1475 if brightness_range is not None:
1476 if (
1477 not isinstance(brightness_range, (tuple, list))
1478 or len(brightness_range) != 2
1479 ):
1480 raise ValueError(
1481 "`brightness_range should be tuple or list of two floats. "
1482 "Received: %s" % (brightness_range,)
1483 )
1484 self.brightness_range = brightness_range
1486 def flow(
1487 self,
1488 x,
1489 y=None,
1490 batch_size=32,
1491 shuffle=True,
1492 sample_weight=None,
1493 seed=None,
1494 save_to_dir=None,
1495 save_prefix="",
1496 save_format="png",
1497 ignore_class_split=False,
1498 subset=None,
1499 ):
1500 """Takes data & label arrays, generates batches of augmented data.
1502 Args:
1503 x: Input data. Numpy array of rank 4 or a tuple. If tuple, the first
1504 element should contain the images and the second element another
1505 numpy array or a list of numpy arrays that gets passed to the
1506 output without any modifications. Can be used to feed the model
1507 miscellaneous data along with the images. In case of grayscale
1508 data, the channels axis of the image array should have value 1, in
1509 case of RGB data, it should have value 3, and in case of RGBA
1510 data, it should have value 4.
1511 y: Labels.
1512 batch_size: Int (default: 32).
1513 shuffle: Boolean (default: True).
1514 sample_weight: Sample weights.
1515 seed: Int (default: None).
1516 save_to_dir: None or str (default: None). This allows you to
1517 optionally specify a directory to which to save the augmented
1518 pictures being generated (useful for visualizing what you are
1519 doing).
1520 save_prefix: Str (default: `''`). Prefix to use for filenames of
1521 saved pictures (only relevant if `save_to_dir` is set).
1522 save_format: one of "png", "jpeg", "bmp", "pdf", "ppm", "gif",
1523 "tif", "jpg" (only relevant if `save_to_dir` is set). Default:
1524 "png".
1525 ignore_class_split: Boolean (default: False), ignore difference
1526 in number of classes in labels across train and validation
1527 split (useful for non-classification tasks)
1528 subset: Subset of data (`"training"` or `"validation"`) if
1529 `validation_split` is set in `ImageDataGenerator`.
1531 Returns:
1532 An `Iterator` yielding tuples of `(x, y)`
1533 where `x` is a numpy array of image data
1534 (in the case of a single image input) or a list
1535 of numpy arrays (in the case with
1536 additional inputs) and `y` is a numpy array
1537 of corresponding labels. If 'sample_weight' is not None,
1538 the yielded tuples are of the form `(x, y, sample_weight)`.
1539 If `y` is None, only the numpy array `x` is returned.
1540 Raises:
1541 ValueError: If the Value of the argument, `subset` is other than
1542 "training" or "validation".
1544 """
1545 return NumpyArrayIterator(
1546 x,
1547 y,
1548 self,
1549 batch_size=batch_size,
1550 shuffle=shuffle,
1551 sample_weight=sample_weight,
1552 seed=seed,
1553 data_format=self.data_format,
1554 save_to_dir=save_to_dir,
1555 save_prefix=save_prefix,
1556 save_format=save_format,
1557 ignore_class_split=ignore_class_split,
1558 subset=subset,
1559 dtype=self.dtype,
1560 )
1562 def flow_from_directory(
1563 self,
1564 directory,
1565 target_size=(256, 256),
1566 color_mode="rgb",
1567 classes=None,
1568 class_mode="categorical",
1569 batch_size=32,
1570 shuffle=True,
1571 seed=None,
1572 save_to_dir=None,
1573 save_prefix="",
1574 save_format="png",
1575 follow_links=False,
1576 subset=None,
1577 interpolation="nearest",
1578 keep_aspect_ratio=False,
1579 ):
1580 """Takes the path to a directory & generates batches of augmented data.
1582 Args:
1583 directory: string, path to the target directory. It should contain
1584 one subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images
1585 inside each of the subdirectories directory tree will be included
1586 in the generator. See [this script](
1587 https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d)
1588 for more details.
1589 target_size: Tuple of integers `(height, width)`, defaults to `(256,
1590 256)`. The dimensions to which all images found will be resized.
1591 color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb".
1592 Whether the images will be converted to have 1, 3, or 4 channels.
1593 classes: Optional list of class subdirectories (e.g. `['dogs',
1594 'cats']`). Default: None. If not provided, the list of classes
1595 will be automatically inferred from the subdirectory
1596 names/structure under `directory`, where each subdirectory will be
1597 treated as a different class (and the order of the classes, which
1598 will map to the label indices, will be alphanumeric). The
1599 dictionary containing the mapping from class names to class
1600 indices can be obtained via the attribute `class_indices`.
1601 class_mode: One of "categorical", "binary", "sparse",
1602 "input", or None. Default: "categorical".
1603 Determines the type of label arrays that are returned:
1604 - "categorical" will be 2D one-hot encoded labels,
1605 - "binary" will be 1D binary labels,
1606 "sparse" will be 1D integer labels,
1607 - "input" will be images identical
1608 to input images (mainly used to work with autoencoders).
1609 - If None, no labels are returned
1610 (the generator will only yield batches of image data,
1611 which is useful to use with `model.predict_generator()`).
1612 Please note that in case of class_mode None,
1613 the data still needs to reside in a subdirectory
1614 of `directory` for it to work correctly.
1615 batch_size: Size of the batches of data (default: 32).
1616 shuffle: Whether to shuffle the data (default: True) If set to
1617 False, sorts the data in alphanumeric order.
1618 seed: Optional random seed for shuffling and transformations.
1619 save_to_dir: None or str (default: None). This allows you to
1620 optionally specify a directory to which to save the augmented
1621 pictures being generated (useful for visualizing what you are
1622 doing).
1623 save_prefix: Str. Prefix to use for filenames of saved pictures
1624 (only relevant if `save_to_dir` is set).
1625 save_format: one of "png", "jpeg", "bmp", "pdf", "ppm", "gif",
1626 "tif", "jpg" (only relevant if `save_to_dir` is set). Default:
1627 "png".
1628 follow_links: Whether to follow symlinks inside
1629 class subdirectories (default: False).
1630 subset: Subset of data (`"training"` or `"validation"`) if
1631 `validation_split` is set in `ImageDataGenerator`.
1632 interpolation: Interpolation method used to resample the image if
1633 the target size is different from that of the loaded image.
1634 Supported methods are `"nearest"`, `"bilinear"`, and `"bicubic"`.
1635 If PIL version 1.1.3 or newer is installed, `"lanczos"` is also
1636 supported. If PIL version 3.4.0 or newer is installed, `"box"` and
1637 `"hamming"` are also supported. By default, `"nearest"` is used.
1638 keep_aspect_ratio: Boolean, whether to resize images to a target
1639 size without aspect ratio distortion. The image is cropped in
1640 the center with target aspect ratio before resizing.
1642 Returns:
1643 A `DirectoryIterator` yielding tuples of `(x, y)`
1644 where `x` is a numpy array containing a batch
1645 of images with shape `(batch_size, *target_size, channels)`
1646 and `y` is a numpy array of corresponding labels.
1647 """
1648 return DirectoryIterator(
1649 directory,
1650 self,
1651 target_size=target_size,
1652 color_mode=color_mode,
1653 keep_aspect_ratio=keep_aspect_ratio,
1654 classes=classes,
1655 class_mode=class_mode,
1656 data_format=self.data_format,
1657 batch_size=batch_size,
1658 shuffle=shuffle,
1659 seed=seed,
1660 save_to_dir=save_to_dir,
1661 save_prefix=save_prefix,
1662 save_format=save_format,
1663 follow_links=follow_links,
1664 subset=subset,
1665 interpolation=interpolation,
1666 dtype=self.dtype,
1667 )
1669 def flow_from_dataframe(
1670 self,
1671 dataframe,
1672 directory=None,
1673 x_col="filename",
1674 y_col="class",
1675 weight_col=None,
1676 target_size=(256, 256),
1677 color_mode="rgb",
1678 classes=None,
1679 class_mode="categorical",
1680 batch_size=32,
1681 shuffle=True,
1682 seed=None,
1683 save_to_dir=None,
1684 save_prefix="",
1685 save_format="png",
1686 subset=None,
1687 interpolation="nearest",
1688 validate_filenames=True,
1689 **kwargs,
1690 ):
1691 """Takes the dataframe and the path to a directory + generates batches.
1693 The generated batches contain augmented/normalized data.
1695 **A simple tutorial can be found **[here](
1696 http://bit.ly/keras_flow_from_dataframe).
1698 Args:
1699 dataframe: Pandas dataframe containing the filepaths relative to
1700 `directory` (or absolute paths if `directory` is None) of the
1701 images in a string column. It should include other column/s
1702 depending on the `class_mode`:
1703 - if `class_mode` is `"categorical"` (default value) it must
1704 include the `y_col` column with the class/es of each image.
1705 Values in column can be string/list/tuple if a single class
1706 or list/tuple if multiple classes.
1707 - if `class_mode` is `"binary"` or `"sparse"` it must include
1708 the given `y_col` column with class values as strings.
1709 - if `class_mode` is `"raw"` or `"multi_output"` it should
1710 contain the columns specified in `y_col`.
1711 - if `class_mode` is `"input"` or `None` no extra column is
1712 needed.
1713 directory: string, path to the directory to read images from. If
1714 `None`, data in `x_col` column should be absolute paths.
1715 x_col: string, column in `dataframe` that contains the filenames (or
1716 absolute paths if `directory` is `None`).
1717 y_col: string or list, column/s in `dataframe` that has the target
1718 data.
1719 weight_col: string, column in `dataframe` that contains the sample
1720 weights. Default: `None`.
1721 target_size: tuple of integers `(height, width)`, default: `(256,
1722 256)`. The dimensions to which all images found will be resized.
1723 color_mode: one of "grayscale", "rgb", "rgba". Default: "rgb".
1724 Whether the images will be converted to have 1 or 3 color
1725 channels.
1726 classes: optional list of classes (e.g. `['dogs', 'cats']`). Default
1727 is None. If not provided, the list of classes will be
1728 automatically inferred from the `y_col`, which will map to the
1729 label indices, will be alphanumeric). The dictionary containing
1730 the mapping from class names to class indices can be obtained via
1731 the attribute `class_indices`.
1732 class_mode: one of "binary", "categorical", "input", "multi_output",
1733 "raw", sparse" or None. Default: "categorical".
1734 Mode for yielding the targets:
1735 - `"binary"`: 1D numpy array of binary labels,
1736 - `"categorical"`: 2D numpy array of one-hot encoded labels.
1737 Supports multi-label output.
1738 - `"input"`: images identical to input images (mainly used to
1739 work with autoencoders),
1740 - `"multi_output"`: list with the values of the different
1741 columns,
1742 - `"raw"`: numpy array of values in `y_col` column(s),
1743 - `"sparse"`: 1D numpy array of integer labels,
1744 - `None`, no targets are returned (the generator will only yield
1745 batches of image data, which is useful to use in
1746 `model.predict()`).
1747 batch_size: size of the batches of data (default: 32).
1748 shuffle: whether to shuffle the data (default: True)
1749 seed: optional random seed for shuffling and transformations.
1750 save_to_dir: None or str (default: None). This allows you to
1751 optionally specify a directory to which to save the augmented
1752 pictures being generated (useful for visualizing what you are
1753 doing).
1754 save_prefix: str. Prefix to use for filenames of saved pictures
1755 (only relevant if `save_to_dir` is set).
1756 save_format: one of "png", "jpeg", "bmp", "pdf", "ppm", "gif",
1757 "tif", "jpg" (only relevant if `save_to_dir` is set). Default:
1758 "png".
1759 subset: Subset of data (`"training"` or `"validation"`) if
1760 `validation_split` is set in `ImageDataGenerator`.
1761 interpolation: Interpolation method used to resample the image if
1762 the target size is different from that of the loaded image.
1763 Supported methods are `"nearest"`, `"bilinear"`, and `"bicubic"`.
1764 If PIL version 1.1.3 or newer is installed, `"lanczos"` is also
1765 supported. If PIL version 3.4.0 or newer is installed, `"box"` and
1766 `"hamming"` are also supported. By default, `"nearest"` is used.
1767 validate_filenames: Boolean, whether to validate image filenames in
1768 `x_col`. If `True`, invalid images will be ignored. Disabling this
1769 option can lead to speed-up in the execution of this function.
1770 Defaults to `True`.
1771 **kwargs: legacy arguments for raising deprecation warnings.
1773 Returns:
1774 A `DataFrameIterator` yielding tuples of `(x, y)`
1775 where `x` is a numpy array containing a batch
1776 of images with shape `(batch_size, *target_size, channels)`
1777 and `y` is a numpy array of corresponding labels.
1778 """
1779 if "has_ext" in kwargs:
1780 warnings.warn(
1781 "has_ext is deprecated, filenames in the dataframe have "
1782 "to match the exact filenames in disk.",
1783 DeprecationWarning,
1784 )
1785 if "sort" in kwargs:
1786 warnings.warn(
1787 "sort is deprecated, batches will be created in the"
1788 "same order than the filenames provided if shuffle"
1789 "is set to False.",
1790 DeprecationWarning,
1791 )
1792 if class_mode == "other":
1793 warnings.warn(
1794 '`class_mode` "other" is deprecated, please use '
1795 '`class_mode` "raw".',
1796 DeprecationWarning,
1797 )
1798 class_mode = "raw"
1799 if "drop_duplicates" in kwargs:
1800 warnings.warn(
1801 "drop_duplicates is deprecated, you can drop duplicates "
1802 "by using the pandas.DataFrame.drop_duplicates method.",
1803 DeprecationWarning,
1804 )
1806 return DataFrameIterator(
1807 dataframe,
1808 directory,
1809 self,
1810 x_col=x_col,
1811 y_col=y_col,
1812 weight_col=weight_col,
1813 target_size=target_size,
1814 color_mode=color_mode,
1815 classes=classes,
1816 class_mode=class_mode,
1817 data_format=self.data_format,
1818 batch_size=batch_size,
1819 shuffle=shuffle,
1820 seed=seed,
1821 save_to_dir=save_to_dir,
1822 save_prefix=save_prefix,
1823 save_format=save_format,
1824 subset=subset,
1825 interpolation=interpolation,
1826 validate_filenames=validate_filenames,
1827 dtype=self.dtype,
1828 )
1830 def standardize(self, x):
1831 """Applies the normalization configuration in-place to a batch of
1832 inputs.
1834 `x` is changed in-place since the function is mainly used internally
1835 to standardize images and feed them to your network. If a copy of `x`
1836 would be created instead it would have a significant performance cost.
1837 If you want to apply this method without changing the input in-place
1838 you can call the method creating a copy before:
1840 standardize(np.copy(x))
1842 Args:
1843 x: Batch of inputs to be normalized.
1845 Returns:
1846 The inputs, normalized.
1847 """
1848 if self.preprocessing_function:
1849 x = self.preprocessing_function(x)
1850 if self.rescale:
1851 x *= self.rescale
1852 if self.samplewise_center:
1853 x -= np.mean(x, keepdims=True)
1854 if self.samplewise_std_normalization:
1855 x /= np.std(x, keepdims=True) + 1e-6
1857 if self.featurewise_center:
1858 if self.mean is not None:
1859 x -= self.mean
1860 else:
1861 warnings.warn(
1862 "This ImageDataGenerator specifies "
1863 "`featurewise_center`, but it hasn't "
1864 "been fit on any training data. Fit it "
1865 "first by calling `.fit(numpy_data)`."
1866 )
1867 if self.featurewise_std_normalization:
1868 if self.std is not None:
1869 x /= self.std + 1e-6
1870 else:
1871 warnings.warn(
1872 "This ImageDataGenerator specifies "
1873 "`featurewise_std_normalization`, "
1874 "but it hasn't "
1875 "been fit on any training data. Fit it "
1876 "first by calling `.fit(numpy_data)`."
1877 )
1878 if self.zca_whitening:
1879 if self.zca_whitening_matrix is not None:
1880 flat_x = x.reshape(-1, np.prod(x.shape[-3:]))
1881 white_x = flat_x @ self.zca_whitening_matrix
1882 x = np.reshape(white_x, x.shape)
1883 else:
1884 warnings.warn(
1885 "This ImageDataGenerator specifies "
1886 "`zca_whitening`, but it hasn't "
1887 "been fit on any training data. Fit it "
1888 "first by calling `.fit(numpy_data)`."
1889 )
1890 return x
1892 def get_random_transform(self, img_shape, seed=None):
1893 """Generates random parameters for a transformation.
1895 Args:
1896 img_shape: Tuple of integers.
1897 Shape of the image that is transformed.
1898 seed: Random seed.
1900 Returns:
1901 A dictionary containing randomly chosen parameters describing the
1902 transformation.
1903 """
1904 img_row_axis = self.row_axis - 1
1905 img_col_axis = self.col_axis - 1
1907 if seed is not None:
1908 np.random.seed(seed)
1910 if self.rotation_range:
1911 theta = np.random.uniform(-self.rotation_range, self.rotation_range)
1912 else:
1913 theta = 0
1915 if self.height_shift_range:
1916 try: # 1-D array-like or int
1917 tx = np.random.choice(self.height_shift_range)
1918 tx *= np.random.choice([-1, 1])
1919 except ValueError: # floating point
1920 tx = np.random.uniform(
1921 -self.height_shift_range, self.height_shift_range
1922 )
1923 if np.max(self.height_shift_range) < 1:
1924 tx *= img_shape[img_row_axis]
1925 else:
1926 tx = 0
1928 if self.width_shift_range:
1929 try: # 1-D array-like or int
1930 ty = np.random.choice(self.width_shift_range)
1931 ty *= np.random.choice([-1, 1])
1932 except ValueError: # floating point
1933 ty = np.random.uniform(
1934 -self.width_shift_range, self.width_shift_range
1935 )
1936 if np.max(self.width_shift_range) < 1:
1937 ty *= img_shape[img_col_axis]
1938 else:
1939 ty = 0
1941 if self.shear_range:
1942 shear = np.random.uniform(-self.shear_range, self.shear_range)
1943 else:
1944 shear = 0
1946 if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
1947 zx, zy = 1, 1
1948 else:
1949 zx, zy = np.random.uniform(
1950 self.zoom_range[0], self.zoom_range[1], 2
1951 )
1953 flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip
1954 flip_vertical = (np.random.random() < 0.5) * self.vertical_flip
1956 channel_shift_intensity = None
1957 if self.channel_shift_range != 0:
1958 channel_shift_intensity = np.random.uniform(
1959 -self.channel_shift_range, self.channel_shift_range
1960 )
1962 brightness = None
1963 if self.brightness_range is not None:
1964 brightness = np.random.uniform(
1965 self.brightness_range[0], self.brightness_range[1]
1966 )
1968 transform_parameters = {
1969 "theta": theta,
1970 "tx": tx,
1971 "ty": ty,
1972 "shear": shear,
1973 "zx": zx,
1974 "zy": zy,
1975 "flip_horizontal": flip_horizontal,
1976 "flip_vertical": flip_vertical,
1977 "channel_shift_intensity": channel_shift_intensity,
1978 "brightness": brightness,
1979 }
1981 return transform_parameters
1983 def apply_transform(self, x, transform_parameters):
1984 """Applies a transformation to an image according to given parameters.
1986 Args:
1987 x: 3D tensor, single image.
1988 transform_parameters: Dictionary with string - parameter pairs
1989 describing the transformation.
1990 Currently, the following parameters
1991 from the dictionary are used:
1992 - `'theta'`: Float. Rotation angle in degrees.
1993 - `'tx'`: Float. Shift in the x direction.
1994 - `'ty'`: Float. Shift in the y direction.
1995 - `'shear'`: Float. Shear angle in degrees.
1996 - `'zx'`: Float. Zoom in the x direction.
1997 - `'zy'`: Float. Zoom in the y direction.
1998 - `'flip_horizontal'`: Boolean. Horizontal flip.
1999 - `'flip_vertical'`: Boolean. Vertical flip.
2000 - `'channel_shift_intensity'`: Float. Channel shift intensity.
2001 - `'brightness'`: Float. Brightness shift intensity.
2003 Returns:
2004 A transformed version of the input (same shape).
2005 """
2006 # x is a single image, so it doesn't have image number at index 0
2007 img_row_axis = self.row_axis - 1
2008 img_col_axis = self.col_axis - 1
2009 img_channel_axis = self.channel_axis - 1
2011 x = apply_affine_transform(
2012 x,
2013 transform_parameters.get("theta", 0),
2014 transform_parameters.get("tx", 0),
2015 transform_parameters.get("ty", 0),
2016 transform_parameters.get("shear", 0),
2017 transform_parameters.get("zx", 1),
2018 transform_parameters.get("zy", 1),
2019 row_axis=img_row_axis,
2020 col_axis=img_col_axis,
2021 channel_axis=img_channel_axis,
2022 fill_mode=self.fill_mode,
2023 cval=self.cval,
2024 order=self.interpolation_order,
2025 )
2027 if transform_parameters.get("channel_shift_intensity") is not None:
2028 x = apply_channel_shift(
2029 x,
2030 transform_parameters["channel_shift_intensity"],
2031 img_channel_axis,
2032 )
2034 if transform_parameters.get("flip_horizontal", False):
2035 x = flip_axis(x, img_col_axis)
2037 if transform_parameters.get("flip_vertical", False):
2038 x = flip_axis(x, img_row_axis)
2040 if transform_parameters.get("brightness") is not None:
2041 x = apply_brightness_shift(
2042 x, transform_parameters["brightness"], False
2043 )
2045 return x
2047 def random_transform(self, x, seed=None):
2048 """Applies a random transformation to an image.
2050 Args:
2051 x: 3D tensor, single image.
2052 seed: Random seed.
2054 Returns:
2055 A randomly transformed version of the input (same shape).
2056 """
2057 params = self.get_random_transform(x.shape, seed)
2058 return self.apply_transform(x, params)
2060 def fit(self, x, augment=False, rounds=1, seed=None):
2061 """Fits the data generator to some sample data.
2063 This computes the internal data stats related to the
2064 data-dependent transformations, based on an array of sample data.
2066 Only required if `featurewise_center` or
2067 `featurewise_std_normalization` or `zca_whitening` are set to True.
2069 When `rescale` is set to a value, rescaling is applied to
2070 sample data before computing the internal data stats.
2072 Args:
2073 x: Sample data. Should have rank 4.
2074 In case of grayscale data,
2075 the channels axis should have value 1, in case
2076 of RGB data, it should have value 3, and in case
2077 of RGBA data, it should have value 4.
2078 augment: Boolean (default: False).
2079 Whether to fit on randomly augmented samples.
2080 rounds: Int (default: 1).
2081 If using data augmentation (`augment=True`),
2082 this is how many augmentation passes over the data to use.
2083 seed: Int (default: None). Random seed.
2084 """
2085 x = np.asarray(x, dtype=self.dtype)
2086 if x.ndim != 4:
2087 raise ValueError(
2088 "Input to `.fit()` should have rank 4. Got array with shape: "
2089 + str(x.shape)
2090 )
2091 if x.shape[self.channel_axis] not in {1, 3, 4}:
2092 warnings.warn(
2093 "Expected input to be images (as Numpy array) "
2094 'following the data format convention "'
2095 + self.data_format
2096 + '" (channels on axis '
2097 + str(self.channel_axis)
2098 + "), i.e. expected either 1, 3 or 4 channels on axis "
2099 + str(self.channel_axis)
2100 + ". However, it was passed an array with shape "
2101 + str(x.shape)
2102 + " ("
2103 + str(x.shape[self.channel_axis])
2104 + " channels)."
2105 )
2107 if seed is not None:
2108 np.random.seed(seed)
2110 x = np.copy(x)
2111 if self.rescale:
2112 x *= self.rescale
2114 if augment:
2115 ax = np.zeros(
2116 tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
2117 dtype=self.dtype,
2118 )
2119 for r in range(rounds):
2120 for i in range(x.shape[0]):
2121 ax[i + r * x.shape[0]] = self.random_transform(x[i])
2122 x = ax
2124 if self.featurewise_center:
2125 self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
2126 broadcast_shape = [1, 1, 1]
2127 broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
2128 self.mean = np.reshape(self.mean, broadcast_shape)
2129 x -= self.mean
2131 if self.featurewise_std_normalization:
2132 self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
2133 broadcast_shape = [1, 1, 1]
2134 broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
2135 self.std = np.reshape(self.std, broadcast_shape)
2136 x /= self.std + 1e-6
2138 if self.zca_whitening:
2139 n = len(x)
2140 flat_x = np.reshape(x, (n, -1))
2142 u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False)
2143 s_inv = np.sqrt(n) / (s + self.zca_epsilon)
2144 self.zca_whitening_matrix = (u * s_inv).dot(u.T)
2147@keras_export("keras.preprocessing.image.random_rotation")
2148def random_rotation(
2149 x,
2150 rg,
2151 row_axis=1,
2152 col_axis=2,
2153 channel_axis=0,
2154 fill_mode="nearest",
2155 cval=0.0,
2156 interpolation_order=1,
2157):
2158 """Performs a random rotation of a Numpy image tensor.
2160 Deprecated: `tf.keras.preprocessing.image.random_rotation` does not operate
2161 on tensors and is not recommended for new code. Prefer
2162 `tf.keras.layers.RandomRotation` which provides equivalent functionality as
2163 a preprocessing layer. For more information, see the tutorial for
2164 [augmenting images](
2165 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
2166 the [preprocessing layer guide](
2167 https://www.tensorflow.org/guide/keras/preprocessing_layers).
2169 Args:
2170 x: Input tensor. Must be 3D.
2171 rg: Rotation range, in degrees.
2172 row_axis: Index of axis for rows in the input tensor.
2173 col_axis: Index of axis for columns in the input tensor.
2174 channel_axis: Index of axis for channels in the input tensor.
2175 fill_mode: Points outside the boundaries of the input
2176 are filled according to the given mode
2177 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
2178 cval: Value used for points outside the boundaries
2179 of the input if `mode='constant'`.
2180 interpolation_order: int, order of spline interpolation.
2181 see `ndimage.interpolation.affine_transform`
2183 Returns:
2184 Rotated Numpy image tensor.
2185 """
2186 theta = np.random.uniform(-rg, rg)
2187 x = apply_affine_transform(
2188 x,
2189 theta=theta,
2190 row_axis=row_axis,
2191 col_axis=col_axis,
2192 channel_axis=channel_axis,
2193 fill_mode=fill_mode,
2194 cval=cval,
2195 order=interpolation_order,
2196 )
2197 return x
2200@keras_export("keras.preprocessing.image.random_shift")
2201def random_shift(
2202 x,
2203 wrg,
2204 hrg,
2205 row_axis=1,
2206 col_axis=2,
2207 channel_axis=0,
2208 fill_mode="nearest",
2209 cval=0.0,
2210 interpolation_order=1,
2211):
2212 """Performs a random spatial shift of a Numpy image tensor.
2214 Deprecated: `tf.keras.preprocessing.image.random_shift` does not operate on
2215 tensors and is not recommended for new code. Prefer
2216 `tf.keras.layers.RandomTranslation` which provides equivalent functionality
2217 as a preprocessing layer. For more information, see the tutorial for
2218 [augmenting images](
2219 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
2220 the [preprocessing layer guide](
2221 https://www.tensorflow.org/guide/keras/preprocessing_layers).
2223 Args:
2224 x: Input tensor. Must be 3D.
2225 wrg: Width shift range, as a float fraction of the width.
2226 hrg: Height shift range, as a float fraction of the height.
2227 row_axis: Index of axis for rows in the input tensor.
2228 col_axis: Index of axis for columns in the input tensor.
2229 channel_axis: Index of axis for channels in the input tensor.
2230 fill_mode: Points outside the boundaries of the input
2231 are filled according to the given mode
2232 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
2233 cval: Value used for points outside the boundaries
2234 of the input if `mode='constant'`.
2235 interpolation_order: int, order of spline interpolation.
2236 see `ndimage.interpolation.affine_transform`
2238 Returns:
2239 Shifted Numpy image tensor.
2240 """
2241 h, w = x.shape[row_axis], x.shape[col_axis]
2242 tx = np.random.uniform(-hrg, hrg) * h
2243 ty = np.random.uniform(-wrg, wrg) * w
2244 x = apply_affine_transform(
2245 x,
2246 tx=tx,
2247 ty=ty,
2248 row_axis=row_axis,
2249 col_axis=col_axis,
2250 channel_axis=channel_axis,
2251 fill_mode=fill_mode,
2252 cval=cval,
2253 order=interpolation_order,
2254 )
2255 return x
2258@keras_export("keras.preprocessing.image.random_shear")
2259def random_shear(
2260 x,
2261 intensity,
2262 row_axis=1,
2263 col_axis=2,
2264 channel_axis=0,
2265 fill_mode="nearest",
2266 cval=0.0,
2267 interpolation_order=1,
2268):
2269 """Performs a random spatial shear of a Numpy image tensor.
2271 Args:
2272 x: Input tensor. Must be 3D.
2273 intensity: Transformation intensity in degrees.
2274 row_axis: Index of axis for rows in the input tensor.
2275 col_axis: Index of axis for columns in the input tensor.
2276 channel_axis: Index of axis for channels in the input tensor.
2277 fill_mode: Points outside the boundaries of the input
2278 are filled according to the given mode
2279 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
2280 cval: Value used for points outside the boundaries
2281 of the input if `mode='constant'`.
2282 interpolation_order: int, order of spline interpolation.
2283 see `ndimage.interpolation.affine_transform`
2285 Returns:
2286 Sheared Numpy image tensor.
2287 """
2288 shear = np.random.uniform(-intensity, intensity)
2289 x = apply_affine_transform(
2290 x,
2291 shear=shear,
2292 row_axis=row_axis,
2293 col_axis=col_axis,
2294 channel_axis=channel_axis,
2295 fill_mode=fill_mode,
2296 cval=cval,
2297 order=interpolation_order,
2298 )
2299 return x
2302@keras_export("keras.preprocessing.image.random_zoom")
2303def random_zoom(
2304 x,
2305 zoom_range,
2306 row_axis=1,
2307 col_axis=2,
2308 channel_axis=0,
2309 fill_mode="nearest",
2310 cval=0.0,
2311 interpolation_order=1,
2312):
2313 """Performs a random spatial zoom of a Numpy image tensor.
2315 Deprecated: `tf.keras.preprocessing.image.random_zoom` does not operate on
2316 tensors and is not recommended for new code. Prefer
2317 `tf.keras.layers.RandomZoom` which provides equivalent functionality as
2318 a preprocessing layer. For more information, see the tutorial for
2319 [augmenting images](
2320 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
2321 the [preprocessing layer guide](
2322 https://www.tensorflow.org/guide/keras/preprocessing_layers).
2324 Args:
2325 x: Input tensor. Must be 3D.
2326 zoom_range: Tuple of floats; zoom range for width and height.
2327 row_axis: Index of axis for rows in the input tensor.
2328 col_axis: Index of axis for columns in the input tensor.
2329 channel_axis: Index of axis for channels in the input tensor.
2330 fill_mode: Points outside the boundaries of the input
2331 are filled according to the given mode
2332 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
2333 cval: Value used for points outside the boundaries
2334 of the input if `mode='constant'`.
2335 interpolation_order: int, order of spline interpolation.
2336 see `ndimage.interpolation.affine_transform`
2338 Returns:
2339 Zoomed Numpy image tensor.
2341 Raises:
2342 ValueError: if `zoom_range` isn't a tuple.
2343 """
2344 if len(zoom_range) != 2:
2345 raise ValueError(
2346 "`zoom_range` should be a tuple or list of two floats. Received: %s"
2347 % (zoom_range,)
2348 )
2350 if zoom_range[0] == 1 and zoom_range[1] == 1:
2351 zx, zy = 1, 1
2352 else:
2353 zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
2354 x = apply_affine_transform(
2355 x,
2356 zx=zx,
2357 zy=zy,
2358 row_axis=row_axis,
2359 col_axis=col_axis,
2360 channel_axis=channel_axis,
2361 fill_mode=fill_mode,
2362 cval=cval,
2363 order=interpolation_order,
2364 )
2365 return x
2368@keras_export("keras.preprocessing.image.apply_channel_shift")
2369def apply_channel_shift(x, intensity, channel_axis=0):
2370 """Performs a channel shift.
2372 Args:
2373 x: Input tensor. Must be 3D.
2374 intensity: Transformation intensity.
2375 channel_axis: Index of axis for channels in the input tensor.
2377 Returns:
2378 Numpy image tensor.
2379 """
2380 x = np.rollaxis(x, channel_axis, 0)
2381 min_x, max_x = np.min(x), np.max(x)
2382 channel_images = [
2383 np.clip(x_channel + intensity, min_x, max_x) for x_channel in x
2384 ]
2385 x = np.stack(channel_images, axis=0)
2386 x = np.rollaxis(x, 0, channel_axis + 1)
2387 return x
2390@keras_export("keras.preprocessing.image.random_channel_shift")
2391def random_channel_shift(x, intensity_range, channel_axis=0):
2392 """Performs a random channel shift.
2394 Args:
2395 x: Input tensor. Must be 3D.
2396 intensity_range: Transformation intensity.
2397 channel_axis: Index of axis for channels in the input tensor.
2399 Returns:
2400 Numpy image tensor.
2401 """
2402 intensity = np.random.uniform(-intensity_range, intensity_range)
2403 return apply_channel_shift(x, intensity, channel_axis=channel_axis)
2406@keras_export("keras.preprocessing.image.apply_brightness_shift")
2407def apply_brightness_shift(x, brightness, scale=True):
2408 """Performs a brightness shift.
2410 Args:
2411 x: Input tensor. Must be 3D.
2412 brightness: Float. The new brightness value.
2413 scale: Whether to rescale the image such that minimum and maximum values
2414 are 0 and 255 respectively. Default: True.
2416 Returns:
2417 Numpy image tensor.
2419 Raises:
2420 ImportError: if PIL is not available.
2421 """
2422 if ImageEnhance is None:
2423 raise ImportError(
2424 "Using brightness shifts requires PIL. Install PIL or Pillow."
2425 )
2426 x_min, x_max = np.min(x), np.max(x)
2427 local_scale = (x_min < 0) or (x_max > 255)
2428 x = image_utils.array_to_img(x, scale=local_scale or scale)
2429 x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
2430 x = imgenhancer_Brightness.enhance(brightness)
2431 x = image_utils.img_to_array(x)
2432 if not scale and local_scale:
2433 x = x / 255 * (x_max - x_min) + x_min
2434 return x
2437@keras_export("keras.preprocessing.image.random_brightness")
2438def random_brightness(x, brightness_range, scale=True):
2439 """Performs a random brightness shift.
2441 Deprecated: `tf.keras.preprocessing.image.random_brightness` does not
2442 operate on tensors and is not recommended for new code. Prefer
2443 `tf.keras.layers.RandomBrightness` which provides equivalent functionality
2444 as a preprocessing layer. For more information, see the tutorial for
2445 [augmenting images](
2446 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as
2447 the [preprocessing layer guide](
2448 https://www.tensorflow.org/guide/keras/preprocessing_layers).
2450 Args:
2451 x: Input tensor. Must be 3D.
2452 brightness_range: Tuple of floats; brightness range.
2453 scale: Whether to rescale the image such that minimum and maximum values
2454 are 0 and 255 respectively. Default: True.
2456 Returns:
2457 Numpy image tensor.
2459 Raises:
2460 ValueError if `brightness_range` isn't a tuple.
2461 """
2462 if len(brightness_range) != 2:
2463 raise ValueError(
2464 "`brightness_range should be tuple or list of two floats. "
2465 "Received: %s" % (brightness_range,)
2466 )
2468 u = np.random.uniform(brightness_range[0], brightness_range[1])
2469 return apply_brightness_shift(x, u, scale)
2472def transform_matrix_offset_center(matrix, x, y):
2473 o_x = float(x) / 2 - 0.5
2474 o_y = float(y) / 2 - 0.5
2475 offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
2476 reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
2477 transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
2478 return transform_matrix
2481@keras_export("keras.preprocessing.image.apply_affine_transform")
2482def apply_affine_transform(
2483 x,
2484 theta=0,
2485 tx=0,
2486 ty=0,
2487 shear=0,
2488 zx=1,
2489 zy=1,
2490 row_axis=1,
2491 col_axis=2,
2492 channel_axis=0,
2493 fill_mode="nearest",
2494 cval=0.0,
2495 order=1,
2496):
2497 """Applies an affine transformation specified by the parameters given.
2499 Args:
2500 x: 3D numpy array - a 2D image with one or more channels.
2501 theta: Rotation angle in degrees.
2502 tx: Width shift.
2503 ty: Heigh shift.
2504 shear: Shear angle in degrees.
2505 zx: Zoom in x direction.
2506 zy: Zoom in y direction
2507 row_axis: Index of axis for rows (aka Y axis) in the input
2508 image. Direction: left to right.
2509 col_axis: Index of axis for columns (aka X axis) in the input
2510 image. Direction: top to bottom.
2511 channel_axis: Index of axis for channels in the input image.
2512 fill_mode: Points outside the boundaries of the input
2513 are filled according to the given mode
2514 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
2515 cval: Value used for points outside the boundaries
2516 of the input if `mode='constant'`.
2517 order: int, order of interpolation
2519 Returns:
2520 The transformed version of the input.
2522 Raises:
2523 ImportError: if SciPy is not available.
2524 """
2525 if scipy is None:
2526 raise ImportError("Image transformations require SciPy. Install SciPy.")
2528 # Input sanity checks:
2529 # 1. x must 2D image with one or more channels (i.e., a 3D tensor)
2530 # 2. channels must be either first or last dimension
2531 if np.unique([row_axis, col_axis, channel_axis]).size != 3:
2532 raise ValueError(
2533 "'row_axis', 'col_axis', and 'channel_axis' must be distinct"
2534 )
2536 # shall we support negative indices?
2537 valid_indices = set([0, 1, 2])
2538 actual_indices = set([row_axis, col_axis, channel_axis])
2539 if actual_indices != valid_indices:
2540 raise ValueError(
2541 f"Invalid axis' indices: {actual_indices - valid_indices}"
2542 )
2544 if x.ndim != 3:
2545 raise ValueError("Input arrays must be multi-channel 2D images.")
2546 if channel_axis not in [0, 2]:
2547 raise ValueError(
2548 "Channels are allowed and the first and last dimensions."
2549 )
2551 transform_matrix = None
2552 if theta != 0:
2553 theta = np.deg2rad(theta)
2554 rotation_matrix = np.array(
2555 [
2556 [np.cos(theta), -np.sin(theta), 0],
2557 [np.sin(theta), np.cos(theta), 0],
2558 [0, 0, 1],
2559 ]
2560 )
2561 transform_matrix = rotation_matrix
2563 if tx != 0 or ty != 0:
2564 shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
2565 if transform_matrix is None:
2566 transform_matrix = shift_matrix
2567 else:
2568 transform_matrix = np.dot(transform_matrix, shift_matrix)
2570 if shear != 0:
2571 shear = np.deg2rad(shear)
2572 shear_matrix = np.array(
2573 [[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]
2574 )
2575 if transform_matrix is None:
2576 transform_matrix = shear_matrix
2577 else:
2578 transform_matrix = np.dot(transform_matrix, shear_matrix)
2580 if zx != 1 or zy != 1:
2581 zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]])
2582 if transform_matrix is None:
2583 transform_matrix = zoom_matrix
2584 else:
2585 transform_matrix = np.dot(transform_matrix, zoom_matrix)
2587 if transform_matrix is not None:
2588 h, w = x.shape[row_axis], x.shape[col_axis]
2589 transform_matrix = transform_matrix_offset_center(
2590 transform_matrix, h, w
2591 )
2592 x = np.rollaxis(x, channel_axis, 0)
2594 # Matrix construction assumes that coordinates are x, y (in that order).
2595 # However, regular numpy arrays use y,x (aka i,j) indexing.
2596 # Possible solution is:
2597 # 1. Swap the x and y axes.
2598 # 2. Apply transform.
2599 # 3. Swap the x and y axes again to restore image-like data ordering.
2600 # Mathematically, it is equivalent to the following transformation:
2601 # M' = PMP, where P is the permutation matrix, M is the original
2602 # transformation matrix.
2603 if col_axis > row_axis:
2604 transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]]
2605 transform_matrix[[0, 1]] = transform_matrix[[1, 0]]
2606 final_affine_matrix = transform_matrix[:2, :2]
2607 final_offset = transform_matrix[:2, 2]
2609 channel_images = [
2610 ndimage.interpolation.affine_transform(
2611 x_channel,
2612 final_affine_matrix,
2613 final_offset,
2614 order=order,
2615 mode=fill_mode,
2616 cval=cval,
2617 )
2618 for x_channel in x
2619 ]
2620 x = np.stack(channel_images, axis=0)
2621 x = np.rollaxis(x, 0, channel_axis + 1)
2622 return x