Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py: 32%

415 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the loss scaling optimizer class.""" 

16 

17from tensorflow.python.distribute import collective_all_reduce_strategy 

18from tensorflow.python.distribute import distribute_lib 

19from tensorflow.python.distribute import mirrored_strategy 

20from tensorflow.python.distribute import one_device_strategy 

21from tensorflow.python.distribute import tpu_strategy 

22from tensorflow.python.eager import backprop 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import indexed_slices 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import smart_cond 

28from tensorflow.python.framework import tensor_conversion 

29from tensorflow.python.keras import backend 

30from tensorflow.python.keras import optimizers 

31from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module 

32from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

33from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 

34from tensorflow.python.ops import cond 

35from tensorflow.python.ops import control_flow_ops 

36from tensorflow.python.ops import math_ops 

37from tensorflow.python.ops import variable_v1 

38from tensorflow.python.ops import variables 

39from tensorflow.python.platform import tf_logging 

40from tensorflow.python.trackable import base as trackable 

41from tensorflow.python.trackable import base_delegate 

42from tensorflow.python.training.experimental import loss_scale as loss_scale_module 

43from tensorflow.python.training.experimental import mixed_precision 

44from tensorflow.python.util import nest 

45from tensorflow.python.util.tf_export import keras_export 

46 

47 

48class _UnwrapPreventer(object): 

49 """Wrapper that DistributionStrategy will not unwrap. 

50 

51 Typically, DistributionStrategy will unwrap values when going from a cross- 

52 replica context to a replica context via `call_for_each_replica`. This class 

53 is a wrapper that DistributionStrategy will not unwrap, so it can be used to 

54 prevent it from unwrapping a value. 

55 

56 TODO(reedwm): Find/implement a better way of preventing values from being 

57 unwrapped by DistributionStrategy 

58 """ 

59 

60 __slots__ = ['value'] 

61 

62 def __init__(self, value): 

63 self.value = value 

64 

65 

66def _is_all_finite(grads): 

67 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 

68 def raw_values(g): 

69 return g.values if isinstance(g, indexed_slices.IndexedSlices) else g 

70 

71 is_finite_per_grad = [ 

72 math_ops.reduce_all(math_ops.is_finite(raw_values(g))) 

73 for g in grads 

74 if g is not None 

75 ] 

76 return math_ops.reduce_all(is_finite_per_grad) 

77 

78 

79def _op_in_graph_mode(tensor): 

80 """Returns the tensor's op in graph mode, or the tensor in eager mode. 

81 

82 This is useful because sometimes an op is needed in graph mode instead of a 

83 tensor. In eager mode, there are no ops. 

84 

85 Args: 

86 tensor: A tensor. 

87 

88 Returns: 

89 The tensor's op in graph mode. The tensor in eager mode. 

90 """ 

91 if context.executing_eagerly(): 

92 return tensor 

93 return tensor.op 

94 

95 

96def _assign_if_finite(var, value): 

97 """Assigns a value to a variable if the value is finite.""" 

98 return cond.cond( 

99 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 

100 control_flow_ops.no_op) 

101 

102 

103class _DynamicLossScaleState(trackable.Trackable): 

104 """The state of a dynamic loss scale.""" 

105 

106 def __init__(self, 

107 initial_loss_scale, 

108 growth_steps, 

109 multiplier): 

110 """Creates the dynamic loss scale.""" 

111 super(_DynamicLossScaleState, self).__init__() 

112 self._initial_loss_scale = float(initial_loss_scale) 

113 self._growth_steps = int(growth_steps) 

114 self._multiplier = float(multiplier) 

115 

116 self._weights = {} 

117 self._current_loss_scale = self._add_weight( 

118 name='current_loss_scale', 

119 dtype=dtypes.float32, 

120 initial_value=self._initial_loss_scale) 

121 # The number of consecutive steps with finite gradients since the last 

122 # nonfinite gradient or change in loss scale. The name is 'good_steps' for 

123 # backwards compatibility with older checkpoints. 

124 self._counter = self._add_weight( 

125 name='good_steps', dtype=dtypes.int64, initial_value=0) 

126 

127 def _add_weight(self, name, initial_value, dtype=None): 

128 """Adds a weight to this loss scale. 

129 

130 Args: 

131 name: Variable name. 

132 initial_value: The variable's initial value. 

133 dtype: The type of the variable. 

134 

135 Returns: 

136 A variable. 

137 

138 Raises: 

139 RuntimeError: If a weight with `name` has already been added. 

