Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/loss_scale_optimizer.py: 30%

569 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the loss scaling optimizer class.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src import backend 

20from keras.src import optimizers 

21from keras.src.dtensor import utils as dtensor_utils 

22from keras.src.optimizers import optimizer 

23from keras.src.optimizers import utils as optimizer_utils 

24from keras.src.optimizers.legacy import optimizer_v2 

25from keras.src.saving import serialization_lib 

26 

27# isort: off 

28from tensorflow.python.platform import tf_logging 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32class _UnwrapPreventer: 

33 """Wrapper that DistributionStrategy will not unwrap. 

34 

35 Typically, DistributionStrategy will unwrap values when going from a cross- 

36 replica context to a replica context via `call_for_each_replica`. This class 

37 is a wrapper that DistributionStrategy will not unwrap, so it can be used to 

38 prevent it from unwrapping a value. 

39 

40 TODO(reedwm): Find/implement a better way of preventing values from being 

41 unwrapped by DistributionStrategy 

42 """ 

43 

44 __slots__ = ["value"] 

45 

46 def __init__(self, value): 

47 self.value = value 

48 

49 

50def _is_all_finite(grads): 

51 """Returns a scalar boolean tensor indicating if all gradients are 

52 finite.""" 

53 

54 def raw_values(g): 

55 return g.values if isinstance(g, tf.IndexedSlices) else g 

56 

57 is_finite_per_grad = [ 

58 tf.reduce_all(tf.math.is_finite(raw_values(g))) 

59 for g in grads 

60 if g is not None 

61 ] 

62 return tf.reduce_all(is_finite_per_grad) 

63 

64 

65def _op_in_graph_mode(tensor): 

66 """Returns the tensor's op in graph mode, or the tensor in eager mode. 

67 

68 This is useful because sometimes an op is needed in graph mode instead of a 

69 tensor. In eager mode, there are no ops. 

70 

71 Args: 

72 tensor: A tensor. 

73 

74 Returns: 

75 The tensor's op in graph mode. The tensor in eager mode. 

76 """ 

77 if tf.executing_eagerly(): 

78 return tensor 

79 return tensor.op 

80 

81 

82def _assign_if_finite(var, value): 

83 """Assigns a value to a variable if the value is finite.""" 

84 return tf.cond( 

85 tf.math.is_finite(value), 

86 lambda: _op_in_graph_mode(var.assign(value)), 

87 tf.no_op, 

88 ) 

89 

90 

91def _maybe_warn_about_scaling( 

92 loss_has_been_scaled, gradients_have_been_unscaled 

93): 

94 """Warn if the loss or gradients hasn't been scaled or unscaled.""" 

95 if loss_has_been_scaled and gradients_have_been_unscaled: 

96 return 

97 

98 example_code = """ 

99 with tf.GradientTape() as tape: 

100 loss = loss_fn() 

101 scaled_loss = opt.get_scaled_loss(loss) 

102 scaled_grads = tape.gradient(scaled_loss, vars) 

103 grads = opt.get_unscaled_gradients(scaled_grads) 

104 opt.apply_gradients([(grads, var)])""" 

105 

106 if not loss_has_been_scaled and not gradients_have_been_unscaled: 

107 tf_logging.warning( 

108 "You forgot to call LossScaleOptimizer.get_scaled_loss() and " 

109 "LossScaleOptimizer.get_unscaled_gradients() before calling " 

110 "LossScaleOptimizer.apply_gradients(). This will likely result in " 

111 "worse model quality, so please call them in the correct places! " 

112 f"For example:{example_code}\nFor more information, see " 

113 "https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/LossScaleOptimizer" # noqa: E501 

114 ) 

115 elif not loss_has_been_scaled: 

116 tf_logging.warning( 

117 "You forgot to call LossScaleOptimizer.get_scaled_loss() before " 

118 "calling LossScaleOptimizer.apply_gradients() (you did call " 

119 "get_unscaled_gradients() however). This will likely result in " 

120 "worse model quality, so please call get_scaled_loss() in the " 

121 f"correct place! For example:{example_code}\nFor more information, " 

122 "see " 

123 "https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/LossScaleOptimizer" # noqa: E501 

124 ) 

125 elif not gradients_have_been_unscaled: 

126 tf_logging.warning( 

127 "You forgot to call LossScaleOptimizer.get_unscaled_gradients() " 

128 "before calling LossScaleOptimizer.apply_gradients() (you did call " 

129 "get_scaled_loss() however). This will likely result in worse " 

130 "model quality, so please call get_unscaled_gradients() in the " 

131 f"correct place! For example:{example_code}\nFor more information, " 

132 "see " 

133 "https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/LossScaleOptimizer" # noqa: E501 

134 ) 

135 

136 

137class _DynamicLossScaleState(tf.__internal__.tracking.Trackable): 

138 """The state of a dynamic loss scale.""" 

139 

140 def __init__(self, initial_loss_scale, growth_steps, multiplier): 

141 """Creates the dynamic loss scale.""" 

142 super().__init__() 

143 self._initial_loss_scale = float(initial_loss_scale) 

144 self._growth_steps = int(growth_steps) 

145 self._multiplier = float(multiplier) 

146 

147 self._weights = {} 

148 self._current_loss_scale = self._add_weight( 

149 name="current_loss_scale", 

150 dtype=tf.float32, 

151 initial_value=self._initial_loss_scale, 

152 ) 

153 # The number of consecutive steps with finite gradients since the last 

154 # nonfinite gradient or change in loss scale. The name is 'good_steps' 

155 # for backwards compatibility with older checkpoints. 

156 self._counter = self._add_weight( 

157 name="good_steps", dtype=tf.int64, initial_value=0 

158 ) 

159 

160 def _add_weight(self, name, initial_value, dtype=None): 

161 """Adds a weight to this loss scale. 

162 

163 Args: 

164 name: Variable name. 

165 initial_value: The variable's initial value. 

166 dtype: The type of the variable. 

167 

168 Returns: 

169 A variable. 

170 

171 Raises: 

172 RuntimeError: If a weight with `name` has already been added. 

173 """ 

174 variable = tf.Variable( 

175 initial_value=initial_value, 

176 name=name, 

177 dtype=dtype, 

178 trainable=False, 

179 synchronization=tf.VariableSynchronization.AUTO, 

180 # Set aggregation to NONE, as loss scaling variables should never be 

181 # aggregated. 

182 aggregation=tf.VariableAggregation.NONE, 

183 ) 

184 if tf.executing_eagerly(): 

185 graph_key = None 

186 else: 

187 graph = tf.compat.v1.get_default_graph() 

188 graph_key = graph._graph_key 

189 

190 key = (name, graph_key) 

191 self._weights[key] = variable 

192 self._handle_deferred_dependencies(name=name, trackable=variable) 

193 backend.track_variable(variable) 

194 return variable 

195 

196 def _trackable_children(self, save_type="checkpoint", **kwargs): 

197 """From Trackable. Gather graph-specific weights to save.""" 

198 if tf.executing_eagerly(): 

199 graph_key = None 

200 else: 

201 graph = tf.compat.v1.get_default_graph() 

202 graph_key = graph._graph_key 

203 weights = {} 

204 for (name, g), v in sorted( 

205 self._weights.items(), key=lambda i: i[0][0] 

206 ): 

207 if g == graph_key: 

208 weights[name] = v 

209 weights.update(super()._trackable_children(save_type, **kwargs)) 

210 return weights 

211 

212 def _lookup_dependency(self, name): 

213 """From Trackable. Find a weight in the current graph.""" 

214 unconditional = super()._lookup_dependency(name) 

215 if unconditional is not None: 

216 return unconditional 

217 if tf.executing_eagerly(): 

218 graph_key = None 

219 else: 

220 graph = tf.compat.v1.get_default_graph() 

221 graph_key = graph._graph_key 

222 return self._weights.get((name, graph_key), None) 

223 

224 @property 

225 def initial_loss_scale(self): 

226 return self._initial_loss_scale 

227 

228 @property 

229 def growth_steps(self): 

230 return self._growth_steps 

