Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/cluster_coordinator.py: 18%

687 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module for `ClusterCoordinator` and relevant cluster-worker related library. 

16 

17This is currently under development and the API is subject to change. 

18""" 

19 

20import collections 

21import contextlib 

22import os 

23import re 

24import threading 

25import time 

26import weakref 

27 

28from six.moves import queue 

29 

30from tensorflow.python.distribute.coordinator import coordinator_context 

31from tensorflow.python.distribute.coordinator import metric_utils 

32from tensorflow.python.distribute.coordinator import remote_value 

33from tensorflow.python.distribute.coordinator import utils 

34from tensorflow.python.distribute.coordinator import values as values_lib 

35from tensorflow.python.distribute.coordinator import watchdog 

36from tensorflow.python.eager import cancellation 

37from tensorflow.python.eager import context 

38from tensorflow.python.eager import def_function 

39from tensorflow.python.eager import executor 

40from tensorflow.python.eager import function as tf_function 

41from tensorflow.python.framework import errors 

42from tensorflow.python.framework import func_graph 

43from tensorflow.python.framework import ops 

44from tensorflow.python.platform import tf_logging as logging 

45from tensorflow.python.util import nest 

46from tensorflow.python.util.tf_export import tf_export 

47 

48# Maximum time for failed worker to come back is 1 hour 

49_WORKER_MAXIMUM_RECOVERY_SEC = 3600 

50# How often to poll task states from the coordination service. In testing, a 

51# value of 1 led to some spurious reports of unavailability, so a higher value 

52# is used. Refer to the discussion in b/249134783 for more. 

53_POLL_FREQ_IN_SEC = 5 

54 

55# Maximum size for queued closures, "infinite" if set to 0. 

56# When the maximum queue size is reached, further schedule calls will become 

57# blocking until some previously queued closures are executed on workers. 

58# Note that using an "infinite" queue size can take a non-trivial portion of 

59# memory, and even lead to coordinator OOM. Modify the size to a smaller value 

60# for coordinator with constrained memory resource (only recommended for 

61# advanced users). Also used in unit tests to ensure the correctness when the 

62# queue is full. 

63_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024 

64 

65# RPC error message from PS 

66_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 

67 

68# InvalidArgumentError (unknown device) will not have "GRPC error..." string. 

69_JOB_WORKER_STRING_IDENTIFIER = "/job:worker" 

70 

71 

72RemoteValueStatus = remote_value.RemoteValueStatus 

73RemoteValue = remote_value.RemoteValue 

74RemoteValueImpl = values_lib.RemoteValueImpl 

75PerWorkerValues = values_lib.PerWorkerValues 

76 

77 

78class ClosureInputError(Exception): 

79 """Wrapper for errors from resource building. 

80 

81 When a closure starts, it first checks for errors in any of its inputs, which 

82 are RemoteValues from resource closures. If there were any errors, it wraps 

83 the exception in this class and raises so it can be handled by the worker 

84 failure handler. 

85 

86 Attributes: 

87 original_exception: 

88 """ 

89 

90 def __init__(self, original_exception): 

91 # Avoid doubly-nested errors 

92 if isinstance(original_exception, 

93 (ClosureInputError, ClosureAbortedError)): 

94 self.original_exception = original_exception.original_exception 

95 else: 

96 self.original_exception = original_exception 

97 message = ("Input has an error, the original exception is %r, " 

98 "error message is %s." % 

99 (self.original_exception, str(self.original_exception))) 

100 super().__init__(message) 

101 self.with_traceback(original_exception.__traceback__) 

102 

103 

104class ClosureAbortedError(Exception): 

105 """Wrapper for errors from training closures, to attach to resource closures. 

106 

107 This wrapper is used when a dependent training closure fails to set errors on 

108 its required resource closures. 

109 

110 Attributes: 

111 original_exception: The Exception to wrap 

112 """ 

113 

114 def __init__(self, original_exception): 

115 # Avoid doubly-nested errors 

116 if isinstance(original_exception, 

117 (ClosureInputError, ClosureAbortedError)): 

118 self.original_exception = original_exception.original_exception 

119 else: 

120 self.original_exception = original_exception 

121 message = ("Other function has an execution error, as a result, the " 

122 "current value is not available. The original exception is %r, " 

123 "error message is %s." % 

124 (self.original_exception, str(self.original_exception))) 

125 super().__init__(message) 

126 self.with_traceback(original_exception.__traceback__) 

127 

128 

129class PSUnavailableError(errors.UnavailableError): 

130 """Specifies that a parameter server is the unavailable task.""" 

131 

132 def __init__(self, original_exception): 

133 assert isinstance(original_exception, errors.UnavailableError) 

134 # TF Errors should have init args set as attributes for serialization. 

135 self.original_exception = original_exception 

136 super().__init__( 

137 original_exception.node_def, 

138 original_exception.op, 

139 original_exception.message, 

140 ) 

141 

142 

143def _get_error_from_remote_values(structure): 

144 """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed.""" 

145 errors_in_structure = [] 

146 

147 def _get_error(val): 

148 if isinstance(val, RemoteValue): 

149 error = val._get_error() # pylint: disable=protected-access 

150 if error: 

151 errors_in_structure.append(error) 

152 

153 nest.map_structure(_get_error, structure) 

154 if errors_in_structure: 

155 return errors_in_structure[0] 

156 else: 

157 return None 

158 

159 

160def _maybe_as_type_spec(val): 

161 if isinstance(val, (RemoteValue, PerWorkerValues)): 

162 if val._type_spec is None: # pylint: disable=protected-access 

163 raise ValueError("Output of a scheduled function that is not " 

164 "tf.function cannot be the input of another function.") 

165 return val._type_spec # pylint: disable=protected-access 

166 else: 

167 return val 

168 

169 

170def _select_worker_slice(worker_id, structured): 

171 """Selects the worker slice of each of the items in `structured`.""" 

172 

173 def _get(x): 

174 return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access 

175 

176 return nest.map_structure(_get, structured) 

177 

178 

179def _disallow_remote_value_as_input(structured): 

180 """Raises if any element of `structured` is a RemoteValue.""" 

181 

182 def _raise_if_remote_value(x): 

183 if isinstance(x, RemoteValue): 

184 raise ValueError( 

185 "`tf.distribute.experimental.coordinator.RemoteValue` used " 

186 "as an input to scheduled function is not yet " 

187 "supported.") 

188 

189 nest.map_structure(_raise_if_remote_value, structured) 

190 

191 

192class Closure(object): 

193 """Hold a function to be scheduled and its arguments.""" 

194 

195 def __init__(self, function, cancellation_mgr, args=None, kwargs=None): 

196 if not callable(function): 

197 raise ValueError("Function passed to `ClusterCoordinator.schedule` must " 

198 "be a callable object.") 

199 self._args = args or () 

200 self._kwargs = kwargs or {} 

201 

202 _disallow_remote_value_as_input(self._args) 

203 _disallow_remote_value_as_input(self._kwargs) 

204 

205 if isinstance(function, def_function.Function): 

206 replica_args = _select_worker_slice(0, self._args) 

207 replica_kwargs = _select_worker_slice(0, self._kwargs) 

208 

209 # Note: no need to handle function registration failure since this kind of 

210 # failure will not raise exceptions as designed in the runtime. The 

211 # coordinator has to rely on subsequent operations that raise to catch 

212 # function registration failure. 

213 

214 # Record the function tracing overhead. Note that we pass in the tracing 

215 # count of the def_function.Function as a state tracker, so that metrics 

216 # will only record the time for actual function tracing (i.e., excluding 

217 # function cache lookups). 

218 with metric_utils.monitored_timer( 

219 "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access 

220 self._concrete_function = function.get_concrete_function( 

221 *nest.map_structure(_maybe_as_type_spec, replica_args), 

222 **nest.map_structure(_maybe_as_type_spec, replica_kwargs)) 

223 elif isinstance(function, tf_function.ConcreteFunction): 

224 self._concrete_function = function 

225 

226 if hasattr(self, "_concrete_function"): 

227 # If we have a concrete function, we get to retrieve the output type spec 

228 # via the structured_output. 

229 self._output_type_spec = func_graph.convert_structure_to_signature( 

230 self._concrete_function.structured_outputs) 

231 self._function = cancellation_mgr.get_cancelable_function( 

232 self._concrete_function) 

233 else: 

234 # Otherwise (i.e. what is passed in is a regular python function), we have 

235 # no such information. 

236 self._output_type_spec = None 

237 self._function = function 

238 

239 self._output_remote_value_ref = None 

240 

241 def build_output_remote_value(self): 

242 if self._output_remote_value_ref is None: 

243 ret = RemoteValueImpl(None, self._output_type_spec) 

244 self._output_remote_value_ref = weakref.ref(ret) 

245 return ret 

246 else: 

247 raise ValueError( 

248 "The output of the Closure cannot be built more than once.") 

249 

250 def maybe_call_with_output_remote_value(self, method): 

251 if self._output_remote_value_ref is None: 

252 return None 

253 output_remote_value = self._output_remote_value_ref() 

254 if output_remote_value is not None: 

255 return method(output_remote_value) 

256 return None 

257 

258 def mark_cancelled(self): 

259 e = errors.CancelledError( 

260 None, None, "The corresponding function is " 

261 "cancelled. Please reschedule the function.") 

262 self.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 

263 

264 def execute_on(self, worker): 

265 """Executes the closure on the given worker. 

