Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_replication.py: 17%

311 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file8 except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ====================================== 

15 

16"""OutsideCompilation, TPUReplicateContext, and supporting functions.""" 

17 

18from typing import Any, Callable, List, Optional, Text, Tuple, Union 

19from absl import logging 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.python.distribute import device_util 

22from tensorflow.python.distribute import distribute_lib 

23from tensorflow.python.framework import device as pydev 

24from tensorflow.python.framework import errors 

25from tensorflow.python.framework import func_graph 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import control_flow_ops 

29from tensorflow.python.ops import variables 

30from tensorflow.python.tpu import device_assignment as device_assignment_lib 

31from tensorflow.python.tpu.ops import tpu_ops 

32from tensorflow.python.types import core as core_types 

33from tensorflow.python.util import compat 

34from tensorflow.python.util.tf_export import tf_export 

35 

36_MAX_WARNING_LINES = 5 

37_TPU_REPLICATE_ATTR = "_tpu_replicate" 

38_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" 

39 

40# Operations that indicate some error in the users graph, e.g. a placeholder 

41# that's introduced outside of the infeed. 

42_DENYLISTED_OPS = frozenset([ 

43 "Placeholder", 

44]) 

45 

46 

47# XLA doesn't currently support reading of intermediate tensors, thus some ops 

48# are not supported. 

49_UNSUPPORTED_OPS = frozenset([ 

50 "AudioSummary", 

51 "AudioSummaryV2", 

52 "HistogramSummary", 

53 "ImageSummary", 

54 "MergeSummary", 

55 "Print", 

56 "ScalarSummary", 

57 "TensorSummary", 

58 "TensorSummaryV2", 

59]) 

60 

61 

62def is_tpu_strategy(strategy: Any) -> bool: 

63 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy") 

64 clz = strategy.__class__ 

65 return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__)) 

66 

67 

68def _enclosing_tpu_device_assignment( 

69) -> Optional[device_assignment_lib.DeviceAssignment]: 

70 if not distribute_lib.has_strategy(): 

71 return None 

72 strategy = distribute_lib.get_strategy() 

73 if not is_tpu_strategy(strategy): 

74 return None 

75 return strategy.extended._device_assignment # pylint: disable=protected-access 

76 

77 

78class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): 

79 """A `ControlFlowContext` for nodes inside a TPU computation. 

80 

81 The primary role of `TPUReplicateContext` is to mark operators inside a 

82 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ 

83 is a unique name. 

84 

85 We use a `ControlFlowContext` to perform the annotation since it integrates 

86 with Tensorflow constructs like ResourceVariables. For example, if a 

87 `ResourceVariable` is constructed inside a tpu.replicate() block, the 

88 `ResourceVariable` implementation can use 

89 `with ops.control_dependencies(None)` to build the variable's definition 

90 outside the replicated computation. 

91 """ 

92 

93 def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation): 

94 """Builds a new TPUReplicateContext. 

95 

96 Args: 

97 name: a unique name for the context, used to populate the `_tpu_replicate` 

98 attribute. 

99 num_replicas: an integer that gives the number of replicas for the 

100 computation. 

101 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any 

102 inputs will have a control dependency on the pivot node. This ensures 

103 that nodes are correctly included in any enclosing control flow 

104 contexts. 

105 """ 

106 super(TPUReplicateContext, self).__init__() 

107 self._num_replicas = num_replicas 

108 self._outer_device_function_stack = None 

109 self._oc_dev_fn_stack = None 

110 self._outside_compilation_cluster = None 

111 self._outside_compilation_v2_context = None 

112 self._outside_compilation_counter = 0 

113 self._in_gradient_colocation = None 

114 self._gradient_colocation_stack = [] 

115 self._host_compute_core = [] 

116 self._name = name 

117 self._tpu_replicate_attr = attr_value_pb2.AttrValue( 

118 s=compat.as_bytes(self._name) 

119 ) 

120 self._unsupported_ops = [] 

121 self._pivot = pivot 

122 self._replicated_vars = {} 

123 

124 def get_replicated_var_handle(self, 

125 name: Text, 

126 handle_id: Text, 

127 vars_: Union[List[core_types.Tensor], 

128 List[variables.Variable]], 

129 is_mirrored: bool = False, 

130 is_packed: bool = False) -> core_types.Tensor: 