231 

232 @property 

233 def multiplier(self): 

234 return self._multiplier 

235 

236 @property 

237 def current_loss_scale(self): 

238 """Returns the current loss scale as a float32 `tf.Variable`.""" 

239 return self._current_loss_scale 

240 

241 @property 

242 def counter(self): 

243 """Returns the counter as a float32 `tf.Variable`.""" 

244 return self._counter 

245 

246 def __call__(self): 

247 """Returns the current loss scale as a scalar `float32` tensor.""" 

248 return tf.convert_to_tensor(self._current_loss_scale) 

249 

250 def update(self, grads): 

251 """Updates the value of the loss scale. 

252 

253 Args: 

254 grads: A nested structure of unscaled gradients, each which is an 

255 all-reduced gradient of the loss with respect to a weight. 

256 

257 Returns: 

258 update_op: In eager mode, None. In graph mode, an op to update the 

259 loss scale. 

260 should_apply_gradients: Either a bool or a scalar boolean tensor. If 

261 False, the caller should skip applying `grads` to the variables this 

262 step. 

263 """ 

264 grads = tf.nest.flatten(grads) 

265 if ( 

266 tf.distribute.has_strategy() 

267 and tf.distribute.in_cross_replica_context() 

268 ): 

269 distribution = tf.distribute.get_strategy() 

270 is_finite_per_replica = distribution.extended.call_for_each_replica( 

271 _is_all_finite, args=(grads,) 

272 ) 

273 # Each replica computed the same `is_finite` value, since `grads` is 

274 # all-reduced across replicas. Arbitrarily take `is_finite` from the 

275 # first replica. 

276 is_finite = distribution.experimental_local_results( 

277 is_finite_per_replica 

278 )[0] 

279 else: 

280 is_finite = _is_all_finite(grads) 

281 

282 def update_if_finite_grads(): 

283 """Update assuming the gradients are finite.""" 

284 

285 def incr_loss_scale(): 

286 new_loss_scale = self.current_loss_scale * self.multiplier 

287 return tf.group( 

288 _assign_if_finite(self.current_loss_scale, new_loss_scale), 

289 self.counter.assign(0), 

290 ) 

291 

292 return tf.cond( 

293 self.counter + 1 >= self.growth_steps, 

294 incr_loss_scale, 

295 lambda: _op_in_graph_mode(self.counter.assign_add(1)), 

296 ) 

297 

298 def update_if_not_finite_grads(): 

299 """Update assuming the gradients are nonfinite.""" 

300 

301 new_loss_scale = tf.maximum( 

302 self.current_loss_scale / self.multiplier, 1 

303 ) 

304 return tf.group( 

305 self.counter.assign(0), 

306 self.current_loss_scale.assign(new_loss_scale), 

307 ) 

308 

309 update_op = tf.cond( 

310 is_finite, update_if_finite_grads, update_if_not_finite_grads 

311 ) 

312 should_apply_gradients = is_finite 

313 return update_op, should_apply_gradients 

314 

315 

316# See LossScaleOptimizer docstring for why this is so big 

317_DEFAULT_INITIAL_SCALE = 2**15 

318_DEFAULT_GROWTH_STEPS = 2000 

319 

320 

321# TODO(b/215389169): Delete this class after `OptimizerV2` is deprecated. 

322class LossScaleOptimizerMetaclass(type): 

323 """Metaclass that delegates LossScaleOptimizer instance creation. 

324 

325 This metaclass causes a LossScaleOptimizer or LossScaleOptimizerV3 to be 

326 created when a BaseLossScaleOptimizer is constructed. As a result, when a 

327 user creates a loss scale optimizer with 

328 `tf.keras.mixed_precision.LossScaleOptimizer(opt)`, either a 

329 LossScaleOptimizer or LossScaleOptimizerV3 will be created, depending on the 

330 type of `opt`. 

331 """ 

332 

333 def __call__(cls, inner_optimizer, *args, **kwargs): 

334 if cls is not BaseLossScaleOptimizer: 

335 return super(LossScaleOptimizerMetaclass, cls).__call__( 

336 inner_optimizer, *args, **kwargs 

337 ) 

338 if isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 

339 return LossScaleOptimizer(inner_optimizer, *args, **kwargs) 

340 elif isinstance(inner_optimizer, optimizer.Optimizer): 

341 return LossScaleOptimizerV3(inner_optimizer, *args, **kwargs) 

342 

343 # Raise TypeError because inner_optimizer is not an optimizer 

344 msg = ( 

345 '"inner_optimizer" must be an instance of ' 

346 "`tf.keras.optimizers.Optimizer` or " 

347 "`tf.keras.optimizers.experimental.Optimizer`, but got: " 

348 f"{inner_optimizer}." 

349 ) 

350 raise TypeError(msg) 

351 

352 

353# TODO(b/215389169): Delete this class after `OptimizerV2` is deprecated. 

354 

355 

356@keras_export("keras.mixed_precision.LossScaleOptimizer") 

357class BaseLossScaleOptimizer(metaclass=LossScaleOptimizerMetaclass): 

358 """An optimizer that applies loss scaling to prevent numeric underflow. 

359 

360 Loss scaling is a technique to prevent numeric underflow in intermediate 

361 gradients when float16 is used. To prevent underflow, the loss is multiplied 

362 (or "scaled") by a certain factor called the "loss scale", which causes 

363 intermediate gradients to be scaled by the loss scale as well. The final 

364 gradients are divided (or "unscaled") by the loss scale to bring them back 

365 to their original value. 

366 

367 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. 

368 By default, the loss scale is dynamically updated over time so you do not 

369 have to choose the loss scale. The `minimize` method automatically scales 

370 the loss, unscales the gradients, and updates the loss scale so all you have 

371 to do is wrap your optimizer with a `LossScaleOptimizer` if you use 

372 `minimize`. For example: 

373 

374 >>> opt = tf.keras.optimizers.experimental.SGD(0.25) 

375 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 

376 >>> var = tf.Variable(1.) 

377 >>> loss_fn = lambda: var ** 2 

378 >>> # 'minimize' applies loss scaling and updates the loss sale. 

379 >>> opt.minimize(loss_fn, var_list=[var]) 

380 >>> var.numpy() 

381 0.5 

382 

383 If a `tf.GradientTape` is used to compute gradients instead of `minimize`, 

384 you must scale the loss and gradients manually. This can be done with the 

385 `LossScaleOptimizer.get_scaled_loss` and 

386 `LossScaleOptimizer.get_unscaled_gradients` methods. For example: 

387 

388 >>> with tf.GradientTape() as tape: 

389 ... loss = loss_fn() 

390 ... scaled_loss = opt.get_scaled_loss(loss) 

391 >>> scaled_grad = tape.gradient(scaled_loss, var) 

392 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad]) 

393 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here 

394 >>> var.numpy() 

395 0.25 

396 

397 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients` 

398 (or both) when using a `tf.GradientTape`, the model will likely converge to 

399 a worse quality. Please make sure you call each function exactly once. 

400 

401 When mixed precision with float16 is used, there is typically no risk of 

402 underflow affecting model quality if loss scaling is properly used. See 

403 [the mixed precision guide]( 

404 https://www.tensorflow.org/guide/keras/mixed_precision) for more information 

405 on how to use mixed precision. 

406 

407 Args: 

408 inner_optimizer: The `tf.keras.optimizers.Optimizer` or 

409 `tf.keras.optimizers.experimental.Optimizer` instance to wrap. 

410 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to 

411 True. If True, the loss scale will be dynamically updated over time 

412 using an algorithm that keeps the loss scale at approximately its 

413 optimal value. If False, a single fixed loss scale is used and 

414 `initial_scale` must be specified, which is used as the loss scale. 

415 Recommended to keep as True, as choosing a fixed loss scale can be 

416 tricky. Currently, there is a small performance overhead to dynamic loss 

417 scaling compared to fixed loss scaling. 

418 initial_scale: The initial loss scale. If `dynamic` is True, this defaults 

419 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as 

420 the sole loss scale, as the loss scale does not change over time. When 

421 dynamic loss scaling is used, is better for this to be a very high 

422 number, because a loss scale that is too high gets lowered far more 

423 quickly than a loss scale that is too low gets raised. 

424 dynamic_growth_steps: With dynamic loss scaling, every 

425 `dynamic_growth_steps` steps with finite gradients, the loss scale is 

426 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the 

427 count is reset back to zero, gradients are skipped that step, and the 

428 loss scale is halved. The count can be queried with 

429 `LossScaleOptimizer.dynamic_counter`. This argument can only be 

430 specified if `dynamic` is True. 

431 

432 `LossScaleOptimizer` will occasionally skip applying gradients to the 

433 variables, in which case the trainable variables will not change that step. 

434 This is done because the dynamic loss scale will sometimes be raised too 

435 high, causing overflow in the gradients. Typically, the first 2 to 15 steps 

436 of the model are skipped as the initial loss scale is very high, but 

437 afterwards steps will only be skipped on average 0.05% of the time (the 

438 fraction of steps skipped is `1 / dynamic_growth_steps`). 

439 

440 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner 

441 optimizer. Additionally, in methods `minimize` and `get_gradients`, it 

442 scales the loss and unscales the gradients. In methods `minimize` and 

443 `apply_gradients`, it additionally updates the loss scale and skips applying 

444 gradients if any gradient has a nonfinite value. 

445 

446 ### Hyperparameters 

447 

448 If wrapping a `tf.keras.optimizers.Optimizer`, hyperparameters can be 

449 accessed and set on the LossScaleOptimizer, which will be delegated to the 

450 wrapped optimizer. 

451 

452 >>> opt = tf.keras.optimizers.legacy.Adam(beta_1=0.8, epsilon=1e-5) 

453 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 

454 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1` 

455 0.8 

456 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7` 

457 >>> opt.beta_1 

458 0.7 

459 >>> opt.inner_optimizer.beta_1 

460 0.7 

461 

462 However, accessing or setting non-hyperparameters is not delegated to the 

463 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but 

464 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on 

465 `beta_1`. 

466 

467 >>> opt.inner_optimizer.epsilon 

468 1e-5 

469 >>> opt.epsilon 

470 Traceback (most recent call last): 

471 ... 

472 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' 

473 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer` 

474 >>> opt.inner_optimizer.epsilon 

475 >>> 1e-5 

476 

477 In the above example, despite epsilon being set on the LossScaleOptimizer, 

478 the old epsilon value will still be used when training as epsilon was not 

479 set on the inner optimizer. 

480 """ 

