Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py: 25%

496 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Version 2 of class Optimizer.""" 

16# pylint: disable=g-bad-name 

17 

18import abc 

19import contextlib 

20import functools 

21import warnings 

22 

23from tensorflow.python.distribute import central_storage_strategy 

24from tensorflow.python.distribute import distribute_lib 

25from tensorflow.python.distribute import parameter_server_strategy 

26from tensorflow.python.distribute import parameter_server_strategy_v2 

27from tensorflow.python.distribute import values as ds_values 

28from tensorflow.python.eager import backprop 

29from tensorflow.python.eager import context 

30from tensorflow.python.framework import dtypes 

31from tensorflow.python.framework import indexed_slices 

32from tensorflow.python.framework import ops 

33from tensorflow.python.framework import tensor_util 

34from tensorflow.python.keras import backend 

35from tensorflow.python.keras import initializers 

36from tensorflow.python.keras.engine import base_layer_utils 

37from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 

38from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 

39from tensorflow.python.keras.utils import generic_utils 

40from tensorflow.python.keras.utils import layer_utils 

41from tensorflow.python.keras.utils import tf_inspect 

42from tensorflow.python.keras.utils import tf_utils 

43from tensorflow.python.ops import array_ops 

44from tensorflow.python.ops import control_flow_ops 

45from tensorflow.python.ops import gen_resource_variable_ops 

46from tensorflow.python.ops import gradients 

47from tensorflow.python.ops import math_ops 

48from tensorflow.python.ops import variables as tf_variables 

49from tensorflow.python.saved_model import revived_types 

50from tensorflow.python.trackable import base as trackable 

51from tensorflow.python.util import nest 

52from tensorflow.python.util.tf_export import keras_export 

53 

54 

55_DEFAULT_VALID_DTYPES = frozenset([ 

56 dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, 

57 dtypes.complex64, dtypes.complex128 

58]) 

59 

60 

61def _deduplicate_indexed_slices(values, indices): 

62 """Sums `values` associated with any non-unique `indices`. 

63 

64 Args: 

65 values: A `Tensor` with rank >= 1. 

66 indices: A one-dimensional integer `Tensor`, indexing into the first 

67 dimension of `values` (as in an IndexedSlices object). 

68 

69 Returns: 

70 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 

71 de-duplicated version of `indices` and `summed_values` contains the sum of 

72 `values` slices associated with each unique index. 

73 """ 

74 unique_indices, new_index_positions = array_ops.unique(indices) 

75 summed_values = math_ops.unsorted_segment_sum( 

76 values, new_index_positions, 

77 array_ops.shape(unique_indices)[0]) 

78 return (summed_values, unique_indices) 

79 

80 

81class NullContextmanager(object): 

82 

83 def __init__(self, *args, **kwargs): 

84 pass 

85 

86 def __enter__(self): 

87 pass 

88 

89 def __exit__(self, type_arg, value_arg, traceback_arg): 

90 return False # False values do not suppress exceptions 

91 

92 

93def name_scope_only_in_function_or_graph(name): 

94 """Internal-only entry point for `name_scope*`. 

95 

96 Enters a compat.v1.name_scope only when in a function or graph, 

97 not when running fully eagerly. 

98 

99 Args: 

100 name: The name argument that is passed to the op function. 

101 

102 Returns: 

103 `name_scope*` context manager. 

104 """ 

105 if not context.executing_eagerly(): 

106 return ops.name_scope_v1(name) 

107 else: 

108 return NullContextmanager() 

109 

110 

111@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta) 

112class OptimizerV2(trackable.Trackable): 

113 """Base class for Keras optimizers. 

114 

115 You should not use this class directly, but instead instantiate one of its 

116 subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc. 

117 

118 ### Usage 

119 

120 ```python 

121 # Create an optimizer with the desired parameters. 

122 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 

123 # `loss` is a callable that takes no argument and returns the value 

124 # to minimize. 

125 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 

126 # In graph mode, returns op that minimizes the loss by updating the listed 

127 # variables. 

128 opt_op = opt.minimize(loss, var_list=[var1, var2]) 

129 opt_op.run() 

130 # In eager mode, simply call minimize to update the list of variables. 

131 opt.minimize(loss, var_list=[var1, var2]) 

132 ``` 

133 

134 ### Usage in custom training loops 

135 

136 In Keras models, sometimes variables are created when the model is first 

137 called, instead of construction time. Examples include 1) sequential models 

138 without input shape pre-defined, or 2) subclassed models. Pass var_list as 

139 callable in these cases. 

140 

141 Example: 

142 

143 ```python 

144 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 

145 model = tf.keras.Sequential() 

146 model.add(tf.keras.layers.Dense(num_hidden, activation='relu')) 

147 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')) 

148 loss_fn = lambda: tf.keras.losses.mse(model(input), output) 

149 var_list_fn = lambda: model.trainable_weights 

150 for input, output in data: 

151 opt.minimize(loss_fn, var_list_fn) 

152 ``` 

153 

154 ### Processing gradients before applying them 

155 

156 Calling `minimize()` takes care of both computing the gradients and 

157 applying them to the variables. If you want to process the gradients 

158 before applying them you can instead use the optimizer in three steps: 

159 

160 1. Compute the gradients with `tf.GradientTape`. 

161 2. Process the gradients as you wish. 

162 3. Apply the processed gradients with `apply_gradients()`. 

163 

164 Example: 

165 

166 ```python 

167 # Create an optimizer. 

168 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 

169 

170 # Compute the gradients for a list of variables. 

171 with tf.GradientTape() as tape: 

172 loss = <call_loss_function> 

173 vars = <list_of_variables> 

174 grads = tape.gradient(loss, vars) 

175 

176 # Process the gradients, for example cap them, etc. 

177 # capped_grads = [MyCapper(g) for g in grads] 

178 processed_grads = [process_gradient(g) for g in grads] 

179 

180 # Ask the optimizer to apply the processed gradients. 

181 opt.apply_gradients(zip(processed_grads, var_list)) 

182 ``` 

183 

184 ### Use with `tf.distribute.Strategy` 

185 

186 This optimizer class is `tf.distribute.Strategy` aware, which means it 

187 automatically sums gradients across all replicas. To average gradients, 

188 you divide your loss by the global batch size, which is done 

189 automatically if you use `tf.keras` built-in training or evaluation loops. 

190 See the `reduction` argument of your loss which should be set to 

191 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 

192 `tf.keras.losses.Reduction.SUM` for not. 

193 

194 To aggregate gradients yourself, call `apply_gradients` with 

195 `experimental_aggregate_gradients` set to False. This is useful if you need to 

196 process aggregated gradients. 

197 

198 If you are not using these and you want to average gradients, you should use 

199 `tf.math.reduce_sum` to add up your per-example losses and then divide by the 

200 global batch size. Note that when using `tf.distribute.Strategy`, the first 

201 component of a tensor's shape is the *replica-local* batch size, which is off 

202 by a factor equal to the number of replicas being used to compute a single 

203 step. As a result, using `tf.math.reduce_mean` will give the wrong answer, 

204 resulting in gradients that can be many times too big. 

205 

206 ### Variable Constraints 

207 

208 All Keras optimizers respect variable constraints. If constraint function is 

