Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/optimizer_v2.py: 21%

512 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Version 2 of class Optimizer.""" 

16 

17 

18import abc 

19import contextlib 

20import functools 

21import warnings 

22from copy import deepcopy 

23 

24import tensorflow.compat.v2 as tf 

25 

26from keras.src import backend 

27from keras.src import initializers 

28from keras.src.engine import base_layer_utils 

29from keras.src.optimizers import utils as optimizer_utils 

30from keras.src.optimizers.schedules import learning_rate_schedule 

31from keras.src.utils import generic_utils 

32from keras.src.utils import layer_utils 

33from keras.src.utils import tf_inspect 

34from keras.src.utils import tf_utils 

35 

36# isort: off 

37from tensorflow.python.util.tf_export import keras_export 

38 

39keras_optimizers_gauge = tf.__internal__.monitoring.BoolGauge( 

40 "/tensorflow/api/keras/optimizers", "keras optimizer usage", "method" 

41) 

42 

43_DEFAULT_VALID_DTYPES = frozenset( 

44 [ 

45 tf.float16, 

46 tf.bfloat16, 

47 tf.float32, 

48 tf.float64, 

49 tf.complex64, 

50 tf.complex128, 

51 ] 

52) 

53 

54 

55def _deduplicate_indexed_slices(values, indices): 

56 """Sums `values` associated with any non-unique `indices`. 

57 

58 Args: 

59 values: A `Tensor` with rank >= 1. 

60 indices: A one-dimensional integer `Tensor`, indexing into the first 

61 dimension of `values` (as in an IndexedSlices object). 

62 

63 Returns: 

64 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 

65 de-duplicated version of `indices` and `summed_values` contains the sum of 

66 `values` slices associated with each unique index. 

67 """ 

68 unique_indices, new_index_positions = tf.unique(indices) 

69 summed_values = tf.math.unsorted_segment_sum( 

70 values, new_index_positions, tf.shape(unique_indices)[0] 

71 ) 

72 return (summed_values, unique_indices) 

73 

74 

75class NullContextmanager: 

76 def __init__(self, *args, **kwargs): 

77 pass 

78 

79 def __enter__(self): 

80 pass 

81 

82 def __exit__(self, type_arg, value_arg, traceback_arg): 

83 return False # False values do not suppress exceptions 

84 

85 

86def name_scope_only_in_function_or_graph(name): 

87 """Internal-only entry point for `name_scope*`. 

88 

89 Enters a compat.v1.name_scope only when in a function or graph, 

90 not when running fully eagerly. 

91 

92 Args: 

93 name: The name argument that is passed to the op function. 

94 

95 Returns: 

96 `name_scope*` context manager. 

97 """ 

98 if not tf.executing_eagerly(): 

99 return tf.name_scope(name) 

100 else: 

101 return NullContextmanager() 

102 

103 

104@keras_export( 

105 "keras.optimizers.legacy.Optimizer", 

106 v1=["keras.optimizers.Optimizer", "keras.optimizers.legacy.Optimizer"], 

107) 

108class OptimizerV2(tf.__internal__.tracking.Trackable): 

109 """Base class for legacy Keras optimizers. 

110 

111 You should not use this class directly, but instead instantiate one of its 

112 subclasses such as `tf.keras.optimizers.legacy.SGD`, 

113 `tf.keras.optimizers.legacy.Adam`, etc. 

114 

115 This is the default Keras optimizer base class until v2.10 (included). 

116 In v2.11 and later, `tf.keras.optimizers.Optimizer` 

117 points to a new base class implementation. The legacy class won't be 

118 deleted in the future and will continue to be available at 

119 `tf.keras.optimizers.legacy.Optimizer`. 

120 

121 ### Usage 

122 

123 ```python 

124 # Create an optimizer with the desired parameters. 

125 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1) 

126 # `loss` is a callable that takes no argument and returns the value 

127 # to minimize. 

128 var1 = tf.Variable(2.0) 

129 var2 = tf.Variable(5.0) 

130 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 

131 # In graph mode, returns op that minimizes the loss by updating the listed 

132 # variables. 

133 opt_op = opt.minimize(loss, var_list=[var1, var2]) 

134 opt_op.run() 

135 # In eager mode, simply call minimize to update the list of variables. 

136 opt.minimize(loss, var_list=[var1, var2]) 

137 ``` 

138 

139 ### Usage in custom training loops 

140 

141 In Keras models, sometimes variables are created when the model is first 

142 called, instead of construction time. Examples include 1) sequential models 

143 without input shape pre-defined, or 2) subclassed models. Pass var_list as 

144 callable in these cases. 

145 

146 Example: 

147 

148 ```python 

149 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1) 

150 model = tf.keras.Sequential() 

151 model.add(tf.keras.layers.Dense(num_hidden, activation='relu')) 

152 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')) 

153 loss_fn = lambda: tf.keras.losses.mse(model(input), output) 

154 var_list_fn = lambda: model.trainable_weights 

155 for input, output in data: 

156 opt.minimize(loss_fn, var_list_fn) 

157 ``` 

158 

159 ### Processing gradients before applying them 

160 

161 Calling `minimize()` takes care of both computing the gradients and 

162 applying them to the variables. If you want to process the gradients 

163 before applying them you can instead use the optimizer in three steps: 

164 

165 1. Compute the gradients with `tf.GradientTape`. 

166 2. Process the gradients as you wish. 

167 3. Apply the processed gradients with `apply_gradients()`. 

168 

169 Example: 

170 

171 ```python 

172 # Create an optimizer. 

173 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1) 

174 

175 # Compute the gradients for a list of variables. 

176 with tf.GradientTape() as tape: 

177 loss = <call_loss_function> 

178 vars = <list_of_variables> 

179 grads = tape.gradient(loss, vars) 

180 

181 # Process the gradients, for example cap them, etc. 

182 # capped_grads = [MyCapper(g) for g in grads] 

183 processed_grads = [process_gradient(g) for g in grads] 

184 

185 # Ask the optimizer to apply the processed gradients. 

186 opt.apply_gradients(zip(processed_grads, var_list)) 

187 ``` 

188 

189 ### Use with `tf.distribute.Strategy` 

190 

191 This optimizer class is `tf.distribute.Strategy` aware, which means it 

192 automatically sums gradients across all replicas. To average gradients, 

193 you divide your loss by the global batch size, which is done 

194 automatically if you use `tf.keras` built-in training or evaluation loops. 

195 See the `reduction` argument of your loss which should be set to 

196 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 

197 `tf.keras.losses.Reduction.SUM` for not. 

198 

199 To aggregate gradients yourself, call `apply_gradients` with 

200 `experimental_aggregate_gradients` set to False. This is useful if you need 

201 to process aggregated gradients. 

202 

203 If you are not using these and you want to average gradients, you should use 

204 `tf.math.reduce_sum` to add up your per-example losses and then divide by 

205 the global batch size. Note that when using `tf.distribute.Strategy`, the 

206 first component of a tensor's shape is the *replica-local* batch size, which 

207 is off by a factor equal to the number of replicas being used to compute a 

208 single step. As a result, using `tf.math.reduce_mean` will give the wrong 

209 answer, resulting in gradients that can be many times too big. 

210 

211 ### Variable Constraints 

212 

213 All Keras optimizers respect variable constraints. If constraint function is 

214 passed to any variable, the constraint will be applied to the variable after 

215 the gradient has been applied to the variable. 

216 Important: If gradient is sparse tensor, variable constraint is not 

217 supported. 

218 

219 ### Thread Compatibility 

220 

221 The entire optimizer is currently thread compatible, not thread-safe. The 

222 user needs to perform synchronization if necessary. 

223 

224 ### Slots 

225 

226 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 

227 additional variables associated with the variables to train. These are 

228 called <i>Slots</i>. Slots have names and you can ask the optimizer for the 

229 names of the slots that it uses. Once you have a slot name you can ask the 

230 optimizer for the variable it created to hold the slot value. 

231 

232 This can be useful if you want to log debug a training algorithm, report 

233 stats about the slots, etc. 

234 

235 ### Hyperparameters 

236 

237 These are arguments passed to the optimizer subclass constructor 