481 

482 @property 

483 def dynamic(self): 

484 """Bool indicating whether dynamic loss scaling is used.""" 

485 raise NotImplementedError 

486 

487 @property 

488 def loss_scale(self): 

489 """The current loss scale as a float32 scalar tensor.""" 

490 raise NotImplementedError 

491 

492 @property 

493 def dynamic_counter(self): 

494 """The number of steps since the loss scale was last increased or 

495 decreased. 

496 

497 This is None if `LossScaleOptimizer.dynamic` is False. 

498 

499 The counter is incremented every step. Once it reaches 

500 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be 

501 doubled and the counter will be reset back to zero. If nonfinite 

502 gradients are encountered, the loss scale will be halved and the counter 

503 will be reset back to zero. 

504 """ 

505 raise NotImplementedError 

506 

507 @property 

508 def initial_scale(self): 

509 """The initial loss scale. 

510 

511 If `LossScaleOptimizer.dynamic` is False, this is the same number as 

512 `LossScaleOptimizer.loss_scale`, as the loss scale never changes. 

513 """ 

514 raise NotImplementedError 

515 

516 @property 

517 def dynamic_growth_steps(self): 

518 """The number of steps it takes to increase the loss scale. 

519 

520 This is None if `LossScaleOptimizer.dynamic` is False. 

521 

522 Every `dynamic_growth_steps` consecutive steps with finite gradients, 

523 the loss scale is increased. 

524 """ 

525 raise NotImplementedError 

526 

527 @property 

528 def inner_optimizer(self): 

529 """The optimizer that this LossScaleOptimizer is wrapping.""" 

530 raise NotImplementedError 

531 

532 def get_scaled_loss(self, loss): 

533 """Scales the loss by the loss scale. 

534 

535 This method is only needed if you compute gradients manually, e.g. with 

536 `tf.GradientTape`. In that case, call this method to scale the loss 

537 before passing the loss to `tf.GradientTape`. If you use 

538 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, 

539 loss scaling is automatically applied and this method is unneeded. 

540 

541 If this method is called, `get_unscaled_gradients` should also be 

542 called. See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for 

543 an example. 

544 

545 Args: 

546 loss: The loss, which will be multiplied by the loss scale. Can either 

547 be a tensor or a callable returning a tensor. 

548 

549 Returns: 

550 `loss` multiplied by `LossScaleOptimizer.loss_scale`. 

551 """ 

552 # Calls to this function would be delegated to `get_scaled_loss` 

553 # of either `LossScaleOptimizer` or `LossScaleOptimizerV3`, depending on 

554 # the type of `inner_optimizer`. 

555 raise NotImplementedError 

556 

557 def get_unscaled_gradients(self, grads): 

558 """Unscales the gradients by the loss scale. 

559 

560 This method is only needed if you compute gradients manually, e.g. with 

561 `tf.GradientTape`. In that case, call this method to unscale the 

562 gradients after computing them with `tf.GradientTape`. If you use 

563 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, 

564 loss scaling is automatically applied and this method is unneeded. 

565 

566 If this method is called, `get_scaled_loss` should also be called. See 

567 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an 

568 example. 

569 

570 Args: 

571 grads: A list of tensors, each which will be divided by the loss 

572 scale. Can have None values, which are ignored. 

573 

574 Returns: 

575 A new list the same size as `grads`, where every non-None value in 

576 `grads` is divided by `LossScaleOptimizer.loss_scale`. 

577 """ 

578 # Calls to this function would be delegated to `get_unscaled_gradients` 

579 # of either `LossScaleOptimizer` or `LossScaleOptimizerV3`, depending on 

580 # the type of `inner_optimizer`. 

581 raise NotImplementedError 

582 

583 

584class LossScaleOptimizer( 

585 tf.__internal__.tracking.DelegatingTrackableMixin, 

586 optimizer_v2.OptimizerV2, 

587 BaseLossScaleOptimizer, 

588): 

589 """An optimizer that applies loss scaling to prevent numeric underflow.""" 

590 

591 _HAS_AGGREGATE_GRAD = True 

592 

593 def __init__( 

594 self, 

595 inner_optimizer, 

596 dynamic=True, 

597 initial_scale=None, 

598 dynamic_growth_steps=None, 

599 ): 

600 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 

601 if isinstance(inner_optimizer, optimizer.Optimizer): 

602 # Give better error message if the new experimental optimizer is 

603 # passed. 

604 raise TypeError( 

605 "You passed an instance of the new experimental " 

606 "optimizer, `optimizer.Optimizer`, " 

607 "to LossScaleOptimizer, but " 

608 "only the classic optimizers subclassing from " 

609 "`tf.keras.optimizers.Optimizer` can be passed. Please " 

610 "use `loss_scale_optimizer.LossScaleOptimizerV3` " 

611 "instead of " 

612 "`tf.keras.mixed_precision.LossScaleOptimizer`, " 

613 "as the former supports wrapping " 

614 "instances of the new experimental optimizer. " 

615 f"Got optimizer: {inner_optimizer}" 

616 ) 

617 msg = ( 

618 '"inner_optimizer" must be an instance of ' 

619 "`tf.keras.optimizers.Optimizer`, but got: %s. " 

620 % inner_optimizer 

621 ) 

622 raise TypeError(msg) 

623 if not isinstance(dynamic, bool): 

624 # Catch errors if a user incorrectly passes a string or float to the 

625 # second argument argument, as this was commonly done for the 

626 # now-removed LossScaleOptimizerV1. 