209 passed to any variable, the constraint will be applied to the variable after 

210 the gradient has been applied to the variable. 

211 Important: If gradient is sparse tensor, variable constraint is not supported. 

212 

213 ### Thread Compatibility 

214 

215 The entire optimizer is currently thread compatible, not thread-safe. The user 

216 needs to perform synchronization if necessary. 

217 

218 ### Slots 

219 

220 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 

221 additional variables associated with the variables to train. These are called 

222 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of 

223 the slots that it uses. Once you have a slot name you can ask the optimizer 

224 for the variable it created to hold the slot value. 

225 

226 This can be useful if you want to log debug a training algorithm, report stats 

227 about the slots, etc. 

228 

229 ### Hyperparameters 

230 

231 These are arguments passed to the optimizer subclass constructor 

232 (the `__init__` method), and then passed to `self._set_hyper()`. 

233 They can be either regular Python values (like 1.0), tensors, or 

234 callables. If they are callable, the callable will be called during 

235 `apply_gradients()` to get the value for the hyper parameter. 

236 

237 Hyperparameters can be overwritten through user code: 

238 

239 Example: 

240 

241 ```python 

242 # Create an optimizer with the desired parameters. 

243 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 

244 # `loss` is a callable that takes no argument and returns the value 

245 # to minimize. 

246 loss = lambda: 3 * var1 + 2 * var2 

247 # In eager mode, simply call minimize to update the list of variables. 

248 opt.minimize(loss, var_list=[var1, var2]) 

249 # update learning rate 

250 opt.learning_rate = 0.05 

251 opt.minimize(loss, var_list=[var1, var2]) 

252 ``` 

253 

254 ### Callable learning rate 

255 

256 Optimizer accepts a callable learning rate in two ways. The first way is 

257 through built-in or customized 

258 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be 

259 called on each iteration with `schedule(iteration)`, a `tf.Variable` 

260 owned by the optimizer. 

261 

262 Example: 

263 

264 >>> var = tf.Variable(np.random.random(size=(1,))) 

265 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( 

266 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) 

267 >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate) 

268 >>> loss = lambda: 3 * var 

269 >>> opt.minimize(loss, var_list=[var]) 

270 <tf.Variable... 

271 

272 The second way is through a callable function that 

273 does not accept any arguments. 

274 

275 Example: 

276 

277 >>> var = tf.Variable(np.random.random(size=(1,))) 

278 >>> def lr_callable(): 

279 ... return .1 

280 >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable) 

281 >>> loss = lambda: 3 * var 

282 >>> opt.minimize(loss, var_list=[var]) 

283 <tf.Variable... 

284 

285 ### Creating a custom optimizer 

286 

287 If you intend to create your own optimization algorithm, simply inherit from 

288 this class and override the following methods: 

289 

290 - `_resource_apply_dense` (update variable given gradient tensor is a dense 

291 `tf.Tensor`) 

292 - `_resource_apply_sparse` (update variable given gradient tensor is a 

293 sparse `tf.IndexedSlices`. The most common way for this to happen 

294 is if you are taking the gradient through a `tf.gather`.) 

295 - `_create_slots` 

296 (if your optimizer algorithm requires additional variables) 

297 - `get_config` 

298 (serialization of the optimizer, include all hyper parameters) 

299 """ 

300 

301 # Subclasses should set this to True unless they override `apply_gradients` 

302 # with a version that does not have the `experimental_aggregate_gradients` 

303 # argument. Older versions of Keras did not have this argument so custom 

304 # optimizers may have overridden `apply_gradients` without the 

305 # `experimental_aggregate_gradients` argument. Keras only passes 

306 # `experimental_aggregate_gradients` if this attribute is True. 

307 # Note: This attribute will likely be removed in an upcoming release. 

308 _HAS_AGGREGATE_GRAD = False 

309 

310 def __init__(self, 

311 name, 

312 gradient_aggregator=None, 

313 gradient_transformers=None, 

314 **kwargs): 

315 """Create a new Optimizer. 

316 

317 This must be called by the constructors of subclasses. 

318 Note that Optimizer instances should not bind to a single graph, 

319 and so shouldn't keep Tensors as member variables. Generally 

320 you should be able to use the _set_hyper()/state.get_hyper() 

321 facility instead. 

322 

323 This class is stateful and thread-compatible. 

324 

325 Example of custom gradient transformations: 

326 

327 ```python 

328 def my_gradient_transformer(grads_and_vars): 

329 # Simple example, double the gradients. 

330 return [(2. * g, v) for g, v in grads_and_vars] 

331 

332 optimizer = tf.keras.optimizers.SGD( 

333 1e-3, gradient_transformers=[my_gradient_transformer]) 

334 ``` 

335 

336 Args: 

337 name: String. The name to use for momentum accumulator weights created 

338 by the optimizer. 

339 gradient_aggregator: The function to use to aggregate gradients across 

340 devices (when using `tf.distribute.Strategy`). If `None`, defaults to 

341 summing the gradients across devices. The function should accept and 

342 return a list of `(gradient, variable)` tuples. 

343 gradient_transformers: Optional. List of functions to use to transform 

344 gradients before applying updates to Variables. The functions are 

345 applied after `gradient_aggregator`. The functions should accept and 

346 return a list of `(gradient, variable)` tuples. 

347 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 

348 `clipnorm`, `global_clipnorm`. 

349 If `clipvalue` (float) is set, the gradient of each weight 

350 is clipped to be no higher than this value. 

351 If `clipnorm` (float) is set, the gradient of each weight 

352 is individually clipped so that its norm is no higher than this value. 

353 If `global_clipnorm` (float) is set the gradient of all weights is 

354 clipped so that their global norm is no higher than this value. 

355 

356 Raises: 

357 ValueError: in case of any invalid argument. 