266 

267 Args: 

268 worker: a `Worker` object. 

269 """ 

270 replica_args = _select_worker_slice(worker.worker_index, self._args) 

271 replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs) 

272 

273 e = ( 

274 _get_error_from_remote_values(replica_args) or 

275 _get_error_from_remote_values(replica_kwargs)) 

276 if e: 

277 if not isinstance(e, ClosureInputError): 

278 e = ClosureInputError(e) 

279 raise e 

280 

281 with ops.device(worker.device_name): 

282 with context.executor_scope(worker.executor): 

283 with coordinator_context.with_dispatch_context(worker): 

284 with metric_utils.monitored_timer("closure_execution"): 

285 output_values = self._function( 

286 *nest.map_structure(coordinator_context.maybe_get_remote_value, 

287 replica_args), 

288 **nest.map_structure(coordinator_context.maybe_get_remote_value, 

289 replica_kwargs)) 

290 self.maybe_call_with_output_remote_value( 

291 lambda r: r._set_values(output_values)) # pylint: disable=protected-access 

292 

293 

294class ResourceClosure(Closure): 

295 

296 def build_output_remote_value(self): 

297 if self._output_remote_value_ref is None: 

298 # We need to remember the Closure object in the `RemoteValue` here. 

299 ret = RemoteValueImpl(self, self._output_type_spec) 

300 self._output_remote_value_ref = weakref.ref(ret) 

301 return ret 

302 else: 

303 return self._output_remote_value_ref() 

304 

305 

306class _CoordinatedClosureQueue(object): 

307 """Manage a queue of closures, inflight count and errors from execution. 

308 

309 This class is thread-safe. 

310 """ 

311 

312 def __init__(self): 

313 # `self._inflight_closure_count` only tracks the number of inflight closures 

314 # that are "in generation". Once an error occurs, error generation is 

315 # incremented and all subsequent arriving closures (from inflight) are 

316 # considered "out of generation". 

317 self._inflight_closure_count = 0 

318 

319 self._queue_lock = threading.Lock() 

320 

321 # Condition indicating that all pending closures (either queued or inflight) 

322 # have been processed, failed, or cancelled. 

323 self._stop_waiting_condition = threading.Condition(self._queue_lock) 

324 

325 # Condition indicating that an item becomes available in queue (not empty). 

326 self._closures_queued_condition = threading.Condition(self._queue_lock) 

327 self._should_process_closures = True 

328 

329 # Condition indicating that a queue slot becomes available (not full). 

330 # Note that even with "infinite" queue size, there is still a "practical" 

331 # size limit for the queue depending on host memory capacity, and thus the 

332 # queue will eventually become full with a lot of enqueued closures. 

333 self._queue_free_slot_condition = threading.Condition(self._queue_lock) 

334 

335 # Condition indicating there is no inflight closures. 

336 self._no_inflight_closure_condition = threading.Condition(self._queue_lock) 

337 

338 # Use to cancel in-flight closures. 

339 self._cancellation_mgr = cancellation.CancellationManager() 

340 

341 if _CLOSURE_QUEUE_MAX_SIZE <= 0: 

342 logging.warning( 

343 "In a `ClusterCoordinator`, creating an infinite closure queue can " 

344 "consume a significant amount of memory and even lead to OOM.") 

345 self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) 

346 self._tagged_queue = collections.defaultdict(queue.Queue) 

347 self._error = None 

348 

349 # The following is a lock to make sure when `wait` is called and before it 

350 # returns no `put` can be executed during this period. It is because `wait` 

351 # won't know what to do with newly put closures. This lock adds an cutoff 

352 # for `wait` so that closures put into the queue while waiting would not be 

353 # taken responsible by this `wait`. 

354 # 

355 # We cannot reuse the `self._queue_lock` since when `wait` waits for a 

356 # condition, the `self._queue_lock` will be released. 

357 # 

358 # We don't use a reader/writer's lock on purpose to reduce the complexity 

359 # of the code. 

360 self._put_wait_lock = threading.Lock() 

361 

362 self._watchdog = watchdog.WatchDog(on_triggered=self._on_watchdog_timeout) 

363 

364 def _on_watchdog_timeout(self): 

365 logging.info("inflight_closure_count is %d", self._inflight_closure_count) 

366 logging.info("current error is %s:%r", self._error, self._error) 

367 

368 def stop(self): 

369 with self._queue_lock: 

370 self._should_process_closures = False 

371 self._cancellation_mgr.start_cancel() 

372 self._closures_queued_condition.notify_all() 

373 self._watchdog.stop() 

374 

375 def _cancel_all_closures(self): 

376 """Clears the queue and sets remaining closures cancelled error. 

377 

378 This method expects self._queue_lock to be held prior to entry. 

379 """ 

380 self._cancellation_mgr.start_cancel() 

381 logging.info("Canceling all closures: waiting for inflight closures to " 

382 "finish") 

383 while self._inflight_closure_count > 0: 

384 self._no_inflight_closure_condition.wait() 

385 logging.info("Canceling all closures: canceling remaining closures on the " 

386 "queue") 

387 while True: 

388 try: 

389 closure = self._queue.get(block=False) 

390 self._queue_free_slot_condition.notify() 

391 closure.mark_cancelled() 

392 except queue.Empty: 

393 break 

394 # The cancellation manager cannot be reused once cancelled. After all 

395 # closures (queued or inflight) are cleaned up, recreate the cancellation 

396 # manager with clean state. 

397 # Note on thread-safety: this is triggered when one of theses 

398 # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the 

399 # same time, no new closures can be constructed (which reads the 

400 # _cancellation_mgr to get cancellable functions). 

401 self._cancellation_mgr = cancellation.CancellationManager() 

402 

403 def _raise_if_error(self): 

404 """Raises the error if one exists. 

405 

406 If an error exists, cancel the closures in queue, raises it, and clear 

407 the error. 

408 

409 This method expects self._queue_lock to be held prior to entry. 

410 """ 

411 if self._error: 

412 logging.error("Start cancelling closures due to error %r: %s", 

413 self._error, self._error) 

414 self._cancel_all_closures() 

415 try: 

416 raise self._error # pylint: disable=raising-bad-type 

417 finally: 

418 self._error = None 

419 

420 def put(self, closure, tag=None): 

421 """Put a closure into the queue for later execution. 

422 

423 If `mark_failed` was called before `put`, the error from the first 

424 invocation of `mark_failed` will be raised. 

425 

426 Args: 

427 closure: The `Closure` to put into the queue. 

428 tag: if not None, put into a queue with the given tag. 

429 """ 

430 closure.tag = tag 

431 if tag is not None: 

432 with self._queue_lock: 

433 self._tagged_queue[tag].put(closure, block=False) 

434 self._closures_queued_condition.notify_all() 

435 else: 

436 with self._put_wait_lock, self._queue_lock: 

437 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 

438 self._queue.put(closure, block=False) 

439 self._raise_if_error() 

440 self._closures_queued_condition.notify() 

441 

442 def get(self, timeout=None, tag=None): 

443 """Return a closure from the queue to be executed. 

444 

445 It will try to fetch an item from the queue with the given tag. If this 

446 queue is empty, it will then check the global queue. 

447 

448 Args: 

449 timeout: timeout when waiting for a closure to be put. 

450 tag: optional tag to specify which queue to query first before querying 

451 the global queue. 

452 

453 Returns: 

454 a closure or None after timeout. 

455 """ 

456 with self._queue_lock: 

457 while (self._should_process_closures and self._queue.empty() and 

458 (tag is None or self._tagged_queue[tag].empty())): 

459 if not self._closures_queued_condition.wait(timeout=timeout): 

460 return None 

461 if not self._should_process_closures: 

462 return None 

463 if tag is not None and not self._tagged_queue[tag].empty(): 

464 closure = self._tagged_queue[tag].get(block=False) 

465 return closure 

466 closure = self._queue.get(block=False) 

467 assert closure.tag is None 

468 assert tag is None or self._tagged_queue[tag].empty() 

469 self._queue_free_slot_condition.notify() 

470 self._inflight_closure_count += 1 

471 return closure 

472 

473 def mark_finished(self): 

474 """Let the queue know that a closure has been successfully executed.""" 

475 with self._queue_lock: 

476 if self._inflight_closure_count < 1: 

477 raise AssertionError("There is no inflight closures to mark_finished.") 

478 self._inflight_closure_count -= 1 

479 if self._inflight_closure_count == 0: 

480 self._no_inflight_closure_condition.notify_all() 

481 if self._queue.empty() and self._inflight_closure_count == 0: 

482 self._stop_waiting_condition.notify_all() 

483 self._watchdog.report_closure_done() 

484 

485 def put_back(self, closure): 

486 """Put the closure back into the queue as it was not properly executed.""" 

487 assert closure.tag is None 

488 with self._queue_lock: 

489 if self._inflight_closure_count < 1: 

490 raise AssertionError("There is no inflight closures to put_back.") 

491 if self._error: 

492 closure.mark_cancelled() 

493 else: 

494 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 

495 self._queue.put(closure, block=False) 

496 self._closures_queued_condition.notify() 

497 self._inflight_closure_count -= 1 

498 if self._inflight_closure_count == 0: 

499 self._no_inflight_closure_condition.notify_all() 

500 

501 def wait(self, timeout=None): 

502 """Wait for all closures to be finished before returning. 

