Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/multi_process_runner.py: 21%
470 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Multi-process runner for testing purpose."""
17import collections
18import contextlib
19import json
20import os
21import signal
22import sys
23import threading
24import time
25import unittest
26import weakref
28from absl import logging
29import six
30from six.moves import queue as Queue
32from tensorflow.python import tf2
33from tensorflow.python.compat import v2_compat
34from tensorflow.python.distribute import multi_worker_util
35from tensorflow.python.distribute import multi_process_lib
36from tensorflow.python.eager import context
37from tensorflow.python.framework import test_util
38from tensorflow.python.util.tf_export import tf_export
40multiprocessing = multi_process_lib.multiprocessing
42# pylint: disable=g-import-not-at-top
43try:
44 # `faulthandler` is not available in py2.
45 import faulthandler
46except ImportError:
47 faulthandler = None
49# TODO(b/150264776): Remove after resolving CI issue.
50try:
51 import dill
52except ImportError:
53 dill = None
55# TODO(b/150264776): Remove after resolving CI issue.
56try:
57 import tblib.pickling_support
58 # For pickling traceback objects.
59 tblib.pickling_support.install()
60except ImportError:
61 pass
64# _ProcessStatusInfo contains process status information. When is_successful
65# attribute is True, the subprocess has ended successfully, or if False, the
66# exception stack trace info is stored in exc_info to pass on to parent process
67# to be re-raised.
68_ProcessStatusInfo = collections.namedtuple(
69 '_ProcessStatusInfo',
70 ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value'])
72# Information returned from a successful MultiProcessRunner run.
73MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
74 ['return_value', 'stdout'])
76# visible_gpus: If not None, CUDA_VISIBLE_DEVICES is set to visible_gpus.
77TestEnvironment = collections.namedtuple('TestEnvironment', [
78 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast',
79 'v2_enabled', 'executing_eagerly', 'visible_gpus'
80])
82# Resources for communication between worker processes and the main process.
83#
84# `process_status_queue` is used by `multi_process_runner` internally for
85# communication from subprocesses to the parent process for whether it's been
86# successful, and if not what the error stack trace is.
87# `parent_to_sub_queue` is used for communications from parent to subprocess.
88# Currently this is only used to terminate subprocesses.
89# TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
90# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent
91# process.
92# `barrier` is a barrier for the party of all subprocesses.
93Resources = collections.namedtuple('Resources', [
94 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier'
95])
97# Default time out sec is selected so that it's handled before the default
98# "medium" timeout of the test runs.
99_DEFAULT_TIMEOUT_SEC = 200
101# The timeout in seconds to wait to force kill a child process. When a child
102# process times out we first try to SIGTERM it so that it has a chance to dump
103# stacktraces. However dumping stacktrace can take a long time.
104_FORCE_KILL_WAIT_SEC = 30
107class MultiProcessRunner(object):
108 """A utility class to start multiple processes to simulate a cluster.
110 We need to use multiple processes to simulate a cluster in TF 2.0 tests
111 because TF 2.0 has some process-global data structures that have to be
112 separated by processes. We also need child processes to test out our fault
113 tolerance because shutting down a standard TensorFlow server within its
114 process is not supported.
116 Note: the main test program that uses this runner class must run main program
117 via `test_main` defined in this file. Using this runner in non-test binaries
118 is not supported yet.
120 This class is not thread-safe. Child processes will inherit TF2 behavior flag.
121 """
123 def __init__(self,
124 fn,
125 cluster_spec,
126 rpc_layer=None,
127 max_run_time=None,
128 grpc_fail_fast=None,
129 stream_output=True,
130 return_output=False,
131 use_dill_for_args=True,
132 daemon=False,
133 dependence_on_chief=True,
134 auto_restart=False,
135 share_gpu=True,
136 args=None,
137 kwargs=None):
138 """Instantiation of a `MultiProcessRunner`.
140 Args:
141 fn: Function to be run on child processes. This will be run on processes
142 for all task types.
143 cluster_spec: Dict for cluster spec. The utility function
144 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec`
145 can be conveniently used to create such dict. The following is an
146 example of cluster with three workers and two ps's.
147 {"worker": ["worker0.example.com:2222",
148 "worker1.example.com:2222",
149 "worker2.example.com:2222"],
150 "ps": ["ps0.example.com:2222",
151 "ps1.example.com:2222"]}
152 rpc_layer: RPC layer to use. Default value is 'grpc'.
153 max_run_time: `None` or integer. If not `None`, child processes are forced
154 to exit at approximately this many seconds after this utility is called.
155 We achieve this through `signal.alarm()` api. Note that this is best
156 effort at Python level since Python signal handler does not get executed
157 when it runs lower level C/C++ code. So it can be delayed for
158 arbitrarily long time. If any of the child process is still running when
159 `max_run_time` is up, they will be force-terminated and an
160 `UnexpectedSubprocessExitError` may be raised. If `None`, child
161 processes are not forced to exit.
162 grpc_fail_fast: Whether GRPC connection between processes should fail
163 without retrying. Defaults to None, in which case the environment
164 variable is not explicitly set.
165 stream_output: True if the output/error from the subprocesses should be
166 streamed to be printed in parent process' log. Defaults to True.
167 return_output: If True, the output/error from the subprocesses should be
168 collected to be attached to the resulting namedtuple returned from
169 `join()`. The list of output can be retrieved via `stdout` attribute.
170 Defaults to False.
171 use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill
172 can pickle more objects, but doesn't work with types in
173 `multiprocessing` library like `Mutex`.
174 daemon: Whether to start processes as daemons.
175 dependence_on_chief: Whether to terminates the cluster if the chief exits.
176 If auto_restart is True, it only terminates the cluster if the chief
177 exits with a zero exit code.
178 auto_restart: Whether to automatically restart processes that exit with
179 non-zero exit code.
180 share_gpu: Whether to share GPUs among workers. If False, each worker is
181 assigned different GPUs in a roundrobin fashion. This should be True
182 whenever possible for better test execution coverage; some situations
183 that need it to be False are tests that runs NCCL.
184 args: Positional arguments to be sent to `fn` run on subprocesses.
185 kwargs: Keyword arguments to be sent to `fn` run on subprocesses.
187 Raises:
188 RuntimeError: if `multi_process_runner.test_main()` is not called.
189 ValueError: if there are more than one chief in the `cluster_spec`.
190 SkipTest: if thread sanitizer is enabled (which is incompatible with MPR).
191 """
192 if test_util.is_tsan_enabled():
193 raise unittest.SkipTest(
194 'ThreadSanitizer is not compatible with MultiProcessRunner.')
196 assert cluster_spec is not None
197 if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1:
198 raise ValueError('If chief exists in the cluster, there must be at most '
199 'one chief. Current `cluster_spec` has {} chiefs.'
200 .format(len(cluster_spec['chief'])))
201 _check_initialization()
202 if not callable(fn):
203 raise ValueError('fn is not a callable')
205 self._fn = fn
206 self._cluster_spec = cluster_spec
207 self._rpc_layer = rpc_layer or 'grpc'
208 self._max_run_time = max_run_time
209 self._grpc_fail_fast = grpc_fail_fast
210 self._stream_output = stream_output
211 # TODO(rchao): Revisit return_output argument to consider other solution.
212 self._return_output = return_output
213 self._dependence_on_chief = dependence_on_chief
214 self._use_dill_for_args = use_dill_for_args
215 self._daemon = daemon
216 self._auto_restart = auto_restart
217 self._args = args or ()
218 self._kwargs = kwargs or {}
220 self._share_gpu = share_gpu
221 self._total_gpu = len(context.context().list_physical_devices('GPU'))
223 # Child processes should have the same v2 and eager behavior.
224 self._v2_enabled = tf2.enabled()
225 self._executing_eagerly = context.executing_eagerly()
227 self._joined = False
228 self._process_lock = threading.Lock()
229 # Guarded by self._process_lock.
230 self._processes = {}
231 # Record which processes are terminated. Due to a bug in Python<3.7,
232 # terminated processes return 255 exit code, which should cause an exception
233 # in join().
234 # https://bugs.python.org/issue30589
235 # Guarded by self._process_lock.
236 self._terminated = set()
237 self._reading_threads = []
239 self._manager = manager()
240 self._process_status_queue = self._manager.Queue()
241 self._parent_to_sub_queue = self._manager.Queue()
242 parties = sum(len(addresses) for addresses in self._cluster_spec.values())
243 self._barrier = self._manager.Barrier(parties)
245 # We use a queue to collect outputs from worker processes since it's thread
246 # safe.
247 self._streaming_queue = self._manager.Queue()
249 self._watchdog_thread = None
251 def set_args(self, args=None, kwargs=None):
252 self._args = args or self._args
253 self._kwargs = kwargs or self._kwargs
255 def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
256 """Function to continuously read lines from subprocesses."""
257 with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader:
258 for line in reader:
259 task_string = '[{}-{}]:'.format(task_type, task_id)
260 formatted_line = '{} {}'.format(task_string.ljust(14), line)
261 if self._stream_output:
262 # TODO(rchao): Use a lock here to ensure the printed lines are not
263 # broken.
264 print(formatted_line, end='', flush=True)
265 if self._return_output:
266 self._streaming_queue.put(formatted_line)
268 def _start_subprocess_and_reading_thread(self,
269 task_type,
270 task_id,
271 cluster_spec=None,
272 fn=None,
273 args=None,
274 kwargs=None):
275 """Start a subprocess and a thread the reads lines from the subprocess."""
277 if dill is None:
278 raise unittest.SkipTest(
279 'TODO(b/150264776): Resolve dependency issue in CI')
281 cluster_spec = cluster_spec or self._cluster_spec
282 visible_gpus = None
283 if not self._share_gpu and self._total_gpu > 0:
284 # Assign GPUs in a roundrobin fashion.
285 id_in_cluster = multi_worker_util.id_in_cluster(cluster_spec, task_type,
286 task_id)
287 worker_count = multi_worker_util.worker_count(cluster_spec, task_type)
288 visible_gpus = list(range(id_in_cluster, self._total_gpu, worker_count))
290 test_env = TestEnvironment(
291 task_type=task_type,
292 task_id=task_id,
293 cluster_spec=cluster_spec,
294 rpc_layer=self._rpc_layer,
295 grpc_fail_fast=self._grpc_fail_fast,
296 v2_enabled=self._v2_enabled,
297 executing_eagerly=self._executing_eagerly,
298 visible_gpus=visible_gpus,
299 )
300 pipe_r, pipe_w = multiprocessing.Pipe(duplex=False)
301 resources = Resources(
302 process_status_queue=self._process_status_queue,
303 parent_to_sub_queue=self._parent_to_sub_queue,
304 streaming_pipe_w=pipe_w,
305 barrier=self._barrier,
306 )
307 if fn is None:
308 fn, args, kwargs = self._fn, self._args, self._kwargs
309 # Always use dill to pickle fn so that we support more callable
310 # types, e.g. lambda.
311 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL)
312 if self._use_dill_for_args:
313 args = dill.dumps(args, dill.HIGHEST_PROTOCOL)
314 kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL)
316 p = _Process(
317 test_env=test_env,
318 target=_ProcFunc(),
319 args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args),
320 daemon=self._daemon)
321 p.start()
322 self._processes[(task_type, task_id)] = p
323 self._terminated.discard((task_type, task_id))
325 # For each subprocess, we dedicate a thread continuously reading lines
326 # from them.
327 thread = threading.Thread( # pylint: disable=unexpected-keyword-arg
328 target=self._continuously_readline_from_sub,
329 args=(pipe_r, task_type, task_id))
330 thread.start()
331 self._reading_threads.append(thread)
333 if self._watchdog_thread is None or not self._watchdog_thread.is_alive():
334 self._watchdog_thread = threading.Thread(target=self._process_watchdog)
335 self._watchdog_thread.start()
337 def start(self):
338 """Starts processes, one for each task in `cluster_spec`.
340 Note that this is best effort by the applicable multiprocessing library,
341 and it may take up to seconds for a subprocess to be successfully started.
342 """
343 with self._process_lock:
344 if self._processes:
345 raise ValueError('MultiProcessRunner already started.')
346 if self._joined:
347 raise ValueError('cannot start new processes after'
348 'MultiProcessRunner.join() is called')
350 for task_type, addresses in self._cluster_spec.items():
351 for task_id, _ in enumerate(addresses):
352 self._start_subprocess_and_reading_thread(task_type, task_id)
354 # TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
355 # without this the tests become very flaky.
356 if self._max_run_time is not None:
358 def handler(signum, frame):
359 del signum, frame
360 self.terminate_all()
362 signal.signal(signal.SIGALRM, handler)
363 signal.alarm(self._max_run_time)
365 def start_in_process_as(self, as_task_type, as_task_id):
366 """Start the processes, with the specified task run in main process.
368 This is similar to `start()` except that the task with task_type
369 `as_task_type` and task_id `as_task_id` is run in the main process.
370 This method is particularly useful when debugging tool such as `pdb` is
371 needed in some specific task. Note that since this method is blocking until
372 that specific task exits, additional actions would need a thread to be
373 called:
375 ```python
376 def fn():
377 # user code to be run
378 import pdb; pdb.set_trace()
380 def follow_ups():
381 time.sleep(5)
382 mpr.start_single_process(
383 task_type='evaluator',
384 task_id=0)
386 mpr = multi_process_runner.MultiProcessRunner(
387 fn,
388 multi_worker_test_base.create_cluster_spec(
389 has_chief=True, num_workers=1))
390 threading.Thread(target=follow_ups).start()
391 mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
392 mpr.join()
393 ```
395 Note that if `return_output=True`, the logs/stdout by task
396 run by the main process is not available in result.stdout.
398 Args:
399 as_task_type: The task type to be run in the main process.
400 as_task_id: The task id to be run in the main process.
401 """
402 if self._processes:
403 raise ValueError('MultiProcessRunner already started.')
404 with self._process_lock:
405 if self._joined:
406 raise ValueError('cannot start new processes after'
407 'MultiProcessRunner.join() is called')
408 for task_type, addresses in self._cluster_spec.items():
409 for task_id, _ in enumerate(addresses):
410 if not (task_type == as_task_type and task_id == as_task_id):
411 self._start_subprocess_and_reading_thread(task_type, task_id)
413 _set_tf_config(as_task_type, as_task_id, self._cluster_spec,
414 self._rpc_layer)
415 self._fn(*self._args, **self._kwargs)
417 def start_single_process(self,
418 task_type,
419 task_id,
420 cluster_spec=None,
421 fn=None,
422 args=None,
423 kwargs=None):
424 """Starts a single process.
426 This starts a process in the cluster with the task type, task id, and the
427 process function (`fn`). If process function is `None`, the function
428 provided at `__init__` will be used. If `cluster_spec` is `None`, the
429 cluster spec provided at `__init__` will be used.
431 TODO(rchao): It is meant that all subprocesses will be updated with the new
432 cluster spec, but this has yet to be implemented. At this time only the
433 newly started subprocess picks up this updated cluster spec.
435 Args:
436 task_type: The task type.
437 task_id: The task id.
438 cluster_spec: The cluster spec to be used on the newly started
439 process. If `None`, the cluster spec provided at `__init__` will be
440 used.
441 fn: The process function to be run on the newly started
442 process. If specified, specify `args` and `kwargs` as well. If `None`,
443 the function provided at `__init__` will be used.
444 args: Optional positional arguments to be supplied in `fn`.
445 kwargs: Optional keyword arguments to be supplied in `fn`.
446 """
447 with self._process_lock:
448 if self._joined:
449 raise ValueError('cannot start new processes after'
450 'MultiProcessRunner.join() is called')
451 self._start_subprocess_and_reading_thread(
452 task_type,
453 task_id,
454 cluster_spec=cluster_spec,
455 fn=fn,
456 args=args or (),
457 kwargs=kwargs or {})
459 def _queue_to_list(self, queue_to_convert):
460 """Convert `queue.Queue` to `list`."""
461 list_to_return = []
462 # Calling `queue.empty()` is not reliable.
463 while True:
464 try:
465 list_to_return.append(queue_to_convert.get(block=False))
466 except Queue.Empty:
467 break
468 return list_to_return
470 def _get_process_statuses(self):
471 # One worker may have multiple statuses. We only keep the last one.
472 statuses = {}
473 for status in self._queue_to_list(self._process_status_queue):
474 statuses[(status.task_type, status.task_id)] = status
475 return statuses
477 def get_process_id(self, task_type, task_id):
478 """Returns the subprocess id given the task type and task id."""
479 with self._process_lock:
480 p = self._processes.get((task_type, task_id), None)
481 return p.pid if p else None
483 def get_process_exit_code(self, task_type, task_id):
484 """Returns the subprocess exit code given the task type and task id.
486 Args:
487 task_type: The task type.
488 task_id: The task id.
490 Returns:
491 The subprocess exit code; `None` if the subprocess has not exited yet.
493 Raises:
494 KeyError: If the corresponding subprocess is not found with `task_type`
495 and `task_id`.
496 """
497 with self._process_lock:
498 p = self._processes[(task_type, task_id)]
499 return p.exitcode if p else None
501 def process_exists(self, task_type, task_id):
502 """Returns whether the subprocess still exists given the task type and id.
504 Args:
505 task_type: The task type.
506 task_id: The task id.
508 Returns:
509 Boolean; whether the subprocess still exists. If the subprocess has
510 exited, this returns False.
511 """
512 return self.get_process_exit_code(task_type, task_id) is None
514 def _process_watchdog(self):
515 """Simulates a cluster management system.
517 - If auto_restart is True, it restarts processes that exit with a non-zero
518 exit code. Note that when join() times out it overrides auto_restart to
519 False.
520 - If dependence_on_chief is True, it terminates all processes once the chief
521 exits. If auto_restart is also True, it only terminates all processes if
522 the chief exit with a zero exit code, otherwise it restarts the chief.
524 This runs in self._watchdog_thread.
525 """
526 while True:
527 time.sleep(1)
528 with self._process_lock:
529 chief = self._processes.get(('chief', 0), None)
530 # Terminate the cluster when _dependence_on_chief is True if either:
531 # - chief has exited with zero exit code.
532 # - chief has exited with non-zero exit code and self._auto_restart is
533 # False.
534 if chief and self._dependence_on_chief and chief.exitcode is not None:
535 if chief.exitcode == 0 or (not self._auto_restart):
536 for p in self._processes.values():
537 # Give other processes a chance to exit on their own.
538 p.join(timeout=3)
539 self._terminate_all()
540 for p in self._processes.values():
541 p.join()
542 return
544 # Auto restart failed processes if self._auto_restart is True.
545 if self._auto_restart:
546 has_failure = False
547 for (task_type, task_id), p in self._processes.items():
548 if p.exitcode is not None and p.exitcode != 0:
549 has_failure = True
550 logging.info('Restarting failed %s-%d', task_type, task_id)
551 self._start_subprocess_and_reading_thread(task_type, task_id)
552 if has_failure:
553 continue
555 # Exit the thread if all processes have exited at this point.
556 if all(p.exitcode is not None for p in self._processes.values()):
557 return
559 def _reraise_if_subprocess_error(self, process_statuses):
560 for process_status in process_statuses.values():
561 assert isinstance(process_status, _ProcessStatusInfo)
562 if not process_status.is_successful:
563 process_status.exc_info[1].mpr_result = self._get_mpr_result(
564 process_statuses)
565 six.reraise(*process_status.exc_info)
567 def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
568 """Joins all the processes with timeout.
570 If any of the subprocesses does not exit approximately after `timeout`
571 seconds has passed after `join` call, this raises a
572 `SubprocessTimeoutError`.
574 Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
575 log the stack traces of the subprocesses when they exit. However, this
576 results in timeout when the test runs with tsan (thread sanitizer); if tsan
577 is being run on the test targets that rely on timeout to assert information,
578 `MultiProcessRunner.terminate_all()` must be called after `join()`, before
579 the test exits, so the subprocesses are terminated with SIGKILL, and data
580 race is removed.
582 Args:
583 timeout: optional integer or `None`. If provided as an integer, and not
584 all processes report status within roughly `timeout` seconds, a
585 `SubprocessTimeoutError` exception will be raised. If `None`, `join` never
586 times out.
588 Returns:
589 A `MultiProcessRunnerResult` object, which has two attributes,
590 `return_value` and `stdout`. `return_value` always contains a list of
591 return values from the subprocesses, although the order is not meaningful.
592 If `return_output` argument is True at `__init__`, `stdout` is available
593 that contains a list of all messages from subprocesses' stdout and stderr.
595 Raises:
596 SubprocessTimeoutError: if not all processes report status approximately
597 within `timeout` seconds. When this is raised, a
598 `MultiProcessRunnerResult` object can be retrieved by
599 `SubprocessTimeoutError`'s mpr_result attribute, which has the same
600 structure as above 'Returns' section describes.
601 UnexpectedSubprocessExitError: If any of the subprocesses did not exit
602 properly (for example, they exit on SIGTERM or SIGKILL signal). When
603 this is raised, a `MultiProcessRunnerResult` object can be retrieved by
604 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
605 same structure as above 'Returns' section describes. If `max_run_time`
606 is not `None`, it is expected that some subprocesses may be
607 force-killed when `max_run_time` is up, and this is raised in those
608 cases.
609 Exception: if there is an Exception propagated from any subprocess. When
610 this is raised, a `MultiProcessRunnerResult` object can be retrieved by
611 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
612 same structure as above 'Returns' section describes.
613 """
614 if timeout and not isinstance(timeout, int):
615 raise ValueError('`timeout` must be an integer or `None`.')
616 with self._process_lock:
617 if self._joined:
618 raise ValueError("MultiProcessRunner can't be joined twice.")
619 self._joined = True
621 self._watchdog_thread.join(timeout)
622 if self._watchdog_thread.is_alive():
623 # Timeout. Force termination to dump worker processes stack trace.
624 with self._process_lock:
625 self._auto_restart = False
626 logging.error('Timeout when joining for child processes. Terminating...')
627 self.terminate_all(sig=signal.SIGTERM)
628 # Wait for the processes to terminate by themselves first, so they have a
629 # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them.
630 self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC)
631 if self._watchdog_thread.is_alive():
632 logging.error('Timeout when waiting for child processes to '
633 'print stacktrace. Sending SIGKILL...')
634 self.terminate_all()
635 self._watchdog_thread.join()
636 process_statuses = self._get_process_statuses()
637 self._reraise_if_subprocess_error(process_statuses)
638 raise SubprocessTimeoutError(
639 'One or more subprocesses timed out, where timeout was set to {}s. '
640 'Please change the `timeout` argument for '
641 '`MultiProcessRunner.join()` or `multi_process_runner.run()` '
642 'if it should be adjusted.'.format(timeout),
643 self._get_mpr_result(process_statuses))
645 for (task_type, task_id), p in self._processes.items():
646 logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
648 process_statuses = self._get_process_statuses()
649 self._reraise_if_subprocess_error(process_statuses)
651 # Checking all the processes that are expected to exit properly.
652 for (task_type, task_id), p in self._processes.items():
653 # Successfully exiting process has exit code 0. We ignore processes that
654 # are terminated.
655 assert p.exitcode is not None
656 if (p.exitcode > 0 and (task_type, task_id) not in self._terminated):
657 raise UnexpectedSubprocessExitError(
658 'Subprocess %s-%d exited with exit code %s. See logs for details.'
659 % (task_type, task_id, p.exitcode),
660 self._get_mpr_result(process_statuses))
662 logging.info('Joining log reading threads.')
663 for thread in self._reading_threads:
664 thread.join()
665 logging.info('Joined log reading threads.')
667 # Clear the alarm.
668 signal.alarm(0)
670 return self._get_mpr_result(process_statuses)
672 def _get_mpr_result(self, process_statuses):
673 stdout = self._queue_to_list(self._streaming_queue)
674 return_values = []
675 for process_status in process_statuses.values():
676 if process_status.return_value is not None:
677 return_values.append(process_status.return_value)
678 return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
680 def terminate(self, task_type, task_id):
681 """Terminates the process with `task_type` and `task_id`.
683 If auto_retart=True, the terminated task will be restarted unless the chief
684 has already exited with zero exit code.
686 Args:
687 task_type: the task type.
688 task_id: the task id.
690 """
691 with self._process_lock:
692 p = self._processes.get((task_type, task_id), None)
693 if p is None:
694 raise ValueError('{}-{} does not exist'.format(task_type, task_id))
695 self._terminated.add((task_type, task_id))
696 # TODO(crccw): change to use Process.terminate() as well.
697 self._parent_to_sub_queue.put('terminate {} {}'.format(
698 task_type, task_id))
699 p.join()
701 def _terminate_all(self, sig=None):
702 """Terminates all subprocesses.
704 The caller is required to hold self._process_lock.
706 Args:
707 sig: the signal used to terminate the process. The default is SIGKILL.
708 """
710 # Use SIGKILL as default. In systems where that's unavailable such as
711 # windows, use SIGTERM.
712 sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
713 for (task_type, task_id), p in self._processes.items():
714 if p.exitcode is not None:
715 logging.info('%s-%d has already exited. Not terminating.', task_type,
716 task_id)
717 continue
718 try:
719 os.kill(p.pid, sig)
720 self._terminated.add((task_type, task_id))
721 logging.info('%s-%d terminated with signal %r.', task_type, task_id,
722 sig)
723 except ProcessLookupError:
724 logging.info('Attempting to kill %s-%d but it does not exist.',
725 task_type, task_id)
727 def terminate_all(self, sig=None):
728 """Terminates all subprocesses."""
729 with self._process_lock:
730 self._terminate_all(sig)
733class _Process(multi_process_lib.Process):
734 """A modified `multiprocessing.Process` that can set up environment variables."""
736 # TODO(crccw): consider moving other logics in _ProcFunc to _Process.
738 def __init__(self, test_env, **kwargs):
739 super(_Process, self).__init__(**kwargs)
740 self._test_env = test_env
741 self._actual_run = getattr(self, 'run')
742 self.run = self._run_with_setenv
744 def _run_with_setenv(self):
745 # We need to set environment variables before doing anything because
746 # setenv() is not thread-safe.
747 test_env = self._test_env
748 if test_env.grpc_fail_fast is not None:
749 os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast)
750 if test_env.visible_gpus:
751 os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
752 [str(i) for i in test_env.visible_gpus])
753 _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec,
754 test_env.rpc_layer)
755 return self._actual_run()
758class _ProcFunc(object):
759 """Represents a callable to run in a subprocess."""
761 @contextlib.contextmanager
762 def _runtime_mode(self, executing_eagerly):
763 if executing_eagerly:
764 with context.eager_mode():
765 yield
766 else:
767 with context.graph_mode():
768 yield
770 def _message_checking_func(self, task_type, task_id):
771 """A function that regularly checks messages from parent process."""
772 # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
773 while True:
774 try:
775 message = self._resources.parent_to_sub_queue.get(block=False)
777 # Currently the only possible message is termination.
778 if not message.startswith('terminate'):
779 raise ValueError('Unrecognized message: {}'.format(message))
781 if message == 'terminate {} {}'.format(task_type, task_id):
782 break
783 else:
784 # If the message is not targeting this process, put it back to the
785 # queue.
786 self._resources.parent_to_sub_queue.put(message)
787 time.sleep(1)
788 except Queue.Empty:
789 time.sleep(0.1)
790 self._resources.process_status_queue.put(
791 _ProcessStatusInfo(
792 task_type=task_type,
793 task_id=task_id,
794 is_successful=True,
795 exc_info=None,
796 return_value=None))
797 # `os._exit(1)` is used to more reliably terminate a subprocess.
798 os._exit(1) # pylint: disable=protected-access
800 def _close_streaming(self):
801 """Close stdout, stderr and streaming pipe.
803 We need to explicitly close them since Tensorflow may take a while to exit,
804 so that the reading threads in the main process can exit more quickly.
805 """
806 sys.stdout.flush()
807 sys.stderr.flush()
808 sys.stdout.close()
809 sys.stderr.close()
810 self._resources.streaming_pipe_w.close()
812 def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args):
813 """The wrapper function that actually gets run in child process(es)."""
815 global _barrier
817 self._resources = resources
818 _barrier = self._resources.barrier
819 fn = dill.loads(fn)
820 if use_dill_for_args:
821 args = dill.loads(args)
822 kwargs = dill.loads(kwargs)
824 if faulthandler is not None:
825 faulthandler.enable()
826 faulthandler.register(signal.SIGTERM, chain=True)
828 # All logging should go to stderr to be streamed to the main process.
829 logging.set_stderrthreshold(logging.DEBUG)
831 # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
832 # print() and logging.*() write directly to `streaming_pipe_w`.
833 # Unfortunately since we cannot prepend task_type and task_id information to
834 # the streamed logs we will need a thread per subprocess to distinguish
835 # where the piece of message is from.
836 os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
837 os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())
839 pid = os.getpid()
840 logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
841 test_env.task_type, test_env.task_id)
842 logging.info('TF_CONFIG: %r', os.environ['TF_CONFIG'])
844 # The thread will be dedicated to checking messages from the parent process.
845 threading.Thread( # pylint: disable=unexpected-keyword-arg
846 target=self._message_checking_func,
847 args=(test_env.task_type, test_env.task_id),
848 daemon=True).start()
850 if test_env.v2_enabled:
851 v2_compat.enable_v2_behavior()
853 with self._runtime_mode(test_env.executing_eagerly):
854 info = _run_contained(test_env.task_type, test_env.task_id, fn, args,
855 kwargs)
856 self._resources.process_status_queue.put(info)
858 # Re-raise the exception in addition to reporting it to the parent
859 # process, so that even if `--test_timeout` flag is set and the
860 # error doesn't make it to be shown in parent process before bazel's
861 # timeout, the log would still show what happens in this subprocess,
862 # instead of silently suppressing the error due to early bazel
863 # timeout. Raising an error in the subprocess produces stack trace in
864 # the log, but the program continues running.
865 if not info.is_successful:
866 six.reraise(*info.exc_info)
868 self._close_streaming()
870 # Exit with code 0 as it's considered successful exit at this point.
871 sys.exit(0)
874# Active MultiProcessPoolRunner. We need to shut them down when the program
875# exits, and this is by setting the `tearDownModule` of the module containing
876# `__main__`. Note this it set in both the parent process and the subprocesses.
877_active_pool_runners = weakref.WeakSet()
880def _shutdown_all_pool_runners():
881 for pool in _active_pool_runners:
882 pool.shutdown()
885def is_oss():
886 """Returns whether the test is run under OSS."""
887 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
890class MultiProcessPoolRunner(object):
891 """A utility class to start a process pool to simulate a cluster.
893 It's similar to MultiProcessRunner, but uses a pool of processes to avoid the
894 expensive initialization cost of Tensorflow.
895 """
897 def __init__(self, cluster_spec, initializer=None, share_gpu=True):
898 """Creates a multi-process pool runner.
900 Args:
901 cluster_spec: Dict for cluster spec. The following is an example of
902 cluster with three workers.
903 {"worker": ["worker0.example.com:2222",
904 "worker1.example.com:2222",
905 "worker2.example.com:2222"]}
906 initializer: a callable to called at the startup of worker processes.
907 share_gpu: Whether to share GPUs among workers. If False, each worker is
908 assigned different GPUs in a roundrobin fashion.
910 Raises:
911 RuntimeError: if `multi_process_runner.test_main()` is not called.
912 ValueError: if there are more than one chief in the `cluster_spec`.
913 """
914 _active_pool_runners.add(self)
915 self._cluster_spec = cluster_spec
916 self._initializer = initializer
917 self._share_gpu = share_gpu
918 self._conn = {}
919 self._runner = None
921 def __del__(self):
922 self.shutdown()
924 def shutdown(self):
925 """Shuts down the worker pool."""
926 for conn in self._conn.values():
927 conn.close()
928 self._conn = {}
929 if self._runner is not None:
930 try:
931 self._runner.join()
932 except Exception as e: # pylint: disable=broad-except
933 logging.error(
934 'Ignoring exception when shutting down MultiProcessPoolRunner: %s',
935 e)
936 self._runner = None
938 def _start(self):
939 """Starts the worker pool."""
940 # We need different arguments for different processes so we're passing a
941 # no-op fn here and use start_single_process instead.
943 if dill is None:
944 raise unittest.SkipTest(
945 'TODO(b/150264776): Resolve dependency issue in CI')
947 self._runner = MultiProcessRunner(
948 fn=lambda: None,
949 cluster_spec=self._cluster_spec,
950 use_dill_for_args=False,
951 share_gpu=self._share_gpu)
952 if self._initializer:
953 initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL)
954 else:
955 initializer = None
956 for task_type, addresses in self._cluster_spec.items():
957 for task_id, _ in enumerate(addresses):
958 conn1, conn2 = multiprocessing.Pipe(duplex=True)
959 self._conn[(task_type, task_id)] = conn1
960 self._runner.start_single_process(
961 task_type,
962 task_id,
963 fn=_pool_runner_worker,
964 args=(task_type, task_id, initializer, conn2))
966 def run(self, fn, args=None, kwargs=None):
967 """Runs `fn` with `args` and `kwargs` on all jobs.
969 Args:
970 fn: The function to be run.
971 args: Optional positional arguments to be supplied in `fn`.
972 kwargs: Optional keyword arguments to be supplied in `fn`.
974 Returns:
975 A list of return values.
976 """
977 _check_initialization()
978 # TODO(b/150264776): skip in OSS until it's implemented.
979 multi_process_lib.Process()
980 if self._runner is None:
981 self._start()
983 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL)
984 for conn in self._conn.values():
985 conn.send((fn, args or [], kwargs or {}))
987 process_statuses = []
988 for (task_type, task_id), conn in self._conn.items():
989 logging.info('Waiting for the result from %s-%d', task_type, task_id)
990 try:
991 process_statuses.append(conn.recv())
992 except EOFError:
993 # This shouldn't happen due to exceptions in fn. This usually
994 # means bugs in the runner.
995 self.shutdown()
996 raise RuntimeError('Unexpected EOF. Worker process may have died. '
997 'Please report a bug')
999 return_values = []
1000 for process_status in process_statuses:
1001 assert isinstance(process_status, _ProcessStatusInfo)
1002 if not process_status.is_successful:
1003 six.reraise(*process_status.exc_info)
1004 if process_status.return_value is not None:
1005 return_values.append(process_status.return_value)
1007 return return_values
1010def _pool_runner_worker(task_type, task_id, initializer, conn):
1011 """Function that runs on the workers in a pool.
1013 It listens for callables to run and returns the result until `conn` is closed.
1014 It captures the exceptions during executing the callable and return it through
1015 `conn`.
1017 Args:
1018 task_type: the task type.
1019 task_id: the task index.
1020 initializer: a callable to execute during startup.
1021 conn: a multiprocessing.Connection object to listen for tasks and send
1022 results.
1023 """
1024 if initializer:
1025 initializer = dill.loads(initializer)
1026 initializer()
1027 while True:
1028 try:
1029 fn, args, kwargs = conn.recv()
1030 except EOFError:
1031 break
1032 fn = dill.loads(fn)
1033 info = _run_contained(task_type, task_id, fn, args, kwargs)
1034 sys.stdout.flush()
1035 sys.stderr.flush()
1036 conn.send(info)
1039def _run_contained(task_type, task_id, fn, args, kwargs):
1040 """Runs `fn` with `args` and `kwargs`.
1042 The function returns _ProcessStatusInfo which captures the return value and
1043 the exception.
1045 Args:
1046 task_type: the task type.
1047 task_id: the task index.
1048 fn: the function to be run.
1049 args: optional positional arguments to be supplied in `fn`.
1050 kwargs: optional keyword arguments to be supplied in `fn`.
1052 Returns:
1053 a _ProcessStatusInfo.
1055 """
1056 is_successful = False
1057 return_value = None
1058 exc_info = None
1059 try:
1060 return_value = fn(*args, **kwargs)
1061 is_successful = True
1062 return _ProcessStatusInfo(
1063 task_type=task_type,
1064 task_id=task_id,
1065 is_successful=is_successful,
1066 exc_info=exc_info,
1067 return_value=return_value)
1069 # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not
1070 # handled here.
1071 except Exception: # pylint: disable=broad-except
1072 exc_info = sys.exc_info()
1073 return _ProcessStatusInfo(
1074 task_type=task_type,
1075 task_id=task_id,
1076 is_successful=is_successful,
1077 exc_info=exc_info,
1078 return_value=return_value)
1081@tf_export('__internal__.distribute.multi_process_runner'
1082 '.SubprocessTimeoutError',
1083 v1=[])
1084class SubprocessTimeoutError(RuntimeError):
1085 """An error that indicates there is at least one subprocess timing out.
1087 When this is raised, a namedtuple object representing the multi-process run
1088 result can be retrieved by
1089 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s
1090 `mpr_result` attribute. See
1091 `tf.__internal__.distribute.multi_process_runner.run` for more information.
1092 """
1094 def __init__(self, msg, mpr_result):
1095 super(SubprocessTimeoutError, self).__init__(msg)
1096 self.mpr_result = mpr_result
1099@tf_export('__internal__.distribute.multi_process_runner'
1100 '.UnexpectedSubprocessExitError',
1101 v1=[])
1102class UnexpectedSubprocessExitError(RuntimeError):
1103 """An error indicating there is at least one subprocess with unexpected exit.
1105 When this is raised, a namedtuple object representing the multi-process run
1106 result can be retrieved by
1107 `tf.__internal__.distribute.multi_process_runner
1108 .UnexpectedSubprocessExitError`'s
1109 `mpr_result` attribute. See
1110 `tf.__internal__.distribute.multi_process_runner.run` for more information.
1111 """
1113 def __init__(self, msg, mpr_result):
1114 super(UnexpectedSubprocessExitError, self).__init__(msg)
1115 self.mpr_result = mpr_result
1118@tf_export(
1119 '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[])
1120class NotInitializedError(RuntimeError):
1121 """An error indicating `multi_process_runner.run` is used without init.
1123 When this is raised, user is supposed to call
1124 `tf.__internal__.distribute.multi_process_runner.test_main()` within
1125 `if __name__ == '__main__':` block to properly initialize
1126 `multi_process_runner.run`.
1127 """
1128 pass
1131def _check_initialization():
1132 if not multi_process_lib.initialized():
1133 raise NotInitializedError(
1134 '`multi_process_runner` is not initialized. '
1135 'Please call `tf.__internal__.distribute.multi_process_runner.'
1136 'test_main()` within `if __name__ == \'__main__\':` block '
1137 'in your python module to properly initialize '
1138 '`multi_process_runner`.')
1141def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None):
1142 """Set TF_CONFIG environment variable."""
1143 tf_config_dict = {
1144 'cluster': cluster_spec,
1145 'task': {
1146 'type': task_type,
1147 'index': task_id,
1148 },
1149 }
1150 if rpc_layer is not None:
1151 tf_config_dict['rpc_layer'] = rpc_layer
1152 os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
1155@tf_export('__internal__.distribute.multi_process_runner.run', v1=[])
1156def run(fn,
1157 cluster_spec,
1158 rpc_layer=None,
1159 max_run_time=None,
1160 return_output=False,
1161 timeout=_DEFAULT_TIMEOUT_SEC,
1162 args=None,
1163 kwargs=None):
1164 """Run `fn` in multiple processes according to `cluster_spec`.
1166 Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run`
1167 launches multiple processes, each of which runs `fn`. These processes are
1168 referred to as "subprocesses" or "child processes". Each of those subprocesses
1169 will have their `TF_CONFIG` environment variable set, according to
1170 `cluster_spec` and their task types. The stdout of the subprocesses are
1171 streamed to the main process' and thus available in logs (if `stream_output`
1172 is True), with [type-id] prefix.
1174 `tf.__internal__.distribute.multi_process_runner.run` will block until all
1175 subprocesses have successfully exited, and return a namedtuple object that
1176 represents the run result. This object has a `return_value` attribute, which
1177 is a list that contains subprocesses `fn`'s return values, for those
1178 subprocesses that successfully returned from `fn`. The order of `return_value`
1179 list is not meaningful. If an optional arg `return_output` (default to False)
1180 is set to True, the namedtuple object will have an additional attribute
1181 `stdout`, which is a list containing the stdout of the subprocesses. If any
1182 subprocess' `fn` ends up raising an error, that error will be reraised from
1183 `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned
1184 namedtuple object will be available through the exception's
1185 `mpr_result` attribute.
1187 This utility is used for simulating running TensorFlow programs across
1188 multiple task types, and each of the task type may contain more than one task
1189 (except for "chief" where more than one task is prohibited). Test coverage of
1190 multi-worker training is the main application of this utility, where code
1191 written for multi-worker training can be realistically covered in unit tests.
1193 Any test module that uses
1194 `tf.__internal__.distribute.multi_process_runner.run()` must call
1195 `tf.__internal__.distribute.multi_process_runner.test_main()` instead of
1196 regular `test.main()` inside `if __name__ == '__main__':` block for proper
1197 initialization.
1199 Args:
1200 fn: Function to be run on child processes. This will be run on processes for
1201 all task types.
1202 cluster_spec: Dict for cluster spec. The utility function
1203 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can
1204 be conveniently used to create such dict. The following is an example of
1205 cluster with three workers and two ps's.
1206 {"worker": ["worker0.example.com:2222",
1207 "worker1.example.com:2222",
1208 "worker2.example.com:2222"],
1209 "ps": ["ps0.example.com:2222",
1210 "ps1.example.com:2222"]}
1211 rpc_layer: RPC layer to use. Default value is 'grpc'.
1212 max_run_time: `None` or integer. If not `None`, child processes are forced
1213 to exit at approximately this many seconds after this utility is called.
1214 We achieve this through `signal.alarm()` api. Note that this is best
1215 effort at Python level since Python signal handler does not get executed
1216 when it runs lower level C/C++ code. So it can be delayed for arbitrarily
1217 long time. If any of the child process is still running when
1218 `max_run_time` is up, they will be force-terminated and an
1219 `tf.__internal__.distribute.multi_process_runner
1220 .UnexpectedSubprocessExitError`
1221 may be raised. If `None`, child processes are not forced to exit.
1222 return_output: If True, the output/error from the subprocesses should be
1223 collected to be attached to the resulting namedtuple returned from this
1224 utility. The list of output can be retrieved via `stdout` attribute.
1225 Defaults to False.
1226 timeout: optional integer or `None`. If provided as an integer, and not all
1227 processes report status within roughly `timeout` seconds, a
1228 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`
1229 exception will be raised. If `None`,
1230 `tf.__internal__.distribute.multi_process_runner.run` never times out.
1231 Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in
1232 `multi_process_runner` module.
1233 args: Positional arguments to be sent to `fn` run on subprocesses.
1234 kwargs: Keyword arguments to be sent to `fn` run on subprocesses.
1236 Returns:
1237 A namedtuple object, which has two attributes,
1238 `return_value` and `stdout`. `return_value` always contains a list of
1239 returnvalues from the subprocesses, although the order is not meaningful.
1240 If `return_output` argument is True, `stdout` is available that contains a
1241 list of all messages from subprocesses' stdout and stderr, and the order
1242 is mostly chronological.
1244 Raises:
1245 RuntimeError: if
1246 `tf.__internal__.distribute.multi_process_runner.test_main()` is
1247 not called in test's `if __name__ == '__main__':` block.
1248 ValueError: if there are more than one chief in the `cluster_spec`.
1249 tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if
1250 not all processes report status approximately
1251 within `timeout` seconds. When this is raised, a
1252 namedtuple object can be retrieved by
1253 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s
1254 `mpr_result` attribute, which has the same
1255 structure as above 'Returns' section describes.
1256 tf.__internal__.distribute.multi_process_runner
1257 .UnexpectedSubprocessExitError:
1258 If any of the subprocesses did not exit
1259 properly (for example, they exit on SIGTERM or SIGKILL signal). When
1260 this is raised, a namedtuple object can be retrieved by
1261 `tf.__internal__.distribute.multi_process_runner
1262 .UnexpectedSubprocessExitError`'s
1263 `mpr_result` attribute, which has the
1264 same structure as above 'Returns' section describes. If `max_run_time`
1265 is not `None`, it is expected that some subprocesses may be
1266 force-killed when `max_run_time` is up, and this is raised in those
1267 cases.
1268 Exception: if there is an Exception propagated from any subprocess. When
1269 this is raised, a namedtuple object can be retrieved by
1270 `tf.__internal__.distribute.multi_process_runner
1271 .UnexpectedSubprocessExitError`
1272 `mpr_result` attribute, which has the
1273 same structure as above 'Returns' section describes.
1275 Examples:
1277 ```python
1278 class SimpleMultiProcessTest(tf.test.TestCase):
1280 def test_simple_printing_and_return(self):
1282 def fn():
1283 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
1285 # This will print "[chief-0]: Task type: chief , task id: 0"
1286 # for chief, for example.
1287 logging.info('Task type: %s, task id: %d',
1288 resolver.task_type, resolver.task_id)
1290 return resolver.task_type
1292 result = tf.__internal__.distribute.multi_process_runner.run(
1293 fn=fn,
1294 cluster_spec=(
1295 tf.__internal__
1296 .distribute.multi_process_runner.create_cluster_spec(
1297 has_chief=True, num_workers=2)))
1298 assert sorted(result.return_value) == ['chief', 'worker', 'worker']
1300 def test_error_from_fn(self):
1302 def fn():
1303 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
1304 raise ValueError('Task type {}, task id {} is errors out'.format(
1305 resolver.task_type, resolver.task_id))
1307 with self.assertRaisesRegexp(ValueError,
1308 'Task type worker, task id 0 is errors out'):
1309 cluster_spec = (
1310 tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
1311 num_workers=1))
1312 tf.__internal__.distribute.multi_process_runner.run(
1313 fn=fn, cluster_spec=cluster_spec)
1316 if __name__ == '__main__':
1317 tf.__internal__.distribute.multi_process_runner.test_main()
1318 ```
1319 """
1320 runner = MultiProcessRunner(
1321 fn,
1322 cluster_spec,
1323 rpc_layer,
1324 max_run_time=max_run_time,
1325 return_output=return_output,
1326 args=args,
1327 kwargs=kwargs)
1328 runner.start()
1329 return runner.join(timeout)
1332# This is set by MultiProcessRunner in worker processes.
1333_barrier = None
1336@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[])
1337def get_barrier():
1338 """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`.
1340 `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns
1341 a `multiprocessing.Barrier` object which can be used within `fn` of
1342 `tf.__internal__.distribute.multi_process_runner` to wait with
1343 `barrier.wait()` call until all other tasks have also reached the
1344 `barrier.wait()` call, before they can proceed individually.
1346 Note that all tasks (subprocesses) have to reach `barrier.wait()` call to
1347 proceed. Currently it is not supported to block on only a subset of tasks
1348 in the cluster.
1350 Example:
1351 ```python
1353 def fn():
1354 some_work_to_be_done_by_all_tasks()
1356 tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
1358 # The barrier guarantees that at this point, all tasks have finished
1359 # `some_work_to_be_done_by_all_tasks()`
1360 some_other_work_to_be_done_by_all_tasks()
1362 result = tf.__internal__.distribute.multi_process_runner.run(
1363 fn=fn,
1364 cluster_spec=(
1365 tf.__internal__
1366 .distribute.multi_process_runner.create_cluster_spec(
1367 num_workers=2)))
1368 ```
1371 Returns:
1372 A `multiprocessing.Barrier` for `multi_process_runner.run`.
1373 """
1374 if _barrier is None:
1375 raise ValueError(
1376 'barrier is not defined. It is likely because you are calling '
1377 'get_barrier() in the main process. get_barrier() can only be called '
1378 'in the subprocesses.'
1379 )
1380 return _barrier
1383_manager = None
1384_manager_lock = threading.Lock()
1387def manager():
1388 """Returns the multiprocessing manager object for concurrency tools.
1390 The manager object is useful as it controls a server process that holds
1391 the python objects that can be shared across processes. This can be used
1392 for parent-subprocess communication:
1394 ```python
1395 manager = multi_process_runner.manager()
1396 some_event_happening_in_subprocess = manager.Event()
1397 mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec,
1398 args=(some_event_happening_in_subprocess,))
1399 mpr.start()
1400 some_event_happening_in_subprocess.wait()
1401 # Do something that only should after some event happens in subprocess.
1402 ```
1404 Note that the user of multi_process_runner should not create additional
1405 `multiprocessing.Manager()` objects; doing so can result in segfault in
1406 some cases.
1408 This method should only be called after multi_process_runner.test_main() is
1409 called.
1410 """
1411 _check_initialization()
1412 global _manager
1413 with _manager_lock:
1414 if _manager is None:
1415 _manager = multiprocessing.Manager()
1416 return _manager
1419@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[])
1420def test_main():
1421 """Main function to be called within `__main__` of a test file.
1423 Any test module that uses
1424 `tf.__internal__.distribute.multi_process_runner.run()`
1425 must call this instead of regular `test.main()` inside
1426 `if __name__ == '__main__':` block, or an error will be raised when
1427 `tf.__internal__.distribute.multi_process_runner.run()` is used. This method
1428 takes
1429 care of needed initialization for launching multiple subprocesses.
1431 Example:
1432 ```python
1433 class MyTestClass(tf.test.TestCase):
1434 def testSomething(self):
1435 # Testing code making use of
1436 # `tf.__internal__.distribute.multi_process_runner.run()`.
1438 if __name__ == '__main__':
1439 tf.__internal__.distribute.multi_process_runner.test_main()
1440 ```
1441 """
1442 # Inject tearDownModule() to shut down all pool runners. Active pool runners
1443 # will block the program from exiting. This is necessary for global pool
1444 # runners. We tried atexit in the past, and it doesn't work in some
1445 # deployment.
1446 old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule',
1447 None)
1449 def tear_down_module():
1450 _shutdown_all_pool_runners()
1451 if old_tear_down_module is not None:
1452 old_tear_down_module()
1454 setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module)
1455 multi_process_lib.test_main()