238 (the `__init__` method), and then passed to `self._set_hyper()`. 

239 They can be either regular Python values (like 1.0), tensors, or 

240 callables. If they are callable, the callable will be called during 

241 `apply_gradients()` to get the value for the hyper parameter. 

242 

243 Hyperparameters can be overwritten through user code: 

244 

245 Example: 

246 

247 ```python 

248 # Create an optimizer with the desired parameters. 

249 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1) 

250 # `loss` is a callable that takes no argument and returns the value 

251 # to minimize. 

252 loss = lambda: 3 * var1 + 2 * var2 

253 # In eager mode, simply call minimize to update the list of variables. 

254 opt.minimize(loss, var_list=[var1, var2]) 

255 # update learning rate 

256 opt.learning_rate = 0.05 

257 opt.minimize(loss, var_list=[var1, var2]) 

258 ``` 

259 

260 ### Callable learning rate 

261 

262 Optimizer accepts a callable learning rate in two ways. The first way is 

263 through built-in or customized 

264 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be 

265 called on each iteration with `schedule(iteration)`, a `tf.Variable` 

266 owned by the optimizer. 

267 

268 Example: 

269 

270 >>> var = tf.Variable(np.random.random(size=(1,))) 

271 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( 

272 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) 

273 >>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=learning_rate) 

274 >>> loss = lambda: 3 * var 

275 >>> opt.minimize(loss, var_list=[var]) 

276 <tf.Variable... 

277 

278 The second way is through a callable function that 

279 does not accept any arguments. 

280 

281 Example: 

282 

283 >>> var = tf.Variable(np.random.random(size=(1,))) 

284 >>> def lr_callable(): 

285 ... return .1 

286 >>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=lr_callable) 

287 >>> loss = lambda: 3 * var 

288 >>> opt.minimize(loss, var_list=[var]) 

289 <tf.Variable... 

290 

291 ### Creating a custom optimizer 

292 

293 If you intend to create your own optimization algorithm, simply inherit from 

294 this class and override the following methods: 

295 

296 - `_resource_apply_dense` (update variable given gradient tensor is a 

297 dense `tf.Tensor`) 

298 - `_resource_apply_sparse` (update variable given gradient tensor is a 

299 sparse `tf.IndexedSlices`. The most common way for this to happen 

300 is if you are taking the gradient through a `tf.gather`.) 

301 - `_create_slots` 

302 (if your optimizer algorithm requires additional variables) 

303 - `get_config` 

304 (serialization of the optimizer, include all hyper parameters) 

305 """ 

306 

307 # Subclasses should set this to True unless they override `apply_gradients` 

308 # with a version that does not have the `experimental_aggregate_gradients` 

309 # argument. Older versions of Keras did not have this argument so custom 

310 # optimizers may have overridden `apply_gradients` without the 

311 # `experimental_aggregate_gradients` argument. Keras only passes 

312 # `experimental_aggregate_gradients` if this attribute is True. 

313 # Note: This attribute will likely be removed in an upcoming release. 

314 _HAS_AGGREGATE_GRAD = False 

315 

316 def __init__( 

317 self, 

318 name, 

319 gradient_aggregator=None, 

320 gradient_transformers=None, 

321 **kwargs, 

322 ): 

323 """Create a new Optimizer. 

324 

325 This must be called by the constructors of subclasses. 

326 Note that Optimizer instances should not bind to a single graph, 

327 and so shouldn't keep Tensors as member variables. Generally 

328 you should be able to use the _set_hyper()/state.get_hyper() 

329 facility instead. 

330 

331 This class is stateful and thread-compatible. 

332 

333 Example of custom gradient transformations: 

334 

335 ```python 

336 def my_gradient_transformer(grads_and_vars): 

337 # Simple example, double the gradients. 

338 return [(2. * g, v) for g, v in grads_and_vars] 

339 

340 optimizer = tf.keras.optimizers.legacy.SGD( 

341 1e-3, gradient_transformers=[my_gradient_transformer]) 

342 ``` 

343 

344 Args: 

345 name: String. The name to use for momentum accumulator weights created 

346 by the optimizer. 

347 gradient_aggregator: The function to use to aggregate gradients across 

348 devices (when using `tf.distribute.Strategy`). If `None`, defaults 

349 to summing the gradients across devices. The function should accept 

350 and return a list of `(gradient, variable)` tuples. 

351 gradient_transformers: Optional. List of functions to use to transform 

352 gradients before applying updates to Variables. The functions are 

353 applied after `gradient_aggregator`. The functions should accept and 

354 return a list of `(gradient, variable)` tuples. 

355 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 

356 `clipnorm`, `global_clipnorm`. 

357 If `clipvalue` (float) is set, the gradient of each weight 

358 is clipped to be no higher than this value. 

359 If `clipnorm` (float) is set, the gradient of each weight 

360 is individually clipped so that its norm is no higher than this 

361 value. If `global_clipnorm` (float) is set the gradient of all 

362 weights is clipped so that their global norm is no higher than this 

363 value. 

364 

365 Raises: 

366 ValueError: in case of any invalid argument. 

367 """ 

368 # Instrument optimizer usages 

369 keras_optimizers_gauge.get_cell(self.__class__.__name__).set(True) 

370 

371 allowed_kwargs = { 

372 "clipnorm", 

373 "clipvalue", 

374 "lr", 

375 "decay", 

376 "global_clipnorm", 

377 } 

378 for k in kwargs: 

379 if k not in allowed_kwargs: 

380 raise TypeError( 

381 "Unexpected keyword argument " 

382 f"passed to optimizer: {str(k)}. Allowed kwargs are " 

383 f"{allowed_kwargs}." 

384 ) 

385 # checks that all keyword arguments are non-negative. 

386 if kwargs[k] is not None and kwargs[k] < 0: 

387 raise ValueError(f"Expected {k} >= 0, received: {kwargs[k]}") 

388 if k == "lr": 

389 warnings.warn( 

390 "The `lr` argument is deprecated, " 

391 "use `learning_rate` instead.", 

392 stacklevel=2, 

393 ) 

394 

395 self._use_locking = True 

396 self._init_set_name(name) 

397 self._hyper = {} 

398 # dict: {variable name : {slot name : variable}} 

399 self._slots = {} 

400 self._slot_names = [] 

401 self._weights = [] 

402 self._iterations = None 

403 

404 # For implementing Trackable. Stores information about how to restore 

405 # slot variables which have not yet been created 

406 # (trackable._CheckpointPosition objects). 

407 # {slot_name : 

408 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 

409 # ... } 

410 self._deferred_slot_restorations = {} 

411 

412 decay = kwargs.pop("decay", 0.0) 

413 if decay < 0.0: 

414 raise ValueError( 

415 f"decay cannot be less than 0. Received: decay={decay}." 

416 ) 

417 self._initial_decay = decay 

418 

419 self._hypers_created = False 

420 # Store the distribution strategy object if the optimizer is created 

421 # inside strategy scope, so it could be used to create variables later. 

422 if tf.distribute.has_strategy(): 

423 self._distribution_strategy = tf.distribute.get_strategy() 

424 else: 

425 self._distribution_strategy = None 

426 

427 # Configure gradient transformations. 

428 if gradient_aggregator is None: 

429 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients 

430 self.gradient_aggregator = gradient_aggregator 

431 if gradient_transformers is None: 

432 gradient_transformers = [] 

433 self.gradient_transformers = gradient_transformers 

434 self.clipnorm = kwargs.pop("clipnorm", None) 

435 self.global_clipnorm = kwargs.pop("global_clipnorm", None) 

436 if self.clipnorm is not None and self.global_clipnorm is not None: 

437 raise ValueError( 

438 "Cannot accept both `clipnorm` and `global_clipnorm`. " 

439 "Received: `clipnorm`={}, `global_clipnorm`={}.".format( 

440 self.clipnorm, self.global_clipnorm 

441 ) 

442 ) 

443 self.clipvalue = kwargs.pop("clipvalue", None) 

444 

445 def __deepcopy__(self, memo): 

446 cls = self.__class__ 

447 result = cls.__new__(cls) 

448 memo[id(self)] = result 

449 for k, v in self.__dict__.items(): 

450 # DistributionStrategy singleton cannot be serialized 

451 if k == "_distribution_strategy": 

452 continue 

453 setattr(result, k, deepcopy(v, memo)) 

454 result._distribution_strategy = self._distribution_strategy 

455 return result 

456 

457 @property 

458 def clipnorm(self): 

459 """`float` or `None`. If set, clips gradients to a maximum norm.""" 

460 return self._clipnorm 

461 

462 @property 

463 def global_clipnorm(self): 

464 """`float` or `None`. 

