1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
16
17This is currently under development and the API is subject to change.
18"""
19
20import collections
21import contextlib
22import os
23import re
24import threading
25import time
26import weakref
27
28from six.moves import queue
29
30from tensorflow.python.distribute.coordinator import coordinator_context
31from tensorflow.python.distribute.coordinator import metric_utils
32from tensorflow.python.distribute.coordinator import remote_value
33from tensorflow.python.distribute.coordinator import utils
34from tensorflow.python.distribute.coordinator import values as values_lib
35from tensorflow.python.distribute.coordinator import watchdog
36from tensorflow.python.eager import cancellation
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.eager import executor
40from tensorflow.python.eager import function as tf_function
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import func_graph
43from tensorflow.python.framework import ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import nest
46from tensorflow.python.util.tf_export import tf_export
47
48# Maximum time for failed worker to come back is 1 hour
49_WORKER_MAXIMUM_RECOVERY_SEC = 3600
50# How often to poll task states from the coordination service. In testing, a
51# value of 1 led to some spurious reports of unavailability, so a higher value
52# is used. Refer to the discussion in b/249134783 for more.
53_POLL_FREQ_IN_SEC = 5
54
55# Maximum size for queued closures, "infinite" if set to 0.
56# When the maximum queue size is reached, further schedule calls will become
57# blocking until some previously queued closures are executed on workers.
58# Note that using an "infinite" queue size can take a non-trivial portion of
59# memory, and even lead to coordinator OOM. Modify the size to a smaller value
60# for coordinator with constrained memory resource (only recommended for
61# advanced users). Also used in unit tests to ensure the correctness when the
62# queue is full.
63_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
64
65# RPC error message from PS
66_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
67
68# InvalidArgumentError (unknown device) will not have "GRPC error..." string.
69_JOB_WORKER_STRING_IDENTIFIER = "/job:worker"
70
71
72RemoteValueStatus = remote_value.RemoteValueStatus
73RemoteValue = remote_value.RemoteValue
74RemoteValueImpl = values_lib.RemoteValueImpl
75PerWorkerValues = values_lib.PerWorkerValues
76
77
78class ClosureInputError(Exception):
79 """Wrapper for errors from resource building.
80
81 When a closure starts, it first checks for errors in any of its inputs, which
82 are RemoteValues from resource closures. If there were any errors, it wraps
83 the exception in this class and raises so it can be handled by the worker
84 failure handler.
85
86 Attributes:
87 original_exception:
88 """
89
90 def __init__(self, original_exception):
91 # Avoid doubly-nested errors
92 if isinstance(original_exception,
93 (ClosureInputError, ClosureAbortedError)):
94 self.original_exception = original_exception.original_exception
95 else:
96 self.original_exception = original_exception
97 message = ("Input has an error, the original exception is %r, "
98 "error message is %s." %
99 (self.original_exception, str(self.original_exception)))
100 super().__init__(message)
101 self.with_traceback(original_exception.__traceback__)
102
103
104class ClosureAbortedError(Exception):
105 """Wrapper for errors from training closures, to attach to resource closures.
106
107 This wrapper is used when a dependent training closure fails to set errors on
108 its required resource closures.
109
110 Attributes:
111 original_exception: The Exception to wrap
112 """
113
114 def __init__(self, original_exception):
115 # Avoid doubly-nested errors
116 if isinstance(original_exception,
117 (ClosureInputError, ClosureAbortedError)):
118 self.original_exception = original_exception.original_exception
119 else:
120 self.original_exception = original_exception
121 message = ("Other function has an execution error, as a result, the "
122 "current value is not available. The original exception is %r, "
123 "error message is %s." %
124 (self.original_exception, str(self.original_exception)))
125 super().__init__(message)
126 self.with_traceback(original_exception.__traceback__)
127
128
129class PSUnavailableError(errors.UnavailableError):
130 """Specifies that a parameter server is the unavailable task."""
131
132 def __init__(self, original_exception):
133 assert isinstance(original_exception, errors.UnavailableError)
134 # TF Errors should have init args set as attributes for serialization.
135 self.original_exception = original_exception
136 super().__init__(
137 original_exception.node_def,
138 original_exception.op,
139 original_exception.message,
140 )
141
142
143def _get_error_from_remote_values(structure):
144 """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
145 errors_in_structure = []
146
147 def _get_error(val):
148 if isinstance(val, RemoteValue):
149 error = val._get_error() # pylint: disable=protected-access
150 if error:
151 errors_in_structure.append(error)
152
153 nest.map_structure(_get_error, structure)
154 if errors_in_structure:
155 return errors_in_structure[0]
156 else:
157 return None
158
159
160def _maybe_as_type_spec(val):
161 if isinstance(val, (RemoteValue, PerWorkerValues)):
162 if val._type_spec is None: # pylint: disable=protected-access
163 raise ValueError("Output of a scheduled function that is not "
164 "tf.function cannot be the input of another function.")
165 return val._type_spec # pylint: disable=protected-access
166 else:
167 return val
168
169
170def _select_worker_slice(worker_id, structured):
171 """Selects the worker slice of each of the items in `structured`."""
172
173 def _get(x):
174 return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
175
176 return nest.map_structure(_get, structured)
177
178
179def _disallow_remote_value_as_input(structured):
180 """Raises if any element of `structured` is a RemoteValue."""
181
182 def _raise_if_remote_value(x):
183 if isinstance(x, RemoteValue):
184 raise ValueError(
185 "`tf.distribute.experimental.coordinator.RemoteValue` used "
186 "as an input to scheduled function is not yet "
187 "supported.")
188
189 nest.map_structure(_raise_if_remote_value, structured)
190
191
192class Closure(object):
193 """Hold a function to be scheduled and its arguments."""
194
195 def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
196 if not callable(function):
197 raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
198 "be a callable object.")
199 self._args = args or ()
200 self._kwargs = kwargs or {}
201
202 _disallow_remote_value_as_input(self._args)
203 _disallow_remote_value_as_input(self._kwargs)
204
205 if isinstance(function, def_function.Function):
206 replica_args = _select_worker_slice(0, self._args)
207 replica_kwargs = _select_worker_slice(0, self._kwargs)
208
209 # Note: no need to handle function registration failure since this kind of
210 # failure will not raise exceptions as designed in the runtime. The
211 # coordinator has to rely on subsequent operations that raise to catch
212 # function registration failure.
213
214 # Record the function tracing overhead. Note that we pass in the tracing
215 # count of the def_function.Function as a state tracker, so that metrics
216 # will only record the time for actual function tracing (i.e., excluding
217 # function cache lookups).
218 with metric_utils.monitored_timer(
219 "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access
220 self._concrete_function = function.get_concrete_function(
221 *nest.map_structure(_maybe_as_type_spec, replica_args),
222 **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
223 elif isinstance(function, tf_function.ConcreteFunction):
224 self._concrete_function = function
225
226 if hasattr(self, "_concrete_function"):
227 # If we have a concrete function, we get to retrieve the output type spec
228 # via the structured_output.
229 self._output_type_spec = func_graph.convert_structure_to_signature(
230 self._concrete_function.structured_outputs)
231 self._function = cancellation_mgr.get_cancelable_function(
232 self._concrete_function)
233 else:
234 # Otherwise (i.e. what is passed in is a regular python function), we have
235 # no such information.
236 self._output_type_spec = None
237 self._function = function
238
239 self._output_remote_value_ref = None
240
241 def build_output_remote_value(self):
242 if self._output_remote_value_ref is None:
243 ret = RemoteValueImpl(None, self._output_type_spec)
244 self._output_remote_value_ref = weakref.ref(ret)
245 return ret
246 else:
247 raise ValueError(
248 "The output of the Closure cannot be built more than once.")
249
250 def maybe_call_with_output_remote_value(self, method):
251 if self._output_remote_value_ref is None:
252 return None
253 output_remote_value = self._output_remote_value_ref()
254 if output_remote_value is not None:
255 return method(output_remote_value)
256 return None
257
258 def mark_cancelled(self):
259 e = errors.CancelledError(
260 None, None, "The corresponding function is "
261 "cancelled. Please reschedule the function.")
262 self.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access
263
264 def execute_on(self, worker):
265 """Executes the closure on the given worker.
266
267 Args:
268 worker: a `Worker` object.
269 """
270 replica_args = _select_worker_slice(worker.worker_index, self._args)
271 replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
272
273 e = (
274 _get_error_from_remote_values(replica_args) or
275 _get_error_from_remote_values(replica_kwargs))
276 if e:
277 if not isinstance(e, ClosureInputError):
278 e = ClosureInputError(e)
279 raise e
280
281 with ops.device(worker.device_name):
282 with context.executor_scope(worker.executor):
283 with coordinator_context.with_dispatch_context(worker):
284 with metric_utils.monitored_timer("closure_execution"):
285 output_values = self._function(
286 *nest.map_structure(coordinator_context.maybe_get_remote_value,
287 replica_args),
288 **nest.map_structure(coordinator_context.maybe_get_remote_value,
289 replica_kwargs))
290 self.maybe_call_with_output_remote_value(
291 lambda r: r._set_values(output_values)) # pylint: disable=protected-access
292
293
294class ResourceClosure(Closure):
295
296 def build_output_remote_value(self):
297 if self._output_remote_value_ref is None:
298 # We need to remember the Closure object in the `RemoteValue` here.
299 ret = RemoteValueImpl(self, self._output_type_spec)
300 self._output_remote_value_ref = weakref.ref(ret)
301 return ret
302 else:
303 return self._output_remote_value_ref()
304
305
306class _CoordinatedClosureQueue(object):
307 """Manage a queue of closures, inflight count and errors from execution.
308
309 This class is thread-safe.
310 """
311
312 def __init__(self):
313 # `self._inflight_closure_count` only tracks the number of inflight closures
314 # that are "in generation". Once an error occurs, error generation is
315 # incremented and all subsequent arriving closures (from inflight) are
316 # considered "out of generation".
317 self._inflight_closure_count = 0
318
319 self._queue_lock = threading.Lock()
320
321 # Condition indicating that all pending closures (either queued or inflight)
322 # have been processed, failed, or cancelled.
323 self._stop_waiting_condition = threading.Condition(self._queue_lock)
324
325 # Condition indicating that an item becomes available in queue (not empty).
326 self._closures_queued_condition = threading.Condition(self._queue_lock)
327 self._should_process_closures = True
328
329 # Condition indicating that a queue slot becomes available (not full).
330 # Note that even with "infinite" queue size, there is still a "practical"
331 # size limit for the queue depending on host memory capacity, and thus the
332 # queue will eventually become full with a lot of enqueued closures.
333 self._queue_free_slot_condition = threading.Condition(self._queue_lock)
334
335 # Condition indicating there is no inflight closures.
336 self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
337
338 # Use to cancel in-flight closures.
339 self._cancellation_mgr = cancellation.CancellationManager()
340
341 if _CLOSURE_QUEUE_MAX_SIZE <= 0:
342 logging.warning(
343 "In a `ClusterCoordinator`, creating an infinite closure queue can "
344 "consume a significant amount of memory and even lead to OOM.")
345 self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
346 self._tagged_queue = collections.defaultdict(queue.Queue)
347 self._error = None
348
349 # The following is a lock to make sure when `wait` is called and before it
350 # returns no `put` can be executed during this period. It is because `wait`
351 # won't know what to do with newly put closures. This lock adds an cutoff
352 # for `wait` so that closures put into the queue while waiting would not be
353 # taken responsible by this `wait`.
354 #
355 # We cannot reuse the `self._queue_lock` since when `wait` waits for a
356 # condition, the `self._queue_lock` will be released.
357 #
358 # We don't use a reader/writer's lock on purpose to reduce the complexity
359 # of the code.
360 self._put_wait_lock = threading.Lock()
361
362 self._watchdog = watchdog.WatchDog(on_triggered=self._on_watchdog_timeout)
363
364 def _on_watchdog_timeout(self):
365 logging.info("inflight_closure_count is %d", self._inflight_closure_count)
366 logging.info("current error is %s:%r", self._error, self._error)
367
368 def stop(self):
369 with self._queue_lock:
370 self._should_process_closures = False
371 self._cancellation_mgr.start_cancel()
372 self._closures_queued_condition.notify_all()
373 self._watchdog.stop()
374
375 def _cancel_all_closures(self):
376 """Clears the queue and sets remaining closures cancelled error.
377
378 This method expects self._queue_lock to be held prior to entry.
379 """
380 self._cancellation_mgr.start_cancel()
381 logging.info("Canceling all closures: waiting for inflight closures to "
382 "finish")
383 while self._inflight_closure_count > 0:
384 self._no_inflight_closure_condition.wait()
385 logging.info("Canceling all closures: canceling remaining closures on the "
386 "queue")
387 while True:
388 try:
389 closure = self._queue.get(block=False)
390 self._queue_free_slot_condition.notify()
391 closure.mark_cancelled()
392 except queue.Empty:
393 break
394 # The cancellation manager cannot be reused once cancelled. After all
395 # closures (queued or inflight) are cleaned up, recreate the cancellation
396 # manager with clean state.
397 # Note on thread-safety: this is triggered when one of theses
398 # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
399 # same time, no new closures can be constructed (which reads the
400 # _cancellation_mgr to get cancellable functions).
401 self._cancellation_mgr = cancellation.CancellationManager()
402
403 def _raise_if_error(self):
404 """Raises the error if one exists.
405
406 If an error exists, cancel the closures in queue, raises it, and clear
407 the error.
408
409 This method expects self._queue_lock to be held prior to entry.
410 """
411 if self._error:
412 logging.error("Start cancelling closures due to error %r: %s",
413 self._error, self._error)
414 self._cancel_all_closures()
415 try:
416 raise self._error # pylint: disable=raising-bad-type
417 finally:
418 self._error = None
419
420 def put(self, closure, tag=None):
421 """Put a closure into the queue for later execution.
422
423 If `mark_failed` was called before `put`, the error from the first
424 invocation of `mark_failed` will be raised.
425
426 Args:
427 closure: The `Closure` to put into the queue.
428 tag: if not None, put into a queue with the given tag.
429 """
430 closure.tag = tag
431 if tag is not None:
432 with self._queue_lock:
433 self._tagged_queue[tag].put(closure, block=False)
434 self._closures_queued_condition.notify_all()
435 else:
436 with self._put_wait_lock, self._queue_lock:
437 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
438 self._queue.put(closure, block=False)
439 self._raise_if_error()
440 self._closures_queued_condition.notify()
441
442 def get(self, timeout=None, tag=None):
443 """Return a closure from the queue to be executed.
444
445 It will try to fetch an item from the queue with the given tag. If this
446 queue is empty, it will then check the global queue.
447
448 Args:
449 timeout: timeout when waiting for a closure to be put.
450 tag: optional tag to specify which queue to query first before querying
451 the global queue.
452
453 Returns:
454 a closure or None after timeout.
455 """
456 with self._queue_lock:
457 while (self._should_process_closures and self._queue.empty() and
458 (tag is None or self._tagged_queue[tag].empty())):
459 if not self._closures_queued_condition.wait(timeout=timeout):
460 return None
461 if not self._should_process_closures:
462 return None
463 if tag is not None and not self._tagged_queue[tag].empty():
464 closure = self._tagged_queue[tag].get(block=False)
465 return closure
466 closure = self._queue.get(block=False)
467 assert closure.tag is None
468 assert tag is None or self._tagged_queue[tag].empty()
469 self._queue_free_slot_condition.notify()
470 self._inflight_closure_count += 1
471 return closure
472
473 def mark_finished(self):
474 """Let the queue know that a closure has been successfully executed."""
475 with self._queue_lock:
476 if self._inflight_closure_count < 1:
477 raise AssertionError("There is no inflight closures to mark_finished.")
478 self._inflight_closure_count -= 1
479 if self._inflight_closure_count == 0:
480 self._no_inflight_closure_condition.notify_all()
481 if self._queue.empty() and self._inflight_closure_count == 0:
482 self._stop_waiting_condition.notify_all()
483 self._watchdog.report_closure_done()
484
485 def put_back(self, closure):
486 """Put the closure back into the queue as it was not properly executed."""
487 assert closure.tag is None
488 with self._queue_lock:
489 if self._inflight_closure_count < 1:
490 raise AssertionError("There is no inflight closures to put_back.")
491 if self._error:
492 closure.mark_cancelled()
493 else:
494 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
495 self._queue.put(closure, block=False)
496 self._closures_queued_condition.notify()
497 self._inflight_closure_count -= 1
498 if self._inflight_closure_count == 0:
499 self._no_inflight_closure_condition.notify_all()
500
501 def wait(self, timeout=None):
502 """Wait for all closures to be finished before returning.
503
504 If `mark_failed` was called before or during `wait`, the error from the
505 first invocation of `mark_failed` will be raised.
506
507 Args:
508 timeout: A float specifying a timeout for the wait in seconds.
509
510 Returns:
511 True unless the given timeout expired, in which case it returns False.
512 """
513 with self._put_wait_lock, self._queue_lock:
514 logging.info("Waiting for all global closures to be finished.")
515 while (not self._error and
516 (not self._queue.empty() or self._inflight_closure_count > 0)):
517 if not self._stop_waiting_condition.wait(timeout=timeout):
518 return False
519 self._raise_if_error()
520 return True
521
522 def mark_failed(self, e):
523 """Sets error and unblocks any wait() call."""
524 with self._queue_lock:
525 # TODO(yuefengz): maybe record all failure and give users more
526 # information?
527 if self._inflight_closure_count < 1:
528 raise AssertionError("There is no inflight closures to mark_failed.")
529 if self._error is None:
530 self._error = e
531 self._inflight_closure_count -= 1
532 if self._inflight_closure_count == 0:
533 self._no_inflight_closure_condition.notify_all()
534 self._stop_waiting_condition.notify_all()
535
536 def done(self):
537 """Returns true if the queue is empty and there is no inflight closure.
538
539 If `mark_failed` was called before `done`, the error from the first
540 invocation of `mark_failed` will be raised.
541 """
542 with self._queue_lock:
543 self._raise_if_error()
544 return self._queue.empty() and self._inflight_closure_count == 0
545
546 def clear_tag_unlocked(self, tag):
547 self._tagged_queue[tag] = queue.Queue()
548
549
550class CoordinationServicePreemptionHandler(object):
551 """Handles preemptions of workers and parameter servers.
552
553 Starts a thread to regularly poll the coordination service (hosted on PS 0)
554 for task states. When a worker's task state reflects an error, it inspects the
555 error. If the error is recoverable (i.e. a preemption), it waits for the
556 worker to recover, then updates the server def. Otherwise, it raises the error
557 to the user.
558
559 A worker error is detected to be recoverable if it is the result of missing a
560 heartbeat that workers regularly send to the coordination service.
561
562 The thread also checks for parameter server errors. If these are detected, the
563 thread and coordinator shutdown. To resume training in this case, the whole
564 job must be restarted and resumed from the latest checkpoint.
565 """
566
567 def __init__(self, server_def, cluster):
568 self._server_def = server_def
569 self._cluster = cluster
570 self._cluster_update_lock = threading.Lock()
571 self._cluster_due_for_update_or_finish = threading.Event()
572 self._worker_up_cond = threading.Condition(self._cluster_update_lock)
573
574 self._next_task_state_cond = threading.Condition()
575 self._task_states = None
576
577 self._error_from_recovery = None
578 self._should_preemption_thread_run = True
579 self._task_state_poller_thread = utils.RepeatedTimer(
580 interval=_POLL_FREQ_IN_SEC,
581 function=self._get_task_states)
582 self._preemption_handler_thread = threading.Thread(
583 target=self._preemption_handler,
584 name="WorkerPreemptionHandler",
585 daemon=True)
586 self._preemption_handler_thread.start()
587
588 self._num_workers = self._cluster._num_workers
589 self._num_ps = self._cluster._num_ps
590
591 def stop(self):
592 """Ensure the worker preemption thread is closed."""
593 self._task_state_poller_thread.stop()
594 self._should_preemption_thread_run = False
595 with self._cluster_update_lock:
596 self._cluster_due_for_update_or_finish.set()
597 # TODO(yuefengz): The preemption handler thread shouldn't be terminated
598 # asynchronously since it touches eager context which is a process-wide
599 # singleton. The problem is in OSS unit tests will time out.
600
601 @contextlib.contextmanager
602 def wait_on_failure(self,
603 on_failure_fn=None,
604 on_transient_failure_fn=None,
605 on_recovery_fn=None,
606 worker_device_name="(unknown)"):
607 """Catches errors during closure execution and handles them.
608
609 Args:
610 on_failure_fn: an optional function to run if preemption happens.
611 on_transient_failure_fn: an optional function to run if transient failure
612 happens.
613 on_recovery_fn: an optional function to run when a worker is recovered
614 from preemption.
615 worker_device_name: the device name of the worker instance that is passing
616 through the failure.
617
618 Yields:
619 None.
620 """
621 assert self._should_preemption_thread_run
622 try:
623 yield
624 except (errors.OpError, ClosureInputError,
625 ClosureAbortedError) as e:
626 # The next state could reflect stale heartbeats, so wait for two rounds.
627 # Example:
628 # - Worker sends healthy heartbeat at T=0.
629 # - Coordination service receives healthy heartbeat at T=0.
630 # - Worker gets preempted at T=0.1.
631 # - Coordinator catches error at T=0.2, and waits here for next states.
632 # - Coordinator polls states at T=1.9. Heartbeat time has not elapsed yet,
633 # so coordination service does not know it is down yet.
634 # - Coordination service learns of worker unavailability at T=2, the next
635 # heartbeat.
636 # - Coordinator polls states at T=3.9 and learns of worker unavailability.
637 with self._next_task_state_cond:
638 # Give some buffer time to make sure task states are updated during the
639 # wait interval
640 self._next_task_state_cond.wait(_POLL_FREQ_IN_SEC * 1.25)
641 with self._next_task_state_cond:
642 self._next_task_state_cond.wait(_POLL_FREQ_IN_SEC * 1.25)
643
644 # Check for coordination service failure
645 if not self._task_states:
646 self._log_ps_failure_and_raise(e, 0)
647
648 worker_states = self._task_states[:self._num_workers]
649 ps_states = self._task_states[self._num_workers:]
650
651 # Check for PS failure
652 if any(ps_states):
653 failed_ps_index = [
654 ix for ix, ps_state in enumerate(ps_states) if ps_state
655 ]
656 self._log_ps_failure_and_raise(e, failed_ps_index[0])
657
658 # Check for preemption of this worker
659 worker_ix = int(worker_device_name.split(":")[-1])
660 if worker_states[worker_ix]:
661 # Raise error if all closures are being cancelled
662 if self._cluster.closure_queue._cancellation_mgr.is_cancelled: # pylint: disable=protected-access
663 if isinstance(e, errors.CancelledError):
664 raise e
665 # It's possible the caught error `e` here is due to worker preemption
666 # and is thus not a `CancelledError`, because a different
667 # unrecoverable error on another worker caused closure cancellation,
668 # while this thread was waiting for task states. So raise a new
669 # CancelledError.
670 else:
671 raise errors.CancelledError(
672 None, None, "The corresponding function was cancelled while "
673 "attempting to recover from worker failure.")
674 # Else, preemption
675 self._handle_failure_and_recovery(e, on_failure_fn,
676 on_transient_failure_fn,
677 on_recovery_fn, worker_device_name)
678 return
679
680 # else, if timeout: log
681 if self._cluster._record_and_ignore_transient_timeouts(e): # pylint: disable=protected-access
682 logging.error(
683 "Remote function on worker %s failed with %r:%s\n"
684 "This derived error is ignored and not reported to users.",
685 worker_device_name, e, e)
686 if on_transient_failure_fn:
687 on_transient_failure_fn()
688 return
689 raise e
690
691 def _handle_failure_and_recovery(self,
692 e,
693 on_failure_fn,
694 on_transient_failure_fn,
695 on_recovery_fn,
696 worker_device_name):
697 """Call failure fn, wait for cluster to recover, then call recovery fn.
698
699 Args:
700 e: the Exception thrown during closure execution.
701 on_failure_fn: an optional function to run if preemption happens.
702 on_transient_failure_fn: an optional function to run if transient failure
703 happens.
704 on_recovery_fn: an optional function to run when a worker is recovered
705 from preemption.
706 worker_device_name: the device name of the worker instance that is passing
707 through the failure.
708 """
709 if on_failure_fn:
710 on_failure_fn(e)
711 # update server def
712 with self._cluster_update_lock:
713 self._cluster_due_for_update_or_finish.set()
714 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
715 if self._error_from_recovery:
716 # TODO(yuefengz): there is only one worker that will get this error.
717 # Ideally we should let all workers notified by `_worker_up_cond` get
718 # this error.
719 try:
720 raise self._error_from_recovery
721 finally:
722 self._error_from_recovery = None
723 logging.info("Worker %s has been recovered.", worker_device_name)
724
725 if on_recovery_fn:
726 logging.info("Worker %s calling on_recovery_fn", worker_device_name)
727 with self.wait_on_failure(
728 on_recovery_fn=on_recovery_fn,
729 on_transient_failure_fn=on_transient_failure_fn,
730 worker_device_name=worker_device_name):
731 on_recovery_fn()
732
733 def _log_ps_failure_and_raise(self, e, ps_index):
734 logging.info("Parameter server failure detected at PS task %d", ps_index)
735 self.stop()
736 raise PSUnavailableError(e)
737
738 def _get_task_states(self):
739 try:
740 self._task_states = context.context().get_task_states(
741 [("worker", self._num_workers), ("ps", self._num_ps)]
742 )
743 except errors.UnavailableError:
744 # Coordination service is down
745 self._task_states = None
746 with self._next_task_state_cond:
747 self._next_task_state_cond.notify_all()
748
749 def _preemption_handler(self):
750 """A loop that handles preemption.
751
752 This loop waits for signal of worker preemption and upon worker preemption,
753 it waits until all workers are back and updates the cluster about the
754 restarted workers.
755 """
756 assert self._should_preemption_thread_run
757 while True:
758 self._cluster_due_for_update_or_finish.wait()
759 if not self._should_preemption_thread_run:
760 logging.info("Stopping the failure handing thread.")
761 break
762
763 with self._cluster_update_lock:
764 try:
765 # TODO(haoyuzhang): support partial cluster recovery
766 logging.info("Cluster now being recovered.")
767 context.context().update_server_def(self._server_def)
768
769 # Cluster updated successfully, clear the update signal, and notify
770 # all workers that they are recovered from failure.
771 logging.info("Cluster successfully recovered.")
772 self._notify_cluster_update()
773 except Exception as e: # pylint: disable=broad-except
774 logging.info("Error occurred while updating server def: %s", e)
775 # Wait for the next set of states from the task state poller
776 with self._next_task_state_cond:
777 self._next_task_state_cond.wait(_POLL_FREQ_IN_SEC * 2)
778 # If a PS is preempted, set the error
779 if not self._task_states:
780 self._error_from_recovery = e
781 else:
782 ps_states = self._task_states[self._num_workers:]
783 # Check for PS failure
784 if any(ps_states):
785 self._error_from_recovery = e
786 # Else, likely another worker failed. Just log and retry
787 self._notify_cluster_update()
788 # NOTE: Since the first RPC (GetStatus) of update_server_def is
789 # currently blocking by default, error should only happen if:
790 # (1) More workers failed while waiting for the previous workers to
791 # come back;
792 # (2) Worker failed when exchanging subsequent RPCs after the first
793 # RPC returns.
794 # Consider adding backoff retry logic if we see the error logged
795 # too frequently.
796 logging.error("Cluster update failed with error: %s. Retrying...", e)
797
798 def _notify_cluster_update(self):
799 self._worker_up_cond.notify_all()
800 # The check for _should_preemption_thread_run is necessary since the
801 # `stop` may have already set _cluster_due_for_update_or_finish.
802 if self._should_preemption_thread_run:
803 self._cluster_due_for_update_or_finish.clear()
804
805
806class WorkerPreemptionHandler(object):
807 """Handles worker preemptions."""
808
809 def __init__(self, server_def, cluster):
810 self._server_def = server_def
811 self._cluster = cluster
812 self._cluster_update_lock = threading.Lock()
813 self._cluster_due_for_update_or_finish = threading.Event()
814 self._worker_up_cond = threading.Condition(self._cluster_update_lock)
815 self._error_from_recovery = None
816 self._should_preemption_thread_run = True
817 self._preemption_handler_thread = threading.Thread(
818 target=self._preemption_handler,
819 name="WorkerPreemptionHandler",
820 daemon=True)
821 self._preemption_handler_thread.start()
822
823 def stop(self):
824 """Ensure the worker preemption thread is closed."""
825 self._should_preemption_thread_run = False
826 with self._cluster_update_lock:
827 self._cluster_due_for_update_or_finish.set()
828 # TODO(yuefengz): The preemption handler thread shouldn't be terminated
829 # asynchronously since it touches eager context which is a process-wide
830 # singleton. The problem is in OSS unit tests will time out.
831
832 def _validate_preemption_failure(self, e):
833 """Validates that the given exception represents worker preemption."""
834
835 # Only categorize the failure as a worker preemption if the cancellation
836 # manager did not attempt to cancel the blocking operations.
837 if _is_worker_failure(e) and (
838 not self._cluster.closure_queue._cancellation_mgr.is_cancelled): # pylint: disable=protected-access
839 return
840 raise e
841
842 @contextlib.contextmanager
843 def wait_on_failure(self,
844 on_failure_fn=None,
845 on_transient_failure_fn=None,
846 on_recovery_fn=None,
847 worker_device_name="(unknown)"):
848 """Catches worker preemption error and wait until failed workers are back.
849
850 Args:
851 on_failure_fn: an optional function to run if preemption happens.
852 on_transient_failure_fn: an optional function to run if transient failure
853 happens.
854 on_recovery_fn: an optional function to run when a worker is recovered
855 from preemption.
856 worker_device_name: the device name of the worker instance that is passing
857 through the failure.
858
859 Yields:
860 None.
861 """
862 assert self._should_preemption_thread_run
863 try:
864 yield
865 except (errors.OpError, ClosureInputError,
866 ClosureAbortedError, TypeError) as e:
867 # If the error is due to temporary connectivity issues between worker and
868 # ps, put back closure, ignore error and do not mark worker as failure.
869 if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access
870 logging.error(
871 "Remote function on worker %s failed with %r:%s\n"
872 "It is treated as a transient connectivity failure for now.",
873 worker_device_name, e, e)
874 if on_transient_failure_fn:
875 on_transient_failure_fn()
876 return
877
878 # If the error is due to temporary connectivity issues that cause the
879 # server-side RPCs to be cancelled, TF might not abort the step and the
880 # closure might timeout. The coordinator ignores certain amount of such
881 # failures without marking worker as failure.
882 if self._cluster._record_and_ignore_transient_timeouts(e): # pylint: disable=protected-access
883 logging.error(
884 "Remote function on worker %s failed with %r:%s\n"
885 "This derived error is ignored and not reported to users.",
886 worker_device_name, e, e)
887 if on_transient_failure_fn:
888 on_transient_failure_fn()
889 return
890
891 # Ignoring derived CancelledErrors to tolerate transient failures in
892 # PS-worker communication, which initially exposed as an UnavailableError
893 # and then lead to sub-function cancellation, subsequently getting
894 # reported from worker to chief as CancelledError.
895 # We do not mark either worker or PS as failed due to only CancelledError.
896 # If there are real (non-transient) failures, they must also be reported
897 # as other errors (UnavailableError most likely) in closure executions.
898 if isinstance(e, errors.CancelledError) and "/job:" in str(e):
899 logging.error(
900 "Remote function on worker %s failed with %r:%s\n"
901 "This derived error is ignored and not reported to users.",
902 worker_device_name, e, e)
903 if on_transient_failure_fn:
904 on_transient_failure_fn()
905 return
906
907 # This reraises the error, if it's not considered recoverable; otherwise,
908 # the following failure recovery logic run. At this time, only worker
909 # unavailability is recoverable. PS unavailability as well as other
910 # errors in the user function is not recoverable.
911 self._validate_preemption_failure(e)
912
913 logging.error("Worker %s failed with %r:%s", worker_device_name, e, e)
914 if on_failure_fn:
915 on_failure_fn(e)
916
917 with self._cluster_update_lock:
918 self._cluster_due_for_update_or_finish.set()
919 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
920 if self._error_from_recovery:
921 # TODO(yuefengz): there is only one worker that will get this error.
922 # Ideally we shuold let all workers notified by `_worker_up_cond` get
923 # this error.
924 try:
925 raise self._error_from_recovery
926 finally:
927 self._error_from_recovery = None
928 logging.info("Worker %s has been recovered.", worker_device_name)
929
930 if on_recovery_fn:
931 logging.info("Worker %s calling on_recovery_fn", worker_device_name)
932 with self.wait_on_failure(
933 on_recovery_fn=on_recovery_fn,
934 on_transient_failure_fn=on_transient_failure_fn,
935 worker_device_name=worker_device_name):
936 on_recovery_fn()
937
938 def _preemption_handler(self):
939 """A loop that handles preemption.
940
941 This loop waits for signal of worker preemption and upon worker preemption,
942 it waits until all workers are back and updates the cluster about the
943 restarted workers.
944 """
945 assert self._should_preemption_thread_run
946 while True:
947 self._cluster_due_for_update_or_finish.wait()
948 if not self._should_preemption_thread_run:
949 logging.info("Stopping the failure handing thread.")
950 break
951
952 with self._cluster_update_lock:
953 try:
954 # TODO(haoyuzhang): support partial cluster recovery
955 logging.info("Cluster now being recovered.")
956 with metric_utils.monitored_timer("server_def_update"):
957 context.context().update_server_def(self._server_def)
958
959 # Cluster updated successfully, clear the update signal, and notify
960 # all workers that they are recovered from failure.
961 logging.info("Cluster successfully recovered.")
962 self._worker_up_cond.notify_all()
963 # The check for _should_preemption_thread_run is necessary since the
964 # `stop` may have already set _cluster_due_for_update_or_finish.
965 if self._should_preemption_thread_run:
966 self._cluster_due_for_update_or_finish.clear()
967 except Exception as e: # pylint: disable=broad-except
968 logging.info("Error occurred while updating server def: %s", e)
969 try:
970 self._validate_preemption_failure(e)
971 except Exception as ps_e: # pylint: disable=broad-except
972 logging.info("Error that occurred while updating server def is not "
973 "a worker failure. So set it as _error_from_recovery")
974 # In this case, a parameter server fails. So we raise this error to
975 # the caller of `wait_on_failure`.
976 self._error_from_recovery = ps_e
977 self._worker_up_cond.notify_all()
978 if self._should_preemption_thread_run:
979 self._cluster_due_for_update_or_finish.clear()
980 # NOTE: Since the first RPC (GetStatus) of update_server_def is
981 # currently blocking by default, error should only happen if:
982 # (1) More workers failed while waiting for the previous workers to
983 # come back;
984 # (2) Worker failed when exchanging subsequent RPCs after the first
985 # RPC returns.
986 # Consider adding backoff retry logic if we see the error logged
987 # too frequently.
988 logging.error("Cluster update failed with error: %s. Retrying...", e)
989
990
991class Worker(object):
992 """A worker in a cluster.
993
994 Attributes:
995 worker_index: The index of the worker in the cluster.
996 device_name: The device string of the worker, e.g. "/job:worker/task:1".
997 executor: The worker's executor for remote function execution.
998 failure_handler: The failure handler used to handler worker preemption
999 failure.
1000 """
1001
1002 def __init__(self, worker_index, device_name, cluster):
1003 self.worker_index = worker_index
1004 self.device_name = device_name
1005 self.executor = executor.new_executor(enable_async=False)
1006 self.failure_handler = cluster.failure_handler
1007 self._cluster = cluster
1008 self._resource_tracking_lock = threading.Lock()
1009 self._resource_remote_value_refs = []
1010 self._is_dead_with_error = None
1011 self._should_worker_thread_run = True
1012
1013 # Worker threads need to start after `Worker`'s initialization.
1014 threading.Thread(target=self._process_queue,
1015 name="WorkerClosureProcessingLoop-%d" % self.worker_index,
1016 daemon=True).start()
1017
1018 def stop(self):
1019 """Ensure the worker thread is closed."""
1020 self._should_worker_thread_run = False
1021
1022 def _schedule_resource(self, closure):
1023 self._cluster.closure_queue.put(closure, tag=self.worker_index)
1024
1025 def _set_resources_aborted(self, e):
1026 """Set the resource ABORTED and add an error to it."""
1027 # TODO(yuefengz): maybe we can query whether a tensor is valid or not
1028 # instead of marking a tensor aborted?
1029 logging.info("[Worker %d] Clearing all resources.", self.worker_index)
1030 for weakref_resource in self._resource_remote_value_refs:
1031 resource = weakref_resource()
1032 if resource:
1033 # It is important to set an error on an aborted RemoteValue from a
1034 # ResourceClosure because its failure will not trigger the worker thread
1035 # to raise error immediately and the worker may continue executing
1036 # closures taking it as an input. The error will then be correctly
1037 # reported to users.
1038 resource._set_aborted(ClosureAbortedError(e)) # pylint: disable=protected-access
1039
1040 def _on_closure_failure(self, closure, e):
1041 logging.info("[Worker %d] Putting back a closure after it failed.",
1042 self.worker_index)
1043 self._cluster.closure_queue.put_back(closure)
1044
1045 with self._resource_tracking_lock:
1046 self._is_dead_with_error = e
1047 self._set_resources_aborted(e)
1048
1049 def _on_resource_closure_failure(self, e):
1050 """Clear tagged queue to ensure resource closures are rebuilt.
1051
1052 Args:
1053 e: The exception arisen from the resource closure.
1054 """
1055 logging.info("[Worker %d] Clearing tagged queue after resource closure "
1056 "failure.", self.worker_index)
1057 with self._resource_tracking_lock:
1058 self._is_dead_with_error = e
1059 # No locking on queue is needed since
1060 # * get will not happen concurrently here.
1061 # * put to the specific tagged queue will be guarded by
1062 # `self._resource_tracking_lock`.
1063 self._cluster.closure_queue.clear_tag_unlocked(self.worker_index)
1064 self._set_resources_aborted(e)
1065
1066 def _on_worker_recovery(self):
1067 logging.info("[Worker %d] calling _on_worker_recovery", self.worker_index)
1068 with self._resource_tracking_lock:
1069 for weakref_resource in self._resource_remote_value_refs:
1070 resource = weakref_resource()
1071 if resource:
1072 self._schedule_resource(resource._closure) # pylint: disable=protected-access
1073 self._is_dead_with_error = False
1074
1075 def _process_closure(self, closure):
1076 """Runs a closure with preemption handling."""
1077 try:
1078 with self.failure_handler.wait_on_failure(
1079 on_failure_fn=lambda e: self._on_closure_failure(closure, e),
1080 on_transient_failure_fn=(
1081 lambda: self._cluster.closure_queue.put_back(closure)),
1082 on_recovery_fn=self._on_worker_recovery,
1083 worker_device_name=self.device_name):
1084 closure.execute_on(self)
1085 with metric_utils.monitored_timer("remote_value_fetch"):
1086 # Copy the remote tensor to local (the coordinator) in case worker
1087 # becomes unavailable at a later time.
1088 closure.maybe_call_with_output_remote_value(lambda r: r.get())
1089 self._cluster.closure_queue.mark_finished()
1090 except Exception as e: # pylint: disable=broad-except
1091 # Avoid logging the derived cancellation error
1092 if not isinstance(e, errors.CancelledError):
1093 logging.error(
1094 " /job:worker/task:%d encountered the following error when "
1095 "processing closure: %r:%s", self.worker_index, e, e)
1096 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access
1097 self._cluster.closure_queue.mark_failed(e)
1098
1099 def _process_resource_closure(self, closure):
1100 """Run the given resource closure with preemption handling."""
1101 assert closure.tag == self.worker_index
1102 try:
1103 with self.failure_handler.wait_on_failure(
1104 on_failure_fn=self._on_resource_closure_failure,
1105 on_transient_failure_fn=(
1106 lambda: self._process_resource_closure(closure)),
1107 on_recovery_fn=self._on_worker_recovery,
1108 worker_device_name=self.device_name):
1109 closure.execute_on(self)
1110 except Exception as e: # pylint: disable=broad-except
1111 # Avoid logging the derived cancellation error
1112 logging.info("[Worker %d] got an exception when processing resource "
1113 "closure", self.worker_index)
1114 if not isinstance(e, errors.CancelledError):
1115 logging.error(
1116 " /job:worker/task:%d encountered the following error when "
1117 "processing resource closure: %r:%s", self.worker_index, e, e)
1118 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access
1119
1120 def _maybe_delay(self):
1121 """Delay if corresponding env vars are set."""
1122 # If the following two env vars variables are set. Scheduling for workers
1123 # will start in a staggered manner. Worker i will wait for
1124 # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding
1125 # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`.
1126 delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
1127 delay_secs *= self.worker_index
1128 delay_cap = int(
1129 os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
1130 if delay_cap:
1131 delay_secs = min(delay_secs, delay_cap)
1132 if delay_secs > 0:
1133 logging.info(" Worker %d sleeping for %d seconds before running function",
1134 self.worker_index, delay_secs)
1135 time.sleep(delay_secs)
1136
1137 def _process_queue(self):
1138 """Function running in a worker thread to process closure queues."""
1139 self._maybe_delay()
1140 while self._should_worker_thread_run:
1141 closure = self._cluster.closure_queue.get(tag=self.worker_index)
1142 if not self._should_worker_thread_run or closure is None:
1143 if closure is not None:
1144 closure.mark_cancelled()
1145 return
1146 if isinstance(closure, ResourceClosure):
1147 self._process_resource_closure(closure)
1148 else:
1149 self._process_closure(closure)
1150 # To properly stop the worker and preemption threads, it is important that
1151 # `ClusterCoordinator` object is not held onto so its `__del__` can be
1152 # called. By removing the reference to the `closure` that has already been
1153 # processed, we ensure that the `closure` object is released, while
1154 # getting the next `closure` at above `self._cluster.closure_queue.get()`
1155 # call.
1156 del closure
1157
1158 def create_resource(self, function, args=None, kwargs=None):
1159 """Synchronously creates a per-worker resource represented by a `RemoteValue`.
1160
1161 Args:
1162 function: the resource function to be run remotely. It should be a
1163 `tf.function`, a concrete function or a Python function.
1164 args: positional arguments to be passed to the function.
1165 kwargs: keyword arguments to be passed to the function.
1166
1167 Returns:
1168 one or several RemoteValue objects depending on the function return
1169 values.
1170 """
1171 # Some notes about the concurrency: currently all the activities related to
1172 # the same worker such as creating resources, setting resources' aborted
1173 # status, and executing closures happen on the same thread. This allows us
1174 # to have simpler logic of concurrency.
1175
1176 closure = ResourceClosure(
1177 function,
1178 self._cluster.resource_cancellation_mgr,
1179 args=args,
1180 kwargs=kwargs)
1181 resource_remote_value = closure.build_output_remote_value()
1182 with self._resource_tracking_lock:
1183 self._register_resource(resource_remote_value)
1184 if self._is_dead_with_error:
1185 resource_remote_value._set_aborted( # pylint: disable=protected-access
1186 ClosureAbortedError(self._is_dead_with_error))
1187 else:
1188 self._schedule_resource(closure)
1189 return resource_remote_value
1190
1191 def _register_resource(self, resource_remote_value):
1192 if not isinstance(resource_remote_value, RemoteValue):
1193 raise ValueError("Resource being registered is not of type "
1194 "`tf.distribute.experimental.coordinator.RemoteValue`.")
1195 self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
1196
1197
1198class Cluster(object):
1199 """A cluster with workers.
1200
1201 We assume all function errors are fatal and based on this assumption our
1202 error reporting logic is:
1203 1) Both `schedule` and `join` can raise a non-retryable error which is the
1204 first error seen by the coordinator from any previously scheduled functions.
1205 2) When an error is raised, there is no guarantee on how many previously
1206 scheduled functions have been executed; functions that have not been executed
1207 will be thrown away and marked as cancelled.
1208 3) After an error is raised, the internal state of error will be cleared.
1209 I.e. functions can continue to be scheduled and subsequent calls of `schedule`
1210 or `join` will not raise the same error again.
1211
1212 Attributes:
1213 failure_handler: The failure handler used to handler worker preemption
1214 failure.
1215 workers: a list of `Worker` objects in the cluster.
1216 closure_queue: the global Closure queue.
1217 resource_cancellation_mgr: the cancellation manager used to cancel resource
1218 closures.
1219 """
1220
1221 def __init__(self, strategy):
1222 """Initializes the cluster instance."""
1223
1224 self._num_workers = strategy._num_workers
1225 self._num_ps = strategy._num_ps
1226
1227 # Ignore PS failures reported by workers due to transient connection errors.
1228 # Transient connectivity issues between workers and PS are relayed by the
1229 # workers to the coordinator, leading the coordinator to believe that there
1230 # are PS failures. The difference between transient vs. permanent PS failure
1231 # is the number of reports from the workers. When this env var is set to a
1232 # positive integer K, the coordinator ignores up to K reports of a failed PS
1233 # task, i.e., only when there are more than K trials of executing closures
1234 # fail due to errors from the same PS instance do we consider the PS
1235 # instance encounters a failure.
1236 # TODO(b/164279603): Remove this workaround when the underlying connectivity
1237 # issue in gRPC server is resolved.
1238 self._transient_ps_failures_threshold = int(
1239 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
1240 self._potential_ps_failures_lock = threading.Lock()
1241 self._potential_ps_failures_count = [0] * self._num_ps
1242
1243 # Ignore worker timeouts due to transient connection errors.
1244 # Transient connectivity issues might cause the server side to unexpectedly
1245 # cancel RPC handling logic, leading to closure execution timeouts. When
1246 # the _transient_timeout_threshold is set to a positive number, the cluster
1247 # coordinator ignores DeadlineExceeded errors from workers for the specified
1248 # times before raising the error to users.
1249 self._transient_timeouts_threshold = int(
1250 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_TIMEOUTS",
1251 self._num_workers // 10))
1252 self._transient_timeouts_lock = threading.Lock()
1253 self._transient_timeouts_count = 0
1254
1255 self.closure_queue = _CoordinatedClosureQueue()
1256 # Set this environment variable to use an experimental
1257 # integration with the runtime coordination service to aid in failure
1258 # detection and handling. This will not affect the functionality of
1259 # the strategy or cluster coordinator, but is off by default.
1260 if os.getenv("TF_PSS_ENABLE_COORDINATION_SERVICE"):
1261 self.failure_handler = CoordinationServicePreemptionHandler(
1262 context.get_server_def(), self,
1263 )
1264 else:
1265 self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
1266 self)
1267 worker_device_strings = [
1268 "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
1269 ]
1270 self.workers = [
1271 Worker(i, w, self) for i, w in enumerate(worker_device_strings)
1272 ]
1273
1274 # Cancellation manager for all resource closures.
1275 self.resource_cancellation_mgr = cancellation.CancellationManager()
1276
1277 def stop(self):
1278 """Stop worker, worker preemption threads, and the closure queue."""
1279 logging.info("Stopping cluster, starting with failure handler")
1280 self.failure_handler.stop()
1281
1282 logging.info("Stopping workers")
1283 for worker in self.workers:
1284 worker.stop()
1285 logging.info("Stopping queue")
1286 self.closure_queue.stop()
1287 logging.info("Start cancelling remote resource-building functions")
1288 self.resource_cancellation_mgr.start_cancel()
1289
1290 def _record_and_ignore_transient_ps_failure(self, e):
1291 """Records potential PS failures and return if failure should be ignored."""
1292 if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
1293 return False
1294
1295 ps_tasks = _extract_failed_ps_instances(str(e))
1296 with self._potential_ps_failures_lock:
1297 for t in ps_tasks:
1298 self._potential_ps_failures_count[t] += 1
1299 # The number of UnavailableError encountered on this PS task exceeds the
1300 # maximum number of ignored error
1301 if (self._potential_ps_failures_count[t] >=
1302 self._transient_ps_failures_threshold):
1303 return False
1304 return True
1305
1306 def _record_and_ignore_transient_timeouts(self, e):
1307 """Records observed timeout error and return if it should be ignored."""
1308 if self._transient_timeouts_threshold <= 0:
1309 return False
1310 if not isinstance(e, errors.DeadlineExceededError):
1311 return False
1312 with self._transient_timeouts_lock:
1313 self._transient_timeouts_count += 1
1314 if self._transient_timeouts_count >= self._transient_timeouts_threshold:
1315 return False
1316 return True
1317
1318 def schedule(self, function, args, kwargs):
1319 """Schedules `function` to be dispatched to a worker for execution.
1320
1321 Args:
1322 function: The function to be dispatched to a worker for execution
1323 asynchronously.
1324 args: Positional arguments for `fn`.
1325 kwargs: Keyword arguments for `fn`.
1326
1327 Returns:
1328 A `RemoteValue` object.
1329 """
1330 closure = Closure(
1331 function,
1332 self.closure_queue._cancellation_mgr, # pylint: disable=protected-access
1333 args=args,
1334 kwargs=kwargs)
1335 ret = closure.build_output_remote_value()
1336 self.closure_queue.put(closure)
1337 return ret
1338
1339 def join(self):
1340 """Blocks until all scheduled functions are executed."""
1341 self.closure_queue.wait()
1342
1343 def done(self):
1344 """Returns true if all scheduled functions are executed."""
1345 return self.closure_queue.done()
1346
1347
1348@tf_export("distribute.experimental.coordinator.ClusterCoordinator",
1349 "distribute.coordinator.ClusterCoordinator", v1=[])
1350class ClusterCoordinator(object):
1351 """An object to schedule and coordinate remote function execution.
1352
1353 This class is used to create fault-tolerant resources and dispatch functions
1354 to remote TensorFlow servers.
1355
1356 Currently, this class is not supported to be used in a standalone manner. It
1357 should be used in conjunction with a `tf.distribute` strategy that is designed
1358 to work with it. The `ClusterCoordinator` class currently only works
1359 `tf.distribute.experimental.ParameterServerStrategy`.
1360
1361 __The `schedule`/`join` APIs__
1362
1363 The most important APIs provided by this class is the `schedule`/`join` pair.
1364 The `schedule` API is non-blocking in that it queues a `tf.function` and
1365 returns a `RemoteValue` immediately. The queued functions will be dispatched
1366 to remote workers in background threads and their `RemoteValue`s will be
1367 filled asynchronously. Since `schedule` doesn’t require worker assignment, the
1368 `tf.function` passed in can be executed on any available worker. If the worker
1369 it is executed on becomes unavailable before its completion, it will be
1370 migrated to another worker. Because of this fact and function execution is not
1371 atomic, a function may be executed more than once.
1372
1373 __Handling Task Failure__
1374
1375 This class when used with
1376 `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
1377 fault tolerance for worker failures. That is, when some workers are not
1378 available for any reason to be reached from the coordinator, the training
1379 progress continues to be made with the remaining workers. Upon recovery of a
1380 failed worker, it will be added for function execution after datasets created
1381 by `create_per_worker_dataset` are re-built on it.
1382
1383 When a parameter server fails, a `tf.errors.UnavailableError` is raised by
1384 `schedule`, `join` or `done`. In this case, in addition to bringing back the
1385 failed parameter server, users should restart the coordinator so that it
1386 reconnects to workers and parameter servers, re-creates the variables, and
1387 loads checkpoints. If the coordinator fails, after the user brings it back,
1388 the program will automatically connect to workers and parameter servers, and
1389 continue the progress from a checkpoint.
1390
1391 It is thus essential that in user's program, a checkpoint file is periodically
1392 saved, and restored at the start of the program. If an
1393 `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a
1394 checkpoiont, its `iterations` property roughly indicates the number of steps
1395 that have been made. This can be used to decide how many epochs and steps are
1396 needed before the training completion.
1397
1398 See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
1399 example usage of this API.
1400
1401 This is currently under development, and the API as well as implementation
1402 are subject to changes.
1403 """
1404
1405 def __new__(cls, strategy):
1406 # `ClusterCoordinator` is kept as a single instance to a given `Strategy`.
1407 # TODO(rchao): Needs a lock for thread-safety
1408 if strategy._cluster_coordinator is None:
1409 strategy._cluster_coordinator = super(
1410 ClusterCoordinator, cls).__new__(cls)
1411 return strategy._cluster_coordinator
1412
1413 def __init__(self, strategy):
1414 """Initialization of a `ClusterCoordinator` instance.
1415
1416 Args:
1417 strategy: a supported `tf.distribute.Strategy` object. Currently, only
1418 `tf.distribute.experimental.ParameterServerStrategy` is supported.
1419
1420 Raises:
1421 ValueError: if the strategy being used is not supported.
1422 """
1423 if not getattr(self, "_has_initialized", False):
1424 if not hasattr(strategy, "_is_parameter_server_strategy_v2"):
1425 raise ValueError(
1426 "Only `tf.distribute.experimental.ParameterServerStrategy` "
1427 "is supported to work with "
1428 "`tf.distribute.experimental.coordinator.ClusterCoordinator` "
1429 "currently.")
1430 self._strategy = strategy
1431 self.strategy.extended._used_with_coordinator = True
1432 self._cluster = Cluster(strategy)
1433 self._has_initialized = True
1434
1435 def __del__(self):
1436 logging.info("ClusterCoordinator destructor: stopping cluster")
1437 self._cluster.stop()
1438
1439 @property
1440 def strategy(self):
1441 """Returns the `Strategy` associated with the `ClusterCoordinator`."""
1442 return self._strategy
1443
1444 def schedule(self, fn, args=None, kwargs=None):
1445 """Schedules `fn` to be dispatched to a worker for asynchronous execution.
1446
1447 This method is non-blocking in that it queues the `fn` which will be
1448 executed later and returns a
1449 `tf.distribute.experimental.coordinator.RemoteValue` object immediately.
1450 `fetch` can be called on it to wait for the function execution to finish
1451 and retrieve its output from a remote worker. On the other hand, call
1452 `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
1453 all scheduled functions to finish.
1454
1455 `schedule` guarantees that `fn` will be executed on a worker at least once;
1456 it could be more than once if its corresponding worker fails in the middle
1457 of its execution. Note that since worker can fail at any point when
1458 executing the function, it is possible that the function is partially
1459 executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
1460 guarantees that in those events, the function will eventually be executed on
1461 any worker that is available.
1462
1463 If any previously scheduled function raises an error, `schedule` will raise
1464 any one of those errors, and clear the errors collected so far. What happens
1465 here, some of the previously scheduled functions may have not been executed.
1466 User can call `fetch` on the returned
1467 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1468 executed, failed, or cancelled, and reschedule the corresponding function if
1469 needed.
1470
1471 When `schedule` raises, it guarantees that there is no function that is
1472 still being executed.
1473
1474 At this time, there is no support of worker assignment for function
1475 execution, or priority of the workers.
1476
1477 `args` and `kwargs` are the arguments passed into `fn`, when `fn` is
1478 executed on a worker. They can be
1479 `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case,
1480 the argument will be substituted with the corresponding component on the
1481 target worker. Arguments that are not
1482 `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
1483 `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
1484 is not supported to be input `args` or `kwargs`.
1485
1486 Args:
1487 fn: A `tf.function`; the function to be dispatched to a worker for
1488 execution asynchronously. Regular python function is not supported to be
1489 scheduled.
1490 args: Positional arguments for `fn`.
1491 kwargs: Keyword arguments for `fn`.
1492
1493 Returns:
1494 A `tf.distribute.experimental.coordinator.RemoteValue` object that
1495 represents the output of the function scheduled.
1496
1497 Raises:
1498 Exception: one of the exceptions caught by the coordinator from any
1499 previously scheduled function, since the last time an error was thrown
1500 or since the beginning of the program.
1501 """
1502 if not isinstance(fn,
1503 (def_function.Function, tf_function.ConcreteFunction)):
1504 raise TypeError(
1505 "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`"
1506 " only accepts a `tf.function` or a concrete function.")
1507 # Slot variables are usually created during function tracing time; thus
1508 # `schedule` needs to be called within the `strategy.scope()`.
1509 with self.strategy.scope():
1510 self.strategy.extended._being_scheduled = True # pylint: disable=protected-access
1511 schedule_remote_value = self._cluster.schedule(
1512 fn, args=args, kwargs=kwargs)
1513 self.strategy.extended._being_scheduled = False # pylint: disable=protected-access
1514 return schedule_remote_value
1515
1516 def join(self):
1517 """Blocks until all the scheduled functions have finished execution.
1518
1519 If any previously scheduled function raises an error, `join` will fail by
1520 raising any one of those errors, and clear the errors collected so far. If
1521 this happens, some of the previously scheduled functions may have not been
1522 executed. Users can call `fetch` on the returned
1523 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1524 executed, failed, or cancelled. If some that have been cancelled need to be
1525 rescheduled, users should call `schedule` with the function again.
1526
1527 When `join` returns or raises, it guarantees that there is no function that
1528 is still being executed.
1529
1530 Raises:
1531 Exception: one of the exceptions caught by the coordinator by any
1532 previously scheduled function since the last time an error was thrown or
1533 since the beginning of the program.
1534 """
1535 self._cluster.join()
1536
1537 def done(self):
1538 """Returns whether all the scheduled functions have finished execution.
1539
1540 If any previously scheduled function raises an error, `done` will fail by
1541 raising any one of those errors.
1542
1543 When `done` returns True or raises, it guarantees that there is no function
1544 that is still being executed.
1545
1546 Returns:
1547 Whether all the scheduled functions have finished execution.
1548 Raises:
1549 Exception: one of the exceptions caught by the coordinator by any
1550 previously scheduled function since the last time an error was thrown or
1551 since the beginning of the program.
1552 """
1553 return self._cluster.done()
1554
1555 def create_per_worker_dataset(self, dataset_fn):
1556 """Create dataset on each worker.
1557
1558 This creates dataset on workers from the input which can be either a
1559 `tf.data.Dataset`, a `tf.distribute.DistributedDataset` or a function which
1560 returns a dataset, and returns an object that represents the collection of
1561 those individual datasets. Calling `iter` on such collection of datasets
1562 returns a `tf.distribute.experimental.coordinator.PerWorkerValues`, which is
1563 a collection of iterators, where the iterators have been placed on
1564 respective workers.
1565
1566 Calling `next` on a `PerWorkerValues` of iterator is unsupported. The
1567 iterator is meant to be passed as an argument into
1568 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
1569 the scheduled function is about to be executed by a worker, the
1570 function will receive the individual iterator that corresponds to the
1571 worker. The `next` method can be called on an iterator inside a
1572 scheduled function when the iterator is an input of the function.
1573
1574 Currently the `schedule` method assumes workers are all the same and thus
1575 assumes the datasets on different workers are the same, except they may be
1576 shuffled differently if they contain a `dataset.shuffle` operation and a
1577 random seed is not set. Because of this, we also recommend the datasets to
1578 be repeated indefinitely and schedule a finite number of steps instead of
1579 relying on the `OutOfRangeError` from a dataset.
1580
1581
1582 Example:
1583
1584 ```python
1585 strategy = tf.distribute.experimental.ParameterServerStrategy(
1586 cluster_resolver=...)
1587 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1588 strategy=strategy)
1589
1590 @tf.function
1591 def worker_fn(iterator):
1592 return next(iterator)
1593
1594 def per_worker_dataset_fn():
1595 return strategy.distribute_datasets_from_function(
1596 lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
1597
1598 per_worker_dataset = coordinator.create_per_worker_dataset(
1599 per_worker_dataset_fn)
1600 per_worker_iter = iter(per_worker_dataset)
1601 remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
1602 assert remote_value.fetch() == 3
1603 ```
1604
1605 Args:
1606 dataset_fn: The dataset function that returns a dataset. This is to be
1607 executed on the workers.
1608
1609 Returns:
1610 An object that represents the collection of those individual
1611 datasets. `iter` is expected to be called on this object that returns
1612 a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
1613 iterators (that are on the workers).
1614 """
1615 return values_lib.get_per_worker_dataset(dataset_fn, self)
1616
1617 def _create_per_worker_resources(self, fn, args=None, kwargs=None):
1618 """Synchronously create resources on the workers.
1619
1620 The resources are represented by
1621 `tf.distribute.experimental.coordinator.RemoteValue`s.
1622
1623 Args:
1624 fn: The function to be dispatched to all workers for execution
1625 asynchronously.
1626 args: Positional arguments for `fn`.
1627 kwargs: Keyword arguments for `fn`.
1628
1629 Returns:
1630 A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
1631 wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
1632 objects.
1633 """
1634 results = []
1635 for w in self._cluster.workers:
1636 results.append(w.create_resource(fn, args=args, kwargs=kwargs))
1637 return PerWorkerValues(tuple(results))
1638
1639 def fetch(self, val):
1640 """Blocking call to fetch results from the remote values.
1641
1642 This is a wrapper around
1643 `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
1644 `RemoteValue` structure; it returns the execution results of
1645 `RemoteValue`s. If not ready, wait for them while blocking the caller.
1646
1647 Example:
1648 ```python
1649 strategy = ...
1650 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1651 strategy)
1652
1653 def dataset_fn():
1654 return tf.data.Dataset.from_tensor_slices([1, 1, 1])
1655
1656 with strategy.scope():
1657 v = tf.Variable(initial_value=0)
1658
1659 @tf.function
1660 def worker_fn(iterator):
1661 def replica_fn(x):
1662 v.assign_add(x)
1663 return v.read_value()
1664 return strategy.run(replica_fn, args=(next(iterator),))
1665
1666 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
1667 distributed_iterator = iter(distributed_dataset)
1668 result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
1669 assert coordinator.fetch(result) == 1
1670 ```
1671
1672 Args:
1673 val: The value to fetch the results from. If this is structure of
1674 `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
1675 called on the individual
1676 `tf.distribute.experimental.coordinator.RemoteValue` to get the result.
1677
1678 Returns:
1679 If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
1680 structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
1681 return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
1682 values immediately if they are available, or block the call until they are
1683 available, and return the fetched
1684 `tf.distribute.experimental.coordinator.RemoteValue` values with the same
1685 structure. If `val` is other types, return it as-is.
1686 """
1687
1688 def _maybe_fetch(val):
1689 if isinstance(val, RemoteValue):
1690 return val.fetch()
1691 else:
1692 return val
1693
1694 # TODO(yuefengz): we should fetch values in a batch.
1695 return nest.map_structure(_maybe_fetch, val)
1696
1697
1698def _extract_failed_ps_instances(err_msg):
1699 """Return a set of potentially failing ps instances from error message."""
1700 tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg)
1701 return set(int(t.split(":")[-1]) for t in tasks)
1702
1703
1704def _is_ps_failure(error):
1705 """Whether the error is considered a parameter server failure."""
1706 if isinstance(error, PSUnavailableError):
1707 return True
1708
1709 # For an `ClosureInputError` or `ClosureAbortedError`, extract
1710 # the original error and assess it accordingly.
1711 if isinstance(error, (ClosureInputError, ClosureAbortedError)):
1712 error = error.original_exception
1713
1714 if _RPC_ERROR_FROM_PS not in str(error):
1715 return False
1716
1717 if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
1718 return True
1719
1720 # The following error could happen when the remote task fails and restarts
1721 # in a very short interval during which no RPCs were exchanged to detect the
1722 # failure. In that case, gRPC allows channel (which is different from a
1723 # connection) to be reused for a replaced server listening to same address.
1724 if isinstance(error, errors.InvalidArgumentError):
1725 if ("unknown device" in str(error).lower() or
1726 "Unable to find the relevant tensor remote_handle" in str(error)):
1727 return True
1728
1729 return False
1730
1731
1732def _handle_graph_execution_error_as_worker_failure():
1733 return int(os.environ.get("TF_PS_HANDLE_UNKNOWN_ERROR", "0")) > 0
1734
1735
1736def _is_worker_failure(error):
1737 """Whether the error is considered a worker failure."""
1738
1739 # TODO(b/216666282): Understand why worker failure can manifest as a
1740 # "Graph execution error" `UnknownError`.
1741 if (_handle_graph_execution_error_as_worker_failure() and
1742 isinstance(error, errors.UnknownError) and
1743 "Graph execution error" in str(error)):
1744 logging.info(f"Handling {type(error)}: {str(error)} as worker failure.")
1745 return True
1746
1747 # For an `ClosureInputError` or `ClosureAbortedError`, extract
1748 # the original error and assess it accordingly.
1749 if isinstance(error, (ClosureInputError, ClosureAbortedError)):
1750 error = error.original_exception
1751
1752 if _JOB_WORKER_STRING_IDENTIFIER not in str(error):
1753 return False
1754 if _RPC_ERROR_FROM_PS in str(error):
1755 return False
1756
1757 # TODO(haoyuzhang): Consider using special status code if error from a
1758 # remote is derived from RPC errors originated from other hosts.
1759 if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
1760 return True
1761
1762 # The following error could happen when the remote task fails and restarts
1763 # in a very short interval during which no RPCs were exchanged to detect the
1764 # failure. In that case, gRPC allows channel (which is different from a
1765 # connection) to be reused for a replaced server listening to same address.
1766 if isinstance(error, errors.InvalidArgumentError):
1767 if ("unknown device" in str(error).lower() or
1768 "Primary device is not remote" in str(error) or
1769 "Unable to find the relevant tensor remote_handle" in str(error)):
1770 return True
1771
1772 # TODO(b/162541228): The following 2 types of errors are very rare and only
1773 # observed in large-scale testing. The types of errors should be reduced.
1774 # This could happen when the function registration fails. In the observed
1775 # cases this only happens to the dataset related functions.
1776 if isinstance(error, errors.NotFoundError):
1777 if ("is neither a type of a primitive operation nor a name of a function "
1778 "registered" in str(error)):
1779 return True
1780
1781 # NOTE(b/179061495): During worker preemptions, if multiple functions are
1782 # running concurrently (especially with subfunctions spanning chief/PS),
1783 # CancelledError can be returned due to chief/PS cancelling outstanding RPCs
1784 # to the failing workers.
1785 if isinstance(error, errors.CancelledError):
1786 return True
1787
1788 # This can occur when preparing closures for execution when doing exact
1789 # evaluation, because the iterator creation, which occurs within the
1790 # tf.function, needs to access the worker device, so it fails if the worker is
1791 # down.
1792 if isinstance(error, TypeError) and "Binding inputs to tf.function" in str(
1793 error):
1794 return True
1795
1796 return False