Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py: 18%

242 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 

16 

17import contextlib 

18import threading 

19import weakref 

20 

21from tensorflow.python import pywrap_tfe 

22from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 

23from tensorflow.python.autograph.impl import api as autograph 

24from tensorflow.python.distribute import distribute_lib 

25from tensorflow.python.distribute import distribute_utils 

26from tensorflow.python.distribute import shared_variable_creator 

27from tensorflow.python.eager import context 

28from tensorflow.python.eager import def_function 

29from tensorflow.python.framework import device as tf_device 

30from tensorflow.python.framework import ops 

31from tensorflow.python.ops import summary_ops_v2 

32from tensorflow.python.ops import variable_scope 

33from tensorflow.python.platform import tf_logging as logging 

34from tensorflow.python.training import coordinator 

35from tensorflow.python.util import traceback_utils 

36 

37 

38def _is_gpu_device(device): 

39 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 

40 

41 

42def call_for_each_replica(strategy, fn, args=None, kwargs=None): 

43 """Call `fn` on each worker devices(replica). 

44 

45 It's highly recommended to wrap the call to this function inside a 

46 `tf.function`, otherwise the performance is poor. 

47 

48 Args: 

49 strategy: `tf.distribute.Strategy`. 

50 fn: function to call on each worker devices. 

51 args: positional arguments to `fn`. 

52 kwargs: keyword arguments to `fn`. 

53 

54 Returns: 

55 Wrapped returned value of `fn` from all replicas. 

56 """ 

57 if args is None: 

58 args = () 

59 if kwargs is None: 

60 kwargs = {} 

61 

62 if isinstance(fn, def_function.Function): 

63 # Don't lift up the tf.function decoration if `fn` is compiled with XLA 

64 # and all devices are GPU. In this case we will use collectives to do 

65 # cross-device communication, thus no merge_call is in the path. 

66 if fn._jit_compile and all( # pylint: disable=protected-access 

67 [_is_gpu_device(d) for d in strategy.extended.worker_devices]): 

68 return _call_for_each_replica(strategy, fn, args, kwargs) 

69 

70 if strategy not in _cfer_fn_cache: 

71 _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary() 

72 wrapped = _cfer_fn_cache[strategy].get(fn) 

73 if wrapped is None: 

74 # We need to wrap fn such that it triggers _call_for_each_replica inside 

75 # the tf.function. We use _clone() instead of @tf.function wrapped 

76 # call_for_each_replica() because we would like to retain the arguments to 

77 # the @tf.function decorator of fn. 

78 def wrapped_fn(*args, **kwargs): 

79 return call_for_each_replica(strategy, fn.python_function, args, kwargs) 

80 

81 wrapped = fn._clone( # pylint: disable=protected-access 

82 python_function=wrapped_fn) 

83 _cfer_fn_cache[strategy][fn] = wrapped 

84 return wrapped(*args, **kwargs) 

85 

86 if context.executing_eagerly(): 

87 logging.log_first_n( 

88 logging.WARN, "Using %s eagerly has significant " 

89 "overhead currently. We will be working on improving " 

90 "this in the future, but for now please wrap " 

91 "`call_for_each_replica` or `experimental_run` or " 

92 "`run` inside a tf.function to get " 

93 "the best performance." % strategy.__class__.__name__, 5) 

94 else: 

95 # When a tf.function is wrapped to trigger _call_for_each_replica (see 

96 # the other branch above), AutoGraph stops conversion at 

97 # _call_for_each_replica itself (TF library functions are allowlisted). 

98 # This makes sure that the Python function that originally passed to 

99 # the tf.function is still converted. 

100 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 

101 

102 return _call_for_each_replica(strategy, fn, args, kwargs) 

103 

104 

105# Per strategy cache for call_for_each_replica def_function.Function objects. 

106_cfer_fn_cache = weakref.WeakKeyDictionary() 

107 

108 

109@contextlib.contextmanager 

110def _enter_graph(g, eager, creator_stack=None): 

111 """Context manager for selecting a graph and maybe eager mode.""" 

112 if eager: 

113 with g.as_default(), context.eager_mode(): 

114 if creator_stack is not None: 

115 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 

116 yield 

117 else: 

118 with g.as_default(): 

119 if creator_stack is not None: 

120 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 

121 yield 

122 

123 

124@contextlib.contextmanager 

