Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/optimizer_v2.py: 21%
512 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Version 2 of class Optimizer."""
18import abc
19import contextlib
20import functools
21import warnings
22from copy import deepcopy
24import tensorflow.compat.v2 as tf
26from keras.src import backend
27from keras.src import initializers
28from keras.src.engine import base_layer_utils
29from keras.src.optimizers import utils as optimizer_utils
30from keras.src.optimizers.schedules import learning_rate_schedule
31from keras.src.utils import generic_utils
32from keras.src.utils import layer_utils
33from keras.src.utils import tf_inspect
34from keras.src.utils import tf_utils
36# isort: off
37from tensorflow.python.util.tf_export import keras_export
39keras_optimizers_gauge = tf.__internal__.monitoring.BoolGauge(
40 "/tensorflow/api/keras/optimizers", "keras optimizer usage", "method"
41)
43_DEFAULT_VALID_DTYPES = frozenset(
44 [
45 tf.float16,
46 tf.bfloat16,
47 tf.float32,
48 tf.float64,
49 tf.complex64,
50 tf.complex128,
51 ]
52)
55def _deduplicate_indexed_slices(values, indices):
56 """Sums `values` associated with any non-unique `indices`.
58 Args:
59 values: A `Tensor` with rank >= 1.
60 indices: A one-dimensional integer `Tensor`, indexing into the first
61 dimension of `values` (as in an IndexedSlices object).
63 Returns:
64 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
65 de-duplicated version of `indices` and `summed_values` contains the sum of
66 `values` slices associated with each unique index.
67 """
68 unique_indices, new_index_positions = tf.unique(indices)
69 summed_values = tf.math.unsorted_segment_sum(
70 values, new_index_positions, tf.shape(unique_indices)[0]
71 )
72 return (summed_values, unique_indices)
75class NullContextmanager:
76 def __init__(self, *args, **kwargs):
77 pass
79 def __enter__(self):
80 pass
82 def __exit__(self, type_arg, value_arg, traceback_arg):
83 return False # False values do not suppress exceptions
86def name_scope_only_in_function_or_graph(name):
87 """Internal-only entry point for `name_scope*`.
89 Enters a compat.v1.name_scope only when in a function or graph,
90 not when running fully eagerly.
92 Args:
93 name: The name argument that is passed to the op function.
95 Returns:
96 `name_scope*` context manager.
97 """
98 if not tf.executing_eagerly():
99 return tf.name_scope(name)
100 else:
101 return NullContextmanager()
104@keras_export(
105 "keras.optimizers.legacy.Optimizer",
106 v1=["keras.optimizers.Optimizer", "keras.optimizers.legacy.Optimizer"],
107)
108class OptimizerV2(tf.__internal__.tracking.Trackable):
109 """Base class for legacy Keras optimizers.
111 You should not use this class directly, but instead instantiate one of its
112 subclasses such as `tf.keras.optimizers.legacy.SGD`,
113 `tf.keras.optimizers.legacy.Adam`, etc.
115 This is the default Keras optimizer base class until v2.10 (included).
116 In v2.11 and later, `tf.keras.optimizers.Optimizer`
117 points to a new base class implementation. The legacy class won't be
118 deleted in the future and will continue to be available at
119 `tf.keras.optimizers.legacy.Optimizer`.
121 ### Usage
123 ```python
124 # Create an optimizer with the desired parameters.
125 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
126 # `loss` is a callable that takes no argument and returns the value
127 # to minimize.
128 var1 = tf.Variable(2.0)
129 var2 = tf.Variable(5.0)
130 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
131 # In graph mode, returns op that minimizes the loss by updating the listed
132 # variables.
133 opt_op = opt.minimize(loss, var_list=[var1, var2])
134 opt_op.run()
135 # In eager mode, simply call minimize to update the list of variables.
136 opt.minimize(loss, var_list=[var1, var2])
137 ```
139 ### Usage in custom training loops
141 In Keras models, sometimes variables are created when the model is first
142 called, instead of construction time. Examples include 1) sequential models
143 without input shape pre-defined, or 2) subclassed models. Pass var_list as
144 callable in these cases.
146 Example:
148 ```python
149 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
150 model = tf.keras.Sequential()
151 model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
152 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
153 loss_fn = lambda: tf.keras.losses.mse(model(input), output)
154 var_list_fn = lambda: model.trainable_weights
155 for input, output in data:
156 opt.minimize(loss_fn, var_list_fn)
157 ```
159 ### Processing gradients before applying them
161 Calling `minimize()` takes care of both computing the gradients and
162 applying them to the variables. If you want to process the gradients
163 before applying them you can instead use the optimizer in three steps:
165 1. Compute the gradients with `tf.GradientTape`.
166 2. Process the gradients as you wish.
167 3. Apply the processed gradients with `apply_gradients()`.
169 Example:
171 ```python
172 # Create an optimizer.
173 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
175 # Compute the gradients for a list of variables.
176 with tf.GradientTape() as tape:
177 loss = <call_loss_function>
178 vars = <list_of_variables>
179 grads = tape.gradient(loss, vars)
181 # Process the gradients, for example cap them, etc.
182 # capped_grads = [MyCapper(g) for g in grads]
183 processed_grads = [process_gradient(g) for g in grads]
185 # Ask the optimizer to apply the processed gradients.
186 opt.apply_gradients(zip(processed_grads, var_list))
187 ```
189 ### Use with `tf.distribute.Strategy`
191 This optimizer class is `tf.distribute.Strategy` aware, which means it
192 automatically sums gradients across all replicas. To average gradients,
193 you divide your loss by the global batch size, which is done
194 automatically if you use `tf.keras` built-in training or evaluation loops.
195 See the `reduction` argument of your loss which should be set to
196 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
197 `tf.keras.losses.Reduction.SUM` for not.
199 To aggregate gradients yourself, call `apply_gradients` with
200 `experimental_aggregate_gradients` set to False. This is useful if you need
201 to process aggregated gradients.
203 If you are not using these and you want to average gradients, you should use
204 `tf.math.reduce_sum` to add up your per-example losses and then divide by
205 the global batch size. Note that when using `tf.distribute.Strategy`, the
206 first component of a tensor's shape is the *replica-local* batch size, which
207 is off by a factor equal to the number of replicas being used to compute a
208 single step. As a result, using `tf.math.reduce_mean` will give the wrong
209 answer, resulting in gradients that can be many times too big.
211 ### Variable Constraints
213 All Keras optimizers respect variable constraints. If constraint function is
214 passed to any variable, the constraint will be applied to the variable after
215 the gradient has been applied to the variable.
216 Important: If gradient is sparse tensor, variable constraint is not
217 supported.
219 ### Thread Compatibility
221 The entire optimizer is currently thread compatible, not thread-safe. The
222 user needs to perform synchronization if necessary.
224 ### Slots
226 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
227 additional variables associated with the variables to train. These are
228 called <i>Slots</i>. Slots have names and you can ask the optimizer for the
229 names of the slots that it uses. Once you have a slot name you can ask the
230 optimizer for the variable it created to hold the slot value.
232 This can be useful if you want to log debug a training algorithm, report
233 stats about the slots, etc.
235 ### Hyperparameters
237 These are arguments passed to the optimizer subclass constructor
238 (the `__init__` method), and then passed to `self._set_hyper()`.
239 They can be either regular Python values (like 1.0), tensors, or
240 callables. If they are callable, the callable will be called during
241 `apply_gradients()` to get the value for the hyper parameter.
243 Hyperparameters can be overwritten through user code:
245 Example:
247 ```python
248 # Create an optimizer with the desired parameters.
249 opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
250 # `loss` is a callable that takes no argument and returns the value
251 # to minimize.
252 loss = lambda: 3 * var1 + 2 * var2
253 # In eager mode, simply call minimize to update the list of variables.
254 opt.minimize(loss, var_list=[var1, var2])
255 # update learning rate
256 opt.learning_rate = 0.05
257 opt.minimize(loss, var_list=[var1, var2])
258 ```
260 ### Callable learning rate
262 Optimizer accepts a callable learning rate in two ways. The first way is
263 through built-in or customized
264 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be
265 called on each iteration with `schedule(iteration)`, a `tf.Variable`
266 owned by the optimizer.
268 Example:
270 >>> var = tf.Variable(np.random.random(size=(1,)))
271 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
272 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
273 >>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=learning_rate)
274 >>> loss = lambda: 3 * var
275 >>> opt.minimize(loss, var_list=[var])
276 <tf.Variable...
278 The second way is through a callable function that
279 does not accept any arguments.
281 Example:
283 >>> var = tf.Variable(np.random.random(size=(1,)))
284 >>> def lr_callable():
285 ... return .1
286 >>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=lr_callable)
287 >>> loss = lambda: 3 * var
288 >>> opt.minimize(loss, var_list=[var])
289 <tf.Variable...
291 ### Creating a custom optimizer
293 If you intend to create your own optimization algorithm, simply inherit from
294 this class and override the following methods:
296 - `_resource_apply_dense` (update variable given gradient tensor is a
297 dense `tf.Tensor`)
298 - `_resource_apply_sparse` (update variable given gradient tensor is a
299 sparse `tf.IndexedSlices`. The most common way for this to happen
300 is if you are taking the gradient through a `tf.gather`.)
301 - `_create_slots`
302 (if your optimizer algorithm requires additional variables)
303 - `get_config`
304 (serialization of the optimizer, include all hyper parameters)
305 """
307 # Subclasses should set this to True unless they override `apply_gradients`
308 # with a version that does not have the `experimental_aggregate_gradients`
309 # argument. Older versions of Keras did not have this argument so custom
310 # optimizers may have overridden `apply_gradients` without the
311 # `experimental_aggregate_gradients` argument. Keras only passes
312 # `experimental_aggregate_gradients` if this attribute is True.
313 # Note: This attribute will likely be removed in an upcoming release.
314 _HAS_AGGREGATE_GRAD = False
316 def __init__(
317 self,
318 name,
319 gradient_aggregator=None,
320 gradient_transformers=None,
321 **kwargs,
322 ):
323 """Create a new Optimizer.
325 This must be called by the constructors of subclasses.
326 Note that Optimizer instances should not bind to a single graph,
327 and so shouldn't keep Tensors as member variables. Generally
328 you should be able to use the _set_hyper()/state.get_hyper()
329 facility instead.
331 This class is stateful and thread-compatible.
333 Example of custom gradient transformations:
335 ```python
336 def my_gradient_transformer(grads_and_vars):
337 # Simple example, double the gradients.
338 return [(2. * g, v) for g, v in grads_and_vars]
340 optimizer = tf.keras.optimizers.legacy.SGD(
341 1e-3, gradient_transformers=[my_gradient_transformer])
342 ```
344 Args:
345 name: String. The name to use for momentum accumulator weights created
346 by the optimizer.
347 gradient_aggregator: The function to use to aggregate gradients across
348 devices (when using `tf.distribute.Strategy`). If `None`, defaults
349 to summing the gradients across devices. The function should accept
350 and return a list of `(gradient, variable)` tuples.
351 gradient_transformers: Optional. List of functions to use to transform
352 gradients before applying updates to Variables. The functions are
353 applied after `gradient_aggregator`. The functions should accept and
354 return a list of `(gradient, variable)` tuples.
355 **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
356 `clipnorm`, `global_clipnorm`.
357 If `clipvalue` (float) is set, the gradient of each weight
358 is clipped to be no higher than this value.
359 If `clipnorm` (float) is set, the gradient of each weight
360 is individually clipped so that its norm is no higher than this
361 value. If `global_clipnorm` (float) is set the gradient of all
362 weights is clipped so that their global norm is no higher than this
363 value.
365 Raises:
366 ValueError: in case of any invalid argument.
367 """
368 # Instrument optimizer usages
369 keras_optimizers_gauge.get_cell(self.__class__.__name__).set(True)
371 allowed_kwargs = {
372 "clipnorm",
373 "clipvalue",
374 "lr",
375 "decay",
376 "global_clipnorm",
377 }
378 for k in kwargs:
379 if k not in allowed_kwargs:
380 raise TypeError(
381 "Unexpected keyword argument "
382 f"passed to optimizer: {str(k)}. Allowed kwargs are "
383 f"{allowed_kwargs}."
384 )
385 # checks that all keyword arguments are non-negative.
386 if kwargs[k] is not None and kwargs[k] < 0:
387 raise ValueError(f"Expected {k} >= 0, received: {kwargs[k]}")
388 if k == "lr":
389 warnings.warn(
390 "The `lr` argument is deprecated, "
391 "use `learning_rate` instead.",
392 stacklevel=2,
393 )
395 self._use_locking = True
396 self._init_set_name(name)
397 self._hyper = {}
398 # dict: {variable name : {slot name : variable}}
399 self._slots = {}
400 self._slot_names = []
401 self._weights = []
402 self._iterations = None
404 # For implementing Trackable. Stores information about how to restore
405 # slot variables which have not yet been created
406 # (trackable._CheckpointPosition objects).
407 # {slot_name :
408 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
409 # ... }
410 self._deferred_slot_restorations = {}
412 decay = kwargs.pop("decay", 0.0)
413 if decay < 0.0:
414 raise ValueError(
415 f"decay cannot be less than 0. Received: decay={decay}."
416 )
417 self._initial_decay = decay
419 self._hypers_created = False
420 # Store the distribution strategy object if the optimizer is created
421 # inside strategy scope, so it could be used to create variables later.
422 if tf.distribute.has_strategy():
423 self._distribution_strategy = tf.distribute.get_strategy()
424 else:
425 self._distribution_strategy = None
427 # Configure gradient transformations.
428 if gradient_aggregator is None:
429 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients
430 self.gradient_aggregator = gradient_aggregator
431 if gradient_transformers is None:
432 gradient_transformers = []
433 self.gradient_transformers = gradient_transformers
434 self.clipnorm = kwargs.pop("clipnorm", None)
435 self.global_clipnorm = kwargs.pop("global_clipnorm", None)
436 if self.clipnorm is not None and self.global_clipnorm is not None:
437 raise ValueError(
438 "Cannot accept both `clipnorm` and `global_clipnorm`. "
439 "Received: `clipnorm`={}, `global_clipnorm`={}.".format(
440 self.clipnorm, self.global_clipnorm
441 )
442 )
443 self.clipvalue = kwargs.pop("clipvalue", None)
445 def __deepcopy__(self, memo):
446 cls = self.__class__
447 result = cls.__new__(cls)
448 memo[id(self)] = result
449 for k, v in self.__dict__.items():
450 # DistributionStrategy singleton cannot be serialized
451 if k == "_distribution_strategy":
452 continue
453 setattr(result, k, deepcopy(v, memo))
454 result._distribution_strategy = self._distribution_strategy
455 return result
457 @property
458 def clipnorm(self):
459 """`float` or `None`. If set, clips gradients to a maximum norm."""
460 return self._clipnorm
462 @property
463 def global_clipnorm(self):
464 """`float` or `None`.
466 If set, clips gradients to a maximum norm.
468 Check `tf.clip_by_global_norm` for more details.
469 """
470 return self._global_clipnorm
472 @clipnorm.setter
473 def clipnorm(self, val):
474 if val is not None and self.gradient_transformers:
475 raise ValueError(
476 "`clipnorm` cannot be set when `gradient_transformers` "
477 "is set. Instead, use the `gradient_transformers` to "
478 "specify clipping and other transformations. Received: "
479 f"val={val}, "
480 f"gradient_transformers={self.gradient_transformers}."
481 )
482 self._clipnorm = val
483 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
484 self._clipnorm
485 )
487 @global_clipnorm.setter
488 def global_clipnorm(self, val):
489 if val is not None and self.gradient_transformers:
490 raise ValueError(
491 "`global_clipnorm` cannot be set when "
492 "`gradient_transformers` "
493 "is set. Instead, use the `gradient_transformers` to "
494 "specify clipping and other transformations. Received: "
495 f"val={val}, "
496 f"gradient_transformers={self.gradient_transformers}."
497 )
498 self._global_clipnorm = val
499 self._global_clipnorm_fn = (
500 optimizer_utils.make_global_gradient_clipnorm_fn(
501 self._global_clipnorm
502 )
503 )
505 @property
506 def clipvalue(self):
507 """`float` or `None`. If set, clips gradients to a maximum value."""
508 return self._clipvalue
510 @clipvalue.setter
511 def clipvalue(self, val):
512 if val is not None and self.gradient_transformers:
513 raise ValueError(
514 "`clipvalue` cannot be set when `gradient_transformers` "
515 "is set. Instead, use the `gradient_transformers` to "
516 "specify clipping and other transformations. Received: "
517 f"val={val}, "
518 f"gradient_transformers={self.gradient_transformers}."
519 )
520 self._clipvalue = val
521 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn(
522 self._clipvalue
523 )
525 def _transform_loss(self, loss):
526 """Called in `.minimize` to transform loss before computing
527 gradients."""
528 return loss
530 def _get_gradients(self, tape, loss, var_list, grad_loss=None):
531 """Called in `minimize` to compute gradients from loss."""
532 grads = tape.gradient(loss, var_list, grad_loss)
533 return list(zip(grads, var_list))
535 def _transform_unaggregated_gradients(self, grads_and_vars):
536 """Called in `apply_gradients` before gradient aggregation."""
537 return grads_and_vars
539 def _aggregate_gradients(self, grads_and_vars):
540 """Called in `apply_gradients` to aggregate gradients across devices.
542 Note that user subclasses may override this, so the interface should not
543 be changed.
545 Args:
546 grads_and_vars: List of (gradient, variable) pairs.
548 Returns:
549 A list of (aggregrated_gradient, variable) pairs. By default, this
550 calls `self.gradient_aggregator`.
551 """
552 return self.gradient_aggregator(grads_and_vars)
554 def _transform_gradients(self, grads_and_vars):
555 """Called in `apply_gradients` after aggregation."""
556 if self._clipvalue is not None:
557 grads_and_vars = self._clipvalue_fn(grads_and_vars)
558 if self._clipnorm is not None:
559 grads_and_vars = self._clipnorm_fn(grads_and_vars)
560 if self._global_clipnorm is not None:
561 grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
563 for fn in self.gradient_transformers:
564 grads_and_vars = fn(grads_and_vars)
565 return grads_and_vars
567 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
568 """Minimize `loss` by updating `var_list`.
570 This method simply computes gradient using `tf.GradientTape` and calls
571 `apply_gradients()`. If you want to process the gradient before applying
572 then call `tf.GradientTape` and `apply_gradients()` explicitly instead
573 of using this function.
575 Args:
576 loss: `Tensor` or callable. If a callable, `loss` should take no
577 arguments and return the value to minimize. If a `Tensor`, the
578 `tape` argument must be passed.
579 var_list: list or tuple of `Variable` objects to update to minimize
580 `loss`, or a callable returning the list or tuple of `Variable`
581 objects. Use callable when the variable list would otherwise be
582 incomplete before `minimize` since the variables are created at the
583 first time `loss` is called.
584 grad_loss: (Optional). A `Tensor` holding the gradient computed for
585 `loss`.
586 name: (Optional) str. Name for the returned operation.
587 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
588 `Tensor`, the tape that computed the `loss` must be provided.
590 Returns:
591 An `Operation` that updates the variables in `var_list`. The
592 `iterations` will be automatically increased by 1.
594 Raises:
595 ValueError: If some of the variables are not `Variable` objects.
597 """
598 grads_and_vars = self._compute_gradients(
599 loss, var_list=var_list, grad_loss=grad_loss, tape=tape
600 )
601 return self.apply_gradients(grads_and_vars, name=name)
603 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
604 """Compute gradients of `loss` for the variables in `var_list`.
606 This is the first part of `minimize()`. It returns a list
607 of (gradient, variable) pairs where "gradient" is the gradient
608 for "variable". Note that "gradient" can be a `Tensor`, an
609 `IndexedSlices`, or `None` if there is no gradient for the
610 given variable.
612 Args:
613 loss: `Tensor` or callable. If a callable, `loss` should take no
614 arguments and return the value to minimize. If a `Tensor`, the
615 `tape` argument must be passed.
616 var_list: list or tuple of `Variable` objects to update to minimize
617 `loss`, or a callable returning the list or tuple of `Variable`
618 objects. Use callable when the variable list would otherwise be
619 incomplete before `minimize` and the variables are created at the
620 first time when `loss` is called.
621 grad_loss: Optional. A `Tensor` holding the gradient computed for
622 `loss`.
623 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
624 `Tensor`, the tape that computed the `loss` must be provided.
626 Returns:
627 A list of (gradient, variable) pairs. Variable is always present, but
628 gradient can be `None`.
630 Raises:
631 TypeError: If `var_list` contains anything else than `Variable`
632 objects.
633 ValueError: If some arguments are invalid, or var_list is None.
634 """
635 # TODO(joshl): Test that we handle weight decay in a reasonable way.
636 if not callable(loss) and tape is None:
637 raise ValueError(
638 "`tape` is required when a `Tensor` loss is passed. "
639 f"Received: loss={loss}, tape={tape}."
640 )
641 tape = tape if tape is not None else tf.GradientTape()
643 if callable(loss):
644 with tape:
645 if not callable(var_list):
646 tape.watch(var_list)
647 loss = loss()
648 if callable(var_list):
649 var_list = var_list()
651 with tape:
652 loss = self._transform_loss(loss)
654 var_list = tf.nest.flatten(var_list)
655 with tf.name_scope(self._name + "/gradients"):
656 grads_and_vars = self._get_gradients(
657 tape, loss, var_list, grad_loss
658 )
660 self._assert_valid_dtypes(
661 [
662 v
663 for g, v in grads_and_vars
664 if g is not None and v.dtype != tf.resource
665 ]
666 )
668 return grads_and_vars
670 def apply_gradients(
671 self, grads_and_vars, name=None, experimental_aggregate_gradients=True
672 ):
673 """Apply gradients to variables.
675 This is the second part of `minimize()`. It returns an `Operation` that
676 applies gradients.
678 The method sums gradients from all replicas in the presence of
679 `tf.distribute.Strategy` by default. You can aggregate gradients
680 yourself by passing `experimental_aggregate_gradients=False`.
682 Example:
684 ```python
685 grads = tape.gradient(loss, vars)
686 grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
687 # Processing aggregated gradients.
688 optimizer.apply_gradients(zip(grads, vars),
689 experimental_aggregate_gradients=False)
691 ```
693 Args:
694 grads_and_vars: List of (gradient, variable) pairs.
695 name: Optional name for the returned operation. When `None`, uses the
696 name passed to the `Optimizer` constructor. Defaults to `None`.
697 experimental_aggregate_gradients: Whether to sum gradients from
698 different replicas in the presence of `tf.distribute.Strategy`. If
699 False, it's user responsibility to aggregate the gradients. Default
700 to `True`.
702 Returns:
703 An `Operation` that applies the specified gradients. The `iterations`
704 will be automatically increased by 1.
706 Raises:
707 TypeError: If `grads_and_vars` is malformed.
708 ValueError: If none of the variables have gradients.
709 RuntimeError: If called in a cross-replica context.
710 """
711 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
712 var_list = [v for (_, v) in grads_and_vars]
714 with tf.name_scope(self._name):
715 # Create iteration if necessary.
716 with tf.init_scope():
717 self._create_all_weights(var_list)
719 if not grads_and_vars:
720 # Distribution strategy does not support reducing an empty list
721 # of gradients
722 return tf.no_op()
724 if tf.distribute.in_cross_replica_context():
725 raise RuntimeError(
726 "`apply_gradients() cannot be called in cross-replica "
727 "context. Use `tf.distribute.Strategy.run` to enter "
728 "replica context. For more information, please see the "
729 "docstring of `tf.distribute.get_replica_context`."
730 )
732 strategy = tf.distribute.get_strategy()
733 if (
734 not experimental_aggregate_gradients
735 and strategy
736 and isinstance(
737 strategy,
738 (
739 tf.compat.v1.distribute.experimental.ParameterServerStrategy, # noqa: E501
740 tf.distribute.experimental.ParameterServerStrategy,
741 tf.distribute.experimental.CentralStorageStrategy,
742 tf.compat.v1.distribute.experimental.CentralStorageStrategy, # noqa: E501
743 ),
744 )
745 ):
746 raise NotImplementedError(
747 "`experimental_aggregate_gradients=False is not supported "
748 "for ParameterServerStrategy and CentralStorageStrategy. "
749 f"Used: strategy={strategy}."
750 )
752 apply_state = self._prepare(var_list)
753 if experimental_aggregate_gradients:
754 grads_and_vars = self._transform_unaggregated_gradients(
755 grads_and_vars
756 )
757 grads_and_vars = self._aggregate_gradients(grads_and_vars)
758 grads_and_vars = self._transform_gradients(grads_and_vars)
760 return tf.__internal__.distribute.interim.maybe_merge_call(
761 functools.partial(
762 self._distributed_apply, apply_state=apply_state
763 ),
764 strategy,
765 grads_and_vars,
766 name=name,
767 )
769 def _distributed_apply(
770 self, distribution, grads_and_vars, apply_state, name
771 ):
772 """`apply_gradients` using a `DistributionStrategy`."""
774 def apply_grad_to_update_var(var, grad):
775 """Apply gradient to variable."""
776 if isinstance(var, tf.Tensor):
777 raise NotImplementedError(
778 "Updating a `Tensor` is not implemented. "
779 f"Received: var={var}."
780 )
782 apply_kwargs = {}
783 if isinstance(grad, tf.IndexedSlices):
784 if var.constraint is not None:
785 raise RuntimeError(
786 "Cannot use a constraint function on a sparse "
787 f"variable. Received: grad={grad}, "
788 f"var.constraint={var.constraint}."
789 )
790 if "apply_state" in self._sparse_apply_args:
791 apply_kwargs["apply_state"] = apply_state
792 return self._resource_apply_sparse_duplicate_indices(
793 grad.values, var, grad.indices, **apply_kwargs
794 )
796 if "apply_state" in self._dense_apply_args:
797 apply_kwargs["apply_state"] = apply_state
798 update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
799 if var.constraint is not None:
800 with tf.control_dependencies([update_op]):
801 return var.assign(var.constraint(var))
802 else:
803 return update_op
805 eagerly_outside_functions = (
806 tf.compat.v1.executing_eagerly_outside_functions()
807 )
808 update_ops = []
809 with name_scope_only_in_function_or_graph(name or self._name):
810 for grad, var in grads_and_vars:
811 # Colocate the update with variables to avoid unnecessary
812 # communication delays. See b/136304694.
813 with distribution.extended.colocate_vars_with(var):
814 with name_scope_only_in_function_or_graph(
815 "update"
816 if eagerly_outside_functions
817 else "update_" + var.op.name
818 ):
819 update_op = distribution.extended.update(
820 var,
821 apply_grad_to_update_var,
822 args=(grad,),
823 group=False,
824 )
825 if tf.distribute.in_cross_replica_context():
826 # In cross-replica context, extended.update returns
827 # a list of update ops from all replicas
828 # (group=False).
829 update_ops.extend(update_op)
830 else:
831 # In replica context, extended.update return the
832 # single update op of current replica.
833 update_ops.append(update_op)
835 any_symbolic = any(
836 isinstance(i, tf.Operation) or tf_utils.is_symbolic_tensor(i)
837 for i in update_ops
838 )
839 if not tf.executing_eagerly() or any_symbolic:
840 # If the current context is graph mode or any of the update ops
841 # are symbolic then the step update should be carried out under
842 # a graph context. (eager updates execute immediately)
843 with backend._current_graph(update_ops).as_default():
844 with tf.control_dependencies([tf.group(update_ops)]):
845 return self.iterations.assign_add(1, read_value=False)
847 return self.iterations.assign_add(1)
849 def get_gradients(self, loss, params):
850 """Returns gradients of `loss` with respect to `params`.
852 Should be used only in legacy v1 graph mode.
854 Args:
855 loss: Loss tensor.
856 params: List of variables.
858 Returns:
859 List of gradient tensors.
861 Raises:
862 ValueError: In case any gradient cannot be computed (e.g. if gradient
863 function not implemented).
864 """
865 params = tf.nest.flatten(params)
866 with backend.get_graph().as_default(), backend.name_scope(
867 self._name + "/gradients"
868 ):
869 grads = tf.compat.v1.gradients(loss, params)
870 for grad, param in zip(grads, params):
871 if grad is None:
872 raise ValueError(
873 "Variable {} has `None` for gradient. "
874 "Please make sure that all of your ops have a "
875 "gradient defined (i.e. are differentiable). "
876 "Common ops without gradient: "
877 "K.argmax, K.round, K.eval.".format(param)
878 )
879 return grads
881 def get_updates(self, loss, params):
882 grads = self.get_gradients(loss, params)
883 grads_and_vars = list(zip(grads, params))
884 self._assert_valid_dtypes(
885 [
886 v
887 for g, v in grads_and_vars
888 if g is not None and v.dtype != tf.resource
889 ]
890 )
891 return [self.apply_gradients(grads_and_vars)]
893 def _set_hyper(self, name, value):
894 """set hyper `name` to value. value can be callable, tensor, numeric."""
895 if isinstance(value, tf.__internal__.tracking.Trackable):
896 self._track_trackable(value, name, overwrite=True)
897 if name not in self._hyper:
898 self._hyper[name] = value
899 else:
900 prev_value = self._hyper[name]
901 if (
902 callable(prev_value)
903 or isinstance(
904 prev_value,
905 (
906 tf.Tensor,
907 int,
908 float,
909 learning_rate_schedule.LearningRateSchedule,
910 ),
911 )
912 or isinstance(
913 value, learning_rate_schedule.LearningRateSchedule
914 )
915 ):
916 self._hyper[name] = value
917 else:
918 backend.set_value(self._hyper[name], value)
920 def _get_hyper(self, name, dtype=None):
921 if not self._hypers_created:
922 self._create_hypers()
923 value = self._hyper[name]
924 if isinstance(value, learning_rate_schedule.LearningRateSchedule):
925 return value
926 if callable(value):
927 value = value()
928 if dtype:
929 return tf.cast(value, dtype)
930 else:
931 return value
933 def _create_slots(self, var_list):
934 pass
936 def _create_slots_for_sharded_variables(self, var_list):
937 """Add ShardedVariables to slots to later reconstruct for checkpointing.
939 ShardedVariables don't have slot variables created for them; their
940 shards do. This function allows users to call get_slot with a
941 ShardedVariable input and receive a ShardedVariable output containing
942 the appropriate slot vars.
944 Iterate over the variables to find shards, and aggregate the sharded
945 containers in a set. Add these ShardedVariables to _slots so that
946 get_slot can retrieve the proper slot variables for their component
947 shards, and reconstruct those into a ShardedVariable.
949 Args:
950 var_list: list or tuple of `Variable` objects that will be minimized
951 using this optimizer.
952 """
953 sharded_vars = set()
954 for var in var_list:
955 if getattr(var, "_sharded_container", False):
956 sharded_vars.add(var._sharded_container())
958 for sharded_var in sharded_vars:
959 sharded_key = _var_key(sharded_var)
960 slot_dict = {}
961 for slot in self.get_slot_names():
962 slot_dict[slot] = sharded_var
963 self._slots[sharded_key] = slot_dict
965 def _create_all_weights(self, var_list):
966 """Creates all weights, including iterations, hyperparameters and slot
967 vars.
969 This will add newly created variables to `optimizer.weights`.
971 New variables are only created when this method is called the first
972 time, or when called with different variables in the var_list.
974 Args:
975 var_list: list or tuple of `Variable` objects that will be minimized
976 using this optimizer.
977 """
979 _ = self.iterations
980 self._create_hypers()
981 self._create_slots(var_list)
982 self._create_slots_for_sharded_variables(var_list)
984 def __getattribute__(self, name):
985 """Overridden to support hyperparameter access."""
986 try:
987 return super().__getattribute__(name)
988 except AttributeError as e:
989 # Needed to avoid infinite recursion with __setattr__.
990 if name == "_hyper":
991 raise e
992 # Backwards compatibility with Keras optimizers.
993 if name == "lr":
994 name = "learning_rate"
995 if name in self._hyper:
996 return self._get_hyper(name)
997 raise e
999 def __dir__(self):
1000 result = set(super().__dir__())
1001 if "_hyper" in result:
1002 result |= self._hyper.keys()
1003 if "learning_rate" in self._hyper.keys():
1004 result.add("lr")
1005 return list(result)
1007 def __setattr__(self, name, value):
1008 """Override setattr to support dynamic hyperparameter setting."""
1009 # Backwards compatibility with Keras optimizers.
1010 if name == "lr":
1011 name = "learning_rate"
1012 if hasattr(self, "_hyper") and name in self._hyper:
1013 self._set_hyper(name, value)
1014 else:
1015 super().__setattr__(name, value)
1017 def get_slot_names(self):
1018 """A list of names for this optimizer's slots."""
1019 return self._slot_names
1021 def add_slot(self, var, slot_name, initializer="zeros", shape=None):
1022 """Add a new slot variable for `var`.
1024 A slot variable is an additional variable associated with `var` to
1025 train. It is allocated and managed by optimizers, e.g. `Adam`.
1027 Args:
1028 var: a `Variable` object.
1029 slot_name: name of the slot variable.
1030 initializer: initializer of the slot variable
1031 shape: (Optional) shape of the slot variable. If not set, it will
1032 default to the shape of `var`.
1034 Returns:
1035 A slot variable.
1036 """
1037 if slot_name not in self._slot_names:
1038 self._slot_names.append(slot_name)
1039 var_key = _var_key(var)
1040 slot_dict = self._slots.setdefault(var_key, {})
1041 weight = slot_dict.get(slot_name, None)
1042 if weight is None:
1043 if isinstance(initializer, str) or callable(initializer):
1044 initializer = initializers.get(initializer)
1045 if isinstance(
1046 initializer,
1047 tf.__internal__.tracking.CheckpointInitialValueCallable,
1048 ) or (shape is not None):
1049 slot_shape = shape
1050 else:
1051 slot_shape = var.shape
1052 initial_value = functools.partial(
1053 initializer, shape=slot_shape, dtype=var.dtype
1054 )
1055 else:
1056 initial_value = initializer
1058 with self._distribution_strategy_scope():
1059 strategy = tf.distribute.get_strategy()
1060 if not strategy.extended.variable_created_in_scope(var):
1061 raise ValueError(
1062 "Trying to create optimizer slot variable under the "
1063 "scope for tf.distribute.Strategy ({}), which is "
1064 "different from the scope used for the original "
1065 "variable ({}). Make sure the slot variables are "
1066 "created under the same strategy scope. This may "
1067 "happen if you're restoring from a checkpoint "
1068 "outside the scope.".format(strategy, var)
1069 )
1071 with strategy.extended.colocate_vars_with(var):
1072 weight = tf.Variable(
1073 name=f"{var._shared_name}/{slot_name}",
1074 dtype=var.dtype,
1075 trainable=False,
1076 initial_value=initial_value,
1077 )
1078 backend.track_variable(weight)
1079 slot_dict[slot_name] = weight
1080 self._restore_slot_variable(
1081 slot_name=slot_name, variable=var, slot_variable=weight
1082 )
1083 self._weights.append(weight)
1084 return weight
1086 def get_slot(self, var, slot_name):
1087 var_key = _var_key(var)
1088 slot_dict = self._slots[var_key]
1089 slot_variable = slot_dict[slot_name]
1090 if isinstance(
1091 slot_variable, tf.__internal__.distribute.ShardedVariable
1092 ):
1093 # Construct a ShardedVariable that points to the input
1094 # ShardedVariable's component shard's slot variables.
1095 shard_vars = []
1096 for shard in slot_variable.variables:
1097 slot_shard = self.get_slot(shard, slot_name)
1098 shard_vars.append(slot_shard)
1099 slot_variable = tf.__internal__.distribute.ShardedVariable(
1100 shard_vars, name=slot_variable.name
1101 )
1102 return slot_variable
1104 def _prepare(self, var_list):
1105 keys = set()
1106 for var in var_list:
1107 if isinstance(var, tf.distribute.DistributedValues):
1108 var_devices = var._devices
1109 else:
1110 var_devices = [var.device]
1111 var_dtype = var.dtype.base_dtype
1112 for var_device in var_devices:
1113 keys.add((var_device, var_dtype))
1115 apply_state = {}
1116 for var_device, var_dtype in keys:
1117 apply_state[(var_device, var_dtype)] = {}
1118 with tf.device(var_device):
1119 self._prepare_local(var_device, var_dtype, apply_state)
1121 return apply_state
1123 def _prepare_local(self, var_device, var_dtype, apply_state):
1124 if "learning_rate" in self._hyper:
1125 lr_t = tf.identity(self._decayed_lr(var_dtype))
1126 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
1128 def _fallback_apply_state(self, var_device, var_dtype):
1129 """Compatibility for subclasses that don't pass apply_state through."""
1130 apply_state = {(var_device, var_dtype): {}}
1131 self._prepare_local(var_device, var_dtype, apply_state)
1132 return apply_state[(var_device, var_dtype)]
1134 def _create_hypers(self):
1135 if self._hypers_created:
1136 return
1137 with self._distribution_strategy_scope():
1138 # Iterate hyper values deterministically.
1139 for name, value in sorted(self._hyper.items()):
1140 if isinstance(value, (tf.Tensor, tf.Variable)) or callable(
1141 value
1142 ):
1143 # The check for `callable` covers the usage when `value` is
1144 # a `LearningRateSchedule`, in which case it does not need
1145 # to create a variable.
1146 continue
1147 else:
1148 self._hyper[name] = self.add_weight(
1149 name,
1150 shape=[],
1151 trainable=False,
1152 initializer=value,
1153 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
1154 )
1155 self._hypers_created = True
1157 @property
1158 def iterations(self):
1159 """Variable. The number of training steps this Optimizer has run."""
1160 if self._iterations is None:
1161 with self._distribution_strategy_scope():
1162 self._iterations = self.add_weight(
1163 "iter",
1164 shape=[],
1165 dtype=tf.int64,
1166 trainable=False,
1167 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
1168 )
1169 self._weights.append(self._iterations)
1170 return self._iterations
1172 @iterations.setter
1173 def iterations(self, variable):
1174 if self._iterations is not None:
1175 raise RuntimeError(
1176 "Cannot set `iterations` to a new Variable after "
1177 "the Optimizer weights have been created. Here it is "
1178 f"attempting to set `iterations` to {variable}."
1179 )
1180 self._iterations = variable
1181 self._weights.append(self._iterations)
1183 def _decayed_lr(self, var_dtype):
1184 """Get decayed learning rate as a Tensor with dtype=var_dtype."""
1185 lr_t = self._get_hyper("learning_rate", var_dtype)
1186 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
1187 local_step = tf.cast(self.iterations, var_dtype)
1188 lr_t = tf.cast(lr_t(local_step), var_dtype)
1189 if self._initial_decay > 0.0:
1190 local_step = tf.cast(self.iterations, var_dtype)
1191 decay_t = tf.cast(self._initial_decay, var_dtype)
1192 lr_t = lr_t / (1.0 + decay_t * local_step)
1193 return lr_t
1195 @abc.abstractmethod
1196 def get_config(self):
1197 """Returns the config of the optimizer.
1199 An optimizer config is a Python dictionary (serializable)
1200 containing the configuration of an optimizer.
1201 The same optimizer can be reinstantiated later
1202 (without any saved state) from this configuration.
1204 Returns:
1205 Python dictionary.
1206 """
1207 config = {"name": self._name}
1208 if self.clipnorm is not None:
1209 config["clipnorm"] = self.clipnorm
1210 if self.clipvalue is not None:
1211 config["clipvalue"] = self.clipvalue
1212 if self.global_clipnorm is not None:
1213 config["global_clipnorm"] = self.global_clipnorm
1214 return config
1216 @classmethod
1217 def from_config(cls, config, custom_objects=None):
1218 """Creates an optimizer from its config.
1220 This method is the reverse of `get_config`,
1221 capable of instantiating the same optimizer from the config
1222 dictionary.
1224 Args:
1225 config: A Python dictionary, typically the output of get_config.
1226 custom_objects: A Python dictionary mapping names to additional
1227 Python objects used to create this optimizer, such as a function
1228 used for a hyperparameter.
1230 Returns:
1231 An optimizer instance.
1232 """
1233 if "lr" in config:
1234 config["learning_rate"] = config.pop("lr")
1235 if "learning_rate" in config:
1236 if isinstance(config["learning_rate"], dict):
1237 config["learning_rate"] = learning_rate_schedule.deserialize(
1238 config["learning_rate"], custom_objects=custom_objects
1239 )
1240 return cls(**config)
1242 def _serialize_hyperparameter(self, hyperparameter_name):
1243 """Serialize a hyperparameter that can be a float, callable, or
1244 Tensor."""
1245 value = self._hyper[hyperparameter_name]
1246 if isinstance(value, learning_rate_schedule.LearningRateSchedule):
1247 return learning_rate_schedule.serialize(value)
1248 if callable(value):
1249 return value()
1250 if tf.is_tensor(value):
1251 return backend.get_value(value)
1252 return value
1254 def variables(self):
1255 """Returns variables of this Optimizer based on the order created."""
1256 return self._weights
1258 @property
1259 def weights(self):
1260 """Returns variables of this Optimizer based on the order created."""
1261 return self._weights
1263 def get_weights(self):
1264 """Returns the current weights of the optimizer.
1266 The weights of an optimizer are its state (ie, variables).
1267 This function returns the weight values associated with this
1268 optimizer as a list of Numpy arrays. The first value is always the
1269 iterations count of the optimizer, followed by the optimizer's state
1270 variables in the order they were created. The returned list can in turn
1271 be used to load state into similarly parameterized optimizers.
1273 For example, the RMSprop optimizer for this simple model returns a list
1274 of three values-- the iteration count, followed by the root-mean-square
1275 value of the kernel and bias of the single Dense layer:
1277 >>> opt = tf.keras.optimizers.legacy.RMSprop()
1278 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1279 >>> m.compile(opt, loss='mse')
1280 >>> data = np.arange(100).reshape(5, 20)
1281 >>> labels = np.zeros(5)
1282 >>> results = m.fit(data, labels) # Training.
1283 >>> len(opt.get_weights())
1284 3
1286 Returns:
1287 Weights values as a list of numpy arrays.
1288 """
1289 params = self.weights
1290 return backend.batch_get_value(params)
1292 # TODO(tanzheny): Maybe share this logic with base_layer.
1293 def set_weights(self, weights):
1294 """Set the weights of the optimizer.
1296 The weights of an optimizer are its state (ie, variables).
1297 This function takes the weight values associated with this
1298 optimizer as a list of Numpy arrays. The first value is always the
1299 iterations count of the optimizer, followed by the optimizer's state
1300 variables in the order they are created. The passed values are used to
1301 set the new state of the optimizer.
1303 For example, the RMSprop optimizer for this simple model takes a list of
1304 three values-- the iteration count, followed by the root-mean-square
1305 value of the kernel and bias of the single Dense layer:
1307 >>> opt = tf.keras.optimizers.legacy.RMSprop()
1308 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1309 >>> m.compile(opt, loss='mse')
1310 >>> data = np.arange(100).reshape(5, 20)
1311 >>> labels = np.zeros(5)
1312 >>> results = m.fit(data, labels) # Training.
1313 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
1314 >>> opt.set_weights(new_weights)
1315 >>> opt.iterations
1316 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
1318 Args:
1319 weights: weight values as a list of numpy arrays.
1320 """
1321 params = self.weights
1322 if len(params) != len(weights):
1323 raise ValueError(
1324 f"You called `set_weights(weights)` on optimizer {self._name} "
1325 f"with a weight list of length {str(len(weights))}, "
1326 f"but the optimizer was expecting {str(len(params))} "
1327 f"weights. Provided weights: {str(weights)[:50]}..."
1328 )
1329 if not params:
1330 return
1331 weight_value_tuples = []
1332 param_values = backend.batch_get_value(params)
1333 for pv, p, w in zip(param_values, params, weights):
1334 if pv.shape != w.shape:
1335 raise ValueError(
1336 f"Optimizer weight shape {str(pv.shape)} "
1337 "not compatible with "
1338 f"provided weight shape {str(w.shape)}."
1339 )
1340 weight_value_tuples.append((p, w))
1341 backend.batch_set_value(weight_value_tuples)
1343 def add_weight(
1344 self,
1345 name,
1346 shape,
1347 dtype=None,
1348 initializer="zeros",
1349 trainable=None,
1350 synchronization=tf.VariableSynchronization.AUTO,
1351 aggregation=tf.VariableAggregation.NONE,
1352 ):
1354 if dtype is None:
1355 dtype = tf.float32
1356 if isinstance(initializer, str) or callable(initializer):
1357 initializer = initializers.get(initializer)
1359 if synchronization == tf.VariableSynchronization.ON_READ:
1360 if trainable:
1361 raise ValueError(
1362 "Synchronization value can be set to "
1363 "VariableSynchronization.ON_READ only for non-trainable "
1364 "variables. You have specified trainable=True and "
1365 "synchronization=VariableSynchronization.ON_READ."
1366 )
1367 else:
1368 # Set trainable to be false when variable is to be synced on
1369 # read.
1370 trainable = False
1371 elif trainable is None:
1372 trainable = True
1374 variable = self._add_variable_with_custom_getter(
1375 name=name,
1376 shape=shape,
1377 getter=base_layer_utils.make_variable,
1378 overwrite=True,
1379 initializer=initializer,
1380 dtype=dtype,
1381 trainable=trainable,
1382 use_resource=True,
1383 synchronization=synchronization,
1384 aggregation=aggregation,
1385 )
1386 backend.track_variable(variable)
1388 return variable
1390 def _init_set_name(self, name, zero_based=True):
1391 if not name:
1392 self._name = backend.unique_object_name(
1393 generic_utils.to_snake_case(self.__class__.__name__),
1394 zero_based=zero_based,
1395 )
1396 else:
1397 self._name = name
1399 def _assert_valid_dtypes(self, tensors):
1400 """Asserts tensors are all valid types (see `_valid_dtypes`).
1402 Args:
1403 tensors: Tensors to check.
1405 Raises:
1406 ValueError: If any tensor is not a valid type.
1407 """
1408 valid_dtypes = self._valid_dtypes()
1409 for t in tensors:
1410 dtype = t.dtype.base_dtype
1411 if dtype not in valid_dtypes:
1412 raise ValueError(
1413 "Invalid type {} for {}, expected: {}.".format(
1414 dtype, t.name, [v for v in valid_dtypes]
1415 )
1416 )
1418 def _valid_dtypes(self):
1419 """Valid types for loss, variables and gradients.
1421 Subclasses should override to allow other float types.
1423 Returns:
1424 Valid types for loss, variables and gradients.
1425 """
1426 return _DEFAULT_VALID_DTYPES
1428 def _call_if_callable(self, param):
1429 """Call the function if param is callable."""
1430 return param() if callable(param) else param
1432 def _resource_apply_dense(self, grad, handle, apply_state):
1433 """Add ops to apply dense gradients to the variable `handle`.
1435 Args:
1436 grad: a `Tensor` representing the gradient.
1437 handle: a `Tensor` of dtype `resource` which points to the variable to
1438 be updated.
1439 apply_state: A dict which is used across multiple apply calls.
1441 Returns:
1442 An `Operation` which updates the value of the variable.
1443 """
1444 raise NotImplementedError(
1445 "`_resource_apply_dense` must be implemented in subclasses."
1446 )
1448 def _resource_apply_sparse_duplicate_indices(
1449 self, grad, handle, indices, **kwargs
1450 ):
1451 """Add ops to apply sparse gradients to `handle`, with repeated indices.
1453 Optimizers which override this method must deal with repeated indices.
1454 See the docstring of `_apply_sparse_duplicate_indices` for details. By
1455 default the correct behavior, to sum non-unique indices and their
1456 associated gradients, is enforced by first pre-processing `grad` and
1457 `indices` and passing them on to `_resource_apply_sparse`. Optimizers
1458 which deal correctly with duplicate indices may instead override this
1459 method to avoid the overhead of summing.
1461 Args:
1462 grad: a `Tensor` representing the gradient for the affected indices.
1463 handle: a `Tensor` of dtype `resource` which points to the variable to
1464 be updated.
1465 indices: a `Tensor` of integral type representing the indices for
1466 which the gradient is nonzero. Indices may be repeated.
1467 **kwargs: May optionally contain `apply_state`
1469 Returns:
1470 An `Operation` which updates the value of the variable.
1471 """
1472 summed_grad, unique_indices = _deduplicate_indexed_slices(
1473 values=grad, indices=indices
1474 )
1475 return self._resource_apply_sparse(
1476 summed_grad, handle, unique_indices, **kwargs
1477 )
1479 def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1480 """Add ops to apply sparse gradients to the variable `handle`.
1482 Similar to `_apply_sparse`, the `indices` argument to this method has
1483 been de-duplicated. Optimizers which deal correctly with non-unique
1484 indices may instead override `_resource_apply_sparse_duplicate_indices`
1485 to avoid this overhead.
1487 Args:
1488 grad: a `Tensor` representing the gradient for the affected indices.
1489 handle: a `Tensor` of dtype `resource` which points to the variable to
1490 be updated.
1491 indices: a `Tensor` of integral type representing the indices for
1492 which the gradient is nonzero. Indices are unique.
1493 apply_state: A dict which is used across multiple apply calls.
1495 Returns:
1496 An `Operation` which updates the value of the variable.
1497 """
1498 raise NotImplementedError(
1499 "`_resource_apply_sparse` Must be implemented in subclasses."
1500 )
1502 def _resource_scatter_add(self, x, i, v):
1503 with tf.control_dependencies(
1504 [
1505 tf.raw_ops.ResourceScatterAdd(
1506 resource=x.handle, indices=i, updates=v
1507 )
1508 ]
1509 ):
1510 return x.value()
1512 def _resource_scatter_update(self, x, i, v):
1513 with tf.control_dependencies(
1514 [
1515 tf.raw_ops.ResourceScatterUpdate(
1516 resource=x.handle, indices=i, updates=v
1517 )
1518 ]
1519 ):
1520 return x.value()
1522 @property
1523 @layer_utils.cached_per_instance
1524 def _dense_apply_args(self):
1525 return tf_inspect.getfullargspec(self._resource_apply_dense).args
1527 @property
1528 @layer_utils.cached_per_instance
1529 def _sparse_apply_args(self):
1530 return tf_inspect.getfullargspec(self._resource_apply_sparse).args
1532 # ---------------
1533 # For implementing the trackable interface
1534 # ---------------
1536 def _restore_slot_variable(self, slot_name, variable, slot_variable):
1537 """Restore a newly created slot variable's value."""
1538 variable_key = _var_key(variable)
1539 deferred_restorations = self._deferred_slot_restorations.get(
1540 slot_name, {}
1541 ).pop(variable_key, [])
1542 # Iterate over restores, highest restore UID first to minimize the
1543 # number of assignments.
1544 deferred_restorations.sort(
1545 key=lambda position: position.restore_uid, reverse=True
1546 )
1547 for checkpoint_position in deferred_restorations:
1548 checkpoint_position.restore(slot_variable)
1550 def _create_or_restore_slot_variable(
1551 self, slot_variable_position, slot_name, variable
1552 ):
1553 """Returns the slot variable that should have a value restored into it.
1555 It is up to the caller to restore the value into the slot variable if a
1556 valid slot variable is returned.
1558 Called when a variable which has an associated slot variable is created
1559 or restored. When executing eagerly, we create the slot variable with a
1560 restoring initializer.
1562 No new variables are created when graph building. Instead,
1563 _restore_slot_variable catches these after normal creation and adds
1564 restore ops to the graph. This method is nonetheless important when
1565 graph building for the case when a slot variable has already been
1566 created but `variable` has just been added to a dependency graph
1567 (causing us to realize that the slot variable needs to be restored).
1569 Args:
1570 slot_variable_position: A `trackable._CheckpointPosition` object
1571 indicating the slot variable `Trackable` object to be restored.
1572 slot_name: The name of this `Optimizer`'s slot to restore into.
1573 variable: The variable object this slot is being created for.
1575 Returns:
1576 A slot variable that should have a value restored into it, or None if
1577 a slot variable should not be restored at this time.
1578 """
1579 variable_key = _var_key(variable)
1580 slot_dict = self._slots.get(variable_key, {})
1581 slot_variable = slot_dict.get(slot_name, None)
1582 if (
1583 slot_variable is None
1584 and tf.executing_eagerly()
1585 and slot_variable_position.is_simple_variable()
1586 # Defer slot variable creation if there is an active variable
1587 # creator scope. Generally we'd like to eagerly create/restore slot
1588 # variables when possible, but this may mean that scopes intended to
1589 # catch `variable` also catch its eagerly created slot variable
1590 # unintentionally (specifically make_template would add a dependency
1591 # on a slot variable if not for this case). Deferring is mostly
1592 # harmless (aside from double initialization), and makes variable
1593 # creator scopes behave the same way they do when graph building.
1594 #
1595 # One notable case is with distribution strategy, which uses
1596 # variable creator scope but always desires the `variable` and the
1597 # slot to use the same scope, thus we can safely eagerly
1598 # create/restore slot variables.
1599 and (
1600 not tf.compat.v1.get_default_graph()._variable_creator_stack
1601 or self._distribution_strategy
1602 )
1603 ):
1604 initializer = (
1605 tf.__internal__.tracking.CheckpointInitialValueCallable(
1606 checkpoint_position=slot_variable_position
1607 )
1608 )
1609 slot_variable = self.add_slot(
1610 var=variable,
1611 initializer=initializer,
1612 slot_name=slot_name,
1613 shape=slot_variable_position.value_shape(),
1614 )
1615 # Slot variables are not owned by any one object (because we don't
1616 # want to save the slot variable if the optimizer is saved without
1617 # the non-slot variable, or if the non-slot variable is saved
1618 # without the optimizer; it's a dependency hypergraph with edges of
1619 # the form (optimizer, non-slot variable, variable)). So we don't
1620 # _track_ slot variables anywhere, and instead special-case this
1621 # dependency and otherwise pretend it's a normal graph.
1622 if slot_variable is not None:
1623 # For sharded variables, we need the logic in get_slot to combine
1624 # slot variables for its shards
1625 if (slot_variable is variable) and (
1626 isinstance(variable, tf.__internal__.distribute.ShardedVariable)
1627 ):
1628 return self.get_slot(variable, slot_name)
1629 # If we've either made this slot variable, or if we've pulled out an
1630 # existing slot variable, we should restore it.
1631 return slot_variable
1632 else:
1633 # We didn't make the slot variable. Defer restoring until it gets
1634 # created normally. We keep a list rather than the one with the
1635 # highest restore UID in case slot variables have their own
1636 # dependencies, in which case those could differ between restores.
1637 self._deferred_slot_restorations.setdefault(
1638 slot_name, {}
1639 ).setdefault(variable_key, []).append(slot_variable_position)
1640 return None
1642 @contextlib.contextmanager
1643 def _distribution_strategy_scope(self):
1644 """Returns the `tf.distribute.Strategy` this optimizer was created
1645 under."""
1646 if self._distribution_strategy and not tf.distribute.has_strategy():
1647 with self._distribution_strategy.scope():
1648 yield self._distribution_strategy.scope()
1649 else:
1650 yield
1653def _var_key(var):
1654 """Key for representing a primary variable, for looking up slots.
1656 In graph mode the name is derived from the var shared name.
1657 In eager mode the name is derived from the var unique id.
1658 If distribution strategy exists, get the primary variable first.
1660 Args:
1661 var: the variable.
1663 Returns:
1664 the unique name of the variable.
1665 """
1667 # Get the distributed variable if it exists.
1668 if hasattr(var, "_distributed_container"):
1669 var = var._distributed_container()
1670 elif (
1671 tf_utils.is_extension_type(var)
1672 and hasattr(var, "handle")
1673 and hasattr(var.handle, "_distributed_container")
1674 ):
1675 # For ResourceVariables, the _distributed_container attribute
1676 # is added to their handle tensors.
1677 var = var.handle._distributed_container()
1678 if getattr(var, "_in_graph_mode", False):
1679 return var._shared_name
1680 return var._unique_id
1683def _get_slot_key_from_var(var, slot_name):
1684 """Get the slot key for the variable: var_name/slot_name."""
1686 name = _var_key(var)
1687 return name + "/" + slot_name
1690class RestoredOptimizer(OptimizerV2):
1691 """A non-functional Optimizer implementation for checkpoint compatibility.
1693 Holds slot variables and hyperparameters when an optimizer is restored from
1694 a SavedModel. These variables may be referenced in functions along with ops
1695 created by the original optimizer, but currently we do not support using the
1696 optimizer object itself (e.g. through `apply_gradients`).
1697 """
1699 # TODO(allenl): Make the restored optimizer functional by tracing its apply
1700 # methods.
1702 def __init__(self):
1703 super().__init__("RestoredOptimizer")
1704 self._hypers_created = True
1706 def get_config(self):
1707 # TODO(allenl): Save and restore the Optimizer's config
1708 raise NotImplementedError(
1709 "Restoring functional Optimizers from SavedModels is not currently "
1710 "supported. Please file a feature request if this limitation "
1711 "bothers you."
1712 )
1715tf.__internal__.saved_model.load.register_revived_type(
1716 "optimizer",
1717 lambda obj: isinstance(obj, OptimizerV2),
1718 versions=[
1719 tf.__internal__.saved_model.load.VersionedTypeRegistration(
1720 object_factory=lambda proto: RestoredOptimizer(),
1721 version=2,
1722 min_producer_version=1,
1723 min_consumer_version=1,
1724 setter=RestoredOptimizer._set_hyper,
1725 )
1726 ],
1727)