Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/preprocessing/image.py: 14%

737 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Utilies for image preprocessing and augmentation. 

18 

19Deprecated: `tf.keras.preprocessing.image` APIs do not operate on tensors and 

20are not recommended for new code. Prefer loading data with 

21`tf.keras.utils.image_dataset_from_directory`, and then transforming the output 

22`tf.data.Dataset` with preprocessing layers. For more information, see the 

23tutorials for [loading images]( 

24https://www.tensorflow.org/tutorials/load_data/images) and [augmenting images]( 

25https://www.tensorflow.org/tutorials/images/data_augmentation), as well as the 

26[preprocessing layer guide]( 

27https://www.tensorflow.org/guide/keras/preprocessing_layers). 

28""" 

29 

30import collections 

31import multiprocessing 

32import os 

33import threading 

34import warnings 

35 

36import numpy as np 

37 

38from keras.src import backend 

39from keras.src.utils import data_utils 

40from keras.src.utils import image_utils 

41from keras.src.utils import io_utils 

42 

43# isort: off 

44from tensorflow.python.util.tf_export import keras_export 

45 

46try: 

47 import scipy 

48 from scipy import linalg # noqa: F401 

49 from scipy import ndimage # noqa: F401 

50except ImportError: 

51 pass 

52try: 

53 from PIL import ImageEnhance 

54except ImportError: 

55 ImageEnhance = None 

56 

57 

58@keras_export("keras.preprocessing.image.Iterator") 

59class Iterator(data_utils.Sequence): 

60 """Base class for image data iterators. 

61 

62 Deprecated: `tf.keras.preprocessing.image.Iterator` is not recommended for 

63 new code. Prefer loading images with 

64 `tf.keras.utils.image_dataset_from_directory` and transforming the output 

65 `tf.data.Dataset` with preprocessing layers. For more information, see the 

66 tutorials for [loading images]( 

67 https://www.tensorflow.org/tutorials/load_data/images) and 

68 [augmenting images]( 

69 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

70 the [preprocessing layer guide]( 

71 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

72 

73 Every `Iterator` must implement the `_get_batches_of_transformed_samples` 

74 method. 

75 

76 Args: 

77 n: Integer, total number of samples in the dataset to loop over. 

78 batch_size: Integer, size of a batch. 

79 shuffle: Boolean, whether to shuffle the data between epochs. 

80 seed: Random seeding for data shuffling. 

81 """ 

82 

83 white_list_formats = ("png", "jpg", "jpeg", "bmp", "ppm", "tif", "tiff") 

84 

85 def __init__(self, n, batch_size, shuffle, seed): 

86 self.n = n 

87 self.batch_size = batch_size 

88 self.seed = seed 

89 self.shuffle = shuffle 

90 self.batch_index = 0 

91 self.total_batches_seen = 0 

92 self.lock = threading.Lock() 

93 self.index_array = None 

94 self.index_generator = self._flow_index() 

95 

96 def _set_index_array(self): 

97 self.index_array = np.arange(self.n) 

98 if self.shuffle: 

99 self.index_array = np.random.permutation(self.n) 

100 

101 def __getitem__(self, idx): 

102 if idx >= len(self): 

103 raise ValueError( 

104 "Asked to retrieve element {idx}, " 

105 "but the Sequence " 

106 "has length {length}".format(idx=idx, length=len(self)) 

107 ) 

108 if self.seed is not None: 

109 np.random.seed(self.seed + self.total_batches_seen) 

110 self.total_batches_seen += 1 

111 if self.index_array is None: 

112 self._set_index_array() 

113 index_array = self.index_array[ 

114 self.batch_size * idx : self.batch_size * (idx + 1) 

115 ] 

116 return self._get_batches_of_transformed_samples(index_array) 

117 

118 def __len__(self): 

119 return (self.n + self.batch_size - 1) // self.batch_size # round up 

120 

121 def on_epoch_end(self): 

122 self._set_index_array() 

123 

124 def reset(self): 

125 self.batch_index = 0 

126 

127 def _flow_index(self): 

128 # Ensure self.batch_index is 0. 

129 self.reset() 

130 while 1: 

131 if self.seed is not None: 

132 np.random.seed(self.seed + self.total_batches_seen) 

133 if self.batch_index == 0: 

134 self._set_index_array() 

135 

136 if self.n == 0: 

137 # Avoiding modulo by zero error 

138 current_index = 0 

139 else: 

140 current_index = (self.batch_index * self.batch_size) % self.n 

141 if self.n > current_index + self.batch_size: 

142 self.batch_index += 1 

143 else: 

144 self.batch_index = 0 

145 self.total_batches_seen += 1 

146 yield self.index_array[ 

147 current_index : current_index + self.batch_size 

148 ] 

149 

150 def __iter__(self): 

151 # Needed if we want to do something like: 

152 # for x, y in data_gen.flow(...): 

153 return self 

154 

155 def __next__(self, *args, **kwargs): 

156 return self.next(*args, **kwargs) 

157 

158 def next(self): 

159 """For python 2.x. 

160 

161 Returns: 

162 The next batch. 

163 """ 

164 with self.lock: 

165 index_array = next(self.index_generator) 

166 # The transformation of images is not under thread lock 

167 # so it can be done in parallel 

168 return self._get_batches_of_transformed_samples(index_array) 

169 

170 def _get_batches_of_transformed_samples(self, index_array): 

171 """Gets a batch of transformed samples. 

172 

173 Args: 

174 index_array: Array of sample indices to include in batch. 

175 Returns: 

176 A batch of transformed samples. 

177 """ 

178 raise NotImplementedError 

179 

180 

181def _iter_valid_files(directory, white_list_formats, follow_links): 

182 """Iterates on files with extension. 

183 

184 Args: 

185 directory: Absolute path to the directory 

186 containing files to be counted 

187 white_list_formats: Set of strings containing allowed extensions for 

188 the files to be counted. 

189 follow_links: Boolean, follow symbolic links to subdirectories. 

190 Yields: 

191 Tuple of (root, filename) with extension in `white_list_formats`. 

192 """ 

193 

194 def _recursive_list(subpath): 

195 return sorted( 

196 os.walk(subpath, followlinks=follow_links), key=lambda x: x[0] 

197 ) 

198 

199 for root, _, files in _recursive_list(directory): 

200 for fname in sorted(files): 

201 if fname.lower().endswith(".tiff"): 

202 warnings.warn( 

203 'Using ".tiff" files with multiple bands ' 

204 "will cause distortion. Please verify your output." 

205 ) 

206 if fname.lower().endswith(white_list_formats): 

207 yield root, fname 

208 

209 

210def _list_valid_filenames_in_directory( 

211 directory, white_list_formats, split, class_indices, follow_links 

212): 

213 """Lists paths of files in `subdir` with extensions in `white_list_formats`. 

214 

215 Args: 

216 directory: absolute path to a directory containing the files to list. 

217 The directory name is used as class label 

218 and must be a key of `class_indices`. 

219 white_list_formats: set of strings containing allowed extensions for 

220 the files to be counted. 

221 split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into 

222 account a certain fraction of files in each directory. 

223 E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent 

224 of images in each directory. 

225 class_indices: dictionary mapping a class name to its index. 

226 follow_links: boolean, follow symbolic links to subdirectories. 

227 

228 Returns: 

229 classes: a list of class indices 

230 filenames: the path of valid files in `directory`, relative from 

231 `directory`'s parent (e.g., if `directory` is "dataset/class1", 

232 the filenames will be 

233 `["class1/file1.jpg", "class1/file2.jpg", ...]`). 

234 """ 

235 dirname = os.path.basename(directory) 

236 if split: 

237 all_files = list( 

238 _iter_valid_files(directory, white_list_formats, follow_links) 

239 ) 

240 num_files = len(all_files) 

241 start, stop = int(split[0] * num_files), int(split[1] * num_files) 

242 valid_files = all_files[start:stop] 

243 else: 

244 valid_files = _iter_valid_files( 

245 directory, white_list_formats, follow_links 

246 ) 

247 classes = [] 

248 filenames = [] 

249 for root, fname in valid_files: 

250 classes.append(class_indices[dirname]) 

251 absolute_path = os.path.join(root, fname) 

252 relative_path = os.path.join( 

253 dirname, os.path.relpath(absolute_path, directory) 

254 ) 

255 filenames.append(relative_path) 

256 

257 return classes, filenames 

258 

259 

260class BatchFromFilesMixin: 

261 """Adds methods related to getting batches from filenames. 

262 

263 It includes the logic to transform image files to batches. 

264 """ 

265 

266 def set_processing_attrs( 

267 self, 

268 image_data_generator, 

269 target_size, 

270 color_mode, 

271 data_format, 

272 save_to_dir, 

273 save_prefix, 

274 save_format, 

275 subset, 

276 interpolation, 

277 keep_aspect_ratio, 

278 ): 

279 """Sets attributes to use later for processing files into a batch. 

280 

281 Args: 

282 image_data_generator: Instance of `ImageDataGenerator` 

283 to use for random transformations and normalization. 

284 target_size: tuple of integers, dimensions to resize input images 

285 to. 

286 color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. 

287 Color mode to read images. 

288 data_format: String, one of `channels_first`, `channels_last`. 

289 save_to_dir: Optional directory where to save the pictures 

290 being yielded, in a viewable format. This is useful 

291 for visualizing the random transformations being 

292 applied, for debugging purposes. 

293 save_prefix: String prefix to use for saving sample 

294 images (if `save_to_dir` is set). 

295 save_format: Format to use for saving sample images 

296 (if `save_to_dir` is set). 

297 subset: Subset of data (`"training"` or `"validation"`) if 

298 validation_split is set in ImageDataGenerator. 

299 interpolation: Interpolation method used to resample the image if 

300 the target size is different from that of the loaded image. 

301 Supported methods are "nearest", "bilinear", and "bicubic". If 

302 PIL version 1.1.3 or newer is installed, "lanczos" is also 

303 supported. If PIL version 3.4.0 or newer is installed, "box" and 

304 "hamming" are also supported. By default, "nearest" is used. 

305 keep_aspect_ratio: Boolean, whether to resize images to a target 

306 size without aspect ratio distortion. The image is cropped in 

307 the center with target aspect ratio before resizing. 

308 """ 

309 self.image_data_generator = image_data_generator 

310 self.target_size = tuple(target_size) 

311 self.keep_aspect_ratio = keep_aspect_ratio 

312 if color_mode not in {"rgb", "rgba", "grayscale"}: 

313 raise ValueError( 

314 "Invalid color mode:", 

315 color_mode, 

316 '; expected "rgb", "rgba", or "grayscale".', 

317 ) 

318 self.color_mode = color_mode 

319 self.data_format = data_format 

320 if self.color_mode == "rgba": 

321 if self.data_format == "channels_last": 

322 self.image_shape = self.target_size + (4,) 

323 else: 

324 self.image_shape = (4,) + self.target_size 

325 elif self.color_mode == "rgb": 

326 if self.data_format == "channels_last": 

327 self.image_shape = self.target_size + (3,) 

328 else: 

329 self.image_shape = (3,) + self.target_size 

330 else: 

331 if self.data_format == "channels_last": 

332 self.image_shape = self.target_size + (1,) 

333 else: 

334 self.image_shape = (1,) + self.target_size 

335 self.save_to_dir = save_to_dir 

336 self.save_prefix = save_prefix 

337 self.save_format = save_format 

338 self.interpolation = interpolation 

339 if subset is not None: 

340 validation_split = self.image_data_generator._validation_split 

341 if subset == "validation": 

342 split = (0, validation_split) 

343 elif subset == "training": 

344 split = (validation_split, 1) 

345 else: 

346 raise ValueError( 

347 "Invalid subset name: %s;" 

348 'expected "training" or "validation"' % (subset,) 

349 ) 

350 else: 

351 split = None 

352 self.split = split 

353 self.subset = subset 

354 

355 def _get_batches_of_transformed_samples(self, index_array): 

356 """Gets a batch of transformed samples. 

357 

358 Args: 

359 index_array: Array of sample indices to include in batch. 

360 Returns: 

361 A batch of transformed samples. 

362 """ 

363 batch_x = np.zeros( 

364 (len(index_array),) + self.image_shape, dtype=self.dtype 

365 ) 

366 # build batch of image data 

367 # self.filepaths is dynamic, is better to call it once outside the loop 

368 filepaths = self.filepaths 

369 for i, j in enumerate(index_array): 

370 img = image_utils.load_img( 

371 filepaths[j], 

372 color_mode=self.color_mode, 

373 target_size=self.target_size, 

374 interpolation=self.interpolation, 

375 keep_aspect_ratio=self.keep_aspect_ratio, 

376 ) 

377 x = image_utils.img_to_array(img, data_format=self.data_format) 

378 # Pillow images should be closed after `load_img`, 

379 # but not PIL images. 

380 if hasattr(img, "close"): 

381 img.close() 

382 if self.image_data_generator: 

383 params = self.image_data_generator.get_random_transform(x.shape) 

384 x = self.image_data_generator.apply_transform(x, params) 

385 x = self.image_data_generator.standardize(x) 

386 batch_x[i] = x 

387 # optionally save augmented images to disk for debugging purposes 

388 if self.save_to_dir: 

389 for i, j in enumerate(index_array): 

390 img = image_utils.array_to_img( 

391 batch_x[i], self.data_format, scale=True 

392 ) 

393 fname = "{prefix}_{index}_{hash}.{format}".format( 

394 prefix=self.save_prefix, 

395 index=j, 

396 hash=np.random.randint(1e7), 

397 format=self.save_format, 

398 ) 

399 img.save(os.path.join(self.save_to_dir, fname)) 

400 # build batch of labels 

401 if self.class_mode == "input": 

402 batch_y = batch_x.copy() 

403 elif self.class_mode in {"binary", "sparse"}: 

404 batch_y = np.empty(len(batch_x), dtype=self.dtype) 

405 for i, n_observation in enumerate(index_array): 

406 batch_y[i] = self.classes[n_observation] 

407 elif self.class_mode == "categorical": 

408 batch_y = np.zeros( 

409 (len(batch_x), len(self.class_indices)), dtype=self.dtype 

410 ) 

411 for i, n_observation in enumerate(index_array): 

412 batch_y[i, self.classes[n_observation]] = 1.0 

413 elif self.class_mode == "multi_output": 

414 batch_y = [output[index_array] for output in self.labels] 

415 elif self.class_mode == "raw": 

416 batch_y = self.labels[index_array] 

417 else: 

418 return batch_x 

419 if self.sample_weight is None: 

420 return batch_x, batch_y 

421 else: 

422 return batch_x, batch_y, self.sample_weight[index_array] 

423 

424 @property 

425 def filepaths(self): 

426 """List of absolute paths to image files.""" 

427 raise NotImplementedError( 

428 "`filepaths` property method has not " 

429 "been implemented in {}.".format(type(self).__name__) 

430 ) 

431 

432 @property 

433 def labels(self): 

434 """Class labels of every observation.""" 

435 raise NotImplementedError( 

436 "`labels` property method has not been implemented in {}.".format( 

437 type(self).__name__ 

438 ) 

439 ) 

440 

441 @property 

442 def sample_weight(self): 

443 raise NotImplementedError( 

444 "`sample_weight` property method has not " 

445 "been implemented in {}.".format(type(self).__name__) 

446 ) 

447 

448 

449@keras_export("keras.preprocessing.image.DirectoryIterator") 

450class DirectoryIterator(BatchFromFilesMixin, Iterator): 

451 """Iterator capable of reading images from a directory on disk. 

452 

453 Deprecated: `tf.keras.preprocessing.image.DirectoryIterator` is not 

454 recommended for new code. Prefer loading images with 

455 `tf.keras.utils.image_dataset_from_directory` and transforming the output 

456 `tf.data.Dataset` with preprocessing layers. For more information, see the 

457 tutorials for [loading images]( 

458 https://www.tensorflow.org/tutorials/load_data/images) and 

459 [augmenting images]( 

460 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

461 the [preprocessing layer guide]( 

462 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

463 

464 Args: 

465 directory: Path to the directory to read images from. Each subdirectory 

466 in this directory will be considered to contain images from one class, 

467 or alternatively you could specify class subdirectories via the 

468 `classes` argument. 

469 image_data_generator: Instance of `ImageDataGenerator` to use for random 

470 transformations and normalization. 

471 target_size: tuple of integers, dimensions to resize input images to. 

472 color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read 

473 images. 

474 classes: Optional list of strings, names of subdirectories containing 

475 images from each class (e.g. `["dogs", "cats"]`). It will be computed 

476 automatically if not set. 

477 class_mode: Mode for yielding the targets: 

478 - `"binary"`: binary targets (if there are only two classes), 

479 - `"categorical"`: categorical targets, 

480 - `"sparse"`: integer targets, 

481 - `"input"`: targets are images identical to input images (mainly 

482 used to work with autoencoders), 

483 - `None`: no targets get yielded (only input images are yielded). 

484 batch_size: Integer, size of a batch. 

485 shuffle: Boolean, whether to shuffle the data between epochs. 

486 seed: Random seed for data shuffling. 

487 data_format: String, one of `channels_first`, `channels_last`. 

488 save_to_dir: Optional directory where to save the pictures being 

489 yielded, in a viewable format. This is useful for visualizing the 

490 random transformations being applied, for debugging purposes. 

491 save_prefix: String prefix to use for saving sample images (if 

492 `save_to_dir` is set). 

493 save_format: Format to use for saving sample images (if `save_to_dir` is 

494 set). 

495 subset: Subset of data (`"training"` or `"validation"`) if 

496 validation_split is set in ImageDataGenerator. 

497 interpolation: Interpolation method used to resample the image if the 

498 target size is different from that of the loaded image. Supported 

499 methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3 

500 or newer is installed, "lanczos" is also supported. If PIL version 

501 3.4.0 or newer is installed, "box" and "hamming" are also supported. 

502 By default, "nearest" is used. 

503 keep_aspect_ratio: Boolean, whether to resize images to a target size 

504 without aspect ratio distortion. The image is cropped in the center 

505 with target aspect ratio before resizing. 

506 dtype: Dtype to use for generated arrays. 

507 """ 

508 

509 allowed_class_modes = {"categorical", "binary", "sparse", "input", None} 

510 

511 def __init__( 

512 self, 

513 directory, 

514 image_data_generator, 

515 target_size=(256, 256), 

516 color_mode="rgb", 

517 classes=None, 

518 class_mode="categorical", 

519 batch_size=32, 

520 shuffle=True, 

521 seed=None, 

522 data_format=None, 

523 save_to_dir=None, 

524 save_prefix="", 

525 save_format="png", 

526 follow_links=False, 

527 subset=None, 

528 interpolation="nearest", 

529 keep_aspect_ratio=False, 

530 dtype=None, 

531 ): 

532 if data_format is None: 

533 data_format = backend.image_data_format() 

534 if dtype is None: 

535 dtype = backend.floatx() 

536 super().set_processing_attrs( 

537 image_data_generator, 

538 target_size, 

539 color_mode, 

540 data_format, 

541 save_to_dir, 

542 save_prefix, 

543 save_format, 

544 subset, 

545 interpolation, 

546 keep_aspect_ratio, 

547 ) 

548 self.directory = directory 

549 self.classes = classes 

550 if class_mode not in self.allowed_class_modes: 

551 raise ValueError( 

552 "Invalid class_mode: {}; expected one of: {}".format( 

553 class_mode, self.allowed_class_modes 

554 ) 

555 ) 

556 self.class_mode = class_mode 

557 self.dtype = dtype 

558 # First, count the number of samples and classes. 

559 self.samples = 0 

560 

561 if not classes: 

562 classes = [] 

563 for subdir in sorted(os.listdir(directory)): 

564 if os.path.isdir(os.path.join(directory, subdir)): 

565 classes.append(subdir) 

566 self.num_classes = len(classes) 

567 self.class_indices = dict(zip(classes, range(len(classes)))) 

568 

569 pool = multiprocessing.pool.ThreadPool() 

570 

571 # Second, build an index of the images 

572 # in the different class subfolders. 

573 results = [] 

574 self.filenames = [] 

575 i = 0 

576 for dirpath in (os.path.join(directory, subdir) for subdir in classes): 

577 results.append( 

578 pool.apply_async( 

579 _list_valid_filenames_in_directory, 

580 ( 

581 dirpath, 

582 self.white_list_formats, 

583 self.split, 

584 self.class_indices, 

585 follow_links, 

586 ), 

587 ) 

588 ) 

589 classes_list = [] 

590 for res in results: 

591 classes, filenames = res.get() 

592 classes_list.append(classes) 

593 self.filenames += filenames 

594 self.samples = len(self.filenames) 

595 self.classes = np.zeros((self.samples,), dtype="int32") 

596 for classes in classes_list: 

597 self.classes[i : i + len(classes)] = classes 

598 i += len(classes) 

599 

600 io_utils.print_msg( 

601 f"Found {self.samples} images belonging to " 

602 f"{self.num_classes} classes." 

603 ) 

604 pool.close() 

605 pool.join() 

606 self._filepaths = [ 

607 os.path.join(self.directory, fname) for fname in self.filenames 

608 ] 

609 super().__init__(self.samples, batch_size, shuffle, seed) 

610 

611 @property 

612 def filepaths(self): 

613 return self._filepaths 

614 

615 @property 

616 def labels(self): 

617 return self.classes 

618 

619 @property # mixin needs this property to work 

620 def sample_weight(self): 

621 # no sample weights will be returned 

622 return None 

623 

624 

625@keras_export("keras.preprocessing.image.NumpyArrayIterator") 

626class NumpyArrayIterator(Iterator): 

627 """Iterator yielding data from a Numpy array. 

628 

629 Deprecated: `tf.keras.preprocessing.image.NumpyArrayIterator` is not 

630 recommended for new code. Prefer loading images with 

631 `tf.keras.utils.image_dataset_from_directory` and transforming the output 

632 `tf.data.Dataset` with preprocessing layers. For more information, see the 

633 tutorials for [loading images]( 

634 https://www.tensorflow.org/tutorials/load_data/images) and 

635 [augmenting images]( 

636 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

637 the [preprocessing layer guide]( 

638 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

639 

640 Args: 

641 x: Numpy array of input data or tuple. If tuple, the second elements is 

642 either another numpy array or a list of numpy arrays, each of which 

643 gets passed through as an output without any modifications. 

644 y: Numpy array of targets data. 

645 image_data_generator: Instance of `ImageDataGenerator` to use for random 

646 transformations and normalization. 

647 batch_size: Integer, size of a batch. 

648 shuffle: Boolean, whether to shuffle the data between epochs. 

649 sample_weight: Numpy array of sample weights. 

650 seed: Random seed for data shuffling. 

651 data_format: String, one of `channels_first`, `channels_last`. 

652 save_to_dir: Optional directory where to save the pictures being 

653 yielded, in a viewable format. This is useful for visualizing the 

654 random transformations being applied, for debugging purposes. 

655 save_prefix: String prefix to use for saving sample images (if 

656 `save_to_dir` is set). 

657 save_format: Format to use for saving sample images (if `save_to_dir` is 

658 set). 

659 subset: Subset of data (`"training"` or `"validation"`) if 

660 validation_split is set in ImageDataGenerator. 

661 ignore_class_split: Boolean (default: False), ignore difference 

662 in number of classes in labels across train and validation 

663 split (useful for non-classification tasks) 

664 dtype: Dtype to use for the generated arrays. 

665 """ 

666 

667 def __init__( 

668 self, 

669 x, 

670 y, 

671 image_data_generator, 

672 batch_size=32, 

673 shuffle=False, 

674 sample_weight=None, 

675 seed=None, 

676 data_format=None, 

677 save_to_dir=None, 

678 save_prefix="", 

679 save_format="png", 

680 subset=None, 

681 ignore_class_split=False, 

682 dtype=None, 

683 ): 

684 if data_format is None: 

685 data_format = backend.image_data_format() 

686 if dtype is None: 

687 dtype = backend.floatx() 

688 self.dtype = dtype 

689 if isinstance(x, tuple) or isinstance(x, list): 

690 if not isinstance(x[1], list): 

691 x_misc = [np.asarray(x[1])] 

692 else: 

693 x_misc = [np.asarray(xx) for xx in x[1]] 

694 x = x[0] 

695 for xx in x_misc: 

696 if len(x) != len(xx): 

697 raise ValueError( 

698 "All of the arrays in `x` " 

699 "should have the same length. " 

700 "Found a pair with: len(x[0]) = %s, len(x[?]) = %s" 

701 % (len(x), len(xx)) 

702 ) 

703 else: 

704 x_misc = [] 

705 

706 if y is not None and len(x) != len(y): 

707 raise ValueError( 

708 "`x` (images tensor) and `y` (labels) " 

709 "should have the same length. " 

710 "Found: x.shape = %s, y.shape = %s" 

711 % (np.asarray(x).shape, np.asarray(y).shape) 

712 ) 

713 if sample_weight is not None and len(x) != len(sample_weight): 

714 raise ValueError( 

715 "`x` (images tensor) and `sample_weight` " 

716 "should have the same length. " 

717 "Found: x.shape = %s, sample_weight.shape = %s" 

718 % (np.asarray(x).shape, np.asarray(sample_weight).shape) 

719 ) 

720 if subset is not None: 

721 if subset not in {"training", "validation"}: 

722 raise ValueError( 

723 "Invalid subset name:", 

724 subset, 

725 '; expected "training" or "validation".', 

726 ) 

727 split_idx = int(len(x) * image_data_generator._validation_split) 

728 

729 if ( 

730 y is not None 

731 and not ignore_class_split 

732 and not np.array_equal( 

733 np.unique(y[:split_idx]), np.unique(y[split_idx:]) 

734 ) 

735 ): 

736 raise ValueError( 

737 "Training and validation subsets " 

738 "have different number of classes after " 

739 "the split. If your numpy arrays are " 

740 "sorted by the label, you might want " 

741 "to shuffle them." 

742 ) 

743 

744 if subset == "validation": 

745 x = x[:split_idx] 

746 x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc] 

747 if y is not None: 

748 y = y[:split_idx] 

749 else: 

750 x = x[split_idx:] 

751 x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] 

752 if y is not None: 

753 y = y[split_idx:] 

754 

755 self.x = np.asarray(x, dtype=self.dtype) 

756 self.x_misc = x_misc 

757 if self.x.ndim != 4: 

758 raise ValueError( 

759 "Input data in `NumpyArrayIterator` " 

760 "should have rank 4. You passed an array " 

761 "with shape", 

762 self.x.shape, 

763 ) 

764 channels_axis = 3 if data_format == "channels_last" else 1 

765 if self.x.shape[channels_axis] not in {1, 3, 4}: 

766 warnings.warn( 

767 'NumpyArrayIterator is set to use the data format convention "' 

768 + data_format 

769 + '" (channels on axis ' 

770 + str(channels_axis) 

771 + "), i.e. expected either 1, 3, or 4 channels on axis " 

772 + str(channels_axis) 

773 + ". However, it was passed an array with shape " 

774 + str(self.x.shape) 

775 + " (" 

776 + str(self.x.shape[channels_axis]) 

777 + " channels)." 

778 ) 

779 if y is not None: 

780 self.y = np.asarray(y) 

781 else: 

782 self.y = None 

783 if sample_weight is not None: 

784 self.sample_weight = np.asarray(sample_weight) 

785 else: 

786 self.sample_weight = None 

787 self.image_data_generator = image_data_generator 

788 self.data_format = data_format 

789 self.save_to_dir = save_to_dir 

790 self.save_prefix = save_prefix 

791 self.save_format = save_format 

792 super().__init__(x.shape[0], batch_size, shuffle, seed) 

793 

794 def _get_batches_of_transformed_samples(self, index_array): 

795 batch_x = np.zeros( 

796 tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=self.dtype 

797 ) 

798 for i, j in enumerate(index_array): 

799 x = self.x[j] 

800 params = self.image_data_generator.get_random_transform(x.shape) 

801 x = self.image_data_generator.apply_transform( 

802 x.astype(self.dtype), params 

803 ) 

804 x = self.image_data_generator.standardize(x) 

805 batch_x[i] = x 

806 

807 if self.save_to_dir: 

808 for i, j in enumerate(index_array): 

809 img = image_utils.array_to_img( 

810 batch_x[i], self.data_format, scale=True 

811 ) 

812 fname = "{prefix}_{index}_{hash}.{format}".format( 

813 prefix=self.save_prefix, 

814 index=j, 

815 hash=np.random.randint(1e4), 

816 format=self.save_format, 

817 ) 

818 img.save(os.path.join(self.save_to_dir, fname)) 

819 batch_x_miscs = [xx[index_array] for xx in self.x_misc] 

820 output = (batch_x if not batch_x_miscs else [batch_x] + batch_x_miscs,) 

821 if self.y is None: 

822 return output[0] 

823 output += (self.y[index_array],) 

824 if self.sample_weight is not None: 

825 output += (self.sample_weight[index_array],) 

826 return output 

827 

828 

829def validate_filename(filename, white_list_formats): 

830 """Check if a filename refers to a valid file. 

831 

832 Args: 

833 filename: String, absolute path to a file 

834 white_list_formats: Set, allowed file extensions 

835 Returns: 

836 A boolean value indicating if the filename is valid or not 

837 """ 

838 return filename.lower().endswith(white_list_formats) and os.path.isfile( 

839 filename 

840 ) 

841 

842 

843class DataFrameIterator(BatchFromFilesMixin, Iterator): 

844 """Iterator capable of reading images from a directory as a dataframe. 

845 

846 Args: 

847 dataframe: Pandas dataframe containing the filepaths relative to 

848 `directory` (or absolute paths if `directory` is None) of the images 

849 in a string column. It should include other column/s depending on the 

850 `class_mode`: - if `class_mode` is `"categorical"` (default value) it 

851 must include the `y_col` column with the class/es of each image. 

852 Values in column can be string/list/tuple if a single class or 

853 list/tuple if multiple classes. 

854 - if `class_mode` is `"binary"` or `"sparse"` it must include the 

855 given `y_col` column with class values as strings. 

856 - if `class_mode` is `"raw"` or `"multi_output"` it should contain 

857 the columns specified in `y_col`. 

858 - if `class_mode` is `"input"` or `None` no extra column is needed. 

859 directory: string, path to the directory to read images from. If `None`, 

860 data in `x_col` column should be absolute paths. 

861 image_data_generator: Instance of `ImageDataGenerator` to use for random 

862 transformations and normalization. If None, no transformations and 

863 normalizations are made. 

864 x_col: string, column in `dataframe` that contains the filenames (or 

865 absolute paths if `directory` is `None`). 

866 y_col: string or list, column/s in `dataframe` that has the target data. 

867 weight_col: string, column in `dataframe` that contains the sample 

868 weights. Default: `None`. 

869 target_size: tuple of integers, dimensions to resize input images to. 

870 color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read 

871 images. 

872 classes: Optional list of strings, classes to use (e.g. `["dogs", 

873 "cats"]`). If None, all classes in `y_col` will be used. 

874 class_mode: one of "binary", "categorical", "input", "multi_output", 

875 "raw", "sparse" or None. Default: "categorical". 

876 Mode for yielding the targets: 

877 - `"binary"`: 1D numpy array of binary labels, 

878 - `"categorical"`: 2D numpy array of one-hot encoded labels. 

879 Supports multi-label output. 

880 - `"input"`: images identical to input images (mainly used to work 

881 with autoencoders), 

882 - `"multi_output"`: list with the values of the different columns, 

883 - `"raw"`: numpy array of values in `y_col` column(s), 

884 - `"sparse"`: 1D numpy array of integer labels, - `None`, no targets 

885 are returned (the generator will only yield batches of image data, 

886 which is useful to use in `model.predict()`). 

887 batch_size: Integer, size of a batch. 

888 shuffle: Boolean, whether to shuffle the data between epochs. 

889 seed: Random seed for data shuffling. 

890 data_format: String, one of `channels_first`, `channels_last`. 

891 save_to_dir: Optional directory where to save the pictures being 

892 yielded, in a viewable format. This is useful for visualizing the 

893 random transformations being applied, for debugging purposes. 

894 save_prefix: String prefix to use for saving sample images (if 

895 `save_to_dir` is set). 

896 save_format: Format to use for saving sample images (if `save_to_dir` is 

897 set). 

898 subset: Subset of data (`"training"` or `"validation"`) if 

899 validation_split is set in ImageDataGenerator. 

900 interpolation: Interpolation method used to resample the image if the 

901 target size is different from that of the loaded image. Supported 

902 methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3 

903 or newer is installed, "lanczos" is also supported. If PIL version 

904 3.4.0 or newer is installed, "box" and "hamming" are also supported. 

905 By default, "nearest" is used. 

906 keep_aspect_ratio: Boolean, whether to resize images to a target size 

907 without aspect ratio distortion. The image is cropped in the center 

908 with target aspect ratio before resizing. 

909 dtype: Dtype to use for the generated arrays. 

910 validate_filenames: Boolean, whether to validate image filenames in 

911 `x_col`. If `True`, invalid images will be ignored. Disabling this 

912 option can lead to speed-up in the instantiation of this class. 

913 Default: `True`. 

914 """ 

915 

916 allowed_class_modes = { 

917 "binary", 

918 "categorical", 

919 "input", 

920 "multi_output", 

921 "raw", 

922 "sparse", 

923 None, 

924 } 

925 

926 def __init__( 

927 self, 

928 dataframe, 

929 directory=None, 

930 image_data_generator=None, 

931 x_col="filename", 

932 y_col="class", 

933 weight_col=None, 

934 target_size=(256, 256), 

935 color_mode="rgb", 

936 classes=None, 

937 class_mode="categorical", 

938 batch_size=32, 

939 shuffle=True, 

940 seed=None, 

941 data_format="channels_last", 

942 save_to_dir=None, 

943 save_prefix="", 

944 save_format="png", 

945 subset=None, 

946 interpolation="nearest", 

947 keep_aspect_ratio=False, 

948 dtype="float32", 

949 validate_filenames=True, 

950 ): 

951 super().set_processing_attrs( 

952 image_data_generator, 

953 target_size, 

954 color_mode, 

955 data_format, 

956 save_to_dir, 

957 save_prefix, 

958 save_format, 

959 subset, 

960 interpolation, 

961 keep_aspect_ratio, 

962 ) 

963 df = dataframe.copy() 

964 self.directory = directory or "" 

965 self.class_mode = class_mode 

966 self.dtype = dtype 

967 # check that inputs match the required class_mode 

968 self._check_params(df, x_col, y_col, weight_col, classes) 

969 if ( 

970 validate_filenames 

971 ): # check which image files are valid and keep them 

972 df = self._filter_valid_filepaths(df, x_col) 

973 if class_mode not in ["input", "multi_output", "raw", None]: 

974 df, classes = self._filter_classes(df, y_col, classes) 

975 num_classes = len(classes) 

976 # build an index of all the unique classes 

977 self.class_indices = dict(zip(classes, range(len(classes)))) 

978 # retrieve only training or validation set 

979 if self.split: 

980 num_files = len(df) 

981 start = int(self.split[0] * num_files) 

982 stop = int(self.split[1] * num_files) 

983 df = df.iloc[start:stop, :] 

984 # get labels for each observation 

985 if class_mode not in ["input", "multi_output", "raw", None]: 

986 self.classes = self.get_classes(df, y_col) 

987 self.filenames = df[x_col].tolist() 

988 self._sample_weight = df[weight_col].values if weight_col else None 

989 

990 if class_mode == "multi_output": 

991 self._targets = [np.array(df[col].tolist()) for col in y_col] 

992 if class_mode == "raw": 

993 self._targets = df[y_col].values 

994 self.samples = len(self.filenames) 

995 validated_string = ( 

996 "validated" if validate_filenames else "non-validated" 

997 ) 

998 if class_mode in ["input", "multi_output", "raw", None]: 

999 io_utils.print_msg( 

1000 f"Found {self.samples} {validated_string} image filenames." 

1001 ) 

1002 else: 

1003 io_utils.print_msg( 

1004 f"Found {self.samples} {validated_string} image filenames " 

1005 f"belonging to {num_classes} classes." 

1006 ) 

1007 self._filepaths = [ 

1008 os.path.join(self.directory, fname) for fname in self.filenames 

1009 ] 

1010 super().__init__(self.samples, batch_size, shuffle, seed) 

1011 

1012 def _check_params(self, df, x_col, y_col, weight_col, classes): 

1013 # check class mode is one of the currently supported 

1014 if self.class_mode not in self.allowed_class_modes: 

1015 raise ValueError( 

1016 "Invalid class_mode: {}; expected one of: {}".format( 

1017 self.class_mode, self.allowed_class_modes 

1018 ) 

1019 ) 

1020 # check that y_col has several column names if class_mode is 

1021 # multi_output 

1022 if (self.class_mode == "multi_output") and not isinstance(y_col, list): 

1023 raise TypeError( 

1024 'If class_mode="{}", y_col must be a list. Received {}.'.format( 

1025 self.class_mode, type(y_col).__name__ 

1026 ) 

1027 ) 

1028 # check that filenames/filepaths column values are all strings 

1029 if not all(df[x_col].apply(lambda x: isinstance(x, str))): 

1030 raise TypeError( 

1031 f"All values in column x_col={x_col} must be strings." 

1032 ) 

1033 # check labels are string if class_mode is binary or sparse 

1034 if self.class_mode in {"binary", "sparse"}: 

1035 if not all(df[y_col].apply(lambda x: isinstance(x, str))): 

1036 raise TypeError( 

1037 'If class_mode="{}", y_col="{}" column ' 

1038 "values must be strings.".format(self.class_mode, y_col) 

1039 ) 

1040 # check that if binary there are only 2 different classes 

1041 if self.class_mode == "binary": 

1042 if classes: 

1043 classes = set(classes) 

1044 if len(classes) != 2: 

1045 raise ValueError( 

1046 'If class_mode="binary" there must be 2 ' 

1047 "classes. {} class/es were given.".format(len(classes)) 

1048 ) 

1049 elif df[y_col].nunique() != 2: 

1050 raise ValueError( 

1051 'If class_mode="binary" there must be 2 classes. ' 

1052 "Found {} classes.".format(df[y_col].nunique()) 

1053 ) 

1054 # check values are string, list or tuple if class_mode is categorical 

1055 if self.class_mode == "categorical": 

1056 types = (str, list, tuple) 

1057 if not all(df[y_col].apply(lambda x: isinstance(x, types))): 

1058 raise TypeError( 

1059 'If class_mode="{}", y_col="{}" column ' 

1060 "values must be type string, list or tuple.".format( 

1061 self.class_mode, y_col 

1062 ) 

1063 ) 

1064 # raise warning if classes are given but will be unused 

1065 if classes and self.class_mode in { 

1066 "input", 

1067 "multi_output", 

1068 "raw", 

1069 None, 

1070 }: 

1071 warnings.warn( 

1072 '`classes` will be ignored given the class_mode="{}"'.format( 

1073 self.class_mode 

1074 ) 

1075 ) 

1076 # check that if weight column that the values are numerical 

1077 if weight_col and not issubclass(df[weight_col].dtype.type, np.number): 

1078 raise TypeError(f"Column weight_col={weight_col} must be numeric.") 

1079 

1080 def get_classes(self, df, y_col): 

1081 labels = [] 

1082 for label in df[y_col]: 

1083 if isinstance(label, (list, tuple)): 

1084 labels.append([self.class_indices[lbl] for lbl in label]) 

1085 else: 

1086 labels.append(self.class_indices[label]) 

1087 return labels 

1088 

1089 @staticmethod 

1090 def _filter_classes(df, y_col, classes): 

1091 df = df.copy() 

1092 

1093 def remove_classes(labels, classes): 

1094 if isinstance(labels, (list, tuple)): 

1095 labels = [cls for cls in labels if cls in classes] 

1096 return labels or None 

1097 elif isinstance(labels, str): 

1098 return labels if labels in classes else None 

1099 else: 

1100 raise TypeError( 

1101 "Expect string, list or tuple " 

1102 "but found {} in {} column ".format(type(labels), y_col) 

1103 ) 

1104 

1105 if classes: 

1106 # prepare for membership lookup 

1107 classes = list(collections.OrderedDict.fromkeys(classes).keys()) 

1108 df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes)) 

1109 else: 

1110 classes = set() 

1111 for v in df[y_col]: 

1112 if isinstance(v, (list, tuple)): 

1113 classes.update(v) 

1114 else: 

1115 classes.add(v) 

1116 classes = sorted(classes) 

1117 return df.dropna(subset=[y_col]), classes 

1118 

1119 def _filter_valid_filepaths(self, df, x_col): 

1120 """Keep only dataframe rows with valid filenames. 

1121 

1122 Args: 

1123 df: Pandas dataframe containing filenames in a column 

1124 x_col: string, column in `df` that contains the filenames or 

1125 filepaths 

1126 Returns: 

1127 absolute paths to image files 

1128 """ 

1129 filepaths = df[x_col].map( 

1130 lambda fname: os.path.join(self.directory, fname) 

1131 ) 

1132 mask = filepaths.apply( 

1133 validate_filename, args=(self.white_list_formats,) 

1134 ) 

1135 n_invalid = (~mask).sum() 

1136 if n_invalid: 

1137 warnings.warn( 

1138 'Found {} invalid image filename(s) in x_col="{}". ' 

1139 "These filename(s) will be ignored.".format(n_invalid, x_col) 

1140 ) 

1141 return df[mask] 

1142 

1143 @property 

1144 def filepaths(self): 

1145 return self._filepaths 

1146 

1147 @property 

1148 def labels(self): 

1149 if self.class_mode in {"multi_output", "raw"}: 

1150 return self._targets 

1151 else: 

1152 return self.classes 

1153 

1154 @property 

1155 def sample_weight(self): 

1156 return self._sample_weight 

1157 

1158 

1159def flip_axis(x, axis): 

1160 x = np.asarray(x).swapaxes(axis, 0) 

1161 x = x[::-1, ...] 

1162 x = x.swapaxes(0, axis) 

1163 return x 

1164 

1165 

1166@keras_export("keras.preprocessing.image.ImageDataGenerator") 

1167class ImageDataGenerator: 

1168 """Generate batches of tensor image data with real-time data augmentation. 

1169 

1170 Deprecated: `tf.keras.preprocessing.image.ImageDataGenerator` is not 

1171 recommended for new code. Prefer loading images with 

1172 `tf.keras.utils.image_dataset_from_directory` and transforming the output 

1173 `tf.data.Dataset` with preprocessing layers. For more information, see the 

1174 tutorials for [loading images]( 

1175 https://www.tensorflow.org/tutorials/load_data/images) and 

1176 [augmenting images]( 

1177 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

1178 the [preprocessing layer guide]( 

1179 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

1180 

1181 The data will be looped over (in batches). 

1182 

1183 Args: 

1184 featurewise_center: Boolean. Set input mean to 0 over the dataset, 

1185 feature-wise. 

1186 samplewise_center: Boolean. Set each sample mean to 0. 

1187 featurewise_std_normalization: Boolean. Divide inputs by std of the 

1188 dataset, feature-wise. 

1189 samplewise_std_normalization: Boolean. Divide each input by its std. 

1190 zca_epsilon: epsilon for ZCA whitening. Default is 1e-6. 

1191 zca_whitening: Boolean. Apply ZCA whitening. 

1192 rotation_range: Int. Degree range for random rotations. 

1193 width_shift_range: Float, 1-D array-like or int 

1194 - float: fraction of total width, if < 1, or pixels if >= 1. 

1195 - 1-D array-like: random elements from the array. 

1196 - int: integer number of pixels from interval `(-width_shift_range, 

1197 +width_shift_range)` - With `width_shift_range=2` possible values 

1198 are integers `[-1, 0, +1]`, same as with `width_shift_range=[-1, 

1199 0, +1]`, while with `width_shift_range=1.0` possible values are 

1200 floats in the interval [-1.0, +1.0). 

1201 height_shift_range: Float, 1-D array-like or int 

1202 - float: fraction of total height, if < 1, or pixels if >= 1. 

1203 - 1-D array-like: random elements from the array. 

1204 - int: integer number of pixels from interval `(-height_shift_range, 

1205 +height_shift_range)` - With `height_shift_range=2` possible 

1206 values are integers `[-1, 0, +1]`, same as with 

1207 `height_shift_range=[-1, 0, +1]`, while with 

1208 `height_shift_range=1.0` possible values are floats in the 

1209 interval [-1.0, +1.0). 

1210 brightness_range: Tuple or list of two floats. Range for picking a 

1211 brightness shift value from. 

1212 shear_range: Float. Shear Intensity (Shear angle in counter-clockwise 

1213 direction in degrees) 

1214 zoom_range: Float or [lower, upper]. Range for random zoom. If a float, 

1215 `[lower, upper] = [1-zoom_range, 1+zoom_range]`. 

1216 channel_shift_range: Float. Range for random channel shifts. 

1217 fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}. Default 

1218 is 'nearest'. Points outside the boundaries of the input are filled 

1219 according to the given mode: 

1220 - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k) 

1221 - 'nearest': aaaaaaaa|abcd|dddddddd 

1222 - 'reflect': abcddcba|abcd|dcbaabcd 

1223 - 'wrap': abcdabcd|abcd|abcdabcd 

1224 cval: Float or Int. Value used for points outside the boundaries when 

1225 `fill_mode = "constant"`. 

1226 horizontal_flip: Boolean. Randomly flip inputs horizontally. 

1227 vertical_flip: Boolean. Randomly flip inputs vertically. 

1228 rescale: rescaling factor. Defaults to None. If None or 0, no rescaling 

1229 is applied, otherwise we multiply the data by the value provided 

1230 (after applying all other transformations). 

1231 preprocessing_function: function that will be applied on each input. The 

1232 function will run after the image is resized and augmented. 

1233 The function should take one argument: one image (Numpy tensor with 

1234 rank 3), and should output a Numpy tensor with the same shape. 

1235 data_format: Image data format, either "channels_first" or 

1236 "channels_last". "channels_last" mode means that the images should 

1237 have shape `(samples, height, width, channels)`, "channels_first" mode 

1238 means that the images should have shape `(samples, channels, height, 

1239 width)`. It defaults to the `image_data_format` value found in your 

1240 Keras config file at `~/.keras/keras.json`. If you never set it, then 

1241 it will be "channels_last". 

1242 validation_split: Float. Fraction of images reserved for validation 

1243 (strictly between 0 and 1). 

1244 dtype: Dtype to use for the generated arrays. 

1245 

1246 Raises: 

1247 ValueError: If the value of the argument, `data_format` is other than 

1248 `"channels_last"` or `"channels_first"`. 

1249 ValueError: If the value of the argument, `validation_split` > 1 

1250 or `validation_split` < 0. 

1251 

1252 Examples: 

1253 

1254 Example of using `.flow(x, y)`: 

1255 

1256 ```python 

1257 (x_train, y_train), (x_test, y_test) = cifar10.load_data() 

1258 y_train = utils.to_categorical(y_train, num_classes) 

1259 y_test = utils.to_categorical(y_test, num_classes) 

1260 datagen = ImageDataGenerator( 

1261 featurewise_center=True, 

1262 featurewise_std_normalization=True, 

1263 rotation_range=20, 

1264 width_shift_range=0.2, 

1265 height_shift_range=0.2, 

1266 horizontal_flip=True, 

1267 validation_split=0.2) 

1268 # compute quantities required for featurewise normalization 

1269 # (std, mean, and principal components if ZCA whitening is applied) 

1270 datagen.fit(x_train) 

1271 # fits the model on batches with real-time data augmentation: 

1272 model.fit(datagen.flow(x_train, y_train, batch_size=32, 

1273 subset='training'), 

1274 validation_data=datagen.flow(x_train, y_train, 

1275 batch_size=8, subset='validation'), 

1276 steps_per_epoch=len(x_train) / 32, epochs=epochs) 

1277 # here's a more "manual" example 

1278 for e in range(epochs): 

1279 print('Epoch', e) 

1280 batches = 0 

1281 for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32): 

1282 model.fit(x_batch, y_batch) 

1283 batches += 1 

1284 if batches >= len(x_train) / 32: 

1285 # we need to break the loop by hand because 

1286 # the generator loops indefinitely 

1287 break 

1288 ``` 

1289 

1290 Example of using `.flow_from_directory(directory)`: 

1291 

1292 ```python 

1293 train_datagen = ImageDataGenerator( 

1294 rescale=1./255, 

1295 shear_range=0.2, 

1296 zoom_range=0.2, 

1297 horizontal_flip=True) 

1298 test_datagen = ImageDataGenerator(rescale=1./255) 

1299 train_generator = train_datagen.flow_from_directory( 

1300 'data/train', 

1301 target_size=(150, 150), 

1302 batch_size=32, 

1303 class_mode='binary') 

1304 validation_generator = test_datagen.flow_from_directory( 

1305 'data/validation', 

1306 target_size=(150, 150), 

1307 batch_size=32, 

1308 class_mode='binary') 

1309 model.fit( 

1310 train_generator, 

1311 steps_per_epoch=2000, 

1312 epochs=50, 

1313 validation_data=validation_generator, 

1314 validation_steps=800) 

1315 ``` 

1316 

1317 Example of transforming images and masks together. 

1318 

1319 ```python 

1320 # we create two instances with the same arguments 

1321 data_gen_args = dict(featurewise_center=True, 

1322 featurewise_std_normalization=True, 

1323 rotation_range=90, 

1324 width_shift_range=0.1, 

1325 height_shift_range=0.1, 

1326 zoom_range=0.2) 

1327 image_datagen = ImageDataGenerator(**data_gen_args) 

1328 mask_datagen = ImageDataGenerator(**data_gen_args) 

1329 # Provide the same seed and keyword arguments to the fit and flow methods 

1330 seed = 1 

1331 image_datagen.fit(images, augment=True, seed=seed) 

1332 mask_datagen.fit(masks, augment=True, seed=seed) 

1333 image_generator = image_datagen.flow_from_directory( 

1334 'data/images', 

1335 class_mode=None, 

1336 seed=seed) 

1337 mask_generator = mask_datagen.flow_from_directory( 

1338 'data/masks', 

1339 class_mode=None, 

1340 seed=seed) 

1341 # combine generators into one which yields image and masks 

1342 train_generator = zip(image_generator, mask_generator) 

1343 model.fit( 

1344 train_generator, 

1345 steps_per_epoch=2000, 

1346 epochs=50) 

1347 ``` 

1348 """ 

1349 

1350 def __init__( 

1351 self, 

1352 featurewise_center=False, 

1353 samplewise_center=False, 

1354 featurewise_std_normalization=False, 

1355 samplewise_std_normalization=False, 

1356 zca_whitening=False, 

1357 zca_epsilon=1e-6, 

1358 rotation_range=0, 

1359 width_shift_range=0.0, 

1360 height_shift_range=0.0, 

1361 brightness_range=None, 

1362 shear_range=0.0, 

1363 zoom_range=0.0, 

1364 channel_shift_range=0.0, 

1365 fill_mode="nearest", 

1366 cval=0.0, 

1367 horizontal_flip=False, 

1368 vertical_flip=False, 

1369 rescale=None, 

1370 preprocessing_function=None, 

1371 data_format=None, 

1372 validation_split=0.0, 

1373 interpolation_order=1, 

1374 dtype=None, 

1375 ): 

1376 if data_format is None: 

1377 data_format = backend.image_data_format() 

1378 if dtype is None: 

1379 dtype = backend.floatx() 

1380 

1381 self.featurewise_center = featurewise_center 

1382 self.samplewise_center = samplewise_center 

1383 self.featurewise_std_normalization = featurewise_std_normalization 

1384 self.samplewise_std_normalization = samplewise_std_normalization 

1385 self.zca_whitening = zca_whitening 

1386 self.zca_epsilon = zca_epsilon 

1387 self.rotation_range = rotation_range 

1388 self.width_shift_range = width_shift_range 

1389 self.height_shift_range = height_shift_range 

1390 self.shear_range = shear_range 

1391 self.zoom_range = zoom_range 

1392 self.channel_shift_range = channel_shift_range 

1393 self.fill_mode = fill_mode 

1394 self.cval = cval 

1395 self.horizontal_flip = horizontal_flip 

1396 self.vertical_flip = vertical_flip 

1397 self.rescale = rescale 

1398 self.preprocessing_function = preprocessing_function 

1399 self.dtype = dtype 

1400 self.interpolation_order = interpolation_order 

1401 

1402 if data_format not in {"channels_last", "channels_first"}: 

1403 raise ValueError( 

1404 '`data_format` should be `"channels_last"` ' 

1405 "(channel after row and column) or " 

1406 '`"channels_first"` (channel before row and column). ' 

1407 "Received: %s" % data_format 

1408 ) 

1409 self.data_format = data_format 

1410 if data_format == "channels_first": 

1411 self.channel_axis = 1 

1412 self.row_axis = 2 

1413 self.col_axis = 3 

1414 if data_format == "channels_last": 

1415 self.channel_axis = 3 

1416 self.row_axis = 1 

1417 self.col_axis = 2 

1418 if validation_split and not 0 < validation_split < 1: 

1419 raise ValueError( 

1420 "`validation_split` must be strictly between 0 and 1. " 

1421 " Received: %s" % validation_split 

1422 ) 

1423 self._validation_split = validation_split 

1424 

1425 self.mean = None 

1426 self.std = None 

1427 self.zca_whitening_matrix = None 

1428 

1429 if isinstance(zoom_range, (float, int)): 

1430 self.zoom_range = [1 - zoom_range, 1 + zoom_range] 

1431 elif len(zoom_range) == 2 and all( 

1432 isinstance(val, (float, int)) for val in zoom_range 

1433 ): 

1434 self.zoom_range = [zoom_range[0], zoom_range[1]] 

1435 else: 

1436 raise ValueError( 

1437 "`zoom_range` should be a float or " 

1438 "a tuple or list of two floats. " 

1439 "Received: %s" % (zoom_range,) 

1440 ) 

1441 if zca_whitening: 

1442 if not featurewise_center: 

1443 self.featurewise_center = True 

1444 warnings.warn( 

1445 "This ImageDataGenerator specifies " 

1446 "`zca_whitening`, which overrides " 

1447 "setting of `featurewise_center`." 

1448 ) 

1449 if featurewise_std_normalization: 

1450 self.featurewise_std_normalization = False 

1451 warnings.warn( 

1452 "This ImageDataGenerator specifies " 

1453 "`zca_whitening` " 

1454 "which overrides setting of" 

1455 "`featurewise_std_normalization`." 

1456 ) 

1457 if featurewise_std_normalization: 

1458 if not featurewise_center: 

1459 self.featurewise_center = True 

1460 warnings.warn( 

1461 "This ImageDataGenerator specifies " 

1462 "`featurewise_std_normalization`, " 

1463 "which overrides setting of " 

1464 "`featurewise_center`." 

1465 ) 

1466 if samplewise_std_normalization: 

1467 if not samplewise_center: 

1468 self.samplewise_center = True 

1469 warnings.warn( 

1470 "This ImageDataGenerator specifies " 

1471 "`samplewise_std_normalization`, " 

1472 "which overrides setting of " 

1473 "`samplewise_center`." 

1474 ) 

1475 if brightness_range is not None: 

1476 if ( 

1477 not isinstance(brightness_range, (tuple, list)) 

1478 or len(brightness_range) != 2 

1479 ): 

1480 raise ValueError( 

1481 "`brightness_range should be tuple or list of two floats. " 

1482 "Received: %s" % (brightness_range,) 

1483 ) 

1484 self.brightness_range = brightness_range 

1485 

1486 def flow( 

1487 self, 

1488 x, 

1489 y=None, 

1490 batch_size=32, 

1491 shuffle=True, 

1492 sample_weight=None, 

1493 seed=None, 

1494 save_to_dir=None, 

1495 save_prefix="", 

1496 save_format="png", 

1497 ignore_class_split=False, 

1498 subset=None, 

1499 ): 

1500 """Takes data & label arrays, generates batches of augmented data. 

1501 

1502 Args: 

1503 x: Input data. Numpy array of rank 4 or a tuple. If tuple, the first 

1504 element should contain the images and the second element another 

1505 numpy array or a list of numpy arrays that gets passed to the 

1506 output without any modifications. Can be used to feed the model 

1507 miscellaneous data along with the images. In case of grayscale 

1508 data, the channels axis of the image array should have value 1, in 

1509 case of RGB data, it should have value 3, and in case of RGBA 

1510 data, it should have value 4. 

1511 y: Labels. 

1512 batch_size: Int (default: 32). 

1513 shuffle: Boolean (default: True). 

1514 sample_weight: Sample weights. 

1515 seed: Int (default: None). 

1516 save_to_dir: None or str (default: None). This allows you to 

1517 optionally specify a directory to which to save the augmented 

1518 pictures being generated (useful for visualizing what you are 

1519 doing). 

1520 save_prefix: Str (default: `''`). Prefix to use for filenames of 

1521 saved pictures (only relevant if `save_to_dir` is set). 

1522 save_format: one of "png", "jpeg", "bmp", "pdf", "ppm", "gif", 

1523 "tif", "jpg" (only relevant if `save_to_dir` is set). Default: 

1524 "png". 

1525 ignore_class_split: Boolean (default: False), ignore difference 

1526 in number of classes in labels across train and validation 

1527 split (useful for non-classification tasks) 

1528 subset: Subset of data (`"training"` or `"validation"`) if 

1529 `validation_split` is set in `ImageDataGenerator`. 

1530 

1531 Returns: 

1532 An `Iterator` yielding tuples of `(x, y)` 

1533 where `x` is a numpy array of image data 

1534 (in the case of a single image input) or a list 

1535 of numpy arrays (in the case with 

1536 additional inputs) and `y` is a numpy array 

1537 of corresponding labels. If 'sample_weight' is not None, 

1538 the yielded tuples are of the form `(x, y, sample_weight)`. 

1539 If `y` is None, only the numpy array `x` is returned. 

1540 Raises: 

1541 ValueError: If the Value of the argument, `subset` is other than 

1542 "training" or "validation". 

1543 

1544 """ 

1545 return NumpyArrayIterator( 

1546 x, 

1547 y, 

1548 self, 

1549 batch_size=batch_size, 

1550 shuffle=shuffle, 

1551 sample_weight=sample_weight, 

1552 seed=seed, 

1553 data_format=self.data_format, 

1554 save_to_dir=save_to_dir, 

1555 save_prefix=save_prefix, 

1556 save_format=save_format, 

1557 ignore_class_split=ignore_class_split, 

1558 subset=subset, 

1559 dtype=self.dtype, 

1560 ) 

1561 

1562 def flow_from_directory( 

1563 self, 

1564 directory, 

1565 target_size=(256, 256), 

1566 color_mode="rgb", 

1567 classes=None, 

1568 class_mode="categorical", 

1569 batch_size=32, 

1570 shuffle=True, 

1571 seed=None, 

1572 save_to_dir=None, 

1573 save_prefix="", 

1574 save_format="png", 

1575 follow_links=False, 

1576 subset=None, 

1577 interpolation="nearest", 

1578 keep_aspect_ratio=False, 

1579 ): 

1580 """Takes the path to a directory & generates batches of augmented data. 

1581 

1582 Args: 

1583 directory: string, path to the target directory. It should contain 

1584 one subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images 

1585 inside each of the subdirectories directory tree will be included 

1586 in the generator. See [this script]( 

1587 https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d) 

1588 for more details. 

1589 target_size: Tuple of integers `(height, width)`, defaults to `(256, 

1590 256)`. The dimensions to which all images found will be resized. 

1591 color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb". 

1592 Whether the images will be converted to have 1, 3, or 4 channels. 

1593 classes: Optional list of class subdirectories (e.g. `['dogs', 

1594 'cats']`). Default: None. If not provided, the list of classes 

1595 will be automatically inferred from the subdirectory 

1596 names/structure under `directory`, where each subdirectory will be 

1597 treated as a different class (and the order of the classes, which 

1598 will map to the label indices, will be alphanumeric). The 

1599 dictionary containing the mapping from class names to class 

1600 indices can be obtained via the attribute `class_indices`. 

1601 class_mode: One of "categorical", "binary", "sparse", 

1602 "input", or None. Default: "categorical". 

1603 Determines the type of label arrays that are returned: 

1604 - "categorical" will be 2D one-hot encoded labels, 

1605 - "binary" will be 1D binary labels, 

1606 "sparse" will be 1D integer labels, 

1607 - "input" will be images identical 

1608 to input images (mainly used to work with autoencoders). 

1609 - If None, no labels are returned 

1610 (the generator will only yield batches of image data, 

1611 which is useful to use with `model.predict_generator()`). 

1612 Please note that in case of class_mode None, 

1613 the data still needs to reside in a subdirectory 

1614 of `directory` for it to work correctly. 

1615 batch_size: Size of the batches of data (default: 32). 

1616 shuffle: Whether to shuffle the data (default: True) If set to 

1617 False, sorts the data in alphanumeric order. 

1618 seed: Optional random seed for shuffling and transformations. 

1619 save_to_dir: None or str (default: None). This allows you to 

1620 optionally specify a directory to which to save the augmented 

1621 pictures being generated (useful for visualizing what you are 

1622 doing). 

1623 save_prefix: Str. Prefix to use for filenames of saved pictures 

1624 (only relevant if `save_to_dir` is set). 

1625 save_format: one of "png", "jpeg", "bmp", "pdf", "ppm", "gif", 

1626 "tif", "jpg" (only relevant if `save_to_dir` is set). Default: 

1627 "png". 

1628 follow_links: Whether to follow symlinks inside 

1629 class subdirectories (default: False). 

1630 subset: Subset of data (`"training"` or `"validation"`) if 

1631 `validation_split` is set in `ImageDataGenerator`. 

1632 interpolation: Interpolation method used to resample the image if 

1633 the target size is different from that of the loaded image. 

1634 Supported methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. 

1635 If PIL version 1.1.3 or newer is installed, `"lanczos"` is also 

1636 supported. If PIL version 3.4.0 or newer is installed, `"box"` and 

1637 `"hamming"` are also supported. By default, `"nearest"` is used. 

1638 keep_aspect_ratio: Boolean, whether to resize images to a target 

1639 size without aspect ratio distortion. The image is cropped in 

1640 the center with target aspect ratio before resizing. 

1641 

1642 Returns: 

1643 A `DirectoryIterator` yielding tuples of `(x, y)` 

1644 where `x` is a numpy array containing a batch 

1645 of images with shape `(batch_size, *target_size, channels)` 

1646 and `y` is a numpy array of corresponding labels. 

1647 """ 

1648 return DirectoryIterator( 

1649 directory, 

1650 self, 

1651 target_size=target_size, 

1652 color_mode=color_mode, 

1653 keep_aspect_ratio=keep_aspect_ratio, 

1654 classes=classes, 

1655 class_mode=class_mode, 

1656 data_format=self.data_format, 

1657 batch_size=batch_size, 

1658 shuffle=shuffle, 

1659 seed=seed, 

1660 save_to_dir=save_to_dir, 

1661 save_prefix=save_prefix, 

1662 save_format=save_format, 

1663 follow_links=follow_links, 

1664 subset=subset, 

1665 interpolation=interpolation, 

1666 dtype=self.dtype, 

1667 ) 

1668 

1669 def flow_from_dataframe( 

1670 self, 

1671 dataframe, 

1672 directory=None, 

1673 x_col="filename", 

1674 y_col="class", 

1675 weight_col=None, 

1676 target_size=(256, 256), 

1677 color_mode="rgb", 

1678 classes=None, 

1679 class_mode="categorical", 

1680 batch_size=32, 

1681 shuffle=True, 

1682 seed=None, 

1683 save_to_dir=None, 

1684 save_prefix="", 

1685 save_format="png", 

1686 subset=None, 

1687 interpolation="nearest", 

1688 validate_filenames=True, 

1689 **kwargs, 

1690 ): 

1691 """Takes the dataframe and the path to a directory + generates batches. 

1692 

1693 The generated batches contain augmented/normalized data. 

1694 

1695 **A simple tutorial can be found **[here]( 

1696 http://bit.ly/keras_flow_from_dataframe). 

1697 

1698 Args: 

1699 dataframe: Pandas dataframe containing the filepaths relative to 

1700 `directory` (or absolute paths if `directory` is None) of the 

1701 images in a string column. It should include other column/s 

1702 depending on the `class_mode`: 

1703 - if `class_mode` is `"categorical"` (default value) it must 

1704 include the `y_col` column with the class/es of each image. 

1705 Values in column can be string/list/tuple if a single class 

1706 or list/tuple if multiple classes. 

1707 - if `class_mode` is `"binary"` or `"sparse"` it must include 

1708 the given `y_col` column with class values as strings. 

1709 - if `class_mode` is `"raw"` or `"multi_output"` it should 

1710 contain the columns specified in `y_col`. 

1711 - if `class_mode` is `"input"` or `None` no extra column is 

1712 needed. 

1713 directory: string, path to the directory to read images from. If 

1714 `None`, data in `x_col` column should be absolute paths. 

1715 x_col: string, column in `dataframe` that contains the filenames (or 

1716 absolute paths if `directory` is `None`). 

1717 y_col: string or list, column/s in `dataframe` that has the target 

1718 data. 

1719 weight_col: string, column in `dataframe` that contains the sample 

1720 weights. Default: `None`. 

1721 target_size: tuple of integers `(height, width)`, default: `(256, 

1722 256)`. The dimensions to which all images found will be resized. 

1723 color_mode: one of "grayscale", "rgb", "rgba". Default: "rgb". 

1724 Whether the images will be converted to have 1 or 3 color 

1725 channels. 

1726 classes: optional list of classes (e.g. `['dogs', 'cats']`). Default 

1727 is None. If not provided, the list of classes will be 

1728 automatically inferred from the `y_col`, which will map to the 

1729 label indices, will be alphanumeric). The dictionary containing 

1730 the mapping from class names to class indices can be obtained via 

1731 the attribute `class_indices`. 

1732 class_mode: one of "binary", "categorical", "input", "multi_output", 

1733 "raw", sparse" or None. Default: "categorical". 

1734 Mode for yielding the targets: 

1735 - `"binary"`: 1D numpy array of binary labels, 

1736 - `"categorical"`: 2D numpy array of one-hot encoded labels. 

1737 Supports multi-label output. 

1738 - `"input"`: images identical to input images (mainly used to 

1739 work with autoencoders), 

1740 - `"multi_output"`: list with the values of the different 

1741 columns, 

1742 - `"raw"`: numpy array of values in `y_col` column(s), 

1743 - `"sparse"`: 1D numpy array of integer labels, 

1744 - `None`, no targets are returned (the generator will only yield 

1745 batches of image data, which is useful to use in 

1746 `model.predict()`). 

1747 batch_size: size of the batches of data (default: 32). 

1748 shuffle: whether to shuffle the data (default: True) 

1749 seed: optional random seed for shuffling and transformations. 

1750 save_to_dir: None or str (default: None). This allows you to 

1751 optionally specify a directory to which to save the augmented 

1752 pictures being generated (useful for visualizing what you are 

1753 doing). 

1754 save_prefix: str. Prefix to use for filenames of saved pictures 

1755 (only relevant if `save_to_dir` is set). 

1756 save_format: one of "png", "jpeg", "bmp", "pdf", "ppm", "gif", 

1757 "tif", "jpg" (only relevant if `save_to_dir` is set). Default: 

1758 "png". 

1759 subset: Subset of data (`"training"` or `"validation"`) if 

1760 `validation_split` is set in `ImageDataGenerator`. 

1761 interpolation: Interpolation method used to resample the image if 

1762 the target size is different from that of the loaded image. 

1763 Supported methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. 

1764 If PIL version 1.1.3 or newer is installed, `"lanczos"` is also 

1765 supported. If PIL version 3.4.0 or newer is installed, `"box"` and 

1766 `"hamming"` are also supported. By default, `"nearest"` is used. 

1767 validate_filenames: Boolean, whether to validate image filenames in 

1768 `x_col`. If `True`, invalid images will be ignored. Disabling this 

1769 option can lead to speed-up in the execution of this function. 

1770 Defaults to `True`. 

1771 **kwargs: legacy arguments for raising deprecation warnings. 

1772 

1773 Returns: 

1774 A `DataFrameIterator` yielding tuples of `(x, y)` 

1775 where `x` is a numpy array containing a batch 

1776 of images with shape `(batch_size, *target_size, channels)` 

1777 and `y` is a numpy array of corresponding labels. 

1778 """ 

1779 if "has_ext" in kwargs: 

1780 warnings.warn( 

1781 "has_ext is deprecated, filenames in the dataframe have " 

1782 "to match the exact filenames in disk.", 

1783 DeprecationWarning, 

1784 ) 

1785 if "sort" in kwargs: 

1786 warnings.warn( 

1787 "sort is deprecated, batches will be created in the" 

1788 "same order than the filenames provided if shuffle" 

1789 "is set to False.", 

1790 DeprecationWarning, 

1791 ) 

1792 if class_mode == "other": 

1793 warnings.warn( 

1794 '`class_mode` "other" is deprecated, please use ' 

1795 '`class_mode` "raw".', 

1796 DeprecationWarning, 

1797 ) 

1798 class_mode = "raw" 

1799 if "drop_duplicates" in kwargs: 

1800 warnings.warn( 

1801 "drop_duplicates is deprecated, you can drop duplicates " 

1802 "by using the pandas.DataFrame.drop_duplicates method.", 

1803 DeprecationWarning, 

1804 ) 

1805 

1806 return DataFrameIterator( 

1807 dataframe, 

1808 directory, 

1809 self, 

1810 x_col=x_col, 

1811 y_col=y_col, 

1812 weight_col=weight_col, 

1813 target_size=target_size, 

1814 color_mode=color_mode, 

1815 classes=classes, 

1816 class_mode=class_mode, 

1817 data_format=self.data_format, 

1818 batch_size=batch_size, 

1819 shuffle=shuffle, 

1820 seed=seed, 

1821 save_to_dir=save_to_dir, 

1822 save_prefix=save_prefix, 

1823 save_format=save_format, 

1824 subset=subset, 

1825 interpolation=interpolation, 

1826 validate_filenames=validate_filenames, 

1827 dtype=self.dtype, 

1828 ) 

1829 

1830 def standardize(self, x): 

1831 """Applies the normalization configuration in-place to a batch of 

1832 inputs. 

1833 

1834 `x` is changed in-place since the function is mainly used internally 

1835 to standardize images and feed them to your network. If a copy of `x` 

1836 would be created instead it would have a significant performance cost. 

1837 If you want to apply this method without changing the input in-place 

1838 you can call the method creating a copy before: 

1839 

1840 standardize(np.copy(x)) 

1841 

1842 Args: 

1843 x: Batch of inputs to be normalized. 

1844 

1845 Returns: 

1846 The inputs, normalized. 

1847 """ 

1848 if self.preprocessing_function: 

1849 x = self.preprocessing_function(x) 

1850 if self.rescale: 

1851 x *= self.rescale 

1852 if self.samplewise_center: 

1853 x -= np.mean(x, keepdims=True) 

1854 if self.samplewise_std_normalization: 

1855 x /= np.std(x, keepdims=True) + 1e-6 

1856 

1857 if self.featurewise_center: 

1858 if self.mean is not None: 

1859 x -= self.mean 

1860 else: 

1861 warnings.warn( 

1862 "This ImageDataGenerator specifies " 

1863 "`featurewise_center`, but it hasn't " 

1864 "been fit on any training data. Fit it " 

1865 "first by calling `.fit(numpy_data)`." 

1866 ) 

1867 if self.featurewise_std_normalization: 

1868 if self.std is not None: 

1869 x /= self.std + 1e-6 

1870 else: 

1871 warnings.warn( 

1872 "This ImageDataGenerator specifies " 

1873 "`featurewise_std_normalization`, " 

1874 "but it hasn't " 

1875 "been fit on any training data. Fit it " 

1876 "first by calling `.fit(numpy_data)`." 

1877 ) 

1878 if self.zca_whitening: 

1879 if self.zca_whitening_matrix is not None: 

1880 flat_x = x.reshape(-1, np.prod(x.shape[-3:])) 

1881 white_x = flat_x @ self.zca_whitening_matrix 

1882 x = np.reshape(white_x, x.shape) 

1883 else: 

1884 warnings.warn( 

1885 "This ImageDataGenerator specifies " 

1886 "`zca_whitening`, but it hasn't " 

1887 "been fit on any training data. Fit it " 

1888 "first by calling `.fit(numpy_data)`." 

1889 ) 

1890 return x 

1891 

1892 def get_random_transform(self, img_shape, seed=None): 

1893 """Generates random parameters for a transformation. 

1894 

1895 Args: 

1896 img_shape: Tuple of integers. 

1897 Shape of the image that is transformed. 

1898 seed: Random seed. 

1899 

1900 Returns: 

1901 A dictionary containing randomly chosen parameters describing the 

1902 transformation. 

1903 """ 

1904 img_row_axis = self.row_axis - 1 

1905 img_col_axis = self.col_axis - 1 

1906 

1907 if seed is not None: 

1908 np.random.seed(seed) 

1909 

1910 if self.rotation_range: 

1911 theta = np.random.uniform(-self.rotation_range, self.rotation_range) 

1912 else: 

1913 theta = 0 

1914 

1915 if self.height_shift_range: 

1916 try: # 1-D array-like or int 

1917 tx = np.random.choice(self.height_shift_range) 

1918 tx *= np.random.choice([-1, 1]) 

1919 except ValueError: # floating point 

1920 tx = np.random.uniform( 

1921 -self.height_shift_range, self.height_shift_range 

1922 ) 

1923 if np.max(self.height_shift_range) < 1: 

1924 tx *= img_shape[img_row_axis] 

1925 else: 

1926 tx = 0 

1927 

1928 if self.width_shift_range: 

1929 try: # 1-D array-like or int 

1930 ty = np.random.choice(self.width_shift_range) 

1931 ty *= np.random.choice([-1, 1]) 

1932 except ValueError: # floating point 

1933 ty = np.random.uniform( 

1934 -self.width_shift_range, self.width_shift_range 

1935 ) 

1936 if np.max(self.width_shift_range) < 1: 

1937 ty *= img_shape[img_col_axis] 

1938 else: 

1939 ty = 0 

1940 

1941 if self.shear_range: 

1942 shear = np.random.uniform(-self.shear_range, self.shear_range) 

1943 else: 

1944 shear = 0 

1945 

1946 if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: 

1947 zx, zy = 1, 1 

1948 else: 

1949 zx, zy = np.random.uniform( 

1950 self.zoom_range[0], self.zoom_range[1], 2 

1951 ) 

1952 

1953 flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip 

1954 flip_vertical = (np.random.random() < 0.5) * self.vertical_flip 

1955 

1956 channel_shift_intensity = None 

1957 if self.channel_shift_range != 0: 

1958 channel_shift_intensity = np.random.uniform( 

1959 -self.channel_shift_range, self.channel_shift_range 

1960 ) 

1961 

1962 brightness = None 

1963 if self.brightness_range is not None: 

1964 brightness = np.random.uniform( 

1965 self.brightness_range[0], self.brightness_range[1] 

1966 ) 

1967 

1968 transform_parameters = { 

1969 "theta": theta, 

1970 "tx": tx, 

1971 "ty": ty, 

1972 "shear": shear, 

1973 "zx": zx, 

1974 "zy": zy, 

1975 "flip_horizontal": flip_horizontal, 

1976 "flip_vertical": flip_vertical, 

1977 "channel_shift_intensity": channel_shift_intensity, 

1978 "brightness": brightness, 

1979 } 

1980 

1981 return transform_parameters 

1982 

1983 def apply_transform(self, x, transform_parameters): 

1984 """Applies a transformation to an image according to given parameters. 

1985 

1986 Args: 

1987 x: 3D tensor, single image. 

1988 transform_parameters: Dictionary with string - parameter pairs 

1989 describing the transformation. 

1990 Currently, the following parameters 

1991 from the dictionary are used: 

1992 - `'theta'`: Float. Rotation angle in degrees. 

1993 - `'tx'`: Float. Shift in the x direction. 

1994 - `'ty'`: Float. Shift in the y direction. 

1995 - `'shear'`: Float. Shear angle in degrees. 

1996 - `'zx'`: Float. Zoom in the x direction. 

1997 - `'zy'`: Float. Zoom in the y direction. 

1998 - `'flip_horizontal'`: Boolean. Horizontal flip. 

1999 - `'flip_vertical'`: Boolean. Vertical flip. 

2000 - `'channel_shift_intensity'`: Float. Channel shift intensity. 

2001 - `'brightness'`: Float. Brightness shift intensity. 

2002 

2003 Returns: 

2004 A transformed version of the input (same shape). 

2005 """ 

2006 # x is a single image, so it doesn't have image number at index 0 

2007 img_row_axis = self.row_axis - 1 

2008 img_col_axis = self.col_axis - 1 

2009 img_channel_axis = self.channel_axis - 1 

2010 

2011 x = apply_affine_transform( 

2012 x, 

2013 transform_parameters.get("theta", 0), 

2014 transform_parameters.get("tx", 0), 

2015 transform_parameters.get("ty", 0), 

2016 transform_parameters.get("shear", 0), 

2017 transform_parameters.get("zx", 1), 

2018 transform_parameters.get("zy", 1), 

2019 row_axis=img_row_axis, 

2020 col_axis=img_col_axis, 

2021 channel_axis=img_channel_axis, 

2022 fill_mode=self.fill_mode, 

2023 cval=self.cval, 

2024 order=self.interpolation_order, 

2025 ) 

2026 

2027 if transform_parameters.get("channel_shift_intensity") is not None: 

2028 x = apply_channel_shift( 

2029 x, 

2030 transform_parameters["channel_shift_intensity"], 

2031 img_channel_axis, 

2032 ) 

2033 

2034 if transform_parameters.get("flip_horizontal", False): 

2035 x = flip_axis(x, img_col_axis) 

2036 

2037 if transform_parameters.get("flip_vertical", False): 

2038 x = flip_axis(x, img_row_axis) 

2039 

2040 if transform_parameters.get("brightness") is not None: 

2041 x = apply_brightness_shift( 

2042 x, transform_parameters["brightness"], False 

2043 ) 

2044 

2045 return x 

2046 

2047 def random_transform(self, x, seed=None): 

2048 """Applies a random transformation to an image. 

2049 

2050 Args: 

2051 x: 3D tensor, single image. 

2052 seed: Random seed. 

2053 

2054 Returns: 

2055 A randomly transformed version of the input (same shape). 

2056 """ 

2057 params = self.get_random_transform(x.shape, seed) 

2058 return self.apply_transform(x, params) 

2059 

2060 def fit(self, x, augment=False, rounds=1, seed=None): 

2061 """Fits the data generator to some sample data. 

2062 

2063 This computes the internal data stats related to the 

2064 data-dependent transformations, based on an array of sample data. 

2065 

2066 Only required if `featurewise_center` or 

2067 `featurewise_std_normalization` or `zca_whitening` are set to True. 

2068 

2069 When `rescale` is set to a value, rescaling is applied to 

2070 sample data before computing the internal data stats. 

2071 

2072 Args: 

2073 x: Sample data. Should have rank 4. 

2074 In case of grayscale data, 

2075 the channels axis should have value 1, in case 

2076 of RGB data, it should have value 3, and in case 

2077 of RGBA data, it should have value 4. 

2078 augment: Boolean (default: False). 

2079 Whether to fit on randomly augmented samples. 

2080 rounds: Int (default: 1). 

2081 If using data augmentation (`augment=True`), 

2082 this is how many augmentation passes over the data to use. 

2083 seed: Int (default: None). Random seed. 

2084 """ 

2085 x = np.asarray(x, dtype=self.dtype) 

2086 if x.ndim != 4: 

2087 raise ValueError( 

2088 "Input to `.fit()` should have rank 4. Got array with shape: " 

2089 + str(x.shape) 

2090 ) 

2091 if x.shape[self.channel_axis] not in {1, 3, 4}: 

2092 warnings.warn( 

2093 "Expected input to be images (as Numpy array) " 

2094 'following the data format convention "' 

2095 + self.data_format 

2096 + '" (channels on axis ' 

2097 + str(self.channel_axis) 

2098 + "), i.e. expected either 1, 3 or 4 channels on axis " 

2099 + str(self.channel_axis) 

2100 + ". However, it was passed an array with shape " 

2101 + str(x.shape) 

2102 + " (" 

2103 + str(x.shape[self.channel_axis]) 

2104 + " channels)." 

2105 ) 

2106 

2107 if seed is not None: 

2108 np.random.seed(seed) 

2109 

2110 x = np.copy(x) 

2111 if self.rescale: 

2112 x *= self.rescale 

2113 

2114 if augment: 

2115 ax = np.zeros( 

2116 tuple([rounds * x.shape[0]] + list(x.shape)[1:]), 

2117 dtype=self.dtype, 

2118 ) 

2119 for r in range(rounds): 

2120 for i in range(x.shape[0]): 

2121 ax[i + r * x.shape[0]] = self.random_transform(x[i]) 

2122 x = ax 

2123 

2124 if self.featurewise_center: 

2125 self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) 

2126 broadcast_shape = [1, 1, 1] 

2127 broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] 

2128 self.mean = np.reshape(self.mean, broadcast_shape) 

2129 x -= self.mean 

2130 

2131 if self.featurewise_std_normalization: 

2132 self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) 

2133 broadcast_shape = [1, 1, 1] 

2134 broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] 

2135 self.std = np.reshape(self.std, broadcast_shape) 

2136 x /= self.std + 1e-6 

2137 

2138 if self.zca_whitening: 

2139 n = len(x) 

2140 flat_x = np.reshape(x, (n, -1)) 

2141 

2142 u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False) 

2143 s_inv = np.sqrt(n) / (s + self.zca_epsilon) 

2144 self.zca_whitening_matrix = (u * s_inv).dot(u.T) 

2145 

2146 

2147@keras_export("keras.preprocessing.image.random_rotation") 

2148def random_rotation( 

2149 x, 

2150 rg, 

2151 row_axis=1, 

2152 col_axis=2, 

2153 channel_axis=0, 

2154 fill_mode="nearest", 

2155 cval=0.0, 

2156 interpolation_order=1, 

2157): 

2158 """Performs a random rotation of a Numpy image tensor. 

2159 

2160 Deprecated: `tf.keras.preprocessing.image.random_rotation` does not operate 

2161 on tensors and is not recommended for new code. Prefer 

2162 `tf.keras.layers.RandomRotation` which provides equivalent functionality as 

2163 a preprocessing layer. For more information, see the tutorial for 

2164 [augmenting images]( 

2165 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

2166 the [preprocessing layer guide]( 

2167 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

2168 

2169 Args: 

2170 x: Input tensor. Must be 3D. 

2171 rg: Rotation range, in degrees. 

2172 row_axis: Index of axis for rows in the input tensor. 

2173 col_axis: Index of axis for columns in the input tensor. 

2174 channel_axis: Index of axis for channels in the input tensor. 

2175 fill_mode: Points outside the boundaries of the input 

2176 are filled according to the given mode 

2177 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 

2178 cval: Value used for points outside the boundaries 

2179 of the input if `mode='constant'`. 

2180 interpolation_order: int, order of spline interpolation. 

2181 see `ndimage.interpolation.affine_transform` 

2182 

2183 Returns: 

2184 Rotated Numpy image tensor. 

2185 """ 

2186 theta = np.random.uniform(-rg, rg) 

2187 x = apply_affine_transform( 

2188 x, 

2189 theta=theta, 

2190 row_axis=row_axis, 

2191 col_axis=col_axis, 

2192 channel_axis=channel_axis, 

2193 fill_mode=fill_mode, 

2194 cval=cval, 

2195 order=interpolation_order, 

2196 ) 

2197 return x 

2198 

2199 

2200@keras_export("keras.preprocessing.image.random_shift") 

2201def random_shift( 

2202 x, 

2203 wrg, 

2204 hrg, 

2205 row_axis=1, 

2206 col_axis=2, 

2207 channel_axis=0, 

2208 fill_mode="nearest", 

2209 cval=0.0, 

2210 interpolation_order=1, 

2211): 

2212 """Performs a random spatial shift of a Numpy image tensor. 

2213 

2214 Deprecated: `tf.keras.preprocessing.image.random_shift` does not operate on 

2215 tensors and is not recommended for new code. Prefer 

2216 `tf.keras.layers.RandomTranslation` which provides equivalent functionality 

2217 as a preprocessing layer. For more information, see the tutorial for 

2218 [augmenting images]( 

2219 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

2220 the [preprocessing layer guide]( 

2221 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

2222 

2223 Args: 

2224 x: Input tensor. Must be 3D. 

2225 wrg: Width shift range, as a float fraction of the width. 

2226 hrg: Height shift range, as a float fraction of the height. 

2227 row_axis: Index of axis for rows in the input tensor. 

2228 col_axis: Index of axis for columns in the input tensor. 

2229 channel_axis: Index of axis for channels in the input tensor. 

2230 fill_mode: Points outside the boundaries of the input 

2231 are filled according to the given mode 

2232 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 

2233 cval: Value used for points outside the boundaries 

2234 of the input if `mode='constant'`. 

2235 interpolation_order: int, order of spline interpolation. 

2236 see `ndimage.interpolation.affine_transform` 

2237 

2238 Returns: 

2239 Shifted Numpy image tensor. 

2240 """ 

2241 h, w = x.shape[row_axis], x.shape[col_axis] 

2242 tx = np.random.uniform(-hrg, hrg) * h 

2243 ty = np.random.uniform(-wrg, wrg) * w 

2244 x = apply_affine_transform( 

2245 x, 

2246 tx=tx, 

2247 ty=ty, 

2248 row_axis=row_axis, 

2249 col_axis=col_axis, 

2250 channel_axis=channel_axis, 

2251 fill_mode=fill_mode, 

2252 cval=cval, 

2253 order=interpolation_order, 

2254 ) 

2255 return x 

2256 

2257 

2258@keras_export("keras.preprocessing.image.random_shear") 

2259def random_shear( 

2260 x, 

2261 intensity, 

2262 row_axis=1, 

2263 col_axis=2, 

2264 channel_axis=0, 

2265 fill_mode="nearest", 

2266 cval=0.0, 

2267 interpolation_order=1, 

2268): 

2269 """Performs a random spatial shear of a Numpy image tensor. 

2270 

2271 Args: 

2272 x: Input tensor. Must be 3D. 

2273 intensity: Transformation intensity in degrees. 

2274 row_axis: Index of axis for rows in the input tensor. 

2275 col_axis: Index of axis for columns in the input tensor. 

2276 channel_axis: Index of axis for channels in the input tensor. 

2277 fill_mode: Points outside the boundaries of the input 

2278 are filled according to the given mode 

2279 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 

2280 cval: Value used for points outside the boundaries 

2281 of the input if `mode='constant'`. 

2282 interpolation_order: int, order of spline interpolation. 

2283 see `ndimage.interpolation.affine_transform` 

2284 

2285 Returns: 

2286 Sheared Numpy image tensor. 

2287 """ 

2288 shear = np.random.uniform(-intensity, intensity) 

2289 x = apply_affine_transform( 

2290 x, 

2291 shear=shear, 

2292 row_axis=row_axis, 

2293 col_axis=col_axis, 

2294 channel_axis=channel_axis, 

2295 fill_mode=fill_mode, 

2296 cval=cval, 

2297 order=interpolation_order, 

2298 ) 

2299 return x 

2300 

2301 

2302@keras_export("keras.preprocessing.image.random_zoom") 

2303def random_zoom( 

2304 x, 

2305 zoom_range, 

2306 row_axis=1, 

2307 col_axis=2, 

2308 channel_axis=0, 

2309 fill_mode="nearest", 

2310 cval=0.0, 

2311 interpolation_order=1, 

2312): 

2313 """Performs a random spatial zoom of a Numpy image tensor. 

2314 

2315 Deprecated: `tf.keras.preprocessing.image.random_zoom` does not operate on 

2316 tensors and is not recommended for new code. Prefer 

2317 `tf.keras.layers.RandomZoom` which provides equivalent functionality as 

2318 a preprocessing layer. For more information, see the tutorial for 

2319 [augmenting images]( 

2320 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

2321 the [preprocessing layer guide]( 

2322 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

2323 

2324 Args: 

2325 x: Input tensor. Must be 3D. 

2326 zoom_range: Tuple of floats; zoom range for width and height. 

2327 row_axis: Index of axis for rows in the input tensor. 

2328 col_axis: Index of axis for columns in the input tensor. 

2329 channel_axis: Index of axis for channels in the input tensor. 

2330 fill_mode: Points outside the boundaries of the input 

2331 are filled according to the given mode 

2332 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 

2333 cval: Value used for points outside the boundaries 

2334 of the input if `mode='constant'`. 

2335 interpolation_order: int, order of spline interpolation. 

2336 see `ndimage.interpolation.affine_transform` 

2337 

2338 Returns: 

2339 Zoomed Numpy image tensor. 

2340 

2341 Raises: 

2342 ValueError: if `zoom_range` isn't a tuple. 

2343 """ 

2344 if len(zoom_range) != 2: 

2345 raise ValueError( 

2346 "`zoom_range` should be a tuple or list of two floats. Received: %s" 

2347 % (zoom_range,) 

2348 ) 

2349 

2350 if zoom_range[0] == 1 and zoom_range[1] == 1: 

2351 zx, zy = 1, 1 

2352 else: 

2353 zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) 

2354 x = apply_affine_transform( 

2355 x, 

2356 zx=zx, 

2357 zy=zy, 

2358 row_axis=row_axis, 

2359 col_axis=col_axis, 

2360 channel_axis=channel_axis, 

2361 fill_mode=fill_mode, 

2362 cval=cval, 

2363 order=interpolation_order, 

2364 ) 

2365 return x 

2366 

2367 

2368@keras_export("keras.preprocessing.image.apply_channel_shift") 

2369def apply_channel_shift(x, intensity, channel_axis=0): 

2370 """Performs a channel shift. 

2371 

2372 Args: 

2373 x: Input tensor. Must be 3D. 

2374 intensity: Transformation intensity. 

2375 channel_axis: Index of axis for channels in the input tensor. 

2376 

2377 Returns: 

2378 Numpy image tensor. 

2379 """ 

2380 x = np.rollaxis(x, channel_axis, 0) 

2381 min_x, max_x = np.min(x), np.max(x) 

2382 channel_images = [ 

2383 np.clip(x_channel + intensity, min_x, max_x) for x_channel in x 

2384 ] 

2385 x = np.stack(channel_images, axis=0) 

2386 x = np.rollaxis(x, 0, channel_axis + 1) 

2387 return x 

2388 

2389 

2390@keras_export("keras.preprocessing.image.random_channel_shift") 

2391def random_channel_shift(x, intensity_range, channel_axis=0): 

2392 """Performs a random channel shift. 

2393 

2394 Args: 

2395 x: Input tensor. Must be 3D. 

2396 intensity_range: Transformation intensity. 

2397 channel_axis: Index of axis for channels in the input tensor. 

2398 

2399 Returns: 

2400 Numpy image tensor. 

2401 """ 

2402 intensity = np.random.uniform(-intensity_range, intensity_range) 

2403 return apply_channel_shift(x, intensity, channel_axis=channel_axis) 

2404 

2405 

2406@keras_export("keras.preprocessing.image.apply_brightness_shift") 

2407def apply_brightness_shift(x, brightness, scale=True): 

2408 """Performs a brightness shift. 

2409 

2410 Args: 

2411 x: Input tensor. Must be 3D. 

2412 brightness: Float. The new brightness value. 

2413 scale: Whether to rescale the image such that minimum and maximum values 

2414 are 0 and 255 respectively. Default: True. 

2415 

2416 Returns: 

2417 Numpy image tensor. 

2418 

2419 Raises: 

2420 ImportError: if PIL is not available. 

2421 """ 

2422 if ImageEnhance is None: 

2423 raise ImportError( 

2424 "Using brightness shifts requires PIL. Install PIL or Pillow." 

2425 ) 

2426 x_min, x_max = np.min(x), np.max(x) 

2427 local_scale = (x_min < 0) or (x_max > 255) 

2428 x = image_utils.array_to_img(x, scale=local_scale or scale) 

2429 x = imgenhancer_Brightness = ImageEnhance.Brightness(x) 

2430 x = imgenhancer_Brightness.enhance(brightness) 

2431 x = image_utils.img_to_array(x) 

2432 if not scale and local_scale: 

2433 x = x / 255 * (x_max - x_min) + x_min 

2434 return x 

2435 

2436 

2437@keras_export("keras.preprocessing.image.random_brightness") 

2438def random_brightness(x, brightness_range, scale=True): 

2439 """Performs a random brightness shift. 

2440 

2441 Deprecated: `tf.keras.preprocessing.image.random_brightness` does not 

2442 operate on tensors and is not recommended for new code. Prefer 

2443 `tf.keras.layers.RandomBrightness` which provides equivalent functionality 

2444 as a preprocessing layer. For more information, see the tutorial for 

2445 [augmenting images]( 

2446 https://www.tensorflow.org/tutorials/images/data_augmentation), as well as 

2447 the [preprocessing layer guide]( 

2448 https://www.tensorflow.org/guide/keras/preprocessing_layers). 

2449 

2450 Args: 

2451 x: Input tensor. Must be 3D. 

2452 brightness_range: Tuple of floats; brightness range. 

2453 scale: Whether to rescale the image such that minimum and maximum values 

2454 are 0 and 255 respectively. Default: True. 

2455 

2456 Returns: 

2457 Numpy image tensor. 

2458 

2459 Raises: 

2460 ValueError if `brightness_range` isn't a tuple. 

2461 """ 

2462 if len(brightness_range) != 2: 

2463 raise ValueError( 

2464 "`brightness_range should be tuple or list of two floats. " 

2465 "Received: %s" % (brightness_range,) 

2466 ) 

2467 

2468 u = np.random.uniform(brightness_range[0], brightness_range[1]) 

2469 return apply_brightness_shift(x, u, scale) 

2470 

2471 

2472def transform_matrix_offset_center(matrix, x, y): 

2473 o_x = float(x) / 2 - 0.5 

2474 o_y = float(y) / 2 - 0.5 

2475 offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 

2476 reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 

2477 transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 

2478 return transform_matrix 

2479 

2480 

2481@keras_export("keras.preprocessing.image.apply_affine_transform") 

2482def apply_affine_transform( 

2483 x, 

2484 theta=0, 

2485 tx=0, 

2486 ty=0, 

2487 shear=0, 

2488 zx=1, 

2489 zy=1, 

2490 row_axis=1, 

2491 col_axis=2, 

2492 channel_axis=0, 

2493 fill_mode="nearest", 

2494 cval=0.0, 

2495 order=1, 

2496): 

2497 """Applies an affine transformation specified by the parameters given. 

2498 

2499 Args: 

2500 x: 3D numpy array - a 2D image with one or more channels. 

2501 theta: Rotation angle in degrees. 

2502 tx: Width shift. 

2503 ty: Heigh shift. 

2504 shear: Shear angle in degrees. 

2505 zx: Zoom in x direction. 

2506 zy: Zoom in y direction 

2507 row_axis: Index of axis for rows (aka Y axis) in the input 

2508 image. Direction: left to right. 

2509 col_axis: Index of axis for columns (aka X axis) in the input 

2510 image. Direction: top to bottom. 

2511 channel_axis: Index of axis for channels in the input image. 

2512 fill_mode: Points outside the boundaries of the input 

2513 are filled according to the given mode 

2514 (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 

2515 cval: Value used for points outside the boundaries 

2516 of the input if `mode='constant'`. 

2517 order: int, order of interpolation 

2518 

2519 Returns: 

2520 The transformed version of the input. 

2521 

2522 Raises: 

2523 ImportError: if SciPy is not available. 

2524 """ 

2525 if scipy is None: 

2526 raise ImportError("Image transformations require SciPy. Install SciPy.") 

2527 

2528 # Input sanity checks: 

2529 # 1. x must 2D image with one or more channels (i.e., a 3D tensor) 

2530 # 2. channels must be either first or last dimension 

2531 if np.unique([row_axis, col_axis, channel_axis]).size != 3: 

2532 raise ValueError( 

2533 "'row_axis', 'col_axis', and 'channel_axis' must be distinct" 

2534 ) 

2535 

2536 # shall we support negative indices? 

2537 valid_indices = set([0, 1, 2]) 

2538 actual_indices = set([row_axis, col_axis, channel_axis]) 

2539 if actual_indices != valid_indices: 

2540 raise ValueError( 

2541 f"Invalid axis' indices: {actual_indices - valid_indices}" 

2542 ) 

2543 

2544 if x.ndim != 3: 

2545 raise ValueError("Input arrays must be multi-channel 2D images.") 

2546 if channel_axis not in [0, 2]: 

2547 raise ValueError( 

2548 "Channels are allowed and the first and last dimensions." 

2549 ) 

2550 

2551 transform_matrix = None 

2552 if theta != 0: 

2553 theta = np.deg2rad(theta) 

2554 rotation_matrix = np.array( 

2555 [ 

2556 [np.cos(theta), -np.sin(theta), 0], 

2557 [np.sin(theta), np.cos(theta), 0], 

2558 [0, 0, 1], 

2559 ] 

2560 ) 

2561 transform_matrix = rotation_matrix 

2562 

2563 if tx != 0 or ty != 0: 

2564 shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) 

2565 if transform_matrix is None: 

2566 transform_matrix = shift_matrix 

2567 else: 

2568 transform_matrix = np.dot(transform_matrix, shift_matrix) 

2569 

2570 if shear != 0: 

2571 shear = np.deg2rad(shear) 

2572 shear_matrix = np.array( 

2573 [[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]] 

2574 ) 

2575 if transform_matrix is None: 

2576 transform_matrix = shear_matrix 

2577 else: 

2578 transform_matrix = np.dot(transform_matrix, shear_matrix) 

2579 

2580 if zx != 1 or zy != 1: 

2581 zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) 

2582 if transform_matrix is None: 

2583 transform_matrix = zoom_matrix 

2584 else: 

2585 transform_matrix = np.dot(transform_matrix, zoom_matrix) 

2586 

2587 if transform_matrix is not None: 

2588 h, w = x.shape[row_axis], x.shape[col_axis] 

2589 transform_matrix = transform_matrix_offset_center( 

2590 transform_matrix, h, w 

2591 ) 

2592 x = np.rollaxis(x, channel_axis, 0) 

2593 

2594 # Matrix construction assumes that coordinates are x, y (in that order). 

2595 # However, regular numpy arrays use y,x (aka i,j) indexing. 

2596 # Possible solution is: 

2597 # 1. Swap the x and y axes. 

2598 # 2. Apply transform. 

2599 # 3. Swap the x and y axes again to restore image-like data ordering. 

2600 # Mathematically, it is equivalent to the following transformation: 

2601 # M' = PMP, where P is the permutation matrix, M is the original 

2602 # transformation matrix. 

2603 if col_axis > row_axis: 

2604 transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]] 

2605 transform_matrix[[0, 1]] = transform_matrix[[1, 0]] 

2606 final_affine_matrix = transform_matrix[:2, :2] 

2607 final_offset = transform_matrix[:2, 2] 

2608 

2609 channel_images = [ 

2610 ndimage.interpolation.affine_transform( 

2611 x_channel, 

2612 final_affine_matrix, 

2613 final_offset, 

2614 order=order, 

2615 mode=fill_mode, 

2616 cval=cval, 

2617 ) 

2618 for x_channel in x 

2619 ] 

2620 x = np.stack(channel_images, axis=0) 

2621 x = np.rollaxis(x, 0, channel_axis + 1) 

2622 return x 

2623