Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/multi_process_lib.py: 42%
73 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for multi-process testing."""
17import multiprocessing
18import os
19import platform
20import sys
21import unittest
22from absl import app
23from absl import logging
25from tensorflow.python.eager import test
28def is_oss():
29 """Returns whether the test is run under OSS."""
30 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
33def _is_enabled():
34 # Note that flags may not be parsed at this point and simply importing the
35 # flags module causes a variety of unusual errors.
36 tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')]
37 if is_oss() and tpu_args:
38 return False
39 if sys.version_info == (3, 8) and platform.system() == 'Linux':
40 return False # TODO(b/171242147)
41 return sys.platform != 'win32'
44class _AbslProcess:
45 """A process that runs using absl.app.run."""
47 def __init__(self, *args, **kwargs):
48 super(_AbslProcess, self).__init__(*args, **kwargs)
49 # Monkey-patch that is carried over into the spawned process by pickle.
50 self._run_impl = getattr(self, 'run')
51 self.run = self._run_with_absl
53 def _run_with_absl(self):
54 app.run(lambda _: self._run_impl())
57if _is_enabled():
59 class AbslForkServerProcess(_AbslProcess,
60 multiprocessing.context.ForkServerProcess):
61 """An absl-compatible Forkserver process.
63 Note: Forkserver is not available in windows.
64 """
66 class AbslForkServerContext(multiprocessing.context.ForkServerContext):
67 _name = 'absl_forkserver'
68 Process = AbslForkServerProcess # pylint: disable=invalid-name
70 multiprocessing = AbslForkServerContext()
71 Process = multiprocessing.Process
73else:
75 class Process(object):
76 """A process that skips test (until windows is supported)."""
78 def __init__(self, *args, **kwargs):
79 del args, kwargs
80 raise unittest.SkipTest(
81 'TODO(b/150264776): Windows is not supported in MultiProcessRunner.')
84_test_main_called = False
87def _set_spawn_exe_path():
88 """Set the path to the executable for spawned processes.
90 This utility searches for the binary the parent process is using, and sets
91 the executable of multiprocessing's context accordingly.
93 Raises:
94 RuntimeError: If the binary path cannot be determined.
95 """
96 # TODO(b/150264776): This does not work with Windows. Find a solution.
97 if sys.argv[0].endswith('.py'):
98 def guess_path(package_root):
99 # If all we have is a python module path, we'll need to make a guess for
100 # the actual executable path.
101 if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
102 # Guess the binary path under bazel. For target
103 # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
104 # argv[0] is in the form of
105 # /.../tensorflow/python/distribute/input_lib_test.py
106 # and the binary is
107 # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
108 package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
109 binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
110 possible_path = os.path.join(package_root_base, package_root,
111 binary)
112 logging.info('Guessed test binary path: %s', possible_path)
113 if os.access(possible_path, os.X_OK):
114 return possible_path
115 return None
116 path = guess_path('org_tensorflow')
117 if not path:
118 path = guess_path('org_keras')
119 if path is None:
120 logging.error(
121 'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
122 sys.argv[0], os.environ)
123 raise RuntimeError('Cannot determine binary path')
124 sys.argv[0] = path
125 # Note that this sets the executable for *all* contexts.
126 multiprocessing.get_context().set_executable(sys.argv[0])
129def _if_spawn_run_and_exit():
130 """If spawned process, run requested spawn task and exit. Else a no-op."""
132 # `multiprocessing` module passes a script "from multiprocessing.x import y"
133 # to subprocess, followed by a main function call. We use this to tell if
134 # the process is spawned. Examples of x are "forkserver" or
135 # "semaphore_tracker".
136 is_spawned = ('-c' in sys.argv[1:] and
137 sys.argv[sys.argv.index('-c') +
138 1].startswith('from multiprocessing.'))
140 if not is_spawned:
141 return
142 cmd = sys.argv[sys.argv.index('-c') + 1]
143 # As a subprocess, we disregarding all other interpreter command line
144 # arguments.
145 sys.argv = sys.argv[0:1]
147 # Run the specified command - this is expected to be one of:
148 # 1. Spawn the process for semaphore tracker.
149 # 2. Spawn the initial process for forkserver.
150 # 3. Spawn any process as requested by the "spawn" method.
151 exec(cmd) # pylint: disable=exec-used
152 sys.exit(0) # Semaphore tracker doesn't explicitly sys.exit.
155def test_main():
156 """Main function to be called within `__main__` of a test file."""
157 global _test_main_called
158 _test_main_called = True
160 os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
162 if _is_enabled():
163 _set_spawn_exe_path()
164 _if_spawn_run_and_exit()
166 # Only runs test.main() if not spawned process.
167 test.main()
170def initialized():
171 """Returns whether the module is initialized."""
172 return _test_main_called