627 raise TypeError( 

628 '"dynamic" argument to LossScaleOptimizer.__init__ must ' 

629 "be a bool, but got: %r" % (dynamic,) 

630 ) 

631 if isinstance(inner_optimizer, LossScaleOptimizer): 

632 raise TypeError( 

633 "LossScaleOptimizer cannot wrap another " 

634 "LossScaleOptimizer, but got: %s" % (inner_optimizer,) 

635 ) 

636 _raise_if_strategy_unsupported() 

637 if getattr( 

638 inner_optimizer, "_is_wrapped_by_loss_scale_optimizer", False 

639 ): 

640 # TODO(reedwm): Maybe support this. The difficulty is that LSO has 

641 # the same checkpoint format as the inner optimizer, so multiple 

642 # LSOs wrapping the same optimizer causes the checkpointing logic to 

643 # become confused. 

644 raise ValueError( 

645 '"inner_optimizer" is already wrapped by a ' 

646 "LossScaleOptimizer. An optimizer can only be wrapped " 

647 "by a single LossScaleOptimizer" 

648 ) 

649 self._optimizer = inner_optimizer 

650 self._optimizer._is_wrapped_by_loss_scale_optimizer = True 

651 

652 # We don't call super().__init__, since we do not want to call 

653 # OptimizerV2's constructor. 

654 tf.__internal__.tracking.DelegatingTrackableMixin.__init__( 

655 self, self._optimizer 

656 ) 

657 

658 if dynamic: 

659 if initial_scale is None: 

660 initial_scale = _DEFAULT_INITIAL_SCALE 

661 if dynamic_growth_steps is None: 

662 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS 

663 self._loss_scale = _DynamicLossScaleState( 

664 initial_scale, dynamic_growth_steps, multiplier=2 

665 ) 

666 self._track_trackable(self._loss_scale, "loss_scale") 

667 else: 

668 if initial_scale is None: 

669 raise ValueError( 

670 '"initial_scale" must be specified if "dynamic" is False' 

671 ) 

672 self._loss_scale = float(initial_scale) 

673 if dynamic_growth_steps is not None: 

674 raise ValueError( 

675 '"dynamic_growth_steps" must be None if "dynamic" ' 

676 "is False, but got: %s" % (dynamic_growth_steps,) 

677 ) 

678 

679 # Used to track whether get_scaled_loss() and get_unscaled_gradients() 

680 # have been called 

681 self._loss_has_been_scaled = False 

682 self._gradients_have_been_unscaled = False 

683 

684 # To support restoring TensorFlow 2.2 checkpoints. 

685 self._track_trackable( 

686 FakeOptimizerForRestoration(self._optimizer), "base_optimizer" 

687 ) 

688 

689 @property 

690 def dynamic(self): 

691 return isinstance(self._loss_scale, _DynamicLossScaleState) 

692 

693 @property 

694 def loss_scale(self): 

695 if isinstance(self._loss_scale, _DynamicLossScaleState): 

696 return tf.convert_to_tensor(self._loss_scale.current_loss_scale) 

697 else: 

698 return tf.convert_to_tensor(self._loss_scale) 

699 

700 @property 

701 def dynamic_counter(self): 

702 if isinstance(self._loss_scale, _DynamicLossScaleState): 

703 return self._loss_scale.counter 

704 else: 

705 return None 

706 

707 @property 

708 def initial_scale(self): 

709 if isinstance(self._loss_scale, _DynamicLossScaleState): 

710 return self._loss_scale.initial_loss_scale 

711 else: 

712 return self._loss_scale 

713 

714 @property 

715 def dynamic_growth_steps(self): 

716 if isinstance(self._loss_scale, _DynamicLossScaleState): 

717 return self._loss_scale.growth_steps 

718 else: 

719 return None 

720 

721 @property 

722 def inner_optimizer(self): 

723 return self._optimizer 

724 

725 def get_scaled_loss(self, loss): 

726 self._loss_has_been_scaled = True 

727 if callable(loss): 

728 

729 def new_loss(): 

730 loss_val = loss() 

731 return loss_val * tf.cast(self.loss_scale, loss_val.dtype) 

732 

733 return new_loss 

734 else: 

735 return loss * tf.cast(self.loss_scale, loss.dtype) 

736 

737 def get_unscaled_gradients(self, grads): 

738 self._gradients_have_been_unscaled = True 

739 loss_scale_reciprocal = 1.0 / self.loss_scale 

740 return [ 

741 _multiply_gradient(g, loss_scale_reciprocal) 

742 if g is not None 

743 else None 

744 for g in grads 

745 ] 

746 

747 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 

748 tape = tf.GradientTape() if tape is None else tape 

749 with tape: 

750 loss = self.get_scaled_loss(loss) 

751 grads_and_vars = self._optimizer._compute_gradients( 

752 loss, var_list, grad_loss, tape=tape 

753 ) 

754 grads = [g for g, _ in grads_and_vars] 

755 weights = [v for _, v in grads_and_vars] 

756 unscaled_grads = self.get_unscaled_gradients(grads) 

757 return list(zip(unscaled_grads, weights)) 

758 

759 def get_gradients(self, loss, params): 

760 loss = self.get_scaled_loss(loss) 

761 grads = self._optimizer.get_gradients(loss, params) 

762 return self.get_unscaled_gradients(grads) 

763 

764 def _create_all_weights(self, var_list): 

765 self._optimizer._create_all_weights(var_list) 

766 

767 def apply_gradients( 

768 self, grads_and_vars, name=None, experimental_aggregate_gradients=True 

769 ): 

770 if tf.distribute.in_cross_replica_context(): 

771 raise ValueError( 

772 "apply_gradients() must be called in a replica context." 

773 ) 

774 # We check for the strategy here despite already checking in the 

775 # constructor as frequently the optimizer is created outside the 

776 # strategy's scope. 

777 _raise_if_strategy_unsupported() 

778 _maybe_warn_about_scaling( 

779 self._loss_has_been_scaled, self._gradients_have_been_unscaled 

780 ) 

781 

782 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 

783 if experimental_aggregate_gradients: 

784 # We must aggregate the gradients here instead of in 

785 # self.optimizer.apply_gradients, so that any NaN or Inf gradients 

786 # are propagated to each replica. If any replica has a NaN or Inf 

787 # gradient, they must all have a NaN or Inf gradient so that they 

788 # all skip the step. 

789 grads_and_vars = self._optimizer._transform_unaggregated_gradients( 

790 grads_and_vars 

791 ) 

792 grads_and_vars = self._optimizer._aggregate_gradients( 

793 grads_and_vars 

794 ) 

795 

796 grads_and_vars = tuple(grads_and_vars) 

797 grads = [g for g, _ in grads_and_vars] 

798 # We do not want DistributionStrategy to unwrap any MirroredVariables in 

799 # grads_and_vars, because even in a replica context, the wrapped 

800 # optimizer expects mirrored variables. So we wrap the variables with an 

801 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the 

802 # MirroredVariables. 

803 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) 

804 

805 def do_not_apply_fn(): 

806 # Normally self._optimizer.iterations is incremented in 

807 # self._optimizer.apply_gradients(). Since that is not called in 

808 # this branch, we increment it here instead. 

809 return self._optimizer.iterations.assign_add(1, read_value=False) 

810 

811 def _if_should_apply_grads(grads): 

812 if isinstance(self._loss_scale, _DynamicLossScaleState): 

813 return self._loss_scale.update(grads) 

814 else: 

815 return (tf.no_op(), True) 

816 

817 if tf.__internal__.distribute.strategy_supports_no_merge_call(): 

818 loss_scale_update_op, should_apply_grads = _if_should_apply_grads( 

819 grads 

820 ) 

821 

822 def apply_fn(): 

823 return self._apply_gradients(grads, wrapped_vars, name) 

824 

825 maybe_apply_op = tf.__internal__.smart_cond.smart_cond( 

826 should_apply_grads, apply_fn, do_not_apply_fn 

827 ) 

828 return tf.group(maybe_apply_op, loss_scale_update_op) 

829 

830 else: 

831 

