Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/multi_process_runner.py: 21%

470 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Multi-process runner for testing purpose.""" 

16 

17import collections 

18import contextlib 

19import json 

20import os 

21import signal 

22import sys 

23import threading 

24import time 

25import unittest 

26import weakref 

27 

28from absl import logging 

29import six 

30from six.moves import queue as Queue 

31 

32from tensorflow.python import tf2 

33from tensorflow.python.compat import v2_compat 

34from tensorflow.python.distribute import multi_worker_util 

35from tensorflow.python.distribute import multi_process_lib 

36from tensorflow.python.eager import context 

37from tensorflow.python.framework import test_util 

38from tensorflow.python.util.tf_export import tf_export 

39 

40multiprocessing = multi_process_lib.multiprocessing 

41 

42# pylint: disable=g-import-not-at-top 

43try: 

44 # `faulthandler` is not available in py2. 

45 import faulthandler 

46except ImportError: 

47 faulthandler = None 

48 

49# TODO(b/150264776): Remove after resolving CI issue. 

50try: 

51 import dill 

52except ImportError: 

53 dill = None 

54 

55# TODO(b/150264776): Remove after resolving CI issue. 

56try: 

57 import tblib.pickling_support 

58 # For pickling traceback objects. 

59 tblib.pickling_support.install() 

60except ImportError: 

61 pass 

62 

63 

64# _ProcessStatusInfo contains process status information. When is_successful 

65# attribute is True, the subprocess has ended successfully, or if False, the 

66# exception stack trace info is stored in exc_info to pass on to parent process 

67# to be re-raised. 

68_ProcessStatusInfo = collections.namedtuple( 

69 '_ProcessStatusInfo', 

70 ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value']) 

71 

72# Information returned from a successful MultiProcessRunner run. 

73MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', 

74 ['return_value', 'stdout']) 

75 

76# visible_gpus: If not None, CUDA_VISIBLE_DEVICES is set to visible_gpus. 

77TestEnvironment = collections.namedtuple('TestEnvironment', [ 

78 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', 

79 'v2_enabled', 'executing_eagerly', 'visible_gpus' 

80]) 

81 

82# Resources for communication between worker processes and the main process. 

83# 

84# `process_status_queue` is used by `multi_process_runner` internally for 

85# communication from subprocesses to the parent process for whether it's been 

86# successful, and if not what the error stack trace is. 

87# `parent_to_sub_queue` is used for communications from parent to subprocess. 

88# Currently this is only used to terminate subprocesses. 

89# TODO(rchao): Remove this once subprocess is terminated by SIGKILL. 

90# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent 

91# process. 

92# `barrier` is a barrier for the party of all subprocesses. 

93Resources = collections.namedtuple('Resources', [ 

94 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier' 

95]) 

96 

97# Default time out sec is selected so that it's handled before the default 

98# "medium" timeout of the test runs. 

99_DEFAULT_TIMEOUT_SEC = 200 

100 

101# The timeout in seconds to wait to force kill a child process. When a child 

102# process times out we first try to SIGTERM it so that it has a chance to dump 

103# stacktraces. However dumping stacktrace can take a long time. 

104_FORCE_KILL_WAIT_SEC = 30 

105 

106 

107class MultiProcessRunner(object): 

108 """A utility class to start multiple processes to simulate a cluster. 

109 

110 We need to use multiple processes to simulate a cluster in TF 2.0 tests 

111 because TF 2.0 has some process-global data structures that have to be 

112 separated by processes. We also need child processes to test out our fault 

113 tolerance because shutting down a standard TensorFlow server within its 

114 process is not supported. 

115 

116 Note: the main test program that uses this runner class must run main program 

117 via `test_main` defined in this file. Using this runner in non-test binaries 

118 is not supported yet. 

119 

120 This class is not thread-safe. Child processes will inherit TF2 behavior flag. 

121 """ 

122 

123 def __init__(self, 

124 fn, 

125 cluster_spec, 

126 rpc_layer=None, 

127 max_run_time=None, 

128 grpc_fail_fast=None, 

129 stream_output=True, 

130 return_output=False, 

131 use_dill_for_args=True, 

132 daemon=False, 

133 dependence_on_chief=True, 

134 auto_restart=False, 

135 share_gpu=True, 

136 args=None, 

137 kwargs=None): 

138 """Instantiation of a `MultiProcessRunner`. 

139 

140 Args: 

141 fn: Function to be run on child processes. This will be run on processes 

142 for all task types. 

143 cluster_spec: Dict for cluster spec. The utility function 