125def _maybe_enter_eager_mode(eager): 

126 if eager: 

127 with context.eager_mode(): 

128 yield 

129 else: 

130 yield 

131 

132 

133def _cpu_device(device): 

134 cpu_device = tf_device.DeviceSpec.from_string(device) 

135 cpu_device = cpu_device.replace(device_type="CPU", device_index=0) 

136 return cpu_device.to_string() 

137 

138 

139class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name 

140 pass 

141 

142 

143def _get_thread_local_configuration_callable(): 

144 if traceback_utils.is_traceback_filtering_enabled(): 

145 thread_local_callables = {traceback_utils.enable_traceback_filtering} 

146 else: 

147 thread_local_callables = {traceback_utils.disable_traceback_filtering} 

148 return thread_local_callables 

149 

150 

151def _call_for_each_replica(distribution, fn, args, kwargs): 

152 """Run `fn` in separate threads, once per replica/worker device. 

153 

154 Args: 

155 distribution: the DistributionStrategy object. 

156 fn: function to run (will be run once per replica, each in its own thread). 

157 args: positional arguments for `fn` 

158 kwargs: keyword arguments for `fn`. 

159 

160 Returns: 

161 Merged return value of `fn` across all replicas. 

162 

163 Raises: 

164 RuntimeError: If fn() calls get_replica_context().merge_call() a different 

165 number of times from the available devices. 

166 """ 

167 # TODO(josh11b): Add this option once we add synchronization to variable 

168 # creation. Until then, this is pretty unsafe to use. 

169 run_concurrently = False 

170 if not context.executing_eagerly(): 

171 # Needed for per-thread device, etc. contexts in graph mode. 

172 ops.get_default_graph().switch_to_thread_local() 

173 

174 coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) 

175 

176 shared_variable_store = {} 

177 devices = distribution.extended.worker_devices 

178 

179 thread_local_callables = _get_thread_local_configuration_callable() 

180 

181 # TODO(isaprykin): Create these threads once instead of during every call. 

182 threads = [] 

183 for index in range(len(devices)): 

184 variable_creator_fn = shared_variable_creator.make_fn( 

185 shared_variable_store, index) 

186 t = _MirroredReplicaThread(distribution, coord, index, devices, 

187 variable_creator_fn, fn, 

188 distribute_utils.caching_scope_local, 

189 distribute_utils.select_replica(index, args), 

190 distribute_utils.select_replica(index, kwargs), 

191 thread_local_callables) 

192 threads.append(t) 

193 

194 for t in threads: 

195 t.start() 

196 

197 # When `fn` starts `should_run` event is set on _MirroredReplicaThread 

198 # (`MRT`) threads. The execution waits until 

199 # `MRT.has_paused` is set, which indicates that either `fn` is 

200 # complete or a `get_replica_context().merge_call()` is called. If `fn` is 

201 # complete, then `MRT.done` is set to True. Otherwise, arguments 

202 # of `get_replica_context().merge_call` from all paused threads are grouped 

203 # and the `merge_fn` is performed. Results of the 

204 # `get_replica_context().merge_call` are then set to `MRT.merge_result`. 

205 # Each such `get_replica_context().merge_call` call returns the 

206 # `MRT.merge_result` for that thread when `MRT.should_run` event 

207 # is reset again. Execution of `fn` resumes. 

208 

209 try: 

210 with coord.stop_on_exception(): 

211 all_done = False 

212 while not all_done and not coord.should_stop(): 

213 done = [] 

214 if run_concurrently: 

215 for t in threads: 

216 t.should_run.set() 

217 for t in threads: 

218 t.has_paused.wait() 

219 t.has_paused.clear() 

220 if coord.should_stop(): 

221 return None 

222 done.append(t.done) 

223 else: 

224 for t in threads: 

225 t.should_run.set() 

226 t.has_paused.wait() 

227 t.has_paused.clear() 

228 if coord.should_stop(): 

229 return None 

230 done.append(t.done) 

231 if coord.should_stop(): 

232 return None 

233 all_done = all(done) 

234 if not all_done: 

235 if any(done): 

236 raise RuntimeError("Some replicas made a different number of " 

237 "replica_context().merge_call() calls.") 

238 # get_replica_context().merge_call() case 

239 merge_args = distribute_utils.regroup( 

240 tuple(t.merge_args for t in threads)) 