465 

466 If set, clips gradients to a maximum norm. 

467 

468 Check `tf.clip_by_global_norm` for more details. 

469 """ 

470 return self._global_clipnorm 

471 

472 @clipnorm.setter 

473 def clipnorm(self, val): 

474 if val is not None and self.gradient_transformers: 

475 raise ValueError( 

476 "`clipnorm` cannot be set when `gradient_transformers` " 

477 "is set. Instead, use the `gradient_transformers` to " 

478 "specify clipping and other transformations. Received: " 

479 f"val={val}, " 

480 f"gradient_transformers={self.gradient_transformers}." 

481 ) 

482 self._clipnorm = val 

483 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn( 

484 self._clipnorm 

485 ) 

486 

487 @global_clipnorm.setter 

488 def global_clipnorm(self, val): 

489 if val is not None and self.gradient_transformers: 

490 raise ValueError( 

491 "`global_clipnorm` cannot be set when " 

492 "`gradient_transformers` " 

493 "is set. Instead, use the `gradient_transformers` to " 

494 "specify clipping and other transformations. Received: " 

495 f"val={val}, " 

496 f"gradient_transformers={self.gradient_transformers}." 

497 ) 

498 self._global_clipnorm = val 

499 self._global_clipnorm_fn = ( 

500 optimizer_utils.make_global_gradient_clipnorm_fn( 

501 self._global_clipnorm 

502 ) 

503 ) 

504 

505 @property 

506 def clipvalue(self): 

507 """`float` or `None`. If set, clips gradients to a maximum value.""" 

508 return self._clipvalue 

509 

510 @clipvalue.setter 

511 def clipvalue(self, val): 

512 if val is not None and self.gradient_transformers: 

513 raise ValueError( 

514 "`clipvalue` cannot be set when `gradient_transformers` " 

515 "is set. Instead, use the `gradient_transformers` to " 

516 "specify clipping and other transformations. Received: " 

517 f"val={val}, " 

518 f"gradient_transformers={self.gradient_transformers}." 

519 ) 

520 self._clipvalue = val 

521 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn( 

522 self._clipvalue 

523 ) 

524 

525 def _transform_loss(self, loss): 

526 """Called in `.minimize` to transform loss before computing 

527 gradients.""" 

528 return loss 

529 

530 def _get_gradients(self, tape, loss, var_list, grad_loss=None): 

531 """Called in `minimize` to compute gradients from loss.""" 

532 grads = tape.gradient(loss, var_list, grad_loss) 

533 return list(zip(grads, var_list)) 

534 

535 def _transform_unaggregated_gradients(self, grads_and_vars): 

536 """Called in `apply_gradients` before gradient aggregation.""" 

537 return grads_and_vars 

538 

539 def _aggregate_gradients(self, grads_and_vars): 

540 """Called in `apply_gradients` to aggregate gradients across devices. 

541 

542 Note that user subclasses may override this, so the interface should not 

543 be changed. 

544 

545 Args: 

546 grads_and_vars: List of (gradient, variable) pairs. 

547 

548 Returns: 

549 A list of (aggregrated_gradient, variable) pairs. By default, this 

550 calls `self.gradient_aggregator`. 

551 """ 

552 return self.gradient_aggregator(grads_and_vars) 

553 

554 def _transform_gradients(self, grads_and_vars): 

555 """Called in `apply_gradients` after aggregation.""" 

556 if self._clipvalue is not None: 

557 grads_and_vars = self._clipvalue_fn(grads_and_vars) 

558 if self._clipnorm is not None: 

559 grads_and_vars = self._clipnorm_fn(grads_and_vars) 

560 if self._global_clipnorm is not None: 

561 grads_and_vars = self._global_clipnorm_fn(grads_and_vars) 

562 

563 for fn in self.gradient_transformers: 

564 grads_and_vars = fn(grads_and_vars) 

565 return grads_and_vars 

566 

567 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None): 

568 """Minimize `loss` by updating `var_list`. 

569 

570 This method simply computes gradient using `tf.GradientTape` and calls 

571 `apply_gradients()`. If you want to process the gradient before applying 

572 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 

573 of using this function. 

574 

575 Args: 

576 loss: `Tensor` or callable. If a callable, `loss` should take no 

577 arguments and return the value to minimize. If a `Tensor`, the 

578 `tape` argument must be passed. 

579 var_list: list or tuple of `Variable` objects to update to minimize 

580 `loss`, or a callable returning the list or tuple of `Variable` 

581 objects. Use callable when the variable list would otherwise be 

582 incomplete before `minimize` since the variables are created at the 

583 first time `loss` is called. 

584 grad_loss: (Optional). A `Tensor` holding the gradient computed for 

585 `loss`. 

586 name: (Optional) str. Name for the returned operation. 

587 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a 

588 `Tensor`, the tape that computed the `loss` must be provided. 

589 

590 Returns: 

591 An `Operation` that updates the variables in `var_list`. The 

592 `iterations` will be automatically increased by 1. 

593 

594 Raises: 

595 ValueError: If some of the variables are not `Variable` objects. 

596 

597 """ 

598 grads_and_vars = self._compute_gradients( 

599 loss, var_list=var_list, grad_loss=grad_loss, tape=tape 

600 ) 

601 return self.apply_gradients(grads_and_vars, name=name) 

602 

603 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 

604 """Compute gradients of `loss` for the variables in `var_list`. 

605 

606 This is the first part of `minimize()`. It returns a list 

607 of (gradient, variable) pairs where "gradient" is the gradient 

608 for "variable". Note that "gradient" can be a `Tensor`, an 

609 `IndexedSlices`, or `None` if there is no gradient for the 

610 given variable. 

611 

612 Args: 

613 loss: `Tensor` or callable. If a callable, `loss` should take no 

614 arguments and return the value to minimize. If a `Tensor`, the 

615 `tape` argument must be passed. 

616 var_list: list or tuple of `Variable` objects to update to minimize 

617 `loss`, or a callable returning the list or tuple of `Variable` 

618 objects. Use callable when the variable list would otherwise be 

619 incomplete before `minimize` and the variables are created at the 

620 first time when `loss` is called. 

621 grad_loss: Optional. A `Tensor` holding the gradient computed for 

622 `loss`. 

623 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a 

624 `Tensor`, the tape that computed the `loss` must be provided. 

625 

626 Returns: 

627 A list of (gradient, variable) pairs. Variable is always present, but 

628 gradient can be `None`. 

629 

630 Raises: 

631 TypeError: If `var_list` contains anything else than `Variable` 

632 objects. 

633 ValueError: If some arguments are invalid, or var_list is None. 

634 """ 

635 # TODO(joshl): Test that we handle weight decay in a reasonable way. 

636 if not callable(loss) and tape is None: 

637 raise ValueError( 

638 "`tape` is required when a `Tensor` loss is passed. " 

639 f"Received: loss={loss}, tape={tape}." 

640 ) 

641 tape = tape if tape is not None else tf.GradientTape() 

642 

643 if callable(loss): 

644 with tape: 

645 if not callable(var_list): 

646 tape.watch(var_list) 

647 loss = loss() 

648 if callable(var_list): 

649 var_list = var_list() 

650 

651 with tape: 

652 loss = self._transform_loss(loss) 

653 

654 var_list = tf.nest.flatten(var_list) 

655 with tf.name_scope(self._name + "/gradients"): 

656 grads_and_vars = self._get_gradients( 

657 tape, loss, var_list, grad_loss 

658 ) 

659 

660 self._assert_valid_dtypes( 

661 [ 

662 v 

663 for g, v in grads_and_vars 

664 if g is not None and v.dtype != tf.resource 

665 ] 

666 ) 

667 

668 return grads_and_vars 

669 

670 def apply_gradients( 

671 self, grads_and_vars, name=None, experimental_aggregate_gradients=True 

672 ): 

673 """Apply gradients to variables. 