131 """Returns a variable handle for replicated TPU variable 'var'. 

132 

133 This is a method used by an experimental replicated variable implementation 

134 and is not intended as a public API. 

135 

136 Args: 

137 name: The common name of the variable. 

138 handle_id: Unique ID of the variable handle, used as the cache key. 

139 vars_: The replicated TPU variables or handles. 

140 is_mirrored: Whether the variables are mirrored, which guarantees the 

141 values in each replica are always the same. 

142 is_packed: Whether the replicated variables are packed into one variable. 

143 

144 Returns: 

145 The handle of the TPU replicated input node. 

146 """ 

147 device_assignment = _enclosing_tpu_device_assignment() 

148 # We don't need to put device assignment as part of the replicated_vars key 

149 # because each TPUReplicateContext will only have one device assignment. 

150 handle = self._replicated_vars.get(handle_id) 

151 if handle is not None: 

152 return handle 

153 

154 if device_assignment is not None and not is_packed: 

155 # Find a variable copy for each replica in the device assignment. 

156 # Note that the order of devices for replicas for the variable and the 

157 # device assignment might not match. 

158 job_name = pydev.DeviceSpec.from_string(vars_[0].device).job 

159 devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_} 

160 replicated_vars = [] 

161 for replica_id in range(device_assignment.num_replicas): 

162 for logical_core in range(device_assignment.num_cores_per_replica): 

163 device = device_util.canonicalize( 

164 device_assignment.tpu_device( 

165 replica=replica_id, logical_core=logical_core, job=job_name)) 

166 if device in devices_to_vars: 

167 replicated_vars.append(devices_to_vars[device]) 

168 break 

169 else: 

170 raise ValueError( 

171 "Failed to find a variable on any device in replica {} for " 

172 "current device assignment".format(replica_id) 

173 ) 

174 else: 

175 replicated_vars = vars_ 

176 

177 # Builds a TPUReplicatedInput node for the variable, if one does not already 

178 # exist. The TPUReplicatedInput node must belong to the enclosing 

179 # control-flow scope of the TPUReplicateContext. 

180 # TODO(phawkins): consider changing the contract of the TPU encapsulation 

181 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope 

182 # instead. 

183 

184 _, graph = _enclosing_tpu_context_and_graph() 

185 with graph.as_default(): 

186 # If replicated_vars are variables, get the handles. Note that this can be 

187 # done inside TPUReplicateContext because replicated_vars.handle may 

188 # create new ops. 

189 if isinstance(replicated_vars[0], variables.Variable): 

190 replicated_vars = [v.handle for v in replicated_vars] 

191 # pylint: disable=protected-access 

192 saved_context = graph._get_control_flow_context() 

193 graph._set_control_flow_context(self.outer_context) 

194 handle = tpu_ops.tpu_replicated_input( 

195 replicated_vars, 

196 name=name + "/handle", 

197 is_mirrored_variable=is_mirrored, 

198 is_packed=is_packed) 

199 graph._set_control_flow_context(saved_context) 

200 # pylint: enable=protected-access 

201 self._replicated_vars[handle_id] = handle 

202 return handle 

203 

204 def report_unsupported_operations(self) -> None: 

205 if self._unsupported_ops: 

206 op_str = "\n".join( 

207 " %s (%s)" % (op.type, op.name) for op in 

208 self._unsupported_ops[:_MAX_WARNING_LINES]) 

209 logging.warning("%d unsupported operations found: \n%s", 

210 len(self._unsupported_ops), op_str) 

211 if len(self._unsupported_ops 

212 ) > _MAX_WARNING_LINES: 

213 logging.warning("... and %d more", 

214 (len(self._unsupported_ops) - _MAX_WARNING_LINES)) 

215 

216 def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text): 

217 if op is not None: 

218 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 

219 # If we are in TF 2 functions (control flow V2 functions, or 

220 # tf.function()), we need to attach _xla_outside_compilation attribute 

221 # directly because we are not in TPUReplicateContext. 

222 try: 

223 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 

224 except ValueError: 

225 # The attr was not present: do nothing. 

226 return 

227 parts = outside_attr.split(".") 

228 cluster = parts[0] + "." + gradient_uid 

229 self._outside_compilation_v2_context = OutsideCompilationV2Context( 

230 cluster) 

231 self._outside_compilation_v2_context.Enter() 

232 return 

233 self._gradient_colocation_stack.append(op) 

234 if not self._outside_compilation_cluster: 

235 try: 

236 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 

237 if self._in_gradient_colocation: 

238 raise NotImplementedError( 

239 "Cannot nest gradient colocation operations outside compilation" 

240 ) 

241 if gradient_uid == "__unsupported__": 