503 

504 If `mark_failed` was called before or during `wait`, the error from the 

505 first invocation of `mark_failed` will be raised. 

506 

507 Args: 

508 timeout: A float specifying a timeout for the wait in seconds. 

509 

510 Returns: 

511 True unless the given timeout expired, in which case it returns False. 

512 """ 

513 with self._put_wait_lock, self._queue_lock: 

514 logging.info("Waiting for all global closures to be finished.") 

515 while (not self._error and 

516 (not self._queue.empty() or self._inflight_closure_count > 0)): 

517 if not self._stop_waiting_condition.wait(timeout=timeout): 

518 return False 

519 self._raise_if_error() 

520 return True 

521 

522 def mark_failed(self, e): 

523 """Sets error and unblocks any wait() call.""" 

524 with self._queue_lock: 

525 # TODO(yuefengz): maybe record all failure and give users more 

526 # information? 

527 if self._inflight_closure_count < 1: 

528 raise AssertionError("There is no inflight closures to mark_failed.") 

529 if self._error is None: 

530 self._error = e 

531 self._inflight_closure_count -= 1 

532 if self._inflight_closure_count == 0: 

533 self._no_inflight_closure_condition.notify_all() 

534 self._stop_waiting_condition.notify_all() 

535 

536 def done(self): 

537 """Returns true if the queue is empty and there is no inflight closure. 

538 

539 If `mark_failed` was called before `done`, the error from the first 

540 invocation of `mark_failed` will be raised. 

541 """ 

542 with self._queue_lock: 

543 self._raise_if_error() 

544 return self._queue.empty() and self._inflight_closure_count == 0 

545 

546 def clear_tag_unlocked(self, tag): 

547 self._tagged_queue[tag] = queue.Queue() 

548 

549 

550class CoordinationServicePreemptionHandler(object): 

551 """Handles preemptions of workers and parameter servers. 

552 

553 Starts a thread to regularly poll the coordination service (hosted on PS 0) 

554 for task states. When a worker's task state reflects an error, it inspects the 

555 error. If the error is recoverable (i.e. a preemption), it waits for the 

556 worker to recover, then updates the server def. Otherwise, it raises the error 

557 to the user. 

558 

559 A worker error is detected to be recoverable if it is the result of missing a 

560 heartbeat that workers regularly send to the coordination service. 

561 

562 The thread also checks for parameter server errors. If these are detected, the 

563 thread and coordinator shutdown. To resume training in this case, the whole 

564 job must be restarted and resumed from the latest checkpoint. 

565 """ 

566 

567 def __init__(self, server_def, cluster): 

568 self._server_def = server_def 

569 self._cluster = cluster 

570 self._cluster_update_lock = threading.Lock() 

571 self._cluster_due_for_update_or_finish = threading.Event() 

572 self._worker_up_cond = threading.Condition(self._cluster_update_lock) 

573 

574 self._next_task_state_cond = threading.Condition() 

575 self._task_states = None 

576 

577 self._error_from_recovery = None 

578 self._should_preemption_thread_run = True 

579 self._task_state_poller_thread = utils.RepeatedTimer( 

580 interval=_POLL_FREQ_IN_SEC, 

581 function=self._get_task_states) 

582 self._preemption_handler_thread = threading.Thread( 

583 target=self._preemption_handler, 

584 name="WorkerPreemptionHandler", 

585 daemon=True) 

586 self._preemption_handler_thread.start() 

587 

588 self._num_workers = self._cluster._num_workers 

589 self._num_ps = self._cluster._num_ps 

590 

591 def stop(self): 

592 """Ensure the worker preemption thread is closed.""" 

593 self._task_state_poller_thread.stop() 

594 self._should_preemption_thread_run = False 

595 with self._cluster_update_lock: 

596 self._cluster_due_for_update_or_finish.set() 

597 # TODO(yuefengz): The preemption handler thread shouldn't be terminated 

598 # asynchronously since it touches eager context which is a process-wide 

599 # singleton. The problem is in OSS unit tests will time out. 

600 

601 @contextlib.contextmanager 

602 def wait_on_failure(self, 

603 on_failure_fn=None, 

604 on_transient_failure_fn=None, 

605 on_recovery_fn=None, 

606 worker_device_name="(unknown)"): 

607 """Catches errors during closure execution and handles them. 

608 

609 Args: 

610 on_failure_fn: an optional function to run if preemption happens. 

611 on_transient_failure_fn: an optional function to run if transient failure 

612 happens. 

613 on_recovery_fn: an optional function to run when a worker is recovered 

614 from preemption. 

615 worker_device_name: the device name of the worker instance that is passing 

616 through the failure. 

617 

618 Yields: 

619 None. 

620 """ 

621 assert self._should_preemption_thread_run 

622 try: 

623 yield 

624 except (errors.OpError, ClosureInputError, 

625 ClosureAbortedError) as e: 

626 # The next state could reflect stale heartbeats, so wait for two rounds. 

627 # Example: 

628 # - Worker sends healthy heartbeat at T=0. 

629 # - Coordination service receives healthy heartbeat at T=0. 

630 # - Worker gets preempted at T=0.1. 

631 # - Coordinator catches error at T=0.2, and waits here for next states. 

632 # - Coordinator polls states at T=1.9. Heartbeat time has not elapsed yet, 

633 # so coordination service does not know it is down yet. 

634 # - Coordination service learns of worker unavailability at T=2, the next 

635 # heartbeat. 

636 # - Coordinator polls states at T=3.9 and learns of worker unavailability. 

637 with self._next_task_state_cond: 

638 # Give some buffer time to make sure task states are updated during the 

639 # wait interval 

640 self._next_task_state_cond.wait(_POLL_FREQ_IN_SEC * 1.25) 

641 with self._next_task_state_cond: 

642 self._next_task_state_cond.wait(_POLL_FREQ_IN_SEC * 1.25) 

643 

644 # Check for coordination service failure 

645 if not self._task_states: 

646 self._log_ps_failure_and_raise(e, 0) 

647 

648 worker_states = self._task_states[:self._num_workers] 

649 ps_states = self._task_states[self._num_workers:] 

650 

651 # Check for PS failure 

652 if any(ps_states): 

653 failed_ps_index = [ 

654 ix for ix, ps_state in enumerate(ps_states) if ps_state 

655 ] 

656 self._log_ps_failure_and_raise(e, failed_ps_index[0]) 

657 

658 # Check for preemption of this worker 

659 worker_ix = int(worker_device_name.split(":")[-1]) 

660 if worker_states[worker_ix]: 

661 # Raise error if all closures are being cancelled 

662 if self._cluster.closure_queue._cancellation_mgr.is_cancelled: # pylint: disable=protected-access 

663 if isinstance(e, errors.CancelledError): 

664 raise e 

665 # It's possible the caught error `e` here is due to worker preemption 

666 # and is thus not a `CancelledError`, because a different 

667 # unrecoverable error on another worker caused closure cancellation, 

668 # while this thread was waiting for task states. So raise a new 

669 # CancelledError. 

670 else: 

671 raise errors.CancelledError( 

672 None, None, "The corresponding function was cancelled while " 

673 "attempting to recover from worker failure.") 

674 # Else, preemption 

675 self._handle_failure_and_recovery(e, on_failure_fn, 

676 on_transient_failure_fn, 

677 on_recovery_fn, worker_device_name) 

678 return 

679 

680 # else, if timeout: log 

681 if self._cluster._record_and_ignore_transient_timeouts(e): # pylint: disable=protected-access 

682 logging.error( 

683 "Remote function on worker %s failed with %r:%s\n" 

684 "This derived error is ignored and not reported to users.", 

685 worker_device_name, e, e) 

686 if on_transient_failure_fn: 

687 on_transient_failure_fn() 

688 return 

689 raise e 

690 

691 def _handle_failure_and_recovery(self, 

692 e, 

693 on_failure_fn, 

694 on_transient_failure_fn, 

695 on_recovery_fn, 

696 worker_device_name): 

697 """Call failure fn, wait for cluster to recover, then call recovery fn. 

698 

699 Args: 

700 e: the Exception thrown during closure execution. 

701 on_failure_fn: an optional function to run if preemption happens. 

702 on_transient_failure_fn: an optional function to run if transient failure 

703 happens. 

704 on_recovery_fn: an optional function to run when a worker is recovered 

705 from preemption. 

706 worker_device_name: the device name of the worker instance that is passing 

707 through the failure. 

