Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/lazy_adam.py: 40%
42 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Variant of the Adam optimizer that handles sparse updates more efficiently.
17Compared with the original Adam optimizer, the one in this file can
18provide a large improvement in model training throughput for some
19applications. However, it provides slightly different semantics than the
20original Adam algorithm, and may lead to different empirical results.
21"""
23import importlib
24import tensorflow as tf
25from tensorflow_addons.utils.types import FloatTensorLike
27from typeguard import typechecked
28from typing import Union, Callable
31if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
32 adam_optimizer_class = tf.keras.optimizers.legacy.Adam
33else:
34 adam_optimizer_class = tf.keras.optimizers.Adam
37@tf.keras.utils.register_keras_serializable(package="Addons")
38class LazyAdam(adam_optimizer_class):
39 """Variant of the Adam optimizer that handles sparse updates more
40 efficiently.
42 The original Adam algorithm maintains two moving-average accumulators for
43 each trainable variable; the accumulators are updated at every step.
44 This class provides lazier handling of gradient updates for sparse
45 variables. It only updates moving-average accumulators for sparse variable
46 indices that appear in the current batch, rather than updating the
47 accumulators for all indices. Compared with the original Adam optimizer,
48 it can provide large improvements in model training throughput for some
49 applications. However, it provides slightly different semantics than the
50 original Adam algorithm, and may lead to different empirical results.
52 Note, amsgrad is currently not supported and the argument can only be
53 False.
54 """
56 @typechecked
57 def __init__(
58 self,
59 learning_rate: Union[FloatTensorLike, Callable] = 0.001,
60 beta_1: FloatTensorLike = 0.9,
61 beta_2: FloatTensorLike = 0.999,
62 epsilon: FloatTensorLike = 1e-7,
63 amsgrad: bool = False,
64 name: str = "LazyAdam",
65 **kwargs,
66 ):
67 """Constructs a new LazyAdam optimizer.
69 Args:
70 learning_rate: A `Tensor` or a floating point value. or a schedule
71 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
72 The learning rate.
73 beta_1: A `float` value or a constant `float` tensor.
74 The exponential decay rate for the 1st moment estimates.
75 beta_2: A `float` value or a constant `float` tensor.
76 The exponential decay rate for the 2nd moment estimates.
77 epsilon: A small constant for numerical stability.
78 This epsilon is "epsilon hat" in
79 [Adam: A Method for Stochastic Optimization. Kingma et al., 2014]
80 (http://arxiv.org/abs/1412.6980) (in the formula just
81 before Section 2.1), not the epsilon in Algorithm 1 of the paper.
82 amsgrad: `boolean`. Whether to apply AMSGrad variant of this
83 algorithm from the paper "On the Convergence of Adam and beyond".
84 Note that this argument is currently not supported and the
85 argument can only be `False`.
86 name: Optional name for the operations created when applying
87 gradients. Defaults to "LazyAdam".
88 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
89 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue`
90 is clip gradients by value, `decay` is included for backward
91 compatibility to allow time inverse decay of learning rate. `lr`
92 is included for backward compatibility, recommended to use
93 `learning_rate` instead.
94 """
95 super().__init__(
96 learning_rate=learning_rate,
97 beta_1=beta_1,
98 beta_2=beta_2,
99 epsilon=epsilon,
100 amsgrad=amsgrad,
101 name=name,
102 **kwargs,
103 )
105 def _resource_apply_sparse(self, grad, var, indices):
106 var_dtype = var.dtype.base_dtype
107 lr_t = self._decayed_lr(var_dtype)
108 beta_1_t = self._get_hyper("beta_1", var_dtype)
109 beta_2_t = self._get_hyper("beta_2", var_dtype)
110 local_step = tf.cast(self.iterations + 1, var_dtype)
111 beta_1_power = tf.math.pow(beta_1_t, local_step)
112 beta_2_power = tf.math.pow(beta_2_t, local_step)
113 epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
114 lr = lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power)
116 # \\(m := beta1 * m + (1 - beta1) * g_t\\)
117 m = self.get_slot(var, "m")
118 m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad
119 m_update_op = self._resource_scatter_update(m, indices, m_t_slice)
121 # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
122 v = self.get_slot(var, "v")
123 v_t_slice = beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square(
124 grad
125 )
126 v_update_op = self._resource_scatter_update(v, indices, v_t_slice)
128 # \\(variable += -learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
129 var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t)
130 var_update_op = self._resource_scatter_sub(var, indices, var_slice)
132 return tf.group(*[var_update_op, m_update_op, v_update_op])
134 def _resource_scatter_update(self, resource, indices, update):
135 return self._resource_scatter_operate(
136 resource, indices, update, tf.raw_ops.ResourceScatterUpdate
137 )
139 def _resource_scatter_sub(self, resource, indices, update):
140 return self._resource_scatter_operate(
141 resource, indices, update, tf.raw_ops.ResourceScatterSub
142 )
144 def _resource_scatter_operate(self, resource, indices, update, resource_scatter_op):
145 resource_update_kwargs = {
146 "resource": resource.handle,
147 "indices": indices,
148 "updates": update,
149 }
151 return resource_scatter_op(**resource_update_kwargs)
153 def get_config(self):
154 return super().get_config()