Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_strategy_util.py: 21%

124 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TPU specific APIs to be used in conjunction with TPU Strategy.""" 

16 

17import gc 

18 

19from tensorflow.core.protobuf import config_pb2 

20from tensorflow.python.client import session as session_lib 

21from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver 

22from tensorflow.python.eager import context 

23from tensorflow.python.eager import monitoring 

24from tensorflow.python.eager.def_function import function 

25from tensorflow.python.eager.def_function import functions_run_eagerly 

26from tensorflow.python.eager.def_function import run_functions_eagerly 

27from tensorflow.python.framework import device 

28from tensorflow.python.framework import errors 

29from tensorflow.python.framework import ops 

30from tensorflow.python.platform import tf_logging as logging 

31from tensorflow.python.tpu import topology 

32from tensorflow.python.tpu import tpu 

33from tensorflow.python.util import compat 

34from tensorflow.python.util.tf_export import tf_export 

35 

36 

37_INITIALIZED_TPU_SYSTEMS = {} 

38_LOCAL_MASTERS = ("", "local") 

39 

40 

41_tpu_worker_address = monitoring.StringGauge( 

42 "/tensorflow/tpu/worker_address", 

43 "The worker address that the coordinator/client connects to.", "address") 

44 

45 

46@tf_export("tpu.experimental.initialize_tpu_system") 

47def initialize_tpu_system(cluster_resolver=None): 

48 """Initialize the TPU devices. 

49 

50 Args: 

51 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 

52 which provides information about the TPU cluster. 

53 Returns: 

54 The tf.tpu.Topology object for the topology of the TPU cluster. If called 

55 inside tf.function, it returns the serialized topology object instead. 

56 

57 Raises: 

58 RuntimeError: If running inside a tf.function. 

59 NotFoundError: If no TPU devices found in eager mode. 

60 """ 

61 

62 # Deallocate all TPU buffers by clearing out eager context caches and 

63 # triggering garbage collection to avoid keeping invalid tpu buffer around 

64 # after reinitialized tpu system. 

65 logging.info("Deallocate tpu buffers before initializing tpu system.") 

66 context.context()._clear_caches() # pylint: disable=protected-access 

67 context.context().clear_kernel_cache() 

68 gc.collect() 

69 

70 job = None 

71 if cluster_resolver is None: 

72 # If no cluster resolver is specified, and running eagerly, execute the init 

73 # ops in the current device scope. 

74 if context.executing_eagerly(): 

75 curr_device = device.DeviceSpec.from_string(context.context().device_name) 

76 if curr_device.job is not None: 

77 job = "{}/replica:0/task:0".format(curr_device.job) 

78 

79 cluster_resolver = TPUClusterResolver("") 

80 assert isinstance(cluster_resolver, TPUClusterResolver) 

81 

82 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 

83 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 

84 logging.warning( 

85 "TPU system %s has already been initialized. " 

86 "Reinitializing the TPU can cause previously created " 

87 "variables on TPU to be lost.", tpu_name) 

88 

89 logging.info("Initializing the TPU system: %s", tpu_name) 

90 

91 # This function looks as it is for the following non-intuitive reasons. 

92 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger 

93 # DistributedTPURewritePass. This pass actually adds real ops that 

94 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system 

95 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 

96 if tpu_name not in _LOCAL_MASTERS: 

97 # Explicitly place the tpu.initialize_system in the first worker to 

98 # avoid the output node match multiple devices error. 

99 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 

100 

101 if context.executing_eagerly(): 

102 @function(autograph=False) 

103 def _tpu_init_fn(): 

104 # In TF1, we usually close chips when compilation fails to clear the data 

105 # in infeed. In TF2, we don't need to do this because infeed is no longer 

106 # used, so user can recover from TPU compilation failures more smoothly. 

107 # Same for the cancellation of a TPU excution. 

108 return tpu.initialize_system( 

109 job=job, 

110 compilation_failure_closes_chips=False, 

111 tpu_cancellation_closes_chips=False) 

112 

113 # The TPU_SYSTEM device must match the device used in tpu.initialize_system 

114 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 

115 # devices available. 

116 run_eagerly = functions_run_eagerly() 

117 if run_eagerly: 

118 logging.warning( 

119 "It looks like tf.function behavior was disabled, perhaps using" 

120 " tf.config.run_functions_eagerly." 

121 " tf.tpu.experimental.initialize_tpu_system requires tf.function to" 

122 " work. This primitive will override the disable." 

123 ) 

124 run_functions_eagerly(False) 

125 try: 

126 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 

127 output = _tpu_init_fn() 

128 context.async_wait() 

129 except errors.InvalidArgumentError as e: 

130 raise errors.NotFoundError( 

131 None, None, 

132 "TPUs not found in the cluster. Failed in initialization: " 

133 + str(e)) 

134 finally: 

135 if run_eagerly is not None: 

136 run_functions_eagerly(run_eagerly) 

137 # Clear out the eager context caches since the memory is invalid now. 

138 context.context()._initialize_logical_devices() # pylint: disable=protected-access 

139 

140 serialized_topology = output.numpy() 

141 elif not ops.executing_eagerly_outside_functions(): 

142 master = cluster_resolver.master() 

143 cluster_spec = cluster_resolver.cluster_spec() 

144 

145 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 

146 if cluster_spec: 

147 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 

148 

149 with ops.Graph().as_default(): 

150 with session_lib.Session(config=session_config, target=master) as sess: 

151 serialized_topology = sess.run(tpu.initialize_system()) 

152 else: 

153 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 

154 serialized_topology = tpu.initialize_system( 

155 job=job, compilation_failure_closes_chips=False) 

156 # If initialize_tpu_system is called inside tf.function, we only return 

157 # the serialized topology object as the tf.tpu.Topology object has to be 

158 # constructed in eager mode. 

159 return serialized_topology 

160 

161 logging.info("Finished initializing TPU system.") 

162 tpu_topology = topology.Topology(serialized=serialized_topology) 

163 cluster_resolver.set_tpu_topology(serialized_topology) 

164 _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology 

165 

166 # Record the address of the TPU worker-0 that the coordinator connects to. 

167 # This can be used to associate the TPU worker with the right coordinator when 

168 # aggregating the metrics for the application. An example of the address: 

169 # /bns/mb/borg/mb/bns/chienchunh/chienchunh_group_49640234.1.tfm_train_tpu_worker/0 

170 _tpu_worker_address.get_cell("address").set(cluster_resolver.get_master()) 

171 

172 return tpu_topology 

173 

174 

175def get_initialized_tpu_systems(): 

176 """Returns all currently initialized tpu systems. 