144 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` 

145 can be conveniently used to create such dict. The following is an 

146 example of cluster with three workers and two ps's. 

147 {"worker": ["worker0.example.com:2222", 

148 "worker1.example.com:2222", 

149 "worker2.example.com:2222"], 

150 "ps": ["ps0.example.com:2222", 

151 "ps1.example.com:2222"]} 

152 rpc_layer: RPC layer to use. Default value is 'grpc'. 

153 max_run_time: `None` or integer. If not `None`, child processes are forced 

154 to exit at approximately this many seconds after this utility is called. 

155 We achieve this through `signal.alarm()` api. Note that this is best 

156 effort at Python level since Python signal handler does not get executed 

157 when it runs lower level C/C++ code. So it can be delayed for 

158 arbitrarily long time. If any of the child process is still running when 

159 `max_run_time` is up, they will be force-terminated and an 

160 `UnexpectedSubprocessExitError` may be raised. If `None`, child 

161 processes are not forced to exit. 

162 grpc_fail_fast: Whether GRPC connection between processes should fail 

163 without retrying. Defaults to None, in which case the environment 

164 variable is not explicitly set. 

165 stream_output: True if the output/error from the subprocesses should be 

166 streamed to be printed in parent process' log. Defaults to True. 

167 return_output: If True, the output/error from the subprocesses should be 

168 collected to be attached to the resulting namedtuple returned from 

169 `join()`. The list of output can be retrieved via `stdout` attribute. 

170 Defaults to False. 

171 use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill 

172 can pickle more objects, but doesn't work with types in 

173 `multiprocessing` library like `Mutex`. 

174 daemon: Whether to start processes as daemons. 

175 dependence_on_chief: Whether to terminates the cluster if the chief exits. 

176 If auto_restart is True, it only terminates the cluster if the chief 

177 exits with a zero exit code. 

178 auto_restart: Whether to automatically restart processes that exit with 

179 non-zero exit code. 

180 share_gpu: Whether to share GPUs among workers. If False, each worker is 

181 assigned different GPUs in a roundrobin fashion. This should be True 

182 whenever possible for better test execution coverage; some situations 

183 that need it to be False are tests that runs NCCL. 

184 args: Positional arguments to be sent to `fn` run on subprocesses. 

185 kwargs: Keyword arguments to be sent to `fn` run on subprocesses. 

186 

187 Raises: 

188 RuntimeError: if `multi_process_runner.test_main()` is not called. 

189 ValueError: if there are more than one chief in the `cluster_spec`. 

190 SkipTest: if thread sanitizer is enabled (which is incompatible with MPR). 

191 """ 

192 if test_util.is_tsan_enabled(): 

193 raise unittest.SkipTest( 

194 'ThreadSanitizer is not compatible with MultiProcessRunner.') 

195 

196 assert cluster_spec is not None 

197 if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1: 

198 raise ValueError('If chief exists in the cluster, there must be at most ' 

199 'one chief. Current `cluster_spec` has {} chiefs.' 

200 .format(len(cluster_spec['chief']))) 

201 _check_initialization() 

202 if not callable(fn): 

203 raise ValueError('fn is not a callable') 

204 

205 self._fn = fn 

206 self._cluster_spec = cluster_spec 

207 self._rpc_layer = rpc_layer or 'grpc' 

208 self._max_run_time = max_run_time 

209 self._grpc_fail_fast = grpc_fail_fast 

210 self._stream_output = stream_output 

211 # TODO(rchao): Revisit return_output argument to consider other solution. 

212 self._return_output = return_output 

213 self._dependence_on_chief = dependence_on_chief 

214 self._use_dill_for_args = use_dill_for_args 

215 self._daemon = daemon 

216 self._auto_restart = auto_restart 

217 self._args = args or () 

218 self._kwargs = kwargs or {} 

219 

220 self._share_gpu = share_gpu 

221 self._total_gpu = len(context.context().list_physical_devices('GPU')) 

222 

223 # Child processes should have the same v2 and eager behavior. 

224 self._v2_enabled = tf2.enabled() 

225 self._executing_eagerly = context.executing_eagerly() 

226 

227 self._joined = False 

228 self._process_lock = threading.Lock() 

229 # Guarded by self._process_lock. 

230 self._processes = {} 

231 # Record which processes are terminated. Due to a bug in Python<3.7, 

232 # terminated processes return 255 exit code, which should cause an exception 

233 # in join(). 

234 # https://bugs.python.org/issue30589 

235 # Guarded by self._process_lock. 

236 self._terminated = set() 

237 self._reading_threads = [] 

238 

239 self._manager = manager() 

240 self._process_status_queue = self._manager.Queue() 

241 self._parent_to_sub_queue = self._manager.Queue() 

242 parties = sum(len(addresses) for addresses in self._cluster_spec.values()) 

243 self._barrier = self._manager.Barrier(parties) 

244 

245 # We use a queue to collect outputs from worker processes since it's thread 

246 # safe. 

247 self._streaming_queue = self._manager.Queue() 

248 

249 self._watchdog_thread = None 

250 

251 def set_args(self, args=None, kwargs=None): 

252 self._args = args or self._args 

253 self._kwargs = kwargs or self._kwargs 

254 

255 def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): 

256 """Function to continuously read lines from subprocesses.""" 

257 with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: 

258 for line in reader: 

259 task_string = '[{}-{}]:'.format(task_type, task_id) 

260 formatted_line = '{} {}'.format(task_string.ljust(14), line) 

261 if self._stream_output: 

262 # TODO(rchao): Use a lock here to ensure the printed lines are not 

263 # broken. 

264 print(formatted_line, end='', flush=True) 

265 if self._return_output: 

266 self._streaming_queue.put(formatted_line) 

267 

268 def _start_subprocess_and_reading_thread(self, 

269 task_type, 

270 task_id, 

271 cluster_spec=None, 

272 fn=None, 

273 args=None, 

274 kwargs=None): 

275 """Start a subprocess and a thread the reads lines from the subprocess.""" 

276 

277 if dill is None: 

278 raise unittest.SkipTest( 

279 'TODO(b/150264776): Resolve dependency issue in CI') 

280 

281 cluster_spec = cluster_spec or self._cluster_spec 

282 visible_gpus = None 

283 if not self._share_gpu and self._total_gpu > 0: 

284 # Assign GPUs in a roundrobin fashion. 

285 id_in_cluster = multi_worker_util.id_in_cluster(cluster_spec, task_type, 

286 task_id) 

287 worker_count = multi_worker_util.worker_count(cluster_spec, task_type) 

288 visible_gpus = list(range(id_in_cluster, self._total_gpu, worker_count)) 

289 

290 test_env = TestEnvironment( 

291 task_type=task_type, 

292 task_id=task_id, 

293 cluster_spec=cluster_spec, 

294 rpc_layer=self._rpc_layer, 

295 grpc_fail_fast=self._grpc_fail_fast, 

296 v2_enabled=self._v2_enabled, 

297 executing_eagerly=self._executing_eagerly, 

298 visible_gpus=visible_gpus, 

299 ) 

300 pipe_r, pipe_w = multiprocessing.Pipe(duplex=False) 

301 resources = Resources( 

302 process_status_queue=self._process_status_queue, 

303 parent_to_sub_queue=self._parent_to_sub_queue, 

304 streaming_pipe_w=pipe_w, 

305 barrier=self._barrier, 

306 ) 

307 if fn is None: 

308 fn, args, kwargs = self._fn, self._args, self._kwargs 

309 # Always use dill to pickle fn so that we support more callable 

310 # types, e.g. lambda. 

311 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) 

312 if self._use_dill_for_args: 

313 args = dill.dumps(args, dill.HIGHEST_PROTOCOL) 

314 kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) 

315 

316 p = _Process( 

317 test_env=test_env, 

318 target=_ProcFunc(), 

319 args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args), 

320 daemon=self._daemon) 

321 p.start() 

322 self._processes[(task_type, task_id)] = p 

323 self._terminated.discard((task_type, task_id)) 

324 

325 # For each subprocess, we dedicate a thread continuously reading lines 

326 # from them. 

327 thread = threading.Thread( # pylint: disable=unexpected-keyword-arg 

328 target=self._continuously_readline_from_sub, 

329 args=(pipe_r, task_type, task_id)) 

330 thread.start() 

331 self._reading_threads.append(thread) 

332 

333 if self._watchdog_thread is None or not self._watchdog_thread.is_alive(): 

334 self._watchdog_thread = threading.Thread(target=self._process_watchdog) 

335 self._watchdog_thread.start() 

336 

337 def start(self): 

338 """Starts processes, one for each task in `cluster_spec`. 

339 

340 Note that this is best effort by the applicable multiprocessing library, 

341 and it may take up to seconds for a subprocess to be successfully started. 

342 """ 

343 with self._process_lock: 

344 if self._processes: 

345 raise ValueError('MultiProcessRunner already started.') 

346 if self._joined: 

347 raise ValueError('cannot start new processes after' 

348 'MultiProcessRunner.join() is called') 

349 

350 for task_type, addresses in self._cluster_spec.items(): 

351 for task_id, _ in enumerate(addresses): 

352 self._start_subprocess_and_reading_thread(task_type, task_id) 

353 

354 # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, 

355 # without this the tests become very flaky. 

356 if self._max_run_time is not None: 

357 

358 def handler(signum, frame): 

359 del signum, frame 

360 self.terminate_all() 

361 

362 signal.signal(signal.SIGALRM, handler) 

363 signal.alarm(self._max_run_time) 

364 

365 def start_in_process_as(self, as_task_type, as_task_id): 

366 """Start the processes, with the specified task run in main process. 

367 

368 This is similar to `start()` except that the task with task_type 

369 `as_task_type` and task_id `as_task_id` is run in the main process. 

370 This method is particularly useful when debugging tool such as `pdb` is 

371 needed in some specific task. Note that since this method is blocking until 

372 that specific task exits, additional actions would need a thread to be 

373 called: 

374 

375 ```python 