708 """ 

709 if on_failure_fn: 

710 on_failure_fn(e) 

711 # update server def 

712 with self._cluster_update_lock: 

713 self._cluster_due_for_update_or_finish.set() 

714 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC) 

715 if self._error_from_recovery: 

716 # TODO(yuefengz): there is only one worker that will get this error. 

717 # Ideally we should let all workers notified by `_worker_up_cond` get 

718 # this error. 

719 try: 

720 raise self._error_from_recovery 

721 finally: 

722 self._error_from_recovery = None 

723 logging.info("Worker %s has been recovered.", worker_device_name) 

724 

725 if on_recovery_fn: 

726 logging.info("Worker %s calling on_recovery_fn", worker_device_name) 

727 with self.wait_on_failure( 

728 on_recovery_fn=on_recovery_fn, 

729 on_transient_failure_fn=on_transient_failure_fn, 

730 worker_device_name=worker_device_name): 

731 on_recovery_fn() 

732 

733 def _log_ps_failure_and_raise(self, e, ps_index): 

734 logging.info("Parameter server failure detected at PS task %d", ps_index) 

735 self.stop() 

736 raise PSUnavailableError(e) 

737 

738 def _get_task_states(self): 

739 try: 

740 self._task_states = context.context().get_task_states( 

741 [("worker", self._num_workers), ("ps", self._num_ps)] 

742 ) 

743 except errors.UnavailableError: 

744 # Coordination service is down 

745 self._task_states = None 

746 with self._next_task_state_cond: 

747 self._next_task_state_cond.notify_all() 

748 

749 def _preemption_handler(self): 

750 """A loop that handles preemption. 

751 

752 This loop waits for signal of worker preemption and upon worker preemption, 

753 it waits until all workers are back and updates the cluster about the 

754 restarted workers. 

755 """ 

756 assert self._should_preemption_thread_run 

757 while True: 

758 self._cluster_due_for_update_or_finish.wait() 

759 if not self._should_preemption_thread_run: 

760 logging.info("Stopping the failure handing thread.") 

761 break 

762 

763 with self._cluster_update_lock: 

764 try: 

765 # TODO(haoyuzhang): support partial cluster recovery 

766 logging.info("Cluster now being recovered.") 

767 context.context().update_server_def(self._server_def) 

768 

769 # Cluster updated successfully, clear the update signal, and notify 

770 # all workers that they are recovered from failure. 

771 logging.info("Cluster successfully recovered.") 

772 self._notify_cluster_update() 

773 except Exception as e: # pylint: disable=broad-except 

774 logging.info("Error occurred while updating server def: %s", e) 

775 # Wait for the next set of states from the task state poller 

776 with self._next_task_state_cond: 

777 self._next_task_state_cond.wait(_POLL_FREQ_IN_SEC * 2) 

778 # If a PS is preempted, set the error 

779 if not self._task_states: 

780 self._error_from_recovery = e 

781 else: 

782 ps_states = self._task_states[self._num_workers:] 

783 # Check for PS failure 

784 if any(ps_states): 

785 self._error_from_recovery = e 

786 # Else, likely another worker failed. Just log and retry 

787 self._notify_cluster_update() 

788 # NOTE: Since the first RPC (GetStatus) of update_server_def is 

789 # currently blocking by default, error should only happen if: 

790 # (1) More workers failed while waiting for the previous workers to 

791 # come back; 

792 # (2) Worker failed when exchanging subsequent RPCs after the first 

793 # RPC returns. 

794 # Consider adding backoff retry logic if we see the error logged 

795 # too frequently. 

796 logging.error("Cluster update failed with error: %s. Retrying...", e) 

797 

798 def _notify_cluster_update(self): 

799 self._worker_up_cond.notify_all() 

800 # The check for _should_preemption_thread_run is necessary since the 

801 # `stop` may have already set _cluster_due_for_update_or_finish. 

802 if self._should_preemption_thread_run: 

803 self._cluster_due_for_update_or_finish.clear() 

804 

805 

806class WorkerPreemptionHandler(object): 

807 """Handles worker preemptions.""" 

808 

809 def __init__(self, server_def, cluster): 

810 self._server_def = server_def 

811 self._cluster = cluster 

812 self._cluster_update_lock = threading.Lock() 

813 self._cluster_due_for_update_or_finish = threading.Event() 

814 self._worker_up_cond = threading.Condition(self._cluster_update_lock) 

815 self._error_from_recovery = None 

816 self._should_preemption_thread_run = True 

817 self._preemption_handler_thread = threading.Thread( 

818 target=self._preemption_handler, 

819 name="WorkerPreemptionHandler", 

820 daemon=True) 

821 self._preemption_handler_thread.start() 

822 

823 def stop(self): 

824 """Ensure the worker preemption thread is closed.""" 

825 self._should_preemption_thread_run = False 

826 with self._cluster_update_lock: 

827 self._cluster_due_for_update_or_finish.set() 

828 # TODO(yuefengz): The preemption handler thread shouldn't be terminated 

829 # asynchronously since it touches eager context which is a process-wide 

830 # singleton. The problem is in OSS unit tests will time out. 

831 

832 def _validate_preemption_failure(self, e): 

833 """Validates that the given exception represents worker preemption.""" 

834 

835 # Only categorize the failure as a worker preemption if the cancellation 

836 # manager did not attempt to cancel the blocking operations. 

837 if _is_worker_failure(e) and ( 

838 not self._cluster.closure_queue._cancellation_mgr.is_cancelled): # pylint: disable=protected-access 

839 return 

840 raise e 

841 

842 @contextlib.contextmanager 

843 def wait_on_failure(self, 

844 on_failure_fn=None, 

845 on_transient_failure_fn=None, 

846 on_recovery_fn=None, 

847 worker_device_name="(unknown)"): 

848 """Catches worker preemption error and wait until failed workers are back. 

849 

850 Args: 

851 on_failure_fn: an optional function to run if preemption happens. 

852 on_transient_failure_fn: an optional function to run if transient failure 

853 happens. 

854 on_recovery_fn: an optional function to run when a worker is recovered 

855 from preemption. 

856 worker_device_name: the device name of the worker instance that is passing 

857 through the failure. 

858 

859 Yields: 

860 None. 

861 """ 

862 assert self._should_preemption_thread_run 

863 try: 

864 yield 

865 except (errors.OpError, ClosureInputError, 

866 ClosureAbortedError, TypeError) as e: 

867 # If the error is due to temporary connectivity issues between worker and 

868 # ps, put back closure, ignore error and do not mark worker as failure. 

869 if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access 

870 logging.error( 

871 "Remote function on worker %s failed with %r:%s\n" 

872 "It is treated as a transient connectivity failure for now.", 

873 worker_device_name, e, e) 

874 if on_transient_failure_fn: 

875 on_transient_failure_fn() 

876 return 

877 

878 # If the error is due to temporary connectivity issues that cause the 

879 # server-side RPCs to be cancelled, TF might not abort the step and the 

880 # closure might timeout. The coordinator ignores certain amount of such 

881 # failures without marking worker as failure. 

882 if self._cluster._record_and_ignore_transient_timeouts(e): # pylint: disable=protected-access 

883 logging.error( 

884 "Remote function on worker %s failed with %r:%s\n" 

885 "This derived error is ignored and not reported to users.", 

886 worker_device_name, e, e) 

887 if on_transient_failure_fn: 

888 on_transient_failure_fn() 

889 return 

890 

891 # Ignoring derived CancelledErrors to tolerate transient failures in 

892 # PS-worker communication, which initially exposed as an UnavailableError 

893 # and then lead to sub-function cancellation, subsequently getting 

894 # reported from worker to chief as CancelledError. 

895 # We do not mark either worker or PS as failed due to only CancelledError. 

896 # If there are real (non-transient) failures, they must also be reported 

897 # as other errors (UnavailableError most likely) in closure executions. 

898 if isinstance(e, errors.CancelledError) and "/job:" in str(e): 

899 logging.error( 

900 "Remote function on worker %s failed with %r:%s\n" 

901 "This derived error is ignored and not reported to users.", 

902 worker_device_name, e, e) 

903 if on_transient_failure_fn: 

904 on_transient_failure_fn() 

905 return 

906 

907 # This reraises the error, if it's not considered recoverable; otherwise, 

908 # the following failure recovery logic run. At this time, only worker 

909 # unavailability is recoverable. PS unavailability as well as other 

910 # errors in the user function is not recoverable. 

911 self._validate_preemption_failure(e) 

912 

913 logging.error("Worker %s failed with %r:%s", worker_device_name, e, e) 

914 if on_failure_fn: 

915 on_failure_fn(e) 

916 

917 with self._cluster_update_lock: 

918 self._cluster_due_for_update_or_finish.set() 

919 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC) 

920 if self._error_from_recovery: 

921 # TODO(yuefengz): there is only one worker that will get this error. 

922 # Ideally we shuold let all workers notified by `_worker_up_cond` get 

923 # this error. 

924 try: 

925 raise self._error_from_recovery 

926 finally: 

927 self._error_from_recovery = None 

928 logging.info("Worker %s has been recovered.", worker_device_name) 

929 

930 if on_recovery_fn: 

931 logging.info("Worker %s calling on_recovery_fn", worker_device_name) 

932 with self.wait_on_failure( 

933 on_recovery_fn=on_recovery_fn, 

934 on_transient_failure_fn=on_transient_failure_fn, 

935 worker_device_name=worker_device_name): 

936 on_recovery_fn() 

937 

938 def _preemption_handler(self): 

939 """A loop that handles preemption. 

940 

941 This loop waits for signal of worker preemption and upon worker preemption, 

942 it waits until all workers are back and updates the cluster about the 

943 restarted workers. 