140 """ 

141 variable = variable_v1.VariableV1( 

142 initial_value=initial_value, 

143 name=name, 

144 dtype=dtype, 

145 trainable=False, 

146 use_resource=True, 

147 synchronization=variables.VariableSynchronization.AUTO, 

148 # Set aggregation to NONE, as loss scaling variables should never be 

149 # aggregated. 

150 aggregation=variables.VariableAggregation.NONE) 

151 if context.executing_eagerly(): 

152 graph_key = None 

153 else: 

154 graph = ops.get_default_graph() 

155 graph_key = graph._graph_key # pylint: disable=protected-access 

156 

157 key = (name, graph_key) 

158 self._weights[key] = variable 

159 self._handle_deferred_dependencies(name=name, trackable=variable) 

160 backend.track_variable(variable) 

161 return variable 

162 

163 def _trackable_children(self, 

164 save_type=trackable.SaveType.CHECKPOINT, 

165 **kwargs): 

166 """From Trackable. Gather graph-specific weights to save.""" 

167 if context.executing_eagerly(): 

168 graph_key = None 

169 else: 

170 graph = ops.get_default_graph() 

171 graph_key = graph._graph_key # pylint: disable=protected-access 

172 weights = {} 

173 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 

174 if g == graph_key: 

175 weights[name] = v 

176 weights.update( 

177 super(_DynamicLossScaleState, 

178 self)._trackable_children(save_type, **kwargs)) 

179 return weights 

180 

181 def _lookup_dependency(self, name): 

182 """From Trackable. Find a weight in the current graph.""" 

183 unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name) 

184 if unconditional is not None: 

185 return unconditional 

186 if context.executing_eagerly(): 

187 graph_key = None 

188 else: 

189 graph = ops.get_default_graph() 

190 graph_key = graph._graph_key # pylint: disable=protected-access 

191 return self._weights.get((name, graph_key), None) 

192 

193 @property 

194 def initial_loss_scale(self): 

195 return self._initial_loss_scale 

196 

197 @property 

198 def growth_steps(self): 

199 return self._growth_steps 

200 

201 @property 

202 def multiplier(self): 

203 return self._multiplier 

204 

205 @property 

206 def current_loss_scale(self): 

207 """Returns the current loss scale as a float32 `tf.Variable`.""" 

208 return self._current_loss_scale 

209 

210 @property 

211 def counter(self): 

212 """Returns the counter as a float32 `tf.Variable`.""" 

213 return self._counter 

214 

215 def __call__(self): 

216 """Returns the current loss scale as a scalar `float32` tensor.""" 

217 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

218 self._current_loss_scale 

219 ) 

220 

221 def update(self, grads): 

222 """Updates the value of the loss scale. 

223 

224 Args: 

225 grads: A nested structure of unscaled gradients, each which is an 

226 all-reduced gradient of the loss with respect to a weight. 

227 

228 Returns: 

229 update_op: In eager mode, None. In graph mode, an op to update the loss 

230 scale. 

231 should_apply_gradients: Either a bool or a scalar boolean tensor. If 

232 False, the caller should skip applying `grads` to the variables this 

233 step. 

234 """ 

235 grads = nest.flatten(grads) 

236 if distribute_lib.has_strategy( 

237 ) and distribute_lib.in_cross_replica_context(): 

238 distribution = distribute_lib.get_strategy() 

239 is_finite_per_replica = distribution.extended.call_for_each_replica( 

240 _is_all_finite, args=(grads,)) 

241 # Each replica computed the same `is_finite` value, since `grads` is 

242 # all-reduced across replicas. Arbitrarily take `is_finite` from the first 

243 # replica. 

244 is_finite = ( 

245 distribution.experimental_local_results(is_finite_per_replica)[0]) 

246 else: 

247 is_finite = _is_all_finite(grads) 

248 

249 def update_if_finite_grads(): 

250 """Update assuming the gradients are finite.""" 

251 

252 def incr_loss_scale(): 

253 new_loss_scale = self.current_loss_scale * self.multiplier 

254 return control_flow_ops.group( 

255 _assign_if_finite(self.current_loss_scale, new_loss_scale), 

256 self.counter.assign(0)) 

257 

258 return cond.cond( 

259 self.counter + 1 >= self.growth_steps, 

260 incr_loss_scale, 

261 lambda: _op_in_graph_mode(self.counter.assign_add(1))) 

262 

263 def update_if_not_finite_grads(): 

264 """Update assuming the gradients are nonfinite.""" 

265 

266 new_loss_scale = math_ops.maximum( 

267 self.current_loss_scale / self.multiplier, 1) 

268 return control_flow_ops.group( 

269 self.counter.assign(0), 

270 self.current_loss_scale.assign(new_loss_scale)) 

271 

272 update_op = cond.cond(is_finite, update_if_finite_grads, 

273 update_if_not_finite_grads) 

274 should_apply_gradients = is_finite 

275 return update_op, should_apply_gradients 

276 

277 

278# See LossScaleOptimizer docstring for why this is so big 

279_DEFAULT_INITIAL_SCALE = 2 ** 15 

280_DEFAULT_GROWTH_STEPS = 2000 

281 

282 

283# pylint: disable=g-classes-have-attributes 

284@keras_export('keras.mixed_precision.LossScaleOptimizer') 

285class LossScaleOptimizer(base_delegate.DelegatingTrackableMixin, 

286 optimizer_v2.OptimizerV2): 

287 """An optimizer that applies loss scaling to prevent numeric underflow. 

288 

289 Loss scaling is a technique to prevent numeric underflow in intermediate 

290 gradients when float16 is used. To prevent underflow, the loss is multiplied 

291 (or "scaled") by a certain factor called the "loss scale", which causes 

292 intermediate gradients to be scaled by the loss scale as well. The final 

293 gradients are divided (or "unscaled") by the loss scale to bring them back to 

294 their original value. 

295 

296 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. 

297 By default, the loss scale is dynamically updated over time so you do not have 

298 to choose the loss scale. The `minimize` method automatically scales the loss, 

299 unscales the gradients, and updates the loss scale so all you have to do is 

300 wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For 

301 example: 

302 

303 >>> opt = tf.keras.optimizers.SGD(0.25) 

304 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 

305 >>> var = tf.Variable(1.) 

306 >>> loss_fn = lambda: var ** 2 

307 >>> # 'minimize' applies loss scaling and updates the loss sale. 

308 >>> opt.minimize(loss_fn, var_list=var) 

309 >>> var.numpy() 

310 0.5 

311 

312 If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you 

313 must scale the loss and gradients manually. This can be done with the 

314 `LossScaleOptimizer.get_scaled_loss` and 

