Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu.py: 19%

500 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ====================================== 

15 

16"""Library of TPU helper functions.""" 

17 

18import collections 

19import enum 

20from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union 

21 

22from absl import logging 

23import numpy as np 

24 

25from tensorflow.compiler.tf2xla.python import xla as tf2xla 

26from tensorflow.core.framework import attr_value_pb2 

27from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding 

28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2 

29from tensorflow.python import tf2 

30from tensorflow.python.compiler.xla import xla 

31from tensorflow.python.framework import auto_control_deps 

32from tensorflow.python.framework import composite_tensor 

33from tensorflow.python.framework import config 

34from tensorflow.python.framework import constant_op 

35from tensorflow.python.framework import dtypes 

36from tensorflow.python.framework import func_graph 

37from tensorflow.python.framework import function 

38from tensorflow.python.framework import ops 

39from tensorflow.python.framework import tensor_shape 

40from tensorflow.python.ops import array_ops 

41from tensorflow.python.ops import array_ops_stack 

42from tensorflow.python.ops import cond 

43from tensorflow.python.ops import control_flow_ops 

44from tensorflow.python.ops import math_ops 

45from tensorflow.python.ops import variable_scope 

46from tensorflow.python.tpu import device_assignment as device_assignment_lib 

47from tensorflow.python.tpu import tensor_tracer 

48from tensorflow.python.tpu import tpu_feed 

49from tensorflow.python.tpu import tpu_function 

50from tensorflow.python.tpu import tpu_name_util 

51from tensorflow.python.tpu import tpu_replication 

52from tensorflow.python.tpu.ops import tpu_ops 

53from tensorflow.python.types import core as core_types 

54from tensorflow.python.util import compat 

55from tensorflow.python.util import nest 

56from tensorflow.python.util import object_identity 

57from tensorflow.python.util import traceback_utils 

58from tensorflow.python.util import variable_utils 

59from tensorflow.python.util.tf_export import tf_export 

60 

61 

62ops.NotDifferentiable("TPUReplicatedInput") 

63 

64# Ops which can be safely pruned from XLA compile if they have no consumers. 

65# These ops should also have no inputs. 

66_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) 

67 

68_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite" 

69_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" 

70_PIVOT_FOR_CLUSTER = "_pivot_for_cluster" 

71 

72 

73core = tpu_name_util.core 

74 

75 

76def _tpu_system_device_name(job: Optional[Text]) -> Text: 

77 """Returns the device name for the TPU_SYSTEM device of `job`.""" 

78 if job is None: 

79 return "/device:TPU_SYSTEM:0" 

80 else: 

81 return "/job:%s/device:TPU_SYSTEM:0" % job 

82 

83 

84@tf_export(v1=["tpu.initialize_system"]) 

85def initialize_system( 

86 embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None, 

87 job: Optional[Text] = None, 

88 compilation_failure_closes_chips: bool = True, 

89 tpu_cancellation_closes_chips: Optional[bool] = None, 

90) -> core_types.Tensor: 

91 """Initializes a distributed TPU system for use with TensorFlow. 

92 

93 Args: 

94 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto 

95 describing the desired configuration of the hardware embedding lookup 

96 tables. If embedding_config is None, no hardware embeddings can be used. 

97 job: The job (the XXX in TensorFlow device specification /job:XXX) that 

98 contains the TPU devices that will be initialized. If job=None it is 

99 assumed there is only one job in the TensorFlow flock, and an error will 

100 be returned if this assumption does not hold. 

101 compilation_failure_closes_chips: Set the configuration whether 

102 we want to close TPU chips when there is a compilation failure. 

103 tpu_cancellation_closes_chips: Set the configuration whether 

104 we want to close TPU chips when a TPU execution is cancelled. If the value 

105 is None, the behavior will be determined by the command line flag 

106 `tpu_cancellation_closes_chips` for the TPU worker. WARNING: this argument 

107 only applies to TFRT TPU runtime. 

108 Returns: 

109 A serialized `TopologyProto` that describes the TPU system. Note: 

110 the topology must be evaluated using `Session.run` before it can be used. 

111 """ 

112 config_string = ("" if embedding_config is None else 

113 embedding_config.SerializeToString()) 

114 

115 # The enum is defined in core/tpu/kernels/tpu_execute_op_options.h. 

116 tpu_cancellation_closes_chips_enum = 0 

117 if tpu_cancellation_closes_chips is not None: 

118 if tpu_cancellation_closes_chips: 

119 tpu_cancellation_closes_chips_enum = 1 

120 else: 

121 tpu_cancellation_closes_chips_enum = 2 

122 

123 with ops.device(_tpu_system_device_name(job)): 

124 topology = tpu_ops.configure_distributed_tpu( 

125 compilation_failure_closes_chips=compilation_failure_closes_chips, 

126 tpu_cancellation_closes_chips=tpu_cancellation_closes_chips_enum, 

127 ) 

128 

129 if embedding_config is None: 

130 return topology 

131 

132 # This set of control dependencies is needed as this function is expected to 

133 # return an op which will return the topology when executed, but we need to 

134 # call the embedding initialization op between initializing the TPU and 

135 # returning the topology. 

136 with ops.control_dependencies([topology]): 

137 embedding_init = tpu_ops.configure_tpu_embedding(config=config_string) 

138 with ops.control_dependencies([embedding_init]): 

139 return array_ops.identity(topology, name="tpu_init_identity") 

140 

141 

142def initialize_system_for_tpu_embedding( 

143 embedding_config: embedding_pb2.TPUEmbeddingConfiguration, 

144 job: Optional[Text] = None, 

145) -> ops.Operation: 

146 """Initializes a distributed TPU Embedding system for use with TensorFlow. 

147 

148 The following two are equivalent: 

149 1. initialize_system() with embedding_config. 

150 2. initialize_system() without embedding_config, then 

151 initialize_system_for_tpu_embedding(). 

152 initialize_system() should not be called with embedding_config if 

153 initialize_system_for_tpu_embedding() is meant to be called later. 

154 

155 Args: 

156 embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired 

157 configuration of the hardware embedding lookup tables. 

158 job: The job (the XXX in TensorFlow device specification /job:XXX) that 

159 contains the TPU devices that will be initialized. If job=None it is 

160 assumed there is only one job in the TensorFlow flock, and an error will 

161 be returned if this assumption does not hold. 

162 

163 Returns: 

164 A no-op. 

165 """ 

166 config_string = embedding_config.SerializeToString() 

167 with ops.device(_tpu_system_device_name(job)): 

168 return tpu_ops.configure_tpu_embedding(config=config_string) 

169 

170 

171@tf_export(v1=["tpu.shutdown_system"]) 

172def shutdown_system(job: Optional[Text] = None) -> ops.Operation: 

173 """Shuts down a running a distributed TPU system. 

174 

175 Args: 

176 job: The job (the XXX in TensorFlow device specification /job:XXX) that 

177 contains the TPU devices that will be shutdown. If job=None it is 

178 assumed there is only one job in the TensorFlow flock, and an error will 

179 be returned if this assumption does not hold. 

180 """ 

181 with ops.device(_tpu_system_device_name(job)): 

182 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() 

183 return shutdown_distributed_tpu 

184 

185 

186@auto_control_deps.register_acd_resource_resolver 

187def tpu_replicated_input_resolver( 

188 op: ops.Operation, 

189 resource_reads: object_identity.ObjectIdentitySet, 

190 resource_writes: object_identity.ObjectIdentitySet) -> bool: 

191 """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs.""" 

192 # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding 

193 # control deps on the replicated inputs. 

194 if op.type == "TPUReplicatedInput": 

195 if resource_reads or resource_writes: 

196 resource_reads.clear() 

197 resource_writes.clear() 

198 return True 

199 else: 

200 return False 

201 # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput 

202 # with the actual replicated inputs. This allows ACD to correct add control 

203 # deps when there are multiple calls to `run` in a 

204 # `tf.function`. 

205 def replace_with_unreplicated_resources(resource_inputs): 

206 """Replaces handles in `resource_inputs` with their unreplicated inputs.""" 

207 to_remove = [] 

208 to_add = [] 

209 for resource in resource_inputs: 

210 if resource.op.type == "TPUReplicatedInput": 

211 to_remove.append(resource) 

212 to_add.extend(resource.op.inputs) 

213 for t in to_remove: 

214 resource_inputs.discard(t) 

215 resource_inputs.update(to_add) 

216 return to_add or to_remove 

217 

218 return bool(replace_with_unreplicated_resources(resource_reads) or 

219 replace_with_unreplicated_resources(resource_writes)) 

220 

221 

222@tf_export(v1=["tpu.PaddingSpec"]) 

223class PaddingSpec(enum.IntEnum): 

224 """Represents the type of padding policies for tpu.replicate.""" 