674 

675 This is the second part of `minimize()`. It returns an `Operation` that 

676 applies gradients. 

677 

678 The method sums gradients from all replicas in the presence of 

679 `tf.distribute.Strategy` by default. You can aggregate gradients 

680 yourself by passing `experimental_aggregate_gradients=False`. 

681 

682 Example: 

683 

684 ```python 

685 grads = tape.gradient(loss, vars) 

686 grads = tf.distribute.get_replica_context().all_reduce('sum', grads) 

687 # Processing aggregated gradients. 

688 optimizer.apply_gradients(zip(grads, vars), 

689 experimental_aggregate_gradients=False) 

690 

691 ``` 

692 

693 Args: 

694 grads_and_vars: List of (gradient, variable) pairs. 

695 name: Optional name for the returned operation. When `None`, uses the 

696 name passed to the `Optimizer` constructor. Defaults to `None`. 

697 experimental_aggregate_gradients: Whether to sum gradients from 

698 different replicas in the presence of `tf.distribute.Strategy`. If 

699 False, it's user responsibility to aggregate the gradients. Default 

700 to `True`. 

701 

702 Returns: 

703 An `Operation` that applies the specified gradients. The `iterations` 

704 will be automatically increased by 1. 

705 

706 Raises: 

707 TypeError: If `grads_and_vars` is malformed. 

708 ValueError: If none of the variables have gradients. 

709 RuntimeError: If called in a cross-replica context. 

710 """ 

711 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 

712 var_list = [v for (_, v) in grads_and_vars] 

713 

714 with tf.name_scope(self._name): 

715 # Create iteration if necessary. 

716 with tf.init_scope(): 

717 self._create_all_weights(var_list) 

718 

719 if not grads_and_vars: 

720 # Distribution strategy does not support reducing an empty list 

721 # of gradients 

722 return tf.no_op() 

723 

724 if tf.distribute.in_cross_replica_context(): 

725 raise RuntimeError( 

726 "`apply_gradients() cannot be called in cross-replica " 

727 "context. Use `tf.distribute.Strategy.run` to enter " 

728 "replica context. For more information, please see the " 

729 "docstring of `tf.distribute.get_replica_context`." 

730 ) 

731 

732 strategy = tf.distribute.get_strategy() 

733 if ( 

734 not experimental_aggregate_gradients 

735 and strategy 

736 and isinstance( 

737 strategy, 

738 ( 

739 tf.compat.v1.distribute.experimental.ParameterServerStrategy, # noqa: E501 

740 tf.distribute.experimental.ParameterServerStrategy, 

741 tf.distribute.experimental.CentralStorageStrategy, 

742 tf.compat.v1.distribute.experimental.CentralStorageStrategy, # noqa: E501 

743 ), 

744 ) 

745 ): 

746 raise NotImplementedError( 

747 "`experimental_aggregate_gradients=False is not supported " 

748 "for ParameterServerStrategy and CentralStorageStrategy. " 

749 f"Used: strategy={strategy}." 

750 ) 

751 

752 apply_state = self._prepare(var_list) 

753 if experimental_aggregate_gradients: 

754 grads_and_vars = self._transform_unaggregated_gradients( 

755 grads_and_vars 

756 ) 

757 grads_and_vars = self._aggregate_gradients(grads_and_vars) 

758 grads_and_vars = self._transform_gradients(grads_and_vars) 

759 

760 return tf.__internal__.distribute.interim.maybe_merge_call( 

761 functools.partial( 

762 self._distributed_apply, apply_state=apply_state 

763 ), 

764 strategy, 

765 grads_and_vars, 

766 name=name, 

767 ) 

768 

769 def _distributed_apply( 

770 self, distribution, grads_and_vars, apply_state, name 

771 ): 

772 """`apply_gradients` using a `DistributionStrategy`.""" 

773 

774 def apply_grad_to_update_var(var, grad): 

775 """Apply gradient to variable.""" 

776 if isinstance(var, tf.Tensor): 

777 raise NotImplementedError( 

778 "Updating a `Tensor` is not implemented. " 

779 f"Received: var={var}." 

780 ) 

781 

782 apply_kwargs = {} 

783 if isinstance(grad, tf.IndexedSlices): 

784 if var.constraint is not None: 

785 raise RuntimeError( 

786 "Cannot use a constraint function on a sparse " 

787 f"variable. Received: grad={grad}, " 

788 f"var.constraint={var.constraint}." 

789 ) 

790 if "apply_state" in self._sparse_apply_args: 

791 apply_kwargs["apply_state"] = apply_state 

792 return self._resource_apply_sparse_duplicate_indices( 

793 grad.values, var, grad.indices, **apply_kwargs 

794 ) 

795 

796 if "apply_state" in self._dense_apply_args: 

797 apply_kwargs["apply_state"] = apply_state 

798 update_op = self._resource_apply_dense(grad, var, **apply_kwargs) 

799 if var.constraint is not None: 

800 with tf.control_dependencies([update_op]): 

801 return var.assign(var.constraint(var)) 

802 else: 

803 return update_op 

804 

805 eagerly_outside_functions = ( 

806 tf.compat.v1.executing_eagerly_outside_functions() 

807 ) 

808 update_ops = [] 

809 with name_scope_only_in_function_or_graph(name or self._name): 

810 for grad, var in grads_and_vars: 

811 # Colocate the update with variables to avoid unnecessary 

812 # communication delays. See b/136304694. 

813 with distribution.extended.colocate_vars_with(var): 

814 with name_scope_only_in_function_or_graph( 

815 "update" 

816 if eagerly_outside_functions 

817 else "update_" + var.op.name 

818 ): 

819 update_op = distribution.extended.update( 

820 var, 

821 apply_grad_to_update_var, 

822 args=(grad,), 

823 group=False, 

824 ) 

825 if tf.distribute.in_cross_replica_context(): 

826 # In cross-replica context, extended.update returns 

827 # a list of update ops from all replicas 

828 # (group=False). 

829 update_ops.extend(update_op) 

830 else: 

831 # In replica context, extended.update return the 

832 # single update op of current replica. 

833 update_ops.append(update_op) 

834 

835 any_symbolic = any( 

836 isinstance(i, tf.Operation) or tf_utils.is_symbolic_tensor(i) 

837 for i in update_ops 

838 ) 

839 if not tf.executing_eagerly() or any_symbolic: 

840 # If the current context is graph mode or any of the update ops 

841 # are symbolic then the step update should be carried out under 

842 # a graph context. (eager updates execute immediately) 

843 with backend._current_graph(update_ops).as_default(): 

844 with tf.control_dependencies([tf.group(update_ops)]): 

845 return self.iterations.assign_add(1, read_value=False) 

846 

847 return self.iterations.assign_add(1) 

848 

849 def get_gradients(self, loss, params): 

850 """Returns gradients of `loss` with respect to `params`. 

851 

852 Should be used only in legacy v1 graph mode. 

853 

854 Args: 

855 loss: Loss tensor. 

856 params: List of variables. 

857 

858 Returns: 

859 List of gradient tensors. 

860 

861 Raises: 

862 ValueError: In case any gradient cannot be computed (e.g. if gradient 

863 function not implemented). 