358 """ 

359 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"} 

360 for k in kwargs: 

361 if k not in allowed_kwargs: 

362 raise TypeError("Unexpected keyword argument " 

363 "passed to optimizer: " + str(k)) 

364 # checks that all keyword arguments are non-negative. 

365 if kwargs[k] is not None and kwargs[k] < 0: 

366 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) 

367 if k == "lr": 

368 warnings.warn( 

369 "The `lr` argument is deprecated, use `learning_rate` instead.") 

370 

371 self._use_locking = True 

372 self._init_set_name(name) 

373 self._hyper = {} 

374 # dict: {variable name : {slot name : variable}} 

375 self._slots = {} 

376 self._slot_names = [] 

377 self._weights = [] 

378 self._iterations = None 

379 

380 # For implementing Trackable. Stores information about how to restore 

381 # slot variables which have not yet been created 

382 # (trackable._CheckpointPosition objects). 

383 # {slot_name : 

384 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 

385 # ... } 

386 self._deferred_slot_restorations = {} 

387 

388 decay = kwargs.pop("decay", 0.0) 

389 if decay < 0.: 

390 raise ValueError("decay cannot be less than 0: {}".format(decay)) 

391 self._initial_decay = decay 

392 

393 self._hypers_created = False 

394 # Store the distribution strategy object if the optimizer is created inside 

395 # strategy scope, so it could be used to create variables later. 

396 if distribute_lib.has_strategy(): 

397 self._distribution_strategy = distribute_lib.get_strategy() 

398 else: 

399 self._distribution_strategy = None 

400 

401 # Configure gradient transformations. 

402 if gradient_aggregator is None: 

403 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients 

404 self.gradient_aggregator = gradient_aggregator 

405 if gradient_transformers is None: 

406 gradient_transformers = [] 

407 self.gradient_transformers = gradient_transformers 

408 self.clipnorm = kwargs.pop("clipnorm", None) 

409 self.global_clipnorm = kwargs.pop("global_clipnorm", None) 

410 if self.clipnorm is not None and self.global_clipnorm is not None: 

411 raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, " 

412 "passed `clipnorm` {}, `global_clipnorm` {}".format( 

413 self.clipnorm, self.global_clipnorm)) 

414 self.clipvalue = kwargs.pop("clipvalue", None) 

415 

416 @property 

417 def clipnorm(self): 

418 """`float` or `None`. If set, clips gradients to a maximum norm.""" 

419 return self._clipnorm 

420 

421 @property 

422 def global_clipnorm(self): 

423 """`float` or `None`. If set, clips gradients to a maximum norm.""" 

424 return self._global_clipnorm 

425 

426 @clipnorm.setter 

427 def clipnorm(self, val): 

428 if val is not None and self.gradient_transformers: 

429 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 

430 "is set. Instead, use the `gradient_transformers` to " 

431 "specify clipping and other transformations.") 

432 self._clipnorm = val 

433 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn( 

434 self._clipnorm) 

435 

436 @global_clipnorm.setter 

437 def global_clipnorm(self, val): 

438 if val is not None and self.gradient_transformers: 

439 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 

440 "is set. Instead, use the `gradient_transformers` to " 

441 "specify clipping and other transformations.") 

442 self._global_clipnorm = val 

443 self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn( 

444 self._global_clipnorm) 

445 

446 @property 

447 def clipvalue(self): 

448 """`float` or `None`. If set, clips gradients to a maximum value.""" 

449 return self._clipvalue 

450 

451 @clipvalue.setter 

452 def clipvalue(self, val): 

453 if val is not None and self.gradient_transformers: 

454 raise ValueError("`clipvalue` cannot be set when `gradient_transformers` " 

455 "is set. Instead, use the `gradient_transformers` to " 

456 "specify clipping and other transformations.") 

457 self._clipvalue = val 

458 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn( 

459 self._clipvalue) 

460 

461 def _transform_loss(self, loss): 

462 """Called in `.minimize` to transform loss before computing gradients.""" 

463 return loss 

464 

465 def _get_gradients(self, tape, loss, var_list, grad_loss=None): 

466 """Called in `minimize` to compute gradients from loss.""" 

467 grads = tape.gradient(loss, var_list, grad_loss) 

468 return list(zip(grads, var_list)) 

469 

470 def _transform_unaggregated_gradients(self, grads_and_vars): 

471 """Called in `apply_gradients` before gradient aggregation.""" 

472 return grads_and_vars 

473 

474 def _aggregate_gradients(self, grads_and_vars): 

475 """Called in `apply_gradients` to aggregate gradients across devices. 

476 

477 Note that user subclasses may override this, so the interface should not be 

478 changed. 

479 

480 Args: 

481 grads_and_vars: List of (gradient, variable) pairs. 

482 

483 Returns: 

484 A list of (aggregrated_gradient, variable) pairs. By default, this calls 

485 `self.gradient_aggregator`. 

486 """ 

487 return self.gradient_aggregator(grads_and_vars) 

488 

489 def _transform_gradients(self, grads_and_vars): 

490 """Called in `apply_gradients` after aggregation.""" 

491 if self._clipvalue is not None: 

492 grads_and_vars = self._clipvalue_fn(grads_and_vars) 

493 if self._clipnorm is not None: 

494 grads_and_vars = self._clipnorm_fn(grads_and_vars) 

495 if self._global_clipnorm is not None: 

496 grads_and_vars = self._global_clipnorm_fn(grads_and_vars) 

497 

498 for fn in self.gradient_transformers: 

499 grads_and_vars = fn(grads_and_vars) 

500 return grads_and_vars 

501 

502 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None): 

503 """Minimize `loss` by updating `var_list`. 

504 

505 This method simply computes gradient using `tf.GradientTape` and calls 

506 `apply_gradients()`. If you want to process the gradient before applying 

507 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 

508 of using this function. 

509 

510 Args: 

511 loss: `Tensor` or callable. If a callable, `loss` should take no arguments 

512 and return the value to minimize. If a `Tensor`, the `tape` argument 

513 must be passed. 

514 var_list: list or tuple of `Variable` objects to update to minimize 

515 `loss`, or a callable returning the list or tuple of `Variable` objects. 

516 Use callable when the variable list would otherwise be incomplete before 

517 `minimize` since the variables are created at the first time `loss` is 

518 called. 

519 grad_loss: (Optional). A `Tensor` holding the gradient computed for 

520 `loss`. 

521 name: (Optional) str. Name for the returned operation. 

522 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 

523 the tape that computed the `loss` must be provided. 

524 

525 Returns: 

526 An `Operation` that updates the variables in `var_list`. The `iterations` 

527 will be automatically increased by 1. 

528 

529 Raises: 

530 ValueError: If some of the variables are not `Variable` objects. 

531 

532 """ 

533 grads_and_vars = self._compute_gradients( 

534 loss, var_list=var_list, grad_loss=grad_loss, tape=tape) 

535 return self.apply_gradients(grads_and_vars, name=name) 

536 

537 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 

538 """Compute gradients of `loss` for the variables in `var_list`. 

539 

540 This is the first part of `minimize()`. It returns a list 

541 of (gradient, variable) pairs where "gradient" is the gradient 

542 for "variable". Note that "gradient" can be a `Tensor`, an 

543 `IndexedSlices`, or `None` if there is no gradient for the 

544 given variable. 

545 

546 Args: 

547 loss: `Tensor` or callable. If a callable, `loss` should take no 

548 arguments and return the value to minimize. If a `Tensor`, the `tape` 

549 argument must be passed. 

550 var_list: list or tuple of `Variable` objects to update to minimize 

551 `loss`, or a callable returning the list or tuple of `Variable` objects. 

552 Use callable when the variable list would otherwise be incomplete before 

553 `minimize` and the variables are created at the first time when `loss` 

554 is called. 

555 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 

556 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 

557 the tape that computed the `loss` must be provided. 

558 

559 Returns: 

560 A list of (gradient, variable) pairs. Variable is always present, but 

561 gradient can be `None`. 

562 

563 Raises: 

564 TypeError: If `var_list` contains anything else than `Variable` objects. 

565 ValueError: If some arguments are invalid, or var_list is None. 