225 # By default the policy is set to AUTO, the dynamic input shape dimension will 

226 # be pad to maximum of all the replicas. 

227 AUTO = 0 

228 # Bucketize the dynamic input shape dimension into a power of 2. 

229 POWER_OF_TWO = 1 

230 

231 

232@tf_export("tpu.XLAOptions") 

233class XLAOptions( 

234 collections.namedtuple("XLAOptions", [ 

235 "use_spmd_for_xla_partitioning", 

236 "enable_xla_dynamic_padder", 

237 ])): 

238 """XLA compilation options. 

239 

240 Attributes: 

241 use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD 

242 partitioner instead of MPMD partitioner when compiler partitioning is 

243 requested. 

244 enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder 

245 infrastructure to handle dynamic shapes inputs inside XLA. True by 

246 default. Disabling this may cause correctness issues with dynamic shapes 

247 inputs, as XLA will just assume the inputs are with padded shapes. However 

248 users can optionally set it to False to improve device time if masking is 

249 already handled in the user side. 

250 """ 

251 

252 def __new__(cls, 

253 use_spmd_for_xla_partitioning=True, 

254 enable_xla_dynamic_padder=True): 

255 return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning, 

256 enable_xla_dynamic_padder) 

257 

258 

259@tf_export(v1=["tpu.replicate"]) 

260@traceback_utils.filter_traceback 

261def replicate( 

262 computation: Callable[..., Any], 

263 inputs: Optional[List[List[core_types.Tensor]]] = None, 

264 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

265 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

266 name: Optional[Text] = None, 

267 maximum_shapes: Optional[Any] = None, 

268 padding_spec: Optional[PaddingSpec] = None, 

269 xla_options: Optional[XLAOptions] = None) -> List[Any]: 

270 """Builds a graph operator that runs a replicated TPU computation. 

271 

272 Example for the basic usage that `inputs` has static shape: 

273 

274 ```python 

275 

276 def computation(x): 

277 x = x + 1 

278 return tf.math.reduce_mean(x) 

279 

280 x = tf.convert_to_tensor([1., 2., 3.]) 

281 y = tf.convert_to_tensor([4., 5., 6.]) 

282 tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]]) 

283 ``` 

284 

285 If the `inputs` has dynamic shapes and you would like to automatically 

286 bucketize the inputs to avoid XLA recompilation. See the advanced example 

287 below: 

288 

289 ```python 

290 

291 def computation(x): 

292 x = x + 1 

293 return tf.math.reduce_mean(x) 

294 

295 # Assume input tensors in two replicas `x` and `y` both have dynamic shape 

296 # ([None, 2]). 

297 tf.compat.v1.tpu.replicate( 

298 computation, 

299 inputs=[x, y], 

300 maximum_shapes=[tf.TensorShape([None, None])], 

301 padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO) 

302 ``` 

303 

304 Args: 

305 computation: A Python function that builds the computation to replicate. 

306 inputs: A list of lists of input tensors or `None` (equivalent to 

307 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 

308 have the same number of inputs. Each input can be a nested structure 

309 containing values that are convertible to tensors. Note that passing an 

310 N-dimension list of compatible values will result in a N-dimension list of 

311 scalar tensors rather than a single Rank-N tensors. If you need different 

312 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 

313 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 

314 of arguments as inputs to computation. 

315 device_assignment: If not `None`, a `DeviceAssignment` describing the 

316 mapping between logical cores in the computation with physical cores in 

317 the TPU topology. Uses a default device assignment if `None`. The 

318 `DeviceAssignment` may be omitted if each replica of the computation uses 

319 only one core, and there is either only one replica, or the number of 

320 replicas is equal to the number of cores in the TPU system. 

321 name: (Deprecated) Does nothing. 

322 maximum_shapes: A nested structure of tf.TensorShape representing the shape 

323 to which the respective component of each input element in each replica 

324 should be padded. Any unknown dimensions (e.g. 

325 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 

326 object) will be padded to the maximum size of that dimension over all 

327 replicas. The structure of `maximum_shapes` needs to be the same as 

328 `inputs[0]`. 

329 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 

330 padding policy when the `inputs` to `tpu.replicate` is dynamic. 

331 One usage is to enable automatic bucketizing on the inputs by setting the 

332 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 

333 recompilation in the XLA side. 

334 xla_options: An instance of `tpu.XLAOptions` which indicates the options 

335 passed to XLA compiler. Use `None` for default options. 

336 Returns: 

337 A list of outputs, indexed by `[replica_num]` each output can be a nested 

338 structure same as what computation() returns with a few exceptions. 

339 

340 Exceptions include: 

341 1) None output: a NoOp would be returned which control-depends on 

342 computation. 

343 2) Single value output: A tuple containing the value would be returned. 

344 3) Operation-only outputs: a NoOp would be returned which 

345 control-depends on computation. 

346 TODO(b/121383831): Investigate into removing these special cases. 

347 

348 Raises: 

349 ValueError: If all replicas do not have equal numbers of input tensors. 

350 ValueError: If the number of inputs per replica does not match 

351 the number of formal parameters to `computation`. 

352 ValueError: If the static `inputs` dimensions don't match with the values 

353 given in `maximum_shapes`. 

354 ValueError: If the structure of inputs per replica does not match 

355 the structure of `maximum_shapes`. 

356 """ 

357 return split_compile_and_replicate( 

358 computation, 

359 inputs, 

360 infeed_queue, 

361 device_assignment, 

362 name, 

363 maximum_shapes=maximum_shapes, 

364 padding_spec=padding_spec, 

365 xla_options=xla_options)[1] 

366 

367 

368def _ceil_to_pow_of_n(x, n): 

369 """Ceil input `x` to power of `n`.""" 

370 x = math_ops.cast(x, dtypes.float32) 

371 lognx = math_ops.log(x) / math_ops.log(n * 1.0) 

372 lognx = math_ops.ceil(lognx) 

373 result = math_ops.pow(n * 1.0, lognx) 

374 result = math_ops.cast(result, dtypes.int32) 

375 return result 

376 

377 

378def _pad_all_input( 

379 inputs: Iterable[core_types.Tensor], 

380 padded_shapes: List[Optional[tensor_shape.TensorShape]], 

381 padding_spec: PaddingSpec 

382) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]: 

383 """Pad all input tensors given padded_shapes. 

384 

385 The real shape tensors will be concatenated with the padded original inputs. 

386 

387 Args: 

388 inputs: The original inputs. 

389 padded_shapes: A list of padded shapes for each input. If an entry is None, 

390 no padding is performed. 

391 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 

392 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 

393 One usage is to enable automatic bucketizing on the inputs by setting the 

394 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 

395 recompilation in the XLA side. 

396 

397 Returns: 

398 The padded inputs and a PaddingMap list which maps the padded input 

399 dimension to the real shape argument index. 

400 """ 

401 # maximum_static_shapes[idx][i] indicates the maximum static size of ith 

402 # dimension of the idx input among all the replicas. 

403 maximum_static_shapes = [] 

404 # need_padding[idx][i] indicates whether the ith dimension of the idx input 

405 # needs padding. 

406 need_padding = [] 

407 input_shape_tensors = [] 

408 for core_idx, inputs_per_core in enumerate(inputs): 

409 for idx, input_tensor in enumerate(inputs_per_core): 

410 input_shape = input_tensor.get_shape().as_list() 

411 if core_idx == 0: 

412 input_shape_tensors.append([]) 

413 maximum_static_shapes.append(input_shape) 

414 need_padding.append(np.full_like(input_shape, False, dtype=bool)) 

415 else: 

416 for i, s in enumerate(input_shape): 

417 if s is None or s != maximum_static_shapes[idx][i]: 

418 need_padding[idx][i] = True 

419 maximum_static_shapes[idx] = max(input_shape, 

420 maximum_static_shapes[idx]) 

421 

422 # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops. 

423 real_input_shape = array_ops.shape(input_tensor) 

424 real_input_shape.op._set_attr( # pylint: disable=protected-access 

425 _POST_DEVICE_REWRITE_ATTR, 

426 attr_value_pb2.AttrValue(b=True)) 

427 input_shape_tensors[idx].append(real_input_shape) 

428 

429 maximum_shapes = [] 

430 for shapes_per_input in input_shape_tensors: 

431 maximum_shapes.append( 

432 math_ops.reduce_max(array_ops_stack.stack(shapes_per_input), axis=0)) 

433 

434 padded_inputs = [] 

435 real_shapes = [] 

436 padding_maps = [] 

437 for core_idx, inputs_per_core in enumerate(inputs): 

438 padded_inputs.append([]) 

439 real_shapes.append([]) 

440 real_shape_idx = len(inputs_per_core) - 1 

441 for idx, input_tensor in enumerate(inputs_per_core): 