832 def _apply_gradients_cross_replica( 

833 distribution, grads, wrapped_vars, name 

834 ): 

835 ( 

836 loss_scale_update_op, 

837 should_apply_grads, 

838 ) = _if_should_apply_grads(grads) 

839 

840 def apply_fn(): 

841 return distribution.extended.call_for_each_replica( 

842 self._apply_gradients, args=(grads, wrapped_vars, name) 

843 ) 

844 

845 # Note: We must call this cond() in a cross-replica context. 

846 # DistributionStrategy does not support having a cond in a 

847 # replica context with a branch that calls `merge_call`, and 

848 # self._optimizer.apply_gradients calls `merge_call`. 

849 maybe_apply_op = tf.__internal__.smart_cond.smart_cond( 

850 should_apply_grads, apply_fn, do_not_apply_fn 

851 ) 

852 return tf.group(maybe_apply_op, loss_scale_update_op) 

853 

854 return tf.distribute.get_replica_context().merge_call( 

855 _apply_gradients_cross_replica, args=(grads, wrapped_vars, name) 

856 ) 

857 

858 def _apply_gradients(self, grads, wrapped_vars, name): 

859 # Pass experimental_aggregate_gradients=False since LossScaleOptimizer 

860 # already aggregated the gradients. 

861 # TODO(reedwm): This will raise a fairly cryptic error message if 

862 # self._optimizer.apply_gradients does not take 

863 # experimental_aggregate_gradients. 

864 return self._optimizer.apply_gradients( 

865 list(zip(grads, wrapped_vars.value)), 

866 name=name, 

867 experimental_aggregate_gradients=False, 

868 ) 

869 

870 def get_config(self): 

871 serialized_optimizer = optimizers.serialize(self._optimizer) 

872 return { 

873 "inner_optimizer": serialized_optimizer, 

874 "dynamic": self.dynamic, 

875 "initial_scale": self.initial_scale, 

876 "dynamic_growth_steps": self.dynamic_growth_steps, 

877 } 

878 

879 @classmethod 

880 def from_config(cls, config, custom_objects=None): 

881 config = config.copy() # Make a copy, since we mutate config 

882 if "loss_scale" in config: 

883 # If loss_scale is in config, we assume we are deserializing a 

884 # LossScaleOptimizer from TF 2.3 or below. We convert the config so 

885 # it can be deserialized in the current LossScaleOptimizer. 

886 loss_scale = serialization_lib.deserialize_keras_object( 

887 config.pop("loss_scale"), 

888 module_objects={ 

889 "FixedLossScale": tf.compat.v1.mixed_precision.FixedLossScale, # noqa: E501 

890 "DynamicLossScale": tf.compat.v1.mixed_precision.DynamicLossScale, # noqa: E501 

891 }, 

892 printable_module_name="loss scale", 

893 ) 

894 

895 if isinstance( 

896 loss_scale, tf.compat.v1.mixed_precision.FixedLossScale 

897 ): 

898 config["dynamic"] = False 

899 config["initial_scale"] = loss_scale._loss_scale_value 

900 elif isinstance( 

901 loss_scale, tf.compat.v1.mixed_precision.DynamicLossScale 

902 ): 

903 config["dynamic"] = True 

904 config["initial_scale"] = loss_scale.initial_loss_scale 

905 config["dynamic_growth_steps"] = loss_scale.increment_period 

906 if loss_scale.multiplier != 2: 

907 raise ValueError( 

908 "Cannot deserialize LossScaleOptimizer with a " 

909 "DynamicLossScale whose multiplier is not 2. Got " 

910 "DynamicLossScale: %s" % (loss_scale,) 

911 ) 

912 else: 

913 raise ValueError( 

914 "Serialized LossScaleOptimizers with a LossScale that is " 

915 "neither a FixedLossScale nor a DynamicLossScale can no " 

916 "longer be deserialized" 

917 ) 

918 config["inner_optimizer"] = config.pop("optimizer") 

919 if isinstance(config["inner_optimizer"], optimizer_v2.OptimizerV2): 

920 inner_optimizer = config["inner_optimizer"] 

921 else: 

922 inner_optimizer = optimizers.deserialize( 

923 config["inner_optimizer"], 

924 custom_objects=custom_objects, 

925 use_legacy_optimizer=True, 

926 ) 

927 del config["inner_optimizer"] 

928 return cls(inner_optimizer, **config) 

929 

930 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer 

931 # below. 

932 

933 @property 

934 def iterations(self): 

935 return self._optimizer.iterations 

936 

937 @iterations.setter 

938 def iterations(self, variable): 

939 self._optimizer.iterations = variable 

940 

941 def get_slot_names(self): 

942 return self._optimizer.get_slot_names() 

943 

944 def variables(self): 

945 return self._optimizer.variables() 

946 

947 @property 

948 def weights(self): 

949 return self._optimizer.weights 

950 

951 def get_weights(self): 

952 return self._optimizer.get_weights() 

953 

954 def set_weights(self, weights): 

955 return self._optimizer.set_weights(weights) 

956 

957 @property 

958 def clipnorm(self): 

959 return self._optimizer.clipnorm 

960 

961 @clipnorm.setter 

962 def clipnorm(self, val): 

963 self._optimizer.clipnorm = val 

964 

965 @property 

966 def global_clipnorm(self): 

967 return self._optimizer.global_clipnorm 

968 

969 @global_clipnorm.setter 

970 def global_clipnorm(self, val): 

971 self._optimizer.global_clipnorm = val 

972 

973 @property 

974 def clipvalue(self): 

975 return self._optimizer.clipvalue 

976 

977 @clipvalue.setter 

978 def clipvalue(self, val): 

979 self._optimizer.clipvalue = val 

980 

981 def _aggregate_gradients(self, grads_and_vars): 

982 return self._optimizer._aggregate_gradients(grads_and_vars) 

983 

984 def _restore_slot_variable(self, slot_name, variable, slot_variable): 

985 return self._optimizer._restore_slot_variable( 

986 slot_name, 

987 variable, 

988 slot_variable, 

989 ) 

990 

991 def _create_or_restore_slot_variable( 

992 self, slot_variable_position, slot_name, variable 

993 ): 

994 return self._optimizer._create_or_restore_slot_variable( 

995 slot_variable_position, slot_name, variable 

996 ) 

997 

998 def get_slot(self, var, slot_name): 

999 return self._optimizer.get_slot(var, slot_name) 

1000 

1001 def add_slot(self, var, slot_name, initializer="zeros"): 

1002 return self._optimizer.add_slot(var, slot_name, initializer) 

1003 

1004 def __getattribute__(self, name): 

1005 try: 

1006 return object.__getattribute__(self, name) 

1007 except AttributeError as e: 

1008 if name == "_optimizer" or name == "_hyper": 

1009 # Avoid infinite recursion 

1010 raise e 

1011 

1012 # Delegate hyperparameter accesses to inner optimizer. 

1013 if name == "lr": 

1014 name = "learning_rate" 

1015 if name in self._optimizer._hyper: 

1016 return self._optimizer._get_hyper(name) 

1017 raise e 

1018 

1019 def __dir__(self): 

1020 result = set(super().__dir__()) 

1021 if "_optimizer" in result: 

1022 result |= self._optimizer._hyper.keys() 

1023 if "learning_rate" in self._optimizer._hyper.keys(): 

1024 result.add("lr") 

1025 return list(result) 

1026 

1027 def __setattr__(self, name, value): 

1028 if name == "lr": 

1029 name = "learning_rate" 

1030 # Delegate setting hyperparameter to inner optimizer if the attribute 

1031 # does not exist on the LossScaleOptimizer 

1032 try: 

1033 # We cannot check for the 'iterations' attribute as it cannot be set 

1034 # after it is accessed. 

1035 if name != "iterations": 

1036 object.__getattribute__(self, name) 

1037 has_attribute = True 

1038 except AttributeError: 

1039 has_attribute = False 

1040 if ( 

1041 name != "_optimizer" 

1042 and name in self._optimizer._hyper 

1043 and not has_attribute 

1044 ): 

1045 self._optimizer._set_hyper(name, value) 

