Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/auto_control_deps.py: 19%

230 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""AutomaticControlDependencies and related functionality.""" 

16 

17import collections 

18import enum 

19 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.python.eager import context 

22from tensorflow.python.framework import auto_control_deps_utils as utils 

23from tensorflow.python.framework import dtypes as dtypes_module 

24from tensorflow.python.framework import indexed_slices 

25from tensorflow.python.framework import op_def_registry 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import registry 

28from tensorflow.python.framework import sparse_tensor 

29from tensorflow.python.ops import array_ops 

30from tensorflow.python.ops import control_flow_ops 

31from tensorflow.python.ops import control_flow_util 

32from tensorflow.python.ops import tensor_array_ops 

33from tensorflow.python.util import nest 

34from tensorflow.python.util import object_identity 

35from tensorflow.python.util import tf_decorator 

36 

37# LINT.IfChange 

38# Op types that should not run in program order, e.g. because they need to run 

39# asynchronously to avoid deadlock. 

40 

41ASYNC_STATEFUL_OPS = frozenset(( 

42 "CollectiveGather", 

43 "CollectiveReduce", 

44 "CollectiveBcastSend", 

45 "CollectiveBcastSendV2", 

46 "CollectiveBcastRecv", 

47 "CollectiveBcastRecvV2", 

48 "NcclAllReduce", 

49 # We do not add "Send" here since we want it to be added as a control output 

50 # in order to avoid being pruned. 

51 "Recv", 

52 "CollectiveInitializeCommunicator", 

53 "CollectiveAssignGroupV2", 

54)) 

55 

56LEGACY_RANDOM_OPS = frozenset(( 

57 # These may be used in variable initializers -- thus their execution should 

58 # not be dependent on other stateful operations. This is because although 

59 # according to program order, tf.Variables may be created in sequence, 

60 # their initialization happens outside of the program order (specifically, 

61 # in graph mode their initialization happens by calling a grouped 

62 # initializer operation or in eager mode, where initialization is lifted 

63 # out of the tf.function and executed the first time the function is 

64 # executed). 

65 # 

66 # Unless there is a specific dependency between the initializers 

67 # themselves (e.g. one initializer depends on a Variable whose value depends 

68 # on another initializer), the initialization can happen in any order so 

69 # long as it's before the associated Variable read operations. 

70 # 

71 # Note that in general the randomness of legacy random operations is only 

72 # guaranteed by providing a graph-level and op-level seed (and ordering of 

73 # the same op across multiple iterations of a while_loop is specifically not 

74 # guaranteed; see the discussion below). 

75 # 

76 # There is a possible race condition inside while_loop where the same 

77 # random OpKernel instantiation is reused across multiple steps 

78 # of the loop. Since legacy Random OpKernels have an internal rng state, 

79 # automatic dependency tracking across loop steps would likely 

80 # fix this race; and for that case this denylist is problematic. 

81 # However, since automatic dependency tracking inside while loops is not 

82 # currently supported, and there are no other examples of OpKernel reuse 

83 # (each OpKernel is associated with a unique op in graph mode), 

84 # this denylist has no effect on the aforementioned behavior. 

85 # 

86 # TODO(ebrevdo,skyewm): Modify the check against this denylist to 

87 # only occur when the op is inside a "variable initialization scope"; and 

88 # add proper autodeps inside while_loops that respects this updated check. 

89 "RandomUniform", 

90 "RandomUniformInt", 

91 "RandomStandardNormal", 

92 "ParameterizedTruncatedNormal", 

93 "TruncatedNormal", 

94 "RandomShuffle", 

95 "Multinomial", 

96 "RandomGamma", 

97 "RandomGammaGrad", 

98 "RandomPoisson", 

99 "RandomPoissonV2", 

100)) 

101 

102MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset(( 

103 "InfeedEnqueue", 

104 "InfeedEnqueueTuple", 

105 "EnqueueTPUEmbeddingSparseBatch", 

106 "EnqueueTPUEmbeddingIntegerBatch", 

107 "EnqueueTPUEmbeddingSparseTensorBatch", 

108 "EnqueueTPUEmbeddingRaggedTensorBatch", 

109 "EnqueueTPUEmbeddingArbitraryTensorBatch", 

110 "DynamicEnqueueTPUEmbeddingArbitraryTensorBatch", 

111)) 

112 

113# These ops are order-insensitive ans should in theory run, but at the moment 

114# they either always have the necessary data dependencies, or have workarounds 

115# in existing code that would break when adding new control deps. This 

116# inconsistency should be eventually fixed, but it would be more effective to 

117# retire the list instead. 

118SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset(( 

119 "CudnnRNN", 

120 "CudnnRNNBackprop", 

121 "CudnnRNNV2", 

122 "CudnnRNNV3", 

123 "CudnnRNNBackpropV2", 

124 "CudnnRNNBackpropV3", 

125 "RestoreV2", 

126 "SaveV2", 

127)) 

128# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc) 

129 

130# Op types that are marked as stateless, but should be allowlisted to add auto 

131# control dependencies. 

132_ALLOWLIST_STATELESS_OPS = [ 

133 # As TPU collective ops are blocking, if there are more than one collective 

134 # op in the function, we need to make sure different collectives ops are 

135 # scheduled in certain orders. Otherwise if at the same time all the 

136 # replicas are launching different collective ops/programs, it may cause 

137 # deadlock. 

138 "AllToAll", 

139 "CrossReplicaSum", 

140 "CollectivePermute", 

141] 

142 

143 

144def op_is_stateful(op): 

145 # pylint: disable=protected-access 

146 ret = ((op._is_stateful and 

147 ((op.type not in ASYNC_STATEFUL_OPS) and 

148 (op.type not in LEGACY_RANDOM_OPS) and 

149 (op.type not in SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS))) or 

150 (op.type in _ALLOWLIST_STATELESS_OPS)) 

151 return ret 

152 

153 

154class ResourceType(enum.Enum): 

155 READ_ONLY = "read-only" 

156 READ_WRITE = "read-write" 

157 

158 

159def collective_manager_ids_from_op(op): 

160 """Returns CollectiveManager ID from the op if one exists, else None. 

