Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/critical_section_ops.py: 27%
134 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Critical Section object and execution logic."""
17import collections
18import contextlib
19import threading
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import gen_resource_variable_ops
27from tensorflow.python.ops import tensor_array_ops
28from tensorflow.python.util import nest
29from tensorflow.python.util import object_identity
30from tensorflow.python.util.tf_export import tf_export
33__all__ = ["CriticalSection"]
36# Graph Keys
37CRITICAL_SECTIONS = "critical_sections"
38CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
41class _ExecutionSignature(
42 collections.namedtuple("_ExecutionSignature",
43 ("op", "handle",
44 "resources", "exclusive_resource_access"))):
45 """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
46 pass
49def _identity(x):
50 """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
51 if isinstance(x, tensor_array_ops.TensorArray):
52 return x.identity()
53 elif isinstance(x, ops.Operation):
54 return control_flow_ops.group(x)
55 elif context.executing_eagerly() and x is None:
56 return None
57 else:
58 return array_ops.identity(x)
61def _get_device_or_colocation(op):
62 return op.device or _get_colocation(op)
65def _get_colocation(op):
66 """Get colocation symbol from op, if any."""
67 try:
68 return op.get_attr("_class")
69 except (ValueError, AttributeError):
70 return None
73_CRITICAL_SECTION_STACK = threading.local()
76def _get_critical_section_stack():
77 try:
78 return _CRITICAL_SECTION_STACK.value
79 except AttributeError:
80 _CRITICAL_SECTION_STACK.value = []
81 return _CRITICAL_SECTION_STACK.value
84@contextlib.contextmanager
85def _push_critical_section_stack(signature):
86 """Push a CriticalSection._signature to the thread-local stack.
88 If the signature is already on the stack, raise an error because it means
89 we're trying to execute inside the same locked CriticalSection, which
90 will create a deadlock.
92 Args:
93 signature: Tuple of the type `CriticalSection._signature`. Uniquely
94 identifies a CriticalSection by its `shared_name`, `container`,
95 and device.
97 Yields:
98 An empty value. The context is guaranteed to run without deadlock.
100 Raises:
101 ValueError: If the signature is already on the stack.
102 RuntimeError: If another thread or function modifies the current stack
103 entry during the yield.
104 """
105 stack = _get_critical_section_stack()
106 if signature in stack:
107 raise ValueError(
108 f"Attempting to lock a CriticalSection (signature={signature}) in which"
109 " we are already running. This is illegal and may cause deadlocks.")
110 stack.append(signature)
111 try:
112 yield
113 finally:
114 received_signature = stack.pop()
115 if received_signature != signature:
116 raise RuntimeError(
117 "CriticalSection stack inconsistency: expected signature "
118 f"{signature} but received {received_signature}")
121@tf_export("CriticalSection")
122class CriticalSection:
123 """Critical section.
125 A `CriticalSection` object is a resource in the graph which executes subgraphs
126 in **serial** order. A common example of a subgraph one may wish to run
127 exclusively is the one given by the following function:
129 ```python
130 v = resource_variable_ops.ResourceVariable(0.0, name="v")
132 def count():
133 value = v.read_value()
134 with tf.control_dependencies([value]):
135 with tf.control_dependencies([v.assign_add(1)]):
136 return tf.identity(value)
137 ```
139 Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
140 The snapshot value is returned.
142 If multiple workers or threads all execute `count` in parallel, there is no
143 guarantee that access to the variable `v` is atomic at any point within
144 any thread's calculation of `count`. In fact, even implementing an atomic
145 counter that guarantees that the user will see each value `0, 1, ...,` is
146 currently impossible.
148 The solution is to ensure any access to the underlying resource `v` is
149 only processed through a critical section:
151 ```python
152 cs = CriticalSection()
153 f1 = cs.execute(count)
154 f2 = cs.execute(count)
155 output = f1 + f2
156 session.run(output)
157 ```
158 The functions `f1` and `f2` will be executed serially, and updates to `v`
159 will be atomic.
161 **NOTES**
163 All resource objects, including the critical section and any captured
164 variables of functions executed on that critical section, will be
165 colocated to the same device (host and cpu/gpu).
167 When using multiple critical sections on the same resources, there is no
168 guarantee of exclusive access to those resources. This behavior is disallowed
169 by default (but see the kwarg `exclusive_resource_access`).
171 For example, running the same function in two separate critical sections
172 will not ensure serial execution:
174 ```python
175 v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
176 def accumulate(up):
177 x = v.read_value()
178 with tf.control_dependencies([x]):
179 with tf.control_dependencies([v.assign_add(up)]):
180 return tf.identity(x)
181 ex1 = CriticalSection().execute(
182 accumulate, 1.0, exclusive_resource_access=False)
183 ex2 = CriticalSection().execute(
184 accumulate, 1.0, exclusive_resource_access=False)
185 bad_sum = ex1 + ex2
186 sess.run(v.initializer)
187 sess.run(bad_sum) # May return 0.0
188 ```
189 """
191 def __init__(self, name=None, shared_name=None,
192 critical_section_def=None, import_scope=None):
193 """Creates a critical section."""
194 context.ensure_initialized()
195 if critical_section_def and name is not None:
196 raise ValueError(f"Arguments critical_section_def={critical_section_def} "
197 f"and shared_name={shared_name} are mutually exclusive. "
198 "Please only specify one of them.")
199 if critical_section_def:
200 raise ValueError("Argument `critical_section_def` is not supported.")
201 else:
202 self._init_from_args(name, shared_name)
204 def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
205 """Initialize the CriticalSection from constructor arguments."""
206 with ops.name_scope(name, "CriticalSection", []) as name:
207 with ops.init_scope():
208 # pylint: disable=protected-access
209 container = ops.get_default_graph()._container
210 # pylint: enable=protected-access
211 if shared_name is None:
212 shared_name = name
213 if container is None:
214 container = ""
215 self._handle = gen_resource_variable_ops.mutex_v2(
216 shared_name=shared_name, container=container, name=name)
217 # Get a uniquely identifying signature for the handle.
218 self._signature = (
219 container,
220 # If shared_name is empty, a unique CriticalSection is created.
221 shared_name or id(self._handle),
222 _get_device_or_colocation(self._handle))
224 if not context.executing_eagerly():
225 ops.add_to_collections(CRITICAL_SECTIONS, self)
227 @property
228 def name(self):
229 return self._handle.op.name
231 def execute(self, fn, exclusive_resource_access=True, name=None):
232 """Execute function `fn()` inside the critical section.
234 `fn` should not accept any arguments. To add extra arguments to when
235 calling `fn` in the critical section, create a lambda:
237 ```python
238 critical_section.execute(lambda: fn(*my_args, **my_kwargs))
239 ```
241 Args:
242 fn: The function to execute. Must return at least one tensor.
243 exclusive_resource_access: Whether the resources required by
244 `fn` should be exclusive to this `CriticalSection`. Default: `True`.
245 You may want to set this to `False` if you will be accessing a
246 resource in read-only mode in two different CriticalSections.
247 name: The name to use when creating the execute operation.
249 Returns:
250 The tensors returned from `fn()`.
252 Raises:
253 ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
254 or lazy way that may cause a deadlock.
255 ValueError: If `exclusive_resource_access == True` and
256 another `CriticalSection` has an execution requesting the same
257 resources as `fn``. Note, even if `exclusive_resource_access` is
258 `True`, if another execution in another `CriticalSection` was created
259 without `exclusive_resource_access=True`, a `ValueError` will be raised.
260 """
261 with ops.name_scope(name, "critical_section_execute", []):
262 # Ensure that mutex locking only happens *after* all args and
263 # kwargs have been executed. This avoids certain types of deadlocks.
264 with _push_critical_section_stack(self._signature):
265 lock = gen_resource_variable_ops.mutex_lock(self._handle)
267 if not context.executing_eagerly():
268 # NOTE(ebrevdo): This is to ensure we don't pick up spurious
269 # Operations created by other threads.
270 with ops.get_default_graph()._lock: # pylint: disable=protected-access
271 existing_ops = ops.get_default_graph().get_operations()
272 with ops.control_dependencies([lock]):
273 r = fn()
274 # TODO(ebrevdo): If creating critical sections in a python loop,
275 # this makes graph creation time quadratic. Revisit if this
276 # becomes a problem.
277 created_ops = (set(ops.get_default_graph().get_operations())
278 .difference(existing_ops))
279 else:
280 with ops.control_dependencies([lock]):
281 r = fn()
283 if not context.executing_eagerly():
284 self._add_control_dependencies_to_lock(created_ops, lock.op)
286 # captured_resources is a list of resources that are directly
287 # accessed only by ops created during fn(), not by any
288 # ancestors of those ops in the graph.
289 captured_resources = object_identity.ObjectIdentitySet([
290 input_ for op in created_ops
291 for input_ in op.inputs
292 if input_.dtype == dtypes.resource
293 ])
295 # NOTE(ebrevdo): The only time self._is_self_handle() is True
296 # in this call is if one of the recently created ops, within
297 # the execute(), themselves attempt to access the
298 # CriticalSection. This will cause a deadlock.
299 if any(self._is_self_handle(x) for x in captured_resources):
300 raise ValueError(
301 "Attempting to lock a CriticalSection in which we are "
302 f"already running (signature={self._signature}). This is illegal "
303 "and may cause deadlocks.")
305 self._check_multiple_access_to_resources(
306 captured_resources, exclusive_resource_access)
308 r_flat = [_identity(x) for x in nest.flatten(r)]
310 with ops.control_dependencies(r_flat):
311 # The identity must run on the same machine as self._handle
312 with ops.colocate_with(self._handle):
313 # Do not use array_ops.identity as there are special
314 # optimizations within TensorFlow which seem to elide it
315 # even when optimizations are disabled(!).
316 ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
317 lock)
319 # Make sure that if any element of r is accessed, all of
320 # them are executed together.
321 r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
323 with ops.control_dependencies([ensure_lock_exists]):
324 outputs = nest.map_structure(_identity, r)
326 if not context.executing_eagerly():
327 signature = _ExecutionSignature(
328 op=lock.op,
329 handle=self._handle,
330 resources=list(captured_resources),
331 exclusive_resource_access=exclusive_resource_access)
332 ops.add_to_collections(
333 CRITICAL_SECTION_EXECUTIONS, signature)
335 return outputs
337 def _add_control_dependencies_to_lock(self, created_ops, lock_op):
338 """To avoid deadlocks, all args must be executed before lock_op."""
339 # Get all arguments (explicit and captured) of all ops created by fn().
340 all_args = set([input_.op for op in created_ops for input_ in op.inputs])
341 all_args.update(
342 input_op for op in created_ops for input_op in op.control_inputs)
343 # Unfortunately, we can't use sets throughout because TF seems to
344 # create new Operation objects for the same op sometimes; and we
345 # can't rely on id(op).
347 # pylint: disable=protected-access
348 all_args_dict = dict((op._id, op) for op in all_args)
350 # Remove ops created within fn, or that lock_op already has a
351 # control dependency on. Also remove a possible self-loop.
352 for op in created_ops:
353 all_args_dict.pop(op._id, None)
354 for op in lock_op.control_inputs:
355 all_args_dict.pop(op._id, None)
356 for input_ in lock_op.inputs:
357 all_args_dict.pop(input_.op._id, None)
358 all_args_dict.pop(lock_op._id, None)
360 all_args = all_args_dict.values()
362 if not all_args:
363 # No control dependencies to add; return early.
364 return
366 # This group is important: it ensures that any ops in all_args
367 # outside the control context of the lock_op (and this fn, which
368 # runs in the same context) are added to this context before
369 # being added to the control dependencies of lock_op.
370 all_args = control_flow_ops.group(*all_args)
372 lock_op._add_control_input(all_args)
373 # pylint: enable=protected-access
375 def _is_self_handle(self, x):
376 """Check if the tensor `x` is the same Mutex as `self._handle`."""
377 if isinstance(x, ops.EagerTensor):
378 return x is self._handle
379 return (x.op.type == "MutexV2"
380 # blank shared_name means the op will create a unique one.
381 and x.op.get_attr("shared_name")
382 and (x.op.get_attr("shared_name") ==
383 self._handle.op.get_attr("shared_name"))
384 and (x.op.device == self._handle.op.device
385 or _get_colocation(x.op) == _get_colocation(self._handle.op)))
387 def _check_multiple_access_to_resources(
388 self, captured_resources, exclusive_resource_access):
389 """Raise if captured_resources are accessed by another CriticalSection.
391 Args:
392 captured_resources: Set of tensors of type resource.
393 exclusive_resource_access: Whether this execution requires exclusive
394 resource access.
396 Raises:
397 ValueError: If any tensors in `captured_resources` are also accessed
398 by another `CriticalSection`, and at least one of them requires
399 exclusive resource access.
400 """
401 # Collections and op introspection does not work in eager
402 # mode. This is generally ok; since eager mode (as of
403 # writing) executes sequentially anyway.
404 for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
405 if self._is_self_handle(sg.handle):
406 # Other executions in the same critical section are allowed.
407 continue
408 if not (exclusive_resource_access or sg.exclusive_resource_access):
409 # Neither execution requested exclusive access.
410 continue
411 resource_intersection = captured_resources.intersection(sg.resources)
412 if resource_intersection:
413 raise ValueError(
414 "This execution would access resources: "
415 f"{list(resource_intersection)}. Either this lock "
416 f"(CriticalSection: {self._handle}) or lock '{sg}' "
417 f"(CriticalSection: {sg.handle}) requested exclusive resource "
418 "access of this resource. Did you mean to call execute with "
419 "keyword argument exclusive_resource_access=False?")