315 `LossScaleOptimizer.get_unscaled_gradients` methods. For example: 

316 

317 >>> with tf.GradientTape() as tape: 

318 ... loss = loss_fn() 

319 ... scaled_loss = opt.get_scaled_loss(loss) 

320 >>> scaled_grad = tape.gradient(scaled_loss, var) 

321 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad]) 

322 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here 

323 >>> var.numpy() 

324 0.25 

325 

326 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients` 

327 (or both) when using a `tf.GradientTape`, the model will likely converge to a 

328 worse quality. Please make sure you call each function exactly once. 

329 

330 When mixed precision with float16 is used, there is typically no risk of 

331 underflow affecting model quality if loss scaling is properly used. See 

332 [the mixed precision guide]( 

333 https://www.tensorflow.org/guide/keras/mixed_precision) for more information 

334 on how to use mixed precision. 

335 

336 Args: 

337 inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap. 

338 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to 

339 True. If True, the loss scale will be dynamically updated over time using 

340 an algorithm that keeps the loss scale at approximately its optimal value. 

341 If False, a single fixed loss scale is used and `initial_scale` must be 

342 specified, which is used as the loss scale. Recommended to keep as True, 

343 as choosing a fixed loss scale can be tricky. Currently, there is a small 

344 performance overhead to dynamic loss scaling compared to fixed loss 

345 scaling. 

346 initial_scale: The initial loss scale. If `dynamic` is True, this defaults 

347 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as 

348 the sole loss scale, as the loss scale does not change over time. When 

349 dynamic loss scaling is used, is better for this to be a very high number, 

350 because a loss scale that is too high gets lowered far more quickly than a 

351 loss scale that is too low gets raised. 

352 dynamic_growth_steps: With dynamic loss scaling, every 

353 `dynamic_growth_steps` steps with finite gradients, the loss scale is 

354 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the 

355 count is reset back to zero, gradients are skipped that step, and the loss 

356 scale is halved. The count can be queried with 

357 `LossScaleOptimizer.dynamic_counter`. This argument can only be specified 

358 if `dynamic` is True. 

359 

360 `LossScaleOptimizer` will occasionally skip applying gradients to the 

361 variables, in which case the trainable variables will not change that step. 

362 This is done because the dynamic loss scale will sometimes be raised too 

363 high, causing overflow in the gradients. Typically, the first 2 to 15 steps of 

364 the model are skipped as the initial loss scale is very high, but afterwards 

365 steps will only be skipped on average 0.05% of the time (the fraction of steps 

366 skipped is `1 / dynamic_growth_steps`). 

367 

368 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner 

369 optimizer. Additionally, in methods `minimize` and `get_gradients`, it scales 

370 the loss and unscales the gradients. In methods `minimize` and 

371 `apply_gradients`, it additionally updates the loss scale and skips applying 

372 gradients if any gradient has a nonfinite value. 

373 

374 ### Hyperparameters 

375 

376 Hyperparameters can be accessed and set on the LossScaleOptimizer, which will 

377 be delegated to the wrapped optimizer. 

378 

379 >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) 

380 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 

381 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1` 

382 0.8 

383 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7` 

384 >>> opt.beta_1 

385 0.7 

386 >>> opt.inner_optimizer.beta_1 

387 0.7 

388 

389 However, accessing or setting non-hyperparameters is not delegated to the 

390 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but 

391 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on 

392 `beta_1`. 

393 

394 >>> opt.inner_optimizer.epsilon 

395 1e-5 

396 >>> opt.epsilon 

397 Traceback (most recent call last): 

398 ... 

399 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' 

400 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer` 

401 >>> opt.inner_optimizer.epsilon 

402 >>> 1e-5 

403 

404 In the above example, despite epsilon being set on the LossScaleOptimizer, the 

405 old epsilon value will still be used when training as epsilon was not set on 

406 the inner optimizer. 

407 """ 

408 

409 _HAS_AGGREGATE_GRAD = True 

410 

411 def __init__(self, inner_optimizer, dynamic=True, initial_scale=None, 

412 dynamic_growth_steps=None): 

413 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 

414 raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, ' 

415 'but got: %s' % inner_optimizer) 

416 if not isinstance(dynamic, bool): 

417 # Catch errors if a user incorrectly passes a string or float to the 

418 # second argument argument, as this is commonly done for 

419 # LossScaleOptimizerV1. 

420 raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must ' 

421 'be a bool, but got: %r' % (dynamic,)) 

422 if isinstance(inner_optimizer, LossScaleOptimizer): 

423 raise TypeError('LossScaleOptimizer cannot wrap another ' 

424 'LossScaleOptimizer, but got: %s' % (inner_optimizer,)) 

425 self._raise_if_strategy_unsupported() 

426 if getattr(inner_optimizer, '_is_wrapped_by_loss_scale_optimizer', False): 

427 # TODO(reedwm): Maybe support this. The difficulty is that LSO has the 

428 # same checkpoint format as the inner optimizer, so multiple LSOs wrapping 

429 # the same optimizer causes the checkpointing logic to become confused. 

430 raise ValueError('"inner_optimizer" is already wrapped by a ' 

431 'LossScaleOptimizer. An optimizer can only be wrapped ' 

432 'by a single LossScaleOptimizer') 

433 self._optimizer = inner_optimizer 

434 self._optimizer._is_wrapped_by_loss_scale_optimizer = True 

435 

436 # We don't call super().__init__, since we do not want to call OptimizerV2's 

437 # constructor. 

438 base_delegate.DelegatingTrackableMixin.__init__(self, self._optimizer) 

439 

440 if dynamic: 

441 if initial_scale is None: 

442 initial_scale = _DEFAULT_INITIAL_SCALE 

443 if dynamic_growth_steps is None: 

444 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS 

445 self._loss_scale = _DynamicLossScaleState( 

446 initial_scale, dynamic_growth_steps, multiplier=2) 

447 self._track_trackable(self._loss_scale, 'loss_scale') 

448 else: 

449 if initial_scale is None: 

450 raise ValueError('"initial_scale" must be specified if "dynamic" is ' 

451 'False') 

452 self._loss_scale = float(initial_scale) 

453 if dynamic_growth_steps is not None: 

454 raise ValueError('"dynamic_growth_steps" must be None if "dynamic" ' 

455 'is False, but got: %s' % (dynamic_growth_steps,)) 

456 

457 # To support restoring TensorFlow 2.2 checkpoints. 

458 self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 

459 'base_optimizer') 

460 

461 @property 

462 def dynamic(self): 

463 """Bool indicating whether dynamic loss scaling is used.""" 

464 return isinstance(self._loss_scale, _DynamicLossScaleState) 

465 

466 @property 

467 def loss_scale(self): 

468 """The current loss scale as a float32 scalar tensor.""" 

469 if isinstance(self._loss_scale, _DynamicLossScaleState): 

470 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

471 self._loss_scale.current_loss_scale 

472 ) 

473 else: 

474 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

475 self._loss_scale 

476 ) 

477 

478 @property 

479 def dynamic_counter(self): 

480 """The number of steps since the loss scale was last increased or decreased. 