442 input_shape_tensor = input_shape_tensors[idx][core_idx] 

443 input_shape = input_tensor.get_shape().as_list() 

444 padded_shape = padded_shapes[idx] 

445 

446 # If we have no padded_shape, then skip padding. 

447 if any(need_padding[idx]) and padded_shape is not None: 

448 for i, s in enumerate(input_shape): 

449 if need_padding[idx][i]: 

450 if core_idx == 0: 

451 real_shape_idx += 1 

452 padding_map = dynamic_padding.PaddingMap() 

453 padding_map.arg_index = idx 

454 padding_map.shape_index = i 

455 padding_map.padding_arg_index = real_shape_idx 

456 padding_maps.append(padding_map) 

457 real_shapes[core_idx].append( 

458 math_ops.cast(input_shape_tensor[i], dtypes.int32)) 

459 

460 paddings = [] 

461 for i, s in enumerate(padded_shape.dims): 

462 if need_padding[idx][i]: 

463 # The minimum padded dimension size is 2 as XLA doesn't support size 

464 # 1 dynamic size. 

465 minimum_dynamic_dim_size = 2 

466 if s.value is not None: 

467 # Pad to the given maximum value. 

468 max_dim_size = max(s.value, minimum_dynamic_dim_size) 

469 else: 

470 # If maximum value is not given, then pad to the maximum dimension 

471 # among all the cores. 

472 max_dim_size = math_ops.maximum(maximum_shapes[idx][i], 

473 minimum_dynamic_dim_size) 

474 if padding_spec == PaddingSpec.POWER_OF_TWO: 

475 max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2) 

476 # Pad to the given maximum value. 

477 padding = [0, max_dim_size - input_shape_tensor[i]] 

478 else: 

479 padding = [0, 0] 

480 paddings.append(padding) 

481 

482 if input_tensor.get_shape().is_fully_defined(): 

483 # TODO(rxsang): This is a hack to make sure padded_input has dynamic 

484 # shapes, so any tf.size/tf.shape op performed on it won't be constant 

485 # folded. Do we have better ways to do it? 

486 padded_input = cond.cond( 

487 array_ops.constant(True), 

488 lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop 

489 lambda: input_tensor) 

490 else: 

491 padded_input = array_ops.pad(input_tensor, paddings) 

492 

493 # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs. 

494 padded_input.op._set_attr( # pylint: disable=protected-access 

495 _POST_DEVICE_REWRITE_ATTR, 

496 attr_value_pb2.AttrValue(b=True)) 

497 

498 padded_inputs[core_idx].append(padded_input) 

499 else: 

500 padded_inputs[core_idx].append(input_tensor) 

501 

502 num_replicas = len(padded_inputs) 

503 for i in range(num_replicas): 

504 padded_inputs[i].extend(real_shapes[i]) 

505 

506 return padded_inputs, padding_maps 

507 

508 

509def _flatten_and_filter_composite(maybe_composite, non_composite_output, 

510 composite_output=None): 

511 """For an input, replaced the input by a tuple if the input is composite. 

512 

513 If `maybe_composite` is not composite, return the parameter 

514 `non_composite_output` otherwise return a tuple which consists of the value of 

515 the parameter `composite_output` the same number of times as there are 

516 components of the composite tensor. 

517 

518 This is useful for computing a mask when flattening nested data with 

519 `expand_composites=True`. For example 

520 

521 ```python 

522 nest.flatten(data, expand_composites=True) 

523 ``` 

524 

525 and 

526 

527 ```python 

528 nest.flatten(nest.map( 

529 data, lambda x: _flatten_and_filter_composite(x, False, True))) 

530 ``` 

531 

532 will have the same length and second will be True if the tensor in the first 

533 is derived from a expanding a composite tensor. 

534 

535 Args: 

536 maybe_composite: A value to test for being a composite tensor. 

537 non_composite_output: The value to return when `maybe_composite` is not a 

538 composite. 

539 composite_output: the value to fill the output tuple with if 

540 `maybe_composite` is a composite. 

541 

542 Returns: 

543 `non_composite_output` or a tuple with multiple copies of 

544 `composite_output`. 

545 """ 

546 

547 if isinstance(maybe_composite, composite_tensor.CompositeTensor): 

548 num_components = len(nest.flatten(maybe_composite, expand_composites=True)) 

549 return (composite_output,) * num_components 

550 return non_composite_output 

551 

552 

553def split_compile_and_replicate( 

554 computation: Callable[..., Any], 

555 inputs: Optional[List[List[core_types.Tensor]]] = None, 

556 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

557 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

558 name: Optional[Text] = None, 

559 use_tpu: bool = True, 

560 maximum_shapes: Optional[Any] = None, 

561 padding_spec: Optional[PaddingSpec] = None, 

562 xla_options: Optional[XLAOptions] = None, 

563) -> List[List[core_types.Tensor]]: 

564 """Builds graph operators that runs compilation and replicated computation. 

565 

566 This is a lower level interface than replicate that returns a separate compile 

567 and execute output tensor. In the generated graph the compile op feeds into 

568 the execute op and no additional compilation is incurred when running the 

569 compile op before the execute op. The compile op returns additional 

570 information about the compilation but does not return the compiled program. 

571 

572 Args: 

573 computation: A Python function that builds the computation to replicate. 

574 inputs: A list of lists of input tensors or `None` (equivalent to 

575 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 

576 have the same number of inputs. Each input can be a nested structure 

577 containing values that are convertible to tensors. Note that passing an 

578 N-dimension list of compatible values will result in a N-dimension list of 

579 scalar tensors rather than a single Rank-N tensors. If you need different 

580 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 

581 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 

582 of arguments as inputs to computation. 

583 device_assignment: If not `None`, a `DeviceAssignment` describing the 

584 mapping between logical cores in the computation with physical cores in 

585 the TPU topology. Uses a default device assignment if `None`. The 

586 `DeviceAssignment` may be omitted if each replica of the computation uses 

587 only one core, and there is either only one replica, or the number of 

588 replicas is equal to the number of cores in the TPU system. 

589 name: (Deprecated) Does nothing. 

590 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU 

591 backends. Currently, only supports a default placement (computation is 

592 placed on GPU if one is available, and on CPU if not). 

593 maximum_shapes: A nested structure of tf.TensorShape representing the shape 

594 to which the respective component of each input element in each replica 

595 should be padded. Any unknown dimensions (e.g. 

596 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 

597 object) will be padded to the maximum size of that dimension over all 

598 replicas. The structure of `maximum_shapes` needs to be the same as 

599 `inputs[0]`. 

600 padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the 

601 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 

602 One usage is to enable automatic bucketizing on the inputs by setting the 

603 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 

604 recompilation in the XLA side. 

605 xla_options: An instance of `tpu.XLAOptions` which indicates the options 

606 passed to XLA compiler. Use `None` for default options. 

607 

608 Returns: 

609 A list of lists with the first list corresponding to the compile op and the 

610 second a list of output tensors, indexed by `[replica_num][output_num]`. 

611 Raises: 

612 ValueError: If all replicas do not have equal numbers of input tensors. 

613 ValueError: If the number of inputs per replica does not match 

614 the number of formal parameters to `computation`. 

615 ValueError: If the static `inputs` dimensions don't match with the values 

616 given in `maximum_shapes`. 

617 ValueError: If the structure of inputs per replica does not match 

618 the structure of `maximum_shapes`. 

619 """ 

620 del name 

621 inputs = [[]] if inputs is None else inputs 

622 xla_options = xla_options or XLAOptions() 

623 

624 metadata_kwargs = {} 

625 if device_assignment is not None: 

626 # Turn the Numpy array into a flattened list so we can pass it as an 

627 # operator attribute. 

628 metadata_kwargs = { 

629 "topology": 

630 device_assignment.topology.serialized(), 

631 "device_assignment": 

632 device_assignment.core_assignment.flatten().tolist() 

633 } 

634 metadata_kwargs["num_cores_per_replica"] = ( 

635 device_assignment.num_cores_per_replica) 

636 

637 # This entry is used for enabling automatic outside compilation. 

638 metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement() 

639 if config.get_soft_device_placement(): 

640 logging.info("Automatic outside compilation is enabled. " 

641 "Ops without XLA kernels will be automatically " 

642 "placed on CPU.") 

643 

644 if not isinstance(inputs, list): 

645 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, " 

646 f"received {type(inputs)}") 

647 if any(not isinstance(inp, (list, tuple)) for inp in inputs): 

648 raise TypeError( 

649 "tpu.replicate() inputs must be a list of lists/tuples, " 

650 f"received types: {[type(inp) for inp in inputs]}") 

651 

652 num_replicas = len(inputs) 

653 

654 # No replicas? Nothing to do. 

655 if num_replicas == 0: 

656 return [] 

657 