376 def fn(): 

377 # user code to be run 

378 import pdb; pdb.set_trace() 

379 

380 def follow_ups(): 

381 time.sleep(5) 

382 mpr.start_single_process( 

383 task_type='evaluator', 

384 task_id=0) 

385 

386 mpr = multi_process_runner.MultiProcessRunner( 

387 fn, 

388 multi_worker_test_base.create_cluster_spec( 

389 has_chief=True, num_workers=1)) 

390 threading.Thread(target=follow_ups).start() 

391 mpr.start_in_process_as(as_task_type='chief', as_task_id=0) 

392 mpr.join() 

393 ``` 

394 

395 Note that if `return_output=True`, the logs/stdout by task 

396 run by the main process is not available in result.stdout. 

397 

398 Args: 

399 as_task_type: The task type to be run in the main process. 

400 as_task_id: The task id to be run in the main process. 

401 """ 

402 if self._processes: 

403 raise ValueError('MultiProcessRunner already started.') 

404 with self._process_lock: 

405 if self._joined: 

406 raise ValueError('cannot start new processes after' 

407 'MultiProcessRunner.join() is called') 

408 for task_type, addresses in self._cluster_spec.items(): 

409 for task_id, _ in enumerate(addresses): 

410 if not (task_type == as_task_type and task_id == as_task_id): 

411 self._start_subprocess_and_reading_thread(task_type, task_id) 

412 

413 _set_tf_config(as_task_type, as_task_id, self._cluster_spec, 

414 self._rpc_layer) 

415 self._fn(*self._args, **self._kwargs) 

416 

417 def start_single_process(self, 

418 task_type, 

419 task_id, 

420 cluster_spec=None, 

421 fn=None, 

422 args=None, 

423 kwargs=None): 

424 """Starts a single process. 

425 

426 This starts a process in the cluster with the task type, task id, and the 

427 process function (`fn`). If process function is `None`, the function 

428 provided at `__init__` will be used. If `cluster_spec` is `None`, the 

429 cluster spec provided at `__init__` will be used. 

430 

431 TODO(rchao): It is meant that all subprocesses will be updated with the new 

432 cluster spec, but this has yet to be implemented. At this time only the 

433 newly started subprocess picks up this updated cluster spec. 

434 

435 Args: 

436 task_type: The task type. 

437 task_id: The task id. 

438 cluster_spec: The cluster spec to be used on the newly started 

439 process. If `None`, the cluster spec provided at `__init__` will be 

440 used. 

441 fn: The process function to be run on the newly started 

442 process. If specified, specify `args` and `kwargs` as well. If `None`, 

443 the function provided at `__init__` will be used. 

444 args: Optional positional arguments to be supplied in `fn`. 

445 kwargs: Optional keyword arguments to be supplied in `fn`. 

446 """ 

447 with self._process_lock: 

448 if self._joined: 

449 raise ValueError('cannot start new processes after' 

450 'MultiProcessRunner.join() is called') 

451 self._start_subprocess_and_reading_thread( 

452 task_type, 

453 task_id, 

454 cluster_spec=cluster_spec, 

455 fn=fn, 

456 args=args or (), 

457 kwargs=kwargs or {}) 

458 

459 def _queue_to_list(self, queue_to_convert): 

460 """Convert `queue.Queue` to `list`.""" 

461 list_to_return = [] 

462 # Calling `queue.empty()` is not reliable. 

463 while True: 

464 try: 

465 list_to_return.append(queue_to_convert.get(block=False)) 

466 except Queue.Empty: 

467 break 

468 return list_to_return 

469 

470 def _get_process_statuses(self): 

471 # One worker may have multiple statuses. We only keep the last one. 

472 statuses = {} 

473 for status in self._queue_to_list(self._process_status_queue): 

474 statuses[(status.task_type, status.task_id)] = status 

475 return statuses 

476 

477 def get_process_id(self, task_type, task_id): 

478 """Returns the subprocess id given the task type and task id.""" 

479 with self._process_lock: 

480 p = self._processes.get((task_type, task_id), None) 

481 return p.pid if p else None 

482 

483 def get_process_exit_code(self, task_type, task_id): 

484 """Returns the subprocess exit code given the task type and task id. 

485 

486 Args: 

487 task_type: The task type. 

488 task_id: The task id. 

489 

490 Returns: 

491 The subprocess exit code; `None` if the subprocess has not exited yet. 

492 

493 Raises: 

494 KeyError: If the corresponding subprocess is not found with `task_type` 

495 and `task_id`. 

496 """ 

497 with self._process_lock: 

498 p = self._processes[(task_type, task_id)] 

499 return p.exitcode if p else None 

500 

501 def process_exists(self, task_type, task_id): 

502 """Returns whether the subprocess still exists given the task type and id. 

503 

504 Args: 

505 task_type: The task type. 

506 task_id: The task id. 

507 

508 Returns: 

509 Boolean; whether the subprocess still exists. If the subprocess has 

510 exited, this returns False. 

511 """ 

512 return self.get_process_exit_code(task_type, task_id) is None 

513 

514 def _process_watchdog(self): 

515 """Simulates a cluster management system. 

516 

517 - If auto_restart is True, it restarts processes that exit with a non-zero 

518 exit code. Note that when join() times out it overrides auto_restart to 

519 False. 

520 - If dependence_on_chief is True, it terminates all processes once the chief 

521 exits. If auto_restart is also True, it only terminates all processes if 

522 the chief exit with a zero exit code, otherwise it restarts the chief. 

523 

524 This runs in self._watchdog_thread. 

525 """ 

526 while True: 

527 time.sleep(1) 

528 with self._process_lock: 

529 chief = self._processes.get(('chief', 0), None) 

530 # Terminate the cluster when _dependence_on_chief is True if either: 

531 # - chief has exited with zero exit code. 

532 # - chief has exited with non-zero exit code and self._auto_restart is 

533 # False. 

534 if chief and self._dependence_on_chief and chief.exitcode is not None: 

535 if chief.exitcode == 0 or (not self._auto_restart): 

536 for p in self._processes.values(): 

537 # Give other processes a chance to exit on their own. 

538 p.join(timeout=3) 

539 self._terminate_all() 

540 for p in self._processes.values(): 

541 p.join() 

542 return 

543 

544 # Auto restart failed processes if self._auto_restart is True. 

545 if self._auto_restart: 

546 has_failure = False 

547 for (task_type, task_id), p in self._processes.items(): 

548 if p.exitcode is not None and p.exitcode != 0: 

549 has_failure = True 

550 logging.info('Restarting failed %s-%d', task_type, task_id) 

551 self._start_subprocess_and_reading_thread(task_type, task_id) 

552 if has_failure: 

553 continue 

554 

555 # Exit the thread if all processes have exited at this point. 

556 if all(p.exitcode is not None for p in self._processes.values()): 

557 return 

558 

559 def _reraise_if_subprocess_error(self, process_statuses): 

560 for process_status in process_statuses.values(): 

561 assert isinstance(process_status, _ProcessStatusInfo) 

562 if not process_status.is_successful: 

563 process_status.exc_info[1].mpr_result = self._get_mpr_result( 

564 process_statuses) 

565 six.reraise(*process_status.exc_info) 

566 

567 def join(self, timeout=_DEFAULT_TIMEOUT_SEC): 

568 """Joins all the processes with timeout. 

569 

570 If any of the subprocesses does not exit approximately after `timeout` 

571 seconds has passed after `join` call, this raises a 

572 `SubprocessTimeoutError`. 

573 

574 Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to 

575 log the stack traces of the subprocesses when they exit. However, this 

576 results in timeout when the test runs with tsan (thread sanitizer); if tsan 

577 is being run on the test targets that rely on timeout to assert information, 

578 `MultiProcessRunner.terminate_all()` must be called after `join()`, before 

579 the test exits, so the subprocesses are terminated with SIGKILL, and data 

580 race is removed. 

581 

582 Args: 

583 timeout: optional integer or `None`. If provided as an integer, and not 

584 all processes report status within roughly `timeout` seconds, a 

585 `SubprocessTimeoutError` exception will be raised. If `None`, `join` never 

586 times out. 

587 

588 Returns: 

589 A `MultiProcessRunnerResult` object, which has two attributes, 

590 `return_value` and `stdout`. `return_value` always contains a list of 

591 return values from the subprocesses, although the order is not meaningful. 

592 If `return_output` argument is True at `__init__`, `stdout` is available 

593 that contains a list of all messages from subprocesses' stdout and stderr. 

594 

595 Raises: 

596 SubprocessTimeoutError: if not all processes report status approximately 

597 within `timeout` seconds. When this is raised, a 

598 `MultiProcessRunnerResult` object can be retrieved by 

599 `SubprocessTimeoutError`'s mpr_result attribute, which has the same 

600 structure as above 'Returns' section describes. 

601 UnexpectedSubprocessExitError: If any of the subprocesses did not exit 

602 properly (for example, they exit on SIGTERM or SIGKILL signal). When 

603 this is raised, a `MultiProcessRunnerResult` object can be retrieved by 

604 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the 

605 same structure as above 'Returns' section describes. If `max_run_time` 

606 is not `None`, it is expected that some subprocesses may be 

607 force-killed when `max_run_time` is up, and this is raised in those 

608 cases. 

609 Exception: if there is an Exception propagated from any subprocess. When 

610 this is raised, a `MultiProcessRunnerResult` object can be retrieved by 

611 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the 

612 same structure as above 'Returns' section describes. 

613 """ 

614 if timeout and not isinstance(timeout, int): 

615 raise ValueError('`timeout` must be an integer or `None`.') 

616 with self._process_lock: 

617 if self._joined: 

618 raise ValueError("MultiProcessRunner can't be joined twice.") 

619 self._joined = True 

620 

621 self._watchdog_thread.join(timeout) 

622 if self._watchdog_thread.is_alive(): 

623 # Timeout. Force termination to dump worker processes stack trace. 

624 with self._process_lock: 

625 self._auto_restart = False 

626 logging.error('Timeout when joining for child processes. Terminating...') 

627 self.terminate_all(sig=signal.SIGTERM) 

628 # Wait for the processes to terminate by themselves first, so they have a 

629 # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them. 

630 self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC) 

631 if self._watchdog_thread.is_alive(): 

632 logging.error('Timeout when waiting for child processes to ' 

633 'print stacktrace. Sending SIGKILL...') 

634 self.terminate_all() 

635 self._watchdog_thread.join() 

636 process_statuses = self._get_process_statuses() 

637 self._reraise_if_subprocess_error(process_statuses) 

638 raise SubprocessTimeoutError( 

639 'One or more subprocesses timed out, where timeout was set to {}s. ' 

640 'Please change the `timeout` argument for ' 

641 '`MultiProcessRunner.join()` or `multi_process_runner.run()` ' 

642 'if it should be adjusted.'.format(timeout), 

643 self._get_mpr_result(process_statuses)) 

644 

645 for (task_type, task_id), p in self._processes.items(): 

646 logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) 

647 

648 process_statuses = self._get_process_statuses() 

649 self._reraise_if_subprocess_error(process_statuses) 

650 

651 # Checking all the processes that are expected to exit properly. 

652 for (task_type, task_id), p in self._processes.items(): 

653 # Successfully exiting process has exit code 0. We ignore processes that 

654 # are terminated. 

655 assert p.exitcode is not None 

656 if (p.exitcode > 0 and (task_type, task_id) not in self._terminated): 

657 raise UnexpectedSubprocessExitError( 

658 'Subprocess %s-%d exited with exit code %s. See logs for details.' 

659 % (task_type, task_id, p.exitcode), 

660 self._get_mpr_result(process_statuses)) 

661 

662 logging.info('Joining log reading threads.') 

663 for thread in self._reading_threads: 

664 thread.join() 

665 logging.info('Joined log reading threads.') 

666 

667 # Clear the alarm. 

668 signal.alarm(0) 

669 

670 return self._get_mpr_result(process_statuses) 

671 

672 def _get_mpr_result(self, process_statuses): 

673 stdout = self._queue_to_list(self._streaming_queue) 

674 return_values = [] 

675 for process_status in process_statuses.values(): 

676 if process_status.return_value is not None: 

677 return_values.append(process_status.return_value) 

678 return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) 

679 

680 def terminate(self, task_type, task_id): 

681 """Terminates the process with `task_type` and `task_id`. 

682 

683 If auto_retart=True, the terminated task will be restarted unless the chief 

684 has already exited with zero exit code. 

685 

686 Args: 

687 task_type: the task type. 

688 task_id: the task id. 

689 

690 """ 

691 with self._process_lock: 

692 p = self._processes.get((task_type, task_id), None) 

693 if p is None: 

694 raise ValueError('{}-{} does not exist'.format(task_type, task_id)) 

695 self._terminated.add((task_type, task_id)) 

696 # TODO(crccw): change to use Process.terminate() as well. 

697 self._parent_to_sub_queue.put('terminate {} {}'.format( 

698 task_type, task_id)) 

699 p.join() 

700 

701 def _terminate_all(self, sig=None): 

702 """Terminates all subprocesses. 

703 

704 The caller is required to hold self._process_lock. 

705 

706 Args: 

707 sig: the signal used to terminate the process. The default is SIGKILL. 

708 """ 

709 

710 # Use SIGKILL as default. In systems where that's unavailable such as 

711 # windows, use SIGTERM. 

712 sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) 

713 for (task_type, task_id), p in self._processes.items(): 

714 if p.exitcode is not None: 

715 logging.info('%s-%d has already exited. Not terminating.', task_type, 

716 task_id) 

717 continue 

718 try: 

719 os.kill(p.pid, sig) 

720 self._terminated.add((task_type, task_id)) 

721 logging.info('%s-%d terminated with signal %r.', task_type, task_id, 

722 sig) 

723 except ProcessLookupError: 

724 logging.info('Attempting to kill %s-%d but it does not exist.', 

725 task_type, task_id) 

726 

727 def terminate_all(self, sig=None): 

728 """Terminates all subprocesses.""" 

729 with self._process_lock: 

730 self._terminate_all(sig) 

731 

732 

733class _Process(multi_process_lib.Process): 

734 """A modified `multiprocessing.Process` that can set up environment variables.""" 

735 

736 # TODO(crccw): consider moving other logics in _ProcFunc to _Process. 

737 

738 def __init__(self, test_env, **kwargs): 

739 super(_Process, self).__init__(**kwargs) 

740 self._test_env = test_env 

741 self._actual_run = getattr(self, 'run') 

742 self.run = self._run_with_setenv 

743 

744 def _run_with_setenv(self): 

745 # We need to set environment variables before doing anything because 

746 # setenv() is not thread-safe. 

747 test_env = self._test_env 

748 if test_env.grpc_fail_fast is not None: 

749 os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast) 

750 if test_env.visible_gpus: 

751 os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( 

752 [str(i) for i in test_env.visible_gpus]) 

753 _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec, 

754 test_env.rpc_layer) 

755 return self._actual_run() 

756 

757 

758class _ProcFunc(object): 

759 """Represents a callable to run in a subprocess.""" 

760 

761 @contextlib.contextmanager 

762 def _runtime_mode(self, executing_eagerly): 

763 if executing_eagerly: 

764 with context.eager_mode(): 

765 yield 

766 else: 

767 with context.graph_mode(): 

768 yield 

769 

770 def _message_checking_func(self, task_type, task_id): 

771 """A function that regularly checks messages from parent process.""" 

772 # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. 

773 while True: 

774 try: 

775 message = self._resources.parent_to_sub_queue.get(block=False) 

776 

777 # Currently the only possible message is termination. 

778 if not message.startswith('terminate'): 

779 raise ValueError('Unrecognized message: {}'.format(message)) 

780 

781 if message == 'terminate {} {}'.format(task_type, task_id): 

782 break 

783 else: 

784 # If the message is not targeting this process, put it back to the 

785 # queue. 

786 self._resources.parent_to_sub_queue.put(message) 

787 time.sleep(1) 

788 except Queue.Empty: 

789 time.sleep(0.1) 

790 self._resources.process_status_queue.put( 

791 _ProcessStatusInfo( 

792 task_type=task_type, 

793 task_id=task_id, 

794 is_successful=True, 

795 exc_info=None, 

796 return_value=None)) 

797 # `os._exit(1)` is used to more reliably terminate a subprocess. 

798 os._exit(1) # pylint: disable=protected-access 

799 

800 def _close_streaming(self): 

801 """Close stdout, stderr and streaming pipe. 

802 

803 We need to explicitly close them since Tensorflow may take a while to exit, 

804 so that the reading threads in the main process can exit more quickly. 

805 """ 

806 sys.stdout.flush() 

807 sys.stderr.flush() 

808 sys.stdout.close() 

809 sys.stderr.close() 

810 self._resources.streaming_pipe_w.close() 

811 

812 def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args): 

813 """The wrapper function that actually gets run in child process(es).""" 

814 

815 global _barrier 

816 

817 self._resources = resources 

818 _barrier = self._resources.barrier 

819 fn = dill.loads(fn) 

820 if use_dill_for_args: 

821 args = dill.loads(args) 

822 kwargs = dill.loads(kwargs) 

823 

824 if faulthandler is not None: 

825 faulthandler.enable() 

826 faulthandler.register(signal.SIGTERM, chain=True) 

827 

828 # All logging should go to stderr to be streamed to the main process. 

829 logging.set_stderrthreshold(logging.DEBUG) 

830 

831 # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so 

832 # print() and logging.*() write directly to `streaming_pipe_w`. 

833 # Unfortunately since we cannot prepend task_type and task_id information to 

834 # the streamed logs we will need a thread per subprocess to distinguish 

835 # where the piece of message is from. 

836 os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) 

837 os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) 

838 

839 pid = os.getpid() 

840 logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, 

841 test_env.task_type, test_env.task_id) 

842 logging.info('TF_CONFIG: %r', os.environ['TF_CONFIG']) 

843 

844 # The thread will be dedicated to checking messages from the parent process. 

845 threading.Thread( # pylint: disable=unexpected-keyword-arg 

846 target=self._message_checking_func, 

847 args=(test_env.task_type, test_env.task_id), 

848 daemon=True).start() 

849 

850 if test_env.v2_enabled: 

851 v2_compat.enable_v2_behavior() 

852 

853 with self._runtime_mode(test_env.executing_eagerly): 

854 info = _run_contained(test_env.task_type, test_env.task_id, fn, args, 

855 kwargs) 

856 self._resources.process_status_queue.put(info) 

857 

858 # Re-raise the exception in addition to reporting it to the parent 

859 # process, so that even if `--test_timeout` flag is set and the 

860 # error doesn't make it to be shown in parent process before bazel's 

861 # timeout, the log would still show what happens in this subprocess, 

862 # instead of silently suppressing the error due to early bazel 

863 # timeout. Raising an error in the subprocess produces stack trace in 

864 # the log, but the program continues running. 

865 if not info.is_successful: 

866 six.reraise(*info.exc_info) 

867 

868 self._close_streaming() 

869 

870 # Exit with code 0 as it's considered successful exit at this point. 

871 sys.exit(0) 

872 

873 

874# Active MultiProcessPoolRunner. We need to shut them down when the program 

875# exits, and this is by setting the `tearDownModule` of the module containing 

876# `__main__`. Note this it set in both the parent process and the subprocesses. 

877_active_pool_runners = weakref.WeakSet() 

878 

879 

880def _shutdown_all_pool_runners(): 

881 for pool in _active_pool_runners: 

882 pool.shutdown() 

883 

884 

885def is_oss(): 

886 """Returns whether the test is run under OSS.""" 

887 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] 

888 

889 

890class MultiProcessPoolRunner(object): 

891 """A utility class to start a process pool to simulate a cluster. 

892 

893 It's similar to MultiProcessRunner, but uses a pool of processes to avoid the 

894 expensive initialization cost of Tensorflow. 

895 """ 

896 

897 def __init__(self, cluster_spec, initializer=None, share_gpu=True): 

898 """Creates a multi-process pool runner. 

899 

900 Args: 

901 cluster_spec: Dict for cluster spec. The following is an example of 

902 cluster with three workers. 

903 {"worker": ["worker0.example.com:2222", 

904 "worker1.example.com:2222", 

905 "worker2.example.com:2222"]} 

906 initializer: a callable to called at the startup of worker processes. 

907 share_gpu: Whether to share GPUs among workers. If False, each worker is 

908 assigned different GPUs in a roundrobin fashion. 

909 

910 Raises: 

911 RuntimeError: if `multi_process_runner.test_main()` is not called. 

912 ValueError: if there are more than one chief in the `cluster_spec`. 

913 """ 

914 _active_pool_runners.add(self) 

915 self._cluster_spec = cluster_spec 

916 self._initializer = initializer 

917 self._share_gpu = share_gpu 

918 self._conn = {} 

919 self._runner = None 

920 

921 def __del__(self): 

922 self.shutdown() 

923 

924 def shutdown(self): 

925 """Shuts down the worker pool.""" 

926 for conn in self._conn.values(): 

927 conn.close() 

928 self._conn = {} 

929 if self._runner is not None: 

930 try: 

931 self._runner.join() 

932 except Exception as e: # pylint: disable=broad-except 

933 logging.error( 

934 'Ignoring exception when shutting down MultiProcessPoolRunner: %s', 

935 e) 

936 self._runner = None 

937 

938 def _start(self): 

939 """Starts the worker pool.""" 

940 # We need different arguments for different processes so we're passing a 

941 # no-op fn here and use start_single_process instead. 

942 

943 if dill is None: 

944 raise unittest.SkipTest( 

945 'TODO(b/150264776): Resolve dependency issue in CI') 

946 

947 self._runner = MultiProcessRunner( 

948 fn=lambda: None, 

949 cluster_spec=self._cluster_spec, 

950 use_dill_for_args=False, 

951 share_gpu=self._share_gpu) 

952 if self._initializer: 

953 initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL) 

954 else: 

955 initializer = None 

956 for task_type, addresses in self._cluster_spec.items(): 

957 for task_id, _ in enumerate(addresses): 

958 conn1, conn2 = multiprocessing.Pipe(duplex=True) 

959 self._conn[(task_type, task_id)] = conn1 

960 self._runner.start_single_process( 

961 task_type, 

962 task_id, 

963 fn=_pool_runner_worker, 

964 args=(task_type, task_id, initializer, conn2)) 

965 

966 def run(self, fn, args=None, kwargs=None): 

967 """Runs `fn` with `args` and `kwargs` on all jobs. 

968 

969 Args: 

970 fn: The function to be run. 

971 args: Optional positional arguments to be supplied in `fn`. 

972 kwargs: Optional keyword arguments to be supplied in `fn`. 

973 

974 Returns: 

975 A list of return values. 

976 """ 

977 _check_initialization() 

978 # TODO(b/150264776): skip in OSS until it's implemented. 

979 multi_process_lib.Process() 

980 if self._runner is None: 

981 self._start() 

982 

983 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) 

984 for conn in self._conn.values(): 

985 conn.send((fn, args or [], kwargs or {})) 

986 

987 process_statuses = [] 

988 for (task_type, task_id), conn in self._conn.items(): 

989 logging.info('Waiting for the result from %s-%d', task_type, task_id) 

990 try: 

991 process_statuses.append(conn.recv()) 

992 except EOFError: 

993 # This shouldn't happen due to exceptions in fn. This usually 

994 # means bugs in the runner. 

995 self.shutdown() 

996 raise RuntimeError('Unexpected EOF. Worker process may have died. ' 

997 'Please report a bug') 

998 

999 return_values = [] 

1000 for process_status in process_statuses: 

1001 assert isinstance(process_status, _ProcessStatusInfo) 

1002 if not process_status.is_successful: 

1003 six.reraise(*process_status.exc_info) 

1004 if process_status.return_value is not None: 

1005 return_values.append(process_status.return_value) 

1006 

1007 return return_values 

1008 

1009 

1010def _pool_runner_worker(task_type, task_id, initializer, conn): 

1011 """Function that runs on the workers in a pool. 

1012 

1013 It listens for callables to run and returns the result until `conn` is closed. 

1014 It captures the exceptions during executing the callable and return it through 

1015 `conn`. 

1016 

1017 Args: 

1018 task_type: the task type. 

1019 task_id: the task index. 

1020 initializer: a callable to execute during startup. 

1021 conn: a multiprocessing.Connection object to listen for tasks and send 

1022 results. 

1023 """ 

1024 if initializer: 

1025 initializer = dill.loads(initializer) 

1026 initializer() 

1027 while True: 

1028 try: 

1029 fn, args, kwargs = conn.recv() 

1030 except EOFError: 

1031 break 

1032 fn = dill.loads(fn) 

1033 info = _run_contained(task_type, task_id, fn, args, kwargs) 

1034 sys.stdout.flush() 

1035 sys.stderr.flush() 

1036 conn.send(info) 

1037 

1038 

1039def _run_contained(task_type, task_id, fn, args, kwargs): 

1040 """Runs `fn` with `args` and `kwargs`. 

1041 

1042 The function returns _ProcessStatusInfo which captures the return value and 

1043 the exception. 

1044 

1045 Args: 

1046 task_type: the task type. 

1047 task_id: the task index. 

1048 fn: the function to be run. 

1049 args: optional positional arguments to be supplied in `fn`. 

1050 kwargs: optional keyword arguments to be supplied in `fn`. 

1051 

1052 Returns: 

1053 a _ProcessStatusInfo. 

1054 

1055 """ 

1056 is_successful = False 

1057 return_value = None 

1058 exc_info = None 

1059 try: 

1060 return_value = fn(*args, **kwargs) 

1061 is_successful = True 

1062 return _ProcessStatusInfo( 

1063 task_type=task_type, 

1064 task_id=task_id, 

1065 is_successful=is_successful, 

1066 exc_info=exc_info, 

1067 return_value=return_value) 

1068 

1069 # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not 

1070 # handled here. 

1071 except Exception: # pylint: disable=broad-except 

1072 exc_info = sys.exc_info() 

1073 return _ProcessStatusInfo( 

1074 task_type=task_type, 

1075 task_id=task_id, 

1076 is_successful=is_successful, 

1077 exc_info=exc_info, 

1078 return_value=return_value) 

1079 

1080 

1081@tf_export('__internal__.distribute.multi_process_runner' 

1082 '.SubprocessTimeoutError', 

1083 v1=[]) 

1084class SubprocessTimeoutError(RuntimeError): 

1085 """An error that indicates there is at least one subprocess timing out. 

1086 

1087 When this is raised, a namedtuple object representing the multi-process run 

1088 result can be retrieved by 

1089 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s 

1090 `mpr_result` attribute. See 

1091 `tf.__internal__.distribute.multi_process_runner.run` for more information. 

1092 """ 

1093 

1094 def __init__(self, msg, mpr_result): 

1095 super(SubprocessTimeoutError, self).__init__(msg) 

1096 self.mpr_result = mpr_result 

1097 

1098 

1099@tf_export('__internal__.distribute.multi_process_runner' 

1100 '.UnexpectedSubprocessExitError', 

1101 v1=[]) 

1102class UnexpectedSubprocessExitError(RuntimeError): 

1103 """An error indicating there is at least one subprocess with unexpected exit. 

1104 

1105 When this is raised, a namedtuple object representing the multi-process run 

1106 result can be retrieved by 

1107 `tf.__internal__.distribute.multi_process_runner 

