Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/optimizer.py: 22%
410 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class of optimizer."""
17import abc
18import platform
19import re
21import tensorflow.compat.v2 as tf
22from absl import logging
24from keras.src import backend
25from keras.src import initializers
26from keras.src.dtensor import utils as dtensor_utils
27from keras.src.optimizers import utils as optimizer_utils
28from keras.src.optimizers.schedules import learning_rate_schedule
29from keras.src.utils import tf_utils
31# isort: off
32from tensorflow.python.util.tf_export import keras_export
33from tensorflow.tools.docs import doc_controls
36class _BaseOptimizer(tf.__internal__.tracking.AutoTrackable):
37 """Optimizer base class, which only supports non-distribute use case."""
39 def __init__(
40 self,
41 name,
42 weight_decay=None,
43 clipnorm=None,
44 clipvalue=None,
45 global_clipnorm=None,
46 use_ema=False,
47 ema_momentum=0.99,
48 ema_overwrite_frequency=None,
49 jit_compile=True,
50 **kwargs,
51 ):
52 self.name = name
53 self.weight_decay = weight_decay
54 self.clipnorm = clipnorm
55 self.global_clipnorm = global_clipnorm
56 self.clipvalue = clipvalue
57 self.use_ema = use_ema
58 # Optimizer only benefits from XLA when training on GPU. So if no
59 # GPU is found, we turn off XLA.
60 if (
61 jit_compile
62 and tf_utils.can_jit_compile()
63 and tf.config.list_physical_devices("GPU")
64 ):
65 self.jit_compile = True
66 else:
67 self.jit_compile = False
69 if platform.system() == "Darwin" and platform.processor() == "arm":
70 logging.warning(
71 "At this time, the v2.11+ optimizer "
72 f"`tf.keras.optimizers.{self.__class__.__name__}` runs slowly "
73 "on M1/M2 Macs, please use the legacy Keras optimizer "
74 "instead, located at "
75 f"`tf.keras.optimizers.legacy.{self.__class__.__name__}`."
76 )
78 if use_ema:
79 # Verify the arguments related to EMA.
80 if ema_momentum > 1 or ema_momentum < 0:
81 raise ValueError(
82 "`ema_momentum` must be in the range [0, 1]. "
83 f"Received: ema_momentum={ema_momentum}"
84 )
85 if ema_overwrite_frequency and (
86 not isinstance(ema_overwrite_frequency, int)
87 or ema_overwrite_frequency < 1
88 ):
89 raise ValueError(
90 "`ema_overwrite_frequency` must be an integer > 1 or None. "
91 "Received: ema_overwrite_frequency="
92 f"{ema_overwrite_frequency}"
93 )
94 self.ema_momentum = ema_momentum
95 self.ema_overwrite_frequency = ema_overwrite_frequency
97 if self.clipnorm is not None and self.global_clipnorm is not None:
98 raise ValueError(
99 "At most one of `clipnorm` and `global_clipnorm` can "
100 f"be set. Received: clipnorm={self.clipnorm}, "
101 f"global_clipnorm={self.global_clipnorm}."
102 )
104 self._variables = []
105 self._create_iteration_variable()
106 self._process_kwargs(kwargs)
108 def _create_iteration_variable(self):
109 """Create the iterations counter variable."""
110 with tf.init_scope():
111 # Lift the variable creation to init scope to avoid environment
112 # issue.
113 self._iterations = tf.Variable(
114 0, name="iteration", dtype=tf.int64, trainable=False
115 )
116 self._variables.append(self._iterations)
118 def _process_kwargs(self, kwargs):
119 # Remove the `is_legacy_optimizer` arg, which is for serialization only.
120 kwargs.pop("is_legacy_optimizer", None)
121 lr = kwargs.pop("lr", None)
122 if lr:
123 logging.warning(
124 "`lr` is deprecated in Keras optimizer, please use "
125 "`learning_rate` or use the legacy optimizer, e.g.,"
126 f"tf.keras.optimizers.legacy.{self.__class__.__name__}."
127 )
128 legacy_kwargs = {
129 "decay",
130 "gradient_aggregator",
131 "gradient_transformers",
132 }
133 for k in kwargs:
134 if k in legacy_kwargs:
135 raise ValueError(
136 f"{k} is deprecated in the new Keras optimizer, please "
137 "check the docstring for valid arguments, or use the "
138 "legacy optimizer, e.g., "
139 f"tf.keras.optimizers.legacy.{self.__class__.__name__}."
140 )
141 else:
142 raise TypeError(
143 f"{k} is not a valid argument, kwargs should be empty "
144 " for `optimizer_experimental.Optimizer`."
145 )
147 def _create_or_restore_slot_variable(self, **kwargs):
148 raise ValueError(
149 "You are trying to restore a checkpoint from a legacy Keras "
150 "optimizer into a v2.11+ Optimizer, which can cause "
151 "errors. Please update the optimizer referenced in your code "
152 "to be an instance of "
153 "`tf.keras.optimizers.legacy.Optimizer`, e.g.: "
154 f"`tf.keras.optimizers.legacy.{self.__class__.__name__}`."
155 )
157 def _var_key(self, variable):
158 """Get a unique identifier of the given variable."""
159 # Get the distributed variable if it exists.
160 # TODO(b/199214315): replace _unique_id with ref() after fixing ref()
161 # issues on AggregatingVariable.
162 return variable._unique_id
164 def _deduplicate_sparse_grad(self, grads):
165 """Deduplicate sparse gradient.
167 For sparse gradients, i.e., gradient is of type `tf.IndexedSlices`,
168 it is possible that `gradient.indices` has duplicated indices.
169 This function adds up values for the duplicated indices, and returns
170 a `tf.IndexedSlices` with indices of unique values.
171 """
172 processed_grads = []
173 for grad in grads:
174 if isinstance(grad, tf.IndexedSlices):
175 values = grad.values
176 indices = grad.indices
177 unique_indices, new_index_positions = tf.unique(indices)
178 summed_values = tf.math.unsorted_segment_sum(
179 values, new_index_positions, tf.shape(unique_indices)[0]
180 )
181 processed_grads.append(
182 tf.IndexedSlices(
183 summed_values, unique_indices, grad.dense_shape
184 )
185 )
186 else:
187 processed_grads.append(grad)
189 return processed_grads
191 @abc.abstractmethod
192 def update_step(self, gradient, variable):
193 """Function to update variable value based on given gradients.
195 This method must be implemented in customized optimizers.
197 Args:
198 gradient: backpropagated gradient of the given variable.
199 variable: variable whose value needs to be updated.
201 Returns:
202 An `Operation` that applies the specified gradients.
204 """
205 raise NotImplementedError
207 @tf.function(jit_compile=True)
208 def _update_step_xla(self, gradient, variable, key):
209 """A wrapper of `update_step` to enable XLA acceleration.
211 Due to `tf.function` tracing mechanism, for (gradient, variable) pairs
212 of the same shape and dtype, the execution graph always invoke the first
213 pair it has seen. Thus, we need a `key` argument to make each (gradient,
214 variable) pair unique. In additions, XLA cannot understand string input,
215 so the key is an integer.
217 Args:
218 gradient: backpropagated gradient of the given variable.
219 variable: variable whose value needs to be updated.
220 key (int): a unique key that identifies the variable.
222 Returns:
223 An `Operation` that applies the specified gradients.
224 """
225 return self._update_step(gradient, variable)
227 def _update_step(self, gradient, variable):
228 if getattr(variable, "_unique_id", None) is None:
229 # Variable has no `_unique_id` if called during `model.save()`, in
230 # which case we do not want to update the variable.
231 return
232 if self._var_key(variable) not in self._index_dict:
233 raise KeyError(
234 f"The optimizer cannot recognize variable {variable.name}. "
235 "This usually means you are trying to call the optimizer to "
236 "update different parts of the model separately. Please call "
237 "`optimizer.build(variables)` with the full list of trainable "
238 "variables before the training loop or use legacy optimizer "
239 f"`tf.keras.optimizers.legacy.{self.__class__.__name__}."
240 )
241 self.update_step(gradient, variable)
243 def compute_gradients(self, loss, var_list, tape=None):
244 """Compute gradients of loss on trainable variables.
246 Args:
247 loss: `Tensor` or callable. If a callable, `loss` should take no
248 arguments and return the value to minimize.
249 var_list: list or tuple of `Variable` objects to update to minimize
250 `loss`, or a callable returning the list or tuple of `Variable`
251 objects. Use callable when the variable list would otherwise be
252 incomplete before `minimize` since the variables are created at the
253 first time `loss` is called.
254 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
255 `Tensor`, the tape that computed the `loss` must be provided.
257 Returns:
258 A list of (gradient, variable) pairs. Variable is always present, but
259 gradient can be `None`.
260 """
261 if not callable(loss) and tape is None:
262 raise ValueError(
263 "`tape` is required when a `Tensor` loss is passed. "
264 f"Received: loss={loss}, tape={tape}."
265 )
266 if tape is None:
267 tape = tf.GradientTape()
268 if callable(loss):
269 with tape:
270 if not callable(var_list):
271 tape.watch(var_list)
272 loss = loss()
273 if callable(var_list):
274 var_list = var_list()
276 grads = tape.gradient(loss, var_list)
277 return list(zip(grads, var_list))
279 def _clip_gradients(self, grads):
280 clipped_grads = []
281 if self.clipnorm and self.clipnorm > 0:
282 for g in grads:
283 if g is None:
284 clipped_grads.append(g)
285 else:
286 clipped_grads.append(tf.clip_by_norm(g, self.clipnorm))
287 return clipped_grads
289 if self.global_clipnorm and self.global_clipnorm > 0:
290 return tf.clip_by_global_norm(grads, self.global_clipnorm)[0]
292 if self.clipvalue and self.clipvalue > 0:
293 for g in grads:
294 if g is None:
295 clipped_grads.append(g)
296 else:
297 clipped_grads.append(
298 tf.clip_by_value(
299 g,
300 clip_value_min=-self.clipvalue,
301 clip_value_max=self.clipvalue,
302 )
303 )
304 return clipped_grads
306 return grads
308 @property
309 def iterations(self):
310 """The number of training steps this `optimizer` has run.
312 By default, iterations would be incremented by one every time
313 `apply_gradients()` is called.
314 """
315 return self._iterations
317 @iterations.setter
318 def iterations(self, variable):
319 if getattr(self, "_built", False):
320 raise RuntimeError(
321 "Cannot set `iterations` to a new Variable after "
322 "the Optimizer weights have been created. Here it is "
323 f"attempting to set `iterations` to {variable}."
324 "Usually this means you are trying to set `iterations`"
325 " after calling `apply_gradients()`. Please set "
326 "`iterations` before calling `apply_gradients()`."
327 )
328 self._iterations = variable
330 @property
331 def learning_rate(self):
332 if not hasattr(self, "_learning_rate") or self._learning_rate is None:
333 raise ValueError(
334 "Missing learning rate, please set self.learning_rate at"
335 " optimizer creation time."
336 )
337 lr = self._learning_rate
338 if isinstance(lr, learning_rate_schedule.LearningRateSchedule):
339 # If the optimizer takes in LearningRateSchedule, then each call to
340 # learning_rate would return `self._current_learning_rate`, which is
341 # updated at each call to `apply_gradients`.
342 return self._current_learning_rate
343 return lr
345 @learning_rate.setter
346 def learning_rate(self, learning_rate):
347 if isinstance(
348 learning_rate, learning_rate_schedule.LearningRateSchedule
349 ):
350 self._learning_rate = learning_rate
351 else:
352 if isinstance(
353 self._learning_rate, learning_rate_schedule.LearningRateSchedule
354 ):
355 raise TypeError(
356 "This optimizer was created with a `LearningRateSchedule`"
357 " object as its `learning_rate` constructor argument, "
358 "hence its learning rate is not settable. If you need the"
359 " learning rate to be settable, you should instantiate "
360 "the optimizer with a float `learning_rate` argument."
361 )
362 self._learning_rate.assign(learning_rate)
364 @property
365 @doc_controls.do_not_generate_docs
366 def lr(self):
367 """Alias of `learning_rate()`.
369 `lr()` is heavily called in workflows using `optimizer_v2.OptimizerV2`,
370 so we keep it for backward compabitliy.
371 """
372 return self.learning_rate
374 @lr.setter
375 def lr(self, learning_rate):
376 self.learning_rate = learning_rate
378 def _build_learning_rate(self, learning_rate):
379 with tf.init_scope():
380 if isinstance(
381 learning_rate, learning_rate_schedule.LearningRateSchedule
382 ):
383 # Create a variable to hold the current learning rate.
384 current_learning_rate = tf.convert_to_tensor(
385 learning_rate(self.iterations)
386 )
387 self._current_learning_rate = tf.Variable(
388 current_learning_rate,
389 name="current_learning_rate",
390 dtype=current_learning_rate.dtype,
391 trainable=False,
392 )
393 return learning_rate
395 return tf.Variable(
396 learning_rate,
397 name="learning_rate",
398 dtype=backend.floatx(),
399 trainable=False,
400 )
402 @abc.abstractmethod
403 def build(self, var_list):
404 """Initialize the optimizer's variables, such as momemtum variables.
406 This function has to be implemented by subclass optimizers, and subclass
407 optimizers need to call `super().build(var_list)`.
409 Args:
410 var_list: List of model variables to build optimizers on. For example,
411 SGD optimizer with momentum will store one momentum variable
412 corresponding to each model variable.
413 """
414 if getattr(self, "_built", False):
415 return
416 self._build_index_dict(var_list)
417 if self.use_ema:
418 self._model_variables_moving_average = []
419 for var in var_list:
420 # Make a copy of the model variables, we will use the copy to
421 # store the moving average of model variables.
422 self._model_variables_moving_average.append(
423 self.add_variable_from_reference(
424 var, "average", initial_value=var
425 )
426 )
428 def _build_index_dict(self, var_list):
429 """Build variable to index dictionary.
431 Build a dictionary that maps variable to the index of it in the given
432 var_list.
434 Args:
435 var_list: List of variables to build index dict on.
437 Returns:
438 None
439 """
440 self._index_dict = {}
441 for i, var in enumerate(var_list):
442 var_key = self._var_key(var)
443 self._index_dict[var_key] = i
445 def add_variable(self, shape, dtype=None, initializer="zeros", name=None):
446 """Create an optimizer variable.
448 Args:
449 shape: A list of integers, a tuple of integers, or a 1-D Tensor of
450 type int32. Defaults to scalar if unspecified.
451 dtype: The DType of the optimizer variable to be created. Defaults to
452 `tf.keras.backend.floatx` if unspecified.
453 initializer: string or callable. Initializer instance.
454 name: The name of the optimizer variable to be created.
456 Returns:
457 An optimizer variable, in the format of tf.Variable.
459 """
460 if isinstance(initializer, str):
461 initializer = initializers.get(initializer)
462 if dtype is None:
463 dtype = backend.floatx()
464 if shape is None:
465 shape = []
466 variable = tf.Variable(
467 initial_value=initializer(shape, dtype), name=name, trainable=False
468 )
469 self._variables.append(variable)
470 return variable
472 def add_variable_from_reference(
473 self, model_variable, variable_name, shape=None, initial_value=None
474 ):
475 """Create an optimizer variable from model variable.
477 Create an optimizer variable based on the information of model variable.
478 For example, in SGD optimizer momemtum, for each model variable, a
479 corresponding momemtum variable is created of the same shape and dtype.
481 Args:
482 model_variable: tf.Variable. The corresponding model variable to the
483 optimizer variable to be created.
484 variable_name: String. The name prefix of the optimizer variable to be
485 created. The create variables name will follow the pattern
486 `{variable_name}/{model_variable.name}`, e.g., `momemtum/dense_1`.
487 shape: List or Tuple, defaults to None. The shape of the optimizer
488 variable to be created. If None, the created variable will have the
489 same shape as `model_variable`.
490 initial_value: A Tensor, or Python object convertible to a Tensor,
491 defaults to None. The initial value of the optimizer variable, if
492 None, the initial value will be default to 0.
494 Returns:
495 An optimizer variable.
496 """
497 if initial_value is None:
498 if shape is None:
499 if model_variable.shape.rank is None:
500 # When the rank is None, we cannot get a concrete
501 # `model_variable.shape`, we use dynamic shape.
502 initial_value = tf.zeros_like(
503 model_variable, dtype=model_variable.dtype
504 )
505 else:
506 # We cannot always use `zeros_like`, because some cases
507 # the shape exists while values don't.
508 initial_value = tf.zeros(
509 model_variable.shape, dtype=model_variable.dtype
510 )
511 else:
512 initial_value = tf.zeros(shape, dtype=model_variable.dtype)
513 variable = tf.Variable(
514 initial_value=initial_value,
515 name=f"{variable_name}/{model_variable._shared_name}",
516 dtype=model_variable.dtype,
517 trainable=False,
518 )
519 self._variables.append(variable)
520 return variable
522 def minimize(self, loss, var_list, tape=None):
523 """Minimize `loss` by updating `var_list`.
525 This method simply computes gradient using `tf.GradientTape` and calls
526 `apply_gradients()`. If you want to process the gradient before applying
527 then call `tf.GradientTape` and `apply_gradients()` explicitly instead
528 of using this function.
530 Args:
531 loss: `Tensor` or callable. If a callable, `loss` should take no
532 arguments and return the value to minimize.
533 var_list: list or tuple of `Variable` objects to update to minimize
534 `loss`, or a callable returning the list or tuple of `Variable`
535 objects. Use callable when the variable list would otherwise be
536 incomplete before `minimize` since the variables are created at the
537 first time `loss` is called.
538 tape: (Optional) `tf.GradientTape`.
540 Returns:
541 None
542 """
543 grads_and_vars = self.compute_gradients(loss, var_list, tape)
544 self.apply_gradients(grads_and_vars)
546 def _compute_current_learning_rate(self):
547 if isinstance(
548 self._learning_rate, learning_rate_schedule.LearningRateSchedule
549 ):
550 # Compute the current learning rate at the beginning of variable
551 # update.
552 if hasattr(self, "_current_learning_rate"):
553 self._current_learning_rate.assign(
554 self._learning_rate(self.iterations)
555 )
556 else:
557 current_learning_rate = tf.convert_to_tensor(
558 self._learning_rate(self.iterations)
559 )
560 self._current_learning_rate = tf.Variable(
561 current_learning_rate,
562 name="current_learning_rate",
563 dtype=current_learning_rate.dtype,
564 trainable=False,
565 )
567 def exclude_from_weight_decay(self, var_list=None, var_names=None):
568 """Exclude variables from weight decay.
570 This method must be called before the optimizer's `build` method is
571 called. You can set specific variables to exclude out, or set a list of
572 strings as the anchor words, if any of which appear in a variable's
573 name, then the variable is excluded.
575 Args:
576 var_list: A list of `tf.Variable`s to exclude from weight decay.
577 var_names: A list of strings. If any string in `var_names` appear
578 in the model variable's name, then this model variable is
579 excluded from weight decay. For example, `var_names=['bias']`
580 excludes all bias variables from weight decay.
581 """
582 if hasattr(self, "_built") and self._built:
583 raise ValueError(
584 "`exclude_from_weight_decay()` can only be configued before "
585 "the optimizer is built."
586 )
588 if var_list:
589 self._exclude_from_weight_decay = [
590 self._var_key(variable) for variable in var_list
591 ]
592 else:
593 self._exclude_from_weight_decay = []
594 self._exclude_from_weight_decay_names = var_names or []
596 def _use_weight_decay(self, variable):
597 exclude_from_weight_decay = getattr(
598 self, "_exclude_from_weight_decay", []
599 )
600 exclude_from_weight_decay_names = getattr(
601 self, "_exclude_from_weight_decay_names", []
602 )
603 variable_id = self._var_key(variable)
604 for exclude_id in exclude_from_weight_decay:
605 if variable_id == exclude_id:
606 return False
607 for name in exclude_from_weight_decay_names:
608 if re.search(name, variable.name) is not None:
609 return False
610 return True
612 def apply_gradients(self, grads_and_vars, name=None):
613 """Apply gradients to variables.
615 Args:
616 grads_and_vars: List of `(gradient, variable)` pairs.
617 name: string, defaults to None. The name of the namescope to
618 use when creating variables. If None, `self.name` will be used.
620 Returns:
621 A `tf.Variable`, representing the current iteration.
623 Raises:
624 TypeError: If `grads_and_vars` is malformed.
625 """
626 self._compute_current_learning_rate()
627 grads_and_vars = list(grads_and_vars)
628 if len(grads_and_vars) == 0:
629 # It is possible that the grad is empty. In this case,
630 # `apply_gradients` is a no-op.
631 return self._iterations
632 grads, trainable_variables = zip(*grads_and_vars)
633 scope_name = name or self.name or "optimizer"
634 with tf.name_scope(scope_name):
635 with tf.init_scope():
636 # Lift variable creation to init scope to avoid environment
637 # issues.
638 self.build(trainable_variables)
639 grads_and_vars = optimizer_utils.filter_empty_gradients(
640 grads_and_vars
641 )
642 if len(list(grads_and_vars)) == 0:
643 # Check again after filtering gradients.
644 return self._iterations
646 grads, trainable_variables = zip(*grads_and_vars)
648 grads = self._clip_gradients(grads)
649 grads = self._deduplicate_sparse_grad(grads)
650 self._apply_weight_decay(trainable_variables)
651 grads_and_vars = list(zip(grads, trainable_variables))
652 iteration = self._internal_apply_gradients(grads_and_vars)
654 # Apply variable constraints after applying gradients.
655 for variable in trainable_variables:
656 if variable.constraint is not None:
657 variable.assign(variable.constraint(variable))
658 return iteration
660 def _apply_weight_decay(self, variables):
661 if self.weight_decay is None:
662 return
663 for variable in variables:
664 if self._use_weight_decay(variable):
665 lr = tf.cast(self.learning_rate, variable.dtype)
666 wd = tf.cast(self.weight_decay, variable.dtype)
667 variable.assign_sub(variable * wd * lr)
669 def _internal_apply_gradients(self, grads_and_vars):
670 """Helper function of apply gradients.
672 This is required for separating out distributed training logic.
674 Args:
675 grads_and_vars: List of (gradient, variable) pairs.
676 """
677 if self.jit_compile:
678 for grad, var in grads_and_vars:
679 self._update_step_xla(grad, var, id(self._var_key(var)))
680 else:
681 for grad, var in grads_and_vars:
682 self._update_step(grad, var)
683 return self.iterations.assign_add(1)
685 def _update_model_variables_moving_average(self, var_list):
686 """Update the stored moving average using the latest value."""
687 if self.use_ema:
688 for var, average in zip(
689 var_list, self._model_variables_moving_average
690 ):
691 average.assign(
692 self.ema_momentum * average + (1 - self.ema_momentum) * var
693 )
695 def _overwrite_model_variables_with_average_value(self, var_list):
696 """Overwrite model variables with its moving average."""
697 if len(var_list) != len(self._model_variables_moving_average):
698 raise ValueError(
699 f"The length of model variables ({len(var_list)}) to "
700 "override does not match the length of model variables "
701 "stored in the optimizer "
702 f"({len(self._model_variables_moving_average)}). Please "
703 "check if the optimizer was called on your model."
704 )
705 self._overwrite_model_variables_with_average_value_helper(var_list)
707 def _overwrite_model_variables_with_average_value_helper(self, var_list):
708 """Helper function that overwrites model variables."""
709 for var, average_var in zip(
710 var_list, self._model_variables_moving_average
711 ):
712 var.assign(average_var)
714 def finalize_variable_values(self, var_list):
715 """Set the final value of model's trainable variables.
717 Sometimes there are some extra steps before ending the variable updates,
718 such as overriding the model variables with its average value.
720 Args:
721 var_list: list of model variables.
722 """
723 if self.use_ema:
724 # If the optimizer uses EMA, then when finalizing, we replace the
725 # model variable value with its moving average stored inside
726 # optimizer.
727 self._overwrite_model_variables_with_average_value(var_list)
729 def _serialize_hyperparameter(self, hyperparameter):
730 """Serialize a hyperparameter that can be a numeric or callable."""
731 if isinstance(
732 hyperparameter, learning_rate_schedule.LearningRateSchedule
733 ):
734 return learning_rate_schedule.serialize(hyperparameter)
735 if isinstance(hyperparameter, tf.Variable):
736 return hyperparameter.numpy()
737 if callable(hyperparameter):
738 return hyperparameter()
739 return hyperparameter
741 def get_config(self):
742 """Returns the config of the optimizer.
744 An optimizer config is a Python dictionary (serializable)
745 containing the configuration of an optimizer.
746 The same optimizer can be reinstantiated later
747 (without any saved state) from this configuration.
749 Subclass optimizer should override this method to include other
750 hyperparameters.
752 Returns:
753 Python dictionary.
754 """
755 config = {
756 "name": self.name,
757 "weight_decay": self.weight_decay,
758 "clipnorm": self.clipnorm,
759 "global_clipnorm": self.global_clipnorm,
760 "clipvalue": self.clipvalue,
761 "use_ema": self.use_ema,
762 "ema_momentum": self.ema_momentum,
763 "ema_overwrite_frequency": self.ema_overwrite_frequency,
764 "jit_compile": self.jit_compile,
765 "is_legacy_optimizer": False,
766 }
767 return config
769 @classmethod
770 def from_config(cls, config, custom_objects=None):
771 """Creates an optimizer from its config.
773 This method is the reverse of `get_config`, capable of instantiating the
774 same optimizer from the config dictionary.
776 Args:
777 config: A Python dictionary, typically the output of get_config.
778 custom_objects: A Python dictionary mapping names to additional
779 user-defined Python objects needed to recreate this optimizer.
781 Returns:
782 An optimizer instance.
783 """
784 if "learning_rate" in config:
785 if isinstance(config["learning_rate"], dict):
786 config["learning_rate"] = learning_rate_schedule.deserialize(
787 config["learning_rate"], custom_objects=custom_objects
788 )
789 return cls(**config)
791 @property
792 def variables(self):
793 """Returns variables of this optimizer."""
794 return CallableList(self._variables)
796 def set_weights(self, weights):
797 """Set the weights of the optimizer.
799 Args:
800 weights: a list of `tf.Variable`s or numpy arrays, the target values
801 of optimizer variables. It should have the same order as
802 `self._variables`.
803 """
804 if not getattr(self, "_built", False):
805 raise ValueError(
806 "You are calling `set_weights()` on an optimizer that has not "
807 "yet been built. Please call "
808 "`optimizer.build(trainable_variables)` to create the "
809 "optimizer weights before calling `set_weights()`."
810 )
812 for variable, weight in zip(self._variables, weights):
813 if variable.shape != weight.shape:
814 raise ValueError(
815 f"Optimizer variable {self._var_key(variable)} has shape "
816 f"{str(variable.shape)} not compatible with provided "
817 f"weight shape {str(weight.shape)}."
818 )
819 variable.assign(weight)
821 def save_own_variables(self, store):
822 """Get the state of this optimizer object."""
823 for i, variable in enumerate(self.variables):
824 store[str(i)] = variable.numpy()
826 def load_own_variables(self, store):
827 """Set the state of this optimizer object."""
828 if len(store.keys()) != len(self.variables):
829 msg = (
830 f"Skipping variable loading for optimizer '{self.name}', "
831 f"because it has {len(self.variables)} variables whereas "
832 f"the saved optimizer has {len(store.keys())} variables. "
833 )
834 if len(self.variables) == 0:
835 msg += (
836 "This is likely because the optimizer has not been "
837 "called/built yet."
838 )
839 logging.warning(msg)
840 return
841 for i, variable in enumerate(self.variables):
842 variable.assign(store[str(i)])
845base_optimizer_keyword_args = """name: String. The name to use
846 for momentum accumulator weights created by
847 the optimizer.
848 weight_decay: Float, defaults to None. If set, weight decay is applied.
849 clipnorm: Float. If set, the gradient of each weight is individually
850 clipped so that its norm is no higher than this value.
851 clipvalue: Float. If set, the gradient of each weight is clipped to be no
852 higher than this value.
853 global_clipnorm: Float. If set, the gradient of all weights is clipped so
854 that their global norm is no higher than this value.
855 use_ema: Boolean, defaults to False. If True, exponential moving average
856 (EMA) is applied. EMA consists of computing an exponential moving
857 average of the weights of the model (as the weight values change after
858 each training batch), and periodically overwriting the weights with
859 their moving average.
860 ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`.
861 This is the momentum to use when computing
862 the EMA of the model's weights:
863 `new_average = ema_momentum * old_average + (1 - ema_momentum) *
864 current_variable_value`.
865 ema_overwrite_frequency: Int or None, defaults to None. Only used if
866 `use_ema=True`. Every `ema_overwrite_frequency` steps of iterations,
867 we overwrite the model variable by its moving average.
868 If None, the optimizer
869 does not overwrite model variables in the middle of training, and you
870 need to explicitly overwrite the variables at the end of training
871 by calling `optimizer.finalize_variable_values()`
872 (which updates the model
873 variables in-place). When using the built-in `fit()` training loop,
874 this happens automatically after the last epoch,
875 and you don't need to do anything.
876 jit_compile: Boolean, defaults to True.
877 If True, the optimizer will use XLA
878 compilation. If no GPU device is found, this flag will be ignored.
879 mesh: optional `tf.experimental.dtensor.Mesh` instance. When provided,
880 the optimizer will be run in DTensor mode, e.g. state
881 tracking variable will be a DVariable, and aggregation/reduction will
882 happen in the global DTensor context.
883 **kwargs: keyword arguments only used for backward compatibility."""
886@keras_export(
887 "keras.optimizers.Optimizer",
888 "keras.optimizers.experimental.Optimizer",
889 v1=[],
890)
891class Optimizer(_BaseOptimizer):
892 """Abstract optimizer base class.
894 This class supports distributed training. If you want to implement your own
895 optimizer, please subclass this class instead of _BaseOptimizer.
897 Args:
898 {{base_optimizer_keyword_args}}
900 ### Usage
902 ```python
903 # Create an optimizer with the desired parameters.
904 opt = keras.optimizers.SGD(learning_rate=0.1)
905 var1, var2 = tf.Variable(1.0), tf.Variable(2.0)
906 # `loss` is a callable that takes no argument and returns the value
907 # to minimize.
908 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
909 # Call minimize to update the list of variables.
910 opt.minimize(loss, var_list=[var1, var2])
911 ```
913 ### Processing gradients before applying them
915 Calling `minimize()` takes care of both computing the gradients and
916 applying them to the variables. If you want to process the gradients
917 before applying them you can instead use the optimizer in three steps:
919 1. Compute the gradients with `tf.GradientTape`.
920 2. Process the gradients as you wish.
921 3. Apply the processed gradients with `apply_gradients()`.
923 Example:
925 ```python
926 # Create an optimizer.
927 opt = tf.keras.optimizers.experimental.SGD(learning_rate=0.1)
928 var1, var2 = tf.Variable(1.0), tf.Variable(2.0)
930 # Compute the gradients for a list of variables.
931 with tf.GradientTape() as tape:
932 loss = 3 * var1 * var1 + 2 * var2 * var2
933 grads = tape.gradient(loss, [var1, var2])
935 # Process the gradients.
936 grads[0] = grads[0] + 1
938 # Ask the optimizer to apply the gradients on variables.
939 opt.apply_gradients(zip(grads, [var1, var2]))
940 ```
942 ### Dynamic learning rate
944 Dynamic learning rate can be achieved by setting learning rate as a built-in
945 or customized `tf.keras.optimizers.schedules.LearningRateSchedule`.
947 Example:
949 >>> var = tf.Variable(np.random.random(size=(1,)))
950 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
951 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
952 >>> opt = tf.keras.optimizers.experimental.SGD(learning_rate=learning_rate)
953 >>> loss = lambda: 3 * var
954 >>> opt.minimize(loss, var_list=[var])
956 ### Gradients clipping
958 Users can clip the gradients before applying to variables by setting
959 `clipnorm`, `clipvalue` and `global_clipnorm`. Notice that `clipnorm` and
960 `global_clipnorm` can only have one being set.
962 Example:
964 >>> opt = tf.keras.optimizers.experimental.SGD(learning_rate=1, clipvalue=1)
965 >>> var1, var2 = tf.Variable(2.0), tf.Variable(2.0)
966 >>> with tf.GradientTape() as tape:
967 ... loss = 2 * var1 + 2 * var2
968 >>> grads = tape.gradient(loss, [var1, var2])
969 >>> print([grads[0].numpy(), grads[1].numpy()])
970 [2.0, 2.0]
971 >>> opt.apply_gradients(zip(grads, [var1, var2]))
972 >>> # Without clipping, we should get [0, 0], but as gradients are clipped
973 >>> # to have max value 1, we get [1.0, 1.0].
974 >>> print([var1.numpy(), var2.numpy()])
975 [1.0, 1.0]
977 ### Using weight decay.
979 Weight decay in certain scenarios can boost the model's performance. Keras
980 has built-in support for weight decay in all optimizers. Users can apply
981 weight decay by setting `weight_decay` argument.
983 >>> opt = tf.keras.optimizers.experimental.SGD(1, weight_decay=0.004)
984 >>> grads, var1, var2 = tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0)
985 >>> # You can exclude variables from weight decay, in this case we
986 >>> # exclude `var2`.
987 >>> opt.exclude_from_weight_decay(var_list=[var2])
988 >>> opt.apply_gradients(zip([grads, grads], [var1, var2]))
989 >>> print([var1.numpy(), var2.numpy()])
990 [1.992, 2.0]
993 ### Using exponential moving average.
995 Empirically it has been found that using the exponential moving average
996 (EMA) of the trained parameters of a deep network achieves a better
997 performance than using its trained parameters directly. Keras optimizers
998 allows users to compute this moving average and overwrite the model
999 variables at desired time.
1001 Example:
1003 ```python
1004 # Create an SGD optimizer with EMA on. `ema_momentum` controls the decay
1005 # rate of the moving average. `ema_momentum=1` means no decay and the stored
1006 # moving average is always model variable's initial value before training.
1007 # Reversely, `ema_momentum=0` is equivalent to not using EMA.
1008 # `ema_overwrite_frequency=3` means every 3 iterations, we overwrite the
1009 # trainable variables with their moving average values.
1010 opt = tf.keras.optimizers.experimental.SGD(
1011 learning_rate=1,
1012 use_ema=True,
1013 ema_momentum=0.5,
1014 ema_overwrite_frequency=3)
1015 var1, var2 = tf.Variable(2.0), tf.Variable(2.0)
1016 with tf.GradientTape() as tape:
1017 loss = var1 + var2
1018 grads = tape.gradient(loss, [var1, var2])
1019 # First iteration: [var1, var2] = [1.0, 1.0]
1020 opt.apply_gradients(zip(grads, [var1, var2]))
1021 print([var1, var2])
1023 # Second iteration: [var1, var2] = [0.0, 0.0]
1024 opt.apply_gradients(zip(grads, [var1, var2]))
1025 print([var1, var2])
1027 # Third iteration, without EMA, we should see [var1, var2] = [-1.0, -1.0],
1028 # but overwriting results in [var1, var2] = [-0.125, -0.125]. The full
1029 # calculation for the moving average of var1 is:
1030 # var1=2*0.5**3+1*(1-0.5)*0.5**2+0*(1-0.5)*0.5**1+(-1)*(1-0.5)=-0.125.
1031 opt.apply_gradients(zip(grads, [var1, var2]))
1032 print([var1, var2])
1034 ```
1035 When optimizer is constructed with `use_ema=True`, in custom training loop,
1036 users can explicitly call `finalize_variable_values()` to overwrite
1037 trainable variables with their EMA values. `finalize_variable_values()` is
1038 by default called at the end of `model.fit()`.
1040 ### Use with `tf.distribute.Strategy`
1042 This optimizer class is `tf.distribute.Strategy` aware, which means it
1043 automatically sums gradients across all replicas. To aggregate gradients
1044 yourself, call `apply_gradients` with `skip_aggregate_gradients` set to
1045 True. This is useful if you need to process aggregated gradients.
1047 ```python
1048 # This example is not runnable, it consists of dummy code for simple
1049 # tutorial.
1050 strategy = tf.distribute.experimental.TPUStrategy()
1052 with strategy.scope():
1053 opt = tf.keras.optimizers.experimental.SGD()
1054 model = magic_function_that_returns_model()
1055 gradients = magic_function_that_returns_gradients()
1056 # Custom logic to aggregate gradients.
1057 gradients = strategy.reduce("SUM", gradients, axis=None)
1058 opt.apply_gradients(zip(gradients, model.trainable_variables),
1059 skip_aggregate_gradients=True)
1060 ```
1062 ### Creating a custom optimizer
1064 If you intend to create your own optimization algorithm, please inherit from
1065 this class and override the following methods:
1067 - `build`: Create your optimizer-related variables, such as `momentums` in
1068 SGD optimizer.
1069 - `update_step`: Implement your optimizer's updating logic.
1070 - `get_config`: serialization of the optimizer, include all hyper
1071 parameters.
1073 Your optimizer would automatically be compatible with tensorflow distributed
1074 training if you subclass `optimizer_experimental.Optimizer`.
1076 """
1078 def __init__(
1079 self,
1080 name,
1081 weight_decay=0,
1082 clipnorm=None,
1083 clipvalue=None,
1084 global_clipnorm=None,
1085 use_ema=False,
1086 ema_momentum=0.99,
1087 ema_overwrite_frequency=None,
1088 jit_compile=True,
1089 **kwargs,
1090 ):
1091 """Create a new Optimizer."""
1092 mesh = kwargs.pop("mesh", None)
1093 self._mesh = mesh
1094 super().__init__(
1095 name,
1096 weight_decay,
1097 clipnorm,
1098 clipvalue,
1099 global_clipnorm,
1100 use_ema,
1101 ema_momentum,
1102 ema_overwrite_frequency,
1103 jit_compile,
1104 **kwargs,
1105 )
1106 self._distribution_strategy = tf.distribute.get_strategy()
1107 self._run_with_dtensor = dtensor_utils.running_with_dtensor_strategy()
1109 def add_variable_from_reference(
1110 self, model_variable, variable_name, shape=None, initial_value=None
1111 ):
1112 if self._mesh:
1113 if initial_value is None:
1114 # Use tf.zeros_like which will propagate the layout information
1115 # from the model weights if any.
1116 initial_value = tf.zeros_like(model_variable)
1117 elif isinstance(initial_value, tf.Tensor):
1118 initial_value = tf.experimental.dtensor.copy_to_mesh(
1119 initial_value,
1120 tf.experimental.dtensor.Layout.replicated(
1121 self._mesh, rank=initial_value.shape.rank
1122 ),
1123 )
1124 variable = tf.experimental.dtensor.DVariable(
1125 initial_value=initial_value,
1126 name=f"{variable_name}/{model_variable._shared_name}",
1127 dtype=model_variable.dtype,
1128 trainable=False,
1129 )
1130 self._variables.append(variable)
1131 return variable
1132 else:
1133 strategy = tf.distribute.get_strategy()
1134 with strategy.extended.colocate_vars_with(model_variable):
1135 return super().add_variable_from_reference(
1136 model_variable, variable_name, shape, initial_value
1137 )
1139 def _create_iteration_variable(self):
1140 if self._mesh:
1141 init_val = tf.constant(0, dtype=tf.int64)
1142 init_val = tf.experimental.dtensor.copy_to_mesh(
1143 init_val,
1144 tf.experimental.dtensor.Layout.replicated(self._mesh, rank=0),
1145 )
1146 with tf.init_scope():
1147 # Lift the variable creation to init scope to avoid environment
1148 # issue.
1149 self._iterations = tf.experimental.dtensor.DVariable(
1150 init_val, name="iteration"
1151 )
1152 self._variables.append(self._iterations)
1153 else:
1154 super()._create_iteration_variable()
1156 def _var_key(self, variable):
1157 """Get a unique identifier of the given variable."""
1159 # Get the distributed variable if it exists.
1160 # TODO(b/197554203): replace _distributed_container() with a public api.
1161 if hasattr(variable, "_distributed_container"):
1162 variable = variable._distributed_container()
1163 elif (
1164 tf_utils.is_extension_type(variable)
1165 and hasattr(variable, "handle")
1166 and hasattr(variable.handle, "_distributed_container")
1167 ):
1168 # For ResourceVariables, the _distributed_container attribute
1169 # is added to their handle tensors.
1170 variable = variable.handle._distributed_container()
1171 return super()._var_key(variable)
1173 def aggregate_gradients(self, grads_and_vars):
1174 """Aggregate gradients on all devices.
1176 By default, we will perform reduce_sum of gradients across devices.
1177 Users can implement their own aggregation logic by overriding this
1178 method.
1180 Args:
1181 grads_and_vars: List of (gradient, variable) pairs.
1183 Returns:
1184 List of (gradient, variable) pairs.
1185 """
1186 if self._mesh or self._run_with_dtensor:
1187 raise NotImplementedError(
1188 "Dtensor doesn't need to manually aggregate gradients"
1189 )
1190 else:
1191 return optimizer_utils.all_reduce_sum_gradients(grads_and_vars)
1193 def apply_gradients(
1194 self,
1195 grads_and_vars,
1196 name=None,
1197 skip_gradients_aggregation=False,
1198 **kwargs,
1199 ):
1200 """Apply gradients to variables.
1202 Args:
1203 grads_and_vars: List of `(gradient, variable)` pairs.
1204 name: string, defaults to None. The name of the namescope to
1205 use when creating variables. If None, `self.name` will be used.
1206 skip_gradients_aggregation: If true, gradients aggregation will not be
1207 performed inside optimizer. Usually this arg is set to True when you
1208 write custom code aggregating gradients outside the optimizer.
1209 **kwargs: keyword arguments only used for backward compatibility.
1211 Returns:
1212 A `tf.Variable`, representing the current iteration.
1214 Raises:
1215 TypeError: If `grads_and_vars` is malformed.
1216 RuntimeError: If called in a cross-replica context.
1217 """
1218 if self._mesh or self._run_with_dtensor:
1219 # Skip any usage of strategy logic for DTensor
1220 return super().apply_gradients(grads_and_vars, name=name)
1222 # `experimental_aggregate_gradients` is an arg in `apply_gradients` of
1223 # v2 optimizer -- the reverse of `skip_gradients_aggregation`.
1224 # We read it from kwargs for backward compatibility.
1225 experimental_aggregate_gradients = kwargs.pop(
1226 "experimental_aggregate_gradients", True
1227 )
1228 if not skip_gradients_aggregation and experimental_aggregate_gradients:
1229 grads_and_vars = self.aggregate_gradients(grads_and_vars)
1230 return super().apply_gradients(grads_and_vars, name=name)
1232 def _apply_weight_decay(self, variables):
1233 # Apply weight decay in distributed setup.
1234 if self.weight_decay is None:
1235 return
1237 def distributed_apply_weight_decay(distribution, variables, **kwargs):
1238 def weight_decay_fn(variable):
1239 if self._use_weight_decay(variable):
1240 lr = tf.cast(self.learning_rate, variable.dtype)
1241 wd = tf.cast(self.weight_decay, variable.dtype)
1242 variable.assign_sub(variable * wd * lr)
1244 for variable in variables:
1245 distribution.extended.update(
1246 variable, weight_decay_fn, group=False
1247 )
1249 tf.__internal__.distribute.interim.maybe_merge_call(
1250 distributed_apply_weight_decay,
1251 self._distribution_strategy,
1252 variables,
1253 )
1255 def _internal_apply_gradients(self, grads_and_vars):
1256 if self._mesh or self._run_with_dtensor:
1257 # Skip any usage of strategy logic for DTensor
1258 return super()._internal_apply_gradients(grads_and_vars)
1260 return tf.__internal__.distribute.interim.maybe_merge_call(
1261 self._distributed_apply_gradients_fn,
1262 self._distribution_strategy,
1263 grads_and_vars,
1264 )
1266 def _overwrite_model_variables_with_average_value_helper(self, var_list):
1267 """Helper function to _overwrite_model_variables_with_average_value.
1269 This function overwrites variables on each device.
1270 Args:
1271 var_list: list of model variables.
1272 """
1273 if self._mesh or self._run_with_dtensor:
1274 # Skip any usage of strategy logic for DTensor
1275 super()._overwrite_model_variables_with_average_value_helper(
1276 var_list
1277 )
1279 strategy = self._distribution_strategy
1280 # Override model variable by the stored average value on all devices.
1281 for var, average_var in zip(
1282 var_list, self._model_variables_moving_average
1283 ):
1284 strategy.extended.update(
1285 var, lambda a, b: a.assign(b), args=(average_var,)
1286 )
1288 def _build_learning_rate(self, learning_rate):
1289 if not self._mesh:
1290 return super()._build_learning_rate(learning_rate)
1292 # For DTensor
1293 variable_creation = tf.experimental.dtensor.DVariable
1294 init_value_convert_fn = lambda x: tf.experimental.dtensor.copy_to_mesh(
1295 x, tf.experimental.dtensor.Layout.replicated(self._mesh, rank=0)
1296 )
1297 if isinstance(
1298 learning_rate, learning_rate_schedule.LearningRateSchedule
1299 ):
1300 current_learning_rate = tf.convert_to_tensor(
1301 learning_rate(self.iterations)
1302 )
1303 current_learning_rate = init_value_convert_fn(current_learning_rate)
1304 # Create a variable to hold the current learning rate.
1305 # Note that the init value `learning_rate(self.iterations)` should
1306 # have the correct layout information from self.iterations.
1307 self._current_learning_rate = variable_creation(
1308 current_learning_rate,
1309 name="learning_rate",
1310 dtype=tf.float32,
1311 )
1312 return learning_rate
1314 init_val = init_value_convert_fn(
1315 tf.constant(learning_rate, dtype=tf.float32)
1316 )
1317 return variable_creation(
1318 init_val,
1319 name="learning_rate",
1320 dtype=backend.floatx(),
1321 trainable=False,
1322 )
1324 def _update_model_variables_moving_average(self, var_list):
1325 """Update the stored moving average using the latest value."""
1326 if self.use_ema:
1328 def update_average(average, var):
1329 average.assign(
1330 self.ema_momentum * average + (1 - self.ema_momentum) * var
1331 )
1333 for var, average in zip(
1334 var_list, self._model_variables_moving_average
1335 ):
1336 self._distribution_strategy.extended.update(
1337 average, update_average, args=(var,), group=False
1338 )
1340 def _distributed_apply_gradients_fn(
1341 self, distribution, grads_and_vars, **kwargs
1342 ):
1343 """`apply_gradients` using a `DistributionStrategy`."""
1345 def apply_grad_to_update_var(var, grad):
1346 if self.jit_compile:
1347 return self._update_step_xla(grad, var, id(self._var_key(var)))
1348 else:
1349 return self._update_step(grad, var)
1351 for grad, var in grads_and_vars:
1352 distribution.extended.update(
1353 var, apply_grad_to_update_var, args=(grad,), group=False
1354 )
1356 if self.use_ema:
1357 _, var_list = zip(*grads_and_vars)
1358 self._update_model_variables_moving_average(var_list)
1359 if self.ema_overwrite_frequency:
1360 # Only when self.ema_overwrite_frequency is not None, we
1361 # overwrite the model variables.
1362 should_overwrite_model_vars = (
1363 self.iterations + 1
1364 ) % self.ema_overwrite_frequency == 0
1365 tf.cond(
1366 tf.cast(should_overwrite_model_vars, tf.bool),
1367 true_fn=lambda: self._overwrite_model_variables_with_average_value( # noqa: E501
1368 var_list
1369 ),
1370 false_fn=lambda: None,
1371 )
1372 return self.iterations.assign_add(1)
1375class RestoredOptimizer(Optimizer):
1376 def __init__(self):
1377 super().__init__("RestoredOptimizer")
1379 def get_config(self):
1380 raise NotImplementedError(
1381 "Restoring functional Optimizers from SavedModels is not currently "
1382 "supported. Please file a feature request if this limitation "
1383 "bothers you."
1384 )
1387class CallableList(list):
1388 """Temporary shim to support both `opt.variables()` and `opt.variables`."""
1390 def __call__(self):
1391 return self
1394# Register the optimizer for loading from saved_model purpose.
1395tf.__internal__.saved_model.load.register_revived_type(
1396 "experimentalOptimizer",
1397 lambda obj: isinstance(obj, Optimizer),
1398 versions=[
1399 tf.__internal__.saved_model.load.VersionedTypeRegistration(
1400 object_factory=lambda proto: RestoredOptimizer(),
1401 version=2,
1402 min_producer_version=1,
1403 min_consumer_version=1,
1404 )
1405 ],
1406)
1408Optimizer.__doc__ = Optimizer.__doc__.replace(
1409 "{{base_optimizer_keyword_args}}", base_optimizer_keyword_args
1410)