658 # Checks all replicas have the same structure. 

659 for i in range(1, num_replicas): 

660 nest.assert_same_structure(inputs[0], inputs[i]) 

661 

662 # Explicitly read variables. 

663 inputs = variable_utils.convert_variables_to_tensors(inputs) 

664 # Flatten inputs. This structure may contain None values, which will be 

665 # handled later. 

666 flat_inputs_with_nones = [ 

667 nest.flatten(per_replica_input, expand_composites=True) 

668 for per_replica_input in inputs 

669 ] 

670 # Mask parallel to one replica's inputs with True for tensors coming from 

671 # composites. 

672 is_composite = nest.flatten(nest.map_structure( 

673 lambda x: _flatten_and_filter_composite(x, False, True), inputs[0])) 

674 

675 # Converts inputs to Tensors, replacing Nones with a placeholder 0 since 

676 # tpu_ops.tpu_replicated_input() can't handle non-Tensor values. 

677 flat_inputs = [] 

678 for inp in flat_inputs_with_nones: 

679 flat_inputs.append([ 

680 constant_op.constant(0) if x is None else ops.convert_to_tensor(x) 

681 for x in inp 

682 ]) 

683 

684 # Verifies that all replicas have matching numbers and types of inputs 

685 flat_input_types = [x.dtype for x in flat_inputs[0]] 

686 input_arity = len(inputs[0]) 

687 flat_input_arity = len(flat_input_types) 

688 for i in range(num_replicas): 

689 if len(inputs[i]) != input_arity: 

690 raise ValueError("Replicas must have the same number of inputs. " 

691 "Replica 0 had {} inputs, replica {} had {} " 

692 "inputs.".format(input_arity, i, len(inputs[i]))) 

693 

694 types = [x.dtype for x in flat_inputs[i]] 

695 if types != flat_input_types: 

696 raise ValueError("Replicas must have matching input types. Replica 0 had " 

697 "input types {}, replica {} had input types {}".format( 

698 flat_input_types, i, types)) 

699 

700 arg_error = xla.check_function_argument_count( 

701 computation, input_arity, infeed_queue) 

702 if arg_error is not None: 

703 if infeed_queue is None: 

704 raise TypeError( 

705 "Supplied computation cannot be called with the specified inputs. " 

706 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, " 

707 f"but the computation needs {arg_error}") 

708 else: 

709 raise TypeError( 

710 "Supplied computation cannot be called with the specified inputs. " 

711 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ", 

712 f"and {infeed_queue.number_of_tuple_elements} additional inputs " 

713 f"from infeed, but the computation needs {arg_error}") 

714 

715 dynamic_shape_inputs = False 

716 if maximum_shapes: 

717 if infeed_queue: 

718 raise ValueError( 

719 "Dynamic input shapes are not supported with infeed queues") 

720 

721 # Make sure maximum_shapes has the same structure as inputs. 

722 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) 

723 

724 # Flatten padded shapes: 

725 # For composite tensor components, we don't want to pad them. For each 

726 # entry of maximum_shapes that corresponds to a composite tensor, replace it 

727 # by a tuple of Nones of the same length as the number of components of the 

728 # composite tensor. When we flatten a second time, this makes 

729 # flat_maximum_shapes have the same length as flat_inputs[i]. We can then 

730 # avoid padding these tensors. The assumption is that they will be used by 

731 # outside compilation or that the components are statically shaped and will 

732 # be used by tpu compatible ops. 

733 flat_maximum_shapes = nest.flatten( 

734 [_flatten_and_filter_composite(x, y) 

735 for x, y in zip(nest.flatten(inputs[0]), 

736 nest.flatten(maximum_shapes))]) 

737 flat_maximum_shapes = [ 

738 tensor_shape.TensorShape(s) if s is not None else None 

739 for s in flat_maximum_shapes 

740 ] 

741 nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes, 

742 check_types=False) 

743 

744 unpadded_inputs = flat_inputs 

745 flat_inputs, padding_maps = _pad_all_input(unpadded_inputs, 

746 flat_maximum_shapes, 

747 padding_spec) 

748 if padding_maps: 

749 dynamic_shape_inputs = True 

750 logging.info("TPU has inputs with dynamic shapes: %s", inputs[0]) 

751 

752 metadata_kwargs["step_marker_location"] = getattr( 

753 computation, "step_marker_location", "STEP_MARK_AT_ENTRY") 

754 metadata_kwargs["use_spmd_for_xla_partitioning"] = \ 

755 xla_options.use_spmd_for_xla_partitioning 

756 

757 graph = ops.get_default_graph() 

758 

759 # Fan-in: Builds a TPUReplicatedInput node for each input. 

760 flat_replicated_inputs = [] 

761 for i in range(0, len(flat_inputs[0])): 

762 replicas = [flat_inputs[replica][i] for replica in range(num_replicas)] 

763 flat_replicated_inputs.append( 

764 tpu_ops.tpu_replicated_input( 

765 replicas, name="input{}".format(i))) 

766 if isinstance(graph, func_graph.FuncGraph): 

767 # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph 

768 # object. If both outside graph and this function have a TPU cluster, 

769 # they will have the same cluster name and it will cause problems (because 

770 # we lower functional ops in Tensorflow 2.0). Append function name to 

771 # 'cluster_name' to avoid cluster name collision. 

772 cluster_name = graph.unique_name("cluster_" + graph.name) 

773 else: 

774 cluster_name = graph.unique_name("cluster") 

775 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") 

776 pivot._set_attr(_PIVOT_FOR_CLUSTER, # pylint: disable=protected-access 

777 attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))) 

778 context = tpu_replication.TPUReplicateContext( 

779 name=cluster_name, num_replicas=num_replicas, pivot=pivot) 

780 try: 

781 context.Enter() 

782 

783 metadata = tpu_ops.tpu_replicate_metadata( 

784 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) 

785 

786 with tpu_function.tpu_shard_context( 

787 num_replicas), ops.control_dependencies([metadata]): 

788 

789 if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder: 

790 for padding_map in padding_maps: 

791 input_shape = flat_replicated_inputs[padding_map.arg_index].shape 

792 flat_replicated_inputs[ 

793 padding_map.arg_index] = tf2xla.set_dynamic_dimension_size( 

794 flat_replicated_inputs[padding_map.arg_index], 

795 padding_map.shape_index, 

796 flat_replicated_inputs[padding_map.padding_arg_index]) 

797 flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape) 

798 

799 # Add identity ops so even unused inputs are "consumed" by the 

800 # computation. This is to avoid orphaned TPUReplicatedInput nodes. 

801 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput 

802 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. 

803 flat_replicated_inputs = [ 

804 array_ops.identity(x, name="replicated_input_{}".format(i)) 

805 for i, x in enumerate(flat_replicated_inputs) 

806 ] 

807 for i, composite in zip(flat_replicated_inputs, is_composite): 

808 # pylint: disable=protected-access 

809 # Add an attribute to the identity node so that they could be removed in 

810 # encapsulate TPU computation pass if unused. However we don't remove 

811 # inputs when dynamic padding is enabled. 

812 # TODO(rxsang): Use other ways except argument index in padding_map so 

813 # outside compilation can work with dynamic padding correctly. 

814 if not dynamic_shape_inputs or composite: 

815 i.op._set_attr("_tpu_input_identity", 

816 attr_value_pb2.AttrValue(b=True)) 

817 # pylint: enable=protected-access 

818 

819 # Clobber replicated placeholders with Nones. 

820 computation_inputs = [ 

821 None if inp is None else replicated for replicated, inp in zip( 

822 flat_replicated_inputs, flat_inputs_with_nones[0]) 

823 ] 

824 

825 # Unflatten the computation inputs to match original input structure. 

826 computation_inputs = nest.pack_sequence_as( 

827 structure=inputs[0], 

828 flat_sequence=computation_inputs[:flat_input_arity], 

829 expand_composites=True) 

830 

831 # If there is an infeed queue, adds the dequeued values to the 

832 # computation's inputs. 

833 if infeed_queue is not None: 

834 infeed_queue.set_number_of_shards(num_replicas) 

835 for t in infeed_queue.generate_dequeue_op(): 

836 computation_inputs.append(t) 

837 

838 # Only resource variables work inside a TPU computation, so turn on 

839 # resource variables for the computation. 

840 # TODO(phawkins): consider removing this code. It will 

841 # be less confusing to clients if they knowingly choose to use resource 

842 # variables. 

843 # Partitioned variables is not supported (b/112311320). 

844 vscope = variable_scope.get_variable_scope() 

845 saved_use_resource = vscope.use_resource 

846 saved_custom_getter = vscope.custom_getter 

847 

848 def custom_getter(getter, name, *args, **kwargs): 

849 """Variables on TPU have a few restrictions.""" 

850 partitioner = kwargs.get("partitioner", None) 