944 """ 

945 assert self._should_preemption_thread_run 

946 while True: 

947 self._cluster_due_for_update_or_finish.wait() 

948 if not self._should_preemption_thread_run: 

949 logging.info("Stopping the failure handing thread.") 

950 break 

951 

952 with self._cluster_update_lock: 

953 try: 

954 # TODO(haoyuzhang): support partial cluster recovery 

955 logging.info("Cluster now being recovered.") 

956 with metric_utils.monitored_timer("server_def_update"): 

957 context.context().update_server_def(self._server_def) 

958 

959 # Cluster updated successfully, clear the update signal, and notify 

960 # all workers that they are recovered from failure. 

961 logging.info("Cluster successfully recovered.") 

962 self._worker_up_cond.notify_all() 

963 # The check for _should_preemption_thread_run is necessary since the 

964 # `stop` may have already set _cluster_due_for_update_or_finish. 

965 if self._should_preemption_thread_run: 

966 self._cluster_due_for_update_or_finish.clear() 

967 except Exception as e: # pylint: disable=broad-except 

968 logging.info("Error occurred while updating server def: %s", e) 

969 try: 

970 self._validate_preemption_failure(e) 

971 except Exception as ps_e: # pylint: disable=broad-except 

972 logging.info("Error that occurred while updating server def is not " 

973 "a worker failure. So set it as _error_from_recovery") 

974 # In this case, a parameter server fails. So we raise this error to 

975 # the caller of `wait_on_failure`. 

976 self._error_from_recovery = ps_e 

977 self._worker_up_cond.notify_all() 

978 if self._should_preemption_thread_run: 

979 self._cluster_due_for_update_or_finish.clear() 

980 # NOTE: Since the first RPC (GetStatus) of update_server_def is 

981 # currently blocking by default, error should only happen if: 

982 # (1) More workers failed while waiting for the previous workers to 

983 # come back; 

984 # (2) Worker failed when exchanging subsequent RPCs after the first 

985 # RPC returns. 

986 # Consider adding backoff retry logic if we see the error logged 

987 # too frequently. 

988 logging.error("Cluster update failed with error: %s. Retrying...", e) 

989 

990 

991class Worker(object): 

992 """A worker in a cluster. 

993 

994 Attributes: 

995 worker_index: The index of the worker in the cluster. 

996 device_name: The device string of the worker, e.g. "/job:worker/task:1". 

997 executor: The worker's executor for remote function execution. 

998 failure_handler: The failure handler used to handler worker preemption 

999 failure. 

1000 """ 

1001 

1002 def __init__(self, worker_index, device_name, cluster): 

1003 self.worker_index = worker_index 

1004 self.device_name = device_name 

1005 self.executor = executor.new_executor(enable_async=False) 

1006 self.failure_handler = cluster.failure_handler 

1007 self._cluster = cluster 

1008 self._resource_tracking_lock = threading.Lock() 

1009 self._resource_remote_value_refs = [] 

1010 self._is_dead_with_error = None 

1011 self._should_worker_thread_run = True 

1012 

1013 # Worker threads need to start after `Worker`'s initialization. 

1014 threading.Thread(target=self._process_queue, 

1015 name="WorkerClosureProcessingLoop-%d" % self.worker_index, 

1016 daemon=True).start() 

1017 

1018 def stop(self): 

1019 """Ensure the worker thread is closed.""" 

1020 self._should_worker_thread_run = False 

1021 

1022 def _schedule_resource(self, closure): 

1023 self._cluster.closure_queue.put(closure, tag=self.worker_index) 

1024 

1025 def _set_resources_aborted(self, e): 

1026 """Set the resource ABORTED and add an error to it.""" 

1027 # TODO(yuefengz): maybe we can query whether a tensor is valid or not 

1028 # instead of marking a tensor aborted? 

1029 logging.info("[Worker %d] Clearing all resources.", self.worker_index) 

1030 for weakref_resource in self._resource_remote_value_refs: 

1031 resource = weakref_resource() 

1032 if resource: 

1033 # It is important to set an error on an aborted RemoteValue from a 

1034 # ResourceClosure because its failure will not trigger the worker thread 

1035 # to raise error immediately and the worker may continue executing 

1036 # closures taking it as an input. The error will then be correctly 

1037 # reported to users. 

1038 resource._set_aborted(ClosureAbortedError(e)) # pylint: disable=protected-access 

1039 

1040 def _on_closure_failure(self, closure, e): 

1041 logging.info("[Worker %d] Putting back a closure after it failed.", 

1042 self.worker_index) 

1043 self._cluster.closure_queue.put_back(closure) 

1044 

1045 with self._resource_tracking_lock: 

1046 self._is_dead_with_error = e 

1047 self._set_resources_aborted(e) 

1048 

1049 def _on_resource_closure_failure(self, e): 

1050 """Clear tagged queue to ensure resource closures are rebuilt. 

1051 

1052 Args: 

1053 e: The exception arisen from the resource closure. 

1054 """ 

1055 logging.info("[Worker %d] Clearing tagged queue after resource closure " 

1056 "failure.", self.worker_index) 

1057 with self._resource_tracking_lock: 

1058 self._is_dead_with_error = e 

1059 # No locking on queue is needed since 

1060 # * get will not happen concurrently here. 

1061 # * put to the specific tagged queue will be guarded by 

1062 # `self._resource_tracking_lock`. 

1063 self._cluster.closure_queue.clear_tag_unlocked(self.worker_index) 

1064 self._set_resources_aborted(e) 

1065 

1066 def _on_worker_recovery(self): 

1067 logging.info("[Worker %d] calling _on_worker_recovery", self.worker_index) 

1068 with self._resource_tracking_lock: 

1069 for weakref_resource in self._resource_remote_value_refs: 

1070 resource = weakref_resource() 

1071 if resource: 

1072 self._schedule_resource(resource._closure) # pylint: disable=protected-access 

1073 self._is_dead_with_error = False 

1074 

1075 def _process_closure(self, closure): 

1076 """Runs a closure with preemption handling.""" 

1077 try: 

1078 with self.failure_handler.wait_on_failure( 

1079 on_failure_fn=lambda e: self._on_closure_failure(closure, e), 

1080 on_transient_failure_fn=( 

1081 lambda: self._cluster.closure_queue.put_back(closure)), 

1082 on_recovery_fn=self._on_worker_recovery, 

1083 worker_device_name=self.device_name): 

1084 closure.execute_on(self) 

1085 with metric_utils.monitored_timer("remote_value_fetch"): 

1086 # Copy the remote tensor to local (the coordinator) in case worker 

1087 # becomes unavailable at a later time. 

1088 closure.maybe_call_with_output_remote_value(lambda r: r.get()) 

1089 self._cluster.closure_queue.mark_finished() 

1090 except Exception as e: # pylint: disable=broad-except 

1091 # Avoid logging the derived cancellation error 

1092 if not isinstance(e, errors.CancelledError): 

1093 logging.error( 

1094 " /job:worker/task:%d encountered the following error when " 

1095 "processing closure: %r:%s", self.worker_index, e, e) 

1096 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 

1097 self._cluster.closure_queue.mark_failed(e) 

1098 

1099 def _process_resource_closure(self, closure): 

1100 """Run the given resource closure with preemption handling.""" 

1101 assert closure.tag == self.worker_index 

1102 try: 

1103 with self.failure_handler.wait_on_failure( 

1104 on_failure_fn=self._on_resource_closure_failure, 

1105 on_transient_failure_fn=( 

1106 lambda: self._process_resource_closure(closure)), 

1107 on_recovery_fn=self._on_worker_recovery, 

1108 worker_device_name=self.device_name): 

1109 closure.execute_on(self) 

1110 except Exception as e: # pylint: disable=broad-except 

1111 # Avoid logging the derived cancellation error 

1112 logging.info("[Worker %d] got an exception when processing resource " 

1113 "closure", self.worker_index) 

1114 if not isinstance(e, errors.CancelledError): 

1115 logging.error( 

1116 " /job:worker/task:%d encountered the following error when " 

1117 "processing resource closure: %r:%s", self.worker_index, e, e) 

1118 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 

1119 

1120 def _maybe_delay(self): 

1121 """Delay if corresponding env vars are set.""" 

1122 # If the following two env vars variables are set. Scheduling for workers 

1123 # will start in a staggered manner. Worker i will wait for 

1124 # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding 

1125 # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`. 

1126 delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0")) 

1127 delay_secs *= self.worker_index 

1128 delay_cap = int( 

1129 os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0")) 

1130 if delay_cap: 

1131 delay_secs = min(delay_secs, delay_cap) 

1132 if delay_secs > 0: 

1133 logging.info(" Worker %d sleeping for %d seconds before running function", 

1134 self.worker_index, delay_secs) 

1135 time.sleep(delay_secs) 

1136 

1137 def _process_queue(self): 

1138 """Function running in a worker thread to process closure queues.""" 

1139 self._maybe_delay() 

1140 while self._should_worker_thread_run: 

1141 closure = self._cluster.closure_queue.get(tag=self.worker_index) 

1142 if not self._should_worker_thread_run or closure is None: 

1143 if closure is not None: 

1144 closure.mark_cancelled() 

1145 return 

1146 if isinstance(closure, ResourceClosure): 

1147 self._process_resource_closure(closure) 

1148 else: 

1149 self._process_closure(closure) 

1150 # To properly stop the worker and preemption threads, it is important that 

1151 # `ClusterCoordinator` object is not held onto so its `__del__` can be 

1152 # called. By removing the reference to the `closure` that has already been 

1153 # processed, we ensure that the `closure` object is released, while 

1154 # getting the next `closure` at above `self._cluster.closure_queue.get()` 

1155 # call. 

1156 del closure 

1157 

1158 def create_resource(self, function, args=None, kwargs=None): 

1159 """Synchronously creates a per-worker resource represented by a `RemoteValue`. 