566 """ 

567 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 

568 if not callable(loss) and tape is None: 

569 raise ValueError("`tape` is required when a `Tensor` loss is passed.") 

570 tape = tape if tape is not None else backprop.GradientTape() 

571 

572 if callable(loss): 

573 with tape: 

574 if not callable(var_list): 

575 tape.watch(var_list) 

576 loss = loss() 

577 if callable(var_list): 

578 var_list = var_list() 

579 

580 with tape: 

581 loss = self._transform_loss(loss) 

582 

583 var_list = nest.flatten(var_list) 

584 with ops.name_scope_v2(self._name + "/gradients"): 

585 grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss) 

586 

587 self._assert_valid_dtypes([ 

588 v for g, v in grads_and_vars 

589 if g is not None and v.dtype != dtypes.resource 

590 ]) 

591 

592 return grads_and_vars 

593 

594 def apply_gradients(self, 

595 grads_and_vars, 

596 name=None, 

597 experimental_aggregate_gradients=True): 

598 """Apply gradients to variables. 

599 

600 This is the second part of `minimize()`. It returns an `Operation` that 

601 applies gradients. 

602 

603 The method sums gradients from all replicas in the presence of 

604 `tf.distribute.Strategy` by default. You can aggregate gradients yourself by 

605 passing `experimental_aggregate_gradients=False`. 

606 

607 Example: 

608 

609 ```python 

610 grads = tape.gradient(loss, vars) 

611 grads = tf.distribute.get_replica_context().all_reduce('sum', grads) 

612 # Processing aggregated gradients. 

613 optimizer.apply_gradients(zip(grads, vars), 

614 experimental_aggregate_gradients=False) 

615 

616 ``` 

617 

618 Args: 

619 grads_and_vars: List of (gradient, variable) pairs. 

620 name: Optional name for the returned operation. Default to the name passed 

621 to the `Optimizer` constructor. 

622 experimental_aggregate_gradients: Whether to sum gradients from different 

623 replicas in the presense of `tf.distribute.Strategy`. If False, it's 

624 user responsibility to aggregate the gradients. Default to True. 

625 

626 Returns: 

627 An `Operation` that applies the specified gradients. The `iterations` 

628 will be automatically increased by 1. 

629 

630 Raises: 

631 TypeError: If `grads_and_vars` is malformed. 

632 ValueError: If none of the variables have gradients. 

633 RuntimeError: If called in a cross-replica context. 

634 """ 

635 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 

636 var_list = [v for (_, v) in grads_and_vars] 

637 

638 with ops.name_scope_v2(self._name): 

639 # Create iteration if necessary. 

640 with ops.init_scope(): 

641 self._create_all_weights(var_list) 

642 

643 if not grads_and_vars: 

644 # Distribution strategy does not support reducing an empty list of 

645 # gradients 

646 return control_flow_ops.no_op() 

647 

648 if distribute_lib.in_cross_replica_context(): 

649 raise RuntimeError( 

650 "`apply_gradients() cannot be called in cross-replica context. " 

651 "Use `tf.distribute.Strategy.run` to enter replica " 

652 "context.") 

653 

654 strategy = distribute_lib.get_strategy() 

655 if (not experimental_aggregate_gradients and strategy and 

656 isinstance(strategy, 

657 (parameter_server_strategy.ParameterServerStrategyV1, 

658 parameter_server_strategy_v2.ParameterServerStrategyV2, 

659 central_storage_strategy.CentralStorageStrategy, 

660 central_storage_strategy.CentralStorageStrategyV1))): 

661 raise NotImplementedError( 

662 "`experimental_aggregate_gradients=False is not supported for " 

663 "ParameterServerStrategy and CentralStorageStrategy") 

664 

665 apply_state = self._prepare(var_list) 

666 if experimental_aggregate_gradients: 

667 grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars) 

668 grads_and_vars = self._aggregate_gradients(grads_and_vars) 

669 grads_and_vars = self._transform_gradients(grads_and_vars) 

670 

671 if optimizer_utils.strategy_supports_no_merge_call(): 

672 return self._distributed_apply(strategy, grads_and_vars, name, 

673 apply_state) 

674 else: 

675 return distribute_lib.get_replica_context().merge_call( 

676 functools.partial(self._distributed_apply, apply_state=apply_state), 

677 args=(grads_and_vars,), 

678 kwargs={ 

679 "name": name, 

680 }) 

681 

682 def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): 

683 """`apply_gradients` using a `DistributionStrategy`.""" 

684 

685 def apply_grad_to_update_var(var, grad): 

686 """Apply gradient to variable.""" 

687 if isinstance(var, ops.Tensor): 

688 raise NotImplementedError("Trying to update a Tensor ", var) 

689 

690 apply_kwargs = {} 

691 if isinstance(grad, indexed_slices.IndexedSlices): 

692 if var.constraint is not None: 

693 raise RuntimeError( 

694 "Cannot use a constraint function on a sparse variable.") 

695 if "apply_state" in self._sparse_apply_args: 

696 apply_kwargs["apply_state"] = apply_state 

697 return self._resource_apply_sparse_duplicate_indices( 

698 grad.values, var, grad.indices, **apply_kwargs) 

699 

700 if "apply_state" in self._dense_apply_args: 

701 apply_kwargs["apply_state"] = apply_state 

702 update_op = self._resource_apply_dense(grad, var, **apply_kwargs) 

703 if var.constraint is not None: 

704 with ops.control_dependencies([update_op]): 

705 return var.assign(var.constraint(var)) 

706 else: 

707 return update_op 

708 

709 eagerly_outside_functions = ops.executing_eagerly_outside_functions() 

710 update_ops = [] 

711 with name_scope_only_in_function_or_graph(name or self._name): 

712 for grad, var in grads_and_vars: 

713 # Colocate the update with variables to avoid unnecessary communication 

714 # delays. See b/136304694. 

715 with distribution.extended.colocate_vars_with(var): 

716 with name_scope_only_in_function_or_graph( 

717 "update" if eagerly_outside_functions else "update_" + 

718 var.op.name): 

719 update_op = distribution.extended.update( 

720 var, apply_grad_to_update_var, args=(grad,), group=False) 

721 if distribute_lib.in_cross_replica_context(): 

722 # In cross-replica context, extended.update returns a list of 

723 # update ops from all replicas (group=False). 

724 update_ops.extend(update_op) 

725 else: 

726 # In replica context, extended.update return the single update op 

727 # of current replica. 

728 update_ops.append(update_op) 

729 

730 any_symbolic = any(isinstance(i, ops.Operation) or 

731 tf_utils.is_symbolic_tensor(i) for i in update_ops) 

732 if not context.executing_eagerly() or any_symbolic: 

733 # If the current context is graph mode or any of the update ops are 

734 # symbolic then the step update should be carried out under a graph 

735 # context. (eager updates execute immediately) 

736 with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access 

737 with ops.control_dependencies([control_flow_ops.group(update_ops)]): 

738 return self._iterations.assign_add(1, read_value=False) 

739 

740 return self._iterations.assign_add(1) 

741 

742 def get_gradients(self, loss, params): 

743 """Returns gradients of `loss` with respect to `params`. 

744 

745 Should be used only in legacy v1 graph mode. 

746 

747 Args: 

748 loss: Loss tensor. 

749 params: List of variables. 

750 

751 Returns: 

752 List of gradient tensors. 

753 

754 Raises: 

755 ValueError: In case any gradient cannot be computed (e.g. if gradient 

756 function not implemented). 