481 

482 This is None if `LossScaleOptimizer.dynamic` is False. 

483 

484 The counter is incremented every step. Once it reaches 

485 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled 

486 and the counter will be reset back to zero. If nonfinite gradients are 

487 encountered, the loss scale will be halved and the counter will be reset 

488 back to zero. 

489 """ 

490 if isinstance(self._loss_scale, _DynamicLossScaleState): 

491 return self._loss_scale.counter 

492 else: 

493 return None 

494 

495 @property 

496 def initial_scale(self): 

497 """The initial loss scale. 

498 

499 If `LossScaleOptimizer.dynamic` is False, this is the same number as 

500 `LossScaleOptimizer.loss_scale`, as the loss scale never changes. 

501 """ 

502 if isinstance(self._loss_scale, _DynamicLossScaleState): 

503 return self._loss_scale.initial_loss_scale 

504 else: 

505 return self._loss_scale 

506 

507 @property 

508 def dynamic_growth_steps(self): 

509 """The number of steps it takes to increase the loss scale. 

510 

511 This is None if `LossScaleOptimizer.dynamic` is False. 

512 

513 Every `dynamic_growth_steps` consecutive steps with finite gradients, the 

514 loss scale is increased. 

515 """ 

516 if isinstance(self._loss_scale, _DynamicLossScaleState): 

517 return self._loss_scale.growth_steps 

518 else: 

519 return None 

520 

521 @property 

522 def inner_optimizer(self): 

523 """The optimizer that this LossScaleOptimizer is wrapping.""" 

524 return self._optimizer 

525 

526 def get_scaled_loss(self, loss): 

527 """Scales the loss by the loss scale. 

528 

529 This method is only needed if you compute gradients manually, e.g. with 

530 `tf.GradientTape`. In that case, call this method to scale the loss before 

531 passing the loss to `tf.GradientTape`. If you use 

532 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 

533 scaling is automatically applied and this method is unneeded. 

534 

535 If this method is called, `get_unscaled_gradients` should also be called. 

536 See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for 

537 an example. 

538 

539 Args: 

540 loss: The loss, which will be multiplied by the loss scale. Can either be 

541 a tensor or a callable returning a tensor. 

542 

543 Returns: 

544 `loss` multiplied by `LossScaleOptimizer.loss_scale`. 

545 """ 

546 if callable(loss): 

547 def new_loss(): 

548 loss_val = loss() 

549 return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype) 

550 return new_loss 

551 else: 

552 return loss * math_ops.cast(self.loss_scale, loss.dtype) 

553 

554 def get_unscaled_gradients(self, grads): 

555 """Unscales the gradients by the loss scale. 

556 

557 This method is only needed if you compute gradients manually, e.g. with 

558 `tf.GradientTape`. In that case, call this method to unscale the gradients 

559 after computing them with `tf.GradientTape`. If you use 

560 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 

561 scaling is automatically applied and this method is unneeded. 

562 

563 If this method is called, `get_scaled_loss` should also be called. See 

564 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an 

565 example. 

566 

567 Args: 

568 grads: A list of tensors, each which will be divided by the loss scale. 

569 Can have None values, which are ignored. 

570 

571 Returns: 

572 A new list the same size as `grads`, where every non-None value in `grads` 

573 is divided by `LossScaleOptimizer.loss_scale`. 