1160 

1161 Args: 

1162 function: the resource function to be run remotely. It should be a 

1163 `tf.function`, a concrete function or a Python function. 

1164 args: positional arguments to be passed to the function. 

1165 kwargs: keyword arguments to be passed to the function. 

1166 

1167 Returns: 

1168 one or several RemoteValue objects depending on the function return 

1169 values. 

1170 """ 

1171 # Some notes about the concurrency: currently all the activities related to 

1172 # the same worker such as creating resources, setting resources' aborted 

1173 # status, and executing closures happen on the same thread. This allows us 

1174 # to have simpler logic of concurrency. 

1175 

1176 closure = ResourceClosure( 

1177 function, 

1178 self._cluster.resource_cancellation_mgr, 

1179 args=args, 

1180 kwargs=kwargs) 

1181 resource_remote_value = closure.build_output_remote_value() 

1182 with self._resource_tracking_lock: 

1183 self._register_resource(resource_remote_value) 

1184 if self._is_dead_with_error: 

1185 resource_remote_value._set_aborted( # pylint: disable=protected-access 

1186 ClosureAbortedError(self._is_dead_with_error)) 

1187 else: 

1188 self._schedule_resource(closure) 

1189 return resource_remote_value 

1190 

1191 def _register_resource(self, resource_remote_value): 

1192 if not isinstance(resource_remote_value, RemoteValue): 

1193 raise ValueError("Resource being registered is not of type " 

1194 "`tf.distribute.experimental.coordinator.RemoteValue`.") 

1195 self._resource_remote_value_refs.append(weakref.ref(resource_remote_value)) 

1196 

1197 

1198class Cluster(object): 

1199 """A cluster with workers. 

1200 

1201 We assume all function errors are fatal and based on this assumption our 

1202 error reporting logic is: 

1203 1) Both `schedule` and `join` can raise a non-retryable error which is the 

1204 first error seen by the coordinator from any previously scheduled functions. 

1205 2) When an error is raised, there is no guarantee on how many previously 

1206 scheduled functions have been executed; functions that have not been executed 

1207 will be thrown away and marked as cancelled. 

1208 3) After an error is raised, the internal state of error will be cleared. 

1209 I.e. functions can continue to be scheduled and subsequent calls of `schedule` 

1210 or `join` will not raise the same error again. 

1211 

1212 Attributes: 

1213 failure_handler: The failure handler used to handler worker preemption 

1214 failure. 

1215 workers: a list of `Worker` objects in the cluster. 

1216 closure_queue: the global Closure queue. 

1217 resource_cancellation_mgr: the cancellation manager used to cancel resource 

1218 closures. 

1219 """ 

1220 

1221 def __init__(self, strategy): 

1222 """Initializes the cluster instance.""" 

1223 

1224 self._num_workers = strategy._num_workers 

1225 self._num_ps = strategy._num_ps 

1226 

1227 # Ignore PS failures reported by workers due to transient connection errors. 

1228 # Transient connectivity issues between workers and PS are relayed by the 

1229 # workers to the coordinator, leading the coordinator to believe that there 

1230 # are PS failures. The difference between transient vs. permanent PS failure 

1231 # is the number of reports from the workers. When this env var is set to a 

1232 # positive integer K, the coordinator ignores up to K reports of a failed PS 

1233 # task, i.e., only when there are more than K trials of executing closures 

1234 # fail due to errors from the same PS instance do we consider the PS 

1235 # instance encounters a failure. 

1236 # TODO(b/164279603): Remove this workaround when the underlying connectivity 

1237 # issue in gRPC server is resolved. 

1238 self._transient_ps_failures_threshold = int( 

1239 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3)) 

1240 self._potential_ps_failures_lock = threading.Lock() 

1241 self._potential_ps_failures_count = [0] * self._num_ps 

1242 

1243 # Ignore worker timeouts due to transient connection errors. 

1244 # Transient connectivity issues might cause the server side to unexpectedly 

1245 # cancel RPC handling logic, leading to closure execution timeouts. When 

1246 # the _transient_timeout_threshold is set to a positive number, the cluster 

1247 # coordinator ignores DeadlineExceeded errors from workers for the specified 

1248 # times before raising the error to users. 

1249 self._transient_timeouts_threshold = int( 

1250 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_TIMEOUTS", 

1251 self._num_workers // 10)) 

1252 self._transient_timeouts_lock = threading.Lock() 

1253 self._transient_timeouts_count = 0 

1254 

1255 self.closure_queue = _CoordinatedClosureQueue() 

1256 # Set this environment variable to use an experimental 

1257 # integration with the runtime coordination service to aid in failure 

1258 # detection and handling. This will not affect the functionality of 

1259 # the strategy or cluster coordinator, but is off by default. 

1260 if os.getenv("TF_PSS_ENABLE_COORDINATION_SERVICE"): 

1261 self.failure_handler = CoordinationServicePreemptionHandler( 

1262 context.get_server_def(), self, 

1263 ) 

1264 else: 

1265 self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), 

1266 self) 

1267 worker_device_strings = [ 

1268 "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers) 

1269 ] 

1270 self.workers = [ 

1271 Worker(i, w, self) for i, w in enumerate(worker_device_strings) 

1272 ] 

1273 

1274 # Cancellation manager for all resource closures. 

1275 self.resource_cancellation_mgr = cancellation.CancellationManager() 

1276 

1277 def stop(self): 

1278 """Stop worker, worker preemption threads, and the closure queue.""" 

1279 logging.info("Stopping cluster, starting with failure handler") 

1280 self.failure_handler.stop() 

1281 

1282 logging.info("Stopping workers") 

1283 for worker in self.workers: 

1284 worker.stop() 

1285 logging.info("Stopping queue") 

1286 self.closure_queue.stop() 

1287 logging.info("Start cancelling remote resource-building functions") 

1288 self.resource_cancellation_mgr.start_cancel() 

1289 

1290 def _record_and_ignore_transient_ps_failure(self, e): 

1291 """Records potential PS failures and return if failure should be ignored.""" 

1292 if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e): 

1293 return False 

1294 

1295 ps_tasks = _extract_failed_ps_instances(str(e)) 

1296 with self._potential_ps_failures_lock: 

1297 for t in ps_tasks: 

1298 self._potential_ps_failures_count[t] += 1 

1299 # The number of UnavailableError encountered on this PS task exceeds the 

1300 # maximum number of ignored error 

1301 if (self._potential_ps_failures_count[t] >= 

1302 self._transient_ps_failures_threshold): 

1303 return False 

1304 return True 

1305 

1306 def _record_and_ignore_transient_timeouts(self, e): 

1307 """Records observed timeout error and return if it should be ignored.""" 

1308 if self._transient_timeouts_threshold <= 0: 

1309 return False 

1310 if not isinstance(e, errors.DeadlineExceededError): 

1311 return False 

1312 with self._transient_timeouts_lock: 

1313 self._transient_timeouts_count += 1 

1314 if self._transient_timeouts_count >= self._transient_timeouts_threshold: 

1315 return False 

1316 return True 

1317 

1318 def schedule(self, function, args, kwargs): 

1319 """Schedules `function` to be dispatched to a worker for execution. 

1320 

1321 Args: 

1322 function: The function to be dispatched to a worker for execution 

1323 asynchronously. 

1324 args: Positional arguments for `fn`. 

1325 kwargs: Keyword arguments for `fn`. 

1326 

1327 Returns: 

1328 A `RemoteValue` object. 

1329 """ 

1330 closure = Closure( 

1331 function, 

1332 self.closure_queue._cancellation_mgr, # pylint: disable=protected-access 

1333 args=args, 

1334 kwargs=kwargs) 

1335 ret = closure.build_output_remote_value() 

1336 self.closure_queue.put(closure) 

1337 return ret 

1338 

1339 def join(self): 

1340 """Blocks until all scheduled functions are executed.""" 

1341 self.closure_queue.wait() 

1342 

1343 def done(self): 

1344 """Returns true if all scheduled functions are executed.""" 

1345 return self.closure_queue.done() 

1346 

1347 

1348@tf_export("distribute.experimental.coordinator.ClusterCoordinator", 

1349 "distribute.coordinator.ClusterCoordinator", v1=[]) 

1350class ClusterCoordinator(object): 

1351 """An object to schedule and coordinate remote function execution. 

1352 

1353 This class is used to create fault-tolerant resources and dispatch functions 

1354 to remote TensorFlow servers. 

1355 

1356 Currently, this class is not supported to be used in a standalone manner. It 

1357 should be used in conjunction with a `tf.distribute` strategy that is designed 

1358 to work with it. The `ClusterCoordinator` class currently only works 

1359 `tf.distribute.experimental.ParameterServerStrategy`. 

1360 