1108 .UnexpectedSubprocessExitError`'s 

1109 `mpr_result` attribute. See 

1110 `tf.__internal__.distribute.multi_process_runner.run` for more information. 

1111 """ 

1112 

1113 def __init__(self, msg, mpr_result): 

1114 super(UnexpectedSubprocessExitError, self).__init__(msg) 

1115 self.mpr_result = mpr_result 

1116 

1117 

1118@tf_export( 

1119 '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[]) 

1120class NotInitializedError(RuntimeError): 

1121 """An error indicating `multi_process_runner.run` is used without init. 

1122 

1123 When this is raised, user is supposed to call 

1124 `tf.__internal__.distribute.multi_process_runner.test_main()` within 

1125 `if __name__ == '__main__':` block to properly initialize 

1126 `multi_process_runner.run`. 

1127 """ 

1128 pass 

1129 

1130 

1131def _check_initialization(): 

1132 if not multi_process_lib.initialized(): 

1133 raise NotInitializedError( 

1134 '`multi_process_runner` is not initialized. ' 

1135 'Please call `tf.__internal__.distribute.multi_process_runner.' 

1136 'test_main()` within `if __name__ == \'__main__\':` block ' 

1137 'in your python module to properly initialize ' 

1138 '`multi_process_runner`.') 

