Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/optimizer.py: 22%

410 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class of optimizer.""" 

16 

17import abc 

18import platform 

19import re 

20 

21import tensorflow.compat.v2 as tf 

22from absl import logging 

23 

24from keras.src import backend 

25from keras.src import initializers 

26from keras.src.dtensor import utils as dtensor_utils 

27from keras.src.optimizers import utils as optimizer_utils 

28from keras.src.optimizers.schedules import learning_rate_schedule 

29from keras.src.utils import tf_utils 

30 

31# isort: off 

32from tensorflow.python.util.tf_export import keras_export 

33from tensorflow.tools.docs import doc_controls 

34 

35 

36class _BaseOptimizer(tf.__internal__.tracking.AutoTrackable): 

37 """Optimizer base class, which only supports non-distribute use case.""" 

38 

39 def __init__( 

40 self, 

41 name, 

42 weight_decay=None, 

43 clipnorm=None, 

44 clipvalue=None, 

45 global_clipnorm=None, 

46 use_ema=False, 

47 ema_momentum=0.99, 

48 ema_overwrite_frequency=None, 

49 jit_compile=True, 

50 **kwargs, 

51 ): 

52 self.name = name 

53 self.weight_decay = weight_decay 

54 self.clipnorm = clipnorm 

55 self.global_clipnorm = global_clipnorm 

56 self.clipvalue = clipvalue 

57 self.use_ema = use_ema 

58 # Optimizer only benefits from XLA when training on GPU. So if no 

59 # GPU is found, we turn off XLA. 

60 if ( 

61 jit_compile 

62 and tf_utils.can_jit_compile() 

63 and tf.config.list_physical_devices("GPU") 

64 ): 

65 self.jit_compile = True 

66 else: 

67 self.jit_compile = False 

68 

69 if platform.system() == "Darwin" and platform.processor() == "arm": 

70 logging.warning( 

71 "At this time, the v2.11+ optimizer " 

72 f"`tf.keras.optimizers.{self.__class__.__name__}` runs slowly " 

73 "on M1/M2 Macs, please use the legacy Keras optimizer " 

74 "instead, located at " 

75 f"`tf.keras.optimizers.legacy.{self.__class__.__name__}`." 

76 ) 

77 

78 if use_ema: 

79 # Verify the arguments related to EMA. 

80 if ema_momentum > 1 or ema_momentum < 0: 

81 raise ValueError( 

82 "`ema_momentum` must be in the range [0, 1]. " 

83 f"Received: ema_momentum={ema_momentum}" 

84 ) 

85 if ema_overwrite_frequency and ( 

86 not isinstance(ema_overwrite_frequency, int) 

87 or ema_overwrite_frequency < 1 

88 ): 

89 raise ValueError( 

90 "`ema_overwrite_frequency` must be an integer > 1 or None. " 

91 "Received: ema_overwrite_frequency=" 

92 f"{ema_overwrite_frequency}" 

93 ) 

94 self.ema_momentum = ema_momentum 

95 self.ema_overwrite_frequency = ema_overwrite_frequency 

96 

97 if self.clipnorm is not None and self.global_clipnorm is not None: 

98 raise ValueError( 

99 "At most one of `clipnorm` and `global_clipnorm` can " 

100 f"be set. Received: clipnorm={self.clipnorm}, " 

101 f"global_clipnorm={self.global_clipnorm}." 

102 ) 

103 

104 self._variables = [] 

105 self._create_iteration_variable() 

106 self._process_kwargs(kwargs) 

107 

108 def _create_iteration_variable(self): 

109 """Create the iterations counter variable.""" 

110 with tf.init_scope(): 

111 # Lift the variable creation to init scope to avoid environment 

112 # issue. 

113 self._iterations = tf.Variable( 

114 0, name="iteration", dtype=tf.int64, trainable=False 

115 ) 

116 self._variables.append(self._iterations) 

117 

118 def _process_kwargs(self, kwargs): 

119 # Remove the `is_legacy_optimizer` arg, which is for serialization only. 

120 kwargs.pop("is_legacy_optimizer", None) 

121 lr = kwargs.pop("lr", None) 

122 if lr: 

123 logging.warning( 

124 "`lr` is deprecated in Keras optimizer, please use " 

125 "`learning_rate` or use the legacy optimizer, e.g.," 

126 f"tf.keras.optimizers.legacy.{self.__class__.__name__}." 

127 ) 

128 legacy_kwargs = { 

129 "decay", 

130 "gradient_aggregator", 

131 "gradient_transformers", 

132 } 

133 for k in kwargs: 

134 if k in legacy_kwargs: 

135 raise ValueError( 

136 f"{k} is deprecated in the new Keras optimizer, please " 

137 "check the docstring for valid arguments, or use the " 

138 "legacy optimizer, e.g., " 

139 f"tf.keras.optimizers.legacy.{self.__class__.__name__}." 

140 ) 

141 else: 

142 raise TypeError( 

143 f"{k} is not a valid argument, kwargs should be empty " 

144 " for `optimizer_experimental.Optimizer`." 

145 ) 

146 

147 def _create_or_restore_slot_variable(self, **kwargs): 

148 raise ValueError( 

149 "You are trying to restore a checkpoint from a legacy Keras " 

150 "optimizer into a v2.11+ Optimizer, which can cause " 

151 "errors. Please update the optimizer referenced in your code " 

152 "to be an instance of " 

153 "`tf.keras.optimizers.legacy.Optimizer`, e.g.: " 

154 f"`tf.keras.optimizers.legacy.{self.__class__.__name__}`." 

155 ) 

156 

157 def _var_key(self, variable): 

158 """Get a unique identifier of the given variable.""" 

159 # Get the distributed variable if it exists. 

160 # TODO(b/199214315): replace _unique_id with ref() after fixing ref() 

161 # issues on AggregatingVariable. 

162 return variable._unique_id 

163 

164 def _deduplicate_sparse_grad(self, grads): 

165 """Deduplicate sparse gradient. 

166 

167 For sparse gradients, i.e., gradient is of type `tf.IndexedSlices`, 

168 it is possible that `gradient.indices` has duplicated indices. 

169 This function adds up values for the duplicated indices, and returns 

170 a `tf.IndexedSlices` with indices of unique values. 

171 """ 

172 processed_grads = [] 

173 for grad in grads: 

174 if isinstance(grad, tf.IndexedSlices): 

175 values = grad.values 

176 indices = grad.indices 

177 unique_indices, new_index_positions = tf.unique(indices) 

178 summed_values = tf.math.unsorted_segment_sum( 

179 values, new_index_positions, tf.shape(unique_indices)[0] 

180 ) 

181 processed_grads.append( 

182 tf.IndexedSlices( 

183 summed_values, unique_indices, grad.dense_shape 

184 ) 

185 ) 

186 else: 

187 processed_grads.append(grad) 

188 

189 return processed_grads 

190 

191 @abc.abstractmethod 

192 def update_step(self, gradient, variable): 

193 """Function to update variable value based on given gradients. 