241 merge_kwargs = distribute_utils.regroup( 

242 tuple(t.merge_kwargs for t in threads)) 

243 # We capture the name_scope of the MRT when we call merge_fn 

244 # to ensure that if we have opened a name scope in the MRT, 

245 # it will be respected when executing the merge function. We only 

246 # capture the name_scope from the first MRT and assume it is 

247 # the same for all other MRTs. 

248 mtt_captured_name_scope = threads[0].captured_name_scope 

249 mtt_captured_var_scope = threads[0].captured_var_scope 

250 # Capture and merge the control dependencies from all the threads. 

251 mtt_captured_control_deps = set() 

252 for t in threads: 

253 mtt_captured_control_deps.update(t.captured_control_deps) 

254 

255 # Control is transfered from _MirroredReplicaThread (MRT) to the main 

256 # thread, i.e., here, to perform `merge_fn`, and thus we preserve the 

257 # name scope, control dependencies, etc. from MRT at the time 

258 # `merge_call` is made. 

259 # One special case is that the `merge_call` is made under an 

260 # `tf.init_scope` in the MRT. `tf.init_scope` will clear control 

261 # dependencies, pause gradient tape, and enter the lowest context on 

262 # the `context_stack` that is not building a graph function. Entering 

263 # the lowest context could be one of the two things: installation of a 

264 # graph as the default graph or switch into eager mode. If the former 

265 # is done and causes `merge_call` to be called in a different graph 

266 # from the one in which `call_for_each_replica` is called, we do not 

267 # allow this case (see comment in `_merge_call`) and we would not have 

268 # arrived here due to the assertion in `_merge_call`. However, if the 

269 # latter is done, we want to make sure the main thread enter an eager 

270 # mode scope as well so that `merge_fn` does not have trouble 

271 # accessing resources defined in MRT under the same context. 

272 with ops.name_scope( 

273 mtt_captured_name_scope), ops.control_dependencies( 

274 mtt_captured_control_deps), variable_scope.variable_scope( 

275 mtt_captured_var_scope), _maybe_enter_eager_mode( 

276 threads[0].merge_call_entered_in_eager): 

277 merge_result = threads[0].merge_fn(distribution, *merge_args, 

278 **merge_kwargs) 

279 for r, t in enumerate(threads): 

280 t.merge_result = distribute_utils.select_replica(r, merge_result) 

281 finally: 

282 for t in threads: 

283 t.should_run.set() 

284 coord.join(threads) 

285 

286 return distribute_utils.regroup(tuple(t.main_result for t in threads)) 

287 

288 

289class _MirroredReplicaThread(threading.Thread): 

290 """A thread that runs() a function on a device.""" 

291 

292 def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn, 

293 caching_scope, args, kwargs, thread_local_callables=None): 

294 super(_MirroredReplicaThread, self).__init__() 

295 self.coord = coord 

296 self.distribution = dist 

297 self.devices = devices 

298 self.replica_id = replica_id 