1139 

1140 

1141def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): 

1142 """Set TF_CONFIG environment variable.""" 

1143 tf_config_dict = { 

1144 'cluster': cluster_spec, 

1145 'task': { 

1146 'type': task_type, 

1147 'index': task_id, 

1148 }, 

1149 } 

1150 if rpc_layer is not None: 

1151 tf_config_dict['rpc_layer'] = rpc_layer 

1152 os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) 

1153 

1154 

1155@tf_export('__internal__.distribute.multi_process_runner.run', v1=[]) 

1156def run(fn, 

1157 cluster_spec, 

1158 rpc_layer=None, 

1159 max_run_time=None, 

1160 return_output=False, 

1161 timeout=_DEFAULT_TIMEOUT_SEC, 

1162 args=None, 

1163 kwargs=None): 

1164 """Run `fn` in multiple processes according to `cluster_spec`. 

1165 

1166 Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run` 

1167 launches multiple processes, each of which runs `fn`. These processes are 

1168 referred to as "subprocesses" or "child processes". Each of those subprocesses 

1169 will have their `TF_CONFIG` environment variable set, according to 

1170 `cluster_spec` and their task types. The stdout of the subprocesses are 

1171 streamed to the main process' and thus available in logs (if `stream_output` 

1172 is True), with [type-id] prefix. 

1173 

1174 `tf.__internal__.distribute.multi_process_runner.run` will block until all 

1175 subprocesses have successfully exited, and return a namedtuple object that 

1176 represents the run result. This object has a `return_value` attribute, which 

1177 is a list that contains subprocesses `fn`'s return values, for those 

1178 subprocesses that successfully returned from `fn`. The order of `return_value` 

1179 list is not meaningful. If an optional arg `return_output` (default to False) 

1180 is set to True, the namedtuple object will have an additional attribute 

1181 `stdout`, which is a list containing the stdout of the subprocesses. If any 

1182 subprocess' `fn` ends up raising an error, that error will be reraised from 

1183 `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned 

1184 namedtuple object will be available through the exception's 

1185 `mpr_result` attribute. 

1186 

1187 This utility is used for simulating running TensorFlow programs across 

1188 multiple task types, and each of the task type may contain more than one task 

1189 (except for "chief" where more than one task is prohibited). Test coverage of 

1190 multi-worker training is the main application of this utility, where code 

1191 written for multi-worker training can be realistically covered in unit tests. 

1192 

1193 Any test module that uses 

1194 `tf.__internal__.distribute.multi_process_runner.run()` must call 

1195 `tf.__internal__.distribute.multi_process_runner.test_main()` instead of 

1196 regular `test.main()` inside `if __name__ == '__main__':` block for proper 

1197 initialization. 

1198 

1199 Args: 

1200 fn: Function to be run on child processes. This will be run on processes for 

1201 all task types. 

1202 cluster_spec: Dict for cluster spec. The utility function 

1203 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can 

1204 be conveniently used to create such dict. The following is an example of 

1205 cluster with three workers and two ps's. 

1206 {"worker": ["worker0.example.com:2222", 

1207 "worker1.example.com:2222", 

1208 "worker2.example.com:2222"], 

1209 "ps": ["ps0.example.com:2222", 

1210 "ps1.example.com:2222"]} 

1211 rpc_layer: RPC layer to use. Default value is 'grpc'. 

1212 max_run_time: `None` or integer. If not `None`, child processes are forced 

1213 to exit at approximately this many seconds after this utility is called. 

1214 We achieve this through `signal.alarm()` api. Note that this is best 

1215 effort at Python level since Python signal handler does not get executed 

1216 when it runs lower level C/C++ code. So it can be delayed for arbitrarily 

1217 long time. If any of the child process is still running when 

1218 `max_run_time` is up, they will be force-terminated and an 

1219 `tf.__internal__.distribute.multi_process_runner 

1220 .UnexpectedSubprocessExitError` 

1221 may be raised. If `None`, child processes are not forced to exit. 

1222 return_output: If True, the output/error from the subprocesses should be 

1223 collected to be attached to the resulting namedtuple returned from this 

1224 utility. The list of output can be retrieved via `stdout` attribute. 

1225 Defaults to False. 

1226 timeout: optional integer or `None`. If provided as an integer, and not all 

1227 processes report status within roughly `timeout` seconds, a 

1228 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError` 

1229 exception will be raised. If `None`, 

1230 `tf.__internal__.distribute.multi_process_runner.run` never times out. 

1231 Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in 

1232 `multi_process_runner` module. 

1233 args: Positional arguments to be sent to `fn` run on subprocesses. 

1234 kwargs: Keyword arguments to be sent to `fn` run on subprocesses. 

1235 

1236 Returns: 

1237 A namedtuple object, which has two attributes, 

1238 `return_value` and `stdout`. `return_value` always contains a list of 

1239 returnvalues from the subprocesses, although the order is not meaningful. 

1240 If `return_output` argument is True, `stdout` is available that contains a 

1241 list of all messages from subprocesses' stdout and stderr, and the order 

1242 is mostly chronological. 

1243 

1244 Raises: 

1245 RuntimeError: if 

1246 `tf.__internal__.distribute.multi_process_runner.test_main()` is 

1247 not called in test's `if __name__ == '__main__':` block. 

1248 ValueError: if there are more than one chief in the `cluster_spec`. 

1249 tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if 

1250 not all processes report status approximately 

1251 within `timeout` seconds. When this is raised, a 

1252 namedtuple object can be retrieved by 

1253 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s 

1254 `mpr_result` attribute, which has the same 

1255 structure as above 'Returns' section describes. 

1256 tf.__internal__.distribute.multi_process_runner 

1257 .UnexpectedSubprocessExitError: 

1258 If any of the subprocesses did not exit 

1259 properly (for example, they exit on SIGTERM or SIGKILL signal). When 

1260 this is raised, a namedtuple object can be retrieved by 

1261 `tf.__internal__.distribute.multi_process_runner 

1262 .UnexpectedSubprocessExitError`'s 

1263 `mpr_result` attribute, which has the 

1264 same structure as above 'Returns' section describes. If `max_run_time` 

1265 is not `None`, it is expected that some subprocesses may be 

1266 force-killed when `max_run_time` is up, and this is raised in those 

1267 cases. 

1268 Exception: if there is an Exception propagated from any subprocess. When 

1269 this is raised, a namedtuple object can be retrieved by 

1270 `tf.__internal__.distribute.multi_process_runner 

1271 .UnexpectedSubprocessExitError` 

1272 `mpr_result` attribute, which has the 

1273 same structure as above 'Returns' section describes. 

1274 

1275 Examples: 

1276 

1277 ```python 

1278 class SimpleMultiProcessTest(tf.test.TestCase): 

1279 

1280 def test_simple_printing_and_return(self): 

1281 

1282 def fn(): 

1283 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 

1284 

1285 # This will print "[chief-0]: Task type: chief , task id: 0" 

1286 # for chief, for example. 

1287 logging.info('Task type: %s, task id: %d', 

1288 resolver.task_type, resolver.task_id) 

1289 

1290 return resolver.task_type 

1291 

1292 result = tf.__internal__.distribute.multi_process_runner.run( 

1293 fn=fn, 

1294 cluster_spec=( 

1295 tf.__internal__ 

1296 .distribute.multi_process_runner.create_cluster_spec( 

1297 has_chief=True, num_workers=2))) 

1298 assert sorted(result.return_value) == ['chief', 'worker', 'worker'] 

1299 

1300 def test_error_from_fn(self): 

1301 

1302 def fn(): 

1303 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 

1304 raise ValueError('Task type {}, task id {} is errors out'.format( 

1305 resolver.task_type, resolver.task_id)) 

1306 

1307 with self.assertRaisesRegexp(ValueError, 

1308 'Task type worker, task id 0 is errors out'): 

1309 cluster_spec = ( 

1310 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 

1311 num_workers=1)) 

1312 tf.__internal__.distribute.multi_process_runner.run( 

1313 fn=fn, cluster_spec=cluster_spec) 

1314 

1315 

1316 if __name__ == '__main__': 

1317 tf.__internal__.distribute.multi_process_runner.test_main() 

1318 ``` 

1319 """ 