757 """ 

758 params = nest.flatten(params) 

759 with backend.get_graph().as_default(), backend.name_scope(self._name + 

760 "/gradients"): 

761 grads = gradients.gradients(loss, params) 

762 for grad, param in zip(grads, params): 

763 if grad is None: 

764 raise ValueError("Variable {} has `None` for gradient. " 

765 "Please make sure that all of your ops have a " 

766 "gradient defined (i.e. are differentiable). " 

767 "Common ops without gradient: " 

768 "K.argmax, K.round, K.eval.".format(param)) 

769 return grads 

770 

771 def get_updates(self, loss, params): 

772 grads = self.get_gradients(loss, params) 

773 grads_and_vars = list(zip(grads, params)) 

774 self._assert_valid_dtypes([ 

775 v for g, v in grads_and_vars 

776 if g is not None and v.dtype != dtypes.resource 

777 ]) 

778 return [self.apply_gradients(grads_and_vars)] 

779 

780 def _set_hyper(self, name, value): 

781 """set hyper `name` to value. value can be callable, tensor, numeric.""" 

782 if isinstance(value, trackable.Trackable): 

783 self._track_trackable(value, name, overwrite=True) 

784 if name not in self._hyper: 

785 self._hyper[name] = value 

786 else: 

787 prev_value = self._hyper[name] 

788 if (callable(prev_value) 

789 or isinstance(prev_value, 

790 (ops.Tensor, int, float, 

791 learning_rate_schedule.LearningRateSchedule)) 

792 or isinstance(value, learning_rate_schedule.LearningRateSchedule)): 

793 self._hyper[name] = value 

794 else: 

795 backend.set_value(self._hyper[name], value) 

796 

797 def _get_hyper(self, name, dtype=None): 

798 if not self._hypers_created: 

799 self._create_hypers() 

800 value = self._hyper[name] 

801 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 

802 return value 

803 if callable(value): 

804 value = value() 

805 if dtype: 

806 return math_ops.cast(value, dtype) 

807 else: 

808 return value 

809 

810 def _create_slots(self, var_list): 

811 pass 

812 

813 def _create_all_weights(self, var_list): 

814 """Creates all weights, including iterations, hyperparameters and slot vars. 

815 

816 This will add newly created variables to `optimizer.weights`. 

817 

818 New variables are only created when this method is called the first time, or 

819 when called with different variables in the var_list. 

820 

821 Args: 

822 var_list: list or tuple of `Variable` objects that will be minimized 

823 using this optimizer. 

824 """ 

825 

826 _ = self.iterations 

827 self._create_hypers() 

828 self._create_slots(var_list) 

829 

830 def __getattribute__(self, name): 

831 """Overridden to support hyperparameter access.""" 

832 try: 

833 return super(OptimizerV2, self).__getattribute__(name) 

834 except AttributeError as e: 

835 # Needed to avoid infinite recursion with __setattr__. 

836 if name == "_hyper": 

837 raise e 

838 # Backwards compatibility with Keras optimizers. 

839 if name == "lr": 

840 name = "learning_rate" 

841 if name in self._hyper: 

842 return self._get_hyper(name) 

843 raise e 

844 

845 def __dir__(self): 

846 result = set(super(OptimizerV2, self).__dir__()) 

847 if "_hyper" in result: 

848 result |= self._hyper.keys() 

849 if "learning_rate" in self._hyper.keys(): 

850 result.add("lr") 

851 return list(result) 

852 

853 def __setattr__(self, name, value): 

854 """Override setattr to support dynamic hyperparameter setting.""" 

855 # Backwards compatibility with Keras optimizers. 

856 if name == "lr": 

857 name = "learning_rate" 

858 if hasattr(self, "_hyper") and name in self._hyper: 

859 self._set_hyper(name, value) 

860 else: 

861 super(OptimizerV2, self).__setattr__(name, value) 

862 

863 def get_slot_names(self): 

864 """A list of names for this optimizer's slots.""" 

865 return self._slot_names 

866 

867 def add_slot(self, var, slot_name, initializer="zeros", shape=None): 

868 """Add a new slot variable for `var`. 

869 

870 A slot variable is an additional variable associated with `var` to train. 

871 It is allocated and managed by optimizers, e.g. `Adam`. 

872 

873 Args: 

874 var: a `Variable` object. 

875 slot_name: name of the slot variable. 

876 initializer: initializer of the slot variable 

877 shape: (Optional) shape of the slot variable. If not set, it will default 

878 to the shape of `var`. 

879 

880 Returns: 

881 A slot variable. 

882 """ 

883 if slot_name not in self._slot_names: 

884 self._slot_names.append(slot_name) 

885 var_key = _var_key(var) 

886 slot_dict = self._slots.setdefault(var_key, {}) 

887 weight = slot_dict.get(slot_name, None) 

888 if weight is None: 

889 if isinstance(initializer, str) or callable(initializer): 

890 initializer = initializers.get(initializer) 

891 if isinstance( 

892 initializer, 

893 trackable.CheckpointInitialValueCallable) or (shape is not None): 

894 slot_shape = shape 

895 else: 

896 slot_shape = var.shape 

897 initial_value = functools.partial( 

898 initializer, shape=slot_shape, dtype=var.dtype) 

899 else: 

900 initial_value = initializer 

901 

902 with self._distribution_strategy_scope(): 

903 strategy = distribute_lib.get_strategy() 

904 if not strategy.extended.variable_created_in_scope(var): 

905 raise ValueError( 

906 "Trying to create optimizer slot variable under the scope for " 

907 "tf.distribute.Strategy ({}), which is different from the scope " 

908 "used for the original variable ({}). Make sure the slot " 

909 "variables are created under the same strategy scope. This may " 

910 "happen if you're restoring from a checkpoint outside the scope" 

911 .format(strategy, var)) 

912 

913 with strategy.extended.colocate_vars_with(var): 

914 weight = tf_variables.Variable( 

915 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access 

916 dtype=var.dtype, 

917 trainable=False, 

918 initial_value=initial_value) 

919 backend.track_variable(weight) 

920 slot_dict[slot_name] = weight 

921 self._restore_slot_variable( 

922 slot_name=slot_name, variable=var, 

923 slot_variable=weight) 

924 self._weights.append(weight) 

925 return weight 

926 

927 def get_slot(self, var, slot_name): 

928 var_key = _var_key(var) 

929 slot_dict = self._slots[var_key] 

930 return slot_dict[slot_name] 

931 

932 def _prepare(self, var_list): 

933 keys = set() 

934 for var in var_list: 

935 if isinstance(var, ds_values.DistributedValues): 

936 var_devices = var._devices # pylint: disable=protected-access 

937 else: 

938 var_devices = [var.device] 

939 var_dtype = var.dtype.base_dtype 

940 for var_device in var_devices: 

941 keys.add((var_device, var_dtype)) 

942 

943 apply_state = {} 

944 for var_device, var_dtype in keys: 

945 apply_state[(var_device, var_dtype)] = {} 

946 with ops.device(var_device): 

947 self._prepare_local(var_device, var_dtype, apply_state) 

948 

949 return apply_state 

950 

951 def _prepare_local(self, var_device, var_dtype, apply_state): 

952 if "learning_rate" in self._hyper: 

953 lr_t = array_ops.identity(self._decayed_lr(var_dtype)) 