194 

195 This method must be implemented in customized optimizers. 

196 

197 Args: 

198 gradient: backpropagated gradient of the given variable. 

199 variable: variable whose value needs to be updated. 

200 

201 Returns: 

202 An `Operation` that applies the specified gradients. 

203 

204 """ 

205 raise NotImplementedError 

206 

207 @tf.function(jit_compile=True) 

208 def _update_step_xla(self, gradient, variable, key): 

209 """A wrapper of `update_step` to enable XLA acceleration. 

210 

211 Due to `tf.function` tracing mechanism, for (gradient, variable) pairs 

212 of the same shape and dtype, the execution graph always invoke the first 

213 pair it has seen. Thus, we need a `key` argument to make each (gradient, 

214 variable) pair unique. In additions, XLA cannot understand string input, 

215 so the key is an integer. 

216 

217 Args: 

218 gradient: backpropagated gradient of the given variable. 

219 variable: variable whose value needs to be updated. 

220 key (int): a unique key that identifies the variable. 

221 

222 Returns: 

223 An `Operation` that applies the specified gradients. 

224 """ 

225 return self._update_step(gradient, variable) 

226 

227 def _update_step(self, gradient, variable): 

228 if getattr(variable, "_unique_id", None) is None: 

229 # Variable has no `_unique_id` if called during `model.save()`, in 

230 # which case we do not want to update the variable. 

231 return 

232 if self._var_key(variable) not in self._index_dict: 

233 raise KeyError( 

234 f"The optimizer cannot recognize variable {variable.name}. " 

235 "This usually means you are trying to call the optimizer to " 

236 "update different parts of the model separately. Please call " 

237 "`optimizer.build(variables)` with the full list of trainable " 

238 "variables before the training loop or use legacy optimizer " 

239 f"`tf.keras.optimizers.legacy.{self.__class__.__name__}." 

240 ) 

241 self.update_step(gradient, variable) 

242 

243 def compute_gradients(self, loss, var_list, tape=None): 

244 """Compute gradients of loss on trainable variables. 

245 

246 Args: 

247 loss: `Tensor` or callable. If a callable, `loss` should take no 

248 arguments and return the value to minimize. 

249 var_list: list or tuple of `Variable` objects to update to minimize 

250 `loss`, or a callable returning the list or tuple of `Variable` 

251 objects. Use callable when the variable list would otherwise be 

252 incomplete before `minimize` since the variables are created at the 

253 first time `loss` is called. 

254 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a 

255 `Tensor`, the tape that computed the `loss` must be provided. 

256 

257 Returns: 

258 A list of (gradient, variable) pairs. Variable is always present, but 

259 gradient can be `None`. 

260 """ 

261 if not callable(loss) and tape is None: 

262 raise ValueError( 

263 "`tape` is required when a `Tensor` loss is passed. " 

264 f"Received: loss={loss}, tape={tape}." 

265 ) 

266 if tape is None: 

267 tape = tf.GradientTape() 

268 if callable(loss): 

269 with tape: 

270 if not callable(var_list): 

271 tape.watch(var_list) 

272 loss = loss() 

273 if callable(var_list): 

274 var_list = var_list() 

275 

276 grads = tape.gradient(loss, var_list) 

277 return list(zip(grads, var_list)) 

278 

279 def _clip_gradients(self, grads): 

280 clipped_grads = [] 

281 if self.clipnorm and self.clipnorm > 0: 

282 for g in grads: 

283 if g is None: 

284 clipped_grads.append(g) 

285 else: 

286 clipped_grads.append(tf.clip_by_norm(g, self.clipnorm)) 

287 return clipped_grads 

288 

289 if self.global_clipnorm and self.global_clipnorm > 0: 

290 return tf.clip_by_global_norm(grads, self.global_clipnorm)[0] 

291 

292 if self.clipvalue and self.clipvalue > 0: 

293 for g in grads: 

294 if g is None: 

295 clipped_grads.append(g) 

296 else: 

297 clipped_grads.append( 

298 tf.clip_by_value( 

299 g, 

300 clip_value_min=-self.clipvalue, 

301 clip_value_max=self.clipvalue, 

302 ) 

303 ) 

304 return clipped_grads 

305 

306 return grads 

307 

308 @property 

309 def iterations(self): 

310 """The number of training steps this `optimizer` has run. 

311 

312 By default, iterations would be incremented by one every time 

313 `apply_gradients()` is called. 

314 """ 

315 return self._iterations 

316 

317 @iterations.setter 

318 def iterations(self, variable): 

319 if getattr(self, "_built", False): 

320 raise RuntimeError( 

321 "Cannot set `iterations` to a new Variable after " 

322 "the Optimizer weights have been created. Here it is " 

323 f"attempting to set `iterations` to {variable}." 

324 "Usually this means you are trying to set `iterations`" 

325 " after calling `apply_gradients()`. Please set " 

326 "`iterations` before calling `apply_gradients()`." 

327 ) 

328 self._iterations = variable 

329 

330 @property 

331 def learning_rate(self): 

332 if not hasattr(self, "_learning_rate") or self._learning_rate is None: 

333 raise ValueError( 

334 "Missing learning rate, please set self.learning_rate at" 

335 " optimizer creation time." 

336 ) 

337 lr = self._learning_rate 

338 if isinstance(lr, learning_rate_schedule.LearningRateSchedule): 

339 # If the optimizer takes in LearningRateSchedule, then each call to 

340 # learning_rate would return `self._current_learning_rate`, which is 

341 # updated at each call to `apply_gradients`. 

342 return self._current_learning_rate 

343 return lr 

344 

345 @learning_rate.setter 

346 def learning_rate(self, learning_rate): 

347 if isinstance( 

348 learning_rate, learning_rate_schedule.LearningRateSchedule 

349 ): 

350 self._learning_rate = learning_rate 

351 else: 

352 if isinstance( 

353 self._learning_rate, learning_rate_schedule.LearningRateSchedule 

354 ): 

355 raise TypeError( 

356 "This optimizer was created with a `LearningRateSchedule`" 

357 " object as its `learning_rate` constructor argument, " 

358 "hence its learning rate is not settable. If you need the" 

359 " learning rate to be settable, you should instantiate " 

360 "the optimizer with a float `learning_rate` argument." 

361 ) 

362 self._learning_rate.assign(learning_rate) 

363 

364 @property 

365 @doc_controls.do_not_generate_docs 

366 def lr(self): 

367 """Alias of `learning_rate()`. 

368 

369 `lr()` is heavily called in workflows using `optimizer_v2.OptimizerV2`, 

370 so we keep it for backward compabitliy. 