1320 runner = MultiProcessRunner( 

1321 fn, 

1322 cluster_spec, 

1323 rpc_layer, 

1324 max_run_time=max_run_time, 

1325 return_output=return_output, 

1326 args=args, 

1327 kwargs=kwargs) 

1328 runner.start() 

1329 return runner.join(timeout) 

1330 

1331 

1332# This is set by MultiProcessRunner in worker processes. 

1333_barrier = None 

1334 

1335 

1336@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[]) 

1337def get_barrier(): 

1338 """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`. 

1339 

1340 `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns 

1341 a `multiprocessing.Barrier` object which can be used within `fn` of 

1342 `tf.__internal__.distribute.multi_process_runner` to wait with 

1343 `barrier.wait()` call until all other tasks have also reached the 

1344 `barrier.wait()` call, before they can proceed individually. 

1345 

1346 Note that all tasks (subprocesses) have to reach `barrier.wait()` call to 

1347 proceed. Currently it is not supported to block on only a subset of tasks 

1348 in the cluster. 

1349 

1350 Example: 

1351 ```python 

1352 

1353 def fn(): 

1354 some_work_to_be_done_by_all_tasks() 

1355 

1356 tf.__internal__.distribute.multi_process_runner.get_barrier().wait() 

1357 

1358 # The barrier guarantees that at this point, all tasks have finished 

1359 # `some_work_to_be_done_by_all_tasks()` 

1360 some_other_work_to_be_done_by_all_tasks() 

1361 

1362 result = tf.__internal__.distribute.multi_process_runner.run( 

1363 fn=fn, 

1364 cluster_spec=( 

1365 tf.__internal__ 

1366 .distribute.multi_process_runner.create_cluster_spec( 

1367 num_workers=2))) 

1368 ``` 

1369 

1370 

1371 Returns: 

1372 A `multiprocessing.Barrier` for `multi_process_runner.run`. 

1373 """ 