954 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t 

955 

956 def _fallback_apply_state(self, var_device, var_dtype): 

957 """Compatibility for subclasses that don't pass apply_state through.""" 

958 apply_state = {(var_device, var_dtype): {}} 

959 self._prepare_local(var_device, var_dtype, apply_state) 

960 return apply_state[(var_device, var_dtype)] 

961 

962 def _create_hypers(self): 

963 if self._hypers_created: 

964 return 

965 with self._distribution_strategy_scope(): 

966 # Iterate hyper values deterministically. 

967 for name, value in sorted(self._hyper.items()): 

968 if isinstance(value, 

969 (ops.Tensor, tf_variables.Variable)) or callable(value): 

970 # The check for `callable` covers the usage when `value` is a 

971 # `LearningRateSchedule`, in which case it does not need to create a 

972 # variable. 

973 continue 

974 else: 

975 self._hyper[name] = self.add_weight( 

976 name, 

977 shape=[], 

978 trainable=False, 

979 initializer=value, 

980 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 

981 self._hypers_created = True 

982 

983 @property 

984 def iterations(self): 

985 """Variable. The number of training steps this Optimizer has run.""" 

986 if self._iterations is None: 

987 with self._distribution_strategy_scope(): 

988 self._iterations = self.add_weight( 

989 "iter", 

990 shape=[], 

991 dtype=dtypes.int64, 

992 trainable=False, 

993 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 

994 self._weights.append(self._iterations) 

995 return self._iterations 

996 

997 @iterations.setter 

998 def iterations(self, variable): 

999 if self._iterations is not None: 

1000 raise RuntimeError("Cannot set `iterations` to a new Variable after " 

1001 "the Optimizer weights have been created") 

1002 self._iterations = variable 

1003 self._weights.append(self._iterations) 

1004 

1005 def _decayed_lr(self, var_dtype): 

1006 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 

1007 lr_t = self._get_hyper("learning_rate", var_dtype) 

1008 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 

1009 local_step = math_ops.cast(self.iterations, var_dtype) 

1010 lr_t = math_ops.cast(lr_t(local_step), var_dtype) 

1011 if self._initial_decay > 0.: 

1012 local_step = math_ops.cast(self.iterations, var_dtype) 

1013 decay_t = math_ops.cast(self._initial_decay, var_dtype) 

1014 lr_t = lr_t / (1. + decay_t * local_step) 

1015 return lr_t 

1016 

1017 @abc.abstractmethod 

1018 def get_config(self): 

1019 """Returns the config of the optimizer. 

1020 

1021 An optimizer config is a Python dictionary (serializable) 

1022 containing the configuration of an optimizer. 

1023 The same optimizer can be reinstantiated later 

1024 (without any saved state) from this configuration. 

1025 

1026 Returns: 

1027 Python dictionary. 

1028 """ 

1029 config = {"name": self._name} 

1030 if self.clipnorm is not None: 

1031 config["clipnorm"] = self.clipnorm 

1032 if self.clipvalue is not None: 

1033 config["clipvalue"] = self.clipvalue 

1034 if self.global_clipnorm is not None: 

1035 config["global_clipnorm"] = self.global_clipnorm 

1036 return config 

1037 

1038 @classmethod 

1039 def from_config(cls, config, custom_objects=None): 

1040 """Creates an optimizer from its config. 

1041 

1042 This method is the reverse of `get_config`, 

1043 capable of instantiating the same optimizer from the config 

1044 dictionary. 

1045 

1046 Args: 

1047 config: A Python dictionary, typically the output of get_config. 

1048 custom_objects: A Python dictionary mapping names to additional Python 

1049 objects used to create this optimizer, such as a function used for a 

1050 hyperparameter. 

1051 

1052 Returns: 

1053 An optimizer instance. 

1054 """ 

1055 if "lr" in config: 

1056 config["learning_rate"] = config.pop("lr") 

1057 if "learning_rate" in config: 

1058 if isinstance(config["learning_rate"], dict): 

1059 config["learning_rate"] = learning_rate_schedule.deserialize( 

1060 config["learning_rate"], custom_objects=custom_objects) 

1061 return cls(**config) 

1062 

1063 def _serialize_hyperparameter(self, hyperparameter_name): 

1064 """Serialize a hyperparameter that can be a float, callable, or Tensor.""" 

1065 value = self._hyper[hyperparameter_name] 

1066 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 

1067 return learning_rate_schedule.serialize(value) 

1068 if callable(value): 

1069 return value() 

1070 if tensor_util.is_tf_type(value): 

1071 return backend.get_value(value) 

1072 return value 

1073 

1074 def variables(self): 

1075 """Returns variables of this Optimizer based on the order created.""" 

1076 return self._weights 

1077 

1078 @property 

1079 def weights(self): 

1080 """Returns variables of this Optimizer based on the order created.""" 

1081 return self._weights 

1082 

1083 def get_weights(self): 

1084 """Returns the current weights of the optimizer. 

1085 

1086 The weights of an optimizer are its state (ie, variables). 

1087 This function returns the weight values associated with this 

1088 optimizer as a list of Numpy arrays. The first value is always the 

1089 iterations count of the optimizer, followed by the optimizer's state 

1090 variables in the order they were created. The returned list can in turn 

1091 be used to load state into similarly parameterized optimizers. 

1092 

1093 For example, the RMSprop optimizer for this simple model returns a list of 

1094 three values-- the iteration count, followed by the root-mean-square value 

1095 of the kernel and bias of the single Dense layer: 

1096 

1097 >>> opt = tf.keras.optimizers.RMSprop() 

1098 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

1099 >>> m.compile(opt, loss='mse') 

1100 >>> data = np.arange(100).reshape(5, 20) 

1101 >>> labels = np.zeros(5) 

1102 >>> results = m.fit(data, labels) # Training. 

1103 >>> len(opt.get_weights()) 

1104 3 

1105 

1106 Returns: 

1107 Weights values as a list of numpy arrays. 

1108 """ 

1109 params = self.weights 

1110 return backend.batch_get_value(params) 

1111 

1112 # TODO(tanzheny): Maybe share this logic with base_layer. 

1113 def set_weights(self, weights): 

1114 """Set the weights of the optimizer. 

1115 

1116 The weights of an optimizer are its state (ie, variables). 

1117 This function takes the weight values associated with this 

1118 optimizer as a list of Numpy arrays. The first value is always the 

1119 iterations count of the optimizer, followed by the optimizer's state 

1120 variables in the order they are created. The passed values are used to set 

1121 the new state of the optimizer. 

1122 

1123 For example, the RMSprop optimizer for this simple model takes a list of 

1124 three values-- the iteration count, followed by the root-mean-square value 

1125 of the kernel and bias of the single Dense layer: 

1126 

1127 >>> opt = tf.keras.optimizers.RMSprop() 

1128 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

1129 >>> m.compile(opt, loss='mse') 

1130 >>> data = np.arange(100).reshape(5, 20) 

1131 >>> labels = np.zeros(5) 

1132 >>> results = m.fit(data, labels) # Training. 

1133 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])] 

1134 >>> opt.set_weights(new_weights) 

1135 >>> opt.iterations 

1136 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10> 

1137 

1138 Args: 

1139 weights: weight values as a list of numpy arrays. 

1140 """ 

1141 params = self.weights 

1142 if len(params) != len(weights): 

1143 raise ValueError( 

1144 "You called `set_weights(weights)` on optimizer " + self._name + 

1145 " with a weight list of length " + str(len(weights)) + 

1146 ", but the optimizer was expecting " + str(len(params)) + 

1147 " weights. Provided weights: " + str(weights)[:50] + "...") 

1148 if not params: 

1149 return 

1150 weight_value_tuples = [] 

1151 param_values = backend.batch_get_value(params) 

1152 for pv, p, w in zip(param_values, params, weights): 

1153 if pv.shape != w.shape: 

1154 raise ValueError("Optimizer weight shape " + str(pv.shape) + 

1155 " not compatible with " 

1156 "provided weight shape " + str(w.shape)) 

1157 weight_value_tuples.append((p, w)) 

1158 backend.batch_set_value(weight_value_tuples) 

1159 

1160 def add_weight(self, 

1161 name, 

1162 shape, 

1163 dtype=None, 

1164 initializer="zeros", 

1165 trainable=None, 

1166 synchronization=tf_variables.VariableSynchronization.AUTO, 

1167 aggregation=tf_variables.VariableAggregation.NONE): 

1168 

1169 if dtype is None: 

1170 dtype = dtypes.float32 

1171 if isinstance(initializer, str) or callable(initializer): 

1172 initializer = initializers.get(initializer) 

1173 

1174 if synchronization == tf_variables.VariableSynchronization.ON_READ: 

1175 if trainable: 

1176 raise ValueError( 

1177 "Synchronization value can be set to " 

1178 "VariableSynchronization.ON_READ only for non-trainable variables. " 

1179 "You have specified trainable=True and " 

1180 "synchronization=VariableSynchronization.ON_READ.") 

1181 else: 

1182 # Set trainable to be false when variable is to be synced on read. 

1183 trainable = False 

1184 elif trainable is None: 

1185 trainable = True 

1186 

1187 variable = self._add_variable_with_custom_getter( 

1188 name=name, 

1189 shape=shape, 

1190 getter=base_layer_utils.make_variable, 

1191 overwrite=True, 

1192 initializer=initializer, 

1193 dtype=dtype, 

1194 trainable=trainable, 

1195 use_resource=True, 

1196 synchronization=synchronization, 

1197 aggregation=aggregation) 

1198 backend.track_variable(variable) 

1199 

1200 return variable 

1201 

1202 def _init_set_name(self, name, zero_based=True): 

1203 if not name: 

1204 self._name = backend.unique_object_name( 

1205 generic_utils.to_snake_case(self.__class__.__name__), 

1206 zero_based=zero_based) 

1207 else: 

1208 self._name = name 

1209 

1210 def _assert_valid_dtypes(self, tensors): 

1211 """Asserts tensors are all valid types (see `_valid_dtypes`). 

1212 

1213 Args: 

1214 tensors: Tensors to check. 

1215 

1216 Raises: 

1217 ValueError: If any tensor is not a valid type. 

1218 """ 

1219 valid_dtypes = self._valid_dtypes() 

1220 for t in tensors: 

1221 dtype = t.dtype.base_dtype 

1222 if dtype not in valid_dtypes: 

1223 raise ValueError("Invalid type %r for %s, expected: %s." % 

1224 (dtype, t.name, [v for v in valid_dtypes])) 

1225 

1226 def _valid_dtypes(self): 

1227 """Valid types for loss, variables and gradients. 

1228 

1229 Subclasses should override to allow other float types. 

1230 

1231 Returns: 

1232 Valid types for loss, variables and gradients. 

1233 """ 

1234 return _DEFAULT_VALID_DTYPES 

1235 

1236 def _call_if_callable(self, param): 

1237 """Call the function if param is callable.""" 

1238 return param() if callable(param) else param 

1239 

1240 def _resource_apply_dense(self, grad, handle, apply_state): 

1241 """Add ops to apply dense gradients to the variable `handle`. 

1242 

1243 Args: 

1244 grad: a `Tensor` representing the gradient. 

1245 handle: a `Tensor` of dtype `resource` which points to the variable to be 

1246 updated. 

1247 apply_state: A dict which is used across multiple apply calls. 

1248 

1249 Returns: 

1250 An `Operation` which updates the value of the variable. 

1251 """ 

1252 raise NotImplementedError("Must be implemented in subclasses.") 

1253 

1254 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices, 

1255 **kwargs): 

1256 """Add ops to apply sparse gradients to `handle`, with repeated indices. 

1257 

1258 Optimizers which override this method must deal with repeated indices. See 

1259 the docstring of `_apply_sparse_duplicate_indices` for details. By default 

1260 the correct behavior, to sum non-unique indices and their associated 

1261 gradients, is enforced by first pre-processing `grad` and `indices` and 

1262 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 

1263 with duplicate indices may instead override this method to avoid the 

1264 overhead of summing. 

1265 

1266 Args: 

1267 grad: a `Tensor` representing the gradient for the affected indices. 

1268 handle: a `Tensor` of dtype `resource` which points to the variable to be 

1269 updated. 

1270 indices: a `Tensor` of integral type representing the indices for which 

1271 the gradient is nonzero. Indices may be repeated. 

1272 **kwargs: May optionally contain `apply_state` 

1273 

1274 Returns: 

1275 An `Operation` which updates the value of the variable. 

1276 """ 

1277 summed_grad, unique_indices = _deduplicate_indexed_slices( 

1278 values=grad, indices=indices) 

1279 return self._resource_apply_sparse(summed_grad, handle, unique_indices, 

1280 **kwargs) 

1281 

1282 def _resource_apply_sparse(self, grad, handle, indices, apply_state): 

1283 """Add ops to apply sparse gradients to the variable `handle`. 

1284 

1285 Similar to `_apply_sparse`, the `indices` argument to this method has been 

1286 de-duplicated. Optimizers which deal correctly with non-unique indices may 

1287 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 

1288 overhead. 

1289 

1290 Args: 

1291 grad: a `Tensor` representing the gradient for the affected indices. 

1292 handle: a `Tensor` of dtype `resource` which points to the variable to be 

1293 updated. 

1294 indices: a `Tensor` of integral type representing the indices for which 

1295 the gradient is nonzero. Indices are unique. 

1296 apply_state: A dict which is used across multiple apply calls. 

1297 

1298 Returns: 

1299 An `Operation` which updates the value of the variable. 

1300 """ 

1301 raise NotImplementedError("Must be implemented in subclasses.") 

1302 

1303 def _resource_scatter_add(self, x, i, v): 

1304 with ops.control_dependencies([ 

1305 gen_resource_variable_ops.ResourceScatterAdd( 

1306 resource=x.handle, indices=i, updates=v) 

1307 ]): 

1308 return x.value() 

1309 

1310 def _resource_scatter_update(self, x, i, v): 

1311 with ops.control_dependencies( 

1312 [gen_resource_variable_ops.ResourceScatterUpdate( 

1313 resource=x.handle, indices=i, updates=v)]): 

1314 return x.value() 

1315 

1316 @property 

1317 @layer_utils.cached_per_instance 

1318 def _dense_apply_args(self): 

1319 return tf_inspect.getfullargspec(self._resource_apply_dense).args 

1320 

1321 @property 

1322 @layer_utils.cached_per_instance 

1323 def _sparse_apply_args(self): 

1324 return tf_inspect.getfullargspec(self._resource_apply_sparse).args 

1325 

1326 # --------------- 

1327 # For implementing the trackable interface 

1328 # --------------- 

1329 

1330 def _restore_slot_variable(self, slot_name, variable, slot_variable): 

1331 """Restore a newly created slot variable's value.""" 

1332 variable_key = _var_key(variable) 

1333 deferred_restorations = self._deferred_slot_restorations.get( 

1334 slot_name, {}).pop(variable_key, []) 

1335 # Iterate over restores, highest restore UID first to minimize the number 

1336 # of assignments. 

1337 deferred_restorations.sort(key=lambda position: position.restore_uid, 

1338 reverse=True) 

1339 for checkpoint_position in deferred_restorations: 

1340 checkpoint_position.restore(slot_variable) 

1341 

1342 def _create_or_restore_slot_variable( 

1343 self, slot_variable_position, slot_name, variable): 

1344 """Restore a slot variable's value, possibly creating it. 

1345 

1346 Called when a variable which has an associated slot variable is created or 

1347 restored. When executing eagerly, we create the slot variable with a 

1348 restoring initializer. 

1349 

1350 No new variables are created when graph building. Instead, 

1351 _restore_slot_variable catches these after normal creation and adds restore 

1352 ops to the graph. This method is nonetheless important when graph building 

1353 for the case when a slot variable has already been created but `variable` 

1354 has just been added to a dependency graph (causing us to realize that the 

1355 slot variable needs to be restored). 

1356 

1357 Args: 

1358 slot_variable_position: A `trackable._CheckpointPosition` object 

1359 indicating the slot variable `Trackable` object to be restored. 

1360 slot_name: The name of this `Optimizer`'s slot to restore into. 

1361 variable: The variable object this slot is being created for. 

1362 """ 

1363 variable_key = _var_key(variable) 

1364 slot_dict = self._slots.get(variable_key, {}) 

1365 slot_variable = slot_dict.get(slot_name, None) 

1366 if (slot_variable is None and context.executing_eagerly() and 

1367 slot_variable_position.is_simple_variable() 

1368 # Defer slot variable creation if there is an active variable creator 

1369 # scope. Generally we'd like to eagerly create/restore slot variables 

1370 # when possible, but this may mean that scopes intended to catch 

1371 # `variable` also catch its eagerly created slot variable 

1372 # unintentionally (specifically make_template would add a dependency on 

1373 # a slot variable if not for this case). Deferring is mostly harmless 

1374 # (aside from double initialization), and makes variable creator scopes 

1375 # behave the same way they do when graph building. 

1376 # 

1377 # One notable case is with distribution strategy, which uses variable 

1378 # creator scope but always desires the `variable` and the slot to use 

1379 # the same scope, thus we can safely eagerly create/restore slot 

1380 # variables. 

1381 and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access 

1382 self._distribution_strategy)): 