574 """ 

575 loss_scale_reciprocal = 1. / self.loss_scale 

576 return [ 

577 _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None 

578 for g in grads 

579 ] 

580 

581 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 

582 tape = backprop.GradientTape() if tape is None else tape 

583 with tape: 

584 loss = self.get_scaled_loss(loss) 

585 grads_and_vars = self._optimizer._compute_gradients( # pylint: disable=protected-access 

586 loss, 

587 var_list, 

588 grad_loss, 

589 tape=tape) 

590 grads = [g for g, _ in grads_and_vars] 

591 weights = [v for _, v in grads_and_vars] 

592 unscaled_grads = self.get_unscaled_gradients(grads) 

593 return list(zip(unscaled_grads, weights)) 

594 

595 def get_gradients(self, loss, params): 

596 loss = self.get_scaled_loss(loss) 

597 grads = self._optimizer.get_gradients(loss, params) 

598 return self.get_unscaled_gradients(grads) 

599 

600 def _create_all_weights(self, var_list): 

601 self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access 

602 

603 def apply_gradients(self, 

604 grads_and_vars, 

605 name=None, 

606 experimental_aggregate_gradients=True): 

607 if distribute_lib.in_cross_replica_context(): 

608 raise ValueError('apply_gradients() must be called in a replica context.') 

609 # We check for the strategy here despite already checking in the constructor 

610 # as frequently the optimizer is created outside the strategy's scope. 

611 self._raise_if_strategy_unsupported() 

612 

613 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 

614 if experimental_aggregate_gradients: 

615 # We must aggregate the gradients here instead of in 

616 # self.optimizer.apply_gradients, so that any NaN or Inf gradients are 

617 # propogated to each replica. If any replica has a NaN or Inf gradient, 

618 # they must all have a NaN or Inf gradient so that they all skip the step. 

619 # pylint: disable=protected-access 

620 grads_and_vars = self._optimizer._transform_unaggregated_gradients( 

621 grads_and_vars) 

622 grads_and_vars = self._optimizer._aggregate_gradients(grads_and_vars) 

623 # pylint: enable=protected-access 

624 

625 grads_and_vars = tuple(grads_and_vars) 

626 grads = [g for g, _ in grads_and_vars] 

627 # We do not want DistributionStrategy to unwrap any MirroredVariables in 

628 # grads_and_vars, because even in a replica context, the wrapped 

629 # optimizer expects mirrored variables. So we wrap the variables with an 

630 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the 

631 # MirroredVariables. 

632 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) 

633 

634 def do_not_apply_fn(): 

635 # Normally self._optimizer.iterations is incremented in 

636 # self._optimizer.apply_gradients(). Since that is not called in this 

637 # branch, we increment it here instead. 

638 return self._optimizer.iterations.assign_add(1, read_value=False) 

639 

640 def _if_should_apply_grads(grads): 

641 if isinstance(self._loss_scale, _DynamicLossScaleState): 

642 return self._loss_scale.update(grads) 

643 else: 

644 return (control_flow_ops.no_op(), True) 

645 

646 if optimizer_utils.strategy_supports_no_merge_call(): 

647 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads) 

648 def apply_fn(): 

649 return self._apply_gradients(grads, wrapped_vars, name) 

650 

651 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 

652 do_not_apply_fn) 

653 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 

654 

655 else: 

656 

657 def _apply_gradients_cross_replica(distribution, grads, wrapped_vars, 

658 name): 

659 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads) 

660 

661 def apply_fn(): 

662 return distribution.extended.call_for_each_replica( 

663 self._apply_gradients, 

664 args=(grads, wrapped_vars, name)) 

665 

666 # Note: We must call this cond() in a cross-replica context. 

667 # DistributionStrategy does not support having a cond in a replica 

668 # context with a branch that calls `merge_call`, and 

669 # self._optimizer.apply_gradients calls `merge_call`. 

670 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 

671 do_not_apply_fn) 

672 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 

673 return distribute_lib.get_replica_context().merge_call( 

674 _apply_gradients_cross_replica, 

675 args=(grads, wrapped_vars, name)) 

676 

677 def _apply_gradients(self, grads, wrapped_vars, name): 

678 # Pass experimental_aggregate_gradients=False since LossScaleOptimizer 

679 # already aggregated the gradients. 

680 # TODO(reedwm): This will raise a fairly cryptic error message if 

681 # self._optimizer.apply_gradients does not take 

682 # experimental_aggregate_gradients. 

683 return self._optimizer.apply_gradients( 

684 list(zip(grads, wrapped_vars.value)), name, 

685 experimental_aggregate_gradients=False) 

686 

687 def get_config(self): 

688 serialized_optimizer = optimizers.serialize(self._optimizer) 

689 return { 

690 'inner_optimizer': serialized_optimizer, 

691 'dynamic': self.dynamic, 

692 'initial_scale': self.initial_scale, 

693 'dynamic_growth_steps': self.dynamic_growth_steps, 

694 } 

695 

696 @classmethod 

697 def from_config(cls, config, custom_objects=None): 

698 config = config.copy() # Make a copy, since we mutate config 

699 if 'loss_scale' in config: 

700 # If loss_scale is in config, we assume we are deserializing a 

701 # LossScaleOptimizer from TF 2.3 or below. We convert the config so it 

702 # can be deserialized in the current LossScaleOptimizer. 

703 loss_scale = keras_loss_scale_module.deserialize( 

704 config.pop('loss_scale')) 

705 if isinstance(loss_scale, loss_scale_module.FixedLossScale): 

706 config['dynamic'] = False 

707 config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access 

708 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 

709 config['dynamic'] = True 

710 config['initial_scale'] = loss_scale.initial_loss_scale 

711 config['dynamic_growth_steps'] = loss_scale.increment_period 

712 if loss_scale.multiplier != 2: 

713 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 

714 'DynamicLossScale whose multiplier is not 2. Got ' 

715 'DynamicLossScale: %s' % (loss_scale,)) 

716 else: 

717 raise ValueError( 

718 'Serialized LossScaleOptimizers with a LossScale that is neither a ' 

719 'FixedLossScale nor a DynamicLossScale can no longer be ' 

720 'deserialized') 

721 config['inner_optimizer'] = config.pop('optimizer') 

722 config['inner_optimizer'] = optimizers.deserialize( 

723 config['inner_optimizer'], custom_objects=custom_objects) 

724 return cls(**config) 

725 

726 def _raise_if_strategy_unsupported(self): 

727 if not strategy_supports_loss_scaling(): 

728 strategy = distribute_lib.get_strategy() 

729 if isinstance(strategy, 

730 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 

731 tpu_strategy.TPUStrategyV2)): 

732 raise ValueError( 

733 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 

734 'unnecessary with TPUs, since they support bfloat16 instead of ' 

735 'float16 and bfloat16 does not require loss scaling. You should ' 

736 'remove the use of the LossScaleOptimizer when TPUs are used.') 

737 else: 

738 raise ValueError('Loss scaling is not supported with the ' 

739 'tf.distribute.Strategy: %s. Try using a different ' 

740 'Strategy, e.g. a MirroredStrategy' % 

741 strategy.__class__.__name__) 

742 

743 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer 

744 # below. 

745 

746 @property 

747 def iterations(self): 

748 return self._optimizer.iterations 

749 

750 @iterations.setter 

751 def iterations(self, variable): 

752 self._optimizer.iterations = variable 

753 

754 def get_slot_names(self): 

755 return self._optimizer.get_slot_names() 

756 

757 def variables(self): 

758 return self._optimizer.variables() 

759 

760 @property 

761 def weights(self): 

762 return self._optimizer.weights 

763 

764 def get_weights(self): 

765 return self._optimizer.get_weights() 

766 

767 def set_weights(self, weights): 

768 return self._optimizer.set_weights(weights) 

769 

770 @property 

771 def clipnorm(self): 

772 return self._optimizer.clipnorm 

773 

774 @clipnorm.setter 

775 def clipnorm(self, val): 

776 self._optimizer.clipnorm = val 

777 

778 @property 

779 def global_clipnorm(self): 

780 return self._optimizer.global_clipnorm 

781 

782 @global_clipnorm.setter 

783 def global_clipnorm(self, val): 

784 self._optimizer.global_clipnorm = val 

785 

786 @property 

787 def clipvalue(self): 

788 return self._optimizer.clipvalue 

789 

790 @clipvalue.setter 

791 def clipvalue(self, val): 

792 self._optimizer.clipvalue = val 

793 

794 def _aggregate_gradients(self, grads_and_vars): 

795 return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access 

796 

797 def _restore_slot_variable(self, slot_name, variable, slot_variable): 

798 return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access 

799 slot_variable) 

800 

801 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 

802 variable): 

803 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 

804 slot_variable_position, slot_name, variable) 

805 

806 def get_slot(self, var, slot_name): 

807 return self._optimizer.get_slot(var, slot_name) 

808 

809 def add_slot(self, var, slot_name, initializer='zeros'): 

810 return self._optimizer.add_slot(var, slot_name, initializer) 

811 

812 def __getattribute__(self, name): 

813 try: 

814 return object.__getattribute__(self, name) 

815 except AttributeError as e: 

816 if name == '_optimizer' or name == '_hyper': 

817 # Avoid infinite recursion 

818 raise e 

819 

820 # Delegate hyperparameter accesses to inner optimizer. 

821 if name == 'lr': 

822 name = 'learning_rate' 

823 if name in self._optimizer._hyper: 

824 return self._optimizer._get_hyper(name) 

825 raise e 

826 

827 def __dir__(self): 

828 result = set(super(LossScaleOptimizer, self).__dir__()) 

829 if '_optimizer' in result: 

830 result |= self._optimizer._hyper.keys() 

831 if 'learning_rate' in self._optimizer._hyper.keys(): 

832 result.add('lr') 

833 return list(result) 

834 

835 def __setattr__(self, name, value): 

836 if name == 'lr': 

837 name = 'learning_rate' 

838 # Delegate setting hyperparameter to inner optimizer if the attribute does 

839 # not exist on the LossScaleOptimizer 

840 try: 

841 # We cannot check for the 'iterations' attribute as it cannot be set after 

842 # it is accessed. 

843 if name != 'iterations': 

844 object.__getattribute__(self, name) 

845 has_attribute = True 

846 except AttributeError: 

847 has_attribute = False 

848 if (name != '_optimizer' and name in self._optimizer._hyper 

849 and not has_attribute): 

850 self._optimizer._set_hyper(name, value) 

851 else: 

852 super(LossScaleOptimizer, self).__setattr__(name, value) 

853 

854 # Explicitly delegate learning_rate. Normally hyperparameters are delegated in 

855 # __getattribute__, but if a hyperparameter is not in self._optimizer._hyper 

856 # (e.g. because self._optimizer itself wraps another optimizer), then it won't 

857 # be delegated. Since learning_rate is a very commonly accessed 

858 # hyperparameter, we delegate it here. 

859 @property 

860 def learning_rate(self): 

861 return self._optimizer.learning_rate 

862 

863 @learning_rate.setter 

864 def learning_rate(self, value): 

865 self._optimizer.learning_rate = value 

866 

867 @property 

868 def lr(self): 

869 return self._optimizer.learning_rate 

870 

871 @lr.setter 

872 def lr(self, value): 

873 self._optimizer.lr = value 

874 

875 # We do not override some OptimizerV2 methods. For each, we describe why we do 

876 # not delegate them to self._optimizer: 

877 # * get_updates: get_updates() calls get_gradients(). Since we override 

878 # get_gradients(), we cannot delegate get_updates() to self._optimizer, 

879 # otherwise the overridden get_gradients() method would not be called. 

880 # Luckily, get_updates() does not access any OptimizerV2 fields, so 

881 # inheriting the OptimizerV2 version works fine. 

882 # * minimize: We don't delegate for a similar as get_updates(): it calls 

883 # both self._compute_gradients() and self.apply_gradients(), and both need 

884 # to have the LossScaleOptimizer version called. 

885 

886 # TODO(reedwm): Maybe throw an error if mixed precision is used without this 

887 # optimizer being used. 

888 

889 

890@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') 

891class LossScaleOptimizerV1(LossScaleOptimizer): 

892 """An deprecated optimizer that applies loss scaling. 

