Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/critical_section_ops.py: 27%

134 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Critical Section object and execution logic.""" 

16 

17import collections 

18import contextlib 

19import threading 

20 

21from tensorflow.python.eager import context 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import control_flow_ops 

26from tensorflow.python.ops import gen_resource_variable_ops 

27from tensorflow.python.ops import tensor_array_ops 

28from tensorflow.python.util import nest 

29from tensorflow.python.util import object_identity 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33__all__ = ["CriticalSection"] 

34 

35 

36# Graph Keys 

37CRITICAL_SECTIONS = "critical_sections" 

38CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" 

39 

40 

41class _ExecutionSignature( 

42 collections.namedtuple("_ExecutionSignature", 

43 ("op", "handle", 

44 "resources", "exclusive_resource_access"))): 

45 """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" 

46 pass 

47 

48 

49def _identity(x): 

50 """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" 

51 if isinstance(x, tensor_array_ops.TensorArray): 

52 return x.identity() 

53 elif isinstance(x, ops.Operation): 

54 return control_flow_ops.group(x) 

55 elif context.executing_eagerly() and x is None: 

56 return None 

57 else: 

58 return array_ops.identity(x) 

59 

60 

61def _get_device_or_colocation(op): 

62 return op.device or _get_colocation(op) 

63 

64 

65def _get_colocation(op): 

66 """Get colocation symbol from op, if any.""" 

67 try: 

68 return op.get_attr("_class") 

69 except (ValueError, AttributeError): 

70 return None 

71 

72 

73_CRITICAL_SECTION_STACK = threading.local() 

74 

75 

76def _get_critical_section_stack(): 

77 try: 

78 return _CRITICAL_SECTION_STACK.value 

79 except AttributeError: 

80 _CRITICAL_SECTION_STACK.value = [] 

81 return _CRITICAL_SECTION_STACK.value 

82 

83 

84@contextlib.contextmanager 

85def _push_critical_section_stack(signature): 

86 """Push a CriticalSection._signature to the thread-local stack. 

87 

88 If the signature is already on the stack, raise an error because it means 

89 we're trying to execute inside the same locked CriticalSection, which 

90 will create a deadlock. 

91 

92 Args: 

93 signature: Tuple of the type `CriticalSection._signature`. Uniquely 

94 identifies a CriticalSection by its `shared_name`, `container`, 

95 and device. 

96 

97 Yields: 

98 An empty value. The context is guaranteed to run without deadlock. 

99 

100 Raises: 

101 ValueError: If the signature is already on the stack. 

102 RuntimeError: If another thread or function modifies the current stack 

103 entry during the yield. 

104 """ 

105 stack = _get_critical_section_stack() 

106 if signature in stack: 

107 raise ValueError( 

108 f"Attempting to lock a CriticalSection (signature={signature}) in which" 

109 " we are already running. This is illegal and may cause deadlocks.") 

110 stack.append(signature) 

111 try: 

112 yield 

113 finally: 

114 received_signature = stack.pop() 

115 if received_signature != signature: 

116 raise RuntimeError( 

117 "CriticalSection stack inconsistency: expected signature " 

118 f"{signature} but received {received_signature}") 

119 

120 

121@tf_export("CriticalSection") 

122class CriticalSection: 

123 """Critical section. 

124 

125 A `CriticalSection` object is a resource in the graph which executes subgraphs 

126 in **serial** order. A common example of a subgraph one may wish to run 

127 exclusively is the one given by the following function: 

128 

129 ```python 

130 v = resource_variable_ops.ResourceVariable(0.0, name="v") 

131 

132 def count(): 

133 value = v.read_value() 

134 with tf.control_dependencies([value]): 

135 with tf.control_dependencies([v.assign_add(1)]): 

136 return tf.identity(value) 

137 ``` 

138 

139 Here, a snapshot of `v` is captured in `value`; and then `v` is updated. 

140 The snapshot value is returned. 