864 """ 

865 params = tf.nest.flatten(params) 

866 with backend.get_graph().as_default(), backend.name_scope( 

867 self._name + "/gradients" 

868 ): 

869 grads = tf.compat.v1.gradients(loss, params) 

870 for grad, param in zip(grads, params): 

871 if grad is None: 

872 raise ValueError( 

873 "Variable {} has `None` for gradient. " 

874 "Please make sure that all of your ops have a " 

875 "gradient defined (i.e. are differentiable). " 

876 "Common ops without gradient: " 

877 "K.argmax, K.round, K.eval.".format(param) 

878 ) 

879 return grads 

880 

881 def get_updates(self, loss, params): 

882 grads = self.get_gradients(loss, params) 

883 grads_and_vars = list(zip(grads, params)) 

884 self._assert_valid_dtypes( 

885 [ 

886 v 

887 for g, v in grads_and_vars 

888 if g is not None and v.dtype != tf.resource 

889 ] 

890 ) 

891 return [self.apply_gradients(grads_and_vars)] 

892 

893 def _set_hyper(self, name, value): 

894 """set hyper `name` to value. value can be callable, tensor, numeric.""" 

895 if isinstance(value, tf.__internal__.tracking.Trackable): 

896 self._track_trackable(value, name, overwrite=True) 

897 if name not in self._hyper: 

898 self._hyper[name] = value 

899 else: 

900 prev_value = self._hyper[name] 

901 if ( 

902 callable(prev_value) 

903 or isinstance( 

904 prev_value, 

905 ( 

906 tf.Tensor, 

907 int, 

908 float, 

909 learning_rate_schedule.LearningRateSchedule, 

910 ), 

911 ) 

912 or isinstance( 

913 value, learning_rate_schedule.LearningRateSchedule 

914 ) 

915 ): 

916 self._hyper[name] = value 

917 else: 

918 backend.set_value(self._hyper[name], value) 

919 

920 def _get_hyper(self, name, dtype=None): 

921 if not self._hypers_created: 

922 self._create_hypers() 

923 value = self._hyper[name] 

924 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 

925 return value 

926 if callable(value): 

927 value = value() 

928 if dtype: 

929 return tf.cast(value, dtype) 

930 else: 

931 return value 

932 

933 def _create_slots(self, var_list): 

934 pass 

935 

936 def _create_slots_for_sharded_variables(self, var_list): 

937 """Add ShardedVariables to slots to later reconstruct for checkpointing. 

938 

939 ShardedVariables don't have slot variables created for them; their 

940 shards do. This function allows users to call get_slot with a 

941 ShardedVariable input and receive a ShardedVariable output containing 

942 the appropriate slot vars. 

943 

944 Iterate over the variables to find shards, and aggregate the sharded 

945 containers in a set. Add these ShardedVariables to _slots so that 

946 get_slot can retrieve the proper slot variables for their component 

947 shards, and reconstruct those into a ShardedVariable. 

948 

949 Args: 

950 var_list: list or tuple of `Variable` objects that will be minimized 

951 using this optimizer. 

952 """ 

953 sharded_vars = set() 

954 for var in var_list: 

955 if getattr(var, "_sharded_container", False): 

956 sharded_vars.add(var._sharded_container()) 

957 

958 for sharded_var in sharded_vars: 

959 sharded_key = _var_key(sharded_var) 

960 slot_dict = {} 

961 for slot in self.get_slot_names(): 

962 slot_dict[slot] = sharded_var 

963 self._slots[sharded_key] = slot_dict 

964 

965 def _create_all_weights(self, var_list): 

966 """Creates all weights, including iterations, hyperparameters and slot 

967 vars. 

968 

969 This will add newly created variables to `optimizer.weights`. 

970 

971 New variables are only created when this method is called the first 

972 time, or when called with different variables in the var_list. 

973 

974 Args: 

975 var_list: list or tuple of `Variable` objects that will be minimized 

976 using this optimizer. 

977 """ 

978 

979 _ = self.iterations 

980 self._create_hypers() 

981 self._create_slots(var_list) 

982 self._create_slots_for_sharded_variables(var_list) 

983 

984 def __getattribute__(self, name): 

985 """Overridden to support hyperparameter access.""" 

986 try: 

987 return super().__getattribute__(name) 

988 except AttributeError as e: 

989 # Needed to avoid infinite recursion with __setattr__. 

990 if name == "_hyper": 

991 raise e 

992 # Backwards compatibility with Keras optimizers. 

993 if name == "lr": 

994 name = "learning_rate" 

995 if name in self._hyper: 

996 return self._get_hyper(name) 

997 raise e 

998 

999 def __dir__(self): 

1000 result = set(super().__dir__()) 

1001 if "_hyper" in result: 

1002 result |= self._hyper.keys() 

1003 if "learning_rate" in self._hyper.keys(): 

1004 result.add("lr") 

1005 return list(result) 

1006 

1007 def __setattr__(self, name, value): 

1008 """Override setattr to support dynamic hyperparameter setting.""" 

1009 # Backwards compatibility with Keras optimizers. 

1010 if name == "lr": 

1011 name = "learning_rate" 

1012 if hasattr(self, "_hyper") and name in self._hyper: 

1013 self._set_hyper(name, value) 

1014 else: 

1015 super().__setattr__(name, value) 

1016 

1017 def get_slot_names(self): 

1018 """A list of names for this optimizer's slots.""" 

1019 return self._slot_names 

1020 

1021 def add_slot(self, var, slot_name, initializer="zeros", shape=None): 

1022 """Add a new slot variable for `var`. 

1023 

1024 A slot variable is an additional variable associated with `var` to 

1025 train. It is allocated and managed by optimizers, e.g. `Adam`. 

1026 

1027 Args: 

1028 var: a `Variable` object. 

1029 slot_name: name of the slot variable. 

1030 initializer: initializer of the slot variable 

1031 shape: (Optional) shape of the slot variable. If not set, it will 

1032 default to the shape of `var`. 

1033 

1034 Returns: 

1035 A slot variable. 

1036 """ 

1037 if slot_name not in self._slot_names: 

1038 self._slot_names.append(slot_name) 

1039 var_key = _var_key(var) 

1040 slot_dict = self._slots.setdefault(var_key, {}) 

1041 weight = slot_dict.get(slot_name, None) 

1042 if weight is None: 

1043 if isinstance(initializer, str) or callable(initializer): 

1044 initializer = initializers.get(initializer) 

1045 if isinstance( 

1046 initializer, 

1047 tf.__internal__.tracking.CheckpointInitialValueCallable, 

1048 ) or (shape is not None): 

1049 slot_shape = shape 

1050 else: 

1051 slot_shape = var.shape 

1052 initial_value = functools.partial( 

1053 initializer, shape=slot_shape, dtype=var.dtype 

1054 ) 

1055 else: 

1056 initial_value = initializer 

1057 

1058 with self._distribution_strategy_scope(): 

1059 strategy = tf.distribute.get_strategy() 

1060 if not strategy.extended.variable_created_in_scope(var): 

1061 raise ValueError( 

1062 "Trying to create optimizer slot variable under the " 

1063 "scope for tf.distribute.Strategy ({}), which is " 

1064 "different from the scope used for the original " 

1065 "variable ({}). Make sure the slot variables are " 

1066 "created under the same strategy scope. This may " 

1067 "happen if you're restoring from a checkpoint " 

1068 "outside the scope.".format(strategy, var) 

1069 ) 

1070 

1071 with strategy.extended.colocate_vars_with(var): 

1072 weight = tf.Variable( 

1073 name=f"{var._shared_name}/{slot_name}", 

1074 dtype=var.dtype, 

1075 trainable=False, 

1076 initial_value=initial_value, 

1077 ) 

1078 backend.track_variable(weight) 

1079 slot_dict[slot_name] = weight 

1080 self._restore_slot_variable( 

1081 slot_name=slot_name, variable=var, slot_variable=weight 

1082 ) 

1083 self._weights.append(weight) 

1084 return weight 

1085 

1086 def get_slot(self, var, slot_name): 

1087 var_key = _var_key(var) 

1088 slot_dict = self._slots[var_key] 

1089 slot_variable = slot_dict[slot_name] 

1090 if isinstance( 

1091 slot_variable, tf.__internal__.distribute.ShardedVariable 

1092 ): 

1093 # Construct a ShardedVariable that points to the input 

1094 # ShardedVariable's component shard's slot variables. 

1095 shard_vars = [] 

1096 for shard in slot_variable.variables: 

1097 slot_shard = self.get_slot(shard, slot_name) 

1098 shard_vars.append(slot_shard) 

1099 slot_variable = tf.__internal__.distribute.ShardedVariable( 

1100 shard_vars, name=slot_variable.name 

1101 ) 

1102 return slot_variable 

1103 

1104 def _prepare(self, var_list): 

1105 keys = set() 

1106 for var in var_list: 

1107 if isinstance(var, tf.distribute.DistributedValues): 

1108 var_devices = var._devices 