371 """ 

372 return self.learning_rate 

373 

374 @lr.setter 

375 def lr(self, learning_rate): 

376 self.learning_rate = learning_rate 

377 

378 def _build_learning_rate(self, learning_rate): 

379 with tf.init_scope(): 

380 if isinstance( 

381 learning_rate, learning_rate_schedule.LearningRateSchedule 

382 ): 

383 # Create a variable to hold the current learning rate. 

384 current_learning_rate = tf.convert_to_tensor( 

385 learning_rate(self.iterations) 

386 ) 

387 self._current_learning_rate = tf.Variable( 

388 current_learning_rate, 

389 name="current_learning_rate", 

390 dtype=current_learning_rate.dtype, 

391 trainable=False, 

392 ) 

393 return learning_rate 

394 

395 return tf.Variable( 

396 learning_rate, 

397 name="learning_rate", 

398 dtype=backend.floatx(), 

399 trainable=False, 

400 ) 

401 

402 @abc.abstractmethod 

403 def build(self, var_list): 

404 """Initialize the optimizer's variables, such as momemtum variables. 

405 

406 This function has to be implemented by subclass optimizers, and subclass 

407 optimizers need to call `super().build(var_list)`. 

408 

409 Args: 

410 var_list: List of model variables to build optimizers on. For example, 

411 SGD optimizer with momentum will store one momentum variable 

412 corresponding to each model variable. 

413 """ 

414 if getattr(self, "_built", False): 

415 return 

416 self._build_index_dict(var_list) 

417 if self.use_ema: 

418 self._model_variables_moving_average = [] 

419 for var in var_list: 

420 # Make a copy of the model variables, we will use the copy to 

421 # store the moving average of model variables. 

422 self._model_variables_moving_average.append( 

423 self.add_variable_from_reference( 

424 var, "average", initial_value=var 

425 ) 

426 ) 

427 

428 def _build_index_dict(self, var_list): 

429 """Build variable to index dictionary. 

430 

431 Build a dictionary that maps variable to the index of it in the given 

432 var_list. 

433 

434 Args: 

435 var_list: List of variables to build index dict on. 

436 

437 Returns: 

438 None 

439 """ 

440 self._index_dict = {} 

441 for i, var in enumerate(var_list): 

442 var_key = self._var_key(var) 

443 self._index_dict[var_key] = i 

444 

445 def add_variable(self, shape, dtype=None, initializer="zeros", name=None): 

446 """Create an optimizer variable. 

447 

448 Args: 

449 shape: A list of integers, a tuple of integers, or a 1-D Tensor of 

450 type int32. Defaults to scalar if unspecified. 

451 dtype: The DType of the optimizer variable to be created. Defaults to 

452 `tf.keras.backend.floatx` if unspecified. 

453 initializer: string or callable. Initializer instance. 

454 name: The name of the optimizer variable to be created. 

455 

456 Returns: 

457 An optimizer variable, in the format of tf.Variable. 

458 

459 """ 

460 if isinstance(initializer, str): 

461 initializer = initializers.get(initializer) 

462 if dtype is None: 

463 dtype = backend.floatx() 

464 if shape is None: 

465 shape = [] 

466 variable = tf.Variable( 

467 initial_value=initializer(shape, dtype), name=name, trainable=False 

468 ) 

469 self._variables.append(variable) 

470 return variable 

471 

472 def add_variable_from_reference( 

473 self, model_variable, variable_name, shape=None, initial_value=None 

474 ): 

475 """Create an optimizer variable from model variable. 

476 

477 Create an optimizer variable based on the information of model variable. 

478 For example, in SGD optimizer momemtum, for each model variable, a 

479 corresponding momemtum variable is created of the same shape and dtype. 

480 

481 Args: 

482 model_variable: tf.Variable. The corresponding model variable to the 

483 optimizer variable to be created. 

484 variable_name: String. The name prefix of the optimizer variable to be 

485 created. The create variables name will follow the pattern 

486 `{variable_name}/{model_variable.name}`, e.g., `momemtum/dense_1`. 

487 shape: List or Tuple, defaults to None. The shape of the optimizer 

488 variable to be created. If None, the created variable will have the 

489 same shape as `model_variable`. 

490 initial_value: A Tensor, or Python object convertible to a Tensor, 

491 defaults to None. The initial value of the optimizer variable, if 

492 None, the initial value will be default to 0. 

493 

494 Returns: 

495 An optimizer variable. 

496 """ 

497 if initial_value is None: 

498 if shape is None: 

499 if model_variable.shape.rank is None: 

500 # When the rank is None, we cannot get a concrete 

501 # `model_variable.shape`, we use dynamic shape. 

502 initial_value = tf.zeros_like( 

503 model_variable, dtype=model_variable.dtype 

504 ) 

505 else: 

506 # We cannot always use `zeros_like`, because some cases 

507 # the shape exists while values don't. 

508 initial_value = tf.zeros( 

509 model_variable.shape, dtype=model_variable.dtype 

510 ) 

511 else: 

512 initial_value = tf.zeros(shape, dtype=model_variable.dtype) 

513 variable = tf.Variable( 

514 initial_value=initial_value, 

515 name=f"{variable_name}/{model_variable._shared_name}", 

516 dtype=model_variable.dtype, 

517 trainable=False, 

518 ) 

519 self._variables.append(variable) 

520 return variable 

521 

522 def minimize(self, loss, var_list, tape=None): 

523 """Minimize `loss` by updating `var_list`. 

524 

525 This method simply computes gradient using `tf.GradientTape` and calls 

526 `apply_gradients()`. If you want to process the gradient before applying 

527 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 

528 of using this function. 

529 

530 Args: 

531 loss: `Tensor` or callable. If a callable, `loss` should take no 

532 arguments and return the value to minimize. 

533 var_list: list or tuple of `Variable` objects to update to minimize 

534 `loss`, or a callable returning the list or tuple of `Variable` 

535 objects. Use callable when the variable list would otherwise be 

536 incomplete before `minimize` since the variables are created at the 

537 first time `loss` is called. 

538 tape: (Optional) `tf.GradientTape`. 

539 

540 Returns: 

541 None 

542 """ 

543 grads_and_vars = self.compute_gradients(loss, var_list, tape) 

544 self.apply_gradients(grads_and_vars) 

545 

546 def _compute_current_learning_rate(self): 

547 if isinstance( 

548 self._learning_rate, learning_rate_schedule.LearningRateSchedule 

549 ): 

550 # Compute the current learning rate at the beginning of variable 

551 # update. 

552 if hasattr(self, "_current_learning_rate"): 

553 self._current_learning_rate.assign( 

554 self._learning_rate(self.iterations) 

555 ) 

556 else: 

557 current_learning_rate = tf.convert_to_tensor( 

558 self._learning_rate(self.iterations) 

559 ) 

560 self._current_learning_rate = tf.Variable( 

561 current_learning_rate, 

562 name="current_learning_rate", 

563 dtype=current_learning_rate.dtype, 

564 trainable=False, 

565 ) 

566 

567 def exclude_from_weight_decay(self, var_list=None, var_names=None): 

568 """Exclude variables from weight decay. 