1046 else: 

1047 super().__setattr__(name, value) 

1048 

1049 # Explicitly delegate learning_rate. Normally hyperparameters are delegated 

1050 # in __getattribute__, but if a hyperparameter is not in 

1051 # self._optimizer._hyper (e.g. because self._optimizer itself wraps another 

1052 # optimizer), then it won't be delegated. Since learning_rate is a very 

1053 # commonly accessed hyperparameter, we delegate it here. 

1054 @property 

1055 def learning_rate(self): 

1056 return self._optimizer.learning_rate 

1057 

1058 @learning_rate.setter 

1059 def learning_rate(self, value): 

1060 self._optimizer.learning_rate = value 

1061 

1062 @property 

1063 def lr(self): 

1064 return self._optimizer.learning_rate 

1065 

1066 @lr.setter 

1067 def lr(self, value): 

1068 self._optimizer.lr = value 

1069 

1070 # We do not override some OptimizerV2 methods. For each, we describe why we 

1071 # do not delegate them to self._optimizer: 

1072 # * get_updates: get_updates() calls get_gradients(). Since we override 

1073 # get_gradients(), we cannot delegate get_updates() to self._optimizer, 

1074 # otherwise the overridden get_gradients() method would not be called. 

1075 # Luckily, get_updates() does not access any OptimizerV2 fields, so 

1076 # inheriting the OptimizerV2 version works fine. 

1077 # * minimize: We don't delegate for a similar as get_updates(): it calls 

1078 # both self._compute_gradients() and self.apply_gradients(), and both need 

1079 # to have the LossScaleOptimizer version called. 

1080 

1081 # TODO(reedwm): Maybe throw an error if mixed precision is used without this 

1082 # optimizer being used. 

1083 

1084 

1085class LossScaleOptimizerV3( 

1086 tf.__internal__.tracking.DelegatingTrackableMixin, 

1087 optimizer.Optimizer, 

1088 BaseLossScaleOptimizer, 

1089): 

1090 """An optimizer that applies loss scaling to prevent numeric underflow. 

1091 

1092 This is a copy of the `mixed_precision.LossScaleOptimizer` class 

1093 defined above, except it subclasses and wraps the new experimental Optimizer 

1094 class instead of the `tf.keras.optimizers.Optimizer` class. Some of the 

1095 methods this class defines and calls are different compared to 

1096 LossScaleOptimizer due to the differences between the two Optimizer base 

1097 classes. Additionally, this class does not support the legacy graph mode, 

1098 but LossScaleOptimizer does. 

1099 

1100 Since the new experimental Optimizer does not have a hyperparameter concept, 

1101 LossScaleOptimizerV3 does not delegate arbitrary hyperparameter accesses to 

1102 the inner optimizer, unlike LossScaleOptimizer. LossScaleOptimizerV3 does 

1103 delegate the "learning_rate" attribute, however. 

1104 """ 

1105 

1106 @tf.__internal__.tracking.no_automatic_dependency_tracking 

1107 def __init__( 

1108 self, 

1109 inner_optimizer, 

1110 dynamic=True, 

1111 initial_scale=None, 

1112 dynamic_growth_steps=None, 

1113 ): 

1114 if not isinstance(inner_optimizer, optimizer.Optimizer): 

1115 if isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 

1116 # Give better error message if the OptimizerV2 class is passed 

1117 # instead of the new experimental optimizer. 

1118 raise TypeError( 

1119 "You passed a `tf.keras.optimizers.Optimizer` instance to " 

1120 "LossScaleOptimizerV3, but only the new experimental " 

1121 "optimizer defined in " 

1122 "keras/optimizer_expeirmental/optimizer.py can be " 

1123 "passed. Please use " 

1124 "`tf.keras.mixed_precision.LossScaleOptimizer` " 

1125 "instead of LossScaleOptimizerV3, as the former supports " 

1126 "`tf.keras.optimizers.Optimizer`s. Got optimizer: " 

1127 f"{inner_optimizer}" 

1128 ) 

1129 raise TypeError( 

1130 '"inner_optimizer" must be an instance of ' 

1131 f"Optimizer, but got: {inner_optimizer}." 

1132 ) 

1133 if not isinstance(dynamic, bool): 

1134 # Catch errors if a user incorrectly passes a string or float to the 

1135 # second argument argument, as this was commonly done for the 

1136 # now-removed LossScaleOptimizerV1. 

1137 raise TypeError( 

1138 '"dynamic" argument to LossScaleOptimizer.__init__ must ' 

1139 f"be a bool, but got: {repr(dynamic)}" 

1140 ) 

1141 if isinstance(inner_optimizer, LossScaleOptimizerV3): 

1142 raise TypeError( 

1143 "LossScaleOptimizer cannot wrap another " 

1144 f"LossScaleOptimizer, but got: {inner_optimizer}" 

1145 ) 

1146 _raise_if_strategy_unsupported() 

1147 if getattr( 

1148 inner_optimizer, "_is_wrapped_by_loss_scale_optimizer", False 

1149 ): 

1150 # TODO(reedwm): Maybe support this. The difficulty is that LSO has 

1151 # the same checkpoint format as the inner optimizer, so multiple 

1152 # LSOs wrapping the same optimizer causes the checkpointing logic to 

1153 # become confused. 

1154 raise ValueError( 

1155 '"inner_optimizer" is already wrapped by a ' 

1156 "LossScaleOptimizer. An optimizer can only be wrapped " 

1157 "by a single LossScaleOptimizer" 

1158 ) 

1159 self._optimizer = inner_optimizer 

1160 self._optimizer._is_wrapped_by_loss_scale_optimizer = True 

1161 

1162 # We don't call super().__init__, since we do not want to call 

1163 # Optimizer's constructor. 

1164 tf.__internal__.tracking.DelegatingTrackableMixin.__init__( 

1165 self, self._optimizer 

1166 ) 

1167 

1168 if dynamic: 

1169 if initial_scale is None: 

1170 initial_scale = _DEFAULT_INITIAL_SCALE 

1171 if dynamic_growth_steps is None: 

1172 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS 

1173 self._loss_scale = _DynamicLossScaleState( 

1174 initial_scale, dynamic_growth_steps, multiplier=2 

1175 ) 

1176 self._track_trackable(self._loss_scale, "loss_scale") 

1177 else: 

1178 if initial_scale is None: 

1179 raise ValueError( 

1180 '"initial_scale" must be specified if "dynamic" is False' 

1181 ) 

1182 self._loss_scale = float(initial_scale) 

1183 if dynamic_growth_steps is not None: 

1184 raise ValueError( 

1185 '"dynamic_growth_steps" must be None if "dynamic" ' 

1186 f"is False, but got: {dynamic_growth_steps}" 

1187 ) 

1188 

1189 # Used to track whether get_scaled_loss() and get_unscaled_gradients() 

1190 # have been called 

1191 self._loss_has_been_scaled = False 

1192 self._gradients_have_been_unscaled = False 

1193 

1194 @property 

1195 def dynamic(self): 

1196 return isinstance(self._loss_scale, _DynamicLossScaleState) 

1197 

1198 @property 

1199 def loss_scale(self): 

1200 if isinstance(self._loss_scale, _DynamicLossScaleState): 

1201 return tf.convert_to_tensor(self._loss_scale.current_loss_scale) 

1202 else: 

1203 return tf.convert_to_tensor(self._loss_scale) 

1204 

1205 @property 

1206 def dynamic_counter(self): 

1207 if isinstance(self._loss_scale, _DynamicLossScaleState): 

1208 return self._loss_scale.counter 

1209 else: 

1210 return None 

1211 

1212 @property 

1213 def initial_scale(self): 

1214 if isinstance(self._loss_scale, _DynamicLossScaleState): 

1215 return self._loss_scale.initial_loss_scale 

1216 else: 

1217 return self._loss_scale 

1218 

1219 @property 

1220 def dynamic_growth_steps(self): 

1221 if isinstance(self._loss_scale, _DynamicLossScaleState): 

1222 return self._loss_scale.growth_steps 

