Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/applications/resnet_rs.py: 26%
209 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""ResNet-RS models for Keras.
19Reference:
20- [Revisiting ResNets: Improved Training and Scaling Strategies](
21 https://arxiv.org/pdf/2103.07579.pdf)
22"""
23import sys
24from typing import Callable
25from typing import Dict
26from typing import List
27from typing import Union
29import tensorflow.compat.v2 as tf
31from keras.src import backend
32from keras.src import layers
33from keras.src.applications import imagenet_utils
34from keras.src.engine import training
35from keras.src.utils import data_utils
36from keras.src.utils import layer_utils
38# isort: off
39from tensorflow.python.util.tf_export import keras_export
41BASE_WEIGHTS_URL = (
42 "https://storage.googleapis.com/tensorflow/keras-applications/resnet_rs/"
43)
45WEIGHT_HASHES = {
46 "resnet-rs-101-i160.h5": "544b3434d00efc199d66e9058c7f3379",
47 "resnet-rs-101-i160_notop.h5": "82d5b90c5ce9d710da639d6216d0f979",
48 "resnet-rs-101-i192.h5": "eb285be29ab42cf4835ff20a5e3b5d23",
49 "resnet-rs-101-i192_notop.h5": "f9a0f6b85faa9c3db2b6e233c4eebb5b",
50 "resnet-rs-152-i192.h5": "8d72a301ed8a6f11a47c4ced4396e338",
51 "resnet-rs-152-i192_notop.h5": "5fbf7ac2155cb4d5a6180ee9e3aa8704",
52 "resnet-rs-152-i224.h5": "31a46a92ab21b84193d0d71dd8c3d03b",
53 "resnet-rs-152-i224_notop.h5": "dc8b2cba2005552eafa3167f00dc2133",
54 "resnet-rs-152-i256.h5": "ba6271b99bdeb4e7a9b15c05964ef4ad",
55 "resnet-rs-152-i256_notop.h5": "fa79794252dbe47c89130f65349d654a",
56 "resnet-rs-200-i256.h5": "a76930b741884e09ce90fa7450747d5f",
57 "resnet-rs-200-i256_notop.h5": "bbdb3994718dfc0d1cd45d7eff3f3d9c",
58 "resnet-rs-270-i256.h5": "20d575825ba26176b03cb51012a367a8",
59 "resnet-rs-270-i256_notop.h5": "2c42ecb22e35f3e23d2f70babce0a2aa",
60 "resnet-rs-350-i256.h5": "f4a039dc3c421321b7fc240494574a68",
61 "resnet-rs-350-i256_notop.h5": "6e44b55025bbdff8f51692a023143d66",
62 "resnet-rs-350-i320.h5": "7ccb858cc738305e8ceb3c0140bee393",
63 "resnet-rs-350-i320_notop.h5": "ab0c1f9079d2f85a9facbd2c88aa6079",
64 "resnet-rs-420-i320.h5": "ae0eb9bed39e64fc8d7e0db4018dc7e8",
65 "resnet-rs-420-i320_notop.h5": "fe6217c32be8305b1889657172b98884",
66 "resnet-rs-50-i160.h5": "69d9d925319f00a8bdd4af23c04e4102",
67 "resnet-rs-50-i160_notop.h5": "90daa68cd26c95aa6c5d25451e095529",
68}
70DEPTH_TO_WEIGHT_VARIANTS = {
71 50: [160],
72 101: [160, 192],
73 152: [192, 224, 256],
74 200: [256],
75 270: [256],
76 350: [256, 320],
77 420: [320],
78}
79BLOCK_ARGS = {
80 50: [
81 {"input_filters": 64, "num_repeats": 3},
82 {"input_filters": 128, "num_repeats": 4},
83 {"input_filters": 256, "num_repeats": 6},
84 {"input_filters": 512, "num_repeats": 3},
85 ],
86 101: [
87 {"input_filters": 64, "num_repeats": 3},
88 {"input_filters": 128, "num_repeats": 4},
89 {"input_filters": 256, "num_repeats": 23},
90 {"input_filters": 512, "num_repeats": 3},
91 ],
92 152: [
93 {"input_filters": 64, "num_repeats": 3},
94 {"input_filters": 128, "num_repeats": 8},
95 {"input_filters": 256, "num_repeats": 36},
96 {"input_filters": 512, "num_repeats": 3},
97 ],
98 200: [
99 {"input_filters": 64, "num_repeats": 3},
100 {"input_filters": 128, "num_repeats": 24},
101 {"input_filters": 256, "num_repeats": 36},
102 {"input_filters": 512, "num_repeats": 3},
103 ],
104 270: [
105 {"input_filters": 64, "num_repeats": 4},
106 {"input_filters": 128, "num_repeats": 29},
107 {"input_filters": 256, "num_repeats": 53},
108 {"input_filters": 512, "num_repeats": 4},
109 ],
110 350: [
111 {"input_filters": 64, "num_repeats": 4},
112 {"input_filters": 128, "num_repeats": 36},
113 {"input_filters": 256, "num_repeats": 72},
114 {"input_filters": 512, "num_repeats": 4},
115 ],
116 420: [
117 {"input_filters": 64, "num_repeats": 4},
118 {"input_filters": 128, "num_repeats": 44},
119 {"input_filters": 256, "num_repeats": 87},
120 {"input_filters": 512, "num_repeats": 4},
121 ],
122}
123CONV_KERNEL_INITIALIZER = {
124 "class_name": "VarianceScaling",
125 "config": {
126 "scale": 2.0,
127 "mode": "fan_out",
128 "distribution": "truncated_normal",
129 },
130}
132BASE_DOCSTRING = """Instantiates the {name} architecture.
134 Reference:
135 [Revisiting ResNets: Improved Training and Scaling Strategies](
136 https://arxiv.org/pdf/2103.07579.pdf)
138 For image classification use cases, see
139 [this page for detailed examples](
140 https://keras.io/api/applications/#usage-examples-for-image-classification-models).
142 For transfer learning use cases, make sure to read the
143 [guide to transfer learning & fine-tuning](
144 https://keras.io/guides/transfer_learning/).
146 Note: each Keras Application expects a specific kind of input preprocessing.
147 For ResNetRs, by default input preprocessing is included as a part of the
148 model (as a `Rescaling` layer), and thus
149 `tf.keras.applications.resnet_rs.preprocess_input` is actually a
150 pass-through function. In this use case, ResNetRS models expect their inputs
151 to be float tensors of pixels with values in the [0-255] range.
152 At the same time, preprocessing as a part of the model (i.e. `Rescaling`
153 layer) can be disabled by setting `include_preprocessing` argument to False.
154 With preprocessing disabled ResNetRS models expect their inputs to be float
155 tensors of pixels with values in the [-1, 1] range.
157 Args:
158 depth: Depth of ResNet network.
159 input_shape: optional shape tuple. It should have exactly 3 inputs
160 channels, and width and height should be no smaller than 32.
161 E.g. (200, 200, 3) would be one valid value.
162 bn_momentum: Momentum parameter for Batch Normalization layers.
163 bn_epsilon: Epsilon parameter for Batch Normalization layers.
164 activation: activation function.
165 se_ratio: Squeeze and Excitation layer ratio.
166 dropout_rate: dropout rate before final classifier layer.
167 drop_connect_rate: dropout rate at skip connections.
168 include_top: whether to include the fully-connected layer at the top of
169 the network.
170 block_args: list of dicts, parameters to construct block modules.
171 model_name: name of the model.
172 pooling: optional pooling mode for feature extraction when `include_top`
173 is `False`.
174 - `None` means that the output of the model will be
175 the 4D tensor output of the
176 last convolutional layer.
177 - `avg` means that global average pooling
178 will be applied to the output of the
179 last convolutional layer, and thus
180 the output of the model will be a 2D tensor.
181 - `max` means that global max pooling will
182 be applied.
183 weights: one of `None` (random initialization), `'imagenet'`
184 (pre-training on ImageNet), or the path to the weights file to be
185 loaded. Note: one model can have multiple imagenet variants
186 depending on input shape it was trained with. For input_shape
187 224x224 pass `imagenet-i224` as argument. By default, highest input
188 shape weights are downloaded.
189 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to
190 use as image input for the model.
191 classes: optional number of classes to classify images into, only to be
192 specified if `include_top` is True, and if no `weights` argument is
193 specified.
194 classifier_activation: A `str` or callable. The activation function to
195 use on the "top" layer. Ignored unless `include_top=True`. Set
196 `classifier_activation=None` to return the logits of the "top"
197 layer.
198 include_preprocessing: Boolean, whether to include the preprocessing
199 layer (`Rescaling`) at the bottom of the network. Defaults to
200 `True`. Note: Input image is normalized by ImageNet mean and
201 standard deviation.
203 Returns:
204 A `keras.Model` instance.
205"""
208def Conv2DFixedPadding(filters, kernel_size, strides, name=None):
209 """Conv2D block with fixed padding."""
210 if name is None:
211 counter = backend.get_uid("conv_")
212 name = f"conv_{counter}"
214 def apply(inputs):
215 if strides > 1:
216 inputs = fixed_padding(inputs, kernel_size)
217 return layers.Conv2D(
218 filters=filters,
219 kernel_size=kernel_size,
220 strides=strides,
221 padding="same" if strides == 1 else "valid",
222 use_bias=False,
223 kernel_initializer=CONV_KERNEL_INITIALIZER,
224 name=name,
225 )(inputs)
227 return apply
230def STEM(
231 bn_momentum: float = 0.0,
232 bn_epsilon: float = 1e-5,
233 activation: str = "relu",
234 name=None,
235):
236 """ResNet-D type STEM block."""
237 if name is None:
238 counter = backend.get_uid("stem_")
239 name = f"stem_{counter}"
241 def apply(inputs):
242 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
244 # First stem block
245 x = Conv2DFixedPadding(
246 filters=32, kernel_size=3, strides=2, name=name + "_stem_conv_1"
247 )(inputs)
248 x = layers.BatchNormalization(
249 axis=bn_axis,
250 momentum=bn_momentum,
251 epsilon=bn_epsilon,
252 name=name + "_stem_batch_norm_1",
253 )(x)
254 x = layers.Activation(activation, name=name + "_stem_act_1")(x)
256 # Second stem block
257 x = Conv2DFixedPadding(
258 filters=32, kernel_size=3, strides=1, name=name + "_stem_conv_2"
259 )(x)
260 x = layers.BatchNormalization(
261 axis=bn_axis,
262 momentum=bn_momentum,
263 epsilon=bn_epsilon,
264 name=name + "_stem_batch_norm_2",
265 )(x)
266 x = layers.Activation(activation, name=name + "_stem_act_2")(x)
268 # Final Stem block:
269 x = Conv2DFixedPadding(
270 filters=64, kernel_size=3, strides=1, name=name + "_stem_conv_3"
271 )(x)
272 x = layers.BatchNormalization(
273 axis=bn_axis,
274 momentum=bn_momentum,
275 epsilon=bn_epsilon,
276 name=name + "_stem_batch_norm_3",
277 )(x)
278 x = layers.Activation(activation, name=name + "_stem_act_3")(x)
280 # Replace stem max pool:
281 x = Conv2DFixedPadding(
282 filters=64, kernel_size=3, strides=2, name=name + "_stem_conv_4"
283 )(x)
284 x = layers.BatchNormalization(
285 axis=bn_axis,
286 momentum=bn_momentum,
287 epsilon=bn_epsilon,
288 name=name + "_stem_batch_norm_4",
289 )(x)
290 x = layers.Activation(activation, name=name + "_stem_act_4")(x)
291 return x
293 return apply
296def SE(
297 in_filters: int, se_ratio: float = 0.25, expand_ratio: int = 1, name=None
298):
299 """Squeeze and Excitation block."""
300 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
301 if name is None:
302 counter = backend.get_uid("se_")
303 name = f"se_{counter}"
305 def apply(inputs):
306 x = layers.GlobalAveragePooling2D(name=name + "_se_squeeze")(inputs)
307 if bn_axis == 1:
308 se_shape = (x.shape[-1], 1, 1)
309 else:
310 se_shape = (1, 1, x.shape[-1])
311 x = layers.Reshape(se_shape, name=name + "_se_reshape")(x)
313 num_reduced_filters = max(1, int(in_filters * 4 * se_ratio))
315 x = layers.Conv2D(
316 filters=num_reduced_filters,
317 kernel_size=[1, 1],
318 strides=[1, 1],
319 kernel_initializer=CONV_KERNEL_INITIALIZER,
320 padding="same",
321 use_bias=True,
322 activation="relu",
323 name=name + "_se_reduce",
324 )(x)
326 x = layers.Conv2D(
327 filters=4
328 * in_filters
329 * expand_ratio, # Expand ratio is 1 by default
330 kernel_size=[1, 1],
331 strides=[1, 1],
332 kernel_initializer=CONV_KERNEL_INITIALIZER,
333 padding="same",
334 use_bias=True,
335 activation="sigmoid",
336 name=name + "_se_expand",
337 )(x)
339 return layers.multiply([inputs, x], name=name + "_se_excite")
341 return apply
344def BottleneckBlock(
345 filters: int,
346 strides: int,
347 use_projection: bool,
348 bn_momentum: float = 0.0,
349 bn_epsilon: float = 1e-5,
350 activation: str = "relu",
351 se_ratio: float = 0.25,
352 survival_probability: float = 0.8,
353 name=None,
354):
355 """Bottleneck block variant for residual networks with BN."""
356 if name is None:
357 counter = backend.get_uid("block_0_")
358 name = f"block_0_{counter}"
360 def apply(inputs):
361 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
363 shortcut = inputs
365 if use_projection:
366 filters_out = filters * 4
367 if strides == 2:
368 shortcut = layers.AveragePooling2D(
369 pool_size=(2, 2),
370 strides=(2, 2),
371 padding="same",
372 name=name + "_projection_pooling",
373 )(inputs)
374 shortcut = Conv2DFixedPadding(
375 filters=filters_out,
376 kernel_size=1,
377 strides=1,
378 name=name + "_projection_conv",
379 )(shortcut)
380 else:
381 shortcut = Conv2DFixedPadding(
382 filters=filters_out,
383 kernel_size=1,
384 strides=strides,
385 name=name + "_projection_conv",
386 )(inputs)
388 shortcut = layers.BatchNormalization(
389 axis=bn_axis,
390 momentum=bn_momentum,
391 epsilon=bn_epsilon,
392 name=name + "_projection_batch_norm",
393 )(shortcut)
395 # First conv layer:
396 x = Conv2DFixedPadding(
397 filters=filters, kernel_size=1, strides=1, name=name + "_conv_1"
398 )(inputs)
399 x = layers.BatchNormalization(
400 axis=bn_axis,
401 momentum=bn_momentum,
402 epsilon=bn_epsilon,
403 name=name + "batch_norm_1",
404 )(x)
405 x = layers.Activation(activation, name=name + "_act_1")(x)
407 # Second conv layer:
408 x = Conv2DFixedPadding(
409 filters=filters,
410 kernel_size=3,
411 strides=strides,
412 name=name + "_conv_2",
413 )(x)
414 x = layers.BatchNormalization(
415 axis=bn_axis,
416 momentum=bn_momentum,
417 epsilon=bn_epsilon,
418 name=name + "_batch_norm_2",
419 )(x)
420 x = layers.Activation(activation, name=name + "_act_2")(x)
422 # Third conv layer:
423 x = Conv2DFixedPadding(
424 filters=filters * 4, kernel_size=1, strides=1, name=name + "_conv_3"
425 )(x)
426 x = layers.BatchNormalization(
427 axis=bn_axis,
428 momentum=bn_momentum,
429 epsilon=bn_epsilon,
430 name=name + "_batch_norm_3",
431 )(x)
433 if 0 < se_ratio < 1:
434 x = SE(filters, se_ratio=se_ratio, name=name + "_se")(x)
436 # Drop connect
437 if survival_probability:
438 x = layers.Dropout(
439 survival_probability,
440 noise_shape=(None, 1, 1, 1),
441 name=name + "_drop",
442 )(x)
444 x = layers.Add()([x, shortcut])
446 return layers.Activation(activation, name=name + "_output_act")(x)
448 return apply
451def BlockGroup(
452 filters,
453 strides,
454 num_repeats,
455 se_ratio: float = 0.25,
456 bn_epsilon: float = 1e-5,
457 bn_momentum: float = 0.0,
458 activation: str = "relu",
459 survival_probability: float = 0.8,
460 name=None,
461):
462 """Create one group of blocks for the ResNet model."""
463 if name is None:
464 counter = backend.get_uid("block_group_")
465 name = f"block_group_{counter}"
467 def apply(inputs):
468 # Only the first block per block_group uses projection shortcut and
469 # strides.
470 x = BottleneckBlock(
471 filters=filters,
472 strides=strides,
473 use_projection=True,
474 se_ratio=se_ratio,
475 bn_epsilon=bn_epsilon,
476 bn_momentum=bn_momentum,
477 activation=activation,
478 survival_probability=survival_probability,
479 name=name + "_block_0_",
480 )(inputs)
482 for i in range(1, num_repeats):
483 x = BottleneckBlock(
484 filters=filters,
485 strides=1,
486 use_projection=False,
487 se_ratio=se_ratio,
488 activation=activation,
489 bn_epsilon=bn_epsilon,
490 bn_momentum=bn_momentum,
491 survival_probability=survival_probability,
492 name=name + f"_block_{i}_",
493 )(x)
494 return x
496 return apply
499def get_survival_probability(init_rate, block_num, total_blocks):
500 """Get survival probability based on block number and initial rate."""
501 return init_rate * float(block_num) / total_blocks
504def allow_bigger_recursion(target_limit: int):
505 """Increase default recursion limit to create larger models."""
506 current_limit = sys.getrecursionlimit()
507 if current_limit < target_limit:
508 sys.setrecursionlimit(target_limit)
511def fixed_padding(inputs, kernel_size):
512 """Pad the input along the spatial dimensions independently of input
513 size."""
514 pad_total = kernel_size - 1
515 pad_beg = pad_total // 2
516 pad_end = pad_total - pad_beg
518 # Use ZeroPadding as to avoid TFOpLambda layer
519 padded_inputs = layers.ZeroPadding2D(
520 padding=((pad_beg, pad_end), (pad_beg, pad_end))
521 )(inputs)
523 return padded_inputs
526def ResNetRS(
527 depth: int,
528 input_shape=None,
529 bn_momentum=0.0,
530 bn_epsilon=1e-5,
531 activation: str = "relu",
532 se_ratio=0.25,
533 dropout_rate=0.25,
534 drop_connect_rate=0.2,
535 include_top=True,
536 block_args: List[Dict[str, int]] = None,
537 model_name="resnet-rs",
538 pooling=None,
539 weights="imagenet",
540 input_tensor=None,
541 classes=1000,
542 classifier_activation: Union[str, Callable] = "softmax",
543 include_preprocessing=True,
544):
545 """Build Resnet-RS model, given provided parameters.
547 Args:
548 depth: Depth of ResNet network.
549 input_shape: optional shape tuple. It should have exactly 3 inputs
550 channels, and width and height should be no smaller than 32. E.g.
551 (200, 200, 3) would be one valid value.
552 bn_momentum: Momentum parameter for Batch Normalization layers.
553 bn_epsilon: Epsilon parameter for Batch Normalization layers.
554 activation: activation function.
555 se_ratio: Squeeze and Excitation layer ratio.
556 dropout_rate: dropout rate before final classifier layer.
557 drop_connect_rate: dropout rate at skip connections.
558 include_top: whether to include the fully-connected layer at the top of
559 the network.
560 block_args: list of dicts, parameters to construct block modules.
561 model_name: name of the model.
562 pooling: optional pooling mode for feature extraction when `include_top`
563 is `False`.
564 - `None` means that the output of the model will be the 4D tensor
565 output of the last convolutional layer.
566 - `avg` means that global average pooling will be applied to the
567 output of the last convolutional layer, and thus the output of the
568 model will be a 2D tensor.
569 - `max` means that global max pooling will be applied.
570 weights: one of `None` (random initialization), `'imagenet'`
571 (pre-training on ImageNet), or the path to the weights file to be
572 loaded. Note- one model can have multiple imagenet variants depending
573 on input shape it was trained with. For input_shape 224x224 pass
574 `imagenet-i224` as argument. By default, highest input shape weights
575 are downloaded.
576 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to
577 use as image input for the model.
578 classes: optional number of classes to classify images into, only to be
579 specified if `include_top` is True, and if no `weights` argument is
580 specified.
581 classifier_activation: A `str` or callable. The activation function to
582 use on the "top" layer. Ignored unless `include_top=True`. Set
583 `classifier_activation=None` to return the logits of the "top" layer.
584 include_preprocessing: Boolean, whether to include the preprocessing
585 layer (`Rescaling`) at the bottom of the network. Defaults to `True`.
586 Note- Input image is normalized by ImageNet mean and standard
587 deviation.
589 Returns:
590 A `tf.keras.Model` instance.
592 Raises:
593 ValueError: in case of invalid argument for `weights`, or invalid input
594 shape.
595 ValueError: if `classifier_activation` is not `softmax` or `None` when
596 using a pretrained top layer.
597 """
598 # Validate parameters
599 available_weight_variants = DEPTH_TO_WEIGHT_VARIANTS[depth]
600 if weights == "imagenet":
601 max_input_shape = max(available_weight_variants)
602 # `imagenet` argument without explicit weights input size.
603 # Picking weights trained with biggest available shape
604 weights = f"{weights}-i{max_input_shape}"
606 weights_allow_list = [f"imagenet-i{x}" for x in available_weight_variants]
607 if not (
608 weights in {*weights_allow_list, None} or tf.io.gfile.exists(weights)
609 ):
610 raise ValueError(
611 "The `weights` argument should be either "
612 "`None` (random initialization), `'imagenet'` "
613 "(pre-training on ImageNet, with highest available input shape),"
614 " or the path to the weights file to be loaded. "
615 f"For ResNetRS{depth} the following weight variants are "
616 f"available {weights_allow_list} (default=highest)."
617 f" Received weights={weights}"
618 )
620 if weights in weights_allow_list and include_top and classes != 1000:
621 raise ValueError(
622 "If using `weights` as `'imagenet'` or any "
623 f"of {weights_allow_list} "
624 "with `include_top` as true, `classes` should be 1000. "
625 f"Received classes={classes}"
626 )
628 input_shape = imagenet_utils.obtain_input_shape(
629 input_shape,
630 default_size=224,
631 min_size=32,
632 data_format=backend.image_data_format(),
633 require_flatten=include_top,
634 weights=weights,
635 )
636 # Define input tensor
637 if input_tensor is None:
638 img_input = layers.Input(shape=input_shape)
639 else:
640 if not backend.is_keras_tensor(input_tensor):
641 img_input = layers.Input(tensor=input_tensor, shape=input_shape)
642 else:
643 img_input = input_tensor
645 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
647 x = img_input
649 if include_preprocessing:
650 num_channels = input_shape[bn_axis - 1]
651 x = layers.Rescaling(scale=1.0 / 255)(x)
652 if num_channels == 3:
653 x = layers.Normalization(
654 mean=[0.485, 0.456, 0.406],
655 variance=[0.229**2, 0.224**2, 0.225**2],
656 axis=bn_axis,
657 )(x)
659 # Build stem
660 x = STEM(
661 bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, activation=activation
662 )(x)
664 # Build blocks
665 if block_args is None:
666 block_args = BLOCK_ARGS[depth]
668 for i, args in enumerate(block_args):
669 survival_probability = get_survival_probability(
670 init_rate=drop_connect_rate,
671 block_num=i + 2,
672 total_blocks=len(block_args) + 1,
673 )
675 x = BlockGroup(
676 filters=args["input_filters"],
677 activation=activation,
678 strides=(1 if i == 0 else 2),
679 num_repeats=args["num_repeats"],
680 se_ratio=se_ratio,
681 bn_momentum=bn_momentum,
682 bn_epsilon=bn_epsilon,
683 survival_probability=survival_probability,
684 name=f"BlockGroup{i + 2}_",
685 )(x)
687 # Build head:
688 if include_top:
689 x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
690 if dropout_rate > 0:
691 x = layers.Dropout(dropout_rate, name="top_dropout")(x)
693 imagenet_utils.validate_activation(classifier_activation, weights)
694 x = layers.Dense(
695 classes, activation=classifier_activation, name="predictions"
696 )(x)
697 else:
698 if pooling == "avg":
699 x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
700 elif pooling == "max":
701 x = layers.GlobalMaxPooling2D(name="max_pool")(x)
703 # Ensure that the model takes into account
704 # any potential predecessors of `input_tensor`.
705 if input_tensor is not None:
706 inputs = layer_utils.get_source_inputs(input_tensor)
707 else:
708 inputs = img_input
710 # Create model.
711 model = training.Model(inputs, x, name=model_name)
713 # Download weights
714 if weights in weights_allow_list:
715 weights_input_shape = weights.split("-")[-1] # e. g. "i160"
716 weights_name = f"{model_name}-{weights_input_shape}"
717 if not include_top:
718 weights_name += "_notop"
720 filename = f"{weights_name}.h5"
721 download_url = BASE_WEIGHTS_URL + filename
722 weights_path = data_utils.get_file(
723 fname=filename,
724 origin=download_url,
725 cache_subdir="models",
726 file_hash=WEIGHT_HASHES[filename],
727 )
728 model.load_weights(weights_path)
730 elif weights is not None:
731 model.load_weights(weights)
733 return model
736@keras_export(
737 "keras.applications.resnet_rs.ResNetRS50", "keras.applications.ResNetRS50"
738)
739def ResNetRS50(
740 include_top=True,
741 weights="imagenet",
742 classes=1000,
743 input_shape=None,
744 input_tensor=None,
745 pooling=None,
746 classifier_activation="softmax",
747 include_preprocessing=True,
748):
749 """Build ResNet-RS50 model."""
750 return ResNetRS(
751 depth=50,
752 include_top=include_top,
753 drop_connect_rate=0.0,
754 dropout_rate=0.25,
755 weights=weights,
756 classes=classes,
757 input_shape=input_shape,
758 input_tensor=input_tensor,
759 pooling=pooling,
760 classifier_activation=classifier_activation,
761 model_name="resnet-rs-50",
762 include_preprocessing=include_preprocessing,
763 )
766@keras_export(
767 "keras.applications.resnet_rs.ResNetRS101", "keras.applications.ResNetRS101"
768)
769def ResNetRS101(
770 include_top=True,
771 weights="imagenet",
772 classes=1000,
773 input_shape=None,
774 input_tensor=None,
775 pooling=None,
776 classifier_activation="softmax",
777 include_preprocessing=True,
778):
779 """Build ResNet-RS101 model."""
780 return ResNetRS(
781 depth=101,
782 include_top=include_top,
783 drop_connect_rate=0.0,
784 dropout_rate=0.25,
785 weights=weights,
786 classes=classes,
787 input_shape=input_shape,
788 input_tensor=input_tensor,
789 pooling=pooling,
790 classifier_activation=classifier_activation,
791 model_name="resnet-rs-101",
792 include_preprocessing=include_preprocessing,
793 )
796@keras_export(
797 "keras.applications.resnet_rs.ResNetRS152", "keras.applications.ResNetRS152"
798)
799def ResNetRS152(
800 include_top=True,
801 weights="imagenet",
802 classes=1000,
803 input_shape=None,
804 input_tensor=None,
805 pooling=None,
806 classifier_activation="softmax",
807 include_preprocessing=True,
808):
809 """Build ResNet-RS152 model."""
810 return ResNetRS(
811 depth=152,
812 include_top=include_top,
813 drop_connect_rate=0.0,
814 dropout_rate=0.25,
815 weights=weights,
816 classes=classes,
817 input_shape=input_shape,
818 input_tensor=input_tensor,
819 pooling=pooling,
820 classifier_activation=classifier_activation,
821 model_name="resnet-rs-152",
822 include_preprocessing=include_preprocessing,
823 )
826@keras_export(
827 "keras.applications.resnet_rs.ResNetRS200", "keras.applications.ResNetRS200"
828)
829def ResNetRS200(
830 include_top=True,
831 weights="imagenet",
832 classes=1000,
833 input_shape=None,
834 input_tensor=None,
835 pooling=None,
836 classifier_activation="softmax",
837 include_preprocessing=True,
838):
839 """Build ResNet-RS200 model."""
840 return ResNetRS(
841 depth=200,
842 include_top=include_top,
843 drop_connect_rate=0.1,
844 dropout_rate=0.25,
845 weights=weights,
846 classes=classes,
847 input_shape=input_shape,
848 input_tensor=input_tensor,
849 pooling=pooling,
850 classifier_activation=classifier_activation,
851 model_name="resnet-rs-200",
852 include_preprocessing=include_preprocessing,
853 )
856@keras_export(
857 "keras.applications.resnet_rs.ResNetRS270", "keras.applications.ResNetRS270"
858)
859def ResNetRS270(
860 include_top=True,
861 weights="imagenet",
862 classes=1000,
863 input_shape=None,
864 input_tensor=None,
865 pooling=None,
866 classifier_activation="softmax",
867 include_preprocessing=True,
868):
869 """Build ResNet-RS-270 model."""
870 allow_bigger_recursion(1300)
871 return ResNetRS(
872 depth=270,
873 include_top=include_top,
874 drop_connect_rate=0.1,
875 dropout_rate=0.25,
876 weights=weights,
877 classes=classes,
878 input_shape=input_shape,
879 input_tensor=input_tensor,
880 pooling=pooling,
881 classifier_activation=classifier_activation,
882 model_name="resnet-rs-270",
883 include_preprocessing=include_preprocessing,
884 )
887@keras_export(
888 "keras.applications.resnet_rs.ResNetRS350", "keras.applications.ResNetRS350"
889)
890def ResNetRS350(
891 include_top=True,
892 weights="imagenet",
893 classes=1000,
894 input_shape=None,
895 input_tensor=None,
896 pooling=None,
897 classifier_activation="softmax",
898 include_preprocessing=True,
899):
900 """Build ResNet-RS350 model."""
901 allow_bigger_recursion(1500)
902 return ResNetRS(
903 depth=350,
904 include_top=include_top,
905 drop_connect_rate=0.1,
906 dropout_rate=0.4,
907 weights=weights,
908 classes=classes,
909 input_shape=input_shape,
910 input_tensor=input_tensor,
911 pooling=pooling,
912 classifier_activation=classifier_activation,
913 model_name="resnet-rs-350",
914 include_preprocessing=include_preprocessing,
915 )
918@keras_export(
919 "keras.applications.resnet_rs.ResNetRS420", "keras.applications.ResNetRS420"
920)
921def ResNetRS420(
922 include_top=True,
923 weights="imagenet",
924 classes=1000,
925 input_shape=None,
926 input_tensor=None,
927 pooling=None,
928 classifier_activation="softmax",
929 include_preprocessing=True,
930):
931 """Build ResNet-RS420 model."""
932 allow_bigger_recursion(1800)
933 return ResNetRS(
934 depth=420,
935 include_top=include_top,
936 dropout_rate=0.4,
937 drop_connect_rate=0.1,
938 weights=weights,
939 classes=classes,
940 input_shape=input_shape,
941 input_tensor=input_tensor,
942 pooling=pooling,
943 classifier_activation=classifier_activation,
944 model_name="resnet-rs-420",
945 include_preprocessing=include_preprocessing,
946 )
949@keras_export("keras.applications.resnet_rs.preprocess_input")
950def preprocess_input(x, data_format=None):
951 """A placeholder method for backward compatibility.
953 The preprocessing logic has been included in the ResnetRS model
954 implementation. Users are no longer required to call this method to
955 normalize
956 the input data. This method does nothing and only kept as a placeholder to
957 align the API surface between old and new version of model.
959 Args:
960 x: A floating point `numpy.array` or a `tf.Tensor`.
961 data_format: Optional data format of the image tensor/array. Defaults to
962 None, in which case the global setting
963 `tf.keras.backend.image_data_format()` is used (unless you changed it,
964 it defaults to "channels_last").{mode}
966 Returns:
967 Unchanged `numpy.array` or `tf.Tensor`.
968 """
969 return x
972@keras_export("keras.applications.resnet_rs.decode_predictions")
973def decode_predictions(preds, top=5):
974 return imagenet_utils.decode_predictions(preds, top=top)
977decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
979ResNetRS50.__doc__ = BASE_DOCSTRING.format(name="ResNetRS50")
980ResNetRS101.__doc__ = BASE_DOCSTRING.format(name="ResNetRS101")
981ResNetRS152.__doc__ = BASE_DOCSTRING.format(name="ResNetRS152")
982ResNetRS200.__doc__ = BASE_DOCSTRING.format(name="ResNetRS200")
983ResNetRS270.__doc__ = BASE_DOCSTRING.format(name="ResNetRS270")
984ResNetRS350.__doc__ = BASE_DOCSTRING.format(name="ResNetRS350")
985ResNetRS420.__doc__ = BASE_DOCSTRING.format(name="ResNetRS420")