569 

570 This method must be called before the optimizer's `build` method is 

571 called. You can set specific variables to exclude out, or set a list of 

572 strings as the anchor words, if any of which appear in a variable's 

573 name, then the variable is excluded. 

574 

575 Args: 

576 var_list: A list of `tf.Variable`s to exclude from weight decay. 

577 var_names: A list of strings. If any string in `var_names` appear 

578 in the model variable's name, then this model variable is 

579 excluded from weight decay. For example, `var_names=['bias']` 

580 excludes all bias variables from weight decay. 

581 """ 

582 if hasattr(self, "_built") and self._built: 

583 raise ValueError( 

584 "`exclude_from_weight_decay()` can only be configued before " 

585 "the optimizer is built." 

586 ) 

587 

588 if var_list: 

589 self._exclude_from_weight_decay = [ 

590 self._var_key(variable) for variable in var_list 

591 ] 

592 else: 

593 self._exclude_from_weight_decay = [] 

594 self._exclude_from_weight_decay_names = var_names or [] 

595 

596 def _use_weight_decay(self, variable): 

597 exclude_from_weight_decay = getattr( 

598 self, "_exclude_from_weight_decay", [] 

599 ) 

600 exclude_from_weight_decay_names = getattr( 

601 self, "_exclude_from_weight_decay_names", [] 

602 ) 

603 variable_id = self._var_key(variable) 

604 for exclude_id in exclude_from_weight_decay: 

605 if variable_id == exclude_id: 

606 return False 

607 for name in exclude_from_weight_decay_names: 

608 if re.search(name, variable.name) is not None: 

609 return False 

610 return True 

611 

612 def apply_gradients(self, grads_and_vars, name=None): 

613 """Apply gradients to variables. 

614 

615 Args: 

616 grads_and_vars: List of `(gradient, variable)` pairs. 

617 name: string, defaults to None. The name of the namescope to 

618 use when creating variables. If None, `self.name` will be used. 

619 

620 Returns: 

621 A `tf.Variable`, representing the current iteration. 

622 

623 Raises: 

624 TypeError: If `grads_and_vars` is malformed. 

625 """ 

626 self._compute_current_learning_rate() 

627 grads_and_vars = list(grads_and_vars) 

628 if len(grads_and_vars) == 0: 

629 # It is possible that the grad is empty. In this case, 

630 # `apply_gradients` is a no-op. 

631 return self._iterations 

632 grads, trainable_variables = zip(*grads_and_vars) 

633 scope_name = name or self.name or "optimizer" 

634 with tf.name_scope(scope_name): 

635 with tf.init_scope(): 

636 # Lift variable creation to init scope to avoid environment 

637 # issues. 

638 self.build(trainable_variables) 

639 grads_and_vars = optimizer_utils.filter_empty_gradients( 

640 grads_and_vars 

641 ) 

642 if len(list(grads_and_vars)) == 0: 

643 # Check again after filtering gradients. 

644 return self._iterations 

645 

646 grads, trainable_variables = zip(*grads_and_vars) 

647 

648 grads = self._clip_gradients(grads) 

649 grads = self._deduplicate_sparse_grad(grads) 

650 self._apply_weight_decay(trainable_variables) 

651 grads_and_vars = list(zip(grads, trainable_variables)) 

652 iteration = self._internal_apply_gradients(grads_and_vars) 

653 

654 # Apply variable constraints after applying gradients. 

655 for variable in trainable_variables: 

656 if variable.constraint is not None: 

657 variable.assign(variable.constraint(variable)) 

658 return iteration 

659 

660 def _apply_weight_decay(self, variables): 

661 if self.weight_decay is None: 

662 return 

663 for variable in variables: 

664 if self._use_weight_decay(variable): 

665 lr = tf.cast(self.learning_rate, variable.dtype) 

666 wd = tf.cast(self.weight_decay, variable.dtype) 

667 variable.assign_sub(variable * wd * lr) 

668 

669 def _internal_apply_gradients(self, grads_and_vars): 

670 """Helper function of apply gradients. 

671 

672 This is required for separating out distributed training logic. 

673 

674 Args: 

675 grads_and_vars: List of (gradient, variable) pairs. 

676 """ 

677 if self.jit_compile: 

678 for grad, var in grads_and_vars: 

679 self._update_step_xla(grad, var, id(self._var_key(var))) 

680 else: 

681 for grad, var in grads_and_vars: 

682 self._update_step(grad, var) 

683 return self.iterations.assign_add(1) 

684 

685 def _update_model_variables_moving_average(self, var_list): 

686 """Update the stored moving average using the latest value.""" 

687 if self.use_ema: 

688 for var, average in zip( 

689 var_list, self._model_variables_moving_average 

690 ): 

691 average.assign( 

692 self.ema_momentum * average + (1 - self.ema_momentum) * var 

693 ) 

694 

695 def _overwrite_model_variables_with_average_value(self, var_list): 

696 """Overwrite model variables with its moving average.""" 

697 if len(var_list) != len(self._model_variables_moving_average): 

698 raise ValueError( 

699 f"The length of model variables ({len(var_list)}) to " 

700 "override does not match the length of model variables " 

701 "stored in the optimizer " 

702 f"({len(self._model_variables_moving_average)}). Please " 

703 "check if the optimizer was called on your model." 

704 ) 

705 self._overwrite_model_variables_with_average_value_helper(var_list) 

706 

707 def _overwrite_model_variables_with_average_value_helper(self, var_list): 

708 """Helper function that overwrites model variables.""" 

709 for var, average_var in zip( 

710 var_list, self._model_variables_moving_average 

711 ): 

712 var.assign(average_var) 

713 

714 def finalize_variable_values(self, var_list): 

715 """Set the final value of model's trainable variables. 

716 

717 Sometimes there are some extra steps before ending the variable updates, 

718 such as overriding the model variables with its average value. 

719 

720 Args: 

721 var_list: list of model variables. 

722 """ 

723 if self.use_ema: 

724 # If the optimizer uses EMA, then when finalizing, we replace the 

725 # model variable value with its moving average stored inside 

726 # optimizer. 

727 self._overwrite_model_variables_with_average_value(var_list) 

728 

729 def _serialize_hyperparameter(self, hyperparameter): 

730 """Serialize a hyperparameter that can be a numeric or callable.""" 

731 if isinstance( 

732 hyperparameter, learning_rate_schedule.LearningRateSchedule 

733 ): 

734 return learning_rate_schedule.serialize(hyperparameter) 

735 if isinstance(hyperparameter, tf.Variable): 

736 return hyperparameter.numpy() 

737 if callable(hyperparameter): 

738 return hyperparameter() 

739 return hyperparameter 

740 

741 def get_config(self): 

742 """Returns the config of the optimizer. 

743 