161 

162 CollectiveManager adds collective and no_op operations tagged with an ID, 

163 unique to the manager object. This function extracts that ID, or None, if the 

164 node was not generated by a CollectiveManager. 

165 

166 Args: 

167 op: `Operation` to get the collective manager ID from. 

168 

169 Returns: 

170 List of CollectiveManager IDs used by the op. 

171 """ 

172 if op.type == "CollectiveReduce": 

173 try: 

174 return [op.get_attr("_collective_manager_id")] 

175 except ValueError: 

176 pass 

177 elif op.type == "StatefulPartitionedCall": 

178 try: 

179 return op.get_attr(utils.COLLECTIVE_MANAGER_IDS) 

180 except ValueError: 

181 pass 

182 return [] 

183 

184 

185class AutomaticControlDependencies(object): 

186 """Context manager to automatically add control dependencies. 

187 

188 Code under this context manager will act as if a sensible set of control 

189 dependencies were present. More specifically: 

190 1. All stateful ops in the scope will execute (with the exception of ops in 

191 ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS) 

192 2. Stateful ops which modify the same resource will execute in program order 

193 

194 Note: creating variables in an automatic control dependencies context is not 

195 supported (the value of the variables will never change as they will keep 

196 getting reinitialized). 

197 

198 NOT THREAD SAFE 

199 """ 

200 

201 def __init__(self): 

202 self._returned_tensors = object_identity.ObjectIdentitySet() 

203 self.ops_which_must_run = set() 

204 self._independent_ops = [] 

205 

206 def mark_as_return(self, tensor): 

207 """Acts like identity but marks the `Tensor` as a return value. 

208 

209 This will possibly return a copy of the `Tensor`. Usage: 

210 

211 ``` 

212 with AutomaticControlDependencies() as a: 

213 ... 

214 t = a.mark_as_return(t) 

215 _ = ...(t...) # i.e. it's safe to use t here 

216 ``` 

217 

218 Args: 

219 tensor: the `Tensor` to be marked 

220 

221 Returns: 

222 a copy of the `Tensor`. 