242 raise NotImplementedError( 

243 "No gradient_uid calling gradient within outside_compilation") 

244 # When we take the gradient of an op X in an outside_compilation 

245 # cluster C in a forward computation we would like to put the ops 

246 # corresponding to the gradient of X into a new outside_compilation 

247 # cluster C'. However, if we take the gradient of X twice, the second 

248 # one should get yet another new outside_compilation cluster C''. 

249 # 

250 # The mechanism we adopt is to use a 'root_cluster' which is the 

251 # cluster that X was in before we took gradients, and a 'gradient_uid' 

252 # which is different for every invocation of gradients, and put the 

253 # gradient of X in cluster 'root_cluster.gradient_uid'. 

254 # 

255 # When taking a gradient of a gradient, some ops will be colocated 

256 # with Op in the forward pass (e.g., cluster root_cluster) and some in 

257 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). 

258 # We need all of the grad-of-grad ops to be in the same cluster to 

259 # avoid cyclic dependencies between clusters. We adopt a heuristic 

260 # that puts any op clustered with root_cluster.<xxx> in 

261 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. 

262 self._in_gradient_colocation = op 

263 parts = outside_attr.split(".") 

264 cluster = parts[0] + "." + gradient_uid 

265 self._EnterOutsideCompilationScope(cluster=cluster) 

266 except ValueError: 

267 # The attr was not present: do nothing. 

268 pass 

269 

270 def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text): 

271 if op is not None: 

272 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 

273 # Inside a TF2 tf.function or control flow graph and `op` was not 

274 # marked to be outside compiled. 

275 assert self._outside_compilation_v2_context is None 

276 return 

277 if self._outside_compilation_v2_context is not None: 

278 # Inside a TF2 tf.function or control flow graph and `op` was 

279 # marked to be outside compiled. 

280 self._outside_compilation_v2_context.Exit() 

281 self._outside_compilation_v2_context = None 

282 return 

283 if not self._gradient_colocation_stack: 

284 raise errors.InternalError( 

285 op.node_def, op, 

286 ("Badly nested gradient colocation: " 

287 + f"empty stack when popping Op {op.name}") 

288 ) 

289 last_op = self._gradient_colocation_stack.pop() 

290 if op is last_op: 

291 if op is self._in_gradient_colocation: 

292 self._in_gradient_colocation = None 

293 self._ExitOutsideCompilationScope() 

294 else: 

295 raise errors.InternalError( 

296 op.node_def, op, 

297 ("Badly nested gradient colocation, " + 

298 f"expected {last_op}, got {op.name}") 

299 ) 

300 

301 def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None): 

302 

303 class FakeOp(object): 

304 """A helper class to determine the current device. 

305 

306 Supports only the type and device set/get methods needed to run the 

307 graph's _apply_device_function method. 

308 """ 

309 

310 def __init__(self): 

311 self._device = "" 

312 

313 @property 

314 def type(self): 

315 return "FakeOp" 

316 

317 @property 

318 def device(self): 

319 return self._device 

320 

321 def _set_device(self, device): 

322 if isinstance(device, pydev.DeviceSpec): 

323 self._device = device.to_string() 

324 else: 

325 self._device = device 

326 

327 def _set_device_from_string(self, device_str): 

328 self._device = device_str 

329 

330 if self._outside_compilation_cluster: 

331 raise NotImplementedError("Cannot nest outside_compilation clusters") 

332 if cluster: 

333 self._outside_compilation_cluster = cluster 

334 else: 

335 self._outside_compilation_cluster = str(self._outside_compilation_counter) 

336 self._outside_compilation_counter += 1 

337 graph = ops.get_default_graph() 

338 fake_op = FakeOp() 

339 graph._apply_device_functions(fake_op) # pylint: disable=protected-access 

340 device = pydev.DeviceSpec.from_string(fake_op.device) 

341 if (device.device_type == "TPU_REPLICATED_CORE" and 

342 device.device_index is not None): 

343 self._host_compute_core.append(self._outside_compilation_cluster + ":" + 

344 str(device.device_index)) 

345 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access 

346 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access 

347 

348 def _ExitOutsideCompilationScope(self): 

349 if not self._outside_compilation_cluster: 

350 raise ValueError( 

351 "Attempted to exit outside_compilation scope when not in scope") 

352 self._outside_compilation_cluster = None 

353 graph = ops.get_default_graph() 

354 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access 

355 

356 def Enter(self) -> None: 

357 if not self._outer_device_function_stack: 