744 An optimizer config is a Python dictionary (serializable) 

745 containing the configuration of an optimizer. 

746 The same optimizer can be reinstantiated later 

747 (without any saved state) from this configuration. 

748 

749 Subclass optimizer should override this method to include other 

750 hyperparameters. 

751 

752 Returns: 

753 Python dictionary. 

754 """ 

755 config = { 

756 "name": self.name, 

757 "weight_decay": self.weight_decay, 

758 "clipnorm": self.clipnorm, 

759 "global_clipnorm": self.global_clipnorm, 

760 "clipvalue": self.clipvalue, 

761 "use_ema": self.use_ema, 

762 "ema_momentum": self.ema_momentum, 

763 "ema_overwrite_frequency": self.ema_overwrite_frequency, 

764 "jit_compile": self.jit_compile, 

765 "is_legacy_optimizer": False, 

766 } 

767 return config 

768 

769 @classmethod 

770 def from_config(cls, config, custom_objects=None): 

771 """Creates an optimizer from its config. 

772 

773 This method is the reverse of `get_config`, capable of instantiating the 

774 same optimizer from the config dictionary. 

775 

776 Args: 

777 config: A Python dictionary, typically the output of get_config. 

778 custom_objects: A Python dictionary mapping names to additional 

779 user-defined Python objects needed to recreate this optimizer. 

780 

781 Returns: 

782 An optimizer instance. 

783 """ 

784 if "learning_rate" in config: 

785 if isinstance(config["learning_rate"], dict): 

786 config["learning_rate"] = learning_rate_schedule.deserialize( 

787 config["learning_rate"], custom_objects=custom_objects 

788 ) 

789 return cls(**config) 

790 

791 @property 

792 def variables(self): 

793 """Returns variables of this optimizer.""" 

794 return CallableList(self._variables) 

795 

796 def set_weights(self, weights): 

797 """Set the weights of the optimizer. 

798 

799 Args: 

800 weights: a list of `tf.Variable`s or numpy arrays, the target values 

801 of optimizer variables. It should have the same order as 

802 `self._variables`. 

803 """ 

804 if not getattr(self, "_built", False): 

805 raise ValueError( 

806 "You are calling `set_weights()` on an optimizer that has not " 

807 "yet been built. Please call " 

808 "`optimizer.build(trainable_variables)` to create the " 

809 "optimizer weights before calling `set_weights()`." 

810 ) 

811 

812 for variable, weight in zip(self._variables, weights): 

813 if variable.shape != weight.shape: 

814 raise ValueError( 

815 f"Optimizer variable {self._var_key(variable)} has shape " 

816 f"{str(variable.shape)} not compatible with provided " 

817 f"weight shape {str(weight.shape)}." 

818 ) 

819 variable.assign(weight) 

820 

821 def save_own_variables(self, store): 

822 """Get the state of this optimizer object.""" 

823 for i, variable in enumerate(self.variables): 

824 store[str(i)] = variable.numpy() 

825 

826 def load_own_variables(self, store): 

827 """Set the state of this optimizer object.""" 

828 if len(store.keys()) != len(self.variables): 

829 msg = ( 

830 f"Skipping variable loading for optimizer '{self.name}', " 

831 f"because it has {len(self.variables)} variables whereas " 

832 f"the saved optimizer has {len(store.keys())} variables. " 

833 ) 

834 if len(self.variables) == 0: 

835 msg += ( 

836 "This is likely because the optimizer has not been " 

837 "called/built yet." 

838 ) 

839 logging.warning(msg) 

840 return 

841 for i, variable in enumerate(self.variables): 

842 variable.assign(store[str(i)]) 

843 

844 

845base_optimizer_keyword_args = """name: String. The name to use 

846 for momentum accumulator weights created by 

847 the optimizer. 

848 weight_decay: Float, defaults to None. If set, weight decay is applied. 

849 clipnorm: Float. If set, the gradient of each weight is individually 

850 clipped so that its norm is no higher than this value. 

851 clipvalue: Float. If set, the gradient of each weight is clipped to be no 

852 higher than this value. 

853 global_clipnorm: Float. If set, the gradient of all weights is clipped so 

854 that their global norm is no higher than this value. 

855 use_ema: Boolean, defaults to False. If True, exponential moving average 

856 (EMA) is applied. EMA consists of computing an exponential moving 

857 average of the weights of the model (as the weight values change after 

858 each training batch), and periodically overwriting the weights with 

859 their moving average. 

860 ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`. 

861 This is the momentum to use when computing 

862 the EMA of the model's weights: 

863 `new_average = ema_momentum * old_average + (1 - ema_momentum) * 

864 current_variable_value`. 

865 ema_overwrite_frequency: Int or None, defaults to None. Only used if 

866 `use_ema=True`. Every `ema_overwrite_frequency` steps of iterations, 

867 we overwrite the model variable by its moving average. 

868 If None, the optimizer 

869 does not overwrite model variables in the middle of training, and you 

870 need to explicitly overwrite the variables at the end of training 

871 by calling `optimizer.finalize_variable_values()` 

872 (which updates the model 

873 variables in-place). When using the built-in `fit()` training loop, 

874 this happens automatically after the last epoch, 

875 and you don't need to do anything. 

876 jit_compile: Boolean, defaults to True. 

877 If True, the optimizer will use XLA 

878 compilation. If no GPU device is found, this flag will be ignored. 

879 mesh: optional `tf.experimental.dtensor.Mesh` instance. When provided, 

880 the optimizer will be run in DTensor mode, e.g. state 

881 tracking variable will be a DVariable, and aggregation/reduction will 

882 happen in the global DTensor context. 

883 **kwargs: keyword arguments only used for backward compatibility.""" 

884 

885 

886@keras_export( 

887 "keras.optimizers.Optimizer", 

888 "keras.optimizers.experimental.Optimizer", 

889 v1=[], 

890) 

891class Optimizer(_BaseOptimizer): 

892 """Abstract optimizer base class. 

893 

894 This class supports distributed training. If you want to implement your own 

895 optimizer, please subclass this class instead of _BaseOptimizer. 

896 

897 Args: 

898 {{base_optimizer_keyword_args}} 

899 

900 ### Usage 

901 

902 ```python 

903 # Create an optimizer with the desired parameters. 

904 opt = keras.optimizers.SGD(learning_rate=0.1) 

905 var1, var2 = tf.Variable(1.0), tf.Variable(2.0) 

906 # `loss` is a callable that takes no argument and returns the value 

907 # to minimize. 

908 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 

909 # Call minimize to update the list of variables. 

910 opt.minimize(loss, var_list=[var1, var2]) 

911 ``` 

912 

913 ### Processing gradients before applying them 

914 

915 Calling `minimize()` takes care of both computing the gradients and 

916 applying them to the variables. If you want to process the gradients 

917 before applying them you can instead use the optimizer in three steps: 

918 

919 1. Compute the gradients with `tf.GradientTape`. 

920 2. Process the gradients as you wish. 

921 3. Apply the processed gradients with `apply_gradients()`. 

922 

923 Example: 

924 

925 ```python 

