Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/multi_process_lib.py: 42%

73 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Library for multi-process testing.""" 

16 

17import multiprocessing 

18import os 

19import platform 

20import sys 

21import unittest 

22from absl import app 

23from absl import logging 

24 

25from tensorflow.python.eager import test 

26 

27 

28def is_oss(): 

29 """Returns whether the test is run under OSS.""" 

30 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] 

31 

32 

33def _is_enabled(): 

34 # Note that flags may not be parsed at this point and simply importing the 

35 # flags module causes a variety of unusual errors. 

36 tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')] 

37 if is_oss() and tpu_args: 

38 return False 

39 if sys.version_info == (3, 8) and platform.system() == 'Linux': 

40 return False # TODO(b/171242147) 

41 return sys.platform != 'win32' 

42 

43 

44class _AbslProcess: 

45 """A process that runs using absl.app.run.""" 

46 

47 def __init__(self, *args, **kwargs): 

48 super(_AbslProcess, self).__init__(*args, **kwargs) 

49 # Monkey-patch that is carried over into the spawned process by pickle. 

50 self._run_impl = getattr(self, 'run') 

51 self.run = self._run_with_absl 

52 

53 def _run_with_absl(self): 

54 app.run(lambda _: self._run_impl()) 

55 

56 

57if _is_enabled(): 

58 

59 class AbslForkServerProcess(_AbslProcess, 

60 multiprocessing.context.ForkServerProcess): 

61 """An absl-compatible Forkserver process. 

62 

63 Note: Forkserver is not available in windows. 

64 """ 

65 

66 class AbslForkServerContext(multiprocessing.context.ForkServerContext): 

67 _name = 'absl_forkserver' 

68 Process = AbslForkServerProcess # pylint: disable=invalid-name 

69 

70 multiprocessing = AbslForkServerContext() 

71 Process = multiprocessing.Process 

72 

73else: 

74 

75 class Process(object): 

76 """A process that skips test (until windows is supported).""" 

77 

78 def __init__(self, *args, **kwargs): 

79 del args, kwargs 

80 raise unittest.SkipTest( 

81 'TODO(b/150264776): Windows is not supported in MultiProcessRunner.') 

82 

83 

84_test_main_called = False 

85 

86 

87def _set_spawn_exe_path(): 

88 """Set the path to the executable for spawned processes. 

89 

90 This utility searches for the binary the parent process is using, and sets 

91 the executable of multiprocessing's context accordingly. 

92 

93 Raises: 

94 RuntimeError: If the binary path cannot be determined. 

95 """ 

96 # TODO(b/150264776): This does not work with Windows. Find a solution. 

97 if sys.argv[0].endswith('.py'): 

98 def guess_path(package_root): 

99 # If all we have is a python module path, we'll need to make a guess for 

100 # the actual executable path. 

101 if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]: 

102 # Guess the binary path under bazel. For target 

103 # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the 

104 # argv[0] is in the form of 

105 # /.../tensorflow/python/distribute/input_lib_test.py 

106 # and the binary is 

107 # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu 

108 package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)] 

109 binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1) 

110 possible_path = os.path.join(package_root_base, package_root, 

111 binary) 

112 logging.info('Guessed test binary path: %s', possible_path) 

113 if os.access(possible_path, os.X_OK): 

114 return possible_path 

115 return None 

116 path = guess_path('org_tensorflow') 

117 if not path: 

118 path = guess_path('org_keras') 

119 if path is None: 

120 logging.error( 

121 'Cannot determine binary path. sys.argv[0]=%s os.environ=%s', 

122 sys.argv[0], os.environ) 

123 raise RuntimeError('Cannot determine binary path') 

124 sys.argv[0] = path 

125 # Note that this sets the executable for *all* contexts. 

126 multiprocessing.get_context().set_executable(sys.argv[0]) 

127 

128 

129def _if_spawn_run_and_exit(): 

130 """If spawned process, run requested spawn task and exit. Else a no-op.""" 

131 

132 # `multiprocessing` module passes a script "from multiprocessing.x import y" 

133 # to subprocess, followed by a main function call. We use this to tell if 

134 # the process is spawned. Examples of x are "forkserver" or 

135 # "semaphore_tracker". 

136 is_spawned = ('-c' in sys.argv[1:] and 

137 sys.argv[sys.argv.index('-c') + 

138 1].startswith('from multiprocessing.')) 

139 

140 if not is_spawned: 

141 return 

142 cmd = sys.argv[sys.argv.index('-c') + 1] 

143 # As a subprocess, we disregarding all other interpreter command line 

144 # arguments. 

145 sys.argv = sys.argv[0:1] 

146 

147 # Run the specified command - this is expected to be one of: 

148 # 1. Spawn the process for semaphore tracker. 

149 # 2. Spawn the initial process for forkserver. 

150 # 3. Spawn any process as requested by the "spawn" method. 

151 exec(cmd) # pylint: disable=exec-used 

152 sys.exit(0) # Semaphore tracker doesn't explicitly sys.exit. 

153 

154 

155def test_main(): 

156 """Main function to be called within `__main__` of a test file.""" 

157 global _test_main_called 

158 _test_main_called = True 

159 

160 os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 

161 

162 if _is_enabled(): 

163 _set_spawn_exe_path() 

164 _if_spawn_run_and_exit() 

165 

166 # Only runs test.main() if not spawned process. 

167 test.main() 

168 

169 

170def initialized(): 

171 """Returns whether the module is initialized.""" 

172 return _test_main_called