Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/optimizer.py: 22%
396 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Base class for optimizers."""
17# pylint: disable=g-bad-name
19import abc
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.distribute import distribute_utils
23from tensorflow.python.distribute import reduce_util as ds_reduce_util
24from tensorflow.python.eager import backprop
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import indexed_slices
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gradients
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variable_v1
36from tensorflow.python.ops import variables
37from tensorflow.python.trackable import base as trackable
38from tensorflow.python.training import slot_creator
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import tf_export
43def get_filtered_grad_fn(grad_fn):
44 # `distributed_context.join()` requires that its arguments are parallel
45 # across threads, and in particular that `grads_and_vars` has the same
46 # variables in the same order.
48 # When computing gradients in eager mode with multiple threads, you
49 # can get extra variables with a gradient of `None`. This happens when
50 # those variables are accessed in another thread during the gradient
51 # computation. To get a consistent set of variables, we filter out
52 # those with `None` gradients.
53 def filtered_grad_fn(*args, **kwargs):
54 return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None]
56 return filtered_grad_fn
59def _deduplicate_indexed_slices(values, indices):
60 """Sums `values` associated with any non-unique `indices`.
62 Args:
63 values: A `Tensor` with rank >= 1.
64 indices: A one-dimensional integer `Tensor`, indexing into the first
65 dimension of `values` (as in an IndexedSlices object).
66 Returns:
67 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
68 de-duplicated version of `indices` and `summed_values` contains the sum of
69 `values` slices associated with each unique index.
70 """
71 unique_indices, new_index_positions = array_ops.unique(indices)
72 summed_values = math_ops.unsorted_segment_sum(
73 values, new_index_positions,
74 array_ops.shape(unique_indices)[0])
75 return (summed_values, unique_indices)
78def _var_key(var):
79 """Returns slot key for `var`."""
80 # pylint: disable=protected-access
81 var = distribute_utils.value_container(var)
82 if (distribute_utils.is_distributed_variable(var) and
83 not ops.executing_eagerly_outside_functions()):
84 return (var.graph, var._shared_name)
85 if hasattr(var, "op"):
86 return (var.op.graph, var.op.name)
87 return var._unique_id
88 # pylint: enable=protected-access
91class _OptimizableVariable(metaclass=abc.ABCMeta):
92 """Interface for abstracting over variables in the optimizers."""
94 @abc.abstractmethod
95 def target(self):
96 """Returns the optimization target for this variable."""
97 raise NotImplementedError("Calling an abstract method.")
99 @abc.abstractmethod
100 def update_op(self, optimizer, g):
101 """Returns the update ops for updating the variable."""
102 raise NotImplementedError("Calling an abstract method.")
105class _RefVariableProcessor(_OptimizableVariable):
106 """Processor for Variable."""
108 def __init__(self, v):
109 self._v = v
111 def __str__(self):
112 return "<_RefVariableProcessor(%s)>" % self._v
114 def target(self):
115 return self._v._ref() # pylint: disable=protected-access
117 def update_op(self, optimizer, g):
118 if isinstance(g, ops.Tensor):
119 update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access
120 if self._v.constraint is not None:
121 with ops.control_dependencies([update_op]):
122 return self._v.assign(self._v.constraint(self._v))
123 else:
124 return update_op
125 else:
126 assert isinstance(g, indexed_slices.IndexedSlices), (
127 "Gradient ", g, " is neither a tensor nor IndexedSlices.")
128 if self._v.constraint is not None:
129 raise RuntimeError(
130 "Cannot use a constraint function on a sparse variable.")
131 # pylint: disable=protected-access
132 return optimizer._apply_sparse_duplicate_indices(g, self._v)
135class _DenseReadResourceVariableProcessor(_OptimizableVariable):
136 """Processor for dense ResourceVariables."""
138 def __init__(self, v):
139 self._v = v
141 def target(self):
142 return self._v
144 def update_op(self, optimizer, g):
145 # pylint: disable=protected-access
146 update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
147 if self._v.constraint is not None:
148 with ops.control_dependencies([update_op]):
149 return self._v.assign(self._v.constraint(self._v))
150 else:
151 return update_op
154class _DenseResourceVariableProcessor(_OptimizableVariable):
155 """Processor for dense ResourceVariables."""
157 def __init__(self, v):
158 self._v = v
160 def target(self):
161 return self._v
163 def update_op(self, optimizer, g):
164 # pylint: disable=protected-access
165 if isinstance(g, indexed_slices.IndexedSlices):
166 if self._v.constraint is not None:
167 raise RuntimeError(
168 "Cannot use a constraint function on a sparse variable.")
169 return optimizer._resource_apply_sparse_duplicate_indices(
170 g.values, self._v, g.indices)
171 update_op = optimizer._resource_apply_dense(g, self._v)
172 if self._v.constraint is not None:
173 with ops.control_dependencies([update_op]):
174 return self._v.assign(self._v.constraint(self._v))
175 else:
176 return update_op
179class _TensorProcessor(_OptimizableVariable):
180 """Processor for ordinary Tensors.
182 Even though a Tensor can't really be updated, sometimes it is useful to
183 compute the gradients with respect to a Tensor using the optimizer. Updating
184 the Tensor is, of course, unsupported.
185 """
187 def __init__(self, v):
188 self._v = v
190 def target(self):
191 return self._v
193 def update_op(self, optimizer, g):
194 raise NotImplementedError("Trying to update a Tensor ", self._v)
197def _get_processor(v):
198 """The processor of v."""
199 if context.executing_eagerly():
200 if isinstance(v, ops.Tensor):
201 return _TensorProcessor(v)
202 else:
203 return _DenseResourceVariableProcessor(v)
204 if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access
205 # True if and only if `v` was initialized eagerly.
206 return _DenseResourceVariableProcessor(v)
207 if v.op.type == "VarHandleOp":
208 return _DenseResourceVariableProcessor(v)
209 if isinstance(v, variables.Variable):
210 return _RefVariableProcessor(v)
211 if isinstance(v, ops.Tensor):
212 return _TensorProcessor(v)
213 raise NotImplementedError("Trying to optimize unsupported type ", v)
216@tf_export(v1=["train.Optimizer"])
217class Optimizer(
218 # Optimizers inherit from Trackable rather than AutoTrackable
219 # since they do most of their dependency management themselves (slot
220 # variables are special-cased, and non-slot variables are keyed to graphs).
221 trackable.Trackable):
222 """Base class for optimizers.
224 This class defines the API to add Ops to train a model. You never use this
225 class directly, but instead instantiate one of its subclasses such as
226 `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
228 ### Usage
230 ```python
231 # Create an optimizer with the desired parameters.
232 opt = GradientDescentOptimizer(learning_rate=0.1)
233 # Add Ops to the graph to minimize a cost by updating a list of variables.
234 # "cost" is a Tensor, and the list of variables contains tf.Variable
235 # objects.
236 opt_op = opt.minimize(cost, var_list=<list of variables>)
237 ```
239 In the training program you will just have to run the returned Op.
241 ```python
242 # Execute opt_op to do one step of training:
243 opt_op.run()
244 ```
246 ### Processing gradients before applying them.
248 Calling `minimize()` takes care of both computing the gradients and
249 applying them to the variables. If you want to process the gradients
250 before applying them you can instead use the optimizer in three steps:
252 1. Compute the gradients with `compute_gradients()`.
253 2. Process the gradients as you wish.
254 3. Apply the processed gradients with `apply_gradients()`.
256 Example:
258 ```python
259 # Create an optimizer.
260 opt = GradientDescentOptimizer(learning_rate=0.1)
262 # Compute the gradients for a list of variables.
263 grads_and_vars = opt.compute_gradients(loss, <list of variables>)
265 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
266 # need to the 'gradient' part, for example cap them, etc.
267 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
269 # Ask the optimizer to apply the capped gradients.
270 opt.apply_gradients(capped_grads_and_vars)
271 ```
273 ### Gating Gradients
275 Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
276 argument that controls the degree of parallelism during the application of
277 the gradients.
279 The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
281 <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
282 the maximum parallelism in execution, at the cost of some non-reproducibility
283 in the results. For example the two gradients of `matmul` depend on the input
284 values: With `GATE_NONE` one of the gradients could be applied to one of the
285 inputs _before_ the other gradient is computed resulting in non-reproducible
286 results.
288 <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
289 they are used. This prevents race conditions for Ops that generate gradients
290 for multiple inputs where the gradients depend on the inputs.
292 <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
293 before any one of them is used. This provides the least parallelism but can
294 be useful if you want to process all gradients before applying any of them.
296 ### Slots
298 Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
299 allocate and manage additional variables associated with the variables to
300 train. These are called <i>Slots</i>. Slots have names and you can ask the
301 optimizer for the names of the slots that it uses. Once you have a slot name
302 you can ask the optimizer for the variable it created to hold the slot value.
304 This can be useful if you want to log debug a training algorithm, report stats
305 about the slots, etc.
307 @compatibility(TF2)
308 `tf.compat.v1.train.Optimizer` can be used in eager mode and `tf.function`,
309 but it is not recommended. Please use the subclasses of
310 `tf.keras.optimizers.Optimizer` instead in TF2. Please see [Basic training
311 loops](https://www.tensorflow.org/guide/basic_training_loops) or
312 [Writing a training loop from scratch]
313 (https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)
314 for examples.
316 If your TF1 code contains a `tf.compat.v1.train.Optimizer` symbol, whether it
317 is used with or without a `tf.estimator.Estimator`, you cannot simply replace
318 that with the corresponding `tf.keras.optimizers.Optimizer`s. To migrate to
319 TF2, it is advised the whole training program used with `Estimator` to be
320 migrated to Keras `Model.fit` based or TF2 custom training loops.
322 #### Structural Mapping to Native TF2
324 Before:
326 ```python
327 sgd_op = tf.compat.v1.train.GradientDescentOptimizer(3.0)
328 opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
329 opt_op.run(session=session)
330 ```
332 After:
334 ```python
335 sgd = tf.keras.optimizers.SGD(3.0)
336 sgd.minimize(cost_fn, [var0, var1])
337 ```
339 #### How to Map Arguments
341 | TF1 Arg Name | TF2 Arg Name | Note |
342 | :-------------------- | :-------------- | :------------------------- |
343 | `use_locking` | Not supported | - |
344 | `name` | `name. ` | - |
346 #### Before & After Usage Example
348 Before:
350 >>> g = tf.compat.v1.Graph()
351 >>> with g.as_default():
352 ... var0 = tf.compat.v1.Variable([1.0, 2.0])
353 ... var1 = tf.compat.v1.Variable([3.0, 4.0])
354 ... cost = 5 * var0 + 3 * var1
355 ... global_step = tf.compat.v1.Variable(
356 ... tf.compat.v1.zeros([], tf.compat.v1.int64), name='global_step')
357 ... init_op = tf.compat.v1.initialize_all_variables()
358 ... sgd_op = tf.compat.v1.train.GradientDescentOptimizer(3.0)
359 ... opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
360 >>> session = tf.compat.v1.Session(graph=g)
361 >>> session.run(init_op)
362 >>> opt_op.run(session=session)
363 >>> print(session.run(var0))
364 [-14. -13.]
367 After:
368 >>> var0 = tf.Variable([1.0, 2.0])
369 >>> var1 = tf.Variable([3.0, 4.0])
370 >>> cost_fn = lambda: 5 * var0 + 3 * var1
371 >>> sgd = tf.keras.optimizers.SGD(3.0)
372 >>> sgd.minimize(cost_fn, [var0, var1])
373 >>> print(var0.numpy())
374 [-14. -13.]
376 @end_compatibility
379 """
381 # Values for gate_gradients.
382 GATE_NONE = 0
383 GATE_OP = 1
384 GATE_GRAPH = 2
386 def __init__(self, use_locking, name):
387 """Create a new Optimizer.
389 This must be called by the constructors of subclasses.
391 Args:
392 use_locking: Bool. If True apply use locks to prevent concurrent updates
393 to variables.
394 name: A non-empty string. The name to use for accumulators created
395 for the optimizer.
397 Raises:
398 ValueError: If name is malformed.
399 """
400 if not name:
401 raise ValueError("Must specify the optimizer name")
402 self._use_locking = use_locking
403 self._name = name
404 # Dictionary of slots.
405 # {slot_name :
406 # {_var_key(variable_to_train): slot_for_the_variable, ... },
407 # ... }
408 self._slots = {}
409 self._non_slot_dict = {}
410 # For implementing Trackable. Stores information about how to restore
411 # slot variables which have not yet been created
412 # (trackable._CheckpointPosition objects).
413 # {slot_name :
414 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
415 # ... }
416 self._deferred_slot_restorations = {}
418 # TODO(isaprykin): When using a DistributionStrategy, and when an
419 # optimizer is created in each replica, it might be dangerous to
420 # rely on some Optimizer methods. When such methods are called on a
421 # per-replica optimizer, an exception needs to be thrown. We do
422 # allow creation per-replica optimizers however, because the
423 # compute_gradients()->apply_gradients() sequence is safe.
425 def get_name(self):
426 return self._name
428 def minimize(self, loss, global_step=None, var_list=None,
429 gate_gradients=GATE_OP, aggregation_method=None,
430 colocate_gradients_with_ops=False, name=None,
431 grad_loss=None):
432 """Add operations to minimize `loss` by updating `var_list`.
434 This method simply combines calls `compute_gradients()` and
435 `apply_gradients()`. If you want to process the gradient before applying
436 them call `compute_gradients()` and `apply_gradients()` explicitly instead
437 of using this function.
439 Args:
440 loss: A `Tensor` containing the value to minimize.
441 global_step: Optional `Variable` to increment by one after the
442 variables have been updated.
443 var_list: Optional list or tuple of `Variable` objects to update to
444 minimize `loss`. Defaults to the list of variables collected in
445 the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
446 gate_gradients: How to gate the computation of gradients. Can be
447 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
448 aggregation_method: Specifies the method used to combine gradient terms.
449 Valid values are defined in the class `AggregationMethod`.
450 colocate_gradients_with_ops: If True, try colocating gradients with
451 the corresponding op.
452 name: Optional name for the returned operation.
453 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
455 Returns:
456 An Operation that updates the variables in `var_list`. If `global_step`
457 was not `None`, that operation also increments `global_step`.
459 Raises:
460 ValueError: If some of the variables are not `Variable` objects.
462 @compatibility(eager)
463 When eager execution is enabled, `loss` should be a Python function that
464 takes no arguments and computes the value to be minimized. Minimization (and
465 gradient computation) is done with respect to the elements of `var_list` if
466 not None, else with respect to any trainable variables created during the
467 execution of the `loss` function. `gate_gradients`, `aggregation_method`,
468 `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
469 execution is enabled.
470 @end_compatibility
471 """
472 grads_and_vars = self.compute_gradients(
473 loss, var_list=var_list, gate_gradients=gate_gradients,
474 aggregation_method=aggregation_method,
475 colocate_gradients_with_ops=colocate_gradients_with_ops,
476 grad_loss=grad_loss)
478 vars_with_grad = [v for g, v in grads_and_vars if g is not None]
479 if not vars_with_grad:
480 raise ValueError(
481 "No gradients provided for any variable, check your graph for ops"
482 " that do not support gradients, between variables %s and loss %s." %
483 ([str(v) for _, v in grads_and_vars], loss))
485 return self.apply_gradients(grads_and_vars, global_step=global_step,
486 name=name)
488 def compute_gradients(self, loss, var_list=None,
489 gate_gradients=GATE_OP,
490 aggregation_method=None,
491 colocate_gradients_with_ops=False,
492 grad_loss=None):
493 """Compute gradients of `loss` for the variables in `var_list`.
495 This is the first part of `minimize()`. It returns a list
496 of (gradient, variable) pairs where "gradient" is the gradient
497 for "variable". Note that "gradient" can be a `Tensor`, an
498 `IndexedSlices`, or `None` if there is no gradient for the
499 given variable.
501 @compatibility(TF2)
502 `tf.keras.optimizers.Optimizer` in TF2 does not provide a
503 `compute_gradients` method, and you should use a `tf.GradientTape` to
504 obtain the gradients:
506 ```python
507 @tf.function
508 def train step(inputs):
509 batch_data, labels = inputs
510 with tf.GradientTape() as tape:
511 predictions = model(batch_data, training=True)
512 loss = tf.keras.losses.CategoricalCrossentropy(
513 reduction=tf.keras.losses.Reduction.NONE)(labels, predictions)
514 gradients = tape.gradient(loss, model.trainable_variables)
515 optimizer.apply_gradients(zip(gradients, model.trainable_variables))
516 ```
518 Args:
519 loss: A Tensor containing the value to minimize or a callable taking
520 no arguments which returns the value to minimize. When eager execution
521 is enabled it must be a callable.
522 var_list: Optional list or tuple of `tf.Variable` to update to minimize
523 `loss`. Defaults to the list of variables collected in the graph
524 under the key `GraphKeys.TRAINABLE_VARIABLES`.
525 gate_gradients: How to gate the computation of gradients. Can be
526 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
527 aggregation_method: Specifies the method used to combine gradient terms.
528 Valid values are defined in the class `AggregationMethod`.
529 colocate_gradients_with_ops: If True, try colocating gradients with
530 the corresponding op.
531 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
533 Returns:
534 A list of (gradient, variable) pairs. Variable is always present, but
535 gradient can be `None`.
537 Raises:
538 TypeError: If `var_list` contains anything else than `Variable` objects.
539 ValueError: If some arguments are invalid.
540 RuntimeError: If called with eager execution enabled and `loss` is
541 not callable.
543 @compatibility(eager)
544 When eager execution is enabled, `gate_gradients`, `aggregation_method`,
545 and `colocate_gradients_with_ops` are ignored.
546 @end_compatibility
547 """
548 if callable(loss):
549 with backprop.GradientTape() as tape:
550 if var_list is not None:
551 tape.watch(var_list)
552 loss_value = loss()
554 # Scale loss if using a "mean" loss reduction and multiple replicas.
555 # Have to be careful to call distribute_utils.get_loss_reduction()
556 # *after* loss() is evaluated, so we know what loss reduction it uses.
557 # TODO(josh11b): Test that we handle weight decay in a reasonable way.
558 loss_value = self._scale_loss(loss_value)
560 if var_list is None:
561 var_list = tape.watched_variables()
562 # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
563 # to be executed.
564 with ops.control_dependencies([loss_value]):
565 grads = tape.gradient(loss_value, var_list, grad_loss)
566 return list(zip(grads, var_list))
568 # Non-callable/Tensor loss case
569 if context.executing_eagerly():
570 raise RuntimeError(
571 "`loss` passed to Optimizer.compute_gradients should "
572 "be a function when eager execution is enabled.")
574 # Scale loss if using a "mean" loss reduction and multiple replicas.
575 loss = self._scale_loss(loss)
577 if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
578 Optimizer.GATE_GRAPH]:
579 raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
580 "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
581 gate_gradients)
582 self._assert_valid_dtypes([loss])
583 if grad_loss is not None:
584 self._assert_valid_dtypes([grad_loss])
585 if var_list is None:
586 var_list = (
587 variables.trainable_variables() +
588 ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
589 else:
590 var_list = nest.flatten(var_list)
591 # pylint: disable=protected-access
592 var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
593 # pylint: enable=protected-access
594 processors = [_get_processor(v) for v in var_list]
595 if not var_list:
596 raise ValueError("No variables to optimize.")
597 var_refs = [p.target() for p in processors]
598 grads = gradients.gradients(
599 loss, var_refs, grad_ys=grad_loss,
600 gate_gradients=(gate_gradients == Optimizer.GATE_OP),
601 aggregation_method=aggregation_method,
602 colocate_gradients_with_ops=colocate_gradients_with_ops)
603 if gate_gradients == Optimizer.GATE_GRAPH:
604 grads = control_flow_ops.tuple(grads)
605 grads_and_vars = list(zip(grads, var_list))
606 self._assert_valid_dtypes(
607 [v for g, v in grads_and_vars
608 if g is not None and v.dtype != dtypes.resource])
609 return grads_and_vars
611 @staticmethod
612 def _scale_loss(loss_value):
613 ops.get_default_graph()._is_loss_scaled_by_optimizer = False # pylint: disable=protected-access
614 if distribute_utils.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
615 num_replicas = distribute_lib.get_strategy().num_replicas_in_sync
616 if num_replicas > 1:
617 loss_value *= (1. / num_replicas)
618 ops.get_default_graph()._is_loss_scaled_by_optimizer = True # pylint: disable=protected-access
619 return loss_value
621 def apply_gradients(
622 self,
623 grads_and_vars,
624 global_step=None,
625 name=None,
626 skip_gradients_aggregation=False,
627 ):
628 """Apply gradients to variables.
630 This is the second part of `minimize()`. It returns an `Operation` that
631 applies gradients.
633 @compatibility(TF2)
634 #### How to Map Arguments
636 | TF1 Arg Name | TF2 Arg Name | Note |
637 | :-------------------- | :-------------- | :------------------------- |
638 | `grads_and_vars` | `grads_and_vars`| - |
639 | `global_step` | Not supported. | Use `optimizer.iterations` |
640 | `name` | `name. ` | - |
642 Args:
643 grads_and_vars: List of (gradient, variable) pairs as returned by
644 `compute_gradients()`.
645 global_step: Optional `Variable` to increment by one after the variables
646 have been updated.
647 name: Optional name for the returned operation. Default to the name
648 passed to the `Optimizer` constructor.
649 skip_gradients_aggregation: If true, gradients aggregation will not be
650 performed inside optimizer. Usually this arg is set to True when you
651 write custom code aggregating gradients outside the optimizer.
653 Returns:
654 An `Operation` that applies the specified gradients. If `global_step`
655 was not None, that operation also increments `global_step`.
657 Raises:
658 TypeError: If `grads_and_vars` is malformed.
659 ValueError: If none of the variables have gradients.
660 RuntimeError: If you should use `_distributed_apply()` instead.
661 """
662 # This is a default implementation of apply_gradients() that can be shared
663 # by most optimizers. It relies on the subclass implementing the following
664 # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
666 # TODO(isaprykin): Get rid of `has_strategy()` check by
667 # always calling _distributed_apply(), using the default distribution
668 # as needed.
669 if distribute_lib.has_strategy() and not skip_gradients_aggregation:
670 # Handle DistributionStrategy case.
671 if distribute_lib.in_cross_replica_context():
672 raise RuntimeError("Use `_distributed_apply()` instead of "
673 "`apply_gradients()` in a cross-replica context.")
675 grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
676 return distribute_lib.get_replica_context().merge_call(
677 self._distributed_apply, args=(grads_and_vars, global_step, name))
679 # No DistributionStrategy case.
680 grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
681 if not grads_and_vars:
682 raise ValueError("No variables provided.")
683 converted_grads_and_vars = []
684 for g, v in grads_and_vars:
685 if g is not None:
686 try:
687 # Convert the grad to Tensor or IndexedSlices if necessary.
688 g = indexed_slices.convert_to_tensor_or_indexed_slices(g)
689 except TypeError:
690 raise TypeError(
691 "Gradient must be convertible to a Tensor"
692 " or IndexedSlices, or None: %s" % g)
693 if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)):
694 raise TypeError(
695 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
696 p = _get_processor(v)
697 converted_grads_and_vars.append((g, v, p))
699 converted_grads_and_vars = tuple(converted_grads_and_vars)
700 var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
701 if not var_list:
702 raise ValueError("No gradients provided for any variable: %s." %
703 ([str(v) for _, v, _ in converted_grads_and_vars],))
704 with ops.init_scope():
705 self._create_slots(var_list)
706 update_ops = []
707 with ops.name_scope(name, self._name, skip_on_eager=False) as name:
708 self._prepare()
709 for grad, var, processor in converted_grads_and_vars:
710 if grad is None:
711 continue
712 # We colocate all ops created in _apply_dense or _apply_sparse
713 # on the same device as the variable.
714 # TODO(apassos): figure out how to get the variable name here.
715 if (context.executing_eagerly() or
716 resource_variable_ops.is_resource_variable(var)
717 and not var._in_graph_mode): # pylint: disable=protected-access
718 scope_name = ""
719 else:
720 scope_name = var.op.name
721 with ops.name_scope(
722 "update_" + scope_name,
723 skip_on_eager=False), ops.colocate_with(var):
724 update_ops.append(processor.update_op(self, grad))
725 if global_step is None:
726 apply_updates = self._finish(update_ops, name)
727 else:
728 with ops.control_dependencies([self._finish(update_ops, "update")]):
729 with ops.colocate_with(global_step):
730 if isinstance(
731 global_step, resource_variable_ops.BaseResourceVariable):
732 # TODO(apassos): the implicit read in assign_add is slow; consider
733 # making it less so.
734 apply_updates = resource_variable_ops.assign_add_variable_op(
735 global_step.handle,
736 ops.convert_to_tensor(1, dtype=global_step.dtype),
737 name=name)
738 else:
739 apply_updates = state_ops.assign_add(global_step, 1, name=name)
741 if not context.executing_eagerly():
742 if isinstance(apply_updates, ops.Tensor):
743 apply_updates = apply_updates.op
744 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
745 if apply_updates not in train_op:
746 train_op.append(apply_updates)
748 return apply_updates
750 def _distributed_apply(self,
751 distribution,
752 grads_and_vars,
753 global_step=None,
754 name=None):
755 """A version of `apply_gradients` for cross-replica context.
757 This is a version of `apply_gradients()` for when you are using a
758 `DistributionStrategy` and are in a cross-replica context. If in a
759 replica context, use `apply_gradients()` as normal.
761 Args:
762 distribution: A `DistributionStrategy` object.
763 grads_and_vars: List of (gradient, variable) pairs as returned by
764 `compute_gradients()`, and then aggregated across replicas.
765 global_step: Optional (mirrored) `Variable` to increment by one
766 after the variables have been updated.
767 name: Optional name for the returned operation. Default to the
768 name passed to the `Optimizer` constructor.
770 Returns:
771 An `Operation` that applies the specified gradients across all
772 replicas. If `global_step` was not None, that operation also
773 increments `global_step`
774 """
775 reduced_grads = distribution.extended.batch_reduce_to(
776 ds_reduce_util.ReduceOp.SUM, grads_and_vars)
777 var_list = [v for _, v in grads_and_vars]
778 grads_and_vars = zip(reduced_grads, var_list)
780 # Note that this is called in a cross-replica context.
781 with ops.init_scope():
782 self._create_slots(var_list)
784 def update(v, g):
785 """Apply gradients to a replica variable."""
786 assert v is not None
788 try:
789 # Convert the grad to Tensor or IndexedSlices if necessary.
790 g = indexed_slices.convert_to_tensor_or_indexed_slices(g)
791 except TypeError:
792 raise TypeError("Gradient must be convertible to a Tensor"
793 " or IndexedSlices, or None: %s" % g)
794 if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)):
795 raise TypeError(
796 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
797 p = _get_processor(v)
799 if context.executing_eagerly() or (
800 resource_variable_ops.is_resource_variable(v) and
801 not v._in_graph_mode): # pylint: disable=protected-access
802 scope_name = v.name.split(":")[0]
803 else:
804 scope_name = v.op.name
806 # device_policy is set because non-mirrored tensors will be read in
807 # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
808 # is an example.
809 with ops.name_scope("update_" + scope_name):
810 return p.update_op(self, g)
812 with ops.name_scope(name, self._name) as name:
813 self._prepare()
815 update_ops = [
816 op
817 for grad, var in grads_and_vars
818 for op in distribution.extended.update(
819 var, update, args=(grad,), group=False)
820 ]
822 def finish(self, update_ops):
823 return self._finish(update_ops, "update")
825 non_slot_devices = distribution.extended.non_slot_devices(var_list)
826 finish_updates = distribution.extended.update_non_slot(
827 non_slot_devices, finish, args=(self, update_ops), group=False)
828 if global_step is None:
829 apply_updates = distribution.group(finish_updates, name=name)
830 else:
831 with ops.control_dependencies(finish_updates):
832 apply_updates = distribution.extended.update(
833 global_step, state_ops.assign_add, args=(1,),
834 kwargs={"name": name})
836 if not context.executing_eagerly():
837 if isinstance(apply_updates, ops.Tensor):
838 apply_updates = apply_updates.op
839 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
840 if apply_updates not in train_op:
841 train_op.append(apply_updates)
843 return apply_updates
845 def get_slot(self, var, name):
846 """Return a slot named `name` created for `var` by the Optimizer.
848 Some `Optimizer` subclasses use additional variables. For example
849 `Momentum` and `Adagrad` use variables to accumulate updates. This method
850 gives access to these `Variable` objects if for some reason you need them.
852 Use `get_slot_names()` to get the list of slot names created by the
853 `Optimizer`.
855 Args:
856 var: A variable passed to `minimize()` or `apply_gradients()`.
857 name: A string.
859 Returns:
860 The `Variable` for the slot if it was created, `None` otherwise.
861 """
862 named_slots = self._slots.get(name, None)
863 if not named_slots:
864 return None
865 slot = named_slots.get(_var_key(var), None)
866 if (distribute_utils.is_distributed_variable(slot) and
867 not distribute_utils.is_distributed_variable(var)):
868 # Make sure var and slot are either both DistributedVariable, or both
869 # per replica variables.
870 slot = slot._get_on_device_or_primary() # pylint: disable=protected-access
871 return slot
873 def get_slot_names(self):
874 """Return a list of the names of slots created by the `Optimizer`.
876 See `get_slot()`.
878 Returns:
879 A list of strings.
880 """
881 return sorted(self._slots.keys())
883 def variables(self):
884 """A list of variables which encode the current state of `Optimizer`.
886 Includes slot variables and additional global variables created by the
887 optimizer in the current default graph.
889 Returns:
890 A list of variables.
891 """
892 current_graph = ops.get_default_graph()
894 def _from_current_graph(variable):
895 if variable._in_graph_mode: # pylint: disable=protected-access
896 return variable.op.graph is current_graph
897 else:
898 # No variable.op in eager mode. We don't expect lots of eager graphs,
899 # but behavior should be consistent with graph mode.
900 return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access
902 optimizer_variables = [v for v in self._non_slot_variables()
903 if _from_current_graph(v)]
904 for _, variable_dict in self._slots.items():
905 for _, slot_for_variable in variable_dict.items():
906 if _from_current_graph(slot_for_variable):
907 optimizer_variables.append(slot_for_variable)
908 # Sort variables by name so that the return is deterministic.
909 return sorted(optimizer_variables, key=lambda v: v.name)
911 def _create_non_slot_variable(self, initial_value, name, colocate_with):
912 """Add an extra variable, not associated with a slot."""
913 # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
914 eager = ops.executing_eagerly_outside_functions()
915 graph = None if eager else colocate_with.graph
917 key = (name, graph)
918 v = self._non_slot_dict.get(key, None)
919 if v is None:
920 self._maybe_initialize_trackable()
921 distribution_strategy = distribute_lib.get_strategy()
922 with distribution_strategy.extended.colocate_vars_with(colocate_with):
923 if eager:
924 restored_initial_value = self._preload_simple_restoration(
925 name=name)
926 if restored_initial_value is not None:
927 initial_value = restored_initial_value
928 v = variable_v1.VariableV1(
929 initial_value, name=name, trainable=False,
930 use_resource=resource_variable_ops.is_resource_variable(
931 colocate_with))
932 # Restore this variable by name if necessary, but don't add a
933 # Trackable dependency. Optimizers return the current graph's
934 # non-slot variables from _checkpoint_dependencies explicitly rather
935 # than unconditionally adding dependencies (since there may be multiple
936 # non-slot variables with the same name in different graphs, trying to
937 # save all of them would result in errors).
938 self._handle_deferred_dependencies(name=name, trackable=v)
939 self._non_slot_dict[key] = v
941 return v
943 def _trackable_children(self,
944 save_type=trackable.SaveType.CHECKPOINT,
945 **kwargs):
946 """From Trackable. Gather graph-specific non-slot variables to save."""
947 current_graph_non_slot_variables = {}
948 current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
949 for (name, _), variable_object in sorted(self._non_slot_dict.items(),
950 # Avoid comparing graphs
951 key=lambda item: item[0][0]):
952 # Skip checking for graph key for eager mode since there's only one graph.
953 # This is necessary because there are cases where _trackable_children() is
954 # called in a differenr thread from the main thread (e.g., async
955 # checkpoint) and hence the default graph key would be different.
956 if (context.executing_eagerly()
957 or variable_object._graph_key == current_graph_key): # pylint: disable=protected-access
958 current_graph_non_slot_variables[name] = variable_object
959 current_graph_non_slot_variables.update(
960 super()._trackable_children(save_type, **kwargs)
961 )
962 return current_graph_non_slot_variables
964 def _lookup_dependency(self, name):
965 """From Trackable. Find a non-slot variable in the current graph."""
966 unconditional = super()._lookup_dependency(name)
967 if unconditional is not None:
968 return unconditional
969 graph = None if context.executing_eagerly() else ops.get_default_graph()
970 return self._get_non_slot_variable(name, graph=graph)
972 def _get_non_slot_variable(self, name, graph=None):
973 non_slot = self._non_slot_dict.get((name, graph), None)
974 if distribute_utils.value_container(non_slot) is not non_slot:
975 # This is a mirrored non-slot. In order to enable code like `_finish`
976 # to assign to a non-slot, return the current context replica.
977 return non_slot.get()
978 else:
979 return non_slot
981 def _non_slot_variables(self):
982 """Additional variables created by the `Optimizer`.
984 Returns:
985 A list or tuple of variables.
986 """
987 return self._non_slot_dict.values()
989 def _assert_valid_dtypes(self, tensors):
990 """Asserts tensors are all valid types (see `_valid_dtypes`).
992 Args:
993 tensors: Tensors to check.
995 Raises:
996 ValueError: If any tensor is not a valid type.
997 """
998 valid_dtypes = self._valid_dtypes()
999 for t in tensors:
1000 dtype = t.dtype.base_dtype
1001 if dtype not in valid_dtypes:
1002 raise ValueError(
1003 "Invalid type %r for %s, expected: %s." % (
1004 dtype, t.name, [v for v in valid_dtypes]))
1006 # --------------
1007 # Methods to be implemented by subclasses if they want to use the
1008 # inherited implementation of apply_gradients() or compute_gradients().
1009 # --------------
1010 def _valid_dtypes(self):
1011 """Valid types for loss, variables and gradients.
1013 Subclasses should override to allow other float types.
1015 Returns:
1016 Valid types for loss, variables and gradients.
1017 """
1018 return set(
1019 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
1021 def _create_slots(self, var_list):
1022 """Create all slots needed by the variables.
1024 Args:
1025 var_list: A list of `Variable` objects.
1026 """
1027 # No slots needed by default
1028 pass
1030 def _prepare(self):
1031 """Create all needed tensors before applying gradients.
1033 This is called with the name_scope using the "name" that
1034 users have chosen for the application of gradients.
1035 """
1036 pass
1038 def _apply_dense(self, grad, var):
1039 """Add ops to apply dense gradients to `var`.
1041 Args:
1042 grad: A `Tensor`.
1043 var: A `Variable` object.
1045 Returns:
1046 An `Operation`.
1047 """
1048 raise NotImplementedError()
1050 def _resource_apply_dense(self, grad, handle):
1051 """Add ops to apply dense gradients to the variable `handle`.
1053 Args:
1054 grad: a `Tensor` representing the gradient.
1055 handle: a `Tensor` of dtype `resource` which points to the variable
1056 to be updated.
1058 Returns:
1059 An `Operation` which updates the value of the variable.
1060 """
1061 raise NotImplementedError()
1063 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
1064 """Add ops to apply sparse gradients to `handle`, with repeated indices.
1066 Optimizers which override this method must deal with repeated indices. See
1067 the docstring of `_apply_sparse_duplicate_indices` for details. By default
1068 the correct behavior, to sum non-unique indices and their associated
1069 gradients, is enforced by first pre-processing `grad` and `indices` and
1070 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1071 with duplicate indices may instead override this method to avoid the
1072 overhead of summing.
1074 Args:
1075 grad: a `Tensor` representing the gradient for the affected indices.
1076 handle: a `Tensor` of dtype `resource` which points to the variable
1077 to be updated.
1078 indices: a `Tensor` of integral type representing the indices for
1079 which the gradient is nonzero. Indices may be repeated.
1081 Returns:
1082 An `Operation` which updates the value of the variable.
1083 """
1084 summed_grad, unique_indices = _deduplicate_indexed_slices(
1085 values=grad, indices=indices)
1086 return self._resource_apply_sparse(summed_grad, handle, unique_indices)
1088 def _resource_apply_sparse(self, grad, handle, indices):
1089 """Add ops to apply sparse gradients to the variable `handle`.
1091 Similar to `_apply_sparse`, the `indices` argument to this method has been
1092 de-duplicated. Optimizers which deal correctly with non-unique indices may
1093 instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1094 overhead.
1096 Args:
1097 grad: a `Tensor` representing the gradient for the affected indices.
1098 handle: a `Tensor` of dtype `resource` which points to the variable
1099 to be updated.
1100 indices: a `Tensor` of integral type representing the indices for
1101 which the gradient is nonzero. Indices are unique.
1103 Returns:
1104 An `Operation` which updates the value of the variable.
1105 """
1106 raise NotImplementedError()
1108 def _apply_sparse_duplicate_indices(self, grad, var):
1109 """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
1111 Optimizers which override this method must deal with IndexedSlices objects
1112 such as the following:
1114 IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
1116 The correct interpretation is:
1118 IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
1120 Many optimizers deal incorrectly with repeated indices when updating based
1121 on sparse gradients (e.g. summing squares rather than squaring the sum, or
1122 applying momentum terms multiple times). Adding first is always the correct
1123 behavior, so this is enforced here by reconstructing the IndexedSlices to
1124 have only unique indices, then calling _apply_sparse.
1126 Optimizers which deal correctly with repeated indices may instead override
1127 this method to avoid the overhead of summing indices.
1129 Args:
1130 grad: `IndexedSlices`.
1131 var: A `Variable` object.
1133 Returns:
1134 An `Operation`.
1135 """
1136 summed_values, unique_indices = _deduplicate_indexed_slices(
1137 values=grad.values, indices=grad.indices)
1138 gradient_no_duplicate_indices = indexed_slices.IndexedSlices(
1139 indices=unique_indices,
1140 values=summed_values,
1141 dense_shape=grad.dense_shape)
1142 return self._apply_sparse(gradient_no_duplicate_indices, var)
1144 def _apply_sparse(self, grad, var):
1145 """Add ops to apply sparse gradients to `var`.
1147 The IndexedSlices object passed to `grad` in this function is by default
1148 pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
1149 indices (see its docstring for details). Optimizers which can tolerate or
1150 have correct special cases for duplicate sparse indices may override
1151 `_apply_sparse_duplicate_indices` instead of this function, avoiding that
1152 overhead.
1154 Args:
1155 grad: `IndexedSlices`, with no repeated indices.
1156 var: A `Variable` object.
1158 Returns:
1159 An `Operation`.
1160 """
1161 raise NotImplementedError()
1163 def _finish(self, update_ops, name_scope):
1164 """Do what is needed to finish the update.
1166 This is called with the `name_scope` using the "name" that
1167 users have chosen for the application of gradients.
1169 Args:
1170 update_ops: List of `Operation` objects to update variables. This list
1171 contains the values returned by the `_apply_dense()` and
1172 `_apply_sparse()` calls.
1173 name_scope: String. Name to use for the returned operation.
1175 Returns:
1176 The operation to apply updates.
1177 """
1178 return control_flow_ops.group(*update_ops, name=name_scope)
1180 # --------------
1181 # Utility methods for subclasses.
1182 # --------------
1184 def _slot_dict(self, slot_name):
1185 """Returns a dict for caching slots created under the given name.
1187 Args:
1188 slot_name: Name for the slot.
1190 Returns:
1191 A dict that maps primary `Variable` objects to the slot created
1192 for that variable, under the given slot name.
1193 """
1194 named_slots = self._slots.get(slot_name, None)
1195 if named_slots is None:
1196 named_slots = {}
1197 self._slots[slot_name] = named_slots
1198 return named_slots
1200 def _get_or_make_slot(self, var, val, slot_name, op_name):
1201 """Find or create a slot for a variable.
1203 Args:
1204 var: A `Variable` object.
1205 val: A `Tensor`. The initial value of the slot.
1206 slot_name: Name for the slot.
1207 op_name: Name to use when scoping the Variable that
1208 needs to be created for the slot.
1210 Returns:
1211 A `Variable` object.
1212 """
1213 named_slots = self._slot_dict(slot_name)
1214 if _var_key(var) not in named_slots:
1215 new_slot_variable = slot_creator.create_slot(
1216 var, val, op_name, copy_xla_sharding=True)
1217 self._restore_slot_variable(
1218 slot_name=slot_name, variable=var,
1219 slot_variable=new_slot_variable)
1220 named_slots[_var_key(var)] = new_slot_variable
1221 return named_slots[_var_key(var)]
1223 def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
1224 slot_name, op_name):
1225 """Find or create a slot for a variable, using an Initializer.
1227 Args:
1228 var: A `Variable` object.
1229 initializer: An `Initializer`. The initial value of the slot.
1230 shape: Shape of the initial value of the slot.
1231 dtype: Type of the value of the slot.
1232 slot_name: Name for the slot.
1233 op_name: Name to use when scoping the Variable that
1234 needs to be created for the slot.
1236 Returns:
1237 A `Variable` object.
1238 """
1239 named_slots = self._slot_dict(slot_name)
1240 if _var_key(var) not in named_slots:
1241 new_slot_variable = slot_creator.create_slot_with_initializer(
1242 var, initializer, shape, dtype, op_name, copy_xla_sharding=True)
1243 self._restore_slot_variable(
1244 slot_name=slot_name, variable=var,
1245 slot_variable=new_slot_variable)
1246 named_slots[_var_key(var)] = new_slot_variable
1247 return named_slots[_var_key(var)]
1249 def _zeros_slot(self, var, slot_name, op_name):
1250 """Find or create a slot initialized with 0.0.
1252 Args:
1253 var: A `Variable` object.
1254 slot_name: Name for the slot.
1255 op_name: Name to use when scoping the Variable that
1256 needs to be created for the slot.
1258 Returns:
1259 A `Variable` object.
1260 """
1261 named_slots = self._slot_dict(slot_name)
1262 if _var_key(var) not in named_slots:
1263 new_slot_variable = slot_creator.create_zeros_slot(
1264 var, op_name, copy_xla_sharding=True)
1265 self._restore_slot_variable(
1266 slot_name=slot_name, variable=var,
1267 slot_variable=new_slot_variable)
1268 named_slots[_var_key(var)] = new_slot_variable
1269 return named_slots[_var_key(var)]
1271 # --------------
1272 # For implementing the Trackable interface.
1273 # --------------
1275 def _restore_slot_variable(self, slot_name, variable, slot_variable):
1276 """Restore a newly created slot variable's value."""
1277 variable_key = _var_key(variable)
1278 deferred_restorations = self._deferred_slot_restorations.get(
1279 slot_name, {}).pop(variable_key, [])
1280 # Iterate over restores, highest restore UID first to minimize the number
1281 # of assignments.
1282 deferred_restorations.sort(key=lambda position: position.restore_uid,
1283 reverse=True)
1284 for checkpoint_position in deferred_restorations:
1285 checkpoint_position.restore(slot_variable)
1287 def _create_or_restore_slot_variable(
1288 self, slot_variable_position, slot_name, variable):
1289 """Restore a slot variable's value, possibly creating it.
1291 Called when a variable which has an associated slot variable is created or
1292 restored. When executing eagerly, we create the slot variable with a
1293 restoring initializer.
1295 No new variables are created when graph building. Instead,
1296 _restore_slot_variable catches these after normal creation and adds restore
1297 ops to the graph. This method is nonetheless important when graph building
1298 for the case when a slot variable has already been created but `variable`
1299 has just been added to a dependency graph (causing us to realize that the
1300 slot variable needs to be restored).
1302 Args:
1303 slot_variable_position: A `trackable._CheckpointPosition` object
1304 indicating the slot variable `Trackable` object to be restored.
1305 slot_name: The name of this `Optimizer`'s slot to restore into.
1306 variable: The variable object this slot is being created for.
1307 """
1308 named_slots = self._slot_dict(slot_name)
1309 variable_key = _var_key(variable)
1310 slot_variable = named_slots.get(variable_key, None)
1311 if (slot_variable is None and context.executing_eagerly() and
1312 slot_variable_position.is_simple_variable()
1313 # Defer slot variable creation if there is an active variable creator
1314 # scope. Generally we'd like to eagerly create/restore slot variables
1315 # when possible, but this may mean that scopes intended to catch
1316 # `variable` also catch its eagerly created slot variable
1317 # unintentionally (specifically make_template would add a dependency on
1318 # a slot variable if not for this case). Deferring is mostly harmless
1319 # (aside from double initialization), and makes variable creator scopes
1320 # behave the same way they do when graph building.
1321 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
1322 initializer = trackable.CheckpointInitialValueCallable(
1323 checkpoint_position=slot_variable_position)
1324 # CheckpointInitialValueCallable will ignore the shape and dtype
1325 # parameters but they must be passed.
1326 slot_variable = self._get_or_make_slot_with_initializer(
1327 var=variable,
1328 initializer=initializer,
1329 shape=variable.shape,
1330 dtype=variable.dtype,
1331 slot_name=slot_name,
1332 op_name=self._name)
1333 # Slot variables are not owned by any one object (because we don't want to
1334 # save the slot variable if the optimizer is saved without the non-slot
1335 # variable, or if the non-slot variable is saved without the optimizer;
1336 # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1337 # variable, variable)). So we don't _track_ slot variables anywhere, and
1338 # instead special-case this dependency and otherwise pretend it's a normal
1339 # graph.
1340 if slot_variable is not None:
1341 # If we've either made this slot variable, or if we've pulled out an
1342 # existing slot variable, we should restore it.
1343 slot_variable_position.restore(slot_variable)
1344 else:
1345 # We didn't make the slot variable. Defer restoring until it gets created
1346 # normally. We keep a list rather than the one with the highest restore
1347 # UID in case slot variables have their own dependencies, in which case
1348 # those could differ between restores.
1349 self._deferred_slot_restorations.setdefault(
1350 slot_name, {}).setdefault(variable_key, []).append(
1351 slot_variable_position)
1353 def _call_if_callable(self, param):
1354 """Call the function if param is callable."""
1355 return param() if callable(param) else param