926 # Create an optimizer. 

927 opt = tf.keras.optimizers.experimental.SGD(learning_rate=0.1) 

928 var1, var2 = tf.Variable(1.0), tf.Variable(2.0) 

929 

930 # Compute the gradients for a list of variables. 

931 with tf.GradientTape() as tape: 

932 loss = 3 * var1 * var1 + 2 * var2 * var2 

933 grads = tape.gradient(loss, [var1, var2]) 

934 

935 # Process the gradients. 

936 grads[0] = grads[0] + 1 

937 

938 # Ask the optimizer to apply the gradients on variables. 

939 opt.apply_gradients(zip(grads, [var1, var2])) 

940 ``` 

941 

942 ### Dynamic learning rate 

943 

944 Dynamic learning rate can be achieved by setting learning rate as a built-in 

945 or customized `tf.keras.optimizers.schedules.LearningRateSchedule`. 

946 

947 Example: 

948 

949 >>> var = tf.Variable(np.random.random(size=(1,))) 

950 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( 

951 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) 

952 >>> opt = tf.keras.optimizers.experimental.SGD(learning_rate=learning_rate) 

953 >>> loss = lambda: 3 * var 

954 >>> opt.minimize(loss, var_list=[var]) 

955 

956 ### Gradients clipping 

957 

958 Users can clip the gradients before applying to variables by setting 

959 `clipnorm`, `clipvalue` and `global_clipnorm`. Notice that `clipnorm` and 

960 `global_clipnorm` can only have one being set. 

961 

962 Example: 

963 

964 >>> opt = tf.keras.optimizers.experimental.SGD(learning_rate=1, clipvalue=1) 

965 >>> var1, var2 = tf.Variable(2.0), tf.Variable(2.0) 

966 >>> with tf.GradientTape() as tape: 

967 ... loss = 2 * var1 + 2 * var2 

968 >>> grads = tape.gradient(loss, [var1, var2]) 

969 >>> print([grads[0].numpy(), grads[1].numpy()]) 

970 [2.0, 2.0] 

971 >>> opt.apply_gradients(zip(grads, [var1, var2])) 

972 >>> # Without clipping, we should get [0, 0], but as gradients are clipped 

973 >>> # to have max value 1, we get [1.0, 1.0]. 

974 >>> print([var1.numpy(), var2.numpy()]) 

975 [1.0, 1.0] 

976 

977 ### Using weight decay. 

978 

979 Weight decay in certain scenarios can boost the model's performance. Keras 

980 has built-in support for weight decay in all optimizers. Users can apply 

981 weight decay by setting `weight_decay` argument. 

982 

983 >>> opt = tf.keras.optimizers.experimental.SGD(1, weight_decay=0.004) 

984 >>> grads, var1, var2 = tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0) 

985 >>> # You can exclude variables from weight decay, in this case we 

986 >>> # exclude `var2`. 

987 >>> opt.exclude_from_weight_decay(var_list=[var2]) 

988 >>> opt.apply_gradients(zip([grads, grads], [var1, var2])) 

989 >>> print([var1.numpy(), var2.numpy()]) 

990 [1.992, 2.0] 

991 

992 

993 ### Using exponential moving average. 

994 

995 Empirically it has been found that using the exponential moving average 

996 (EMA) of the trained parameters of a deep network achieves a better 

997 performance than using its trained parameters directly. Keras optimizers 

998 allows users to compute this moving average and overwrite the model 

999 variables at desired time. 

1000 

1001 Example: 

1002 

1003 ```python 

1004 # Create an SGD optimizer with EMA on. `ema_momentum` controls the decay 

1005 # rate of the moving average. `ema_momentum=1` means no decay and the stored 

1006 # moving average is always model variable's initial value before training. 

1007 # Reversely, `ema_momentum=0` is equivalent to not using EMA. 

1008 # `ema_overwrite_frequency=3` means every 3 iterations, we overwrite the 

1009 # trainable variables with their moving average values. 

1010 opt = tf.keras.optimizers.experimental.SGD( 

1011 learning_rate=1, 

1012 use_ema=True, 

1013 ema_momentum=0.5, 

1014 ema_overwrite_frequency=3) 

1015 var1, var2 = tf.Variable(2.0), tf.Variable(2.0) 

1016 with tf.GradientTape() as tape: 

1017 loss = var1 + var2 

1018 grads = tape.gradient(loss, [var1, var2]) 

1019 # First iteration: [var1, var2] = [1.0, 1.0] 

1020 opt.apply_gradients(zip(grads, [var1, var2])) 

1021 print([var1, var2]) 

1022 

1023 # Second iteration: [var1, var2] = [0.0, 0.0] 

1024 opt.apply_gradients(zip(grads, [var1, var2])) 

1025 print([var1, var2]) 

1026 

1027 # Third iteration, without EMA, we should see [var1, var2] = [-1.0, -1.0], 

1028 # but overwriting results in [var1, var2] = [-0.125, -0.125]. The full 

1029 # calculation for the moving average of var1 is: 

1030 # var1=2*0.5**3+1*(1-0.5)*0.5**2+0*(1-0.5)*0.5**1+(-1)*(1-0.5)=-0.125. 

1031 opt.apply_gradients(zip(grads, [var1, var2])) 

1032 print([var1, var2]) 

1033 

1034 ``` 

1035 When optimizer is constructed with `use_ema=True`, in custom training loop, 

1036 users can explicitly call `finalize_variable_values()` to overwrite 

1037 trainable variables with their EMA values. `finalize_variable_values()` is 

1038 by default called at the end of `model.fit()`. 

1039 

1040 ### Use with `tf.distribute.Strategy` 

1041 

1042 This optimizer class is `tf.distribute.Strategy` aware, which means it 

1043 automatically sums gradients across all replicas. To aggregate gradients 

1044 yourself, call `apply_gradients` with `skip_aggregate_gradients` set to 

1045 True. This is useful if you need to process aggregated gradients. 

1046 

1047 ```python 

1048 # This example is not runnable, it consists of dummy code for simple 

1049 # tutorial. 

1050 strategy = tf.distribute.experimental.TPUStrategy() 

1051 

1052 with strategy.scope(): 

1053 opt = tf.keras.optimizers.experimental.SGD() 

1054 model = magic_function_that_returns_model() 

1055 gradients = magic_function_that_returns_gradients() 

1056 # Custom logic to aggregate gradients. 

1057 gradients = strategy.reduce("SUM", gradients, axis=None) 

1058 opt.apply_gradients(zip(gradients, model.trainable_variables), 

1059 skip_aggregate_gradients=True) 