1383 initializer = trackable.CheckpointInitialValueCallable( 

1384 checkpoint_position=slot_variable_position) 

1385 slot_variable = self.add_slot( 

1386 var=variable, 

1387 initializer=initializer, 

1388 slot_name=slot_name, 

1389 shape=slot_variable_position.value_shape()) 

1390 # Slot variables are not owned by any one object (because we don't want to 

1391 # save the slot variable if the optimizer is saved without the non-slot 

1392 # variable, or if the non-slot variable is saved without the optimizer; 

1393 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 

1394 # variable, variable)). So we don't _track_ slot variables anywhere, and 

1395 # instead special-case this dependency and otherwise pretend it's a normal 

1396 # graph. 

1397 if slot_variable is not None: 

1398 # If we've either made this slot variable, or if we've pulled out an 

1399 # existing slot variable, we should restore it. 

1400 slot_variable_position.restore(slot_variable) 

1401 else: 

1402 # We didn't make the slot variable. Defer restoring until it gets created 

1403 # normally. We keep a list rather than the one with the highest restore 

1404 # UID in case slot variables have their own dependencies, in which case 

1405 # those could differ between restores. 

1406 self._deferred_slot_restorations.setdefault( 

1407 slot_name, {}).setdefault(variable_key, []).append( 

1408 slot_variable_position) 