851 if partitioner is not None: 

852 kwargs["partitioner"] = None 

853 logging.warning( 

854 "Partitioned variables are not supported on TPU. Got " 

855 "`partitioner` that is %s for variable %s. " 

856 "Setting `partitioner` to `None`.", partitioner, name) 

857 if saved_custom_getter is None: 

858 return getter(name, *args, **kwargs) 

859 else: 

860 return saved_custom_getter(getter, name, *args, **kwargs) 

861 

862 vscope.set_use_resource(True) 

863 vscope.set_custom_getter(custom_getter) 

864 

865 outputs = computation(*computation_inputs) 

866 

867 vscope.set_use_resource(saved_use_resource) 

868 vscope.set_custom_getter(saved_custom_getter) 

869 

870 outputs = variable_utils.convert_variables_to_tensors(outputs) 

871 

872 need_spmd_partitioning = ( 

873 xla_options.use_spmd_for_xla_partitioning and 

874 device_assignment is not None and 

875 device_assignment.num_cores_per_replica > 1) 

876 outputs_is_flat = xla.is_flat(outputs) 

877 if outputs_is_flat: 

878 output_tensors, control_deps, pack_template = _postprocess_flat_outputs( 

879 outputs, need_spmd_partitioning) 

880 else: 

881 output_tensors, control_deps, pack_template = ( 

882 _postprocess_non_flat_outputs(outputs, need_spmd_partitioning)) 

883 

884 if tensor_tracer.TensorTracer.is_enabled(): 

885 if tf2.enabled(): 

886 logging.warn("TF API ver >= 2.0 detected. " 

887 "Tensor Tracer v1 is not enabled.") 

888 else: 

889 tt = tensor_tracer.TensorTracer() 

890 output_tensors = tt.trace_tpu(ops.get_default_graph(), 

891 output_tensors, control_deps, 

892 num_replicas) 

893 

894 context.ExitResult(output_tensors) 

895 finally: 

896 context.report_unsupported_operations() 

897 context.Exit() 

898 host_compute_core = context.HostComputeCore() 

899 

900 if host_compute_core: 

901 attr_value = attr_value_pb2.AttrValue() 

902 attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core) 

903 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access 

904 

905 with ops.control_dependencies([metadata]): 

906 if use_tpu: 

907 compile_status = tpu_ops.tpu_compilation_result() 

908 op = compile_status.op 

909 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) 

910 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access 

911 else: 

912 compile_status = control_flow_ops.no_op(name="compilation_status") 

913 

914 if not output_tensors: 

915 # Returns a list of NoOps dependent on the replication Op, indexed by 

916 # [replica_num]. 

917 return [ 

918 compile_status, 

919 [ 

920 control_flow_ops.group(control_deps, name="shard_%d" % i) 

921 for i in range(num_replicas) 

922 ] 

923 ] 

924 

925 # Fan-out: Builds a TPUReplicatedOutput node for each output. 

926 replicated_outputs = [[] for i in range(num_replicas)] 

927 for i, t in enumerate(output_tensors): 

928 

929 # None values returned by the computation can't be sent to 

930 # tpu_ops.tpu_replicated_output(), we handle them specially here. We can 

931 # avoid the placeholder 0 routine required on the inputs since outputs are 

932 # replicated per-tensor, not per-replica, so we can skip replication. 

933 if t is None: 

934 for replica in range(num_replicas): 

935 replicated_outputs[replica].append(None) 

936 continue 

937 

938 # Fan-out: Builds a TPUReplicatedOutput node for each output. 

939 ys = tpu_ops.tpu_replicated_output( 

940 t, num_replicas, name="output{}".format(i)) 

941 

942 # Wraps the outputs in identity operators so the names of any possible 

943 # `fetch` nodes are preserved by the replication rewrite. 

944 with ops.control_dependencies(control_deps): 

945 for replica in range(num_replicas): 

946 replicated_outputs[replica].append( 

947 array_ops.identity( 

948 ys[replica], name="output_%d_shard_%d" % (i, replica))) 

949 

950 replicated_outputs = [ 

951 nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True) 

952 for replica_outs in replicated_outputs 

953 ] 

954 

955 return [compile_status, replicated_outputs] 

956 

957 

958def _postprocess_flat_outputs( 

959 outputs: Any, 

960 need_spmd_partitioning: bool 

961) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 

962 """Validates non-flat outputs, add backs device assignments and other attrs. 

963 

964 Args: 

965 outputs: Output from `computation` inside `tpu.rewrite`. 

966 need_spmd_partitioning: Whether XLA SPMD partitioning is needed. 

967 

968 Returns: 

969 - Tensors extracted from outputs. 

970 - Operations extracted from outputs. 

971 - A pack template for use with nest.pack_sequence_as to pack the tensors. 

972 """ 

973 # Following code segment is to preserve legacy behavior. Previously we only 

974 # supported flat outputs and thus for consistency it was nice to convert even 

975 # single element into a tuple. But now that we support arbitrary output 

976 # structure, this is no longer necessary. 

977 # TODO(b/121383831): Migrate all legacy use cases and delete this special 

978 # case. 

979 # If the computation returns `None`, make it an empty tuple. 

980 if outputs is None: 

981 outputs = tuple() 

982 

983 # For legacy / backwards compatibility reasons we return a list for "flat" 

984 # output values (even if the user's flat return value was a different type or 

985 # even just a scalar value) so use nest.flatten to compute a flat list pack 

986 # template. 

987 pack_template = nest.flatten(outputs, expand_composites=False) 

988 

989 # Even though outputs is already "flat", we flatten any composites so their 

990 # component tensors can be tagged and replicated. The pack_template will be 

991 # used by the caller to repack the composite tensors. 

992 outputs = nest.flatten(outputs, expand_composites=True) 

993 

994 # Append `no_op` here so that fetching any return value of this function 

995 # will trigger TPUExecute node. 

996 outputs += (control_flow_ops.no_op(),) 

997 

998 maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x) 

999 try: 

1000 if need_spmd_partitioning: 

1001 outputs = [ 

1002 o if isinstance(o, ops.Operation) else maybe_convert(o) 

1003 for o in outputs 

1004 ] 

1005 else: 

1006 with ops.device(core(0)): 

1007 outputs = [ 

1008 o if isinstance(o, ops.Operation) else maybe_convert(o) 

1009 for o in outputs 

1010 ] 

1011 except Exception as e: 

1012 raise ValueError( 

1013 "TPU function return values must all either be Operations or " 

1014 f"convertible to Tensors. Got error: {e}") 

1015 

1016 # Separates the returned Operations and Tensors. 

1017 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 

1018 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 

1019 

1020 if outputs != output_tensors + output_operations: 

1021 raise ValueError( 

1022 "TPU functions must return zero-or more Tensor values followed by " 

1023 "zero or more Operations.") 

1024 

1025 # Trim operations off the end of the pack template. output_operations has 1 

1026 # extra element due to the no-op that is added. 

1027 if len(output_operations) > 1: 

1028 pack_template = pack_template[:1 - len(output_operations)] 

1029 

1030 # Wraps outputs in Identity ops. Otherwise a replicated input copied 

1031 # straight to an output would bypass the replicate(). This would be bad 

1032 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 

1033 # be rewritten away, leading to a runtime error. 

1034 # TODO(phawkins): extend the rewrite to elide these nodes instead. 

1035 new_output_tensors = [] 

1036 for t in output_tensors: 

1037 if t is None: 

1038 new_output_tensors.append(None) 

1039 elif need_spmd_partitioning: 

1040 o = array_ops.identity(t) 

1041 # pylint: disable=protected-access 

1042 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 

1043 # pylint: enable=protected-access 

1044 new_output_tensors.append(o) 

1045 else: 

1046 with ops.device(t.device if t.device else core(0)): 

1047 o = array_ops.identity(t) 

1048 # pylint: disable=protected-access 

1049 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 

1050 # pylint: enable=protected-access 

1051 new_output_tensors.append(o) 

1052 return new_output_tensors, output_operations, pack_template 

1053 

1054 

1055def _postprocess_non_flat_outputs( 

1056 outputs: Any, 

1057 need_spmd_partitioning: bool 

1058) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 

1059 """Validates non-flat outputs, add backs device assignments and other attrs. 

1060 

1061 Args: 

1062 outputs: Output from `computation` inside `tpu.rewrite`. 

1063 need_spmd_partitioning: Whether XLA SPMD partitioning is needed. 

1064 

1065 Returns: 

1066 - Tensors extracted from outputs. 

1067 - An empty Operations list because Operations are not allowed in non-flat 

1068 outputs. 

1069 - A pack template for use with nest.pack_sequence_as to pack the tensors. 

1070 """ 

1071 

1072 # Flatten output items. 

1073 flat_outputs = nest.flatten(outputs, expand_composites=True) 