1060 ``` 

1061 

1062 ### Creating a custom optimizer 

1063 

1064 If you intend to create your own optimization algorithm, please inherit from 

1065 this class and override the following methods: 

1066 

1067 - `build`: Create your optimizer-related variables, such as `momentums` in 

1068 SGD optimizer. 

1069 - `update_step`: Implement your optimizer's updating logic. 

1070 - `get_config`: serialization of the optimizer, include all hyper 

1071 parameters. 

1072 

1073 Your optimizer would automatically be compatible with tensorflow distributed 

1074 training if you subclass `optimizer_experimental.Optimizer`. 

1075 

1076 """ 

1077 

1078 def __init__( 

1079 self, 

1080 name, 

1081 weight_decay=0, 

1082 clipnorm=None, 

1083 clipvalue=None, 

1084 global_clipnorm=None, 

1085 use_ema=False, 

1086 ema_momentum=0.99, 

1087 ema_overwrite_frequency=None, 

1088 jit_compile=True, 

1089 **kwargs, 

1090 ): 

1091 """Create a new Optimizer.""" 

1092 mesh = kwargs.pop("mesh", None) 

1093 self._mesh = mesh 

1094 super().__init__( 

1095 name, 

1096 weight_decay, 

1097 clipnorm, 

1098 clipvalue, 

1099 global_clipnorm, 

1100 use_ema, 

1101 ema_momentum, 

1102 ema_overwrite_frequency, 

1103 jit_compile, 

1104 **kwargs, 

1105 ) 

1106 self._distribution_strategy = tf.distribute.get_strategy() 

1107 self._run_with_dtensor = dtensor_utils.running_with_dtensor_strategy() 

1108 

1109 def add_variable_from_reference( 

1110 self, model_variable, variable_name, shape=None, initial_value=None 

1111 ): 

1112 if self._mesh: 

1113 if initial_value is None: 

1114 # Use tf.zeros_like which will propagate the layout information 

1115 # from the model weights if any. 

1116 initial_value = tf.zeros_like(model_variable) 

1117 elif isinstance(initial_value, tf.Tensor): 

1118 initial_value = tf.experimental.dtensor.copy_to_mesh( 

1119 initial_value, 

1120 tf.experimental.dtensor.Layout.replicated( 

1121 self._mesh, rank=initial_value.shape.rank 

1122 ), 

1123 ) 

1124 variable = tf.experimental.dtensor.DVariable( 

1125 initial_value=initial_value, 

1126 name=f"{variable_name}/{model_variable._shared_name}", 

1127 dtype=model_variable.dtype, 

1128 trainable=False, 

1129 ) 

1130 self._variables.append(variable) 

1131 return variable 

1132 else: 

1133 strategy = tf.distribute.get_strategy() 

1134 with strategy.extended.colocate_vars_with(model_variable): 

1135 return super().add_variable_from_reference( 

1136 model_variable, variable_name, shape, initial_value 

1137 ) 

1138 

1139 def _create_iteration_variable(self): 

1140 if self._mesh: 

1141 init_val = tf.constant(0, dtype=tf.int64) 

1142 init_val = tf.experimental.dtensor.copy_to_mesh( 

1143 init_val, 

1144 tf.experimental.dtensor.Layout.replicated(self._mesh, rank=0), 

1145 ) 

1146 with tf.init_scope(): 

1147 # Lift the variable creation to init scope to avoid environment 

1148 # issue. 

1149 self._iterations = tf.experimental.dtensor.DVariable( 

1150 init_val, name="iteration" 

1151 ) 

1152 self._variables.append(self._iterations) 

1153 else: 

1154 super()._create_iteration_variable() 

1155 

1156 def _var_key(self, variable): 

1157 """Get a unique identifier of the given variable.""" 

1158 

1159 # Get the distributed variable if it exists. 

1160 # TODO(b/197554203): replace _distributed_container() with a public api. 

1161 if hasattr(variable, "_distributed_container"): 

1162 variable = variable._distributed_container() 

1163 elif ( 

1164 tf_utils.is_extension_type(variable) 

1165 and hasattr(variable, "handle") 

1166 and hasattr(variable.handle, "_distributed_container") 

1167 ): 

1168 # For ResourceVariables, the _distributed_container attribute 

1169 # is added to their handle tensors. 

1170 variable = variable.handle._distributed_container() 

1171 return super()._var_key(variable) 

1172 

1173 def aggregate_gradients(self, grads_and_vars): 

1174 """Aggregate gradients on all devices. 

1175 

1176 By default, we will perform reduce_sum of gradients across devices. 

1177 Users can implement their own aggregation logic by overriding this 

1178 method. 

1179 

1180 Args: 

1181 grads_and_vars: List of (gradient, variable) pairs. 

1182 

1183 Returns: 

1184 List of (gradient, variable) pairs. 

1185 """ 

1186 if self._mesh or self._run_with_dtensor: 

1187 raise NotImplementedError( 

1188 "Dtensor doesn't need to manually aggregate gradients" 

1189 ) 

1190 else: 

1191 return optimizer_utils.all_reduce_sum_gradients(grads_and_vars) 

1192 

1193 def apply_gradients( 

1194 self, 

1195 grads_and_vars, 

1196 name=None, 

1197 skip_gradients_aggregation=False, 

1198 **kwargs, 

1199 ): 

1200 """Apply gradients to variables. 

1201 

1202 Args: 

1203 grads_and_vars: List of `(gradient, variable)` pairs. 

1204 name: string, defaults to None. The name of the namescope to 

1205 use when creating variables. If None, `self.name` will be used. 

1206 skip_gradients_aggregation: If true, gradients aggregation will not be 

1207 performed inside optimizer. Usually this arg is set to True when you 

1208 write custom code aggregating gradients outside the optimizer. 

1209 **kwargs: keyword arguments only used for backward compatibility. 

1210 

1211 Returns: 

1212 A `tf.Variable`, representing the current iteration. 

1213 

1214 Raises: 

1215 TypeError: If `grads_and_vars` is malformed. 

1216 RuntimeError: If called in a cross-replica context. 

1217 """ 

1218 if self._mesh or self._run_with_dtensor: 

1219 # Skip any usage of strategy logic for DTensor 

1220 return super().apply_gradients(grads_and_vars, name=name) 

1221 

1222 # `experimental_aggregate_gradients` is an arg in `apply_gradients` of 

1223 # v2 optimizer -- the reverse of `skip_gradients_aggregation`. 

1224 # We read it from kwargs for backward compatibility. 

1225 experimental_aggregate_gradients = kwargs.pop( 

1226 "experimental_aggregate_gradients", True 

1227 ) 

1228 if not skip_gradients_aggregation and experimental_aggregate_gradients: 

1229 grads_and_vars = self.aggregate_gradients(grads_and_vars) 

1230 return super().apply_gradients(grads_and_vars, name=name) 

1231 

1232 def _apply_weight_decay(self, variables): 

1233 # Apply weight decay in distributed setup. 

1234 if self.weight_decay is None: 

1235 return 

1236 

1237 def distributed_apply_weight_decay(distribution, variables, **kwargs): 

1238 def weight_decay_fn(variable): 

1239 if self._use_weight_decay(variable): 

1240 lr = tf.cast(self.learning_rate, variable.dtype) 

1241 wd = tf.cast(self.weight_decay, variable.dtype) 

1242 variable.assign_sub(variable * wd * lr) 

1243 

1244 for variable in variables: 

1245 distribution.extended.update( 

1246 variable, weight_decay_fn, group=False 

1247 ) 

1248 

1249 tf.__internal__.distribute.interim.maybe_merge_call( 

1250 distributed_apply_weight_decay, 

1251 self._distribution_strategy, 

1252 variables, 

1253 ) 

1254 

1255 def _internal_apply_gradients(self, grads_and_vars): 

1256 if self._mesh or self._run_with_dtensor: 

1257 # Skip any usage of strategy logic for DTensor 

1258 return super()._internal_apply_gradients(grads_and_vars) 

1259 

1260 return tf.__internal__.distribute.interim.maybe_merge_call( 

1261 self._distributed_apply_gradients_fn, 

1262 self._distribution_strategy, 

1263 grads_and_vars, 

1264 ) 

1265 

1266 def _overwrite_model_variables_with_average_value_helper(self, var_list): 

1267 """Helper function to _overwrite_model_variables_with_average_value. 

1268 

1269 This function overwrites variables on each device. 

1270 Args: 

1271 var_list: list of model variables. 

1272 """ 

1273 if self._mesh or self._run_with_dtensor: 

1274 # Skip any usage of strategy logic for DTensor 

1275 super()._overwrite_model_variables_with_average_value_helper( 

1276 var_list 

1277 ) 

1278 

1279 strategy = self._distribution_strategy 

1280 # Override model variable by the stored average value on all devices. 

1281 for var, average_var in zip( 

1282 var_list, self._model_variables_moving_average 

1283 ): 

1284 strategy.extended.update( 

1285 var, lambda a, b: a.assign(b), args=(average_var,) 

1286 ) 

1287 

1288 def _build_learning_rate(self, learning_rate): 

1289 if not self._mesh: 

1290 return super()._build_learning_rate(learning_rate) 

1291 

1292 # For DTensor 

1293 variable_creation = tf.experimental.dtensor.DVariable 

1294 init_value_convert_fn = lambda x: tf.experimental.dtensor.copy_to_mesh( 

1295 x, tf.experimental.dtensor.Layout.replicated(self._mesh, rank=0) 

1296 ) 

1297 if isinstance( 

1298 learning_rate, learning_rate_schedule.LearningRateSchedule 

1299 ): 

1300 current_learning_rate = tf.convert_to_tensor( 

1301 learning_rate(self.iterations) 

1302 ) 

1303 current_learning_rate = init_value_convert_fn(current_learning_rate) 

1304 # Create a variable to hold the current learning rate. 

1305 # Note that the init value `learning_rate(self.iterations)` should 

1306 # have the correct layout information from self.iterations. 

1307 self._current_learning_rate = variable_creation( 

1308 current_learning_rate, 

1309 name="learning_rate", 

1310 dtype=tf.float32, 

1311 ) 

1312 return learning_rate 

1313 

1314 init_val = init_value_convert_fn( 

1315 tf.constant(learning_rate, dtype=tf.float32) 

1316 ) 

1317 return variable_creation( 

1318 init_val, 

1319 name="learning_rate", 

1320 dtype=backend.floatx(), 

1321 trainable=False, 

1322 ) 

1323 

1324 def _update_model_variables_moving_average(self, var_list): 

1325 """Update the stored moving average using the latest value.""" 

1326 if self.use_ema: 

1327 

1328 def update_average(average, var): 

1329 average.assign( 

1330 self.ema_momentum * average + (1 - self.ema_momentum) * var 

1331 ) 

1332 

1333 for var, average in zip( 

1334 var_list, self._model_variables_moving_average 

1335 ): 

1336 self._distribution_strategy.extended.update( 

1337 average, update_average, args=(var,), group=False 

1338 ) 

1339 

1340 def _distributed_apply_gradients_fn( 

1341 self, distribution, grads_and_vars, **kwargs 

1342 ): 

1343 """`apply_gradients` using a `DistributionStrategy`.""" 

1344 

1345 def apply_grad_to_update_var(var, grad): 

1346 if self.jit_compile: 

1347 return self._update_step_xla(grad, var, id(self._var_key(var))) 

1348 else: 

1349 return self._update_step(grad, var) 

1350 

1351 for grad, var in grads_and_vars: 

1352 distribution.extended.update( 

1353 var, apply_grad_to_update_var, args=(grad,), group=False 

1354 ) 

1355 

1356 if self.use_ema: 

1357 _, var_list = zip(*grads_and_vars) 

1358 self._update_model_variables_moving_average(var_list) 

1359 if self.ema_overwrite_frequency: 

1360 # Only when self.ema_overwrite_frequency is not None, we 

1361 # overwrite the model variables. 

1362 should_overwrite_model_vars = ( 

1363 self.iterations + 1 

1364 ) % self.ema_overwrite_frequency == 0 

1365 tf.cond( 

1366 tf.cast(should_overwrite_model_vars, tf.bool), 

1367 true_fn=lambda: self._overwrite_model_variables_with_average_value( # noqa: E501 

1368 var_list 

1369 ), 

1370 false_fn=lambda: None, 

1371 ) 

1372 return self.iterations.assign_add(1) 

1373 

1374 

1375class RestoredOptimizer(Optimizer): 

1376 def __init__(self): 

1377 super().__init__("RestoredOptimizer") 

1378 

1379 def get_config(self): 

1380 raise NotImplementedError( 

1381 "Restoring functional Optimizers from SavedModels is not currently " 

1382 "supported. Please file a feature request if this limitation " 

1383 "bothers you." 

1384 ) 

1385 

1386 

1387class CallableList(list): 

1388 """Temporary shim to support both `opt.variables()` and `opt.variables`.""" 

1389 

1390 def __call__(self): 

1391 return self 

1392 

1393 

1394# Register the optimizer for loading from saved_model purpose. 

1395tf.__internal__.saved_model.load.register_revived_type( 

1396 "experimentalOptimizer", 

1397 lambda obj: isinstance(obj, Optimizer), 

1398 versions=[ 

1399 tf.__internal__.saved_model.load.VersionedTypeRegistration( 

1400 object_factory=lambda proto: RestoredOptimizer(), 

1401 version=2, 

1402 min_producer_version=1, 

1403 min_consumer_version=1, 

1404 ) 

1405 ], 

1406) 

1407 

1408Optimizer.__doc__ = Optimizer.__doc__.replace( 

1409 "{{base_optimizer_keyword_args}}", base_optimizer_keyword_args 

1410) 

1411