Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/lazy_adam.py: 40%

42 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Variant of the Adam optimizer that handles sparse updates more efficiently. 

16 

17Compared with the original Adam optimizer, the one in this file can 

18provide a large improvement in model training throughput for some 

19applications. However, it provides slightly different semantics than the 

20original Adam algorithm, and may lead to different empirical results. 

21""" 

22 

23import importlib 

24import tensorflow as tf 

25from tensorflow_addons.utils.types import FloatTensorLike 

26 

27from typeguard import typechecked 

28from typing import Union, Callable 

29 

30 

31if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: 

32 adam_optimizer_class = tf.keras.optimizers.legacy.Adam 

33else: 

34 adam_optimizer_class = tf.keras.optimizers.Adam 

35 

36 

37@tf.keras.utils.register_keras_serializable(package="Addons") 

38class LazyAdam(adam_optimizer_class): 

39 """Variant of the Adam optimizer that handles sparse updates more 

40 efficiently. 

41 

42 The original Adam algorithm maintains two moving-average accumulators for 

43 each trainable variable; the accumulators are updated at every step. 

44 This class provides lazier handling of gradient updates for sparse 

45 variables. It only updates moving-average accumulators for sparse variable 

46 indices that appear in the current batch, rather than updating the 

47 accumulators for all indices. Compared with the original Adam optimizer, 

48 it can provide large improvements in model training throughput for some 

49 applications. However, it provides slightly different semantics than the 

50 original Adam algorithm, and may lead to different empirical results. 

51 

52 Note, amsgrad is currently not supported and the argument can only be 

53 False. 

54 """ 

55 

56 @typechecked 

57 def __init__( 

58 self, 

59 learning_rate: Union[FloatTensorLike, Callable] = 0.001, 

60 beta_1: FloatTensorLike = 0.9, 

61 beta_2: FloatTensorLike = 0.999, 

62 epsilon: FloatTensorLike = 1e-7, 

63 amsgrad: bool = False, 

64 name: str = "LazyAdam", 

65 **kwargs, 

66 ): 

67 """Constructs a new LazyAdam optimizer. 

68 

69 Args: 

70 learning_rate: A `Tensor` or a floating point value. or a schedule 

71 that is a `tf.keras.optimizers.schedules.LearningRateSchedule` 

72 The learning rate. 

73 beta_1: A `float` value or a constant `float` tensor. 

74 The exponential decay rate for the 1st moment estimates. 

75 beta_2: A `float` value or a constant `float` tensor. 

76 The exponential decay rate for the 2nd moment estimates. 

77 epsilon: A small constant for numerical stability. 

78 This epsilon is "epsilon hat" in 

79 [Adam: A Method for Stochastic Optimization. Kingma et al., 2014] 

80 (http://arxiv.org/abs/1412.6980) (in the formula just 

81 before Section 2.1), not the epsilon in Algorithm 1 of the paper. 

82 amsgrad: `boolean`. Whether to apply AMSGrad variant of this 

83 algorithm from the paper "On the Convergence of Adam and beyond". 

84 Note that this argument is currently not supported and the 

85 argument can only be `False`. 

86 name: Optional name for the operations created when applying 

87 gradients. Defaults to "LazyAdam". 

88 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, 

89 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` 

90 is clip gradients by value, `decay` is included for backward 

91 compatibility to allow time inverse decay of learning rate. `lr` 

92 is included for backward compatibility, recommended to use 

93 `learning_rate` instead. 

94 """ 

95 super().__init__( 

96 learning_rate=learning_rate, 

97 beta_1=beta_1, 

98 beta_2=beta_2, 

99 epsilon=epsilon, 

100 amsgrad=amsgrad, 

101 name=name, 

102 **kwargs, 

103 ) 

104 

105 def _resource_apply_sparse(self, grad, var, indices): 

106 var_dtype = var.dtype.base_dtype 

107 lr_t = self._decayed_lr(var_dtype) 

108 beta_1_t = self._get_hyper("beta_1", var_dtype) 

109 beta_2_t = self._get_hyper("beta_2", var_dtype) 

110 local_step = tf.cast(self.iterations + 1, var_dtype) 

111 beta_1_power = tf.math.pow(beta_1_t, local_step) 

112 beta_2_power = tf.math.pow(beta_2_t, local_step) 

113 epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) 

114 lr = lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power) 

115 

116 # \\(m := beta1 * m + (1 - beta1) * g_t\\) 

117 m = self.get_slot(var, "m") 

118 m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad 

119 m_update_op = self._resource_scatter_update(m, indices, m_t_slice) 

120 

121 # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) 

122 v = self.get_slot(var, "v") 

123 v_t_slice = beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square( 

124 grad 

125 ) 

126 v_update_op = self._resource_scatter_update(v, indices, v_t_slice) 

127 

128 # \\(variable += -learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) 

129 var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t) 

130 var_update_op = self._resource_scatter_sub(var, indices, var_slice) 

131 

132 return tf.group(*[var_update_op, m_update_op, v_update_op]) 

133 

134 def _resource_scatter_update(self, resource, indices, update): 

135 return self._resource_scatter_operate( 

136 resource, indices, update, tf.raw_ops.ResourceScatterUpdate 

137 ) 

138 

139 def _resource_scatter_sub(self, resource, indices, update): 

140 return self._resource_scatter_operate( 

141 resource, indices, update, tf.raw_ops.ResourceScatterSub 

142 ) 

143 

144 def _resource_scatter_operate(self, resource, indices, update, resource_scatter_op): 

145 resource_update_kwargs = { 

146 "resource": resource.handle, 

147 "indices": indices, 

148 "updates": update, 

149 } 

150 

151 return resource_scatter_op(**resource_update_kwargs) 

152 

153 def get_config(self): 

154 return super().get_config()