1109 else: 

1110 var_devices = [var.device] 

1111 var_dtype = var.dtype.base_dtype 

1112 for var_device in var_devices: 

1113 keys.add((var_device, var_dtype)) 

1114 

1115 apply_state = {} 

1116 for var_device, var_dtype in keys: 

1117 apply_state[(var_device, var_dtype)] = {} 

1118 with tf.device(var_device): 

1119 self._prepare_local(var_device, var_dtype, apply_state) 

1120 

1121 return apply_state 

1122 

1123 def _prepare_local(self, var_device, var_dtype, apply_state): 

1124 if "learning_rate" in self._hyper: 

1125 lr_t = tf.identity(self._decayed_lr(var_dtype)) 

1126 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t 

1127 

1128 def _fallback_apply_state(self, var_device, var_dtype): 

1129 """Compatibility for subclasses that don't pass apply_state through.""" 

1130 apply_state = {(var_device, var_dtype): {}} 

1131 self._prepare_local(var_device, var_dtype, apply_state) 

1132 return apply_state[(var_device, var_dtype)] 

1133 

1134 def _create_hypers(self): 

1135 if self._hypers_created: 

1136 return 

1137 with self._distribution_strategy_scope(): 

1138 # Iterate hyper values deterministically. 

1139 for name, value in sorted(self._hyper.items()): 

1140 if isinstance(value, (tf.Tensor, tf.Variable)) or callable( 

1141 value 

1142 ): 

1143 # The check for `callable` covers the usage when `value` is 

1144 # a `LearningRateSchedule`, in which case it does not need 

1145 # to create a variable. 

1146 continue 

1147 else: 

1148 self._hyper[name] = self.add_weight( 

1149 name, 

1150 shape=[], 

1151 trainable=False, 

1152 initializer=value, 

1153 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, 

1154 ) 

1155 self._hypers_created = True 

1156 

1157 @property 

1158 def iterations(self): 

1159 """Variable. The number of training steps this Optimizer has run.""" 

1160 if self._iterations is None: 

1161 with self._distribution_strategy_scope(): 

1162 self._iterations = self.add_weight( 

1163 "iter", 

1164 shape=[], 

1165 dtype=tf.int64, 

1166 trainable=False, 

1167 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, 

1168 ) 

1169 self._weights.append(self._iterations) 

1170 return self._iterations 

1171 

1172 @iterations.setter 

1173 def iterations(self, variable): 

1174 if self._iterations is not None: 

1175 raise RuntimeError( 

1176 "Cannot set `iterations` to a new Variable after " 

1177 "the Optimizer weights have been created. Here it is " 

1178 f"attempting to set `iterations` to {variable}." 

1179 ) 

1180 self._iterations = variable 

1181 self._weights.append(self._iterations) 

1182 

1183 def _decayed_lr(self, var_dtype): 

1184 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 

1185 lr_t = self._get_hyper("learning_rate", var_dtype) 

1186 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 

1187 local_step = tf.cast(self.iterations, var_dtype) 

1188 lr_t = tf.cast(lr_t(local_step), var_dtype) 

1189 if self._initial_decay > 0.0: 

1190 local_step = tf.cast(self.iterations, var_dtype) 

1191 decay_t = tf.cast(self._initial_decay, var_dtype) 

1192 lr_t = lr_t / (1.0 + decay_t * local_step) 

1193 return lr_t 

1194 

1195 @abc.abstractmethod 

1196 def get_config(self): 

1197 """Returns the config of the optimizer. 

1198 

1199 An optimizer config is a Python dictionary (serializable) 

1200 containing the configuration of an optimizer. 

1201 The same optimizer can be reinstantiated later 

1202 (without any saved state) from this configuration. 

1203 

1204 Returns: 

1205 Python dictionary. 

1206 """ 

1207 config = {"name": self._name} 

1208 if self.clipnorm is not None: 

1209 config["clipnorm"] = self.clipnorm 

1210 if self.clipvalue is not None: 

1211 config["clipvalue"] = self.clipvalue 

1212 if self.global_clipnorm is not None: 

1213 config["global_clipnorm"] = self.global_clipnorm 

1214 return config 

1215 

1216 @classmethod 

1217 def from_config(cls, config, custom_objects=None): 

1218 """Creates an optimizer from its config. 

1219 

1220 This method is the reverse of `get_config`, 

1221 capable of instantiating the same optimizer from the config 

1222 dictionary. 

1223 

1224 Args: 

1225 config: A Python dictionary, typically the output of get_config. 

1226 custom_objects: A Python dictionary mapping names to additional 

1227 Python objects used to create this optimizer, such as a function 

1228 used for a hyperparameter. 

1229 

1230 Returns: 

1231 An optimizer instance. 

1232 """ 

1233 if "lr" in config: 

1234 config["learning_rate"] = config.pop("lr") 

1235 if "learning_rate" in config: 

1236 if isinstance(config["learning_rate"], dict): 

1237 config["learning_rate"] = learning_rate_schedule.deserialize( 

1238 config["learning_rate"], custom_objects=custom_objects 

1239 ) 

1240 return cls(**config) 

1241 

1242 def _serialize_hyperparameter(self, hyperparameter_name): 

1243 """Serialize a hyperparameter that can be a float, callable, or 

1244 Tensor.""" 

1245 value = self._hyper[hyperparameter_name] 

1246 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 

1247 return learning_rate_schedule.serialize(value) 

1248 if callable(value): 

1249 return value() 

1250 if tf.is_tensor(value): 

1251 return backend.get_value(value) 

1252 return value 

1253 

1254 def variables(self): 

1255 """Returns variables of this Optimizer based on the order created.""" 

1256 return self._weights 

1257 

1258 @property 

1259 def weights(self): 

1260 """Returns variables of this Optimizer based on the order created.""" 

1261 return self._weights 

1262 

1263 def get_weights(self): 

1264 """Returns the current weights of the optimizer. 

1265 

1266 The weights of an optimizer are its state (ie, variables). 

1267 This function returns the weight values associated with this 

1268 optimizer as a list of Numpy arrays. The first value is always the 

1269 iterations count of the optimizer, followed by the optimizer's state 

1270 variables in the order they were created. The returned list can in turn 

1271 be used to load state into similarly parameterized optimizers. 

1272 

1273 For example, the RMSprop optimizer for this simple model returns a list 

1274 of three values-- the iteration count, followed by the root-mean-square 

1275 value of the kernel and bias of the single Dense layer: 

1276 

1277 >>> opt = tf.keras.optimizers.legacy.RMSprop() 

1278 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

1279 >>> m.compile(opt, loss='mse') 

1280 >>> data = np.arange(100).reshape(5, 20) 

1281 >>> labels = np.zeros(5) 

1282 >>> results = m.fit(data, labels) # Training. 

1283 >>> len(opt.get_weights()) 

1284 3 

1285 

1286 Returns: 

1287 Weights values as a list of numpy arrays. 

1288 """ 

1289 params = self.weights 

1290 return backend.batch_get_value(params) 

1291 

1292 # TODO(tanzheny): Maybe share this logic with base_layer. 

1293 def set_weights(self, weights): 

1294 """Set the weights of the optimizer. 

1295 

1296 The weights of an optimizer are its state (ie, variables). 

1297 This function takes the weight values associated with this 

1298 optimizer as a list of Numpy arrays. The first value is always the 

1299 iterations count of the optimizer, followed by the optimizer's state 

1300 variables in the order they are created. The passed values are used to 

1301 set the new state of the optimizer. 

1302 

1303 For example, the RMSprop optimizer for this simple model takes a list of 

1304 three values-- the iteration count, followed by the root-mean-square 

1305 value of the kernel and bias of the single Dense layer: 

1306 

1307 >>> opt = tf.keras.optimizers.legacy.RMSprop() 

1308 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

1309 >>> m.compile(opt, loss='mse') 

1310 >>> data = np.arange(100).reshape(5, 20) 

1311 >>> labels = np.zeros(5) 

1312 >>> results = m.fit(data, labels) # Training. 

1313 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])] 

1314 >>> opt.set_weights(new_weights) 

1315 >>> opt.iterations 

1316 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10> 

1317 

1318 Args: 

1319 weights: weight values as a list of numpy arrays. 

1320 """ 

