1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various learning rate decay functions."""
16
17import functools
18
19from tensorflow.python.eager import context
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import tensor_conversion
22from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
23from tensorflow.python.ops import math_ops
24from tensorflow.python.util import nest
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export(v1=["train.exponential_decay"])
29def exponential_decay(learning_rate,
30 global_step,
31 decay_steps,
32 decay_rate,
33 staircase=False,
34 name=None):
35 """Applies exponential decay to the learning rate.
36
37 When training a model, it is often recommended to lower the learning rate as
38 the training progresses. This function applies an exponential decay function
39 to a provided initial learning rate. It requires a `global_step` value to
40 compute the decayed learning rate. You can just pass a TensorFlow variable
41 that you increment at each training step.
42
43 The function returns the decayed learning rate. It is computed as:
44
45 ```python
46 decayed_learning_rate = learning_rate *
47 decay_rate ^ (global_step / decay_steps)
48 ```
49
50 If the argument `staircase` is `True`, then `global_step / decay_steps` is an
51 integer division and the decayed learning rate follows a staircase function.
52
53 Example: decay every 100000 steps with a base of 0.96:
54
55 ```python
56 ...
57 global_step = tf.Variable(0, trainable=False)
58 starter_learning_rate = 0.1
59 learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
60 global_step,
61 100000, 0.96, staircase=True)
62 # Passing global_step to minimize() will increment it at each step.
63 learning_step = (
64 tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
65 .minimize(...my loss..., global_step=global_step)
66 )
67 ```
68
69 Args:
70 learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
71 The initial learning rate.
72 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
73 step to use for the decay computation. Must not be negative.
74 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
75 be positive. See the decay computation above.
76 decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
77 The decay rate.
78 staircase: Boolean. If `True` decay the learning rate at discrete intervals
79 name: String. Optional name of the operation. Defaults to
80 'ExponentialDecay'.
81
82 Returns:
83 A scalar `Tensor` of the same type as `learning_rate`. The decayed
84 learning rate.
85
86 Raises:
87 ValueError: if `global_step` is not supplied.
88
89 @compatibility(eager)
90 When eager execution is enabled, this function returns a function which in
91 turn returns the decayed learning rate Tensor. This can be useful for changing
92 the learning rate value across different invocations of optimizer functions.
93 @end_compatibility
94 """
95 decayed_lr = learning_rate_schedule.ExponentialDecay(
96 learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
97 if not context.executing_eagerly():
98 decayed_lr = decayed_lr(global_step)
99 else:
100 decayed_lr = functools.partial(decayed_lr, global_step)
101 return decayed_lr
102
103
104@tf_export(v1=["train.piecewise_constant_decay", "train.piecewise_constant"])
105def piecewise_constant(x, boundaries, values, name=None):
106 """Piecewise constant from boundaries and interval values.
107
108 Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
109 for the next 10000 steps, and 0.1 for any additional steps.
110
111 ```python
112 global_step = tf.Variable(0, trainable=False)
113 boundaries = [100000, 110000]
114 values = [1.0, 0.5, 0.1]
115 learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries,
116 values)
117
118 # Later, whenever we perform an optimization step, we increment global_step.
119 ```
120
121 Args:
122 x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
123 `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
124 boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
125 increasing entries, and with all elements having the same type as `x`.
126 values: A list of `Tensor`s or `float`s or `int`s that specifies the values
127 for the intervals defined by `boundaries`. It should have one more element
128 than `boundaries`, and all elements should have the same type.
129 name: A string. Optional name of the operation. Defaults to
130 'PiecewiseConstant'.
131
132 Returns:
133 A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
134 `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
135 and values[-1] when `x > boundaries[-1]`.
136
137 Raises:
138 ValueError: if types of `x` and `boundaries` do not match, or types of all
139 `values` do not match or
140 the number of elements in the lists does not match.
141
142 @compatibility(eager)
143 When eager execution is enabled, this function returns a function which in
144 turn returns the decayed learning rate Tensor. This can be useful for changing
145 the learning rate value across different invocations of optimizer functions.
146 @end_compatibility
147 """
148 boundaries = nest.map_structure(
149 tensor_conversion.convert_to_tensor_v2_with_dispatch,
150 nest.flatten(boundaries),
151 )
152 values = nest.map_structure(
153 tensor_conversion.convert_to_tensor_v2_with_dispatch, nest.flatten(values)
154 )
155 x_recomp = tensor_conversion.convert_to_tensor_v2_with_dispatch(x)
156 # Avoid explicit conversion to x's dtype. This could result in faulty
157 # comparisons, for example if floats are converted to integers.
158 for i, b in enumerate(boundaries):
159 if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
160 # We can promote int32 boundaries to int64 without loss of precision.
161 # This covers the most common case where the user passes in boundaries
162 # as an array of Python integers.
163 if (b.dtype.base_dtype == dtypes.int32 and
164 x_recomp.dtype.base_dtype == dtypes.int64):
165 b = math_ops.cast(b, x_recomp.dtype.base_dtype)
166 boundaries[i] = b
167 else:
168 raise ValueError(
169 "Boundaries (%s) must have the same dtype as x (%s)." %
170 (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
171 for v in values[1:]:
172 if v.dtype.base_dtype != values[0].dtype.base_dtype:
173 raise ValueError(
174 "Values must have elements all with the same dtype (%s vs %s)." %
175 (values[0].dtype.base_dtype, v.dtype.base_dtype))
176 decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
177 boundaries, values, name=name)
178 if not context.executing_eagerly():
179 decayed_lr = decayed_lr(x)
180 else:
181 decayed_lr = functools.partial(decayed_lr, x)
182 return decayed_lr
183
184
185@tf_export(v1=["train.polynomial_decay"])
186def polynomial_decay(learning_rate,
187 global_step,
188 decay_steps,
189 end_learning_rate=0.0001,
190 power=1.0,
191 cycle=False,
192 name=None):
193 """Applies a polynomial decay to the learning rate.
194
195 It is commonly observed that a monotonically decreasing learning rate, whose
196 degree of change is carefully chosen, results in a better performing model.
197 This function applies a polynomial decay function to a provided initial
198 `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`.
199
200 It requires a `global_step` value to compute the decayed learning rate. You
201 can just pass a TensorFlow variable that you increment at each training step.
202
203 The function returns the decayed learning rate. It is computed as:
204
205 ```python
206 global_step = min(global_step, decay_steps)
207 decayed_learning_rate = (learning_rate - end_learning_rate) *
208 (1 - global_step / decay_steps) ^ (power) +
209 end_learning_rate
210
211 ```
212
213 If `cycle` is True then a multiple of `decay_steps` is used, the first one
214 that is bigger than `global_steps`.
215
216 ```python
217 decay_steps = decay_steps * ceil(global_step / decay_steps)
218 decayed_learning_rate = (learning_rate - end_learning_rate) *
219 (1 - global_step / decay_steps) ^ (power) +
220 end_learning_rate
221
222 ```
223
224 Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5):
225
226 ```python
227 ...
228 global_step = tf.Variable(0, trainable=False)
229 starter_learning_rate = 0.1
230 end_learning_rate = 0.01
231 decay_steps = 10000
232 learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate,
233 global_step,
234 decay_steps, end_learning_rate,
235 power=0.5)
236 # Passing global_step to minimize() will increment it at each step.
237 learning_step = (
238 tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
239 .minimize(...my loss..., global_step=global_step)
240 )
241 ```
242
243 Args:
244 learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
245 The initial learning rate.
246 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
247 step to use for the decay computation. Must not be negative.
248 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
249 be positive. See the decay computation above.
250 end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python
251 number. The minimal end learning rate.
252 power: A scalar `float32` or `float64` `Tensor` or a Python number. The
253 power of the polynomial. Defaults to linear, 1.0.
254 cycle: A boolean, whether or not it should cycle beyond decay_steps.
255 name: String. Optional name of the operation. Defaults to
256 'PolynomialDecay'.
257
258 Returns:
259 A scalar `Tensor` of the same type as `learning_rate`. The decayed
260 learning rate.
261
262 Raises:
263 ValueError: if `global_step` is not supplied.
264
265 @compatibility(eager)
266 When eager execution is enabled, this function returns a function which in
267 turn returns the decayed learning rate Tensor. This can be useful for changing
268 the learning rate value across different invocations of optimizer functions.
269 @end_compatibility
270 """
271 decayed_lr = learning_rate_schedule.PolynomialDecay(
272 learning_rate,
273 decay_steps,
274 end_learning_rate=end_learning_rate,
275 power=power,
276 cycle=cycle,
277 name=name)
278
279 if not context.executing_eagerly():
280 decayed_lr = decayed_lr(global_step)
281 else:
282 decayed_lr = functools.partial(decayed_lr, global_step)
283 return decayed_lr
284
285
286@tf_export(v1=["train.natural_exp_decay"])
287def natural_exp_decay(learning_rate,
288 global_step,
289 decay_steps,
290 decay_rate,
291 staircase=False,
292 name=None):
293 """Applies natural exponential decay to the initial learning rate.
294
295 When training a model, it is often recommended to lower the learning rate as
296 the training progresses. This function applies an exponential decay function
297 to a provided initial learning rate. It requires an `global_step` value to
298 compute the decayed learning rate. You can just pass a TensorFlow variable
299 that you increment at each training step.
300
301 The function returns the decayed learning rate. It is computed as:
302
303 ```python
304 decayed_learning_rate = learning_rate * exp(-decay_rate * global_step /
305 decay_step)
306 ```
307
308 or, if `staircase` is `True`, as:
309
310 ```python
311 decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step /
312 decay_step))
313 ```
314
315 Example: decay exponentially with a base of 0.96:
316
317 ```python
318 ...
319 global_step = tf.Variable(0, trainable=False)
320 learning_rate = 0.1
321 decay_steps = 5
322 k = 0.5
323 learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate,
324 global_step,
325 decay_steps, k)
326
327 # Passing global_step to minimize() will increment it at each step.
328 learning_step = (
329 tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
330 .minimize(...my loss..., global_step=global_step)
331 )
332 ```
333
334 Args:
335 learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
336 The initial learning rate.
337 global_step: A Python number. Global step to use for the decay computation.
338 Must not be negative.
339 decay_steps: How often to apply decay.
340 decay_rate: A Python number. The decay rate.
341 staircase: Whether to apply decay in a discrete staircase, as opposed to
342 continuous, fashion.
343 name: String. Optional name of the operation. Defaults to
344 'ExponentialTimeDecay'.
345
346 Returns:
347 A scalar `Tensor` of the same type as `learning_rate`. The decayed
348 learning rate.
349
350 Raises:
351 ValueError: if `global_step` is not supplied.
352
353 @compatibility(eager)
354 When eager execution is enabled, this function returns a function which in
355 turn returns the decayed learning rate Tensor. This can be useful for changing
356 the learning rate value across different invocations of optimizer functions.
357 @end_compatibility
358 """
359 natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
360 decayed_lr = learning_rate_schedule.ExponentialDecay(
361 learning_rate,
362 decay_steps,
363 natural_exp_rate,
364 staircase=staircase,
365 name=name)
366
367 if not context.executing_eagerly():
368 decayed_lr = decayed_lr(global_step)
369 else:
370 decayed_lr = functools.partial(decayed_lr, global_step)
371 return decayed_lr
372
373
374@tf_export(v1=["train.inverse_time_decay"])
375def inverse_time_decay(learning_rate,
376 global_step,
377 decay_steps,
378 decay_rate,
379 staircase=False,
380 name=None):
381 """Applies inverse time decay to the initial learning rate.
382
383 When training a model, it is often recommended to lower the learning rate as
384 the training progresses. This function applies an inverse decay function
385 to a provided initial learning rate. It requires an `global_step` value to
386 compute the decayed learning rate. You can just pass a TensorFlow variable
387 that you increment at each training step.
388
389 The function returns the decayed learning rate. It is computed as:
390
391 ```python
392 decayed_learning_rate = learning_rate / (1 + decay_rate * global_step /
393 decay_step)
394 ```
395
396 or, if `staircase` is `True`, as:
397
398 ```python
399 decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step /
400 decay_step))
401 ```
402
403 Example: decay 1/t with a rate of 0.5:
404
405 ```python
406 ...
407 global_step = tf.Variable(0, trainable=False)
408 learning_rate = 0.1
409 decay_steps = 1.0
410 decay_rate = 0.5
411 learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate,
412 global_step,
413 decay_steps, decay_rate)
414
415 # Passing global_step to minimize() will increment it at each step.
416 learning_step = (
417 tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
418 .minimize(...my loss..., global_step=global_step)
419 )
420 ```
421
422 Args:
423 learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
424 The initial learning rate.
425 global_step: A Python number. Global step to use for the decay computation.
426 Must not be negative.
427 decay_steps: How often to apply decay.
428 decay_rate: A Python number. The decay rate.
429 staircase: Whether to apply decay in a discrete staircase, as opposed to
430 continuous, fashion.
431 name: String. Optional name of the operation. Defaults to
432 'InverseTimeDecay'.
433
434 Returns:
435 A scalar `Tensor` of the same type as `learning_rate`. The decayed
436 learning rate.
437
438 Raises:
439 ValueError: if `global_step` is not supplied.
440
441 @compatibility(eager)
442 When eager execution is enabled, this function returns a function which in
443 turn returns the decayed learning rate Tensor. This can be useful for changing
444 the learning rate value across different invocations of optimizer functions.
445 @end_compatibility
446 """
447 decayed_lr = learning_rate_schedule.InverseTimeDecay(
448 learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
449
450 if not context.executing_eagerly():
451 decayed_lr = decayed_lr(global_step)
452 else:
453 decayed_lr = functools.partial(decayed_lr, global_step)
454 return decayed_lr
455
456
457@tf_export(v1=["train.cosine_decay"])
458def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
459 """Applies cosine decay to the learning rate.
460
461 When training a model, it is often recommended to lower the learning rate as
462 the training progresses. This function applies a cosine decay function
463 to a provided initial learning rate. It requires a `global_step` value to
464 compute the decayed learning rate. You can just pass a TensorFlow variable
465 that you increment at each training step.
466
467 The function returns the decayed learning rate. It is computed as:
468 ```python
469 global_step = min(global_step, decay_steps)
470 cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
471 decayed = (1 - alpha) * cosine_decay + alpha
472 decayed_learning_rate = learning_rate * decayed
473 ```
474
475 Example usage:
476 ```python
477 decay_steps = 1000
478 lr_decayed = cosine_decay(learning_rate, global_step, decay_steps)
479 ```
480
481 Args:
482 learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
483 The initial learning rate.
484 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
485 step to use for the decay computation.
486 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
487 of steps to decay over.
488 alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
489 learning rate value as a fraction of learning_rate.
490 name: String. Optional name of the operation. Defaults to 'CosineDecay'.
491
492 Returns:
493 A scalar `Tensor` of the same type as `learning_rate`. The decayed
494 learning rate.
495 Raises:
496 ValueError: if `global_step` is not supplied.
497
498 References:
499 Stochastic Gradient Descent with Warm Restarts:
500 [Loshchilov et al., 2017]
501 (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx)
502 ([pdf](https://openreview.net/pdf?id=Skq89Scxx))
503
504 @compatibility(eager)
505 When eager execution is enabled, this function returns a function which in
506 turn returns the decayed learning rate Tensor. This can be useful for changing
507 the learning rate value across different invocations of optimizer functions.
508 @end_compatibility
509 """
510 decayed_lr = learning_rate_schedule.CosineDecay(
511 learning_rate, decay_steps, alpha=alpha, name=name)
512
513 if not context.executing_eagerly():
514 decayed_lr = decayed_lr(global_step)
515 else:
516 decayed_lr = functools.partial(decayed_lr, global_step)
517 return decayed_lr
518
519
520@tf_export(v1=["train.cosine_decay_restarts"])
521def cosine_decay_restarts(learning_rate,
522 global_step,
523 first_decay_steps,
524 t_mul=2.0,
525 m_mul=1.0,
526 alpha=0.0,
527 name=None):
528 """Applies cosine decay with restarts to the learning rate.
529
530 When training a model, it is often recommended to lower the learning rate as
531 the training progresses. This function applies a cosine decay function with
532 restarts to a provided initial learning rate. It requires a `global_step`
533 value to compute the decayed learning rate. You can just pass a TensorFlow
534 variable that you increment at each training step.
535
536 The function returns the decayed learning rate while taking into account
537 possible warm restarts. The learning rate multiplier first decays
538 from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
539 restart is performed. Each new warm restart runs for `t_mul` times more steps
540 and with `m_mul` times smaller initial learning rate.
541
542 Example usage:
543 ```python
544 first_decay_steps = 1000
545 lr_decayed = cosine_decay_restarts(learning_rate, global_step,
546 first_decay_steps)
547 ```
548
549 Args:
550 learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
551 The initial learning rate.
552 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
553 step to use for the decay computation.
554 first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
555 Number of steps to decay over.
556 t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to
557 derive the number of iterations in the i-th period
558 m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
559 Used to derive the initial learning rate of the i-th period:
560 alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
561 learning rate value as a fraction of the learning_rate.
562 name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
563
564 Returns:
565 A scalar `Tensor` of the same type as `learning_rate`. The decayed
566 learning rate.
567 Raises:
568 ValueError: if `global_step` is not supplied.
569
570 References:
571 Stochastic Gradient Descent with Warm Restarts:
572 [Loshchilov et al., 2017]
573 (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx)
574 ([pdf](https://openreview.net/pdf?id=Skq89Scxx))
575
576 @compatibility(eager)
577 When eager execution is enabled, this function returns a function which in
578 turn returns the decayed learning rate Tensor. This can be useful for changing
579 the learning rate value across different invocations of optimizer functions.
580 @end_compatibility
581 """
582 decayed_lr = learning_rate_schedule.CosineDecayRestarts(
583 learning_rate,
584 first_decay_steps,
585 t_mul=t_mul,
586 m_mul=m_mul,
587 alpha=alpha,
588 name=name)
589
590 if not context.executing_eagerly():
591 decayed_lr = decayed_lr(global_step)
592 else:
593 decayed_lr = functools.partial(decayed_lr, global_step)
594 return decayed_lr
595
596
597@tf_export(v1=["train.linear_cosine_decay"])
598def linear_cosine_decay(learning_rate,
599 global_step,
600 decay_steps,
601 num_periods=0.5,
602 alpha=0.0,
603 beta=0.001,
604 name=None):
605 """Applies linear cosine decay to the learning rate.
606
607 Note that linear cosine decay is more aggressive than cosine decay and
608 larger initial learning rates can typically be used.
609
610 When training a model, it is often recommended to lower the learning rate as
611 the training progresses. This function applies a linear cosine decay function
612 to a provided initial learning rate. It requires a `global_step` value to
613 compute the decayed learning rate. You can just pass a TensorFlow variable
614 that you increment at each training step.
615
616 The function returns the decayed learning rate. It is computed as:
617 ```python
618 global_step = min(global_step, decay_steps)
619 linear_decay = (decay_steps - global_step) / decay_steps)
620 cosine_decay = 0.5 * (
621 1 + cos(pi * 2 * num_periods * global_step / decay_steps))
622 decayed = (alpha + linear_decay) * cosine_decay + beta
623 decayed_learning_rate = learning_rate * decayed
624 ```
625
626 Example usage:
627 ```python
628 decay_steps = 1000
629 lr_decayed = linear_cosine_decay(learning_rate, global_step, decay_steps)
630 ```
631
632 Args:
633 learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
634 The initial learning rate.
635 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
636 step to use for the decay computation.
637 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
638 of steps to decay over.
639 num_periods: Number of periods in the cosine part of the decay. See
640 computation above.
641 alpha: See computation above.
642 beta: See computation above.
643 name: String. Optional name of the operation. Defaults to
644 'LinearCosineDecay'.
645
646 Returns:
647 A scalar `Tensor` of the same type as `learning_rate`. The decayed
648 learning rate.
649 Raises:
650 ValueError: if `global_step` is not supplied.
651
652 References:
653 Neural Optimizer Search with Reinforcement Learning:
654 [Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html)
655 ([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf))
656 Stochastic Gradient Descent with Warm Restarts:
657 [Loshchilov et al., 2017]
658 (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx)
659 ([pdf](https://openreview.net/pdf?id=Skq89Scxx))
660
661 @compatibility(eager)
662 When eager execution is enabled, this function returns a function which in
663 turn returns the decayed learning rate Tensor. This can be useful for changing
664 the learning rate value across different invocations of optimizer functions.
665 @end_compatibility
666 """
667 decayed_lr = learning_rate_schedule.LinearCosineDecay(
668 learning_rate,
669 decay_steps,
670 num_periods=num_periods,
671 alpha=alpha,
672 beta=beta,
673 name=name)
674
675 if not context.executing_eagerly():
676 decayed_lr = decayed_lr(global_step)
677 else:
678 decayed_lr = functools.partial(decayed_lr, global_step)
679 return decayed_lr
680
681
682@tf_export(v1=["train.noisy_linear_cosine_decay"])
683def noisy_linear_cosine_decay(learning_rate,
684 global_step,
685 decay_steps,
686 initial_variance=1.0,
687 variance_decay=0.55,
688 num_periods=0.5,
689 alpha=0.0,
690 beta=0.001,
691 name=None):
692 """Applies noisy linear cosine decay to the learning rate.
693
694 Note that linear cosine decay is more aggressive than cosine decay and
695 larger initial learning rates can typically be used.
696
697 When training a model, it is often recommended to lower the learning rate as
698 the training progresses. This function applies a noisy linear
699 cosine decay function to a provided initial learning rate.
700 It requires a `global_step` value to compute the decayed learning rate.
701 You can just pass a TensorFlow variable that you increment at each
702 training step.
703
704 The function returns the decayed learning rate. It is computed as:
705 ```python
706 global_step = min(global_step, decay_steps)
707 linear_decay = (decay_steps - global_step) / decay_steps)
708 cosine_decay = 0.5 * (
709 1 + cos(pi * 2 * num_periods * global_step / decay_steps))
710 decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
711 decayed_learning_rate = learning_rate * decayed
712 ```
713 where eps_t is 0-centered gaussian noise with variance
714 initial_variance / (1 + global_step) ** variance_decay
715
716 Example usage:
717 ```python
718 decay_steps = 1000
719 lr_decayed = noisy_linear_cosine_decay(
720 learning_rate, global_step, decay_steps)
721 ```
722
723 Args:
724 learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
725 The initial learning rate.
726 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
727 step to use for the decay computation.
728 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
729 of steps to decay over.
730 initial_variance: initial variance for the noise. See computation above.
731 variance_decay: decay for the noise's variance. See computation above.
732 num_periods: Number of periods in the cosine part of the decay. See
733 computation above.
734 alpha: See computation above.
735 beta: See computation above.
736 name: String. Optional name of the operation. Defaults to
737 'NoisyLinearCosineDecay'.
738
739 Returns:
740 A scalar `Tensor` of the same type as `learning_rate`. The decayed
741 learning rate.
742 Raises:
743 ValueError: if `global_step` is not supplied.
744
745 References:
746 Neural Optimizer Search with Reinforcement Learning:
747 [Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html)
748 ([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf))
749 Stochastic Gradient Descent with Warm Restarts:
750 [Loshchilov et al., 2017]
751 (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx)
752 ([pdf](https://openreview.net/pdf?id=Skq89Scxx))
753
754 @compatibility(eager)
755 When eager execution is enabled, this function returns a function which in
756 turn returns the decayed learning rate Tensor. This can be useful for changing
757 the learning rate value across different invocations of optimizer functions.
758 @end_compatibility
759 """
760 decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay(
761 learning_rate,
762 decay_steps,
763 initial_variance=initial_variance,
764 variance_decay=variance_decay,
765 num_periods=num_periods,
766 alpha=alpha,
767 beta=beta,
768 name=name)
769
770 if not context.executing_eagerly():
771 decayed_lr = decayed_lr(global_step)
772 else:
773 decayed_lr = functools.partial(decayed_lr, global_step)
774 return decayed_lr