177 

178 Returns: 

179 A dictionary, with tpu name as the key and the tpu topology as the value. 

180 """ 

181 return _INITIALIZED_TPU_SYSTEMS.copy() 

182 

183 

184@tf_export("tpu.experimental.shutdown_tpu_system") 

185def shutdown_tpu_system(cluster_resolver=None): 

186 """Shuts down the TPU devices. 

187 

188 This will clear all caches, even those that are maintained through sequential 

189 calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation 

190 cache. 

191 

192 Args: 

193 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 

194 which provides information about the TPU cluster. 

195 

196 Raises: 

197 RuntimeError: If no TPU devices found for eager execution or if run in a 

198 tf.function. 

199 """ 

200 job = None 

201 if cluster_resolver is None: 

202 # If no cluster resolver is specified, and running eagerly, execute the init 

203 # ops in the current device scope. 

204 if context.executing_eagerly(): 

205 curr_device = device.DeviceSpec.from_string(context.context().device_name) 

206 if curr_device.job is not None: 

207 job = "{}/replica:0/task:0".format(curr_device.job) 

208 

209 cluster_resolver = TPUClusterResolver("") 

210 assert isinstance(cluster_resolver, TPUClusterResolver) 

211 

212 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 

213 if tpu_name not in _INITIALIZED_TPU_SYSTEMS: 

214 logging.warning("You are shutting down a TPU system %s that has not been " 

215 "initialized." % tpu_name) 

216 

217 logging.info("Shutting down the TPU system: %s", tpu_name) 

218 

219 if context.executing_eagerly(): 

220 # This function looks as it is for the following non-intuitive reasons. 

221 # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger 

222 # DistributedTPURewritePass. This pass actually adds real ops that 

223 # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system 

224 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 

225 if tpu_name not in _LOCAL_MASTERS: 

226 # Explicitly place the tpu.shutdown_system in the first worker to 

227 # avoid the output node match multiple devices error. 

228 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 

229 

230 @function(autograph=False) 

231 def _tpu_shutdown_fn(): 

232 tpu.shutdown_system(job=job) 

233 

234 # The TPU_SYSTEM device must match the device used in tpu.shutdown_system 

235 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 

236 # devices available. 

237 run_eagerly = functions_run_eagerly() 

238 if run_eagerly: 

239 logging.warning( 

240 "It looks like tf.function behavior was disabled, perhaps using" 

241 " tf.config.run_functions_eagerly." 

242 " tf.tpu.experimental.shutdown_tpu_system requires tf.function to" 

243 " work. This primitive will override the disable." 

244 ) 

245 run_functions_eagerly(False) 

246 try: 

247 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 

248 _tpu_shutdown_fn() 

249 finally: 

250 if run_eagerly is not None: 

251 run_functions_eagerly(run_eagerly) 

252 

253 # Clear out the eager context caches since the memory is invalid now. 

254 logging.info("Clearing out eager caches") 

255 context.context()._clear_caches() # pylint: disable=protected-access 

256 context.context().clear_kernel_cache() 

257 elif not ops.executing_eagerly_outside_functions(): 

258 master = cluster_resolver.master() 

259 cluster_spec = cluster_resolver.cluster_spec() 

260 

261 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 

262 if cluster_spec: 

263 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 

264 

265 with ops.Graph().as_default(): 

266 with session_lib.Session(config=session_config, target=master) as sess: 

267 sess.run(tpu.shutdown_system()) 

268 else: 

269 raise RuntimeError( 

270 "initialize_tpu_system is not supported within " 

271 "tf.functions. You should call initialize_tpu_system outside of your tf.function. " 

272 ) 

273 

274 logging.info("Finished shutting down TPU system.") 

275 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 

276 del _INITIALIZED_TPU_SYSTEMS[tpu_name]