893 

894 Warning: This class is deprecated and will be removed in a future version of 

895 TensorFlow. Please use the non-experimental class 

896 `tf.keras.mixed_precision.LossScaleOptimizer` instead. 

897 

898 This class is identical to the non-experimental 

899 `keras.mixed_precision.LossScaleOptimizer` except its constructor takes 

900 different arguments. For this class (the experimental version), the 

901 constructor takes a `loss_scale` argument. For the non-experimental class, 

902 the constructor encodes the loss scaling information in multiple arguments. 

903 Note that unlike this class, the non-experimental class does not accept a 

904 `tf.compat.v1.mixed_precision.LossScale`, which is deprecated. 

905 

906 If you currently use this class, you should switch to the non-experimental 

907 `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several 

908 examples of converting the use of the experimental class to the equivalent 

909 non-experimental class. 

910 

911 >>> # In all of the examples below, `opt1` and `opt2` are identical 

912 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 

913 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 

914 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 

915 ... tf.keras.optimizers.SGD()) 

916 >>> assert opt1.get_config() == opt2.get_config() 

917 

918 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 

919 ... tf.keras.optimizers.SGD(), loss_scale=123) 

920 >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123 

921 >>> # refers to the initial loss scale, which is the single fixed loss scale 

922 >>> # when dynamic=False. 

923 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 

924 ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123) 

925 >>> assert opt1.get_config() == opt2.get_config() 

926 

927 >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale( 

928 ... initial_loss_scale=2048, increment_period=500) 

929 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 

930 ... tf.keras.optimizers.SGD(), loss_scale=loss_scale) 

931 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 

932 ... tf.keras.optimizers.SGD(), initial_scale=2048, 

933 ... dynamic_growth_steps=500) 

934 >>> assert opt1.get_config() == opt2.get_config() 

935 

936 Make sure to also switch from this class to the non-experimental class in 

937 isinstance checks, if you have any. If you do not do this, your model may run 

938 into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses 

939 the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to 

940 switch isinstance checks to the non-experimental `LossScaleOptimizer` even 

941 before using the non-experimental `LossScaleOptimizer`. 

942 

943 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 

944 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 

945 >>> # The experimental class subclasses the non-experimental class 

946 >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer) 

947 True 

948 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 

949 ... tf.keras.optimizers.SGD()) 

950 >>> # The non-experimental class does NOT subclass the experimental class. 

951 >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer) 

952 False 

953 

954 Args: 

955 optimizer: The Optimizer instance to wrap. 

956 loss_scale: The loss scale to scale the loss and gradients. This can 

957 either be an int/float to use a fixed loss scale, the string "dynamic" 

958 to use dynamic loss scaling, or an instance of a LossScale. The string 

959 "dynamic" equivalent to passing `DynamicLossScale()`, and passing an 

960 int/float is equivalent to passing a FixedLossScale with the given loss 

961 scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must 

962 be 2 (the default). 

963 """ 

964 

965 def __init__(self, optimizer, loss_scale): 

966 warn_msg_prefix = ( 

967 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' 

968 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' 

969 'instead. ') 

970 

971 if isinstance(loss_scale, dict): 

972 loss_scale = keras_loss_scale_module.deserialize(loss_scale) 

973 

974 if isinstance(loss_scale, (int, float)): 

975 tf_logging.warning( 

976 warn_msg_prefix + 'For example:\n' 

977 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 

978 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) 

979 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 

980 initial_scale=loss_scale) 

981 elif isinstance(loss_scale, loss_scale_module.FixedLossScale): 

982 ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access 

983 tf_logging.warning( 

984 warn_msg_prefix + 'For example:\n' 

985 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 

986 'opt, dynamic=False, initial_scale={})'.format(ls_val)) 

987 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 

988 initial_scale=ls_val) 

989 elif loss_scale == 'dynamic': 

990 tf_logging.warning( 

991 warn_msg_prefix + 'For example:\n' 

992 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 

993 'opt)') 

994 super(LossScaleOptimizerV1, self).__init__(optimizer) 

995 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 

996 kwargs = {} 

997 extra_arguments = '' 

998 if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: 

999 kwargs['initial_scale'] = loss_scale.initial_loss_scale 

1000 extra_arguments += (', initial_scale=%s' % 

1001 loss_scale.initial_loss_scale) 

1002 if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: 

1003 kwargs['dynamic_growth_steps'] = loss_scale.increment_period 

1004 extra_arguments += (', dynamic_growth_steps=%s' % 

1005 loss_scale.increment_period) 

1006 if loss_scale.multiplier != 2: 

1007 raise ValueError('When passing a DynamicLossScale to "loss_scale", ' 

1008 'DynamicLossScale.multiplier must be 2. Got: %s' 

1009 % (loss_scale,)) 

1010 tf_logging.warning( 

1011 warn_msg_prefix + 

1012 'Note that the non-experimental LossScaleOptimizer does not take a ' 

1013 'DynamicLossScale but instead takes the dynamic configuration ' 

1014 'directly in the constructor. For example:\n' 

1015 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 

1016 'opt{})\n'.format(extra_arguments)) 

1017 super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) 

1018 elif isinstance(loss_scale, loss_scale_module.LossScale): 

1019 raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' 

1020 'DynamicLossScale is no longer supported. Got: {}' 

1021 .format(loss_scale)) 

1022 else: 

1023 raise ValueError('Invalid value passed to loss_scale. loss_scale ' 

1024 'must be the string "dynamic" (recommended), an int, ' 

1025 'a float, a FixedLossScale, or a DynamicLossScale. Got ' 

1026 'value: {}'.format(loss_scale)) 

1027 

1028 @classmethod 

1029 def from_config(cls, config, custom_objects=None): 

1030 config = config.copy() # Make a copy, since we mutate config 

1031 

1032 # If loss_scale is in config, we assume we are deserializing a 

1033 # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are 

1034 # deserializing a LossScaleOptimizer from TF 2.4 or above. 

1035 if 'loss_scale' in config: 

1036 config['loss_scale'] = keras_loss_scale_module.deserialize( 

1037 config['loss_scale']) 

1038 if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) 

1039 and config['loss_scale'].multiplier != 2): 

1040 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 

1041 'DynamicLossScale whose multiplier is not 2. Got ' 

1042 'DynamicLossScale: %s' % (config['loss_scale'],)) 

1043 config['optimizer'] = optimizers.deserialize( 

1044 config['optimizer'], custom_objects=custom_objects) 

1045 return cls(**config) 

1046 

1047 # We convert the config, as generated by LossScaleOptimizer.get_config, to a 

1048 # version that can be passed to LossScaleOptimizerV1.__init__ 

1049 if config['dynamic']: 

1050 config['loss_scale'] = loss_scale_module.DynamicLossScale( 

1051 config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) 

1052 else: 

1053 config['loss_scale'] = loss_scale_module.FixedLossScale( 

1054 config['initial_scale']) 

1055 

1056 del config['dynamic'] 

1057 del config['initial_scale'] 

1058 del config['dynamic_growth_steps'] 

1059 config['optimizer'] = optimizers.deserialize( 

1060 config.pop('inner_optimizer'), custom_objects=custom_objects) 

1061 return cls(**config) 

1062 

1063 

1064class FakeOptimizerForRestoration(trackable.Trackable): 

1065 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints. 

1066 

1067 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class 

1068 exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow. 

1069 

1070 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the 

1071 following in LossScaleOptimizer.__init__ 

1072 

1073 ``` 