1074 

1075 # Convert all non-None non-Operation outputs to Tensors. 

1076 for i, o in enumerate(flat_outputs): 

1077 if o is None: 

1078 flat_outputs[i] = None 

1079 continue 

1080 

1081 if isinstance(o, ops.Operation): 

1082 raise ValueError( 

1083 "tpu.rewrite does not support Operation as return value in non-flat " 

1084 "output structure. You can set returned Operations as control " 

1085 "dependencies of returned Tensors so Operations are triggered when " 

1086 f'Tensors are evaluated. Operation found: "{o.name}"') 

1087 

1088 try: 

1089 o = ops.convert_to_tensor(o) 

1090 except Exception as e: 

1091 raise ValueError( 

1092 "TPU function return values must all either be Operations or " 

1093 f'convertible to Tensors. Got error: "{e}"') 

1094 

1095 # Wraps outputs in Identity ops. Otherwise a replicated input copied 

1096 # straight to an output would bypass the replicate(). This would be bad 

1097 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 

1098 # be rewritten away, leading to a runtime error. 

1099 # TODO(phawkins): extend the rewrite to elide these nodes instead. 

1100 if need_spmd_partitioning: 

1101 o = array_ops.identity(o) 

1102 # pylint: disable=protected-access 

1103 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 

1104 # pylint: enable=protected-access 

1105 flat_outputs[i] = array_ops.identity(o) 

1106 else: 

1107 with ops.device(o.device if o.device else core(0)): 

1108 o = array_ops.identity(o) 

1109 # pylint: disable=protected-access 

1110 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 

1111 # pylint: enable=protected-access 

1112 flat_outputs[i] = array_ops.identity(o) 

1113 

1114 # All flat_outputs are Tensors, and no Operations. 

1115 return flat_outputs, [], outputs 

1116 

1117 

1118def split_compile_and_shard( 

1119 computation: Callable[..., Any], 

1120 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 

1121 num_shards: int = 1, 

1122 input_shard_axes: Optional[List[int]] = None, 

1123 outputs_from_all_shards: Union[bool, List[bool]] = True, 

1124 output_shard_axes: Optional[List[int]] = None, 

1125 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

1126 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

1127 name: Optional[Text] = None, 

1128 xla_options: Optional[XLAOptions] = None, 

1129 ) -> Tuple[ops.Operation, List[core_types.Tensor]]: 

1130 """Shards `computation` for parallel execution. 

1131 

1132 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 

1133 of which has a corresponding split axis (from `input_shard_axes`). Each input 

1134 is split into `num_shards` pieces along the corresponding axis, and 

1135 computation is applied to each shard in parallel. 

1136 

1137 Tensors are broadcast to all shards if they are lexically captured by 

1138 `computation`. e.g., 

1139 

1140 x = tf.constant(7) 

1141 def computation(): 

1142 return x + 3 

1143 ... = shard(computation, ...) 

1144 

1145 If `outputs_from_all_shards` is true, the outputs from all shards of 

1146 `computation` are concatenated back together along their `output_shard_axes`. 

1147 Otherwise, each output is taken from an arbitrary shard. 

1148 

1149 Inputs and outputs of the computation must be at least rank-1 Tensors. 

1150 

1151 Args: 

1152 computation: A Python function that builds a computation to apply to each 

1153 shard of the input. 

1154 inputs: A list of input tensors or None (equivalent to an empty list). Each 

1155 input tensor has a corresponding shard axes, given by `input_shard_axes`, 

1156 which must have size divisible by `num_shards`. 

1157 num_shards: The number of shards. 

1158 input_shard_axes: A list of dimensions along which to shard `inputs`, or 

1159 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 

1160 there must be one dimension per input. 

1161 outputs_from_all_shards: Boolean or list of boolean. For each output, if 

1162 `True`, outputs from all shards are concatenated along the corresponding 

1163 `output_shard_axes` entry. Otherwise, each output is taken 

1164 from an arbitrary shard. If the argument is a boolean, the argument's 

1165 value is used for each output. 

1166 output_shard_axes: A list of dimensions along which to concatenate the 

1167 outputs of `computation`, or `None`. `None` means "concatenate all outputs 

1168 along dimension 0". If not `None`, there must be one dimension per output. 

1169 Ignored if `outputs_from_all_shards` is False. 

1170 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 

1171 of `computation`. 

1172 device_assignment: If not `None`, a `DeviceAssignment` describing the 

1173 mapping between logical cores in the computation with physical cores in 

1174 the TPU topology. Uses a default device assignment if `None`. The 

1175 `DeviceAssignment` may be omitted if each shard of the computation uses 

1176 only one core, and there is either only one shard, or the number of shards 

1177 is equal to the number of cores in the TPU system. 

1178 name: (Deprecated) Does nothing. 

1179 xla_options: An instance of `tpu.XLAOptions` which indicates the options 

1180 passed to XLA compiler. Use `None` for default options. 

1181 Returns: 

1182 A tuple of (compile op, [output tensors]). 

1183 Raises: 

1184 ValueError: If num_shards <= 0 

1185 ValueError: If len(input_shard_axes) != len(inputs) 

1186 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 

1187 """ 

1188 # TODO(phawkins): consider adding support for broadcasting Tensors passed as 

1189 # inputs. 

1190 

1191 if num_shards <= 0: 

1192 raise ValueError( 

1193 f"num_shards must be a positive integer. Received {num_shards}") 

1194 

1195 inputs = [] if inputs is None else inputs 

1196 if not isinstance(inputs, list): 

1197 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None. " 

1198 f"Received {type(inputs)}") 

1199 

1200 # Converts inputs to Tensors. 

1201 inputs = [ops.convert_to_tensor(x) for x in inputs] 

1202 

1203 if input_shard_axes is None: 

1204 input_shard_axes = [0] * len(inputs) 

1205 if len(inputs) != len(input_shard_axes): 

1206 raise ValueError("Length of input_shard_axes must be equal to the number " 

1207 f"of inputs. Received {len(inputs)} inputs and " 

1208 f"{len(input_shard_axes)} input_shard_axes.") 

1209 

1210 if inputs: 

1211 # Splits the `inputs` along the corresponding `input_shard_axes`, giving 

1212 # lists with layout [input][shard] 

1213 split_inputs = [ 

1214 array_ops.split(x, num_shards, axis=axis) 

1215 for (axis, x) in zip(input_shard_axes, inputs)] 

1216 

1217 # Transposes the input lists to have layout [shard][input] 

1218 transposed_inputs = [list(i) for i in zip(*split_inputs)] 

1219 else: 

1220 transposed_inputs = [[]] * num_shards 

1221 

1222 compile_op, outputs = split_compile_and_replicate( 

1223 computation, 

1224 transposed_inputs, 

1225 infeed_queue=infeed_queue, 

1226 device_assignment=device_assignment, 

1227 name=name, 

1228 xla_options=xla_options) 

1229 

1230 # There must be at least one shard since num_shards > 0. 

1231 # TODO(b/36647078) remove disable when pylint bug is fixed. 

1232 # pylint: disable=indexing-exception 

1233 if isinstance(outputs[0], ops.Operation): 

1234 # pylint: enable=indexing-exception 

1235 # There were no outputs from the computation and replicate returned a list 

1236 # of NoOps with control dependencies on the computation. Return the first 

1237 # one so it can be used as a control dependency or fetch node. 

1238 # TODO(b/36647078) remove disable when pylint bug is fixed. 

1239 # pylint: disable=indexing-exception 

1240 return compile_op, [outputs[0]] 

1241 # pylint: enable=indexing-exception 

1242 

1243 # TODO(b/36647078) remove disable when pylint bug is fixed. 

1244 # pylint: disable=indexing-exception 

1245 num_outputs = len(outputs[0]) 

1246 # pylint: enable=indexing-exception 

1247 

1248 if output_shard_axes is None: 

1249 output_shard_axes = [0] * num_outputs 

1250 if num_outputs != len(output_shard_axes): 

1251 raise ValueError("Length of output_shard_axes must be equal to the number " 

1252 f"of outputs. Received {num_outputs} outputs " 

1253 f"and {len(output_shard_axes)} output_shard_axes.") 

1254 

1255 if isinstance(outputs_from_all_shards, bool): 

1256 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs 

1257 

1258 if num_outputs != len(outputs_from_all_shards): 

1259 raise ValueError( 

1260 "Length of outputs_from_all_shards must be equal to the number of " 

1261 f"outputs. Received {num_outputs} outputs and " 

1262 f"{len(outputs_from_all_shards)} outputs_from_all_shards.") 

1263 

1264 results = [] 

1265 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, 

1266 zip(*outputs)): 

1267 if all_shards: 

1268 # Concatenate all of the outputs together (use stack for scalars). 