1361 __The `schedule`/`join` APIs__ 

1362 

1363 The most important APIs provided by this class is the `schedule`/`join` pair. 

1364 The `schedule` API is non-blocking in that it queues a `tf.function` and 

1365 returns a `RemoteValue` immediately. The queued functions will be dispatched 

1366 to remote workers in background threads and their `RemoteValue`s will be 

1367 filled asynchronously. Since `schedule` doesn’t require worker assignment, the 

1368 `tf.function` passed in can be executed on any available worker. If the worker 

1369 it is executed on becomes unavailable before its completion, it will be 

1370 migrated to another worker. Because of this fact and function execution is not 

1371 atomic, a function may be executed more than once. 

1372 

1373 __Handling Task Failure__ 

1374 

1375 This class when used with 

1376 `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in 

1377 fault tolerance for worker failures. That is, when some workers are not 

1378 available for any reason to be reached from the coordinator, the training 

1379 progress continues to be made with the remaining workers. Upon recovery of a 

1380 failed worker, it will be added for function execution after datasets created 

1381 by `create_per_worker_dataset` are re-built on it. 

1382 

1383 When a parameter server fails, a `tf.errors.UnavailableError` is raised by 

1384 `schedule`, `join` or `done`. In this case, in addition to bringing back the 

1385 failed parameter server, users should restart the coordinator so that it 

1386 reconnects to workers and parameter servers, re-creates the variables, and 

1387 loads checkpoints. If the coordinator fails, after the user brings it back, 

1388 the program will automatically connect to workers and parameter servers, and 

1389 continue the progress from a checkpoint. 

1390 

1391 It is thus essential that in user's program, a checkpoint file is periodically 

1392 saved, and restored at the start of the program. If an 

1393 `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a 

1394 checkpoiont, its `iterations` property roughly indicates the number of steps 

1395 that have been made. This can be used to decide how many epochs and steps are 

1396 needed before the training completion. 

1397 

1398 See `tf.distribute.experimental.ParameterServerStrategy` docstring for an 

1399 example usage of this API. 

1400 

1401 This is currently under development, and the API as well as implementation 

1402 are subject to changes. 

1403 """ 

1404 

1405 def __new__(cls, strategy): 

1406 # `ClusterCoordinator` is kept as a single instance to a given `Strategy`. 

1407 # TODO(rchao): Needs a lock for thread-safety 

1408 if strategy._cluster_coordinator is None: 

1409 strategy._cluster_coordinator = super( 

1410 ClusterCoordinator, cls).__new__(cls) 

1411 return strategy._cluster_coordinator 

1412 

1413 def __init__(self, strategy): 

1414 """Initialization of a `ClusterCoordinator` instance. 

1415 

1416 Args: 

1417 strategy: a supported `tf.distribute.Strategy` object. Currently, only 

1418 `tf.distribute.experimental.ParameterServerStrategy` is supported. 

1419 

1420 Raises: 

1421 ValueError: if the strategy being used is not supported. 

1422 """ 

1423 if not getattr(self, "_has_initialized", False): 

1424 if not hasattr(strategy, "_is_parameter_server_strategy_v2"): 

1425 raise ValueError( 

1426 "Only `tf.distribute.experimental.ParameterServerStrategy` " 

1427 "is supported to work with " 

1428 "`tf.distribute.experimental.coordinator.ClusterCoordinator` " 

1429 "currently.") 

1430 self._strategy = strategy 

1431 self.strategy.extended._used_with_coordinator = True 

1432 self._cluster = Cluster(strategy) 

1433 self._has_initialized = True 

1434 

1435 def __del__(self): 

1436 logging.info("ClusterCoordinator destructor: stopping cluster") 

1437 self._cluster.stop() 

1438 

1439 @property 

1440 def strategy(self): 

1441 """Returns the `Strategy` associated with the `ClusterCoordinator`.""" 

1442 return self._strategy 

1443 

1444 def schedule(self, fn, args=None, kwargs=None): 

1445 """Schedules `fn` to be dispatched to a worker for asynchronous execution. 

1446 

1447 This method is non-blocking in that it queues the `fn` which will be 

1448 executed later and returns a 

1449 `tf.distribute.experimental.coordinator.RemoteValue` object immediately. 

1450 `fetch` can be called on it to wait for the function execution to finish 

1451 and retrieve its output from a remote worker. On the other hand, call 

1452 `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for 

1453 all scheduled functions to finish. 

1454 

1455 `schedule` guarantees that `fn` will be executed on a worker at least once; 

1456 it could be more than once if its corresponding worker fails in the middle 

1457 of its execution. Note that since worker can fail at any point when 

1458 executing the function, it is possible that the function is partially 

1459 executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator` 

1460 guarantees that in those events, the function will eventually be executed on 

1461 any worker that is available. 

1462 

1463 If any previously scheduled function raises an error, `schedule` will raise 

1464 any one of those errors, and clear the errors collected so far. What happens 

1465 here, some of the previously scheduled functions may have not been executed. 

1466 User can call `fetch` on the returned 

1467 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 

1468 executed, failed, or cancelled, and reschedule the corresponding function if 

1469 needed. 

1470 

1471 When `schedule` raises, it guarantees that there is no function that is 

1472 still being executed. 

1473 

1474 At this time, there is no support of worker assignment for function 

1475 execution, or priority of the workers. 

1476 

1477 `args` and `kwargs` are the arguments passed into `fn`, when `fn` is 

1478 executed on a worker. They can be 

1479 `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case, 

1480 the argument will be substituted with the corresponding component on the 

1481 target worker. Arguments that are not 

1482 `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into 

1483 `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue` 

1484 is not supported to be input `args` or `kwargs`. 

1485 

1486 Args: 

1487 fn: A `tf.function`; the function to be dispatched to a worker for 

1488 execution asynchronously. Regular python function is not supported to be 

1489 scheduled. 

1490 args: Positional arguments for `fn`. 

1491 kwargs: Keyword arguments for `fn`. 

1492 

1493 Returns: 

1494 A `tf.distribute.experimental.coordinator.RemoteValue` object that 

1495 represents the output of the function scheduled. 

1496 

1497 Raises: 

1498 Exception: one of the exceptions caught by the coordinator from any 

1499 previously scheduled function, since the last time an error was thrown 

1500 or since the beginning of the program. 

1501 """ 

1502 if not isinstance(fn, 

1503 (def_function.Function, tf_function.ConcreteFunction)): 

1504 raise TypeError( 

1505 "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`" 

1506 " only accepts a `tf.function` or a concrete function.") 

1507 # Slot variables are usually created during function tracing time; thus 

1508 # `schedule` needs to be called within the `strategy.scope()`. 

1509 with self.strategy.scope(): 

1510 self.strategy.extended._being_scheduled = True # pylint: disable=protected-access 

1511 schedule_remote_value = self._cluster.schedule( 

1512 fn, args=args, kwargs=kwargs) 

1513 self.strategy.extended._being_scheduled = False # pylint: disable=protected-access 

1514 return schedule_remote_value 

1515 

1516 def join(self): 

1517 """Blocks until all the scheduled functions have finished execution. 

1518 

1519 If any previously scheduled function raises an error, `join` will fail by 

1520 raising any one of those errors, and clear the errors collected so far. If 

1521 this happens, some of the previously scheduled functions may have not been 

1522 executed. Users can call `fetch` on the returned 

1523 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 

1524 executed, failed, or cancelled. If some that have been cancelled need to be 

1525 rescheduled, users should call `schedule` with the function again. 

1526 

1527 When `join` returns or raises, it guarantees that there is no function that 

1528 is still being executed. 

1529 

1530 Raises: 

1531 Exception: one of the exceptions caught by the coordinator by any 

1532 previously scheduled function since the last time an error was thrown or 

1533 since the beginning of the program. 

1534 """ 

1535 self._cluster.join() 

1536 

1537 def done(self): 

1538 """Returns whether all the scheduled functions have finished execution. 

1539 

1540 If any previously scheduled function raises an error, `done` will fail by 

1541 raising any one of those errors. 

1542 

1543 When `done` returns True or raises, it guarantees that there is no function 

1544 that is still being executed. 

1545 

1546 Returns: 

1547 Whether all the scheduled functions have finished execution. 

1548 Raises: 

1549 Exception: one of the exceptions caught by the coordinator by any 

1550 previously scheduled function since the last time an error was thrown or 

1551 since the beginning of the program. 

1552 """ 

1553 return self._cluster.done() 

1554 

1555 def create_per_worker_dataset(self, dataset_fn): 

1556 """Create dataset on each worker. 

1557 

1558 This creates dataset on workers from the input which can be either a 

1559 `tf.data.Dataset`, a `tf.distribute.DistributedDataset` or a function which 

1560 returns a dataset, and returns an object that represents the collection of 

1561 those individual datasets. Calling `iter` on such collection of datasets 

1562 returns a `tf.distribute.experimental.coordinator.PerWorkerValues`, which is 

1563 a collection of iterators, where the iterators have been placed on 

1564 respective workers. 

1565 

1566 Calling `next` on a `PerWorkerValues` of iterator is unsupported. The 

1567 iterator is meant to be passed as an argument into 

1568 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When 

1569 the scheduled function is about to be executed by a worker, the 