358 # Capture the device function stack at the time of first entry 

359 # since that is the stack that will be used outside_compilation. 

360 graph = ops.get_default_graph() 

361 # pylint: disable=protected-access 

362 self._outer_device_function_stack = graph._device_function_stack.copy() 

363 # pylint: enable=protected-access 

364 super(TPUReplicateContext, self).Enter() 

365 

366 def HostComputeCore(self) -> List[Text]: 

367 return self._host_compute_core 

368 

369 def _RemoveExternalControlEdges( 

370 self, 

371 op: ops.Operation) -> Tuple[List[ops.Operation], List[ops.Operation]]: 

372 """Remove any external control dependency on this op.""" 

373 internal_control_inputs = [] 

374 external_control_inputs = [] 

375 for x in op.control_inputs: 

376 # pylint: disable=protected-access 

377 is_internal_op = False 

378 ctxt = x._get_control_flow_context() 

379 while ctxt is not None: 

380 if ctxt == self: 

381 is_internal_op = True 

382 break 

383 ctxt = ctxt._outer_context 

384 if is_internal_op: 

385 internal_control_inputs.append(x) 

386 else: 

387 external_control_inputs.append(x) 

388 # pylint: enable=protected-access 

389 # pylint: disable=protected-access 

390 op._remove_all_control_inputs() 

391 op._add_control_inputs(internal_control_inputs) 

392 # pylint: enable=protected-access 

393 return internal_control_inputs, external_control_inputs 

394 

395 def AddOp(self, op: ops.Operation) -> None: 

396 # pylint: disable=protected-access 

397 if op.type in _DENYLISTED_OPS: 

398 logging.error( 

399 "Operation of type %s (%s) is not supported on the TPU. " 

400 "Execution will fail if this op is used in the graph. ", op.type, 

401 op.name) 

402 

403 if op.type in _UNSUPPORTED_OPS: 

404 self._unsupported_ops.append(op) 

405 

406 if any(x.dtype._is_ref_dtype for x in op.inputs): 

407 raise NotImplementedError( 

408 f"Non-resource Variables are not supported inside TPU computations " 

409 f"(operator name: {op.name})") 

410 

411 # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add 

412 # the "_cloned" attribute and we should continue in that case. 

413 if (_TPU_REPLICATE_ATTR in op.node_def.attr and 

414 "_cloned" not in op.node_def.attr): 

415 raise ValueError(f"TPU computations cannot be nested on op ({op})") 

416 op._set_attr(_TPU_REPLICATE_ATTR, self._tpu_replicate_attr) 

417 if self._outside_compilation_cluster: 

418 op._set_attr( 

419 _OUTSIDE_COMPILATION_ATTR, 

420 attr_value_pb2.AttrValue( 

421 s=compat.as_bytes(self._outside_compilation_cluster))) 

422 if self._num_replicas > 1 or not self._outside_compilation_cluster: 

423 # Prevent feeding or fetching anything that is being compiled, 

424 # and any replicated outside_compilation Op. 

425 op.graph.prevent_feeding(op) 

426 op.graph.prevent_fetching(op) 

427 

428 # Remove any control edges from outer control flow contexts. These may cause 

429 # mismatched frame errors. 

430 (internal_control_inputs, 

431 external_control_inputs) = self._RemoveExternalControlEdges(op) 

432 

433 if not op.inputs: 

434 # Add a control edge from the control pivot to this op. 

435 if not internal_control_inputs: 

436 # pylint: disable=protected-access 

437 op._add_control_input(self.GetControlPivot()) 

438 # pylint: enable=protected-access 

439 else: 

440 for index in range(len(op.inputs)): 

441 x = op.inputs[index] 

442 real_x = self.AddValue(x) 

443 if real_x is not x: 

444 op._update_input(index, real_x) # pylint: disable=protected-access 

445 

446 if external_control_inputs: 

447 # Use an identity to pull control inputs as data inputs. Note that we 

448 # ignore ops which don't have outputs. TODO(phawkins): fix that. 

449 with ops.control_dependencies(None): 

450 self.Enter() 

451 external_control_inputs = [ 

452 array_ops.identity(x.outputs[0]).op 

453 for x in external_control_inputs 

454 if x.outputs 

455 ] 

456 self.Exit() 

457 # pylint: disable=protected-access 

458 op._add_control_inputs(external_control_inputs) 

459 # pylint: enable=protected-access 

460 

461 # Mark op's outputs as seen by this context and any outer contexts. 

462 output_names = [x.name for x in op.outputs] 

