Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py: 25%
496 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Version 2 of class Optimizer."""
16# pylint: disable=g-bad-name
18import abc
19import contextlib
20import functools
21import warnings
23from tensorflow.python.distribute import central_storage_strategy
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import parameter_server_strategy
26from tensorflow.python.distribute import parameter_server_strategy_v2
27from tensorflow.python.distribute import values as ds_values
28from tensorflow.python.eager import backprop
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.keras import backend
35from tensorflow.python.keras import initializers
36from tensorflow.python.keras.engine import base_layer_utils
37from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
38from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
39from tensorflow.python.keras.utils import generic_utils
40from tensorflow.python.keras.utils import layer_utils
41from tensorflow.python.keras.utils import tf_inspect
42from tensorflow.python.keras.utils import tf_utils
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import gen_resource_variable_ops
46from tensorflow.python.ops import gradients
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import variables as tf_variables
49from tensorflow.python.saved_model import revived_types
50from tensorflow.python.trackable import base as trackable
51from tensorflow.python.util import nest
52from tensorflow.python.util.tf_export import keras_export
55_DEFAULT_VALID_DTYPES = frozenset([
56 dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
57 dtypes.complex64, dtypes.complex128
58])
61def _deduplicate_indexed_slices(values, indices):
62 """Sums `values` associated with any non-unique `indices`.
64 Args:
65 values: A `Tensor` with rank >= 1.
66 indices: A one-dimensional integer `Tensor`, indexing into the first
67 dimension of `values` (as in an IndexedSlices object).
69 Returns:
70 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
71 de-duplicated version of `indices` and `summed_values` contains the sum of
72 `values` slices associated with each unique index.
73 """
74 unique_indices, new_index_positions = array_ops.unique(indices)
75 summed_values = math_ops.unsorted_segment_sum(
76 values, new_index_positions,
77 array_ops.shape(unique_indices)[0])
78 return (summed_values, unique_indices)
81class NullContextmanager(object):
83 def __init__(self, *args, **kwargs):
84 pass
86 def __enter__(self):
87 pass
89 def __exit__(self, type_arg, value_arg, traceback_arg):
90 return False # False values do not suppress exceptions
93def name_scope_only_in_function_or_graph(name):
94 """Internal-only entry point for `name_scope*`.
96 Enters a compat.v1.name_scope only when in a function or graph,
97 not when running fully eagerly.
99 Args:
100 name: The name argument that is passed to the op function.
102 Returns:
103 `name_scope*` context manager.
104 """
105 if not context.executing_eagerly():
106 return ops.name_scope_v1(name)
107 else:
108 return NullContextmanager()
111@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta)
112class OptimizerV2(trackable.Trackable):
113 """Base class for Keras optimizers.
115 You should not use this class directly, but instead instantiate one of its
116 subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc.
118 ### Usage
120 ```python
121 # Create an optimizer with the desired parameters.
122 opt = tf.keras.optimizers.SGD(learning_rate=0.1)
123 # `loss` is a callable that takes no argument and returns the value
124 # to minimize.
125 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
126 # In graph mode, returns op that minimizes the loss by updating the listed
127 # variables.
128 opt_op = opt.minimize(loss, var_list=[var1, var2])
129 opt_op.run()
130 # In eager mode, simply call minimize to update the list of variables.
131 opt.minimize(loss, var_list=[var1, var2])
132 ```
134 ### Usage in custom training loops
136 In Keras models, sometimes variables are created when the model is first
137 called, instead of construction time. Examples include 1) sequential models
138 without input shape pre-defined, or 2) subclassed models. Pass var_list as
139 callable in these cases.
141 Example:
143 ```python
144 opt = tf.keras.optimizers.SGD(learning_rate=0.1)
145 model = tf.keras.Sequential()
146 model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
147 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
148 loss_fn = lambda: tf.keras.losses.mse(model(input), output)
149 var_list_fn = lambda: model.trainable_weights
150 for input, output in data:
151 opt.minimize(loss_fn, var_list_fn)
152 ```
154 ### Processing gradients before applying them
156 Calling `minimize()` takes care of both computing the gradients and
157 applying them to the variables. If you want to process the gradients
158 before applying them you can instead use the optimizer in three steps:
160 1. Compute the gradients with `tf.GradientTape`.
161 2. Process the gradients as you wish.
162 3. Apply the processed gradients with `apply_gradients()`.
164 Example:
166 ```python
167 # Create an optimizer.
168 opt = tf.keras.optimizers.SGD(learning_rate=0.1)
170 # Compute the gradients for a list of variables.
171 with tf.GradientTape() as tape:
172 loss = <call_loss_function>
173 vars = <list_of_variables>
174 grads = tape.gradient(loss, vars)
176 # Process the gradients, for example cap them, etc.
177 # capped_grads = [MyCapper(g) for g in grads]
178 processed_grads = [process_gradient(g) for g in grads]
180 # Ask the optimizer to apply the processed gradients.
181 opt.apply_gradients(zip(processed_grads, var_list))
182 ```
184 ### Use with `tf.distribute.Strategy`
186 This optimizer class is `tf.distribute.Strategy` aware, which means it
187 automatically sums gradients across all replicas. To average gradients,
188 you divide your loss by the global batch size, which is done
189 automatically if you use `tf.keras` built-in training or evaluation loops.
190 See the `reduction` argument of your loss which should be set to
191 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
192 `tf.keras.losses.Reduction.SUM` for not.
194 To aggregate gradients yourself, call `apply_gradients` with
195 `experimental_aggregate_gradients` set to False. This is useful if you need to
196 process aggregated gradients.
198 If you are not using these and you want to average gradients, you should use
199 `tf.math.reduce_sum` to add up your per-example losses and then divide by the
200 global batch size. Note that when using `tf.distribute.Strategy`, the first
201 component of a tensor's shape is the *replica-local* batch size, which is off
202 by a factor equal to the number of replicas being used to compute a single
203 step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
204 resulting in gradients that can be many times too big.
206 ### Variable Constraints
208 All Keras optimizers respect variable constraints. If constraint function is
209 passed to any variable, the constraint will be applied to the variable after
210 the gradient has been applied to the variable.
211 Important: If gradient is sparse tensor, variable constraint is not supported.
213 ### Thread Compatibility
215 The entire optimizer is currently thread compatible, not thread-safe. The user
216 needs to perform synchronization if necessary.
218 ### Slots
220 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
221 additional variables associated with the variables to train. These are called
222 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of
223 the slots that it uses. Once you have a slot name you can ask the optimizer
224 for the variable it created to hold the slot value.
226 This can be useful if you want to log debug a training algorithm, report stats
227 about the slots, etc.
229 ### Hyperparameters
231 These are arguments passed to the optimizer subclass constructor
232 (the `__init__` method), and then passed to `self._set_hyper()`.
233 They can be either regular Python values (like 1.0), tensors, or
234 callables. If they are callable, the callable will be called during
235 `apply_gradients()` to get the value for the hyper parameter.
237 Hyperparameters can be overwritten through user code:
239 Example:
241 ```python
242 # Create an optimizer with the desired parameters.
243 opt = tf.keras.optimizers.SGD(learning_rate=0.1)
244 # `loss` is a callable that takes no argument and returns the value
245 # to minimize.
246 loss = lambda: 3 * var1 + 2 * var2
247 # In eager mode, simply call minimize to update the list of variables.
248 opt.minimize(loss, var_list=[var1, var2])
249 # update learning rate
250 opt.learning_rate = 0.05
251 opt.minimize(loss, var_list=[var1, var2])
252 ```
254 ### Callable learning rate
256 Optimizer accepts a callable learning rate in two ways. The first way is
257 through built-in or customized
258 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be
259 called on each iteration with `schedule(iteration)`, a `tf.Variable`
260 owned by the optimizer.
262 Example:
264 >>> var = tf.Variable(np.random.random(size=(1,)))
265 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
266 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
267 >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate)
268 >>> loss = lambda: 3 * var
269 >>> opt.minimize(loss, var_list=[var])
270 <tf.Variable...
272 The second way is through a callable function that
273 does not accept any arguments.
275 Example:
277 >>> var = tf.Variable(np.random.random(size=(1,)))
278 >>> def lr_callable():
279 ... return .1
280 >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable)
281 >>> loss = lambda: 3 * var
282 >>> opt.minimize(loss, var_list=[var])
283 <tf.Variable...
285 ### Creating a custom optimizer
287 If you intend to create your own optimization algorithm, simply inherit from
288 this class and override the following methods:
290 - `_resource_apply_dense` (update variable given gradient tensor is a dense
291 `tf.Tensor`)
292 - `_resource_apply_sparse` (update variable given gradient tensor is a
293 sparse `tf.IndexedSlices`. The most common way for this to happen
294 is if you are taking the gradient through a `tf.gather`.)
295 - `_create_slots`
296 (if your optimizer algorithm requires additional variables)
297 - `get_config`
298 (serialization of the optimizer, include all hyper parameters)
299 """
301 # Subclasses should set this to True unless they override `apply_gradients`
302 # with a version that does not have the `experimental_aggregate_gradients`
303 # argument. Older versions of Keras did not have this argument so custom
304 # optimizers may have overridden `apply_gradients` without the
305 # `experimental_aggregate_gradients` argument. Keras only passes
306 # `experimental_aggregate_gradients` if this attribute is True.
307 # Note: This attribute will likely be removed in an upcoming release.
308 _HAS_AGGREGATE_GRAD = False
310 def __init__(self,
311 name,
312 gradient_aggregator=None,
313 gradient_transformers=None,
314 **kwargs):
315 """Create a new Optimizer.
317 This must be called by the constructors of subclasses.
318 Note that Optimizer instances should not bind to a single graph,
319 and so shouldn't keep Tensors as member variables. Generally
320 you should be able to use the _set_hyper()/state.get_hyper()
321 facility instead.
323 This class is stateful and thread-compatible.
325 Example of custom gradient transformations:
327 ```python
328 def my_gradient_transformer(grads_and_vars):
329 # Simple example, double the gradients.
330 return [(2. * g, v) for g, v in grads_and_vars]
332 optimizer = tf.keras.optimizers.SGD(
333 1e-3, gradient_transformers=[my_gradient_transformer])
334 ```
336 Args:
337 name: String. The name to use for momentum accumulator weights created
338 by the optimizer.
339 gradient_aggregator: The function to use to aggregate gradients across
340 devices (when using `tf.distribute.Strategy`). If `None`, defaults to
341 summing the gradients across devices. The function should accept and
342 return a list of `(gradient, variable)` tuples.
343 gradient_transformers: Optional. List of functions to use to transform
344 gradients before applying updates to Variables. The functions are
345 applied after `gradient_aggregator`. The functions should accept and
346 return a list of `(gradient, variable)` tuples.
347 **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
348 `clipnorm`, `global_clipnorm`.
349 If `clipvalue` (float) is set, the gradient of each weight
350 is clipped to be no higher than this value.
351 If `clipnorm` (float) is set, the gradient of each weight
352 is individually clipped so that its norm is no higher than this value.
353 If `global_clipnorm` (float) is set the gradient of all weights is
354 clipped so that their global norm is no higher than this value.
356 Raises:
357 ValueError: in case of any invalid argument.
358 """
359 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"}
360 for k in kwargs:
361 if k not in allowed_kwargs:
362 raise TypeError("Unexpected keyword argument "
363 "passed to optimizer: " + str(k))
364 # checks that all keyword arguments are non-negative.
365 if kwargs[k] is not None and kwargs[k] < 0:
366 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
367 if k == "lr":
368 warnings.warn(
369 "The `lr` argument is deprecated, use `learning_rate` instead.")
371 self._use_locking = True
372 self._init_set_name(name)
373 self._hyper = {}
374 # dict: {variable name : {slot name : variable}}
375 self._slots = {}
376 self._slot_names = []
377 self._weights = []
378 self._iterations = None
380 # For implementing Trackable. Stores information about how to restore
381 # slot variables which have not yet been created
382 # (trackable._CheckpointPosition objects).
383 # {slot_name :
384 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
385 # ... }
386 self._deferred_slot_restorations = {}
388 decay = kwargs.pop("decay", 0.0)
389 if decay < 0.:
390 raise ValueError("decay cannot be less than 0: {}".format(decay))
391 self._initial_decay = decay
393 self._hypers_created = False
394 # Store the distribution strategy object if the optimizer is created inside
395 # strategy scope, so it could be used to create variables later.
396 if distribute_lib.has_strategy():
397 self._distribution_strategy = distribute_lib.get_strategy()
398 else:
399 self._distribution_strategy = None
401 # Configure gradient transformations.
402 if gradient_aggregator is None:
403 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients
404 self.gradient_aggregator = gradient_aggregator
405 if gradient_transformers is None:
406 gradient_transformers = []
407 self.gradient_transformers = gradient_transformers
408 self.clipnorm = kwargs.pop("clipnorm", None)
409 self.global_clipnorm = kwargs.pop("global_clipnorm", None)
410 if self.clipnorm is not None and self.global_clipnorm is not None:
411 raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, "
412 "passed `clipnorm` {}, `global_clipnorm` {}".format(
413 self.clipnorm, self.global_clipnorm))
414 self.clipvalue = kwargs.pop("clipvalue", None)
416 @property
417 def clipnorm(self):
418 """`float` or `None`. If set, clips gradients to a maximum norm."""
419 return self._clipnorm
421 @property
422 def global_clipnorm(self):
423 """`float` or `None`. If set, clips gradients to a maximum norm."""
424 return self._global_clipnorm
426 @clipnorm.setter
427 def clipnorm(self, val):
428 if val is not None and self.gradient_transformers:
429 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
430 "is set. Instead, use the `gradient_transformers` to "
431 "specify clipping and other transformations.")
432 self._clipnorm = val
433 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
434 self._clipnorm)
436 @global_clipnorm.setter
437 def global_clipnorm(self, val):
438 if val is not None and self.gradient_transformers:
439 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
440 "is set. Instead, use the `gradient_transformers` to "
441 "specify clipping and other transformations.")
442 self._global_clipnorm = val
443 self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn(
444 self._global_clipnorm)
446 @property
447 def clipvalue(self):
448 """`float` or `None`. If set, clips gradients to a maximum value."""
449 return self._clipvalue
451 @clipvalue.setter
452 def clipvalue(self, val):
453 if val is not None and self.gradient_transformers:
454 raise ValueError("`clipvalue` cannot be set when `gradient_transformers` "
455 "is set. Instead, use the `gradient_transformers` to "
456 "specify clipping and other transformations.")
457 self._clipvalue = val
458 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn(
459 self._clipvalue)
461 def _transform_loss(self, loss):
462 """Called in `.minimize` to transform loss before computing gradients."""
463 return loss
465 def _get_gradients(self, tape, loss, var_list, grad_loss=None):
466 """Called in `minimize` to compute gradients from loss."""
467 grads = tape.gradient(loss, var_list, grad_loss)
468 return list(zip(grads, var_list))
470 def _transform_unaggregated_gradients(self, grads_and_vars):
471 """Called in `apply_gradients` before gradient aggregation."""
472 return grads_and_vars
474 def _aggregate_gradients(self, grads_and_vars):
475 """Called in `apply_gradients` to aggregate gradients across devices.
477 Note that user subclasses may override this, so the interface should not be
478 changed.
480 Args:
481 grads_and_vars: List of (gradient, variable) pairs.
483 Returns:
484 A list of (aggregrated_gradient, variable) pairs. By default, this calls
485 `self.gradient_aggregator`.
486 """
487 return self.gradient_aggregator(grads_and_vars)
489 def _transform_gradients(self, grads_and_vars):
490 """Called in `apply_gradients` after aggregation."""
491 if self._clipvalue is not None:
492 grads_and_vars = self._clipvalue_fn(grads_and_vars)
493 if self._clipnorm is not None:
494 grads_and_vars = self._clipnorm_fn(grads_and_vars)
495 if self._global_clipnorm is not None:
496 grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
498 for fn in self.gradient_transformers:
499 grads_and_vars = fn(grads_and_vars)
500 return grads_and_vars
502 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
503 """Minimize `loss` by updating `var_list`.
505 This method simply computes gradient using `tf.GradientTape` and calls
506 `apply_gradients()`. If you want to process the gradient before applying
507 then call `tf.GradientTape` and `apply_gradients()` explicitly instead
508 of using this function.
510 Args:
511 loss: `Tensor` or callable. If a callable, `loss` should take no arguments
512 and return the value to minimize. If a `Tensor`, the `tape` argument
513 must be passed.
514 var_list: list or tuple of `Variable` objects to update to minimize
515 `loss`, or a callable returning the list or tuple of `Variable` objects.
516 Use callable when the variable list would otherwise be incomplete before
517 `minimize` since the variables are created at the first time `loss` is
518 called.
519 grad_loss: (Optional). A `Tensor` holding the gradient computed for
520 `loss`.
521 name: (Optional) str. Name for the returned operation.
522 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
523 the tape that computed the `loss` must be provided.
525 Returns:
526 An `Operation` that updates the variables in `var_list`. The `iterations`
527 will be automatically increased by 1.
529 Raises:
530 ValueError: If some of the variables are not `Variable` objects.
532 """
533 grads_and_vars = self._compute_gradients(
534 loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
535 return self.apply_gradients(grads_and_vars, name=name)
537 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
538 """Compute gradients of `loss` for the variables in `var_list`.
540 This is the first part of `minimize()`. It returns a list
541 of (gradient, variable) pairs where "gradient" is the gradient
542 for "variable". Note that "gradient" can be a `Tensor`, an
543 `IndexedSlices`, or `None` if there is no gradient for the
544 given variable.
546 Args:
547 loss: `Tensor` or callable. If a callable, `loss` should take no
548 arguments and return the value to minimize. If a `Tensor`, the `tape`
549 argument must be passed.
550 var_list: list or tuple of `Variable` objects to update to minimize
551 `loss`, or a callable returning the list or tuple of `Variable` objects.
552 Use callable when the variable list would otherwise be incomplete before
553 `minimize` and the variables are created at the first time when `loss`
554 is called.
555 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
556 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
557 the tape that computed the `loss` must be provided.
559 Returns:
560 A list of (gradient, variable) pairs. Variable is always present, but
561 gradient can be `None`.
563 Raises:
564 TypeError: If `var_list` contains anything else than `Variable` objects.
565 ValueError: If some arguments are invalid, or var_list is None.
566 """
567 # TODO(josh11b): Test that we handle weight decay in a reasonable way.
568 if not callable(loss) and tape is None:
569 raise ValueError("`tape` is required when a `Tensor` loss is passed.")
570 tape = tape if tape is not None else backprop.GradientTape()
572 if callable(loss):
573 with tape:
574 if not callable(var_list):
575 tape.watch(var_list)
576 loss = loss()
577 if callable(var_list):
578 var_list = var_list()
580 with tape:
581 loss = self._transform_loss(loss)
583 var_list = nest.flatten(var_list)
584 with ops.name_scope_v2(self._name + "/gradients"):
585 grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
587 self._assert_valid_dtypes([
588 v for g, v in grads_and_vars
589 if g is not None and v.dtype != dtypes.resource
590 ])
592 return grads_and_vars
594 def apply_gradients(self,
595 grads_and_vars,
596 name=None,
597 experimental_aggregate_gradients=True):
598 """Apply gradients to variables.
600 This is the second part of `minimize()`. It returns an `Operation` that
601 applies gradients.
603 The method sums gradients from all replicas in the presence of
604 `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
605 passing `experimental_aggregate_gradients=False`.
607 Example:
609 ```python
610 grads = tape.gradient(loss, vars)
611 grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
612 # Processing aggregated gradients.
613 optimizer.apply_gradients(zip(grads, vars),
614 experimental_aggregate_gradients=False)
616 ```
618 Args:
619 grads_and_vars: List of (gradient, variable) pairs.
620 name: Optional name for the returned operation. Default to the name passed
621 to the `Optimizer` constructor.
622 experimental_aggregate_gradients: Whether to sum gradients from different
623 replicas in the presense of `tf.distribute.Strategy`. If False, it's
624 user responsibility to aggregate the gradients. Default to True.
626 Returns:
627 An `Operation` that applies the specified gradients. The `iterations`
628 will be automatically increased by 1.
630 Raises:
631 TypeError: If `grads_and_vars` is malformed.
632 ValueError: If none of the variables have gradients.
633 RuntimeError: If called in a cross-replica context.
634 """
635 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
636 var_list = [v for (_, v) in grads_and_vars]
638 with ops.name_scope_v2(self._name):
639 # Create iteration if necessary.
640 with ops.init_scope():
641 self._create_all_weights(var_list)
643 if not grads_and_vars:
644 # Distribution strategy does not support reducing an empty list of
645 # gradients
646 return control_flow_ops.no_op()
648 if distribute_lib.in_cross_replica_context():
649 raise RuntimeError(
650 "`apply_gradients() cannot be called in cross-replica context. "
651 "Use `tf.distribute.Strategy.run` to enter replica "
652 "context.")
654 strategy = distribute_lib.get_strategy()
655 if (not experimental_aggregate_gradients and strategy and
656 isinstance(strategy,
657 (parameter_server_strategy.ParameterServerStrategyV1,
658 parameter_server_strategy_v2.ParameterServerStrategyV2,
659 central_storage_strategy.CentralStorageStrategy,
660 central_storage_strategy.CentralStorageStrategyV1))):
661 raise NotImplementedError(
662 "`experimental_aggregate_gradients=False is not supported for "
663 "ParameterServerStrategy and CentralStorageStrategy")
665 apply_state = self._prepare(var_list)
666 if experimental_aggregate_gradients:
667 grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars)
668 grads_and_vars = self._aggregate_gradients(grads_and_vars)
669 grads_and_vars = self._transform_gradients(grads_and_vars)
671 if optimizer_utils.strategy_supports_no_merge_call():
672 return self._distributed_apply(strategy, grads_and_vars, name,
673 apply_state)
674 else:
675 return distribute_lib.get_replica_context().merge_call(
676 functools.partial(self._distributed_apply, apply_state=apply_state),
677 args=(grads_and_vars,),
678 kwargs={
679 "name": name,
680 })
682 def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
683 """`apply_gradients` using a `DistributionStrategy`."""
685 def apply_grad_to_update_var(var, grad):
686 """Apply gradient to variable."""
687 if isinstance(var, ops.Tensor):
688 raise NotImplementedError("Trying to update a Tensor ", var)
690 apply_kwargs = {}
691 if isinstance(grad, indexed_slices.IndexedSlices):
692 if var.constraint is not None:
693 raise RuntimeError(
694 "Cannot use a constraint function on a sparse variable.")
695 if "apply_state" in self._sparse_apply_args:
696 apply_kwargs["apply_state"] = apply_state
697 return self._resource_apply_sparse_duplicate_indices(
698 grad.values, var, grad.indices, **apply_kwargs)
700 if "apply_state" in self._dense_apply_args:
701 apply_kwargs["apply_state"] = apply_state
702 update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
703 if var.constraint is not None:
704 with ops.control_dependencies([update_op]):
705 return var.assign(var.constraint(var))
706 else:
707 return update_op
709 eagerly_outside_functions = ops.executing_eagerly_outside_functions()
710 update_ops = []
711 with name_scope_only_in_function_or_graph(name or self._name):
712 for grad, var in grads_and_vars:
713 # Colocate the update with variables to avoid unnecessary communication
714 # delays. See b/136304694.
715 with distribution.extended.colocate_vars_with(var):
716 with name_scope_only_in_function_or_graph(
717 "update" if eagerly_outside_functions else "update_" +
718 var.op.name):
719 update_op = distribution.extended.update(
720 var, apply_grad_to_update_var, args=(grad,), group=False)
721 if distribute_lib.in_cross_replica_context():
722 # In cross-replica context, extended.update returns a list of
723 # update ops from all replicas (group=False).
724 update_ops.extend(update_op)
725 else:
726 # In replica context, extended.update return the single update op
727 # of current replica.
728 update_ops.append(update_op)
730 any_symbolic = any(isinstance(i, ops.Operation) or
731 tf_utils.is_symbolic_tensor(i) for i in update_ops)
732 if not context.executing_eagerly() or any_symbolic:
733 # If the current context is graph mode or any of the update ops are
734 # symbolic then the step update should be carried out under a graph
735 # context. (eager updates execute immediately)
736 with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access
737 with ops.control_dependencies([control_flow_ops.group(update_ops)]):
738 return self._iterations.assign_add(1, read_value=False)
740 return self._iterations.assign_add(1)
742 def get_gradients(self, loss, params):
743 """Returns gradients of `loss` with respect to `params`.
745 Should be used only in legacy v1 graph mode.
747 Args:
748 loss: Loss tensor.
749 params: List of variables.
751 Returns:
752 List of gradient tensors.
754 Raises:
755 ValueError: In case any gradient cannot be computed (e.g. if gradient
756 function not implemented).
757 """
758 params = nest.flatten(params)
759 with backend.get_graph().as_default(), backend.name_scope(self._name +
760 "/gradients"):
761 grads = gradients.gradients(loss, params)
762 for grad, param in zip(grads, params):
763 if grad is None:
764 raise ValueError("Variable {} has `None` for gradient. "
765 "Please make sure that all of your ops have a "
766 "gradient defined (i.e. are differentiable). "
767 "Common ops without gradient: "
768 "K.argmax, K.round, K.eval.".format(param))
769 return grads
771 def get_updates(self, loss, params):
772 grads = self.get_gradients(loss, params)
773 grads_and_vars = list(zip(grads, params))
774 self._assert_valid_dtypes([
775 v for g, v in grads_and_vars
776 if g is not None and v.dtype != dtypes.resource
777 ])
778 return [self.apply_gradients(grads_and_vars)]
780 def _set_hyper(self, name, value):
781 """set hyper `name` to value. value can be callable, tensor, numeric."""
782 if isinstance(value, trackable.Trackable):
783 self._track_trackable(value, name, overwrite=True)
784 if name not in self._hyper:
785 self._hyper[name] = value
786 else:
787 prev_value = self._hyper[name]
788 if (callable(prev_value)
789 or isinstance(prev_value,
790 (ops.Tensor, int, float,
791 learning_rate_schedule.LearningRateSchedule))
792 or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
793 self._hyper[name] = value
794 else:
795 backend.set_value(self._hyper[name], value)
797 def _get_hyper(self, name, dtype=None):
798 if not self._hypers_created:
799 self._create_hypers()
800 value = self._hyper[name]
801 if isinstance(value, learning_rate_schedule.LearningRateSchedule):
802 return value
803 if callable(value):
804 value = value()
805 if dtype:
806 return math_ops.cast(value, dtype)
807 else:
808 return value
810 def _create_slots(self, var_list):
811 pass
813 def _create_all_weights(self, var_list):
814 """Creates all weights, including iterations, hyperparameters and slot vars.
816 This will add newly created variables to `optimizer.weights`.
818 New variables are only created when this method is called the first time, or
819 when called with different variables in the var_list.
821 Args:
822 var_list: list or tuple of `Variable` objects that will be minimized
823 using this optimizer.
824 """
826 _ = self.iterations
827 self._create_hypers()
828 self._create_slots(var_list)
830 def __getattribute__(self, name):
831 """Overridden to support hyperparameter access."""
832 try:
833 return super(OptimizerV2, self).__getattribute__(name)
834 except AttributeError as e:
835 # Needed to avoid infinite recursion with __setattr__.
836 if name == "_hyper":
837 raise e
838 # Backwards compatibility with Keras optimizers.
839 if name == "lr":
840 name = "learning_rate"
841 if name in self._hyper:
842 return self._get_hyper(name)
843 raise e
845 def __dir__(self):
846 result = set(super(OptimizerV2, self).__dir__())
847 if "_hyper" in result:
848 result |= self._hyper.keys()
849 if "learning_rate" in self._hyper.keys():
850 result.add("lr")
851 return list(result)
853 def __setattr__(self, name, value):
854 """Override setattr to support dynamic hyperparameter setting."""
855 # Backwards compatibility with Keras optimizers.
856 if name == "lr":
857 name = "learning_rate"
858 if hasattr(self, "_hyper") and name in self._hyper:
859 self._set_hyper(name, value)
860 else:
861 super(OptimizerV2, self).__setattr__(name, value)
863 def get_slot_names(self):
864 """A list of names for this optimizer's slots."""
865 return self._slot_names
867 def add_slot(self, var, slot_name, initializer="zeros", shape=None):
868 """Add a new slot variable for `var`.
870 A slot variable is an additional variable associated with `var` to train.
871 It is allocated and managed by optimizers, e.g. `Adam`.
873 Args:
874 var: a `Variable` object.
875 slot_name: name of the slot variable.
876 initializer: initializer of the slot variable
877 shape: (Optional) shape of the slot variable. If not set, it will default
878 to the shape of `var`.
880 Returns:
881 A slot variable.
882 """
883 if slot_name not in self._slot_names:
884 self._slot_names.append(slot_name)
885 var_key = _var_key(var)
886 slot_dict = self._slots.setdefault(var_key, {})
887 weight = slot_dict.get(slot_name, None)
888 if weight is None:
889 if isinstance(initializer, str) or callable(initializer):
890 initializer = initializers.get(initializer)
891 if isinstance(
892 initializer,
893 trackable.CheckpointInitialValueCallable) or (shape is not None):
894 slot_shape = shape
895 else:
896 slot_shape = var.shape
897 initial_value = functools.partial(
898 initializer, shape=slot_shape, dtype=var.dtype)
899 else:
900 initial_value = initializer
902 with self._distribution_strategy_scope():
903 strategy = distribute_lib.get_strategy()
904 if not strategy.extended.variable_created_in_scope(var):
905 raise ValueError(
906 "Trying to create optimizer slot variable under the scope for "
907 "tf.distribute.Strategy ({}), which is different from the scope "
908 "used for the original variable ({}). Make sure the slot "
909 "variables are created under the same strategy scope. This may "
910 "happen if you're restoring from a checkpoint outside the scope"
911 .format(strategy, var))
913 with strategy.extended.colocate_vars_with(var):
914 weight = tf_variables.Variable(
915 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
916 dtype=var.dtype,
917 trainable=False,
918 initial_value=initial_value)
919 backend.track_variable(weight)
920 slot_dict[slot_name] = weight
921 self._restore_slot_variable(
922 slot_name=slot_name, variable=var,
923 slot_variable=weight)
924 self._weights.append(weight)
925 return weight
927 def get_slot(self, var, slot_name):
928 var_key = _var_key(var)
929 slot_dict = self._slots[var_key]
930 return slot_dict[slot_name]
932 def _prepare(self, var_list):
933 keys = set()
934 for var in var_list:
935 if isinstance(var, ds_values.DistributedValues):
936 var_devices = var._devices # pylint: disable=protected-access
937 else:
938 var_devices = [var.device]
939 var_dtype = var.dtype.base_dtype
940 for var_device in var_devices:
941 keys.add((var_device, var_dtype))
943 apply_state = {}
944 for var_device, var_dtype in keys:
945 apply_state[(var_device, var_dtype)] = {}
946 with ops.device(var_device):
947 self._prepare_local(var_device, var_dtype, apply_state)
949 return apply_state
951 def _prepare_local(self, var_device, var_dtype, apply_state):
952 if "learning_rate" in self._hyper:
953 lr_t = array_ops.identity(self._decayed_lr(var_dtype))
954 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
956 def _fallback_apply_state(self, var_device, var_dtype):
957 """Compatibility for subclasses that don't pass apply_state through."""
958 apply_state = {(var_device, var_dtype): {}}
959 self._prepare_local(var_device, var_dtype, apply_state)
960 return apply_state[(var_device, var_dtype)]
962 def _create_hypers(self):
963 if self._hypers_created:
964 return
965 with self._distribution_strategy_scope():
966 # Iterate hyper values deterministically.
967 for name, value in sorted(self._hyper.items()):
968 if isinstance(value,
969 (ops.Tensor, tf_variables.Variable)) or callable(value):
970 # The check for `callable` covers the usage when `value` is a
971 # `LearningRateSchedule`, in which case it does not need to create a
972 # variable.
973 continue
974 else:
975 self._hyper[name] = self.add_weight(
976 name,
977 shape=[],
978 trainable=False,
979 initializer=value,
980 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
981 self._hypers_created = True
983 @property
984 def iterations(self):
985 """Variable. The number of training steps this Optimizer has run."""
986 if self._iterations is None:
987 with self._distribution_strategy_scope():
988 self._iterations = self.add_weight(
989 "iter",
990 shape=[],
991 dtype=dtypes.int64,
992 trainable=False,
993 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
994 self._weights.append(self._iterations)
995 return self._iterations
997 @iterations.setter
998 def iterations(self, variable):
999 if self._iterations is not None:
1000 raise RuntimeError("Cannot set `iterations` to a new Variable after "
1001 "the Optimizer weights have been created")
1002 self._iterations = variable
1003 self._weights.append(self._iterations)
1005 def _decayed_lr(self, var_dtype):
1006 """Get decayed learning rate as a Tensor with dtype=var_dtype."""
1007 lr_t = self._get_hyper("learning_rate", var_dtype)
1008 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
1009 local_step = math_ops.cast(self.iterations, var_dtype)
1010 lr_t = math_ops.cast(lr_t(local_step), var_dtype)
1011 if self._initial_decay > 0.:
1012 local_step = math_ops.cast(self.iterations, var_dtype)
1013 decay_t = math_ops.cast(self._initial_decay, var_dtype)
1014 lr_t = lr_t / (1. + decay_t * local_step)
1015 return lr_t
1017 @abc.abstractmethod
1018 def get_config(self):
1019 """Returns the config of the optimizer.
1021 An optimizer config is a Python dictionary (serializable)
1022 containing the configuration of an optimizer.
1023 The same optimizer can be reinstantiated later
1024 (without any saved state) from this configuration.
1026 Returns:
1027 Python dictionary.
1028 """
1029 config = {"name": self._name}
1030 if self.clipnorm is not None:
1031 config["clipnorm"] = self.clipnorm
1032 if self.clipvalue is not None:
1033 config["clipvalue"] = self.clipvalue
1034 if self.global_clipnorm is not None:
1035 config["global_clipnorm"] = self.global_clipnorm
1036 return config
1038 @classmethod
1039 def from_config(cls, config, custom_objects=None):
1040 """Creates an optimizer from its config.
1042 This method is the reverse of `get_config`,
1043 capable of instantiating the same optimizer from the config
1044 dictionary.
1046 Args:
1047 config: A Python dictionary, typically the output of get_config.
1048 custom_objects: A Python dictionary mapping names to additional Python
1049 objects used to create this optimizer, such as a function used for a
1050 hyperparameter.
1052 Returns:
1053 An optimizer instance.
1054 """
1055 if "lr" in config:
1056 config["learning_rate"] = config.pop("lr")
1057 if "learning_rate" in config:
1058 if isinstance(config["learning_rate"], dict):
1059 config["learning_rate"] = learning_rate_schedule.deserialize(
1060 config["learning_rate"], custom_objects=custom_objects)
1061 return cls(**config)
1063 def _serialize_hyperparameter(self, hyperparameter_name):
1064 """Serialize a hyperparameter that can be a float, callable, or Tensor."""
1065 value = self._hyper[hyperparameter_name]
1066 if isinstance(value, learning_rate_schedule.LearningRateSchedule):
1067 return learning_rate_schedule.serialize(value)
1068 if callable(value):
1069 return value()
1070 if tensor_util.is_tf_type(value):
1071 return backend.get_value(value)
1072 return value
1074 def variables(self):
1075 """Returns variables of this Optimizer based on the order created."""
1076 return self._weights
1078 @property
1079 def weights(self):
1080 """Returns variables of this Optimizer based on the order created."""
1081 return self._weights
1083 def get_weights(self):
1084 """Returns the current weights of the optimizer.
1086 The weights of an optimizer are its state (ie, variables).
1087 This function returns the weight values associated with this
1088 optimizer as a list of Numpy arrays. The first value is always the
1089 iterations count of the optimizer, followed by the optimizer's state
1090 variables in the order they were created. The returned list can in turn
1091 be used to load state into similarly parameterized optimizers.
1093 For example, the RMSprop optimizer for this simple model returns a list of
1094 three values-- the iteration count, followed by the root-mean-square value
1095 of the kernel and bias of the single Dense layer:
1097 >>> opt = tf.keras.optimizers.RMSprop()
1098 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1099 >>> m.compile(opt, loss='mse')
1100 >>> data = np.arange(100).reshape(5, 20)
1101 >>> labels = np.zeros(5)
1102 >>> results = m.fit(data, labels) # Training.
1103 >>> len(opt.get_weights())
1104 3
1106 Returns:
1107 Weights values as a list of numpy arrays.
1108 """
1109 params = self.weights
1110 return backend.batch_get_value(params)
1112 # TODO(tanzheny): Maybe share this logic with base_layer.
1113 def set_weights(self, weights):
1114 """Set the weights of the optimizer.
1116 The weights of an optimizer are its state (ie, variables).
1117 This function takes the weight values associated with this
1118 optimizer as a list of Numpy arrays. The first value is always the
1119 iterations count of the optimizer, followed by the optimizer's state
1120 variables in the order they are created. The passed values are used to set
1121 the new state of the optimizer.
1123 For example, the RMSprop optimizer for this simple model takes a list of
1124 three values-- the iteration count, followed by the root-mean-square value
1125 of the kernel and bias of the single Dense layer:
1127 >>> opt = tf.keras.optimizers.RMSprop()
1128 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1129 >>> m.compile(opt, loss='mse')
1130 >>> data = np.arange(100).reshape(5, 20)
1131 >>> labels = np.zeros(5)
1132 >>> results = m.fit(data, labels) # Training.
1133 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
1134 >>> opt.set_weights(new_weights)
1135 >>> opt.iterations
1136 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
1138 Args:
1139 weights: weight values as a list of numpy arrays.
1140 """
1141 params = self.weights
1142 if len(params) != len(weights):
1143 raise ValueError(
1144 "You called `set_weights(weights)` on optimizer " + self._name +
1145 " with a weight list of length " + str(len(weights)) +
1146 ", but the optimizer was expecting " + str(len(params)) +
1147 " weights. Provided weights: " + str(weights)[:50] + "...")
1148 if not params:
1149 return
1150 weight_value_tuples = []
1151 param_values = backend.batch_get_value(params)
1152 for pv, p, w in zip(param_values, params, weights):
1153 if pv.shape != w.shape:
1154 raise ValueError("Optimizer weight shape " + str(pv.shape) +
1155 " not compatible with "
1156 "provided weight shape " + str(w.shape))
1157 weight_value_tuples.append((p, w))
1158 backend.batch_set_value(weight_value_tuples)
1160 def add_weight(self,
1161 name,
1162 shape,
1163 dtype=None,
1164 initializer="zeros",
1165 trainable=None,
1166 synchronization=tf_variables.VariableSynchronization.AUTO,
1167 aggregation=tf_variables.VariableAggregation.NONE):
1169 if dtype is None:
1170 dtype = dtypes.float32
1171 if isinstance(initializer, str) or callable(initializer):
1172 initializer = initializers.get(initializer)
1174 if synchronization == tf_variables.VariableSynchronization.ON_READ:
1175 if trainable:
1176 raise ValueError(
1177 "Synchronization value can be set to "
1178 "VariableSynchronization.ON_READ only for non-trainable variables. "
1179 "You have specified trainable=True and "
1180 "synchronization=VariableSynchronization.ON_READ.")
1181 else:
1182 # Set trainable to be false when variable is to be synced on read.
1183 trainable = False
1184 elif trainable is None:
1185 trainable = True
1187 variable = self._add_variable_with_custom_getter(
1188 name=name,
1189 shape=shape,
1190 getter=base_layer_utils.make_variable,
1191 overwrite=True,
1192 initializer=initializer,
1193 dtype=dtype,
1194 trainable=trainable,
1195 use_resource=True,
1196 synchronization=synchronization,
1197 aggregation=aggregation)
1198 backend.track_variable(variable)
1200 return variable
1202 def _init_set_name(self, name, zero_based=True):
1203 if not name:
1204 self._name = backend.unique_object_name(
1205 generic_utils.to_snake_case(self.__class__.__name__),
1206 zero_based=zero_based)
1207 else:
1208 self._name = name
1210 def _assert_valid_dtypes(self, tensors):
1211 """Asserts tensors are all valid types (see `_valid_dtypes`).
1213 Args:
1214 tensors: Tensors to check.
1216 Raises:
1217 ValueError: If any tensor is not a valid type.
1218 """
1219 valid_dtypes = self._valid_dtypes()
1220 for t in tensors:
1221 dtype = t.dtype.base_dtype
1222 if dtype not in valid_dtypes:
1223 raise ValueError("Invalid type %r for %s, expected: %s." %
1224 (dtype, t.name, [v for v in valid_dtypes]))
1226 def _valid_dtypes(self):
1227 """Valid types for loss, variables and gradients.
1229 Subclasses should override to allow other float types.
1231 Returns:
1232 Valid types for loss, variables and gradients.
1233 """
1234 return _DEFAULT_VALID_DTYPES
1236 def _call_if_callable(self, param):
1237 """Call the function if param is callable."""
1238 return param() if callable(param) else param
1240 def _resource_apply_dense(self, grad, handle, apply_state):
1241 """Add ops to apply dense gradients to the variable `handle`.
1243 Args:
1244 grad: a `Tensor` representing the gradient.
1245 handle: a `Tensor` of dtype `resource` which points to the variable to be
1246 updated.
1247 apply_state: A dict which is used across multiple apply calls.
1249 Returns:
1250 An `Operation` which updates the value of the variable.
1251 """
1252 raise NotImplementedError("Must be implemented in subclasses.")
1254 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
1255 **kwargs):
1256 """Add ops to apply sparse gradients to `handle`, with repeated indices.
1258 Optimizers which override this method must deal with repeated indices. See
1259 the docstring of `_apply_sparse_duplicate_indices` for details. By default
1260 the correct behavior, to sum non-unique indices and their associated
1261 gradients, is enforced by first pre-processing `grad` and `indices` and
1262 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1263 with duplicate indices may instead override this method to avoid the
1264 overhead of summing.
1266 Args:
1267 grad: a `Tensor` representing the gradient for the affected indices.
1268 handle: a `Tensor` of dtype `resource` which points to the variable to be
1269 updated.
1270 indices: a `Tensor` of integral type representing the indices for which
1271 the gradient is nonzero. Indices may be repeated.
1272 **kwargs: May optionally contain `apply_state`
1274 Returns:
1275 An `Operation` which updates the value of the variable.
1276 """
1277 summed_grad, unique_indices = _deduplicate_indexed_slices(
1278 values=grad, indices=indices)
1279 return self._resource_apply_sparse(summed_grad, handle, unique_indices,
1280 **kwargs)
1282 def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1283 """Add ops to apply sparse gradients to the variable `handle`.
1285 Similar to `_apply_sparse`, the `indices` argument to this method has been
1286 de-duplicated. Optimizers which deal correctly with non-unique indices may
1287 instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1288 overhead.
1290 Args:
1291 grad: a `Tensor` representing the gradient for the affected indices.
1292 handle: a `Tensor` of dtype `resource` which points to the variable to be
1293 updated.
1294 indices: a `Tensor` of integral type representing the indices for which
1295 the gradient is nonzero. Indices are unique.
1296 apply_state: A dict which is used across multiple apply calls.
1298 Returns:
1299 An `Operation` which updates the value of the variable.
1300 """
1301 raise NotImplementedError("Must be implemented in subclasses.")
1303 def _resource_scatter_add(self, x, i, v):
1304 with ops.control_dependencies([
1305 gen_resource_variable_ops.ResourceScatterAdd(
1306 resource=x.handle, indices=i, updates=v)
1307 ]):
1308 return x.value()
1310 def _resource_scatter_update(self, x, i, v):
1311 with ops.control_dependencies(
1312 [gen_resource_variable_ops.ResourceScatterUpdate(
1313 resource=x.handle, indices=i, updates=v)]):
1314 return x.value()
1316 @property
1317 @layer_utils.cached_per_instance
1318 def _dense_apply_args(self):
1319 return tf_inspect.getfullargspec(self._resource_apply_dense).args
1321 @property
1322 @layer_utils.cached_per_instance
1323 def _sparse_apply_args(self):
1324 return tf_inspect.getfullargspec(self._resource_apply_sparse).args
1326 # ---------------
1327 # For implementing the trackable interface
1328 # ---------------
1330 def _restore_slot_variable(self, slot_name, variable, slot_variable):
1331 """Restore a newly created slot variable's value."""
1332 variable_key = _var_key(variable)
1333 deferred_restorations = self._deferred_slot_restorations.get(
1334 slot_name, {}).pop(variable_key, [])
1335 # Iterate over restores, highest restore UID first to minimize the number
1336 # of assignments.
1337 deferred_restorations.sort(key=lambda position: position.restore_uid,
1338 reverse=True)
1339 for checkpoint_position in deferred_restorations:
1340 checkpoint_position.restore(slot_variable)
1342 def _create_or_restore_slot_variable(
1343 self, slot_variable_position, slot_name, variable):
1344 """Restore a slot variable's value, possibly creating it.
1346 Called when a variable which has an associated slot variable is created or
1347 restored. When executing eagerly, we create the slot variable with a
1348 restoring initializer.
1350 No new variables are created when graph building. Instead,
1351 _restore_slot_variable catches these after normal creation and adds restore
1352 ops to the graph. This method is nonetheless important when graph building
1353 for the case when a slot variable has already been created but `variable`
1354 has just been added to a dependency graph (causing us to realize that the
1355 slot variable needs to be restored).
1357 Args:
1358 slot_variable_position: A `trackable._CheckpointPosition` object
1359 indicating the slot variable `Trackable` object to be restored.
1360 slot_name: The name of this `Optimizer`'s slot to restore into.
1361 variable: The variable object this slot is being created for.
1362 """
1363 variable_key = _var_key(variable)
1364 slot_dict = self._slots.get(variable_key, {})
1365 slot_variable = slot_dict.get(slot_name, None)
1366 if (slot_variable is None and context.executing_eagerly() and
1367 slot_variable_position.is_simple_variable()
1368 # Defer slot variable creation if there is an active variable creator
1369 # scope. Generally we'd like to eagerly create/restore slot variables
1370 # when possible, but this may mean that scopes intended to catch
1371 # `variable` also catch its eagerly created slot variable
1372 # unintentionally (specifically make_template would add a dependency on
1373 # a slot variable if not for this case). Deferring is mostly harmless
1374 # (aside from double initialization), and makes variable creator scopes
1375 # behave the same way they do when graph building.
1376 #
1377 # One notable case is with distribution strategy, which uses variable
1378 # creator scope but always desires the `variable` and the slot to use
1379 # the same scope, thus we can safely eagerly create/restore slot
1380 # variables.
1381 and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access
1382 self._distribution_strategy)):
1383 initializer = trackable.CheckpointInitialValueCallable(
1384 checkpoint_position=slot_variable_position)
1385 slot_variable = self.add_slot(
1386 var=variable,
1387 initializer=initializer,
1388 slot_name=slot_name,
1389 shape=slot_variable_position.value_shape())
1390 # Slot variables are not owned by any one object (because we don't want to
1391 # save the slot variable if the optimizer is saved without the non-slot
1392 # variable, or if the non-slot variable is saved without the optimizer;
1393 # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1394 # variable, variable)). So we don't _track_ slot variables anywhere, and
1395 # instead special-case this dependency and otherwise pretend it's a normal
1396 # graph.
1397 if slot_variable is not None:
1398 # If we've either made this slot variable, or if we've pulled out an
1399 # existing slot variable, we should restore it.
1400 slot_variable_position.restore(slot_variable)
1401 else:
1402 # We didn't make the slot variable. Defer restoring until it gets created
1403 # normally. We keep a list rather than the one with the highest restore
1404 # UID in case slot variables have their own dependencies, in which case
1405 # those could differ between restores.
1406 self._deferred_slot_restorations.setdefault(
1407 slot_name, {}).setdefault(variable_key, []).append(
1408 slot_variable_position)
1410 @contextlib.contextmanager
1411 def _distribution_strategy_scope(self):
1412 """Returns the `tf.distribute.Strategy` this optimizer was created under."""
1413 if self._distribution_strategy and not distribute_lib.has_strategy():
1414 with self._distribution_strategy.scope():
1415 yield self._distribution_strategy.scope()
1416 else:
1417 yield
1420def _var_key(var):
1421 """Key for representing a primary variable, for looking up slots.
1423 In graph mode the name is derived from the var shared name.
1424 In eager mode the name is derived from the var unique id.
1425 If distribution strategy exists, get the primary variable first.
1427 Args:
1428 var: the variable.
1430 Returns:
1431 the unique name of the variable.
1432 """
1434 # pylint: disable=protected-access
1435 # Get the distributed variable if it exists.
1436 if hasattr(var, "_distributed_container"):
1437 var = var._distributed_container()
1438 if var._in_graph_mode:
1439 return var._shared_name
1440 return var._unique_id
1443def _get_slot_key_from_var(var, slot_name):
1444 """Get the slot key for the variable: var_name/slot_name."""
1446 name = _var_key(var)
1447 return name + "/" + slot_name
1450class RestoredOptimizer(OptimizerV2):
1451 """A non-functional Optimizer implementation for checkpoint compatibility.
1453 Holds slot variables and hyperparameters when an optimizer is restored from a
1454 SavedModel. These variables may be referenced in functions along with ops
1455 created by the original optimizer, but currently we do not support using the
1456 optimizer object iself (e.g. through `apply_gradients`).
1457 """
1458 # TODO(allenl): Make the restored optimizer functional by tracing its apply
1459 # methods.
1461 def __init__(self):
1462 super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1463 self._hypers_created = True
1465 def get_config(self):
1466 # TODO(allenl): Save and restore the Optimizer's config
1467 raise NotImplementedError(
1468 "Restoring functional Optimizers from SavedModels is not currently "
1469 "supported. Please file a feature request if this limitation bothers "
1470 "you.")
1472revived_types.register_revived_type(
1473 "tf_deprecated_optimizer",
1474 lambda obj: isinstance(obj, OptimizerV2),
1475 versions=[revived_types.VersionedTypeRegistration(
1476 object_factory=lambda proto: RestoredOptimizer(),
1477 version=1,
1478 min_producer_version=1,
1479 min_consumer_version=1,
1480 setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access
1481 )])