141 

142 If multiple workers or threads all execute `count` in parallel, there is no 

143 guarantee that access to the variable `v` is atomic at any point within 

144 any thread's calculation of `count`. In fact, even implementing an atomic 

145 counter that guarantees that the user will see each value `0, 1, ...,` is 

146 currently impossible. 

147 

148 The solution is to ensure any access to the underlying resource `v` is 

149 only processed through a critical section: 

150 

151 ```python 

152 cs = CriticalSection() 

153 f1 = cs.execute(count) 

154 f2 = cs.execute(count) 

155 output = f1 + f2 

156 session.run(output) 

157 ``` 

158 The functions `f1` and `f2` will be executed serially, and updates to `v` 

159 will be atomic. 

160 

161 **NOTES** 

162 

163 All resource objects, including the critical section and any captured 

164 variables of functions executed on that critical section, will be 

165 colocated to the same device (host and cpu/gpu). 

166 

167 When using multiple critical sections on the same resources, there is no 

168 guarantee of exclusive access to those resources. This behavior is disallowed 

169 by default (but see the kwarg `exclusive_resource_access`). 

170 

171 For example, running the same function in two separate critical sections 

172 will not ensure serial execution: 

173 

174 ```python 

175 v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True) 

176 def accumulate(up): 

177 x = v.read_value() 

178 with tf.control_dependencies([x]): 

179 with tf.control_dependencies([v.assign_add(up)]): 

180 return tf.identity(x) 

181 ex1 = CriticalSection().execute( 

182 accumulate, 1.0, exclusive_resource_access=False) 

183 ex2 = CriticalSection().execute( 

184 accumulate, 1.0, exclusive_resource_access=False) 

185 bad_sum = ex1 + ex2 

186 sess.run(v.initializer) 

187 sess.run(bad_sum) # May return 0.0 

188 ``` 

189 """ 

190 

191 def __init__(self, name=None, shared_name=None, 

192 critical_section_def=None, import_scope=None): 

193 """Creates a critical section.""" 

194 context.ensure_initialized() 

195 if critical_section_def and name is not None: 

196 raise ValueError(f"Arguments critical_section_def={critical_section_def} " 

197 f"and shared_name={shared_name} are mutually exclusive. " 

198 "Please only specify one of them.") 

199 if critical_section_def: 

200 raise ValueError("Argument `critical_section_def` is not supported.") 

201 else: 

202 self._init_from_args(name, shared_name) 

203 

204 def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name 

205 """Initialize the CriticalSection from constructor arguments.""" 

206 with ops.name_scope(name, "CriticalSection", []) as name: 

207 with ops.init_scope(): 

208 # pylint: disable=protected-access 

209 container = ops.get_default_graph()._container 

210 # pylint: enable=protected-access 

211 if shared_name is None: 

212 shared_name = name 

213 if container is None: 

214 container = "" 

215 self._handle = gen_resource_variable_ops.mutex_v2( 

216 shared_name=shared_name, container=container, name=name) 

217 # Get a uniquely identifying signature for the handle. 

218 self._signature = ( 

219 container, 

220 # If shared_name is empty, a unique CriticalSection is created. 

221 shared_name or id(self._handle), 

222 _get_device_or_colocation(self._handle)) 

223 

224 if not context.executing_eagerly(): 

225 ops.add_to_collections(CRITICAL_SECTIONS, self) 

226 

227 @property 

228 def name(self): 

229 return self._handle.op.name 

230 

231 def execute(self, fn, exclusive_resource_access=True, name=None): 

232 """Execute function `fn()` inside the critical section. 

233 

234 `fn` should not accept any arguments. To add extra arguments to when 

235 calling `fn` in the critical section, create a lambda: 

236 

237 ```python 

238 critical_section.execute(lambda: fn(*my_args, **my_kwargs)) 

239 ``` 

240 

241 Args: 

242 fn: The function to execute. Must return at least one tensor. 