463 context = self 

464 while context is not None: 

465 # pylint: disable=protected-access 

466 context._values.update(output_names) 

467 context = context._outer_context 

468 # pylint: enable=protected-access 

469 

470 if self._outer_context: 

471 self._outer_context.AddInnerOp(op) 

472 

473 def AddValue(self, val: core_types.Tensor) -> core_types.Tensor: 

474 """Add `val` to the current context and its outer context recursively.""" 

475 if not self._outer_context: 

476 return val 

477 

478 if val.name in self._values: 

479 # Use the real value if it comes from outer context. 

480 result = self._external_values.get(val.name) 

481 return val if result is None else result 

482 

483 result = val 

484 self._values.add(val.name) 

485 if self._outer_context: 

486 result = self._outer_context.AddValue(val) 

487 self._values.add(result.name) 

488 

489 self._external_values[val.name] = result 

490 

491 return result 

492 

493 def AddInnerOp(self, op: ops.Operation): 

494 self.AddOp(op) 

495 if self._outer_context: 

496 self._outer_context.AddInnerOp(op) 

497 

498 @property 

499 def grad_state(self): 

500 # Define the gradient loop state associated with the TPUReplicateContext to 

501 # be None as the TPUReplicateContext does not get nested nor does the 

502 # grad_state outside the TPUReplicateContext affect the graph inside so the 

503 # grad_state should be as if this is the top-level gradient state. 

504 return None 

505 

506 @property 

507 def back_prop(self): 

508 """Forwards to the enclosing while context, if any.""" 

509 if self.GetWhileContext(): 

510 return self.GetWhileContext().back_prop 

511 return False 

512 

513 def GetControlPivot(self) -> ops.Operation: 

514 return self._pivot 

515 

516 def RequiresUniqueFunctionRetracing(self): 

517 # More context: b/158152827. TPU stack uses the TPUReplicateContext to 

518 # create replicated variable handles and cluster TPU computations, thus we 

519 # always retrace a tf.function when the wrapped TPUReplicateContext changes. 

520 return True 

521 

522 

523def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]: 

524 """Returns the TPUReplicateContext and its associated graph.""" 

525 graph = ops.get_default_graph() 

526 while graph is not None: 

527 # pylint: disable=protected-access 

528 context_ = graph._get_control_flow_context() 

529 # pylint: enable=protected-access 

530 while context_ is not None: 

531 if isinstance(context_, TPUReplicateContext): 

532 return context_, graph 

533 context_ = context_.outer_context 

534 graph = getattr(graph, "outer_graph", None) 

535 raise ValueError("get_replicated_var_handle() called without " 

536 "TPUReplicateContext. This shouldn't happen. Please file " 

537 "a bug.") 

538 

539 

540class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): 

541 """The context for outside compilation in Tensorflow 2.0. 

542 

543 Every op added in this context will be assigned an _xla_outside_compilation 

544 attribute. 

545 """ 

546 

547 def __init__(self, name: Text): 

548 control_flow_ops.ControlFlowContext.__init__(self) 

549 self._name = name 

550 

551 def AddOp(self, op: ops.Operation) -> None: 

552 if self._outer_context: 

553 self._outer_context.AddOp(op) 

554 # pylint: disable=protected-access 

