Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/sync_replicas_optimizer.py: 23%
146 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Synchronize replicas for training."""
17from tensorflow.python.distribute import distribute_lib
18from tensorflow.python.framework import indexed_slices
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import data_flow_ops
23from tensorflow.python.ops import state_ops
24from tensorflow.python.ops import variable_v1
25from tensorflow.python.ops import variables
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.training import optimizer
28from tensorflow.python.training import queue_runner
29from tensorflow.python.training import session_manager
30from tensorflow.python.training import session_run_hook
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.tf_export import tf_export
35# Please note that the gradients from replicas are averaged instead of summed
36# (as in the old sync_replicas_optimizer) so you need to increase the learning
37# rate according to the number of replicas. This change is introduced to be
38# consistent with how gradients are aggregated (averaged) within a batch in a
39# replica.
40@tf_export(v1=["train.SyncReplicasOptimizer"])
41class SyncReplicasOptimizer(optimizer.Optimizer):
42 """Class to synchronize, aggregate gradients and pass them to the optimizer.
44 This class is deprecated. For synchronous training, please use [Distribution
45 Strategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute).
47 In a typical asynchronous training environment, it's common to have some
48 stale gradients. For example, with a N-replica asynchronous training,
49 gradients will be applied to the variables N times independently. Depending
50 on each replica's training speed, some gradients might be calculated from
51 copies of the variable from several steps back (N-1 steps on average). This
52 optimizer avoids stale gradients by collecting gradients from all replicas,
53 averaging them, then applying them to the variables in one shot, after
54 which replicas can fetch the new variables and continue.
56 The following accumulators/queue are created:
58 * N `gradient accumulators`, one per variable to train. Gradients are pushed
59 to them and the chief worker will wait until enough gradients are collected
60 and then average them before applying to variables. The accumulator will
61 drop all stale gradients (more details in the accumulator op).
62 * 1 `token` queue where the optimizer pushes the new global_step value after
63 all variables are updated.
65 The following local variable is created:
66 * `sync_rep_local_step`, one per replica. Compared against the global_step in
67 each accumulator to check for staleness of the gradients.
69 The optimizer adds nodes to the graph to collect gradients and pause the
70 trainers until variables are updated.
71 For the Parameter Server job:
73 1. An accumulator is created for each variable, and each replica pushes the
74 gradients into the accumulators instead of directly applying them to the
75 variables.
76 2. Each accumulator averages once enough gradients (replicas_to_aggregate)
77 have been accumulated.
78 3. Apply the averaged gradients to the variables.
79 4. Only after all variables have been updated, increment the global step.
80 5. Only after step 4, pushes `global_step` in the `token_queue`, once for
81 each worker replica. The workers can now fetch the global step, use it to
82 update its local_step variable and start the next batch. Please note that
83 some workers can consume multiple minibatches, while some may not consume
84 even one. This is because each worker fetches minibatches as long as
85 a token exists. If one worker is stuck for some reason and does not
86 consume a token, another worker can use it.
88 For the replicas:
90 1. Start a step: fetch variables and compute gradients.
91 2. Once the gradients have been computed, push them into gradient
92 accumulators. Each accumulator will check the staleness and drop the stale.
93 3. After pushing all the gradients, dequeue an updated value of global_step
94 from the token queue and record that step to its local_step variable. Note
95 that this is effectively a barrier.
96 4. Start the next batch.
98 ### Usage
100 ```python
101 # Create any optimizer to update the variables, say a simple SGD:
102 opt = GradientDescentOptimizer(learning_rate=0.1)
104 # Wrap the optimizer with sync_replicas_optimizer with 50 replicas: at each
105 # step the optimizer collects 50 gradients before applying to variables.
106 # Note that if you want to have 2 backup replicas, you can change
107 # total_num_replicas=52 and make sure this number matches how many physical
108 # replicas you started in your job.
109 opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
110 total_num_replicas=50)
112 # Some models have startup_delays to help stabilize the model but when using
113 # sync_replicas training, set it to 0.
115 # Now you can call `minimize()` or `compute_gradients()` and
116 # `apply_gradients()` normally
117 training_op = opt.minimize(total_loss, global_step=self.global_step)
120 # You can create the hook which handles initialization and queues.
121 sync_replicas_hook = opt.make_session_run_hook(is_chief)
122 ```
124 In the training program, every worker will run the train_op as if not
125 synchronized.
127 ```python
128 with training.MonitoredTrainingSession(
129 master=workers[worker_id].target, is_chief=is_chief,
130 hooks=[sync_replicas_hook]) as mon_sess:
131 while not mon_sess.should_stop():
132 mon_sess.run(training_op)
133 ```
135 To use SyncReplicasOptimizer with an `Estimator`, you need to send
136 sync_replicas_hook while calling the fit.
137 ```python
138 my_estimator = DNNClassifier(..., optimizer=opt)
139 my_estimator.fit(..., hooks=[sync_replicas_hook])
140 ```
141 """
143 @deprecation.deprecated(
144 None, "The `SyncReplicaOptimizer` class is deprecated. For synchronous "
145 "training, please use [Distribution Strategies](https://github.com/"
146 "tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute).",
147 warn_once=True)
148 def __init__(self,
149 opt,
150 replicas_to_aggregate,
151 total_num_replicas=None,
152 variable_averages=None,
153 variables_to_average=None,
154 use_locking=False,
155 name="sync_replicas"):
156 """Construct a sync_replicas optimizer.
158 Args:
159 opt: The actual optimizer that will be used to compute and apply the
160 gradients. Must be one of the Optimizer classes.
161 replicas_to_aggregate: number of replicas to aggregate for each variable
162 update.
163 total_num_replicas: Total number of tasks/workers/replicas, could be
164 different from replicas_to_aggregate.
165 If total_num_replicas > replicas_to_aggregate: it is backup_replicas +
166 replicas_to_aggregate.
167 If total_num_replicas < replicas_to_aggregate: Replicas compute
168 multiple batches per update to variables.
169 variable_averages: Optional `ExponentialMovingAverage` object, used to
170 maintain moving averages for the variables passed in
171 `variables_to_average`.
172 variables_to_average: a list of variables that need to be averaged. Only
173 needed if variable_averages is passed in.
174 use_locking: If True use locks for update operation.
175 name: string. Optional name of the returned operation.
176 """
177 if total_num_replicas is None:
178 total_num_replicas = replicas_to_aggregate
180 super(SyncReplicasOptimizer, self).__init__(use_locking, name)
181 logging.info(
182 "SyncReplicasV2: replicas_to_aggregate=%s; total_num_replicas=%s",
183 replicas_to_aggregate, total_num_replicas)
184 self._opt = opt
185 self._replicas_to_aggregate = replicas_to_aggregate
186 self._gradients_applied = False
187 self._variable_averages = variable_averages
188 self._variables_to_average = variables_to_average
189 self._total_num_replicas = total_num_replicas
190 self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate)
191 self._global_step = None
192 self._sync_token_queue = None
194 # The synchronization op will be executed in a queue runner which should
195 # only be executed by one of the replicas (usually the chief).
196 self._chief_queue_runner = None
198 # Remember which accumulator is on which device to set the initial step in
199 # the accumulator to be global step. This list contains list of the
200 # following format: (accumulator, device).
201 self._accumulator_list = []
203 def compute_gradients(self, *args, **kwargs):
204 """Compute gradients of "loss" for the variables in "var_list".
206 This simply wraps the compute_gradients() from the real optimizer. The
207 gradients will be aggregated in the apply_gradients() so that user can
208 modify the gradients like clipping with per replica global norm if needed.
209 The global norm with aggregated gradients can be bad as one replica's huge
210 gradients can hurt the gradients from other replicas.
212 Args:
213 *args: Arguments for compute_gradients().
214 **kwargs: Keyword arguments for compute_gradients().
216 Returns:
217 A list of (gradient, variable) pairs.
218 """
219 return self._opt.compute_gradients(*args, **kwargs)
221 def apply_gradients(self, grads_and_vars, global_step=None, name=None):
222 """Apply gradients to variables.
224 This contains most of the synchronization implementation and also wraps the
225 apply_gradients() from the real optimizer.
227 Args:
228 grads_and_vars: List of (gradient, variable) pairs as returned by
229 compute_gradients().
230 global_step: Optional Variable to increment by one after the
231 variables have been updated.
232 name: Optional name for the returned operation. Default to the
233 name passed to the Optimizer constructor.
235 Returns:
236 train_op: The op to dequeue a token so the replicas can exit this batch
237 and start the next one. This is executed by each replica.
239 Raises:
240 ValueError: If the grads_and_vars is empty.
241 ValueError: If global step is not provided, the staleness cannot be
242 checked.
243 """
244 if not grads_and_vars:
245 raise ValueError("Must supply at least one variable")
247 if global_step is None:
248 raise ValueError("Global step is required to check staleness")
250 self._global_step = global_step
251 train_ops = []
252 aggregated_grad = []
253 var_list = []
255 # local_anchor op will be placed on this worker task by default.
256 local_anchor = control_flow_ops.no_op()
257 # Colocating local_step variable prevents it being placed on the PS.
258 distribution_strategy = distribute_lib.get_strategy()
259 with distribution_strategy.extended.colocate_vars_with(local_anchor):
260 self._local_step = variable_v1.VariableV1(
261 initial_value=0,
262 trainable=False,
263 collections=[ops.GraphKeys.LOCAL_VARIABLES],
264 dtype=global_step.dtype.base_dtype,
265 name="sync_rep_local_step")
267 self.local_step_init_op = state_ops.assign(self._local_step, global_step)
268 chief_init_ops = [self.local_step_init_op]
269 self.ready_for_local_init_op = variables.report_uninitialized_variables(
270 variables.global_variables())
272 with ops.name_scope(None, self._name):
273 for grad, var in grads_and_vars:
274 var_list.append(var)
275 with ops.device(var.device):
276 # Dense gradients.
277 if grad is None:
278 aggregated_grad.append(None) # pass-through.
279 continue
280 elif isinstance(grad, ops.Tensor):
281 grad_accum = data_flow_ops.ConditionalAccumulator(
282 grad.dtype,
283 shape=var.get_shape(),
284 shared_name=var.name + "/grad_accum")
285 train_ops.append(grad_accum.apply_grad(
286 grad, local_step=self._local_step))
287 aggregated_grad.append(grad_accum.take_grad(
288 self._replicas_to_aggregate))
289 else:
290 if not isinstance(grad, indexed_slices.IndexedSlices):
291 raise ValueError("Unknown grad type!")
292 grad_accum = data_flow_ops.SparseConditionalAccumulator(
293 grad.dtype, shape=(), shared_name=var.name + "/grad_accum")
294 train_ops.append(grad_accum.apply_indexed_slices_grad(
295 grad, local_step=self._local_step))
296 aggregated_grad.append(grad_accum.take_indexed_slices_grad(
297 self._replicas_to_aggregate))
299 self._accumulator_list.append((grad_accum, var.device))
301 aggregated_grads_and_vars = zip(aggregated_grad, var_list)
303 # sync_op will be assigned to the same device as the global step.
304 with ops.device(global_step.device), ops.name_scope(""):
305 update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
306 global_step)
308 # Create token queue.
309 with ops.device(global_step.device), ops.name_scope(""):
310 sync_token_queue = (
311 data_flow_ops.FIFOQueue(-1,
312 global_step.dtype.base_dtype,
313 shapes=(),
314 name="sync_token_q",
315 shared_name="sync_token_q"))
316 self._sync_token_queue = sync_token_queue
318 with ops.device(global_step.device), ops.name_scope(""):
319 # Replicas have to wait until they can get a token from the token queue.
320 with ops.control_dependencies(train_ops):
321 token = sync_token_queue.dequeue()
322 train_op = state_ops.assign(self._local_step, token)
324 with ops.control_dependencies([update_op]):
325 # Sync_op needs to insert tokens to the token queue at the end of the
326 # step so the replicas can fetch them to start the next step.
327 tokens = array_ops.fill([self._tokens_per_step], global_step)
328 sync_op = sync_token_queue.enqueue_many((tokens,))
330 if self._variable_averages is not None:
331 with ops.control_dependencies([sync_op]), ops.name_scope(""):
332 sync_op = self._variable_averages.apply(
333 self._variables_to_average)
335 self._chief_queue_runner = queue_runner.QueueRunner(
336 sync_token_queue, [sync_op])
337 for accum, dev in self._accumulator_list:
338 with ops.device(dev):
339 chief_init_ops.append(
340 accum.set_global_step(
341 global_step, name="SetGlobalStep"))
342 self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
343 self._gradients_applied = True
344 return train_op
346 def get_chief_queue_runner(self):
347 """Returns the QueueRunner for the chief to execute.
349 This includes the operations to synchronize replicas: aggregate gradients,
350 apply to variables, increment global step, insert tokens to token queue.
352 Note that this can only be called after calling apply_gradients() which
353 actually generates this queuerunner.
355 Returns:
356 A `QueueRunner` for chief to execute.
358 Raises:
359 ValueError: If this is called before apply_gradients().
360 """
361 if self._gradients_applied is False:
362 raise ValueError("Should be called after apply_gradients().")
364 return self._chief_queue_runner
366 def get_slot(self, *args, **kwargs):
367 """Return a slot named "name" created for "var" by the Optimizer.
369 This simply wraps the get_slot() from the actual optimizer.
371 Args:
372 *args: Arguments for get_slot().
373 **kwargs: Keyword arguments for get_slot().
375 Returns:
376 The `Variable` for the slot if it was created, `None` otherwise.
377 """
378 return self._opt.get_slot(*args, **kwargs)
380 def variables(self):
381 """Fetches a list of optimizer variables in the default graph.
383 This wraps `variables()` from the actual optimizer. It does not include
384 the `SyncReplicasOptimizer`'s local step.
386 Returns:
387 A list of variables.
388 """
389 return self._opt.variables()
391 def get_slot_names(self, *args, **kwargs):
392 """Return a list of the names of slots created by the `Optimizer`.
394 This simply wraps the get_slot_names() from the actual optimizer.
396 Args:
397 *args: Arguments for get_slot().
398 **kwargs: Keyword arguments for get_slot().
400 Returns:
401 A list of strings.
402 """
403 return self._opt.get_slot_names(*args, **kwargs)
405 def get_init_tokens_op(self, num_tokens=-1):
406 """Returns the op to fill the sync_token_queue with the tokens.
408 This is supposed to be executed in the beginning of the chief/sync thread
409 so that even if the total_num_replicas is less than replicas_to_aggregate,
410 the model can still proceed as the replicas can compute multiple steps per
411 variable update. Make sure:
412 `num_tokens >= replicas_to_aggregate - total_num_replicas`.
414 Args:
415 num_tokens: Number of tokens to add to the queue.
417 Returns:
418 An op for the chief/sync replica to fill the token queue.
420 Raises:
421 ValueError: If this is called before apply_gradients().
422 ValueError: If num_tokens are smaller than replicas_to_aggregate -
423 total_num_replicas.
424 """
425 if self._gradients_applied is False:
426 raise ValueError(
427 "get_init_tokens_op() should be called after apply_gradients().")
429 tokens_needed = self._replicas_to_aggregate - self._total_num_replicas
430 if num_tokens == -1:
431 num_tokens = self._replicas_to_aggregate
432 elif num_tokens < tokens_needed:
433 raise ValueError(
434 "Too few tokens to finish the first step: %d (given) vs %d (needed)" %
435 (num_tokens, tokens_needed))
437 if num_tokens > 0:
438 with ops.device(self._global_step.device), ops.name_scope(""):
439 tokens = array_ops.fill([num_tokens], self._global_step)
440 init_tokens = self._sync_token_queue.enqueue_many((tokens,))
441 else:
442 init_tokens = control_flow_ops.no_op(name="no_init_tokens")
444 return init_tokens
446 def make_session_run_hook(self, is_chief, num_tokens=-1):
447 """Creates a hook to handle SyncReplicasHook ops such as initialization."""
448 return _SyncReplicasOptimizerHook(self, is_chief, num_tokens)
451class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
452 """A SessionRunHook handles ops related to SyncReplicasOptimizer."""
454 def __init__(self, sync_optimizer, is_chief, num_tokens):
455 """Creates hook to handle SyncReplicasOptimizer initialization ops.
457 Args:
458 sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
459 is_chief: `Bool`, whether is this a chief replica or not.
460 num_tokens: Number of tokens to add to the queue.
461 """
462 self._sync_optimizer = sync_optimizer
463 self._is_chief = is_chief
464 self._num_tokens = num_tokens
466 def begin(self):
467 if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access
468 raise ValueError(
469 "SyncReplicasOptimizer.apply_gradient should be called before using "
470 "the hook.")
471 if self._is_chief:
472 self._local_init_op = self._sync_optimizer.chief_init_op
473 self._ready_for_local_init_op = (
474 self._sync_optimizer.ready_for_local_init_op)
475 self._q_runner = self._sync_optimizer.get_chief_queue_runner()
476 self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
477 self._num_tokens)
478 else:
479 self._local_init_op = self._sync_optimizer.local_step_init_op
480 self._ready_for_local_init_op = (
481 self._sync_optimizer.ready_for_local_init_op)
482 self._q_runner = None
483 self._init_tokens_op = None
485 def after_create_session(self, session, coord):
486 """Runs SyncReplicasOptimizer initialization ops."""
487 local_init_success, msg = session_manager._ready( # pylint: disable=protected-access
488 self._ready_for_local_init_op, session,
489 "Model is not ready for SyncReplicasOptimizer local init.")
490 if not local_init_success:
491 raise RuntimeError(
492 "Init operations did not make model ready for SyncReplicasOptimizer "
493 "local_init. Init op: %s, error: %s" %
494 (self._local_init_op.name, msg))
495 session.run(self._local_init_op)
496 if self._init_tokens_op is not None:
497 session.run(self._init_tokens_op)
498 if self._q_runner is not None:
499 self._q_runner.create_threads(
500 session, coord=coord, daemon=True, start=True)