1223 else: 

1224 return None 

1225 

1226 @property 

1227 def inner_optimizer(self): 

1228 return self._optimizer 

1229 

1230 def get_scaled_loss(self, loss): 

1231 self._loss_has_been_scaled = True 

1232 if callable(loss): 

1233 

1234 def new_loss(): 

1235 loss_val = loss() 

1236 return loss_val * tf.cast(self.loss_scale, loss_val.dtype) 

1237 

1238 return new_loss 

1239 else: 

1240 return loss * tf.cast(self.loss_scale, loss.dtype) 

1241 

1242 def get_unscaled_gradients(self, grads): 

1243 self._gradients_have_been_unscaled = True 

1244 loss_scale_reciprocal = 1.0 / self.loss_scale 

1245 return [ 

1246 _multiply_gradient(g, loss_scale_reciprocal) 

1247 if g is not None 

1248 else None 

1249 for g in grads 

1250 ] 

1251 

1252 def compute_gradients(self, loss, var_list, tape=None): 

1253 tape = tf.GradientTape() if tape is None else tape 

1254 with tape: 

1255 loss = self.get_scaled_loss(loss) 

1256 grads_and_vars = self._optimizer.compute_gradients( 

1257 loss, var_list, tape=tape 

1258 ) 

1259 grads = [g for g, _ in grads_and_vars] 

1260 weights = [v for _, v in grads_and_vars] 

1261 unscaled_grads = self.get_unscaled_gradients(grads) 

1262 return list(zip(unscaled_grads, weights)) 

1263 

1264 def apply_gradients( 

1265 self, grads_and_vars, skip_gradients_aggregation=False, **kwargs 

1266 ): 

1267 if tf.distribute.in_cross_replica_context(): 

1268 raise ValueError( 

1269 "apply_gradients() must be called in a replica context." 

1270 ) 

1271 # We check for the strategy here despite already checking in the 

1272 # constructor as frequently the optimizer is created outside the 

1273 # strategy's scope. 

1274 _raise_if_strategy_unsupported() 

1275 _maybe_warn_about_scaling( 

1276 self._loss_has_been_scaled, self._gradients_have_been_unscaled 

1277 ) 

1278 

1279 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 

1280 # `experimental_aggregate_gradients` is an arg in `apply_gradients` of 

1281 # v2 optimizer -- the reverse of `skip_gradients_aggregation`. 

1282 # We read it from kwargs for backward compatibility. 

1283 experimental_aggregate_gradients = kwargs.pop( 

1284 "experimental_aggregate_gradients", True 

1285 ) 

1286 run_with_dtensor = ( 

1287 # `_run_with_dtensor` is for dtensor based strategy scope, and 

1288 # `_mesh` is when user explicitly specify the mesh setting for 

1289 # optimizer. 

1290 self._optimizer._run_with_dtensor 

1291 or self._optimizer._mesh 

1292 ) 

1293 

1294 if ( 

1295 not skip_gradients_aggregation 

1296 and experimental_aggregate_gradients 

1297 and not run_with_dtensor 

1298 ): 

1299 # We must aggregate the gradients here instead of in 

1300 # self.optimizer.apply_gradients, so that any NaN or Inf gradients 

1301 # are propagated to each replica. If any replica has a NaN or Inf 

1302 # gradient, they must all have a NaN or Inf gradient so that they 

1303 # all skip the step. 

1304 grads_and_vars = self._optimizer.aggregate_gradients(grads_and_vars) 

1305 

1306 grads_and_vars = tuple(grads_and_vars) 

1307 grads = [g for g, _ in grads_and_vars] 

1308 # We do not want DistributionStrategy to unwrap any MirroredVariables in 

1309 # grads_and_vars, because even in a replica context, the wrapped 

1310 # optimizer expects mirrored variables. So we wrap the variables with an 

1311 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the 

1312 # MirroredVariables. 

1313 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) 

1314 

1315 def do_not_apply_fn(): 

1316 # Normally self._optimizer.iterations is incremented in 

1317 # self._optimizer.apply_gradients(). Since that is not called in 

1318 # this branch, we increment it here instead. 

1319 self._optimizer.iterations.assign_add(1, read_value=False) 

1320 

1321 def _if_should_apply_grads(grads): 

1322 if isinstance(self._loss_scale, _DynamicLossScaleState): 

1323 _, should_apply_grad = self._loss_scale.update(grads) 

1324 return should_apply_grad 

1325 else: 

1326 return True 

1327 

1328 if tf.__internal__.distribute.strategy_supports_no_merge_call(): 

1329 should_apply_grads = _if_should_apply_grads(grads) 

1330 

1331 def apply_fn(): 

1332 return self._apply_gradients(grads, wrapped_vars) 

1333 

1334 tf.__internal__.smart_cond.smart_cond( 

1335 should_apply_grads, apply_fn, do_not_apply_fn 

1336 ) 

1337 else: 

1338 

1339 def _apply_gradients_cross_replica( 

1340 distribution, grads, wrapped_vars 

1341 ): 

1342 should_apply_grads = _if_should_apply_grads(grads) 

1343 

1344 def apply_fn(): 

1345 distribution.extended.call_for_each_replica( 

1346 self._apply_gradients, args=(grads, wrapped_vars) 

1347 ) 

1348 

1349 # Note: We must call this cond() in a cross-replica context. 

1350 # DistributionStrategy does not support having a cond in a 

1351 # replica context with a branch that calls `merge_call`, and 

1352 # self._optimizer.apply_gradients calls `merge_call`. 

1353 tf.__internal__.smart_cond.smart_cond( 

1354 should_apply_grads, apply_fn, do_not_apply_fn 

1355 ) 

1356 

1357 tf.distribute.get_replica_context().merge_call( 

1358 _apply_gradients_cross_replica, args=(grads, wrapped_vars) 

1359 ) 

1360 

1361 def _apply_gradients(self, grads, wrapped_vars): 

1362 # Pass skip_gradients_aggregation=True since LossScaleOptimizer 

1363 # already aggregated the gradients. 

1364 self._optimizer.apply_gradients( 

1365 list(zip(grads, wrapped_vars.value)), 

1366 skip_gradients_aggregation=True, 

1367 ) 

1368 

1369 def get_config(self): 

1370 serialized_optimizer = optimizers.serialize(self._optimizer) 

1371 return { 

1372 "inner_optimizer": serialized_optimizer, 

1373 "dynamic": self.dynamic, 

1374 "initial_scale": self.initial_scale, 

1375 "dynamic_growth_steps": self.dynamic_growth_steps, 

1376 } 

1377 

1378 @classmethod 

1379 def from_config(cls, config, custom_objects=None): 

1380 config = config.copy() # Make a copy, since we mutate config 

1381 if isinstance(config["inner_optimizer"], optimizer.Optimizer): 

1382 inner_optimizer = config["inner_optimizer"] 

1383 else: 

1384 inner_optimizer = optimizers.deserialize( 

1385 config["inner_optimizer"], 

1386 custom_objects=custom_objects, 

1387 use_legacy_optimizer=False, 

1388 ) 

1389 del config["inner_optimizer"] 

1390 return cls(inner_optimizer, **config) 

1391 

1392 @property 

1393 def iterations(self): 

1394 return self._optimizer.iterations 

1395 

1396 @iterations.setter 

1397 def iterations(self, variable): 

1398 self._optimizer.iterations = variable 

1399 

1400 @property 

1401 def variables(self): 

1402 return self._optimizer.variables 

1403 

1404 def build(self, var_list): 

1405 return self._optimizer.build(var_list) 

1406 

1407 @property 

1408 def learning_rate(self): 

1409 return self._optimizer.learning_rate 

1410 

1411 @learning_rate.setter 

1412 def learning_rate(self, learning_rate): 

1413 self._optimizer.learning_rate = learning_rate 

1414 

1415 @property 

1416 def use_ema(self): 

1417 return self._optimizer.use_ema 

1418 

1419 @use_ema.setter 

1420 def use_ema(self, use_ema): 

1421 self._optimizer.use_ema = use_ema 