1374 if _barrier is None: 

1375 raise ValueError( 

1376 'barrier is not defined. It is likely because you are calling ' 

1377 'get_barrier() in the main process. get_barrier() can only be called ' 

1378 'in the subprocesses.' 

1379 ) 

1380 return _barrier 

1381 

1382 

1383_manager = None 

1384_manager_lock = threading.Lock() 

1385 

1386 

1387def manager(): 

1388 """Returns the multiprocessing manager object for concurrency tools. 

1389 

1390 The manager object is useful as it controls a server process that holds 

1391 the python objects that can be shared across processes. This can be used 

1392 for parent-subprocess communication: 

1393 

1394 ```python 

1395 manager = multi_process_runner.manager() 

1396 some_event_happening_in_subprocess = manager.Event() 

1397 mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec, 

1398 args=(some_event_happening_in_subprocess,)) 

1399 mpr.start() 

1400 some_event_happening_in_subprocess.wait() 

1401 # Do something that only should after some event happens in subprocess. 

1402 ``` 

1403 

1404 Note that the user of multi_process_runner should not create additional 

1405 `multiprocessing.Manager()` objects; doing so can result in segfault in 

1406 some cases. 

1407 

1408 This method should only be called after multi_process_runner.test_main() is 

1409 called. 

1410 """ 