223 """ 

224 if isinstance(tensor, indexed_slices.IndexedSlices): 

225 values = array_ops.identity(tensor.values) 

226 indices = array_ops.identity(tensor.indices) 

227 self._returned_tensors.add(indices) 

228 self._returned_tensors.add(values) 

229 return indexed_slices.IndexedSlices( 

230 values, indices, dense_shape=tensor.dense_shape) 

231 elif isinstance(tensor, sparse_tensor.SparseTensor): 

232 values = array_ops.identity(tensor.values) 

233 indices = array_ops.identity(tensor.indices) 

234 self._returned_tensors.add(indices) 

235 self._returned_tensors.add(values) 

236 return sparse_tensor.SparseTensor( 

237 indices, values, dense_shape=tensor.dense_shape) 

238 elif isinstance(tensor, tensor_array_ops.TensorArray): 

239 flow = array_ops.identity(tensor.flow) 

240 self._returned_tensors.add(flow) 

241 return tensor_array_ops.build_ta_with_new_flow(tensor, flow) 

242 # We want to make the return values depend on the stateful operations, but 

243 # we don't want to introduce a cycle, so we make the return value the result 

244 # of a new identity operation that the stateful operations definitely don't 

245 # depend on. 

246 tensor = array_ops.identity(tensor) 

247 self._returned_tensors.add(tensor) 

248 return tensor 

249 

250 def run_independently(self, op): 

251 """Marks the given op as independent. 

252 

253 Overrides any other rule for the op. 

254 

255 Independent ops are guaranteed to execute before the return values, but 

256 are allowed to run in parallel with everything else. Use in programs which 

257 can guarantee that an op has side effects that don't affect any other op. 

258 

259 Args: 

260 op: An operation 

261 """ 

262 self._independent_ops.append(op) 

263 op._set_attr("_independent_side_effects", attr_value_pb2.AttrValue(b=True)) # pylint: disable=protected-access 

264 

265 def __enter__(self): 

266 if context.executing_eagerly(): 

267 return self 

268 # This code assumes no other thread is adding ops to the graph while 

269 # we're adding ops to the graph. 

270 # TODO(apassos): Fix this by locking the graph or using a temporary 

271 # graph (but that would mess up devices and collections at least, 

272 # probably other things as well). 

273 g = ops.get_default_graph() 

274 self._graph = g 

275 g._add_control_dependencies = True # pylint: disable=protected-access 

276 g.experimental_acd_manager = self 

277 self._n_operations = g.num_operations() 

278 return self 

279 

280 def _process_switch(self, switch_op, ops_which_must_run, 

281 last_write_to_resource, merge_for_resource): 

282 """Processes a switch node for a resource input. 

283 

284 When tensorflow creates a cond, it creates a control flow context for each 

285 branch of the cond. Each external tensor accessed by that branch is routed 

286 through a switch op, which gets created in the graph _after_ the op which 

287 uses that tensor get created. 

288 

289 If the resource comes from another switch op we process that one first. 

290 

291 _process_switch creates a corresponding merge node for the switch node. This 

292 merge node is added to the outer control flow context of the switch 

293 node. We also ensure that: 

294 

295 1. The switch node executes after the previous op which used the resource 

296 tensor 

297 

298 2. Any op which uses a resource output of the switch node executes before 

299 the merge for the switch node. 

300 

301 3. The next op which uses the input resource to the switch node (which 

302 might be another switch node for the other branch of the conditional) 

303 will execute after the merge node is done. 

304 

305 4. The merge node is marked as must_run so it will run even if no 

306 subsequent operation uses the resource. 

307 

308 Args: 

309 switch_op: the switch op to be processed 

310 ops_which_must_run: the set of ops which must run 

311 last_write_to_resource: map from resource tensor to last op updating 

312 it 

313 merge_for_resource: map from resource tensor to merge which must follow 

314 all usages of it. 

