Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/applications/resnet_rs.py: 26%

209 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""ResNet-RS models for Keras. 

18 

19Reference: 

20- [Revisiting ResNets: Improved Training and Scaling Strategies]( 

21 https://arxiv.org/pdf/2103.07579.pdf) 

22""" 

23import sys 

24from typing import Callable 

25from typing import Dict 

26from typing import List 

27from typing import Union 

28 

29import tensorflow.compat.v2 as tf 

30 

31from keras.src import backend 

32from keras.src import layers 

33from keras.src.applications import imagenet_utils 

34from keras.src.engine import training 

35from keras.src.utils import data_utils 

36from keras.src.utils import layer_utils 

37 

38# isort: off 

39from tensorflow.python.util.tf_export import keras_export 

40 

41BASE_WEIGHTS_URL = ( 

42 "https://storage.googleapis.com/tensorflow/keras-applications/resnet_rs/" 

43) 

44 

45WEIGHT_HASHES = { 

46 "resnet-rs-101-i160.h5": "544b3434d00efc199d66e9058c7f3379", 

47 "resnet-rs-101-i160_notop.h5": "82d5b90c5ce9d710da639d6216d0f979", 

48 "resnet-rs-101-i192.h5": "eb285be29ab42cf4835ff20a5e3b5d23", 

49 "resnet-rs-101-i192_notop.h5": "f9a0f6b85faa9c3db2b6e233c4eebb5b", 

50 "resnet-rs-152-i192.h5": "8d72a301ed8a6f11a47c4ced4396e338", 

51 "resnet-rs-152-i192_notop.h5": "5fbf7ac2155cb4d5a6180ee9e3aa8704", 

52 "resnet-rs-152-i224.h5": "31a46a92ab21b84193d0d71dd8c3d03b", 

53 "resnet-rs-152-i224_notop.h5": "dc8b2cba2005552eafa3167f00dc2133", 

54 "resnet-rs-152-i256.h5": "ba6271b99bdeb4e7a9b15c05964ef4ad", 

55 "resnet-rs-152-i256_notop.h5": "fa79794252dbe47c89130f65349d654a", 

56 "resnet-rs-200-i256.h5": "a76930b741884e09ce90fa7450747d5f", 

57 "resnet-rs-200-i256_notop.h5": "bbdb3994718dfc0d1cd45d7eff3f3d9c", 

58 "resnet-rs-270-i256.h5": "20d575825ba26176b03cb51012a367a8", 

59 "resnet-rs-270-i256_notop.h5": "2c42ecb22e35f3e23d2f70babce0a2aa", 

60 "resnet-rs-350-i256.h5": "f4a039dc3c421321b7fc240494574a68", 

61 "resnet-rs-350-i256_notop.h5": "6e44b55025bbdff8f51692a023143d66", 

62 "resnet-rs-350-i320.h5": "7ccb858cc738305e8ceb3c0140bee393", 

63 "resnet-rs-350-i320_notop.h5": "ab0c1f9079d2f85a9facbd2c88aa6079", 

64 "resnet-rs-420-i320.h5": "ae0eb9bed39e64fc8d7e0db4018dc7e8", 

65 "resnet-rs-420-i320_notop.h5": "fe6217c32be8305b1889657172b98884", 

66 "resnet-rs-50-i160.h5": "69d9d925319f00a8bdd4af23c04e4102", 

67 "resnet-rs-50-i160_notop.h5": "90daa68cd26c95aa6c5d25451e095529", 

68} 

69 

70DEPTH_TO_WEIGHT_VARIANTS = { 

71 50: [160], 

72 101: [160, 192], 

73 152: [192, 224, 256], 

74 200: [256], 

75 270: [256], 

76 350: [256, 320], 

77 420: [320], 

78} 

79BLOCK_ARGS = { 

80 50: [ 

81 {"input_filters": 64, "num_repeats": 3}, 

82 {"input_filters": 128, "num_repeats": 4}, 

83 {"input_filters": 256, "num_repeats": 6}, 

84 {"input_filters": 512, "num_repeats": 3}, 

85 ], 

86 101: [ 

87 {"input_filters": 64, "num_repeats": 3}, 

88 {"input_filters": 128, "num_repeats": 4}, 

89 {"input_filters": 256, "num_repeats": 23}, 

90 {"input_filters": 512, "num_repeats": 3}, 

91 ], 

92 152: [ 

93 {"input_filters": 64, "num_repeats": 3}, 

94 {"input_filters": 128, "num_repeats": 8}, 

95 {"input_filters": 256, "num_repeats": 36}, 

96 {"input_filters": 512, "num_repeats": 3}, 

97 ], 

98 200: [ 

99 {"input_filters": 64, "num_repeats": 3}, 

100 {"input_filters": 128, "num_repeats": 24}, 

101 {"input_filters": 256, "num_repeats": 36}, 

102 {"input_filters": 512, "num_repeats": 3}, 

103 ], 

104 270: [ 

105 {"input_filters": 64, "num_repeats": 4}, 

106 {"input_filters": 128, "num_repeats": 29}, 

107 {"input_filters": 256, "num_repeats": 53}, 

108 {"input_filters": 512, "num_repeats": 4}, 

109 ], 

110 350: [ 

111 {"input_filters": 64, "num_repeats": 4}, 

112 {"input_filters": 128, "num_repeats": 36}, 

113 {"input_filters": 256, "num_repeats": 72}, 

114 {"input_filters": 512, "num_repeats": 4}, 

115 ], 

116 420: [ 

117 {"input_filters": 64, "num_repeats": 4}, 

118 {"input_filters": 128, "num_repeats": 44}, 

119 {"input_filters": 256, "num_repeats": 87}, 

120 {"input_filters": 512, "num_repeats": 4}, 

121 ], 

122} 

123CONV_KERNEL_INITIALIZER = { 

124 "class_name": "VarianceScaling", 

125 "config": { 

126 "scale": 2.0, 

127 "mode": "fan_out", 

128 "distribution": "truncated_normal", 

129 }, 

130} 

131 

132BASE_DOCSTRING = """Instantiates the {name} architecture. 

133 

134 Reference: 

135 [Revisiting ResNets: Improved Training and Scaling Strategies]( 

136 https://arxiv.org/pdf/2103.07579.pdf) 

137 

138 For image classification use cases, see 

139 [this page for detailed examples]( 

140 https://keras.io/api/applications/#usage-examples-for-image-classification-models). 

141 

142 For transfer learning use cases, make sure to read the 

143 [guide to transfer learning & fine-tuning]( 

144 https://keras.io/guides/transfer_learning/). 

145 

146 Note: each Keras Application expects a specific kind of input preprocessing. 

147 For ResNetRs, by default input preprocessing is included as a part of the 

148 model (as a `Rescaling` layer), and thus 

149 `tf.keras.applications.resnet_rs.preprocess_input` is actually a 

150 pass-through function. In this use case, ResNetRS models expect their inputs 

151 to be float tensors of pixels with values in the [0-255] range. 

152 At the same time, preprocessing as a part of the model (i.e. `Rescaling` 

153 layer) can be disabled by setting `include_preprocessing` argument to False. 

154 With preprocessing disabled ResNetRS models expect their inputs to be float 

155 tensors of pixels with values in the [-1, 1] range. 

156 

157 Args: 

158 depth: Depth of ResNet network. 

159 input_shape: optional shape tuple. It should have exactly 3 inputs 

160 channels, and width and height should be no smaller than 32. 

161 E.g. (200, 200, 3) would be one valid value. 

162 bn_momentum: Momentum parameter for Batch Normalization layers. 

163 bn_epsilon: Epsilon parameter for Batch Normalization layers. 

164 activation: activation function. 

165 se_ratio: Squeeze and Excitation layer ratio. 

166 dropout_rate: dropout rate before final classifier layer. 

167 drop_connect_rate: dropout rate at skip connections. 

168 include_top: whether to include the fully-connected layer at the top of 

169 the network. 

170 block_args: list of dicts, parameters to construct block modules. 

171 model_name: name of the model. 

172 pooling: optional pooling mode for feature extraction when `include_top` 

173 is `False`. 

174 - `None` means that the output of the model will be 

175 the 4D tensor output of the 

176 last convolutional layer. 

177 - `avg` means that global average pooling 

178 will be applied to the output of the 

179 last convolutional layer, and thus 

180 the output of the model will be a 2D tensor. 

181 - `max` means that global max pooling will 

182 be applied. 

183 weights: one of `None` (random initialization), `'imagenet'` 

184 (pre-training on ImageNet), or the path to the weights file to be 

185 loaded. Note: one model can have multiple imagenet variants 

186 depending on input shape it was trained with. For input_shape 

187 224x224 pass `imagenet-i224` as argument. By default, highest input 

188 shape weights are downloaded. 

189 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to 

190 use as image input for the model. 

191 classes: optional number of classes to classify images into, only to be 

192 specified if `include_top` is True, and if no `weights` argument is 

193 specified. 

194 classifier_activation: A `str` or callable. The activation function to 

195 use on the "top" layer. Ignored unless `include_top=True`. Set 

196 `classifier_activation=None` to return the logits of the "top" 

197 layer. 

198 include_preprocessing: Boolean, whether to include the preprocessing 

199 layer (`Rescaling`) at the bottom of the network. Defaults to 

200 `True`. Note: Input image is normalized by ImageNet mean and 

201 standard deviation. 

202 

203 Returns: 

204 A `keras.Model` instance. 

205""" 

206 

207 

208def Conv2DFixedPadding(filters, kernel_size, strides, name=None): 

209 """Conv2D block with fixed padding.""" 

210 if name is None: 

211 counter = backend.get_uid("conv_") 

212 name = f"conv_{counter}" 

213 

214 def apply(inputs): 

215 if strides > 1: 

216 inputs = fixed_padding(inputs, kernel_size) 

217 return layers.Conv2D( 

218 filters=filters, 

219 kernel_size=kernel_size, 

220 strides=strides, 

221 padding="same" if strides == 1 else "valid", 

222 use_bias=False, 

223 kernel_initializer=CONV_KERNEL_INITIALIZER, 

224 name=name, 

225 )(inputs) 

226 

227 return apply 

228 

229 

230def STEM( 

231 bn_momentum: float = 0.0, 

232 bn_epsilon: float = 1e-5, 

233 activation: str = "relu", 

234 name=None, 

235): 

236 """ResNet-D type STEM block.""" 

237 if name is None: 

238 counter = backend.get_uid("stem_") 

239 name = f"stem_{counter}" 

240 

241 def apply(inputs): 

242 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 

243 

244 # First stem block 

245 x = Conv2DFixedPadding( 

246 filters=32, kernel_size=3, strides=2, name=name + "_stem_conv_1" 

247 )(inputs) 

248 x = layers.BatchNormalization( 

249 axis=bn_axis, 

250 momentum=bn_momentum, 

251 epsilon=bn_epsilon, 

252 name=name + "_stem_batch_norm_1", 

253 )(x) 

254 x = layers.Activation(activation, name=name + "_stem_act_1")(x) 

255 

256 # Second stem block 

257 x = Conv2DFixedPadding( 

258 filters=32, kernel_size=3, strides=1, name=name + "_stem_conv_2" 

259 )(x) 

260 x = layers.BatchNormalization( 

261 axis=bn_axis, 

262 momentum=bn_momentum, 

263 epsilon=bn_epsilon, 

264 name=name + "_stem_batch_norm_2", 

265 )(x) 

266 x = layers.Activation(activation, name=name + "_stem_act_2")(x) 

267 

268 # Final Stem block: 

269 x = Conv2DFixedPadding( 

270 filters=64, kernel_size=3, strides=1, name=name + "_stem_conv_3" 

271 )(x) 

272 x = layers.BatchNormalization( 

273 axis=bn_axis, 

274 momentum=bn_momentum, 

275 epsilon=bn_epsilon, 

276 name=name + "_stem_batch_norm_3", 

277 )(x) 

278 x = layers.Activation(activation, name=name + "_stem_act_3")(x) 

279 

280 # Replace stem max pool: 

281 x = Conv2DFixedPadding( 

282 filters=64, kernel_size=3, strides=2, name=name + "_stem_conv_4" 

283 )(x) 

284 x = layers.BatchNormalization( 

285 axis=bn_axis, 

286 momentum=bn_momentum, 

287 epsilon=bn_epsilon, 

288 name=name + "_stem_batch_norm_4", 

289 )(x) 

290 x = layers.Activation(activation, name=name + "_stem_act_4")(x) 

291 return x 

292 

293 return apply 

294 

295 

296def SE( 

297 in_filters: int, se_ratio: float = 0.25, expand_ratio: int = 1, name=None 

298): 

299 """Squeeze and Excitation block.""" 

300 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 

301 if name is None: 

302 counter = backend.get_uid("se_") 

303 name = f"se_{counter}" 

304 

305 def apply(inputs): 

306 x = layers.GlobalAveragePooling2D(name=name + "_se_squeeze")(inputs) 

307 if bn_axis == 1: 

308 se_shape = (x.shape[-1], 1, 1) 

309 else: 

310 se_shape = (1, 1, x.shape[-1]) 

311 x = layers.Reshape(se_shape, name=name + "_se_reshape")(x) 

312 

313 num_reduced_filters = max(1, int(in_filters * 4 * se_ratio)) 

314 

315 x = layers.Conv2D( 

316 filters=num_reduced_filters, 

317 kernel_size=[1, 1], 

318 strides=[1, 1], 

319 kernel_initializer=CONV_KERNEL_INITIALIZER, 

320 padding="same", 

321 use_bias=True, 

322 activation="relu", 

323 name=name + "_se_reduce", 

324 )(x) 

325 

326 x = layers.Conv2D( 

327 filters=4 

328 * in_filters 

329 * expand_ratio, # Expand ratio is 1 by default 

330 kernel_size=[1, 1], 

331 strides=[1, 1], 

332 kernel_initializer=CONV_KERNEL_INITIALIZER, 

333 padding="same", 

334 use_bias=True, 

335 activation="sigmoid", 

336 name=name + "_se_expand", 

337 )(x) 

338 

339 return layers.multiply([inputs, x], name=name + "_se_excite") 

340 

341 return apply 

342 

343 

344def BottleneckBlock( 

345 filters: int, 

346 strides: int, 

347 use_projection: bool, 

348 bn_momentum: float = 0.0, 

349 bn_epsilon: float = 1e-5, 

350 activation: str = "relu", 

351 se_ratio: float = 0.25, 

352 survival_probability: float = 0.8, 

353 name=None, 

354): 

355 """Bottleneck block variant for residual networks with BN.""" 

356 if name is None: 

357 counter = backend.get_uid("block_0_") 

358 name = f"block_0_{counter}" 

359 

360 def apply(inputs): 

361 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 

362 

363 shortcut = inputs 

364 

365 if use_projection: 

366 filters_out = filters * 4 

367 if strides == 2: 

368 shortcut = layers.AveragePooling2D( 

369 pool_size=(2, 2), 

370 strides=(2, 2), 

371 padding="same", 

372 name=name + "_projection_pooling", 

373 )(inputs) 

374 shortcut = Conv2DFixedPadding( 

375 filters=filters_out, 

376 kernel_size=1, 

377 strides=1, 

378 name=name + "_projection_conv", 

379 )(shortcut) 

380 else: 

381 shortcut = Conv2DFixedPadding( 

382 filters=filters_out, 

383 kernel_size=1, 

384 strides=strides, 

385 name=name + "_projection_conv", 

386 )(inputs) 

387 

388 shortcut = layers.BatchNormalization( 

389 axis=bn_axis, 

390 momentum=bn_momentum, 

391 epsilon=bn_epsilon, 

392 name=name + "_projection_batch_norm", 

393 )(shortcut) 

394 

395 # First conv layer: 

396 x = Conv2DFixedPadding( 

397 filters=filters, kernel_size=1, strides=1, name=name + "_conv_1" 

398 )(inputs) 

399 x = layers.BatchNormalization( 

400 axis=bn_axis, 

401 momentum=bn_momentum, 

402 epsilon=bn_epsilon, 

403 name=name + "batch_norm_1", 

404 )(x) 

405 x = layers.Activation(activation, name=name + "_act_1")(x) 

406 

407 # Second conv layer: 

408 x = Conv2DFixedPadding( 

409 filters=filters, 

410 kernel_size=3, 

411 strides=strides, 

412 name=name + "_conv_2", 

413 )(x) 

414 x = layers.BatchNormalization( 

415 axis=bn_axis, 

416 momentum=bn_momentum, 

417 epsilon=bn_epsilon, 

418 name=name + "_batch_norm_2", 

419 )(x) 

420 x = layers.Activation(activation, name=name + "_act_2")(x) 

421 

422 # Third conv layer: 

423 x = Conv2DFixedPadding( 

424 filters=filters * 4, kernel_size=1, strides=1, name=name + "_conv_3" 

425 )(x) 

426 x = layers.BatchNormalization( 

427 axis=bn_axis, 

428 momentum=bn_momentum, 

429 epsilon=bn_epsilon, 

430 name=name + "_batch_norm_3", 

431 )(x) 

432 

433 if 0 < se_ratio < 1: 

434 x = SE(filters, se_ratio=se_ratio, name=name + "_se")(x) 

435 

436 # Drop connect 

437 if survival_probability: 

438 x = layers.Dropout( 

439 survival_probability, 

440 noise_shape=(None, 1, 1, 1), 

441 name=name + "_drop", 

442 )(x) 

443 

444 x = layers.Add()([x, shortcut]) 

445 

446 return layers.Activation(activation, name=name + "_output_act")(x) 

447 

448 return apply 

449 

450 

451def BlockGroup( 

452 filters, 

453 strides, 

454 num_repeats, 

455 se_ratio: float = 0.25, 

456 bn_epsilon: float = 1e-5, 

457 bn_momentum: float = 0.0, 

458 activation: str = "relu", 

459 survival_probability: float = 0.8, 

460 name=None, 

461): 

462 """Create one group of blocks for the ResNet model.""" 

463 if name is None: 

464 counter = backend.get_uid("block_group_") 

465 name = f"block_group_{counter}" 

466 

467 def apply(inputs): 

468 # Only the first block per block_group uses projection shortcut and 

469 # strides. 

470 x = BottleneckBlock( 

471 filters=filters, 

472 strides=strides, 

473 use_projection=True, 

474 se_ratio=se_ratio, 

475 bn_epsilon=bn_epsilon, 

476 bn_momentum=bn_momentum, 

477 activation=activation, 

478 survival_probability=survival_probability, 

479 name=name + "_block_0_", 

480 )(inputs) 

481 

482 for i in range(1, num_repeats): 

483 x = BottleneckBlock( 

484 filters=filters, 

485 strides=1, 

486 use_projection=False, 

487 se_ratio=se_ratio, 

488 activation=activation, 

489 bn_epsilon=bn_epsilon, 

490 bn_momentum=bn_momentum, 

491 survival_probability=survival_probability, 

492 name=name + f"_block_{i}_", 

493 )(x) 

494 return x 

495 

496 return apply 

497 

498 

499def get_survival_probability(init_rate, block_num, total_blocks): 

500 """Get survival probability based on block number and initial rate.""" 

501 return init_rate * float(block_num) / total_blocks 

502 

503 

504def allow_bigger_recursion(target_limit: int): 

505 """Increase default recursion limit to create larger models.""" 

506 current_limit = sys.getrecursionlimit() 

507 if current_limit < target_limit: 

508 sys.setrecursionlimit(target_limit) 

509 

510 

511def fixed_padding(inputs, kernel_size): 

512 """Pad the input along the spatial dimensions independently of input 

513 size.""" 

514 pad_total = kernel_size - 1 

515 pad_beg = pad_total // 2 

516 pad_end = pad_total - pad_beg 

517 

518 # Use ZeroPadding as to avoid TFOpLambda layer 

519 padded_inputs = layers.ZeroPadding2D( 

520 padding=((pad_beg, pad_end), (pad_beg, pad_end)) 

521 )(inputs) 

522 

523 return padded_inputs 

524 

525 

526def ResNetRS( 

527 depth: int, 

528 input_shape=None, 

529 bn_momentum=0.0, 

530 bn_epsilon=1e-5, 

531 activation: str = "relu", 

532 se_ratio=0.25, 

533 dropout_rate=0.25, 

534 drop_connect_rate=0.2, 

535 include_top=True, 

536 block_args: List[Dict[str, int]] = None, 

537 model_name="resnet-rs", 

538 pooling=None, 

539 weights="imagenet", 

540 input_tensor=None, 

541 classes=1000, 

542 classifier_activation: Union[str, Callable] = "softmax", 

543 include_preprocessing=True, 

544): 

545 """Build Resnet-RS model, given provided parameters. 

546 

547 Args: 

548 depth: Depth of ResNet network. 

549 input_shape: optional shape tuple. It should have exactly 3 inputs 

550 channels, and width and height should be no smaller than 32. E.g. 

551 (200, 200, 3) would be one valid value. 

552 bn_momentum: Momentum parameter for Batch Normalization layers. 

553 bn_epsilon: Epsilon parameter for Batch Normalization layers. 

554 activation: activation function. 

555 se_ratio: Squeeze and Excitation layer ratio. 

556 dropout_rate: dropout rate before final classifier layer. 

557 drop_connect_rate: dropout rate at skip connections. 

558 include_top: whether to include the fully-connected layer at the top of 

559 the network. 

560 block_args: list of dicts, parameters to construct block modules. 

561 model_name: name of the model. 

562 pooling: optional pooling mode for feature extraction when `include_top` 

563 is `False`. 

564 - `None` means that the output of the model will be the 4D tensor 

565 output of the last convolutional layer. 

566 - `avg` means that global average pooling will be applied to the 

567 output of the last convolutional layer, and thus the output of the 

568 model will be a 2D tensor. 

569 - `max` means that global max pooling will be applied. 

570 weights: one of `None` (random initialization), `'imagenet'` 

571 (pre-training on ImageNet), or the path to the weights file to be 

572 loaded. Note- one model can have multiple imagenet variants depending 

573 on input shape it was trained with. For input_shape 224x224 pass 

574 `imagenet-i224` as argument. By default, highest input shape weights 

575 are downloaded. 

576 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to 

577 use as image input for the model. 

578 classes: optional number of classes to classify images into, only to be 

579 specified if `include_top` is True, and if no `weights` argument is 

580 specified. 

581 classifier_activation: A `str` or callable. The activation function to 

582 use on the "top" layer. Ignored unless `include_top=True`. Set 

583 `classifier_activation=None` to return the logits of the "top" layer. 

584 include_preprocessing: Boolean, whether to include the preprocessing 

585 layer (`Rescaling`) at the bottom of the network. Defaults to `True`. 

586 Note- Input image is normalized by ImageNet mean and standard 

587 deviation. 

588 

589 Returns: 

590 A `tf.keras.Model` instance. 

591 

592 Raises: 

593 ValueError: in case of invalid argument for `weights`, or invalid input 

594 shape. 

595 ValueError: if `classifier_activation` is not `softmax` or `None` when 

596 using a pretrained top layer. 

597 """ 

598 # Validate parameters 

599 available_weight_variants = DEPTH_TO_WEIGHT_VARIANTS[depth] 

600 if weights == "imagenet": 

601 max_input_shape = max(available_weight_variants) 

602 # `imagenet` argument without explicit weights input size. 

603 # Picking weights trained with biggest available shape 

604 weights = f"{weights}-i{max_input_shape}" 

605 

606 weights_allow_list = [f"imagenet-i{x}" for x in available_weight_variants] 

607 if not ( 

608 weights in {*weights_allow_list, None} or tf.io.gfile.exists(weights) 

609 ): 

610 raise ValueError( 

611 "The `weights` argument should be either " 

612 "`None` (random initialization), `'imagenet'` " 

613 "(pre-training on ImageNet, with highest available input shape)," 

614 " or the path to the weights file to be loaded. " 

615 f"For ResNetRS{depth} the following weight variants are " 

616 f"available {weights_allow_list} (default=highest)." 

617 f" Received weights={weights}" 

618 ) 

619 

620 if weights in weights_allow_list and include_top and classes != 1000: 

621 raise ValueError( 

622 "If using `weights` as `'imagenet'` or any " 

623 f"of {weights_allow_list} " 

624 "with `include_top` as true, `classes` should be 1000. " 

625 f"Received classes={classes}" 

626 ) 

627 

628 input_shape = imagenet_utils.obtain_input_shape( 

629 input_shape, 

630 default_size=224, 

631 min_size=32, 

632 data_format=backend.image_data_format(), 

633 require_flatten=include_top, 

634 weights=weights, 

635 ) 

636 # Define input tensor 

637 if input_tensor is None: 

638 img_input = layers.Input(shape=input_shape) 

639 else: 

640 if not backend.is_keras_tensor(input_tensor): 

641 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 

642 else: 

643 img_input = input_tensor 

644 

645 bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 

646 

647 x = img_input 

648 

649 if include_preprocessing: 

650 num_channels = input_shape[bn_axis - 1] 

651 x = layers.Rescaling(scale=1.0 / 255)(x) 

652 if num_channels == 3: 

653 x = layers.Normalization( 

654 mean=[0.485, 0.456, 0.406], 

655 variance=[0.229**2, 0.224**2, 0.225**2], 

656 axis=bn_axis, 

657 )(x) 

658 

659 # Build stem 

660 x = STEM( 

661 bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, activation=activation 

662 )(x) 

663 

664 # Build blocks 

665 if block_args is None: 

666 block_args = BLOCK_ARGS[depth] 

667 

668 for i, args in enumerate(block_args): 

669 survival_probability = get_survival_probability( 

670 init_rate=drop_connect_rate, 

671 block_num=i + 2, 

672 total_blocks=len(block_args) + 1, 

673 ) 

674 

675 x = BlockGroup( 

676 filters=args["input_filters"], 

677 activation=activation, 

678 strides=(1 if i == 0 else 2), 

679 num_repeats=args["num_repeats"], 

680 se_ratio=se_ratio, 

681 bn_momentum=bn_momentum, 

682 bn_epsilon=bn_epsilon, 

683 survival_probability=survival_probability, 

684 name=f"BlockGroup{i + 2}_", 

685 )(x) 

686 

687 # Build head: 

688 if include_top: 

689 x = layers.GlobalAveragePooling2D(name="avg_pool")(x) 

690 if dropout_rate > 0: 

691 x = layers.Dropout(dropout_rate, name="top_dropout")(x) 

692 

693 imagenet_utils.validate_activation(classifier_activation, weights) 

694 x = layers.Dense( 

695 classes, activation=classifier_activation, name="predictions" 

696 )(x) 

697 else: 

698 if pooling == "avg": 

699 x = layers.GlobalAveragePooling2D(name="avg_pool")(x) 

700 elif pooling == "max": 

701 x = layers.GlobalMaxPooling2D(name="max_pool")(x) 

702 

703 # Ensure that the model takes into account 

704 # any potential predecessors of `input_tensor`. 

705 if input_tensor is not None: 

706 inputs = layer_utils.get_source_inputs(input_tensor) 

707 else: 

708 inputs = img_input 

709 

710 # Create model. 

711 model = training.Model(inputs, x, name=model_name) 

712 

713 # Download weights 

714 if weights in weights_allow_list: 

715 weights_input_shape = weights.split("-")[-1] # e. g. "i160" 

716 weights_name = f"{model_name}-{weights_input_shape}" 

717 if not include_top: 

718 weights_name += "_notop" 

719 

720 filename = f"{weights_name}.h5" 

721 download_url = BASE_WEIGHTS_URL + filename 

722 weights_path = data_utils.get_file( 

723 fname=filename, 

724 origin=download_url, 

725 cache_subdir="models", 

726 file_hash=WEIGHT_HASHES[filename], 

727 ) 

728 model.load_weights(weights_path) 

729 

730 elif weights is not None: 

731 model.load_weights(weights) 

732 

733 return model 

734 

735 

736@keras_export( 

737 "keras.applications.resnet_rs.ResNetRS50", "keras.applications.ResNetRS50" 

738) 

739def ResNetRS50( 

740 include_top=True, 

741 weights="imagenet", 

742 classes=1000, 

743 input_shape=None, 

744 input_tensor=None, 

745 pooling=None, 

746 classifier_activation="softmax", 

747 include_preprocessing=True, 

748): 

749 """Build ResNet-RS50 model.""" 

750 return ResNetRS( 

751 depth=50, 

752 include_top=include_top, 

753 drop_connect_rate=0.0, 

754 dropout_rate=0.25, 

755 weights=weights, 

756 classes=classes, 

757 input_shape=input_shape, 

758 input_tensor=input_tensor, 

759 pooling=pooling, 

760 classifier_activation=classifier_activation, 

761 model_name="resnet-rs-50", 

762 include_preprocessing=include_preprocessing, 

763 ) 

764 

765 

766@keras_export( 

767 "keras.applications.resnet_rs.ResNetRS101", "keras.applications.ResNetRS101" 

768) 

769def ResNetRS101( 

770 include_top=True, 

771 weights="imagenet", 

772 classes=1000, 

773 input_shape=None, 

774 input_tensor=None, 

775 pooling=None, 

776 classifier_activation="softmax", 

777 include_preprocessing=True, 

778): 

779 """Build ResNet-RS101 model.""" 

780 return ResNetRS( 

781 depth=101, 

782 include_top=include_top, 

783 drop_connect_rate=0.0, 

784 dropout_rate=0.25, 

785 weights=weights, 

786 classes=classes, 

787 input_shape=input_shape, 

788 input_tensor=input_tensor, 

789 pooling=pooling, 

790 classifier_activation=classifier_activation, 

791 model_name="resnet-rs-101", 

792 include_preprocessing=include_preprocessing, 

793 ) 

794 

795 

796@keras_export( 

797 "keras.applications.resnet_rs.ResNetRS152", "keras.applications.ResNetRS152" 

798) 

799def ResNetRS152( 

800 include_top=True, 

801 weights="imagenet", 

802 classes=1000, 

803 input_shape=None, 

804 input_tensor=None, 

805 pooling=None, 

806 classifier_activation="softmax", 

807 include_preprocessing=True, 

808): 

809 """Build ResNet-RS152 model.""" 

810 return ResNetRS( 

811 depth=152, 

812 include_top=include_top, 

813 drop_connect_rate=0.0, 

814 dropout_rate=0.25, 

815 weights=weights, 

816 classes=classes, 

817 input_shape=input_shape, 

818 input_tensor=input_tensor, 

819 pooling=pooling, 

820 classifier_activation=classifier_activation, 

821 model_name="resnet-rs-152", 

822 include_preprocessing=include_preprocessing, 

823 ) 

824 

825 

826@keras_export( 

827 "keras.applications.resnet_rs.ResNetRS200", "keras.applications.ResNetRS200" 

828) 

829def ResNetRS200( 

830 include_top=True, 

831 weights="imagenet", 

832 classes=1000, 

833 input_shape=None, 

834 input_tensor=None, 

835 pooling=None, 

836 classifier_activation="softmax", 

837 include_preprocessing=True, 

838): 

839 """Build ResNet-RS200 model.""" 

840 return ResNetRS( 

841 depth=200, 

842 include_top=include_top, 

843 drop_connect_rate=0.1, 

844 dropout_rate=0.25, 

845 weights=weights, 

846 classes=classes, 

847 input_shape=input_shape, 

848 input_tensor=input_tensor, 

849 pooling=pooling, 

850 classifier_activation=classifier_activation, 

851 model_name="resnet-rs-200", 

852 include_preprocessing=include_preprocessing, 

853 ) 

854 

855 

856@keras_export( 

857 "keras.applications.resnet_rs.ResNetRS270", "keras.applications.ResNetRS270" 

858) 

859def ResNetRS270( 

860 include_top=True, 

861 weights="imagenet", 

862 classes=1000, 

863 input_shape=None, 

864 input_tensor=None, 

865 pooling=None, 

866 classifier_activation="softmax", 

867 include_preprocessing=True, 

868): 

869 """Build ResNet-RS-270 model.""" 

870 allow_bigger_recursion(1300) 

871 return ResNetRS( 

872 depth=270, 

873 include_top=include_top, 

874 drop_connect_rate=0.1, 

875 dropout_rate=0.25, 

876 weights=weights, 

877 classes=classes, 

878 input_shape=input_shape, 

879 input_tensor=input_tensor, 

880 pooling=pooling, 

881 classifier_activation=classifier_activation, 

882 model_name="resnet-rs-270", 

883 include_preprocessing=include_preprocessing, 

884 ) 

885 

886 

887@keras_export( 

888 "keras.applications.resnet_rs.ResNetRS350", "keras.applications.ResNetRS350" 

889) 

890def ResNetRS350( 

891 include_top=True, 

892 weights="imagenet", 

893 classes=1000, 

894 input_shape=None, 

895 input_tensor=None, 

896 pooling=None, 

897 classifier_activation="softmax", 

898 include_preprocessing=True, 

899): 

900 """Build ResNet-RS350 model.""" 

901 allow_bigger_recursion(1500) 

902 return ResNetRS( 

903 depth=350, 

904 include_top=include_top, 

905 drop_connect_rate=0.1, 

906 dropout_rate=0.4, 

907 weights=weights, 

908 classes=classes, 

909 input_shape=input_shape, 

910 input_tensor=input_tensor, 

911 pooling=pooling, 

912 classifier_activation=classifier_activation, 

913 model_name="resnet-rs-350", 

914 include_preprocessing=include_preprocessing, 

915 ) 

916 

917 

918@keras_export( 

919 "keras.applications.resnet_rs.ResNetRS420", "keras.applications.ResNetRS420" 

920) 

921def ResNetRS420( 

922 include_top=True, 

923 weights="imagenet", 

924 classes=1000, 

925 input_shape=None, 

926 input_tensor=None, 

927 pooling=None, 

928 classifier_activation="softmax", 

929 include_preprocessing=True, 

930): 

931 """Build ResNet-RS420 model.""" 

932 allow_bigger_recursion(1800) 

933 return ResNetRS( 

934 depth=420, 

935 include_top=include_top, 

936 dropout_rate=0.4, 

937 drop_connect_rate=0.1, 

938 weights=weights, 

939 classes=classes, 

940 input_shape=input_shape, 

941 input_tensor=input_tensor, 

942 pooling=pooling, 

943 classifier_activation=classifier_activation, 

944 model_name="resnet-rs-420", 

945 include_preprocessing=include_preprocessing, 

946 ) 

947 

948 

949@keras_export("keras.applications.resnet_rs.preprocess_input") 

950def preprocess_input(x, data_format=None): 

951 """A placeholder method for backward compatibility. 

952 

953 The preprocessing logic has been included in the ResnetRS model 

954 implementation. Users are no longer required to call this method to 

955 normalize 

956 the input data. This method does nothing and only kept as a placeholder to 

957 align the API surface between old and new version of model. 

958 

959 Args: 

960 x: A floating point `numpy.array` or a `tf.Tensor`. 

961 data_format: Optional data format of the image tensor/array. Defaults to 

962 None, in which case the global setting 

963 `tf.keras.backend.image_data_format()` is used (unless you changed it, 

964 it defaults to "channels_last").{mode} 

965 

966 Returns: 

967 Unchanged `numpy.array` or `tf.Tensor`. 

968 """ 

969 return x 

970 

971 

972@keras_export("keras.applications.resnet_rs.decode_predictions") 

973def decode_predictions(preds, top=5): 

974 return imagenet_utils.decode_predictions(preds, top=top) 

975 

976 

977decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 

978 

979ResNetRS50.__doc__ = BASE_DOCSTRING.format(name="ResNetRS50") 

980ResNetRS101.__doc__ = BASE_DOCSTRING.format(name="ResNetRS101") 

981ResNetRS152.__doc__ = BASE_DOCSTRING.format(name="ResNetRS152") 

982ResNetRS200.__doc__ = BASE_DOCSTRING.format(name="ResNetRS200") 

983ResNetRS270.__doc__ = BASE_DOCSTRING.format(name="ResNetRS270") 

984ResNetRS350.__doc__ = BASE_DOCSTRING.format(name="ResNetRS350") 

985ResNetRS420.__doc__ = BASE_DOCSTRING.format(name="ResNetRS420") 

986