1321 params = self.weights 

1322 if len(params) != len(weights): 

1323 raise ValueError( 

1324 f"You called `set_weights(weights)` on optimizer {self._name} " 

1325 f"with a weight list of length {str(len(weights))}, " 

1326 f"but the optimizer was expecting {str(len(params))} " 

1327 f"weights. Provided weights: {str(weights)[:50]}..." 

1328 ) 

1329 if not params: 

1330 return 

1331 weight_value_tuples = [] 

1332 param_values = backend.batch_get_value(params) 

1333 for pv, p, w in zip(param_values, params, weights): 

1334 if pv.shape != w.shape: 

1335 raise ValueError( 

1336 f"Optimizer weight shape {str(pv.shape)} " 

1337 "not compatible with " 

1338 f"provided weight shape {str(w.shape)}." 

1339 ) 

1340 weight_value_tuples.append((p, w)) 

1341 backend.batch_set_value(weight_value_tuples) 

1342 

1343 def add_weight( 

1344 self, 

1345 name, 

1346 shape, 

1347 dtype=None, 

1348 initializer="zeros", 

1349 trainable=None, 

1350 synchronization=tf.VariableSynchronization.AUTO, 

1351 aggregation=tf.VariableAggregation.NONE, 

1352 ): 

1353 

1354 if dtype is None: 

1355 dtype = tf.float32 

1356 if isinstance(initializer, str) or callable(initializer): 

1357 initializer = initializers.get(initializer) 

1358 

1359 if synchronization == tf.VariableSynchronization.ON_READ: 

1360 if trainable: 

1361 raise ValueError( 

1362 "Synchronization value can be set to " 

1363 "VariableSynchronization.ON_READ only for non-trainable " 

1364 "variables. You have specified trainable=True and " 

1365 "synchronization=VariableSynchronization.ON_READ." 

1366 ) 

1367 else: 

1368 # Set trainable to be false when variable is to be synced on 

1369 # read. 

1370 trainable = False 

1371 elif trainable is None: 

1372 trainable = True 

1373 

1374 variable = self._add_variable_with_custom_getter( 

1375 name=name, 

1376 shape=shape, 

1377 getter=base_layer_utils.make_variable, 

1378 overwrite=True, 

1379 initializer=initializer, 

1380 dtype=dtype, 

1381 trainable=trainable, 

1382 use_resource=True, 

1383 synchronization=synchronization, 

1384 aggregation=aggregation, 

1385 ) 

1386 backend.track_variable(variable) 

1387 

1388 return variable 

1389 

1390 def _init_set_name(self, name, zero_based=True): 

1391 if not name: 

1392 self._name = backend.unique_object_name( 

1393 generic_utils.to_snake_case(self.__class__.__name__), 

1394 zero_based=zero_based, 

1395 ) 

1396 else: 

1397 self._name = name 

1398 

1399 def _assert_valid_dtypes(self, tensors): 

1400 """Asserts tensors are all valid types (see `_valid_dtypes`). 

1401 

1402 Args: 

1403 tensors: Tensors to check. 

1404 

1405 Raises: 

1406 ValueError: If any tensor is not a valid type. 

1407 """ 

1408 valid_dtypes = self._valid_dtypes() 

1409 for t in tensors: 

1410 dtype = t.dtype.base_dtype 

1411 if dtype not in valid_dtypes: 

1412 raise ValueError( 

1413 "Invalid type {} for {}, expected: {}.".format( 

1414 dtype, t.name, [v for v in valid_dtypes] 

1415 ) 

1416 ) 

1417 

1418 def _valid_dtypes(self): 

1419 """Valid types for loss, variables and gradients. 

1420 

1421 Subclasses should override to allow other float types. 

1422 

1423 Returns: 

1424 Valid types for loss, variables and gradients. 

1425 """ 

1426 return _DEFAULT_VALID_DTYPES 

1427 

1428 def _call_if_callable(self, param): 

1429 """Call the function if param is callable.""" 

1430 return param() if callable(param) else param 

1431 

1432 def _resource_apply_dense(self, grad, handle, apply_state): 

1433 """Add ops to apply dense gradients to the variable `handle`. 

1434 

1435 Args: 

1436 grad: a `Tensor` representing the gradient. 

1437 handle: a `Tensor` of dtype `resource` which points to the variable to 

1438 be updated. 

1439 apply_state: A dict which is used across multiple apply calls. 

1440 

1441 Returns: 

1442 An `Operation` which updates the value of the variable. 

1443 """ 

1444 raise NotImplementedError( 

1445 "`_resource_apply_dense` must be implemented in subclasses." 

1446 ) 

1447 

1448 def _resource_apply_sparse_duplicate_indices( 

1449 self, grad, handle, indices, **kwargs 

1450 ): 

1451 """Add ops to apply sparse gradients to `handle`, with repeated indices. 

1452 

1453 Optimizers which override this method must deal with repeated indices. 

1454 See the docstring of `_apply_sparse_duplicate_indices` for details. By 

1455 default the correct behavior, to sum non-unique indices and their 

1456 associated gradients, is enforced by first pre-processing `grad` and 

1457 `indices` and passing them on to `_resource_apply_sparse`. Optimizers 

1458 which deal correctly with duplicate indices may instead override this 

1459 method to avoid the overhead of summing. 

1460 

1461 Args: 

1462 grad: a `Tensor` representing the gradient for the affected indices. 

1463 handle: a `Tensor` of dtype `resource` which points to the variable to 

1464 be updated. 

1465 indices: a `Tensor` of integral type representing the indices for 

1466 which the gradient is nonzero. Indices may be repeated. 

1467 **kwargs: May optionally contain `apply_state` 

1468 

1469 Returns: 

1470 An `Operation` which updates the value of the variable. 

1471 """ 

1472 summed_grad, unique_indices = _deduplicate_indexed_slices( 

1473 values=grad, indices=indices 

1474 ) 

1475 return self._resource_apply_sparse( 

1476 summed_grad, handle, unique_indices, **kwargs 

1477 ) 

1478 

1479 def _resource_apply_sparse(self, grad, handle, indices, apply_state): 

1480 """Add ops to apply sparse gradients to the variable `handle`. 

1481 

1482 Similar to `_apply_sparse`, the `indices` argument to this method has 

1483 been de-duplicated. Optimizers which deal correctly with non-unique 

1484 indices may instead override `_resource_apply_sparse_duplicate_indices` 

1485 to avoid this overhead. 

1486 

1487 Args: 

1488 grad: a `Tensor` representing the gradient for the affected indices. 

1489 handle: a `Tensor` of dtype `resource` which points to the variable to 

1490 be updated. 

1491 indices: a `Tensor` of integral type representing the indices for 

1492 which the gradient is nonzero. Indices are unique. 

1493 apply_state: A dict which is used across multiple apply calls. 

1494 

1495 Returns: 

1496 An `Operation` which updates the value of the variable. 

1497 """ 

1498 raise NotImplementedError( 

1499 "`_resource_apply_sparse` Must be implemented in subclasses." 

1500 ) 

1501 

1502 def _resource_scatter_add(self, x, i, v): 

1503 with tf.control_dependencies( 

1504 [ 

1505 tf.raw_ops.ResourceScatterAdd( 

1506 resource=x.handle, indices=i, updates=v 

1507 ) 

1508 ] 

1509 ): 

1510 return x.value() 

1511 

1512 def _resource_scatter_update(self, x, i, v): 

1513 with tf.control_dependencies( 

1514 [ 

1515 tf.raw_ops.ResourceScatterUpdate( 

1516 resource=x.handle, indices=i, updates=v 

1517 ) 

1518 ] 

1519 ): 

1520 return x.value() 

1521 

1522 @property 

1523 @layer_utils.cached_per_instance 

1524 def _dense_apply_args(self): 

1525 return tf_inspect.getfullargspec(self._resource_apply_dense).args 

1526 

1527 @property 

1528 @layer_utils.cached_per_instance 

1529 def _sparse_apply_args(self): 

1530 return tf_inspect.getfullargspec(self._resource_apply_sparse).args 

1531 

