Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py: 18%
242 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
17import contextlib
18import threading
19import weakref
21from tensorflow.python import pywrap_tfe
22from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
23from tensorflow.python.autograph.impl import api as autograph
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import shared_variable_creator
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import device as tf_device
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import summary_ops_v2
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.training import coordinator
35from tensorflow.python.util import traceback_utils
38def _is_gpu_device(device):
39 return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
42def call_for_each_replica(strategy, fn, args=None, kwargs=None):
43 """Call `fn` on each worker devices(replica).
45 It's highly recommended to wrap the call to this function inside a
46 `tf.function`, otherwise the performance is poor.
48 Args:
49 strategy: `tf.distribute.Strategy`.
50 fn: function to call on each worker devices.
51 args: positional arguments to `fn`.
52 kwargs: keyword arguments to `fn`.
54 Returns:
55 Wrapped returned value of `fn` from all replicas.
56 """
57 if args is None:
58 args = ()
59 if kwargs is None:
60 kwargs = {}
62 if isinstance(fn, def_function.Function):
63 # Don't lift up the tf.function decoration if `fn` is compiled with XLA
64 # and all devices are GPU. In this case we will use collectives to do
65 # cross-device communication, thus no merge_call is in the path.
66 if fn._jit_compile and all( # pylint: disable=protected-access
67 [_is_gpu_device(d) for d in strategy.extended.worker_devices]):
68 return _call_for_each_replica(strategy, fn, args, kwargs)
70 if strategy not in _cfer_fn_cache:
71 _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
72 wrapped = _cfer_fn_cache[strategy].get(fn)
73 if wrapped is None:
74 # We need to wrap fn such that it triggers _call_for_each_replica inside
75 # the tf.function. We use _clone() instead of @tf.function wrapped
76 # call_for_each_replica() because we would like to retain the arguments to
77 # the @tf.function decorator of fn.
78 def wrapped_fn(*args, **kwargs):
79 return call_for_each_replica(strategy, fn.python_function, args, kwargs)
81 wrapped = fn._clone( # pylint: disable=protected-access
82 python_function=wrapped_fn)
83 _cfer_fn_cache[strategy][fn] = wrapped
84 return wrapped(*args, **kwargs)
86 if context.executing_eagerly():
87 logging.log_first_n(
88 logging.WARN, "Using %s eagerly has significant "
89 "overhead currently. We will be working on improving "
90 "this in the future, but for now please wrap "
91 "`call_for_each_replica` or `experimental_run` or "
92 "`run` inside a tf.function to get "
93 "the best performance." % strategy.__class__.__name__, 5)
94 else:
95 # When a tf.function is wrapped to trigger _call_for_each_replica (see
96 # the other branch above), AutoGraph stops conversion at
97 # _call_for_each_replica itself (TF library functions are allowlisted).
98 # This makes sure that the Python function that originally passed to
99 # the tf.function is still converted.
100 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
102 return _call_for_each_replica(strategy, fn, args, kwargs)
105# Per strategy cache for call_for_each_replica def_function.Function objects.
106_cfer_fn_cache = weakref.WeakKeyDictionary()
109@contextlib.contextmanager
110def _enter_graph(g, eager, creator_stack=None):
111 """Context manager for selecting a graph and maybe eager mode."""
112 if eager:
113 with g.as_default(), context.eager_mode():
114 if creator_stack is not None:
115 g._variable_creator_stack = creator_stack # pylint: disable=protected-access
116 yield
117 else:
118 with g.as_default():
119 if creator_stack is not None:
120 g._variable_creator_stack = creator_stack # pylint: disable=protected-access
121 yield
124@contextlib.contextmanager
125def _maybe_enter_eager_mode(eager):
126 if eager:
127 with context.eager_mode():
128 yield
129 else:
130 yield
133def _cpu_device(device):
134 cpu_device = tf_device.DeviceSpec.from_string(device)
135 cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
136 return cpu_device.to_string()
139class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
140 pass
143def _get_thread_local_configuration_callable():
144 if traceback_utils.is_traceback_filtering_enabled():
145 thread_local_callables = {traceback_utils.enable_traceback_filtering}
146 else:
147 thread_local_callables = {traceback_utils.disable_traceback_filtering}
148 return thread_local_callables
151def _call_for_each_replica(distribution, fn, args, kwargs):
152 """Run `fn` in separate threads, once per replica/worker device.
154 Args:
155 distribution: the DistributionStrategy object.
156 fn: function to run (will be run once per replica, each in its own thread).
157 args: positional arguments for `fn`
158 kwargs: keyword arguments for `fn`.
160 Returns:
161 Merged return value of `fn` across all replicas.
163 Raises:
164 RuntimeError: If fn() calls get_replica_context().merge_call() a different
165 number of times from the available devices.
166 """
167 # TODO(josh11b): Add this option once we add synchronization to variable
168 # creation. Until then, this is pretty unsafe to use.
169 run_concurrently = False
170 if not context.executing_eagerly():
171 # Needed for per-thread device, etc. contexts in graph mode.
172 ops.get_default_graph().switch_to_thread_local()
174 coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
176 shared_variable_store = {}
177 devices = distribution.extended.worker_devices
179 thread_local_callables = _get_thread_local_configuration_callable()
181 # TODO(isaprykin): Create these threads once instead of during every call.
182 threads = []
183 for index in range(len(devices)):
184 variable_creator_fn = shared_variable_creator.make_fn(
185 shared_variable_store, index)
186 t = _MirroredReplicaThread(distribution, coord, index, devices,
187 variable_creator_fn, fn,
188 distribute_utils.caching_scope_local,
189 distribute_utils.select_replica(index, args),
190 distribute_utils.select_replica(index, kwargs),
191 thread_local_callables)
192 threads.append(t)
194 for t in threads:
195 t.start()
197 # When `fn` starts `should_run` event is set on _MirroredReplicaThread
198 # (`MRT`) threads. The execution waits until
199 # `MRT.has_paused` is set, which indicates that either `fn` is
200 # complete or a `get_replica_context().merge_call()` is called. If `fn` is
201 # complete, then `MRT.done` is set to True. Otherwise, arguments
202 # of `get_replica_context().merge_call` from all paused threads are grouped
203 # and the `merge_fn` is performed. Results of the
204 # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
205 # Each such `get_replica_context().merge_call` call returns the
206 # `MRT.merge_result` for that thread when `MRT.should_run` event
207 # is reset again. Execution of `fn` resumes.
209 try:
210 with coord.stop_on_exception():
211 all_done = False
212 while not all_done and not coord.should_stop():
213 done = []
214 if run_concurrently:
215 for t in threads:
216 t.should_run.set()
217 for t in threads:
218 t.has_paused.wait()
219 t.has_paused.clear()
220 if coord.should_stop():
221 return None
222 done.append(t.done)
223 else:
224 for t in threads:
225 t.should_run.set()
226 t.has_paused.wait()
227 t.has_paused.clear()
228 if coord.should_stop():
229 return None
230 done.append(t.done)
231 if coord.should_stop():
232 return None
233 all_done = all(done)
234 if not all_done:
235 if any(done):
236 raise RuntimeError("Some replicas made a different number of "
237 "replica_context().merge_call() calls.")
238 # get_replica_context().merge_call() case
239 merge_args = distribute_utils.regroup(
240 tuple(t.merge_args for t in threads))
241 merge_kwargs = distribute_utils.regroup(
242 tuple(t.merge_kwargs for t in threads))
243 # We capture the name_scope of the MRT when we call merge_fn
244 # to ensure that if we have opened a name scope in the MRT,
245 # it will be respected when executing the merge function. We only
246 # capture the name_scope from the first MRT and assume it is
247 # the same for all other MRTs.
248 mtt_captured_name_scope = threads[0].captured_name_scope
249 mtt_captured_var_scope = threads[0].captured_var_scope
250 # Capture and merge the control dependencies from all the threads.
251 mtt_captured_control_deps = set()
252 for t in threads:
253 mtt_captured_control_deps.update(t.captured_control_deps)
255 # Control is transfered from _MirroredReplicaThread (MRT) to the main
256 # thread, i.e., here, to perform `merge_fn`, and thus we preserve the
257 # name scope, control dependencies, etc. from MRT at the time
258 # `merge_call` is made.
259 # One special case is that the `merge_call` is made under an
260 # `tf.init_scope` in the MRT. `tf.init_scope` will clear control
261 # dependencies, pause gradient tape, and enter the lowest context on
262 # the `context_stack` that is not building a graph function. Entering
263 # the lowest context could be one of the two things: installation of a
264 # graph as the default graph or switch into eager mode. If the former
265 # is done and causes `merge_call` to be called in a different graph
266 # from the one in which `call_for_each_replica` is called, we do not
267 # allow this case (see comment in `_merge_call`) and we would not have
268 # arrived here due to the assertion in `_merge_call`. However, if the
269 # latter is done, we want to make sure the main thread enter an eager
270 # mode scope as well so that `merge_fn` does not have trouble
271 # accessing resources defined in MRT under the same context.
272 with ops.name_scope(
273 mtt_captured_name_scope), ops.control_dependencies(
274 mtt_captured_control_deps), variable_scope.variable_scope(
275 mtt_captured_var_scope), _maybe_enter_eager_mode(
276 threads[0].merge_call_entered_in_eager):
277 merge_result = threads[0].merge_fn(distribution, *merge_args,
278 **merge_kwargs)
279 for r, t in enumerate(threads):
280 t.merge_result = distribute_utils.select_replica(r, merge_result)
281 finally:
282 for t in threads:
283 t.should_run.set()
284 coord.join(threads)
286 return distribute_utils.regroup(tuple(t.main_result for t in threads))
289class _MirroredReplicaThread(threading.Thread):
290 """A thread that runs() a function on a device."""
292 def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn,
293 caching_scope, args, kwargs, thread_local_callables=None):
294 super(_MirroredReplicaThread, self).__init__()
295 self.coord = coord
296 self.distribution = dist
297 self.devices = devices
298 self.replica_id = replica_id
299 self.replica_id_in_sync_group = (
300 dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access
302 self.variable_creator_fn = variable_creator_fn
303 # State needed to run and return the results of `fn`.
304 self.main_fn = fn
305 self.main_args = args
306 self.main_kwargs = kwargs
307 self.main_result = None
308 self.done = False
309 # State needed to run the next merge_call() (if any) requested via
310 # ReplicaContext.
311 self.merge_fn = None
312 self.merge_args = None
313 self.merge_kwargs = None
314 self.merge_result = None
315 self.captured_name_scope = None
316 self.captured_var_scope = None
317 try:
318 self.caching_scope_entered = caching_scope.new_cache_scope_count
319 self.caching_scope_exited = caching_scope.cache_scope_exited_count
320 except AttributeError:
321 self.caching_scope_entered = None
322 self.caching_scope_exited = None
324 # We use a thread.Event for the main thread to signal when this
325 # thread should start running (`should_run`), and another for
326 # this thread to transfer control back to the main thread
327 # (`has_paused`, either when it gets to a
328 # `get_replica_context().merge_call` or when `fn` returns). In
329 # either case the event starts cleared, is signaled by calling
330 # set(). The receiving thread waits for the signal by calling
331 # wait() and then immediately clearing the event using clear().
332 self.should_run = threading.Event()
333 self.has_paused = threading.Event()
334 # These fields have to do with inheriting various contexts from the
335 # parent thread:
336 context.ensure_initialized()
337 ctx = context.context()
338 self.in_eager = ctx.executing_eagerly()
339 self.record_thread_local_summary_state()
340 self.record_thread_local_eager_context_state()
341 self.context_device_policy = (
342 pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
343 ctx._context_handle)) # pylint: disable=protected-access
344 self.graph = ops.get_default_graph()
345 with ops.init_scope():
346 self._init_in_eager = context.executing_eagerly()
347 self._init_graph = ops.get_default_graph()
348 self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access
349 self._var_scope = variable_scope.get_variable_scope()
350 # Adding a "/" at end lets us re-enter this scope later.
351 self._name_scope = self.graph.get_name_scope()
352 if self._name_scope:
353 self._name_scope += "/"
354 if self.replica_id > 0:
355 if not self._name_scope:
356 self._name_scope = ""
357 self._name_scope += "replica_%d/" % self.replica_id
359 self._thread_local_callables = thread_local_callables
361 def run(self):
362 self.should_run.wait()
363 self.should_run.clear()
364 try:
365 if self.coord.should_stop():
366 return
367 self.restore_thread_local_summary_state()
368 self.restore_thread_local_callable()
369 self.restore_thread_local_eager_context_state()
370 if (self.caching_scope_entered is not None and
371 self.caching_scope_exited is not None):
372 distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered
373 distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited
374 # TODO(josh11b): Use current logical device instead of 0 here.
375 with self.coord.stop_on_exception(), \
376 _enter_graph(self._init_graph, self._init_in_eager), \
377 _enter_graph(self.graph, self.in_eager,
378 self._variable_creator_stack), \
379 context.device_policy(self.context_device_policy), \
380 _MirroredReplicaContext(self.distribution,
381 self.replica_id_in_sync_group), \
382 ops.device(self.devices[self.replica_id]), \
383 ops.name_scope(self._name_scope), \
384 variable_scope.variable_scope(
385 self._var_scope, reuse=self.replica_id > 0), \
386 variable_scope.variable_creator_scope(self.variable_creator_fn):
387 self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
388 self.done = True
389 finally:
390 self.has_paused.set()
392 def record_thread_local_summary_state(self):
393 """Record the thread local summary state in self."""
394 # TODO(slebedev): is this still relevant? the referenced bug is closed.
395 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
396 self._summary_step = summary_state.step
397 self._summary_writer = summary_state.writer
398 self._summary_recording = summary_state.is_recording
399 self._summary_recording_distribution_strategy = (
400 summary_state.is_recording_distribution_strategy)
402 def restore_thread_local_summary_state(self):
403 """Restore thread local summary state from self."""
404 # TODO(slebedev): is this still relevant? the referenced bug is closed.
405 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
406 summary_state.step = self._summary_step
407 summary_state.writer = self._summary_writer
408 summary_state.is_recording = self._summary_recording
409 summary_state.is_recording_distribution_strategy = (
410 self._summary_recording_distribution_strategy)
412 def record_thread_local_eager_context_state(self):
413 ctx = context.context()
414 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
415 self._eager_context_op_callbacks = eager_context_state.op_callbacks
416 # TODO(b/125892694): record other fields in EagerContext.
418 def restore_thread_local_eager_context_state(self):
419 ctx = context.context()
420 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
421 eager_context_state.op_callbacks = self._eager_context_op_callbacks
422 # TODO(b/125892694): record other fields in EagerContext.
424 def restore_thread_local_callable(self):
425 if self._thread_local_callables:
426 for fn in self._thread_local_callables:
427 fn()
430class _MirroredReplicaContext(distribute_lib.ReplicaContext):
431 """ReplicaContext for synchronized replica."""
433 def _merge_call(self, fn, args, kwargs):
434 """`merge_call()` implementation for synchronized replica.
436 This pauses the current replica thread and passes `fn` and its arguments to
437 the main thread. The main thread will wait until all replicas pause, then
438 invoke `fn` with grouped arguments. The current replica thread will continue
439 after `fn` completes.
441 See `_call_for_each_replica` for the logic in the main thread.
443 Args:
444 fn: a function that is called in cross replica context with grouped
445 arguments from each replica. `fn` should returns grouped values.
446 args: positional arguments to `fn`.
447 kwargs: keyward arguments to `fn`.
449 Returns:
450 Return value of `fn` for the current replica.
452 Raises:
453 RuntimeError: when merge_call happens in a different graph, e.g. in a
454 different tf.function, which is not supported now.
455 _RequestedStop: when stop is requested.
457 """
458 t = threading.current_thread()
459 assert isinstance(t, _MirroredReplicaThread)
460 t.merge_fn = fn
461 t.merge_args = args
462 t.merge_kwargs = kwargs
463 t.captured_name_scope = t.graph.get_name_scope()
464 # Adding a "/" at end lets us re-enter this scope later.
465 if t.captured_name_scope:
466 t.captured_name_scope += "/"
468 t.captured_var_scope = variable_scope.get_variable_scope()
469 t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
471 t.merge_call_entered_in_eager = context.context().executing_eagerly()
473 # It is problematic if `merge_call` is called under a different graph other
474 # than the one that `_call_for_each_replica` is called under, there are
475 # 3 cases this can happen:
476 #
477 # 1. The `fn` passed to `_call_for_each_replica` is decorated with
478 # `tf.function` and there is a `merge_call` in `fn`. Since
479 # MirroredStrategy traces a separate function per thread (per device),
480 # and each trace takes a shared lock, the lock is never released by the
481 # first thread and subsequent replica threads cannot proceed to trace
482 # their own functions. This issue is addressed by always converting
483 # `_call_for_each_replica(tf.function(f))` to
484 # ``tf.function(_call_for_each_replica(f))`.` in
485 # `MirroredStrategy._call_for_each_replica`.
486 #
487 # 2. The `fn` passed to `_call_for_each_replica` contains a nested
488 # `tf.function`, and there is a `merge_call` in the nested `tf.function`.
489 # In this case each thread can successfully trace its own function, but
490 # since the `merge_fn` passed to `merge_call` is executed in the main
491 # thread (where `_call_for_each_replica` is executed), it can't access
492 # the tensors that come from different graphs.
493 #
494 # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow
495 # statement, and there is a `merge_call` inside the control-flow body,
496 # `fn` or `_call_for_each_replica` is decorated with `tf.function`.
497 # Control flow statement creates a separate graph for its body, similar
498 # to #2, `merge_fn` executed in the main thread can't access the
499 # tensors that come from different graphs.
500 #
501 # We raise an error for #2 and #3.
502 if ops.get_default_graph() != t.graph:
503 raise RuntimeError(
504 "`merge_call` called while defining a new graph or a tf.function."
505 " This can often happen if the function `fn` passed to"
506 " `strategy.run()` contains a nested `@tf.function`, and the nested "
507 "`@tf.function` contains a synchronization point, such as aggregating"
508 " gradients (e.g, optimizer.apply_gradients), or if the function `fn`"
509 " uses a control flow statement which contains a synchronization"
510 " point in the body. Such behaviors are not yet supported. Instead,"
511 " please avoid nested `tf.function`s or control flow statements that"
512 " may potentially cross a synchronization boundary, for example,"
513 " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`"
514 " inside a `tf.function` or move the control flow out of `fn`. If"
515 " you are subclassing a `tf.keras.Model`, please avoid decorating"
516 " overridden methods `test_step` and `train_step` in `tf.function`.")
518 t.has_paused.set()
519 t.should_run.wait()
520 t.should_run.clear()
521 if t.coord.should_stop():
522 raise _RequestedStop()
523 t.merge_call_entered_in_eager = None
524 return t.merge_result
526 @property
527 def devices(self):
528 distribute_lib.require_replica_context(self)
529 return [
530 self._strategy.extended.worker_devices_by_replica[
531 self._replica_id_in_sync_group]
532 ]