315 """ 

316 # pylint: disable=protected-access 

317 inp = switch_op.inputs[0] 

318 input_id = ops.tensor_id(inp) 

319 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": 

320 self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, 

321 merge_for_resource) 

322 output = switch_op.outputs[0] 

323 output_id = ops.tensor_id(output) 

324 if output_id in merge_for_resource: 

325 return 

326 new_merge = control_flow_ops.merge( 

327 switch_op.outputs, name="artificial_merge") 

328 new_merge[0].op._control_flow_context = ( 

329 switch_op._control_flow_context.outer_context) 

330 # Ensures the merge always runs 

331 ops_which_must_run.add(new_merge[0].op) 

332 if input_id in last_write_to_resource: 

333 # Ensures the switch executes after the previous op using the resource. 

334 switch_op._add_control_input(last_write_to_resource[input_id]) 

335 # Ensure the next op outside the cond happens after the merge. 

336 last_write_to_resource[input_id] = new_merge[0].op 

337 if input_id in merge_for_resource: 

338 merge_for_resource[input_id]._add_control_input(new_merge[0].op) 

339 for o in switch_op.outputs: 

340 # Ensures the merge will execute after all ops inside the cond 

341 merge_for_resource[ops.tensor_id(o)] = new_merge[0].op 

342 

343 def __exit__(self, unused_type, unused_value, unused_traceback): 

344 # pylint: disable=protected-access 

345 if context.executing_eagerly(): 

346 return 

347 

348 if self._graph is not ops.get_default_graph(): 

349 raise RuntimeError( 

350 "Within the automatic control dependency context, the default graph" 

351 f" cannot change. Upon entry it was {self._graph}, but on exit it" 

352 f" changed to {ops.get_default_graph()}") 

353 

354 outer_graph = getattr(self._graph, "outer_graph", None) 

355 if outer_graph is not None: 

356 self._graph._add_control_dependencies = outer_graph._add_control_dependencies 

357 else: 

358 self._graph._add_control_dependencies = False 

359 self._graph.experimental_acd_manager = None 

360 

361 # map from resource tensor to the last op which wrote to it 

362 last_write_to_resource = {} 

363 # map from resource tensor to the list of reads from it since the last 

364 # write or since the beginning of the function. 

365 reads_since_last_write_to_resource = collections.defaultdict(list) 

366 # CollectiveManager manager_ids within a particular function call should not 

367 # be needed outside of that function call. So we keep them separate (though 

368 # the general idea of the maps is the same, in the future, we'll need to 

369 # correctly thread the control output outside). 

370 # Map from collective manager scope to the last op which used it 

371 collective_manager_scopes_opened = {} 

372 collective_manager_scopes_used = {} 

373 # set of conditional and loop exits 

374 ops_which_must_run = set() 

375 # merge which must depend on ops which use this resource 

376 merge_for_resource = {} 

377 

378 new_operations = self._graph.get_operations()[self._n_operations:] 

379 

380 # Ensures that uses of resource tensors get serialized properly and all 

381 # execute. This is done by keeping a map from resource tensor to the last op 

382 # in graph-construction order which used it (last_write_to_resource). 

383 # 

384 # Conditionals are written in TensorFlow such that every external tensor 

385 # accessed in the conditional goes through a switch op and every return 

386 # tensor (it's guaranteed that there will be at least one) goes through a 

387 # merge op. 

388 # 

389 # To handle conditionals, switches are handled in a special way (see 

390 # comments for _process_switch). Merge nodes created by TF's conditional 

391 # logic (as opposed to by _process_switch) are forced to run and also get a 

392 # control dependency added to them to ensure all stateful ops inside their 

393 # control flow context run. 

394 # 

395 # We also ensure that if an op is using a resource output by a switch node 

396 # (that is, a resource tensor for which there's a value in 

397 # merge_for_resource) this op will run before the merge for that resource. 

398 # 

399 # We try to add control inputs to nodes respecting their control flow 

400 # contexts to avoid dead nodes propagating everywhere and leading to 

401 # "retval[0] doesn't have value" errors. If a node gets a control dependency 

402 # on a dead node (i.e. a note from an untaken control flow branch) that node 

403 # will be marked as dead unless it's a merge node. 

404 # 

405 # TODO(apassos): serialize non-resource-taking stateful ops as well, and 

406 # test that it works. Support while loops. Support init_scope escaping from 

407 # this. 

408 for op in new_operations: 

409 # TODO(apassos) make this code safely support while loops. 

410 if control_flow_util.IsInWhileLoop(op): 

411 continue 

412 control_inputs = set() 

413 

414 if op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS: 

415 # This will add it to self._independent_ops, but also mark it with an 

416 # attribute. 

417 self.run_independently(op) 

418 

419 if op in self._independent_ops: 

420 ops_which_must_run.add(op) 

421 continue 

422 

423 # Ensure stateful ops run. 

424 # Read-only ops are added to control outputs if the read value is 

425 # consumed. This covers the case when the read value is returned from 

426 # the function since that goes through a tf.identity in mark_as_return. 

427 if ((op_def_registry.get(op.type) is None) or 

428 (op_is_stateful(op) and 

429 (op.type not in utils.RESOURCE_READ_OPS or 

430 any(output.consumers() for output in op.outputs)))): 

431 ops_which_must_run.add(op) 

432 

433 # Make a note of all opened manager_ids. 

434 if op.type == "NoOp": 

435 try: 

436 collective_manager_scopes_opened[op.get_attr( 

437 "_collective_manager_id")] = op 

438 except ValueError: 

439 pass 

440 # Ignore switches (they're handled separately) 

441 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: 

442 continue 

443 # Make merges trigger all other computation which must run 

444 # TODO(mdan): Don't do this. Write a transform to chains instead. 

445 # See core/common_runtime/control_flow_deps_to_chains.cc. 

446 if op.type == "Merge": 

447 for o in ops_which_must_run: 

448 op._add_control_input(o) 

449 for inp in o.inputs: 

450 input_id = ops.tensor_id(inp) 

451 if input_id in last_write_to_resource: 

452 last_write_to_resource[input_id] = op 

453 ops_which_must_run = set([op]) 

454 continue 

455 

456 resource_inputs = set() 

457 # Check for any resource inputs. If we find any, we update control_inputs 

458 # and last_write_to_resource. 

459 for inp, resource_type in _get_resource_inputs(op): 

460 is_read = resource_type == ResourceType.READ_ONLY 

461 input_id = ops.tensor_id(inp) 

462 

463 # If the op receives the same resource tensor twice as an input, we skip 

464 # to avoid the op getting a control dependency on itself. 

465 if input_id in resource_inputs: 

466 continue 

467 

468 resource_inputs.add(input_id) 

469 # Deal with switches, finally. 

470 if inp.op.type == "Switch": 

471 self._process_switch(inp.op, ops_which_must_run, 

472 last_write_to_resource, merge_for_resource) 

473 is_building_function = op.graph.building_function 

474 # Ensure uses of resources are serialized 

475 if input_id in last_write_to_resource: 

476 if is_building_function or ( 

477 last_write_to_resource[input_id]._control_flow_context 

478 is op._control_flow_context): 

479 control_inputs.add(last_write_to_resource[input_id]) 

480 # Ensure merges happen after the closing of a cond block 

481 if input_id in merge_for_resource: 

482 merge_for_resource[input_id]._add_control_input(op) 

483 if is_read: 

484 reads_since_last_write_to_resource[input_id].append(op) 

485 else: 

486 control_inputs.update(reads_since_last_write_to_resource[input_id]) 

487 reads_since_last_write_to_resource[input_id] = [] 

488 last_write_to_resource[input_id] = op 

489 

490 if (op_is_stateful(op) and not resource_inputs 

491 and op._control_flow_context is None): 

492 if None in last_write_to_resource: 

493 op._add_control_input(last_write_to_resource[None]) 

494 last_write_to_resource[None] = op 

495 

496 # Ensure ordering of collective ops 

497 manager_ids = collective_manager_ids_from_op(op) 

498 for manager_id in manager_ids: 

499 if manager_id in collective_manager_scopes_opened: 

500 # Chain this function call if the scope was opened. 

501 op._add_control_input(collective_manager_scopes_opened[manager_id]) 

502 collective_manager_scopes_opened[manager_id] = op 

503 else: 

504 # If this op is in a scope not created here, create a chain starting 

505 # at this op. 

506 if manager_id in collective_manager_scopes_used: 

507 op._add_control_input(collective_manager_scopes_used[manager_id]) 

508 collective_manager_scopes_used[manager_id] = op 

509 

510 if control_inputs and not is_building_function: 

511 control_inputs = [ 

512 c for c in control_inputs 

513 if c._control_flow_context is op._control_flow_context 

514 ] 

515 

516 op._add_control_inputs(control_inputs) 

517 

518 # Ensure all ops which must run do run 

519 self.ops_which_must_run.update(ops_which_must_run) 

520 

521 control_output_op = None 

522 for idx, r in enumerate( 

523 nest.flatten(list(self._returned_tensors), expand_composites=True)): 

524 if self.ops_which_must_run: 

525 updated_ops_which_must_run = [] 

526 if r.graph.building_function: 

527 # There may be many stateful ops in the graph. Adding them as 

528 # control inputs to each function output could create excessive 

529 # control edges in the graph. Thus we create an intermediate No-op to 

530 # chain the control dependencies between stateful ops and function 

531 # outputs. 

532 if idx == 0: 

533 control_output_op = control_flow_ops.no_op() 

534 control_output_op._add_control_inputs(self.ops_which_must_run) 

535 updated_ops_which_must_run = [control_output_op] 

536 else: 

537 updated_ops_which_must_run = [ 

538 o for o in self.ops_which_must_run 

539 if o._control_flow_context is r.op._control_flow_context 

540 ] 

541 r.op._add_control_inputs(updated_ops_which_must_run) 

542 

543 self.collective_manager_ids_used = collective_manager_scopes_used 

544 

545 

546_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers") 

547 

548 

549def register_acd_resource_resolver(f): 

550 """Register a function for resolving resources touched by an op. 