1532 # --------------- 

1533 # For implementing the trackable interface 

1534 # --------------- 

1535 

1536 def _restore_slot_variable(self, slot_name, variable, slot_variable): 

1537 """Restore a newly created slot variable's value.""" 

1538 variable_key = _var_key(variable) 

1539 deferred_restorations = self._deferred_slot_restorations.get( 

1540 slot_name, {} 

1541 ).pop(variable_key, []) 

1542 # Iterate over restores, highest restore UID first to minimize the 

1543 # number of assignments. 

1544 deferred_restorations.sort( 

1545 key=lambda position: position.restore_uid, reverse=True 

1546 ) 

1547 for checkpoint_position in deferred_restorations: 

1548 checkpoint_position.restore(slot_variable) 

1549 

1550 def _create_or_restore_slot_variable( 

1551 self, slot_variable_position, slot_name, variable 

1552 ): 

1553 """Returns the slot variable that should have a value restored into it. 

1554 

1555 It is up to the caller to restore the value into the slot variable if a 

1556 valid slot variable is returned. 

1557 

1558 Called when a variable which has an associated slot variable is created 

1559 or restored. When executing eagerly, we create the slot variable with a 

1560 restoring initializer. 

1561 

1562 No new variables are created when graph building. Instead, 

1563 _restore_slot_variable catches these after normal creation and adds 

1564 restore ops to the graph. This method is nonetheless important when 

1565 graph building for the case when a slot variable has already been 

1566 created but `variable` has just been added to a dependency graph 

1567 (causing us to realize that the slot variable needs to be restored). 

1568 

1569 Args: 

1570 slot_variable_position: A `trackable._CheckpointPosition` object 

1571 indicating the slot variable `Trackable` object to be restored. 

1572 slot_name: The name of this `Optimizer`'s slot to restore into. 

1573 variable: The variable object this slot is being created for. 

1574 

1575 Returns: 

1576 A slot variable that should have a value restored into it, or None if 

1577 a slot variable should not be restored at this time. 

1578 """ 

1579 variable_key = _var_key(variable) 

1580 slot_dict = self._slots.get(variable_key, {}) 

1581 slot_variable = slot_dict.get(slot_name, None) 

1582 if ( 

1583 slot_variable is None 

1584 and tf.executing_eagerly() 

1585 and slot_variable_position.is_simple_variable() 

1586 # Defer slot variable creation if there is an active variable 

1587 # creator scope. Generally we'd like to eagerly create/restore slot 

1588 # variables when possible, but this may mean that scopes intended to 

1589 # catch `variable` also catch its eagerly created slot variable 

1590 # unintentionally (specifically make_template would add a dependency 

1591 # on a slot variable if not for this case). Deferring is mostly 

1592 # harmless (aside from double initialization), and makes variable 

1593 # creator scopes behave the same way they do when graph building. 

1594 # 

1595 # One notable case is with distribution strategy, which uses 

1596 # variable creator scope but always desires the `variable` and the 

1597 # slot to use the same scope, thus we can safely eagerly 

1598 # create/restore slot variables. 

1599 and ( 

1600 not tf.compat.v1.get_default_graph()._variable_creator_stack 

1601 or self._distribution_strategy 

1602 ) 

1603 ): 

1604 initializer = ( 

1605 tf.__internal__.tracking.CheckpointInitialValueCallable( 

1606 checkpoint_position=slot_variable_position 

1607 ) 

1608 ) 

1609 slot_variable = self.add_slot( 

1610 var=variable, 

1611 initializer=initializer, 

1612 slot_name=slot_name, 

1613 shape=slot_variable_position.value_shape(), 

1614 ) 

1615 # Slot variables are not owned by any one object (because we don't 

1616 # want to save the slot variable if the optimizer is saved without 

1617 # the non-slot variable, or if the non-slot variable is saved 

1618 # without the optimizer; it's a dependency hypergraph with edges of 

1619 # the form (optimizer, non-slot variable, variable)). So we don't 

1620 # _track_ slot variables anywhere, and instead special-case this 

1621 # dependency and otherwise pretend it's a normal graph. 

1622 if slot_variable is not None: 

1623 # For sharded variables, we need the logic in get_slot to combine 

1624 # slot variables for its shards 

1625 if (slot_variable is variable) and ( 

1626 isinstance(variable, tf.__internal__.distribute.ShardedVariable) 

1627 ): 

1628 return self.get_slot(variable, slot_name) 

1629 # If we've either made this slot variable, or if we've pulled out an 

1630 # existing slot variable, we should restore it. 

1631 return slot_variable 

1632 else: 

1633 # We didn't make the slot variable. Defer restoring until it gets 

1634 # created normally. We keep a list rather than the one with the 

1635 # highest restore UID in case slot variables have their own 

1636 # dependencies, in which case those could differ between restores. 

1637 self._deferred_slot_restorations.setdefault( 

1638 slot_name, {} 

1639 ).setdefault(variable_key, []).append(slot_variable_position) 

1640 return None 

1641 

1642 @contextlib.contextmanager 

1643 def _distribution_strategy_scope(self): 

1644 """Returns the `tf.distribute.Strategy` this optimizer was created 

1645 under.""" 

1646 if self._distribution_strategy and not tf.distribute.has_strategy(): 

1647 with self._distribution_strategy.scope(): 

1648 yield self._distribution_strategy.scope() 

1649 else: 

1650 yield 

1651 

1652 

1653def _var_key(var): 

1654 """Key for representing a primary variable, for looking up slots. 

1655 

1656 In graph mode the name is derived from the var shared name. 

1657 In eager mode the name is derived from the var unique id. 

1658 If distribution strategy exists, get the primary variable first. 

1659 

1660 Args: 

1661 var: the variable. 

1662 

1663 Returns: 

1664 the unique name of the variable. 

1665 """ 

1666 

1667 # Get the distributed variable if it exists. 

1668 if hasattr(var, "_distributed_container"): 

1669 var = var._distributed_container() 

1670 elif ( 

1671 tf_utils.is_extension_type(var) 

1672 and hasattr(var, "handle") 

1673 and hasattr(var.handle, "_distributed_container") 

1674 ): 

1675 # For ResourceVariables, the _distributed_container attribute 

1676 # is added to their handle tensors. 

1677 var = var.handle._distributed_container() 

1678 if getattr(var, "_in_graph_mode", False): 

1679 return var._shared_name 

1680 return var._unique_id 

1681 

1682 

1683def _get_slot_key_from_var(var, slot_name): 

1684 """Get the slot key for the variable: var_name/slot_name.""" 

1685 

1686 name = _var_key(var) 

1687 return name + "/" + slot_name 

1688 

1689 

1690class RestoredOptimizer(OptimizerV2): 

1691 """A non-functional Optimizer implementation for checkpoint compatibility. 

1692 

1693 Holds slot variables and hyperparameters when an optimizer is restored from 

1694 a SavedModel. These variables may be referenced in functions along with ops 

1695 created by the original optimizer, but currently we do not support using the 

1696 optimizer object itself (e.g. through `apply_gradients`). 

1697 """ 

1698 

1699 # TODO(allenl): Make the restored optimizer functional by tracing its apply 

1700 # methods. 

1701 

1702 def __init__(self): 

1703 super().__init__("RestoredOptimizer") 

1704 self._hypers_created = True 

1705 

1706 def get_config(self): 

1707 # TODO(allenl): Save and restore the Optimizer's config 

1708 raise NotImplementedError( 

1709 "Restoring functional Optimizers from SavedModels is not currently " 

1710 "supported. Please file a feature request if this limitation " 

1711 "bothers you." 

1712 ) 

1713 

1714 

1715tf.__internal__.saved_model.load.register_revived_type( 

1716 "optimizer", 

1717 lambda obj: isinstance(obj, OptimizerV2), 

1718 versions=[ 

1719 tf.__internal__.saved_model.load.VersionedTypeRegistration( 

1720 object_factory=lambda proto: RestoredOptimizer(), 

1721 version=2, 

1722 min_producer_version=1, 

1723 min_consumer_version=1, 

1724 setter=RestoredOptimizer._set_hyper, 

1725 ) 

1726 ], 

1727) 

1728