299 self.replica_id_in_sync_group = ( 

300 dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access 

301 

302 self.variable_creator_fn = variable_creator_fn 

303 # State needed to run and return the results of `fn`. 

304 self.main_fn = fn 

305 self.main_args = args 

306 self.main_kwargs = kwargs 

307 self.main_result = None 

308 self.done = False 

309 # State needed to run the next merge_call() (if any) requested via 

310 # ReplicaContext. 

311 self.merge_fn = None 

312 self.merge_args = None 

313 self.merge_kwargs = None 

314 self.merge_result = None 

315 self.captured_name_scope = None 

316 self.captured_var_scope = None 

317 try: 

318 self.caching_scope_entered = caching_scope.new_cache_scope_count 

319 self.caching_scope_exited = caching_scope.cache_scope_exited_count 

320 except AttributeError: 

321 self.caching_scope_entered = None 

322 self.caching_scope_exited = None 

323 

324 # We use a thread.Event for the main thread to signal when this 

325 # thread should start running (`should_run`), and another for 

326 # this thread to transfer control back to the main thread 

327 # (`has_paused`, either when it gets to a 

328 # `get_replica_context().merge_call` or when `fn` returns). In 

329 # either case the event starts cleared, is signaled by calling 

330 # set(). The receiving thread waits for the signal by calling 

331 # wait() and then immediately clearing the event using clear(). 

332 self.should_run = threading.Event() 

333 self.has_paused = threading.Event() 

334 # These fields have to do with inheriting various contexts from the 

335 # parent thread: 

336 context.ensure_initialized() 

337 ctx = context.context() 

338 self.in_eager = ctx.executing_eagerly() 

339 self.record_thread_local_summary_state() 

340 self.record_thread_local_eager_context_state() 

341 self.context_device_policy = ( 

342 pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( 

343 ctx._context_handle)) # pylint: disable=protected-access 

344 self.graph = ops.get_default_graph() 

345 with ops.init_scope(): 

346 self._init_in_eager = context.executing_eagerly() 

347 self._init_graph = ops.get_default_graph() 

348 self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access 

349 self._var_scope = variable_scope.get_variable_scope() 

350 # Adding a "/" at end lets us re-enter this scope later. 

351 self._name_scope = self.graph.get_name_scope() 

352 if self._name_scope: 

353 self._name_scope += "/" 

354 if self.replica_id > 0: 

355 if not self._name_scope: 

356 self._name_scope = "" 

357 self._name_scope += "replica_%d/" % self.replica_id 

358 

359 self._thread_local_callables = thread_local_callables 

360 

361 def run(self): 

362 self.should_run.wait() 

363 self.should_run.clear() 

364 try: 

365 if self.coord.should_stop(): 

366 return 

367 self.restore_thread_local_summary_state() 

368 self.restore_thread_local_callable() 

369 self.restore_thread_local_eager_context_state() 

370 if (self.caching_scope_entered is not None and 

371 self.caching_scope_exited is not None): 

372 distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered 

373 distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited 

374 # TODO(josh11b): Use current logical device instead of 0 here. 

375 with self.coord.stop_on_exception(), \ 

376 _enter_graph(self._init_graph, self._init_in_eager), \ 

377 _enter_graph(self.graph, self.in_eager, 

378 self._variable_creator_stack), \ 

379 context.device_policy(self.context_device_policy), \ 

380 _MirroredReplicaContext(self.distribution, 

381 self.replica_id_in_sync_group), \ 

382 ops.device(self.devices[self.replica_id]), \ 

383 ops.name_scope(self._name_scope), \ 

384 variable_scope.variable_scope( 

385 self._var_scope, reuse=self.replica_id > 0), \ 

386 variable_scope.variable_creator_scope(self.variable_creator_fn): 

387 self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) 

388 self.done = True 

389 finally: 

390 self.has_paused.set() 

391 

392 def record_thread_local_summary_state(self): 

393 """Record the thread local summary state in self.""" 

394 # TODO(slebedev): is this still relevant? the referenced bug is closed. 

395 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 

396 self._summary_step = summary_state.step 

397 self._summary_writer = summary_state.writer 

398 self._summary_recording = summary_state.is_recording 

399 self._summary_recording_distribution_strategy = ( 

400 summary_state.is_recording_distribution_strategy) 

401 

402 def restore_thread_local_summary_state(self): 

403 """Restore thread local summary state from self.""" 

404 # TODO(slebedev): is this still relevant? the referenced bug is closed. 

405 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 

406 summary_state.step = self._summary_step 

407 summary_state.writer = self._summary_writer 

408 summary_state.is_recording = self._summary_recording 

409 summary_state.is_recording_distribution_strategy = ( 

410 self._summary_recording_distribution_strategy) 

411 

412 def record_thread_local_eager_context_state(self): 

413 ctx = context.context() 

414 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 

415 self._eager_context_op_callbacks = eager_context_state.op_callbacks 

416 # TODO(b/125892694): record other fields in EagerContext. 

417 

418 def restore_thread_local_eager_context_state(self): 

419 ctx = context.context() 

420 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 

421 eager_context_state.op_callbacks = self._eager_context_op_callbacks 

422 # TODO(b/125892694): record other fields in EagerContext. 

423 

424 def restore_thread_local_callable(self): 

425 if self._thread_local_callables: 

426 for fn in self._thread_local_callables: 

427 fn() 

428 

429 

430class _MirroredReplicaContext(distribute_lib.ReplicaContext): 

431 """ReplicaContext for synchronized replica.""" 

432 

433 def _merge_call(self, fn, args, kwargs): 