555 op._set_attr("_xla_outside_compilation", 

556 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 

557 # pylint: enable=protected-access 

558 

559 def AddInnerOp(self, op: ops.Operation) -> None: 

560 if self._outer_context: 

561 self._outer_context.AddInnerOp(op) 

562 # pylint: disable=protected-access 

563 op._set_attr("_xla_outside_compilation", 

564 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 

565 # pylint: enable=protected-access 

566 

567 def to_control_flow_context_def(self, context_def, export_scope=None): 

568 raise NotImplementedError 

569 

570 

571@tf_export(v1=["tpu.outside_compilation"]) 

572def outside_compilation(computation: Callable[..., Any], *args, 

573 **kwargs) -> Any: 

574 """Builds part of a computation outside any current TPU replicate scope. 

575 

576 `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU 

577 instead of running on TPU. For example, users can run ops that are not 

578 supported on TPU's (e.g. tf.summary.write()) by explicitly placing those 

579 ops on CPU's. Below usage of outside compilation will place ops in 

580 `computation_with_string_ops` on CPU. 

581 

582 Example usage: 

583 

584 ```python 

585 def computation_with_string_ops(x): 

586 # strings types are not supported on TPU's and below ops must 

587 # run on CPU instead. 

588 output = tf.strings.format('1{}', x) 

589 return tf.strings.to_number(output) 

590 

591 def tpu_computation(): 

592 # Expected output is 11. 

593 output = tf.tpu.outside_compilation(computation_with_string_ops, 1) 

594 ``` 

595 

596 Outside compilation should be called inside TPUReplicateContext. That is, 

597 `tf.tpu.outside_compilation()` should be called inside a function that is 

598 passed to `tpu.split_compile_and_replicate()` -- this is implied when 

599 outside compilation is invoked inside a function passed to TPUStrategy 

600 `run()`. If invoked outside of TPUReplicateContext, 

601 then this simply returns the result of `computation`, and therefore, 

602 would be a no-op. Note that outside compilation is different from 

603 `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in 

604 outside compilation is replicated and executed separately for each 

605 replica. On the other hand, `merge_call()` requires a `merge_fn` 

606 to aggregate the inputs from different replicas and is executed only 

607 once. 

608 

609 For variables placed in TPU device, which includes variables created inside 

610 TPUStrategy scope, outside compilation logic must not include variable 

611 read/write. For variables placed on host, which is the case when variables 

612 created via TPUEstimator, variable read/write is only allowed if the variable 

613 is not accessed by any other ops in the TPU computation. Variable read/write 

614 from outside compilation cluster is not visible from TPU computation and 

615 vice versa. Therefore, if outside compilation logic contains such host 

616 variables read/write ops and if the variables are accessed by TPU 

617 computation as well, then this may lead to deadlock. 

618 

619 Internally, `tf.tpu.outside_compilation()` adds outside compilation 

620 attributes to all ops in `computation`. During later graph pass, these 

621 ops with outside compilation attribute is extracted out and replicated 

622 into a host-side graph. Inputs to this extract host-side graph is sent 

623 from TPU computation graph to host graph via a pair of XlaSendToHost and 

624 XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()` 

625 may result in tensor transfer between TPU and CPU, leading to non-trivial 

626 performance impact. 

627 

628 Args: 

629 computation: A Python function that builds the computation to place on the 

630 host. 

631 *args: the positional arguments for the computation. 

632 **kwargs: the keyword arguments for the computation. 

633 

634 Returns: 

635 The Tensors returned by computation. 

636 """ 

637 args = [] if args is None else args 

638 graph = ops.get_default_graph() 

639 

640 # If we are in TF 2 functions (control flow V2 functions, or tf.function()), 

641 # we need to attach _xla_outside_compilation attribute directly because we are 

642 # not in TPUReplicateContext. 

643 if isinstance(graph, func_graph.FuncGraph): 

644 try: 

645 tpu_context, _ = _enclosing_tpu_context_and_graph() 

646 except ValueError: 

647 logging.warning( 

648 "Outside compilation attempted outside TPUReplicateContext " 

649 "scope. As no enclosing TPUReplicateContext can be found, " 

650 "returning the result of `computation` as is.") 

651 return computation(*args, **kwargs) 

652 

653 # pylint: disable=protected-access 

654 outside_compilation_name = str(tpu_context._outside_compilation_counter) 

655 tpu_context._outside_compilation_counter = ( 

656 tpu_context._outside_compilation_counter + 1) 

657 # pylint: enable=protected-access 

658 

659 outside_compilation_context = OutsideCompilationV2Context( 

660 outside_compilation_name) 

661 outside_compilation_context.Enter() 

662 args = [] if args is None else args 

663 retval = computation(*args, **kwargs) 

664 outside_compilation_context.Exit() 

665 return retval 

666 

667 # If we are in a TPUReplicateContext, signal that we are now 

668 # outside_compilation 

669 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access 

670 context = initial_context 

671 while context: 

672 if isinstance(context, TPUReplicateContext): 

673 context._EnterOutsideCompilationScope() # pylint: disable=protected-access 

674 context = context.outer_context 

675 

676 retval = computation(*args, **kwargs) 

677 

678 # If we are in a TPUReplicateContext, signal that we are no longer 

679 # outside_compilation 

680 final_context = graph._get_control_flow_context() # pylint: disable=protected-access 

681 if initial_context is not final_context: 

682 raise NotImplementedError( 

683 "Control-flow context cannot be different at start and end of an " 

684 "outside_compilation scope") 

685 context = initial_context 

686 while context: 

687 if isinstance(context, TPUReplicateContext): 

688 context._ExitOutsideCompilationScope() # pylint: disable=protected-access 

689 context = context.outer_context 

690 

691 return retval