243 exclusive_resource_access: Whether the resources required by 

244 `fn` should be exclusive to this `CriticalSection`. Default: `True`. 

245 You may want to set this to `False` if you will be accessing a 

246 resource in read-only mode in two different CriticalSections. 

247 name: The name to use when creating the execute operation. 

248 

249 Returns: 

250 The tensors returned from `fn()`. 

251 

252 Raises: 

253 ValueError: If `fn` attempts to lock this `CriticalSection` in any nested 

254 or lazy way that may cause a deadlock. 

255 ValueError: If `exclusive_resource_access == True` and 

256 another `CriticalSection` has an execution requesting the same 

257 resources as `fn``. Note, even if `exclusive_resource_access` is 

258 `True`, if another execution in another `CriticalSection` was created 

259 without `exclusive_resource_access=True`, a `ValueError` will be raised. 

260 """ 

261 with ops.name_scope(name, "critical_section_execute", []): 

262 # Ensure that mutex locking only happens *after* all args and 

263 # kwargs have been executed. This avoids certain types of deadlocks. 

264 with _push_critical_section_stack(self._signature): 

265 lock = gen_resource_variable_ops.mutex_lock(self._handle) 

266 

267 if not context.executing_eagerly(): 

268 # NOTE(ebrevdo): This is to ensure we don't pick up spurious 

269 # Operations created by other threads. 

270 with ops.get_default_graph()._lock: # pylint: disable=protected-access 

271 existing_ops = ops.get_default_graph().get_operations() 

272 with ops.control_dependencies([lock]): 

273 r = fn() 

274 # TODO(ebrevdo): If creating critical sections in a python loop, 

275 # this makes graph creation time quadratic. Revisit if this 

276 # becomes a problem. 

277 created_ops = (set(ops.get_default_graph().get_operations()) 

278 .difference(existing_ops)) 

279 else: 

280 with ops.control_dependencies([lock]): 

281 r = fn() 

282 

283 if not context.executing_eagerly(): 

284 self._add_control_dependencies_to_lock(created_ops, lock.op) 

285 

286 # captured_resources is a list of resources that are directly 

287 # accessed only by ops created during fn(), not by any 

288 # ancestors of those ops in the graph. 

289 captured_resources = object_identity.ObjectIdentitySet([ 

290 input_ for op in created_ops 

291 for input_ in op.inputs 

292 if input_.dtype == dtypes.resource 

293 ]) 

294 

295 # NOTE(ebrevdo): The only time self._is_self_handle() is True 

296 # in this call is if one of the recently created ops, within 

297 # the execute(), themselves attempt to access the 

298 # CriticalSection. This will cause a deadlock. 

299 if any(self._is_self_handle(x) for x in captured_resources): 

300 raise ValueError( 

301 "Attempting to lock a CriticalSection in which we are " 

302 f"already running (signature={self._signature}). This is illegal " 

303 "and may cause deadlocks.") 

304 

305 self._check_multiple_access_to_resources( 

306 captured_resources, exclusive_resource_access) 

307 

308 r_flat = [_identity(x) for x in nest.flatten(r)] 

309 

310 with ops.control_dependencies(r_flat): 

311 # The identity must run on the same machine as self._handle 

312 with ops.colocate_with(self._handle): 

313 # Do not use array_ops.identity as there are special 

314 # optimizations within TensorFlow which seem to elide it 

315 # even when optimizations are disabled(!). 

316 ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( 

317 lock) 

318 

319 # Make sure that if any element of r is accessed, all of 

320 # them are executed together. 

321 r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) 

322 

323 with ops.control_dependencies([ensure_lock_exists]): 

324 outputs = nest.map_structure(_identity, r) 

325 

326 if not context.executing_eagerly(): 

327 signature = _ExecutionSignature( 

328 op=lock.op, 

329 handle=self._handle, 

330 resources=list(captured_resources), 

331 exclusive_resource_access=exclusive_resource_access) 

332 ops.add_to_collections( 

333 CRITICAL_SECTION_EXECUTIONS, signature) 

334 

335 return outputs 

336 

337 def _add_control_dependencies_to_lock(self, created_ops, lock_op): 

338 """To avoid deadlocks, all args must be executed before lock_op.""" 

339 # Get all arguments (explicit and captured) of all ops created by fn(). 

340 all_args = set([input_.op for op in created_ops for input_ in op.inputs]) 

341 all_args.update( 

342 input_op for op in created_ops for input_op in op.control_inputs) 

343 # Unfortunately, we can't use sets throughout because TF seems to 

344 # create new Operation objects for the same op sometimes; and we 

345 # can't rely on id(op). 

346 

347 # pylint: disable=protected-access 

348 all_args_dict = dict((op._id, op) for op in all_args) 

349 

350 # Remove ops created within fn, or that lock_op already has a 

351 # control dependency on. Also remove a possible self-loop. 

352 for op in created_ops: 

353 all_args_dict.pop(op._id, None) 

354 for op in lock_op.control_inputs: 

355 all_args_dict.pop(op._id, None) 

356 for input_ in lock_op.inputs: 

357 all_args_dict.pop(input_.op._id, None) 

358 all_args_dict.pop(lock_op._id, None) 

359 

360 all_args = all_args_dict.values() 

361 

362 if not all_args: 

363 # No control dependencies to add; return early. 

364 return 

365 

366 # This group is important: it ensures that any ops in all_args 

367 # outside the control context of the lock_op (and this fn, which 

368 # runs in the same context) are added to this context before 

369 # being added to the control dependencies of lock_op. 

370 all_args = control_flow_ops.group(*all_args) 

371 

372 lock_op._add_control_input(all_args) 

373 # pylint: enable=protected-access 

374 

375 def _is_self_handle(self, x): 

376 """Check if the tensor `x` is the same Mutex as `self._handle`.""" 

377 if isinstance(x, ops.EagerTensor): 

378 return x is self._handle 

379 return (x.op.type == "MutexV2" 

380 # blank shared_name means the op will create a unique one. 

381 and x.op.get_attr("shared_name") 

382 and (x.op.get_attr("shared_name") == 

383 self._handle.op.get_attr("shared_name")) 

384 and (x.op.device == self._handle.op.device 

385 or _get_colocation(x.op) == _get_colocation(self._handle.op))) 

386 

387 def _check_multiple_access_to_resources( 

388 self, captured_resources, exclusive_resource_access): 

389 """Raise if captured_resources are accessed by another CriticalSection. 

390 

391 Args: 

392 captured_resources: Set of tensors of type resource. 

393 exclusive_resource_access: Whether this execution requires exclusive 

394 resource access. 

395 

396 Raises: 

397 ValueError: If any tensors in `captured_resources` are also accessed 

398 by another `CriticalSection`, and at least one of them requires 

399 exclusive resource access. 

400 """ 

401 # Collections and op introspection does not work in eager 

402 # mode. This is generally ok; since eager mode (as of 

403 # writing) executes sequentially anyway. 

404 for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): 

405 if self._is_self_handle(sg.handle): 

406 # Other executions in the same critical section are allowed. 

407 continue 

408 if not (exclusive_resource_access or sg.exclusive_resource_access): 

409 # Neither execution requested exclusive access. 

410 continue 

411 resource_intersection = captured_resources.intersection(sg.resources) 

412 if resource_intersection: 

413 raise ValueError( 

414 "This execution would access resources: " 

415 f"{list(resource_intersection)}. Either this lock " 

416 f"(CriticalSection: {self._handle}) or lock '{sg}' " 

417 f"(CriticalSection: {sg.handle}) requested exclusive resource " 

418 "access of this resource. Did you mean to call execute with " 

419 "keyword argument exclusive_resource_access=False?")