551 

552 `f` is called for every Operation added in the ACD context with the op's 

553 original resource reads and writes. `f` is expected to update the sets of 

554 resource reads and writes in-place and return True if it updated either of the 

555 sets, False otherwise. 

556 

557 Example: 

558 @register_acd_resource_resolver 

559 def identity_resolver(op, resource_reads, resource_writes): 

560 # op: The `Operation` being processed by ACD currently. 

561 # resource_reads: An `ObjectIdentitySet` of read-only resources. 

562 # resource_writes: An `ObjectIdentitySet` of read-write resources. 

563 def update(resource_inputs): 

564 to_remove = [] 

565 to_add = [] 

566 for resource in resource_inputs: 

567 if resource.op.type == "Identity": 

568 to_remove.append(resource) 

569 to_add.extend(resource.op.inputs) 

570 for t in to_remove: 

571 resource_inputs.discard(t) 

572 resource_inputs.update(to_add) 

573 return to_add or to_remove 

574 return update(resource_reads) or update(resource_writes) 

575 

576 Args: 

577 f: Python function with signature 

578 (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool 

579 

580 Returns: 

581 The function `f` after adding it to the registry. 

582 """ 

583 _acd_resource_resolvers_registry.register(f) 

584 return f 

585 

586 

587@register_acd_resource_resolver 

588def _identity_resolver(op, resource_reads, resource_writes): 

589 """Replaces Identity output with its input in resource_inputs.""" 

590 del op 

591 def update(resource_inputs): 

592 to_remove = [] 

593 to_add = [] 

594 for resource in resource_inputs: 

595 if resource.op.type == "Identity": 

596 to_remove.append(resource) 

597 to_add.extend(resource.op.inputs) 

598 for t in to_remove: 

599 resource_inputs.discard(t) 

600 resource_inputs.update(to_add) 

601 return to_add or to_remove 

602 

603 return update(resource_reads) or update(resource_writes) 

604 

605 

606def _get_resource_inputs(op): 

607 """Returns an iterable of resources touched by this `op`.""" 

608 reads, writes = utils.get_read_write_resource_inputs(op) 

609 saturated = False 

610 while not saturated: 

611 saturated = True 

612 for key in _acd_resource_resolvers_registry.list(): 

613 # Resolvers should return true if they are updating the list of 

614 # resource_inputs. 

615 # TODO(srbs): An alternate would be to just compare the old and new set 

616 # but that may not be as fast. 

617 updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes) 

618 if updated: 

619 # Conservatively remove any resources from `reads` that are also writes. 

620 reads = reads.difference(writes) 

621 saturated = saturated and not updated 

622 

623 # Note: A resource handle that is not written to is treated as read-only. We 

624 # don't have a special way of denoting an unused resource. 

625 for t in reads: 

626 yield (t, ResourceType.READ_ONLY) 

627 for t in writes: 

628 yield (t, ResourceType.READ_WRITE) 

629 

630 

631def automatic_control_dependencies(f): 

632 """Wraps f to automatically insert control dependencies. 

633 

634 The inserted dependencies ensure that: 

635 1. All stateful ops in f run when the result of f runs 

636 2. Updates to the same resources happen in order. 

637 

638 Args: 

639 f: the function to be wrapped. 

640 

641 Returns: 

642 The wrapped function. 

643 """ 

644 

645 def wrapper(*args, **kwargs): 

646 with AutomaticControlDependencies() as a: 

647 result = f(*args, **kwargs) 

648 result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] 

649 return nest.pack_sequence_as(result, result_flat) 

650 

651 return tf_decorator.make_decorator(f, wrapper)