1570 function will receive the individual iterator that corresponds to the 

1571 worker. The `next` method can be called on an iterator inside a 

1572 scheduled function when the iterator is an input of the function. 

1573 

1574 Currently the `schedule` method assumes workers are all the same and thus 

1575 assumes the datasets on different workers are the same, except they may be 

1576 shuffled differently if they contain a `dataset.shuffle` operation and a 

1577 random seed is not set. Because of this, we also recommend the datasets to 

1578 be repeated indefinitely and schedule a finite number of steps instead of 

1579 relying on the `OutOfRangeError` from a dataset. 

1580 

1581 

1582 Example: 

1583 

1584 ```python 

1585 strategy = tf.distribute.experimental.ParameterServerStrategy( 

1586 cluster_resolver=...) 

1587 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 

1588 strategy=strategy) 

1589 

1590 @tf.function 

1591 def worker_fn(iterator): 

1592 return next(iterator) 

1593 

1594 def per_worker_dataset_fn(): 

1595 return strategy.distribute_datasets_from_function( 

1596 lambda x: tf.data.Dataset.from_tensor_slices([3] * 3)) 

1597 

1598 per_worker_dataset = coordinator.create_per_worker_dataset( 

1599 per_worker_dataset_fn) 

1600 per_worker_iter = iter(per_worker_dataset) 

1601 remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,)) 

1602 assert remote_value.fetch() == 3 

1603 ``` 

1604 

1605 Args: 

1606 dataset_fn: The dataset function that returns a dataset. This is to be 

1607 executed on the workers. 

1608 

1609 Returns: 

1610 An object that represents the collection of those individual 

1611 datasets. `iter` is expected to be called on this object that returns 

1612 a `tf.distribute.experimental.coordinator.PerWorkerValues` of the 

1613 iterators (that are on the workers). 

1614 """ 

1615 return values_lib.get_per_worker_dataset(dataset_fn, self) 

1616 

1617 def _create_per_worker_resources(self, fn, args=None, kwargs=None): 

1618 """Synchronously create resources on the workers. 

1619 

1620 The resources are represented by 

1621 `tf.distribute.experimental.coordinator.RemoteValue`s. 

1622 

1623 Args: 

1624 fn: The function to be dispatched to all workers for execution 

1625 asynchronously. 

1626 args: Positional arguments for `fn`. 

1627 kwargs: Keyword arguments for `fn`. 

1628 

1629 Returns: 

1630 A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which 

1631 wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue` 

1632 objects. 

1633 """ 

1634 results = [] 

1635 for w in self._cluster.workers: 

1636 results.append(w.create_resource(fn, args=args, kwargs=kwargs)) 

1637 return PerWorkerValues(tuple(results)) 

1638 

1639 def fetch(self, val): 

1640 """Blocking call to fetch results from the remote values. 

1641 

1642 This is a wrapper around 

1643 `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a 

1644 `RemoteValue` structure; it returns the execution results of 

1645 `RemoteValue`s. If not ready, wait for them while blocking the caller. 

1646 

1647 Example: 

1648 ```python 

1649 strategy = ... 

1650 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 

1651 strategy) 

1652 

1653 def dataset_fn(): 

1654 return tf.data.Dataset.from_tensor_slices([1, 1, 1]) 

1655 

1656 with strategy.scope(): 

1657 v = tf.Variable(initial_value=0) 

1658 

1659 @tf.function 

1660 def worker_fn(iterator): 

1661 def replica_fn(x): 

1662 v.assign_add(x) 

1663 return v.read_value() 

1664 return strategy.run(replica_fn, args=(next(iterator),)) 

1665 

1666 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 

1667 distributed_iterator = iter(distributed_dataset) 

1668 result = coordinator.schedule(worker_fn, args=(distributed_iterator,)) 

1669 assert coordinator.fetch(result) == 1 

1670 ``` 

1671 

1672 Args: 

1673 val: The value to fetch the results from. If this is structure of 

1674 `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be 

1675 called on the individual 

1676 `tf.distribute.experimental.coordinator.RemoteValue` to get the result. 

1677 

1678 Returns: 

1679 If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a 

1680 structure of `tf.distribute.experimental.coordinator.RemoteValue`s, 

1681 return the fetched `tf.distribute.experimental.coordinator.RemoteValue` 

1682 values immediately if they are available, or block the call until they are 

1683 available, and return the fetched 

1684 `tf.distribute.experimental.coordinator.RemoteValue` values with the same 

1685 structure. If `val` is other types, return it as-is. 

1686 """ 

1687 

1688 def _maybe_fetch(val): 

1689 if isinstance(val, RemoteValue): 

1690 return val.fetch() 

1691 else: 

1692 return val 

1693 

1694 # TODO(yuefengz): we should fetch values in a batch. 

1695 return nest.map_structure(_maybe_fetch, val) 

1696 

1697 

1698def _extract_failed_ps_instances(err_msg): 

1699 """Return a set of potentially failing ps instances from error message.""" 

1700 tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg) 

1701 return set(int(t.split(":")[-1]) for t in tasks) 

1702 

1703 

1704def _is_ps_failure(error): 

1705 """Whether the error is considered a parameter server failure.""" 

1706 if isinstance(error, PSUnavailableError): 

1707 return True 

1708 

1709 # For an `ClosureInputError` or `ClosureAbortedError`, extract 

1710 # the original error and assess it accordingly. 

1711 if isinstance(error, (ClosureInputError, ClosureAbortedError)): 

1712 error = error.original_exception 

1713 

1714 if _RPC_ERROR_FROM_PS not in str(error): 

1715 return False 

1716 

1717 if isinstance(error, (errors.UnavailableError, errors.AbortedError)): 

1718 return True 

1719 

1720 # The following error could happen when the remote task fails and restarts 

1721 # in a very short interval during which no RPCs were exchanged to detect the 

1722 # failure. In that case, gRPC allows channel (which is different from a 

1723 # connection) to be reused for a replaced server listening to same address. 

1724 if isinstance(error, errors.InvalidArgumentError): 

1725 if ("unknown device" in str(error).lower() or 

1726 "Unable to find the relevant tensor remote_handle" in str(error)): 

1727 return True 

1728 

1729 return False 

1730 

1731 

1732def _handle_graph_execution_error_as_worker_failure(): 

1733 return int(os.environ.get("TF_PS_HANDLE_UNKNOWN_ERROR", "0")) > 0 

1734 

1735 

1736def _is_worker_failure(error): 

1737 """Whether the error is considered a worker failure.""" 

1738 

1739 # TODO(b/216666282): Understand why worker failure can manifest as a 

1740 # "Graph execution error" `UnknownError`. 

1741 if (_handle_graph_execution_error_as_worker_failure() and 

1742 isinstance(error, errors.UnknownError) and 

1743 "Graph execution error" in str(error)): 

1744 logging.info(f"Handling {type(error)}: {str(error)} as worker failure.") 

1745 return True 

1746 

1747 # For an `ClosureInputError` or `ClosureAbortedError`, extract 

1748 # the original error and assess it accordingly. 

1749 if isinstance(error, (ClosureInputError, ClosureAbortedError)): 

1750 error = error.original_exception 

1751 

1752 if _JOB_WORKER_STRING_IDENTIFIER not in str(error): 

1753 return False 

1754 if _RPC_ERROR_FROM_PS in str(error): 

1755 return False 

1756 

1757 # TODO(haoyuzhang): Consider using special status code if error from a 

1758 # remote is derived from RPC errors originated from other hosts. 

1759 if isinstance(error, (errors.UnavailableError, errors.AbortedError)): 

1760 return True 

1761 

1762 # The following error could happen when the remote task fails and restarts 

1763 # in a very short interval during which no RPCs were exchanged to detect the 

1764 # failure. In that case, gRPC allows channel (which is different from a 

1765 # connection) to be reused for a replaced server listening to same address. 

1766 if isinstance(error, errors.InvalidArgumentError): 

1767 if ("unknown device" in str(error).lower() or 

1768 "Primary device is not remote" in str(error) or 

1769 "Unable to find the relevant tensor remote_handle" in str(error)): 

1770 return True 

1771 

1772 # TODO(b/162541228): The following 2 types of errors are very rare and only 

1773 # observed in large-scale testing. The types of errors should be reduced. 

1774 # This could happen when the function registration fails. In the observed 

1775 # cases this only happens to the dataset related functions. 

1776 if isinstance(error, errors.NotFoundError): 

1777 if ("is neither a type of a primitive operation nor a name of a function " 

1778 "registered" in str(error)): 

1779 return True 

1780 

1781 # NOTE(b/179061495): During worker preemptions, if multiple functions are 

1782 # running concurrently (especially with subfunctions spanning chief/PS), 

1783 # CancelledError can be returned due to chief/PS cancelling outstanding RPCs 

1784 # to the failing workers. 

1785 if isinstance(error, errors.CancelledError): 

1786 return True 

1787 

1788 # This can occur when preparing closures for execution when doing exact 

1789 # evaluation, because the iterator creation, which occurs within the 

1790 # tf.function, needs to access the worker device, so it fails if the worker is 

1791 # down. 

1792 if isinstance(error, TypeError) and "Binding inputs to tf.function" in str( 

1793 error): 

1794 return True 

1795 

1796 return False