434 """`merge_call()` implementation for synchronized replica. 

435 

436 This pauses the current replica thread and passes `fn` and its arguments to 

437 the main thread. The main thread will wait until all replicas pause, then 

438 invoke `fn` with grouped arguments. The current replica thread will continue 

439 after `fn` completes. 

440 

441 See `_call_for_each_replica` for the logic in the main thread. 

442 

443 Args: 

444 fn: a function that is called in cross replica context with grouped 

445 arguments from each replica. `fn` should returns grouped values. 

446 args: positional arguments to `fn`. 

447 kwargs: keyward arguments to `fn`. 

448 

449 Returns: 

450 Return value of `fn` for the current replica. 

451 

452 Raises: 

453 RuntimeError: when merge_call happens in a different graph, e.g. in a 

454 different tf.function, which is not supported now. 

455 _RequestedStop: when stop is requested. 

456 

457 """ 

458 t = threading.current_thread() 

459 assert isinstance(t, _MirroredReplicaThread) 

460 t.merge_fn = fn 

461 t.merge_args = args 

462 t.merge_kwargs = kwargs 

463 t.captured_name_scope = t.graph.get_name_scope() 

464 # Adding a "/" at end lets us re-enter this scope later. 

465 if t.captured_name_scope: 

466 t.captured_name_scope += "/" 

467 

468 t.captured_var_scope = variable_scope.get_variable_scope() 

469 t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access 

470 

471 t.merge_call_entered_in_eager = context.context().executing_eagerly() 

472 

473 # It is problematic if `merge_call` is called under a different graph other 

474 # than the one that `_call_for_each_replica` is called under, there are 

475 # 3 cases this can happen: 

476 # 

477 # 1. The `fn` passed to `_call_for_each_replica` is decorated with 

478 # `tf.function` and there is a `merge_call` in `fn`. Since 

479 # MirroredStrategy traces a separate function per thread (per device), 

480 # and each trace takes a shared lock, the lock is never released by the 

481 # first thread and subsequent replica threads cannot proceed to trace 

482 # their own functions. This issue is addressed by always converting 

483 # `_call_for_each_replica(tf.function(f))` to 

484 # ``tf.function(_call_for_each_replica(f))`.` in 

485 # `MirroredStrategy._call_for_each_replica`. 

486 # 

487 # 2. The `fn` passed to `_call_for_each_replica` contains a nested 

488 # `tf.function`, and there is a `merge_call` in the nested `tf.function`. 

489 # In this case each thread can successfully trace its own function, but 

490 # since the `merge_fn` passed to `merge_call` is executed in the main 

491 # thread (where `_call_for_each_replica` is executed), it can't access 

492 # the tensors that come from different graphs. 

493 # 

494 # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow 

495 # statement, and there is a `merge_call` inside the control-flow body, 

496 # `fn` or `_call_for_each_replica` is decorated with `tf.function`. 

497 # Control flow statement creates a separate graph for its body, similar 

498 # to #2, `merge_fn` executed in the main thread can't access the 

499 # tensors that come from different graphs. 

500 # 

501 # We raise an error for #2 and #3. 

502 if ops.get_default_graph() != t.graph: 

503 raise RuntimeError( 

504 "`merge_call` called while defining a new graph or a tf.function." 

505 " This can often happen if the function `fn` passed to" 

506 " `strategy.run()` contains a nested `@tf.function`, and the nested " 

507 "`@tf.function` contains a synchronization point, such as aggregating" 

508 " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" 

509 " uses a control flow statement which contains a synchronization" 

510 " point in the body. Such behaviors are not yet supported. Instead," 

511 " please avoid nested `tf.function`s or control flow statements that" 

512 " may potentially cross a synchronization boundary, for example," 

513 " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" 

514 " inside a `tf.function` or move the control flow out of `fn`. If" 

515 " you are subclassing a `tf.keras.Model`, please avoid decorating" 

516 " overridden methods `test_step` and `train_step` in `tf.function`.") 

517 

518 t.has_paused.set() 

519 t.should_run.wait() 

520 t.should_run.clear() 

521 if t.coord.should_stop(): 

522 raise _RequestedStop() 

523 t.merge_call_entered_in_eager = None 

524 return t.merge_result 

525 

526 @property 

527 def devices(self): 

528 distribute_lib.require_replica_context(self) 

529 return [ 

530 self._strategy.extended.worker_devices_by_replica[ 

531 self._replica_id_in_sync_group] 

532 ]