1411 _check_initialization() 

1412 global _manager 

1413 with _manager_lock: 

1414 if _manager is None: 

1415 _manager = multiprocessing.Manager() 

1416 return _manager 

1417 

1418 

1419@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[]) 

1420def test_main(): 

1421 """Main function to be called within `__main__` of a test file. 

1422 

1423 Any test module that uses 

1424 `tf.__internal__.distribute.multi_process_runner.run()` 

1425 must call this instead of regular `test.main()` inside 

1426 `if __name__ == '__main__':` block, or an error will be raised when 

1427 `tf.__internal__.distribute.multi_process_runner.run()` is used. This method 

1428 takes 

1429 care of needed initialization for launching multiple subprocesses. 

1430 

1431 Example: 

1432 ```python 

1433 class MyTestClass(tf.test.TestCase): 

1434 def testSomething(self): 

1435 # Testing code making use of 

1436 # `tf.__internal__.distribute.multi_process_runner.run()`. 

1437 

1438 if __name__ == '__main__': 

1439 tf.__internal__.distribute.multi_process_runner.test_main() 

1440 ``` 

1441 """ 

1442 # Inject tearDownModule() to shut down all pool runners. Active pool runners 

1443 # will block the program from exiting. This is necessary for global pool 

1444 # runners. We tried atexit in the past, and it doesn't work in some 

1445 # deployment. 

1446 old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule', 

1447 None) 

1448 

1449 def tear_down_module(): 

1450 _shutdown_all_pool_runners() 

1451 if old_tear_down_module is not None: 

1452 old_tear_down_module() 

1453 

1454 setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module) 

1455 multi_process_lib.test_main()