1074 self._track_trackable(self._optimizer, 'base_optimizer') 

1075 ``` 

1076 

1077 This means a dependency from the LossScaleOptimizer to the wrapped optimizer 

1078 would be stored in the checkpoint. However now, the checkpoint format with a 

1079 LossScaleOptimizer is the same as the format without a LossScaleOptimizer, 

1080 except the loss scale is also stored. This means there is no dependency from 

1081 the LossScaleOptimizer to the wrapped optimizer. Instead, the 

1082 LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's 

1083 perspective, by overriding all Trackable methods and delegating them to the 

1084 wrapped optimizer. 

1085 

1086 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency 

1087 on this class instead of the inner optimizer. When restored, this class will 

1088 instead restore the slot variables of the inner optimizer. Since this class 

1089 has no variables, it does not affect the checkpoint when saved. 

1090 """ 

1091 

1092 def __init__(self, optimizer): 

1093 self._optimizer = optimizer 

1094 

1095 def get_slot_names(self): 

1096 return self._optimizer.get_slot_names() 

1097 

1098 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 

1099 variable): 

1100 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 

1101 slot_variable_position, slot_name, variable) 

1102 

1103 

1104mixed_precision.register_loss_scale_wrapper(optimizer_v2.OptimizerV2, 

1105 LossScaleOptimizerV1) 

1106 

1107 

1108def _multiply_gradient(gradient, scale): 

1109 """Multiply a (possibly sparse) gradient by the given scale factor.""" 

1110 scale = math_ops.cast(scale, gradient.dtype) 

1111 if isinstance(gradient, indexed_slices.IndexedSlices): 

1112 return indexed_slices.IndexedSlices( 

1113 gradient.values * scale, 

1114 gradient.indices, 

1115 dense_shape=gradient.dense_shape) 

1116 else: 

1117 return gradient * scale 

1118 

1119 

1120def strategy_supports_loss_scaling(): 

1121 """Returns True if the current Strategy supports loss scaling.""" 

1122 if not distribute_lib.has_strategy(): 

1123 return True 

1124 strategy = distribute_lib.get_strategy() 

1125 # Strategies are supported if either there is only one replica or if variables 

1126 # are replicated per device. Otherwise, the current model.fit() implementation 

1127 # and most custom training loops incorrectly unscale the gradients. Currently, 

1128 # gradients are unscaled once per compute replica, but they should be unscaled 

1129 # once per variable replica. When there is one variable replica for each 

1130 # compute replica, this works fine, but otherwise issues will occur. 

1131 # TODO(reedwm): Support all strategies. 

1132 return isinstance(strategy, ( 

1133 collective_all_reduce_strategy.CollectiveAllReduceStrategy, 

1134 collective_all_reduce_strategy.CollectiveAllReduceStrategyV1, 

1135 one_device_strategy.OneDeviceStrategy, 

1136 one_device_strategy.OneDeviceStrategyV1, 

1137 mirrored_strategy.MirroredStrategy, 

1138 mirrored_strategy.MirroredStrategyV1, 

1139 ))