1269 shape = x[0].shape 

1270 is_scalar = shape is not None and (shape.ndims == 0) 

1271 results.append((array_ops_stack.stack(list(x)) if is_scalar 

1272 else array_ops.concat(list(x), axis=axis))) 

1273 else: 

1274 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. 

1275 results.append(x[0]) 

1276 

1277 return compile_op, results 

1278 

1279 

1280@tf_export(v1=["tpu.shard"]) 

1281@traceback_utils.filter_traceback 

1282def shard( 

1283 computation: Callable[..., Any], 

1284 inputs: Optional[List[core_types.Tensor]] = None, 

1285 num_shards: int = 1, 

1286 input_shard_axes: Optional[List[int]] = None, 

1287 outputs_from_all_shards: Union[bool, List[bool]] = True, 

1288 output_shard_axes: Optional[List[int]] = None, 

1289 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

1290 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

1291 name: Optional[Text] = None, 

1292 xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]: 

1293 """Shards `computation` for parallel execution. 

1294 

1295 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 

1296 of which has a corresponding split axis (from `input_shard_axes`). Each input 

1297 is split into `num_shards` pieces along the corresponding axis, and 

1298 computation is applied to each shard in parallel. 

1299 

1300 Tensors are broadcast to all shards if they are lexically captured by 

1301 `computation`. e.g., 

1302 

1303 x = tf.constant(7) 

1304 def computation(): 

1305 return x + 3 

1306 ... = shard(computation, ...) 

1307 

1308 TODO(phawkins): consider adding support for broadcasting Tensors passed 

1309 as inputs. 

1310 

1311 If `outputs_from_all_shards` is true, the outputs from all shards of 

1312 `computation` are concatenated back together along their `output_shard_axes`. 

1313 Otherwise, each output is taken from an arbitrary shard. 

1314 

1315 Inputs and outputs of the computation must be at least rank-1 Tensors. 

1316 

1317 Args: 

1318 computation: A Python function that builds a computation to apply to each 

1319 shard of the input. 

1320 inputs: A list of input tensors or None (equivalent to an empty list). Each 

1321 input tensor has a corresponding shard axes, given by `input_shard_axes`, 

1322 which must have size divisible by `num_shards`. 

1323 num_shards: The number of shards. 

1324 input_shard_axes: A list of dimensions along which to shard `inputs`, or 

1325 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 

1326 there must be one dimension per input. 

1327 outputs_from_all_shards: Boolean or list of boolean. For each output, if 

1328 `True`, outputs from all shards are concatenated along the corresponding 

1329 `output_shard_axes` entry. Otherwise, each output is taken 

1330 from an arbitrary shard. If the argument is a boolean, the argument's 

1331 value is used for each output. 

1332 output_shard_axes: A list of dimensions along which to concatenate the 

1333 outputs of `computation`, or `None`. `None` means "concatenate all outputs 

1334 along dimension 0". If not `None`, there must be one dimension per output. 

1335 Ignored if `outputs_from_all_shards` is False. 

1336 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 

1337 of `computation`. 

1338 device_assignment: If not `None`, a `DeviceAssignment` describing the 

1339 mapping between logical cores in the computation with physical cores in 

1340 the TPU topology. Uses a default device assignment if `None`. The 

1341 `DeviceAssignment` may be omitted if each shard of the computation uses 

1342 only one core, and there is either only one shard, or the number of shards 

1343 is equal to the number of cores in the TPU system. 

1344 name: (Deprecated) Does nothing. 

1345 xla_options: An instance of `tpu.XLAOptions` which indicates the options 

1346 passed to XLA compiler. Use `None` for default options. 

1347 Returns: 

1348 A list of output tensors. 

1349 Raises: 

1350 ValueError: If num_shards <= 0 

1351 ValueError: If len(input_shard_axes) != len(inputs) 

1352 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 

1353 """ 

1354 return split_compile_and_shard( 

1355 computation, 

1356 inputs=inputs, 

1357 num_shards=num_shards, 

1358 input_shard_axes=input_shard_axes, 

1359 outputs_from_all_shards=outputs_from_all_shards, 

1360 output_shard_axes=output_shard_axes, 

1361 infeed_queue=infeed_queue, 

1362 device_assignment=device_assignment, 

1363 name=name, 

1364 xla_options=xla_options)[1] 

1365 

1366 

1367@tf_export(v1=["tpu.batch_parallel"]) 

1368@traceback_utils.filter_traceback 

1369def batch_parallel( 

1370 computation: Callable[..., Any], 

1371 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 

1372 num_shards: int = 1, 

1373 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

1374 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

1375 name: Optional[Text] = None, 

1376 xla_options: Optional[XLAOptions] = None): 

1377 """Shards `computation` along the batch dimension for parallel execution. 

1378 

1379 Convenience wrapper around shard(). 

1380 

1381 `inputs` must be a list of Tensors or None (equivalent to an empty list). 

1382 Each input is split into `num_shards` pieces along the 0-th dimension, and 

1383 computation is applied to each shard in parallel. 

1384 

1385 Tensors are broadcast to all shards if they are lexically captured by 

1386 `computation`. e.g., 

1387 

1388 x = tf.constant(7) 

1389 def computation(): 

1390 return x + 3 

1391 ... = shard(computation, ...) 

1392 

1393 The outputs from all shards are concatenated back together along their 0-th 

1394 dimension. 

1395 

1396 Inputs and outputs of the computation must be at least rank-1 Tensors. 

1397 

1398 Args: 

1399 computation: A Python function that builds a computation to apply to each 

1400 shard of the input. 

1401 inputs: A list of input tensors or None (equivalent to an empty list). The 

1402 0-th dimension of each Tensor must have size divisible by `num_shards`. 

1403 num_shards: The number of shards. 

1404 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 

1405 of arguments as inputs to `computation`. 

1406 device_assignment: If not `None`, a `DeviceAssignment` describing the 

1407 mapping between logical cores in the computation with physical cores in 

1408 the TPU topology. Uses a default device assignment if `None`. The 

1409 `DeviceAssignment` may be omitted if each shard of the computation uses 

1410 only one core, and there is either only one shard, or the number of shards 

1411 is equal to the number of cores in the TPU system. 

1412 name: (Deprecated) Does nothing. 

1413 xla_options: An instance of `tpu.XLAOptions` which indicates the options 

1414 passed to XLA compiler. Use `None` for default options. 

1415 Returns: 

1416 A list of output tensors. 

1417 Raises: 

1418 ValueError: If `num_shards <= 0` 

1419 """ 

1420 return shard( 

1421 computation, 

1422 inputs, 

1423 num_shards=num_shards, 

1424 infeed_queue=infeed_queue, 

1425 device_assignment=device_assignment, 

1426 name=name, 

1427 xla_options=xla_options) 

1428 

1429 

1430@tf_export(v1=["tpu.rewrite"]) 

1431@traceback_utils.filter_traceback 

1432def rewrite( 

1433 computation: Callable[..., Any], 

1434 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 

1435 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

1436 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

1437 name: Optional[Text] = None, 

1438 xla_options: Optional[XLAOptions] = None) -> Any: 

1439 """Rewrites `computation` for execution on a TPU system. 

1440 

1441 Args: 

1442 computation: A Python function that builds a computation to apply to the 

1443 input. If the function takes n inputs, 'inputs' should be a list of n 

1444 tensors. 

1445 

1446 `computation` may return a list of operations and tensors. Tensors must 

1447 come before operations in the returned list. The return value of 

1448 `rewrite` is a list of tensors corresponding to the tensors from the 

1449 output of `computation`. 

1450 

1451 All `Operation`s constructed during `computation` will be executed when 

1452 evaluating any of the returned output tensors, not just the ones returned. 

1453 inputs: A list of input tensors or `None` (equivalent to an empty list). 

1454 Each input can be a nested structure containing values that are 

1455 convertible to tensors. Note that passing an N-dimension list of 

1456 compatible values will result in a N-dimension list of scalar tensors 

1457 rather than a single Rank-N tensors. If you need different behavior, 

1458 convert part of inputs to tensors with `tf.convert_to_tensor`. 

1459 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 

1460 of arguments as inputs to `computation`. 

1461 device_assignment: if not `None`, a `DeviceAssignment` describing the 

1462 mapping between logical cores in the computation with physical cores in 

1463 the TPU topology. May be omitted for a single-core computation, in which 

1464 case the core attached to task 0, TPU device 0 is used. 

1465 name: (Deprecated) Does nothing. 

1466 xla_options: An instance of `tpu.XLAOptions` which indicates the options 

1467 passed to XLA compiler. Use `None` for default options. 

1468 Returns: 

1469 Same data structure as if computation(*inputs) is called directly with some 

1470 exceptions for correctness. Exceptions include: 

1471 1) None output: a NoOp would be returned which control-depends on 

1472 computation. 

1473 2) Single value output: A tuple containing the value would be returned. 

1474 3) Operation-only outputs: a NoOp would be returned which 

1475 control-depends on computation. 

1476 TODO(b/121383831): Investigate into removing these special cases. 

1477 """ 