1409 

1410 @contextlib.contextmanager 

1411 def _distribution_strategy_scope(self): 

1412 """Returns the `tf.distribute.Strategy` this optimizer was created under.""" 

1413 if self._distribution_strategy and not distribute_lib.has_strategy(): 

1414 with self._distribution_strategy.scope(): 

1415 yield self._distribution_strategy.scope() 

1416 else: 

1417 yield 

1418 

1419 

1420def _var_key(var): 

1421 """Key for representing a primary variable, for looking up slots. 

1422 

1423 In graph mode the name is derived from the var shared name. 

1424 In eager mode the name is derived from the var unique id. 

1425 If distribution strategy exists, get the primary variable first. 

1426 

1427 Args: 

1428 var: the variable. 

1429 

1430 Returns: 

1431 the unique name of the variable. 

1432 """ 

1433 

1434 # pylint: disable=protected-access 

1435 # Get the distributed variable if it exists. 

1436 if hasattr(var, "_distributed_container"): 

1437 var = var._distributed_container() 

1438 if var._in_graph_mode: 

1439 return var._shared_name 

1440 return var._unique_id 

1441 

1442 

1443def _get_slot_key_from_var(var, slot_name): 

1444 """Get the slot key for the variable: var_name/slot_name.""" 

1445 

1446 name = _var_key(var) 

1447 return name + "/" + slot_name 

1448 

1449 

1450class RestoredOptimizer(OptimizerV2): 

1451 """A non-functional Optimizer implementation for checkpoint compatibility. 

1452 

1453 Holds slot variables and hyperparameters when an optimizer is restored from a 

1454 SavedModel. These variables may be referenced in functions along with ops 

1455 created by the original optimizer, but currently we do not support using the 

1456 optimizer object iself (e.g. through `apply_gradients`). 

1457 """ 

1458 # TODO(allenl): Make the restored optimizer functional by tracing its apply 

1459 # methods. 

1460 

1461 def __init__(self): 

1462 super(RestoredOptimizer, self).__init__("RestoredOptimizer") 

1463 self._hypers_created = True 

1464 

1465 def get_config(self): 

1466 # TODO(allenl): Save and restore the Optimizer's config 

1467 raise NotImplementedError( 

1468 "Restoring functional Optimizers from SavedModels is not currently " 

1469 "supported. Please file a feature request if this limitation bothers " 

1470 "you.") 

1471 

1472revived_types.register_revived_type( 

1473 "tf_deprecated_optimizer", 

1474 lambda obj: isinstance(obj, OptimizerV2), 

1475 versions=[revived_types.VersionedTypeRegistration( 

1476 object_factory=lambda proto: RestoredOptimizer(), 

1477 version=1, 

1478 min_producer_version=1, 

1479 min_consumer_version=1, 

1480 setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access 

1481 )])