1422 

1423 @property 

1424 def ema_momentum(self): 

1425 return self._optimizer.ema_momentum 

1426 

1427 @ema_momentum.setter 

1428 def ema_momentum(self, ema_momentum): 

1429 self._optimizer.ema_momentum = ema_momentum 

1430 

1431 def finalize_variable_values(self, var_list): 

1432 self._optimizer.finalize_variable_values(var_list) 

1433 

1434 

1435class FakeOptimizerForRestoration(tf.__internal__.tracking.Trackable): 

1436 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints. 

1437 

1438 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This 

1439 class exists to support restoring TF 2.2 checkpoints in newer version of 

1440 TensorFlow. 

1441 

1442 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling 

1443 the following in LossScaleOptimizer.__init__ 

1444 

1445 ``` 

1446 self._track_trackable(self._optimizer, 'base_optimizer') 

1447 ``` 

1448 

1449 This means a dependency from the LossScaleOptimizer to the wrapped optimizer 

1450 would be stored in the checkpoint. However now, the checkpoint format with a 

1451 LossScaleOptimizer is the same as the format without a LossScaleOptimizer, 

1452 except the loss scale is also stored. This means there is no dependency from 

1453 the LossScaleOptimizer to the wrapped optimizer. Instead, the 

1454 LossScaleOptimizer acts as if it is the wrapped optimizer, from a 

1455 checkpoint's perspective, by overriding all Trackable methods and delegating 

1456 them to the wrapped optimizer. 

1457 

1458 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency 

1459 on this class instead of the inner optimizer. When restored, this class will 

1460 instead restore the slot variables of the inner optimizer. Since this class 

1461 has no variables, it does not affect the checkpoint when saved. 

1462 """ 

1463 

1464 def __init__(self, optimizer): 

1465 self._optimizer = optimizer 

1466 

1467 def get_slot_names(self): 

1468 return self._optimizer.get_slot_names() 

1469 

1470 def _create_or_restore_slot_variable( 

1471 self, slot_variable_position, slot_name, variable 

1472 ): 

1473 return self._optimizer._create_or_restore_slot_variable( 

1474 slot_variable_position, slot_name, variable 

1475 ) 

1476 

1477 

1478def _create_loss_scale_optimizer_from_v1_loss_scale(optimizer, loss_scale): 

1479 """Creates an LSO from a tf.compat.v1.mixed_precision.LossScale. 

1480 

1481 This is only used to pass to 

1482 `tf.__internal__.mixed_precision.register_loss_scale_wrapper` below, which 

1483 is called so that 

1484 `tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite` can 

1485 wrap a Keras optimizer with a LossScaleOptimizer. 

1486 

1487 Args: 

1488 optimizer: An OptimizerV2 instance. 

1489 loss_scale: A `tf.compat.v1.mixed_precision.LossScale` instance 

1490 

1491 Returns: 

1492 A LossScaleOptimizer that wraps `optimizer` and uses the same loss scaling 

1493 algorithm as `loss_scale`. 

1494 """ 

1495 if isinstance(loss_scale, (int, float)): 

1496 return LossScaleOptimizer( 

1497 optimizer, dynamic=False, initial_scale=loss_scale 

1498 ) 

1499 elif isinstance(loss_scale, tf.compat.v1.mixed_precision.FixedLossScale): 

1500 ls_val = loss_scale._loss_scale_value 

1501 return LossScaleOptimizer( 

1502 optimizer, dynamic=False, initial_scale=ls_val 

1503 ) 

1504 elif loss_scale == "dynamic": 

1505 return LossScaleOptimizer(optimizer) 

1506 elif isinstance(loss_scale, tf.compat.v1.mixed_precision.DynamicLossScale): 

1507 if loss_scale.multiplier != 2: 

1508 raise ValueError( 

1509 'When passing a DynamicLossScale to "loss_scale", ' 

1510 "DynamicLossScale.multiplier must be 2. Got: " 

1511 f"{loss_scale}" 

1512 ) 

1513 return LossScaleOptimizer( 

1514 optimizer, 

1515 initial_scale=loss_scale.initial_loss_scale, 

1516 dynamic_growth_steps=loss_scale.increment_period, 

1517 ) 

1518 elif isinstance(loss_scale, tf.compat.v1.mixed_precision.LossScale): 

1519 raise TypeError( 

1520 "Passing a LossScale that is not a FixedLossScale or a " 

1521 f"DynamicLossScale is not supported. Got: {loss_scale}" 

1522 ) 

1523 else: 

1524 raise ValueError( 

1525 "Invalid value passed to loss_scale. loss_scale " 

1526 'must be the string "dynamic" (recommended), an int, ' 

1527 "a float, a FixedLossScale, or a DynamicLossScale. Got " 

1528 f"value: {loss_scale}" 

1529 ) 

1530 

1531 

1532tf.__internal__.mixed_precision.register_loss_scale_wrapper( 

1533 optimizer_v2.OptimizerV2, 

1534 _create_loss_scale_optimizer_from_v1_loss_scale, 

1535 LossScaleOptimizer, 

1536) 

1537 

1538 

1539def _multiply_gradient(gradient, scale): 

1540 """Multiply a (possibly sparse) gradient by the given scale factor.""" 

1541 scale = tf.cast(scale, gradient.dtype) 

1542 if isinstance(gradient, tf.IndexedSlices): 

1543 return tf.IndexedSlices( 

1544 gradient.values * scale, 

1545 gradient.indices, 

1546 dense_shape=gradient.dense_shape, 

1547 ) 

1548 else: 

1549 return gradient * scale 

1550 

1551 

1552def strategy_supports_loss_scaling(): 

1553 """Returns True if the current Strategy supports loss scaling.""" 

1554 if not tf.distribute.has_strategy(): 

1555 return True 

1556 strategy = tf.distribute.get_strategy() 

1557 # Strategies are supported if either there is only one replica or if 

1558 # variables are replicated per device. Otherwise, the current model.fit() 

1559 # implementation and most custom training loops incorrectly unscale the 

1560 # gradients. Currently, gradients are unscaled once per compute replica, but 

1561 # they should be unscaled once per variable replica. When there is one 

1562 # variable replica for each compute replica, this works fine, but otherwise 

1563 # issues will occur. 

1564 # TODO(reedwm): Support all strategies. 

1565 return ( 

1566 isinstance( 

1567 strategy, 

1568 ( 

1569 tf.distribute.MultiWorkerMirroredStrategy, 

1570 tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy, # noqa: E501 

1571 tf.distribute.OneDeviceStrategy, 

1572 tf.compat.v1.distribute.OneDeviceStrategy, 

1573 tf.distribute.MirroredStrategy, 

1574 tf.compat.v1.distribute.MirroredStrategy, 

1575 ), 

1576 ) 

1577 or dtensor_utils.running_with_dtensor_strategy() 

1578 ) 

1579 

1580 

1581def _raise_if_strategy_unsupported(): 

1582 """Raise an exception if the current strategy doesn't support loss 

1583 scaling.""" 

1584 if not strategy_supports_loss_scaling(): 

1585 strategy = tf.distribute.get_strategy() 

1586 if isinstance( 

1587 strategy, 

1588 ( 

1589 tf.distribute.experimental.TPUStrategy, 

1590 tf.compat.v1.distribute.experimental.TPUStrategy, 

1591 tf.distribute.TPUStrategy, 

1592 ), 

1593 ): 

1594 raise ValueError( 

1595 "Loss scaling is not supported with TPUStrategy. Loss scaling " 

1596 "is unnecessary with TPUs, since they support bfloat16 instead " 

1597 "of float16 and bfloat16 does not require loss scaling. You " 

1598 "should remove the use of the LossScaleOptimizer when TPUs are " 

1599 "used." 

1600 ) 

1601 else: 

1602 raise ValueError( 

1603 "Loss scaling is not supported with the " 

1604 "tf.distribute.Strategy: " 

1605 f"{strategy.__class__.__name__}. Try using a different " 

1606 "Strategy, e.g. a MirroredStrategy" 

1607 ) 

1608