1478 # TODO(b/36647078) remove disable when pylint bug is fixed. 

1479 # pylint: disable=indexing-exception 

1480 return replicate( 

1481 computation, 

1482 None if inputs is None else [inputs], 

1483 infeed_queue=infeed_queue, 

1484 device_assignment=device_assignment, 

1485 name=name, 

1486 xla_options=xla_options)[0] 

1487 # pylint: enable=indexing-exception 

1488 

1489 # Operations that indicate some error in the user's inference graph. 

1490 

1491 

1492_DENYLISTED_INFERENCE_OPS = set([ 

1493 "ReadVariableOp", 

1494 "AssignVariableOp", 

1495 "AssignAddVariableOp", 

1496 "AssignSubVariableOp", 

1497 "VarHandleOp", 

1498 "Variable", 

1499 "VariableV2", 

1500]) 

1501 

1502 

1503def under_tpu_inference_context() -> bool: 

1504 """Check if it is currently under `_TPUInferenceContext`.""" 

1505 graph = ops.get_default_graph() 

1506 while graph: 

1507 context = graph._get_control_flow_context() # pylint: disable=protected-access 

1508 while context: 

1509 if isinstance(context, _TPUInferenceContext): 

1510 return True 

1511 context = context.outer_context 

1512 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access 

1513 graph = graph._outer_graph # pylint: disable=protected-access 

1514 elif isinstance(graph, func_graph.FuncGraph): 

1515 graph = graph.outer_graph 

1516 else: 

1517 return False 

1518 

1519 

1520class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): 

1521 """A `ControlFlowContext` for nodes inside a TPU inference computation. 

1522 

1523 The primary role of `_TPUInferenceContext` is to indicate the mode of 

1524 operation and possibly sanity check operators inside a 

1525 tpu.rewrite_for_inference() computation. 

1526 """ 

1527 

1528 def __init__(self, name: Text, check_ops: bool = True): 

1529 super(_TPUInferenceContext, self).__init__() 

1530 self._name = name 

1531 self._check_ops = check_ops 

1532 

1533 def AddOp(self, op): 

1534 self._AddOpInternal(op) 

1535 

1536 def _AddOpInternal(self, op): 

1537 # pylint: disable=protected-access 

1538 if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS: 

1539 raise NotImplementedError( 

1540 f"Operation of type {op.type} ({op.name}) is not supported on the " 

1541 "TPU for inference. Execution will fail if this op is used in the " 

1542 "graph. Make sure your variables are using variable_scope.") 

1543 if self._outer_context: 

1544 self._outer_context.AddInnerOp(op) 

1545 

1546 def AddValue(self, val): 

1547 result = val 

1548 if self._outer_context: 

1549 result = self._outer_context.AddValue(val) 

1550 return result 

1551 

1552 def AddInnerOp(self, op): 

1553 self._AddOpInternal(op) 

1554 

1555 @property 

1556 def grad_state(self): 

1557 return None 

1558 

1559 

1560def validate_inference_rewrite_for_variables(graph: ops.Graph): 

1561 """Validates whether rewrite_for_inference() 'worked' for variables. 

1562 

1563 The rewrite_for_inference() method is supposed to append GuaranteeConstOps 

1564 after ReadVariableOps, but this mechanism works only if you are using 

1565 tf.compat.v1.get_variable() to create and access variables in your tpu 

1566 computation. This validation method can be called immediately after calling 

1567 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added 

1568 to the graph. 

1569 

1570 Typical usages: 

1571 tpu.validate_inference_rewrite_for_variables( 

1572 tf.compat.v1.get_default_graph()) 

1573 

1574 tpu.validate_inference_rewrite_for_variables(sess.graph) 

1575 

1576 Args: 

1577 graph: The graph which needs to be validated. 

1578 Raises: 

1579 RuntimeError: if validation failed. 

1580 """ 

1581 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): 

1582 raise RuntimeError( 

1583 "No GuaranteeConst ops found in the graph after running " 

1584 "tpu.rewrite_for_inference(...). Please check that you are using " 

1585 "tf.get_variable() to create and access variables in your tpu " 

1586 "computation.") 

1587 

1588 

1589def rewrite_for_inference( 

1590 computation: Callable[..., Any], 

1591 inputs: Optional[List[core_types.Tensor]] = None, 

1592 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

1593 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 

1594 name: Optional[Text] = None) -> List[core_types.Tensor]: 

1595 """Rewrites `computation` for inference on a TPU system. 

1596 

1597 Other than 'rewriting' the computation to run on a TPU, if using variables 

1598 in your computation, it moves the ReadVariableOps outside the TPU 

1599 computation, and adds GuaranteeConst ops just after the ReadVariableOps. 

1600 This mechanism works only if you are using tf.compat.v1.get_variable() to 

1601 create and access variables in your tpu computation. You can validate 

1602 whether this worked, by calling validate_inference_rewrite_for_variables() 

1603 method immediately after this method to check whether GuaranteeConstOps 

1604 where added to the graph. 

1605 

1606 Args: 

1607 computation: A Python function that builds a computation to apply to the 

1608 input. If the function takes n inputs, 'inputs' should be a list of n 

1609 tensors. If the function returns m outputs, rewrite will return a list of 

1610 m tensors. 

1611 inputs: A list of input tensors or `None` (equivalent to an empty list). 

1612 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 

1613 of arguments as inputs to `computation`. 

1614 device_assignment: if not `None`, a `DeviceAssignment` describing the 

1615 mapping between logical cores in the computation with physical cores in 

1616 the TPU topology. May be omitted for a single-core computation, in which 

1617 case the core attached to task 0, TPU device 0 is used. 

1618 name: The name of the operator. 

1619 Returns: 

1620 A list of output tensors. 

1621 """ 

1622 

1623 def guarantee_const_getter(getter, name, *args, **kwargs): 

1624 with ops.control_dependencies(None): 

1625 return array_ops.guarantee_const( 

1626 getter(name, *args, **kwargs), name=name + "/GuaranteeConst") 

1627 

1628 def wrapped_computation(*args, **kwargs): 

1629 """Execute computation under `_TPUInferenceContext`.""" 

1630 context = _TPUInferenceContext( 

1631 name=ops.get_default_graph().unique_name("rewrite_for_inference")) 

1632 try: 

1633 context.Enter() 

1634 

1635 vscope = variable_scope.get_variable_scope() 

1636 prev_custom_getter = vscope.custom_getter 

1637 prev_caching_device = vscope.caching_device 

1638 vscope.set_custom_getter(guarantee_const_getter) 

1639 vscope.set_caching_device(lambda op: op.device) 

1640 

1641 result = computation(*args, **kwargs) 

1642 

1643 vscope.set_custom_getter(prev_custom_getter) 

1644 vscope.set_caching_device(prev_caching_device) 

1645 finally: 

1646 context.Exit() 

1647 return result 

1648 

1649 # pylint: disable=undefined-variable 

1650 return rewrite( 

1651 wrapped_computation, 

1652 inputs=inputs, 

1653 infeed_queue=infeed_queue, 

1654 device_assignment=device_assignment, 

1655 name=name) 

1656 # pylint: enable=undefined-variable 

1657 

1658 

1659def prune_unconnected_ops_from_xla(prune_graph: ops.Graph): 

1660 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. 

1661 

1662 Args: 

1663 prune_graph: A tensorflow graph from which we wish to prune unconnected ops 

1664 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have 

1665 no inputs and no consumers. These can often be left behind due to graph 

1666 construction rewiring (for instance TF-Hub). While they never execute, 

1667 they will cause XLA compile to fail so we strip them from XLA compile by 

1668 removing the tpu_replicate attribute. 

1669 """ 

1670 # Scan over the top level graph and all function graphs. 

1671 for graph in [prune_graph] + [ 

1672 f for f in prune_graph._functions.values() # pylint: disable=protected-access 

1673 ]: 

1674 if not isinstance(graph, ops.Graph): 

1675 continue 

1676 for op in graph.get_operations(): 

1677 if op.type not in _UNCONNECTED_OPS_TO_PRUNE: 

1678 continue 

1679 outputs_consumed = False 

1680 for output in op.outputs: 

1681 if output.consumers(): 

1682 outputs_consumed = True 

1683 break 

1684 if not outputs_consumed: 

1685 logging.info( 

1686 "Pruning OP %s of type %s from XLA Compile due to " 

1687 "it being disconnected.", op.name, op.type) 

1688 op._clear_attr(tpu_replication._TPU_REPLICATE_ATTR) # pylint: disable=protected-access