Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_strategy.py: 22%

667 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TPU Strategy.""" 

16 

17import atexit 

18import collections 

19import contextlib 

20import copy 

21import functools 

22import weakref 

23 

24from absl import logging 

25import numpy as np 

26 

27from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 

28from tensorflow.python.autograph.impl import api as autograph 

29from tensorflow.python.compiler.xla.experimental import xla_sharding 

30from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 

31from tensorflow.python.distribute import device_util 

32from tensorflow.python.distribute import distribute_lib 

33from tensorflow.python.distribute import distribute_utils 

34from tensorflow.python.distribute import input_lib 

35from tensorflow.python.distribute import input_util 

36from tensorflow.python.distribute import numpy_dataset 

37from tensorflow.python.distribute import reduce_util 

38from tensorflow.python.distribute import tpu_replicated_variable 

39from tensorflow.python.distribute import tpu_util 

40from tensorflow.python.distribute import tpu_values 

41from tensorflow.python.distribute import values 

42from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 

43from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 

44from tensorflow.python.eager import context 

45from tensorflow.python.eager import def_function 

46from tensorflow.python.eager import function 

47from tensorflow.python.framework import constant_op 

48from tensorflow.python.framework import device as tf_device 

49from tensorflow.python.framework import device_spec 

50from tensorflow.python.framework import dtypes 

51from tensorflow.python.framework import indexed_slices 

52from tensorflow.python.framework import ops 

53from tensorflow.python.framework import sparse_tensor 

54from tensorflow.python.framework import tensor_shape 

55from tensorflow.python.framework import tensor_util 

56from tensorflow.python.ops import array_ops 

57from tensorflow.python.ops import control_flow_ops 

58from tensorflow.python.ops import math_ops 

59from tensorflow.python.ops import resource_variable_ops 

60from tensorflow.python.ops import variables as variables_lib 

61from tensorflow.python.ops.ragged import ragged_tensor 

62from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import 

63from tensorflow.python.tpu import tpu 

64from tensorflow.python.tpu import tpu_hardware_feature 

65from tensorflow.python.tpu import tpu_strategy_util 

66from tensorflow.python.tpu import training_loop 

67from tensorflow.python.tpu.ops import tpu_ops 

68from tensorflow.python.util import deprecation 

69from tensorflow.python.util import nest 

70from tensorflow.python.util import tf_inspect 

71from tensorflow.python.util.tf_export import tf_export 

72 

73 

74_XLA_OP_BY_OP_INPUTS_LIMIT = 200 

75 

76 

77@contextlib.contextmanager 

78def maybe_init_scope(): 

79 if ops.executing_eagerly_outside_functions(): 

80 yield 

81 else: 

82 with ops.init_scope(): 

83 yield 

84 

85 

86def validate_run_function(fn): 

87 """Validate the function passed into strategy.run.""" 

88 

89 # We allow three types of functions/objects passed into TPUStrategy 

90 # run in eager mode: 

91 # 1. a user annotated tf.function 

92 # 2. a ConcreteFunction, this is mostly what you get from loading a saved 

93 # model. 

94 # 3. a callable object and the `__call__` method itself is a tf.function. 

95 # 

96 # Otherwise we return an error, because we don't support eagerly running 

97 # run in TPUStrategy. 

98 

99 if context.executing_eagerly() \ 

100 and not isinstance(fn, def_function.Function) \ 

101 and not isinstance(fn, function.ConcreteFunction) \ 

102 and not (callable(fn) and isinstance(fn.__call__, def_function.Function)): 

103 raise NotImplementedError( 

104 "TPUStrategy.run(fn, ...) does not support pure eager " 

105 "execution. please make sure the function passed into " 

106 "`strategy.run` is a `tf.function` or " 

107 "`strategy.run` is called inside a `tf.function` if " 

108 "eager behavior is enabled.") 

109 

110 

111def _maybe_partial_apply_variables(fn, args, kwargs): 

112 """Inspects arguments to partially apply any DistributedVariable. 

113 

114 This avoids an automatic cast of the current variable value to tensor. 

115 

116 Note that a variable may be captured implicitly with Python scope instead of 

117 passing it to run(), but supporting run() keeps behavior consistent 

118 with MirroredStrategy. 

119 

120 Since positional arguments must be applied from left to right, this function 

121 does some tricky function inspection to move variable positional arguments 

122 into kwargs. As a result of this, we can't support passing Variables as *args, 

123 nor as args to functions which combine both explicit positional arguments and 

124 *args. 

125 

126 Args: 

127 fn: The function to run, as passed to run(). 

128 args: Positional arguments to fn, as passed to run(). 

129 kwargs: Keyword arguments to fn, as passed to run(). 

130 

131 Returns: 

132 A tuple of the function (possibly wrapped), args, kwargs (both 

133 possibly filtered, with members of args possibly moved to kwargs). 

134 If no variables are found, this function is a noop. 

135 

136 Raises: 

137 ValueError: If the function signature makes unsupported use of *args, or if 

138 too many arguments are passed. 

139 """ 

140 

141 def is_distributed_var(x): 

142 flat = nest.flatten(x) 

143 return flat and isinstance(flat[0], values.DistributedVariable) 

144 

145 # We will split kwargs into two dicts, one of which will be applied now. 

146 var_kwargs = {} 

147 nonvar_kwargs = {} 

148 

149 if kwargs: 

150 var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)} 

151 if var_kwargs: 

152 nonvar_kwargs = { 

153 k: v for k, v in kwargs.items() if not is_distributed_var(v) 

154 } 

155 

156 # Dump the argument names of `fn` to a list. This will include both positional 

157 # and keyword arguments, but since positional arguments come first we can 

158 # look up names of positional arguments by index. 

159 positional_args = [] 

160 index_of_star_args = None 

161 for i, p in enumerate(tf_inspect.signature(fn).parameters.values()): 

162 # Class methods define "self" as first argument, but we don't pass "self". 

163 # Note that this is a heuristic, as a method can name its first argument 

164 # something else, and a function can define a first argument "self" as well. 

165 # In both of these cases, using a Variable will fail with an unfortunate 

166 # error about the number of arguments. 

167 # inspect.is_method() seems not to work here, possibly due to the use of 

168 # tf.function(). 

169 if i == 0 and p.name == "self": 

170 continue 

171 

172 if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD: 

173 positional_args.append(p.name) 

174 

175 elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL: 

176 # We'll raise an error later if a variable is passed to *args, since we 

177 # can neither pass it by name nor partially apply it. This case only 

178 # happens once at most. 

179 index_of_star_args = i 

180 

181 elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY: 

182 # This is a rare Python feature, indicating a / in the arg list. 

183 if var_kwargs or any(is_distributed_var(a) for a in args): 

184 raise ValueError( 

185 "Mixing Variables and positional-only parameters not supported by " 

186 f"TPUStrategy. Received {len(var_kwargs)} DistributedVariables in " 

187 f"**kwargs and {sum(is_distributed_var(a) for a in args)} in *args," 

188 " expected zero for both." 

189 ) 

190 return fn, args, kwargs 

191 

192 star_args = [] 

193 have_seen_var_arg = False 

194 

195 for i, a in enumerate(args): 

196 if is_distributed_var(a): 

197 if index_of_star_args is not None and i >= index_of_star_args: 

198 raise ValueError( 

199 "TPUStrategy.run() cannot handle Variables passed to *args. " 

200 "Either name the function argument, or capture the Variable " 

201 "implicitly.") 

202 if len(positional_args) <= i: 

203 raise ValueError( 

204 "Too many positional arguments passed to call to TPUStrategy.run()." 

205 ) 

206 var_kwargs[positional_args[i]] = a 

207 have_seen_var_arg = True 

208 else: 

209 if index_of_star_args is not None and i >= index_of_star_args: 

210 if have_seen_var_arg: 

211 raise ValueError( 

212 "TPUStrategy.run() cannot handle both Variables and a mix of " 

213 "positional args and *args. Either remove the *args, or capture " 

214 "the Variable implicitly.") 

215 else: 

216 star_args.append(a) 

217 continue 

218 

219 if len(positional_args) <= i: 

220 raise ValueError( 

221 "Too many positional arguments passed to call to TPUStrategy.run()." 

222 ) 

223 nonvar_kwargs[positional_args[i]] = a 

224 

225 if var_kwargs: 

226 return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs 

227 return fn, args, kwargs 

228 

229 

230@tf_export("distribute.TPUStrategy", v1=[]) 

231class TPUStrategyV2(distribute_lib.Strategy): 

232 """Synchronous training on TPUs and TPU Pods. 

233 

234 To construct a TPUStrategy object, you need to run the 

235 initialization code as below: 

236 

237 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

238 >>> tf.config.experimental_connect_to_cluster(resolver) 

239 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 

240 >>> strategy = tf.distribute.TPUStrategy(resolver) 

241 

242 While using distribution strategies, the variables created within the 

243 strategy's scope will be replicated across all the replicas and can be kept in 

244 sync using all-reduce algorithms. 

245 

246 To run TF2 programs on TPUs, you can either use `.compile` and 

247 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 

248 training loop by calling `strategy.run` directly. Note that 

249 TPUStrategy doesn't support pure eager execution, so please make sure the 

250 function passed into `strategy.run` is a `tf.function` or 

251 `strategy.run` is called inside a `tf.function` if eager 

252 behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu. 

253 

254 `distribute_datasets_from_function` and 

255 `experimental_distribute_dataset` APIs can be used to distribute the dataset 

256 across the TPU workers when writing your own training loop. If you are using 

257 `fit` and `compile` methods available in `tf.keras.Model`, then Keras will 

258 handle the distribution for you. 

259 

260 An example of writing customized training loop on TPUs: 

261 

262 >>> with strategy.scope(): 

263 ... model = tf.keras.Sequential([ 

264 ... tf.keras.layers.Dense(2, input_shape=(5,)), 

265 ... ]) 

266 ... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 

267 

268 >>> def dataset_fn(ctx): 

269 ... x = np.random.random((2, 5)).astype(np.float32) 

270 ... y = np.random.randint(2, size=(2, 1)) 

271 ... dataset = tf.data.Dataset.from_tensor_slices((x, y)) 

272 ... return dataset.repeat().batch(1, drop_remainder=True) 

273 >>> dist_dataset = strategy.distribute_datasets_from_function( 

274 ... dataset_fn) 

275 >>> iterator = iter(dist_dataset) 

276 

277 >>> @tf.function() 

278 ... def train_step(iterator): 

279 ... 

280 ... def step_fn(inputs): 

281 ... features, labels = inputs 

282 ... with tf.GradientTape() as tape: 

283 ... logits = model(features, training=True) 

284 ... loss = tf.keras.losses.sparse_categorical_crossentropy( 

285 ... labels, logits) 

286 ... 

287 ... grads = tape.gradient(loss, model.trainable_variables) 

288 ... optimizer.apply_gradients(zip(grads, model.trainable_variables)) 

289 ... 

290 ... strategy.run(step_fn, args=(next(iterator),)) 

291 

292 >>> train_step(iterator) 

293 

294 For the advanced use cases like model parallelism, you can set 

295 `experimental_device_assignment` argument when creating TPUStrategy to specify 

296 number of replicas and number of logical devices. Below is an example to 

297 initialize TPU system with 2 logical devices and 1 replica. 

298 

299 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

300 >>> tf.config.experimental_connect_to_cluster(resolver) 

301 >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver) 

302 >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build( 

303 ... topology, 

304 ... computation_shape=[1, 1, 1, 2], 

305 ... num_replicas=1) 

306 >>> strategy = tf.distribute.TPUStrategy( 

307 ... resolver, experimental_device_assignment=device_assignment) 

308 

309 Then you can run a `tf.add` operation only on logical device 0. 

310 

311 >>> @tf.function() 

312 ... def step_fn(inputs): 

313 ... features, _ = inputs 

314 ... output = tf.add(features, features) 

315 ... 

316 ... # Add operation will be executed on logical device 0. 

317 ... output = strategy.experimental_assign_to_logical_device(output, 0) 

318 ... return output 

319 >>> dist_dataset = strategy.distribute_datasets_from_function( 

320 ... dataset_fn) 

321 >>> iterator = iter(dist_dataset) 

322 >>> strategy.run(step_fn, args=(next(iterator),)) 

323 

324 `experimental_spmd_xla_partitioning` enables the experimental XLA SPMD feature 

325 for model parallelism. This flag can reduce the compilation time and HBM 

326 requirements. When running in this mode, every input tensor must either be 

327 partitioned (via `strategy.experimental_split_to_logical_devices`) or fully 

328 replicated (via `strategy.experimental_replicate_to_logical_devices`) to all 

329 logical devices. And calling `strategy.experimental_assign_to_logical_device` 

330 will result in a ValueError in this mode. 

331 """ 

332 

333 def __init__(self, 

334 tpu_cluster_resolver=None, 

335 experimental_device_assignment=None, 

336 experimental_spmd_xla_partitioning=False): 

337 """Synchronous training in TPU donuts or Pods. 

338 

339 Args: 

340 tpu_cluster_resolver: A 

341 `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which 

342 provides information about the TPU cluster. If None, it will assume 

343 running on a local TPU worker. 

344 experimental_device_assignment: Optional 

345 `tf.tpu.experimental.DeviceAssignment` to specify the placement of 

346 replicas on the TPU cluster. 

347 experimental_spmd_xla_partitioning: If True, enable the SPMD (Single 

348 Program Multiple Data) mode in XLA compiler. This flag only affects the 

349 performance of XLA compilation and the HBM requirement of the compiled 

350 TPU program. Ceveat: if this flag is True, calling 

351 `tf.distribute.TPUStrategy.experimental_assign_to_logical_device` will 

352 result in a ValueError. 

353 """ 

354 super(TPUStrategyV2, self).__init__( 

355 TPUExtended( 

356 self, 

357 tpu_cluster_resolver, 

358 device_assignment=experimental_device_assignment, 

359 use_spmd_for_xla_partitioning=experimental_spmd_xla_partitioning, 

360 enable_data_reorder=experimental_device_assignment is not None, 

361 ) 

362 ) 

363 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 

364 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

365 "num_workers").set(self.extended.num_hosts) 

366 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

367 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 

368 # Packed variable is used to reduce the overhead of function execution. 

369 # For a DistributedVariable, only one variable handle is captured into a 

370 # function graph. It's only supported in eager mode. 

371 self._enable_packed_variable_in_eager_mode = True 

372 

373 def run(self, fn, args=(), kwargs=None, options=None): 

374 """Run the computation defined by `fn` on each TPU replica. 

375 

376 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 

377 `tf.distribute.DistributedValues`, such as those produced by a 

378 `tf.distribute.DistributedDataset` from 

379 `tf.distribute.Strategy.experimental_distribute_dataset` or 

380 `tf.distribute.Strategy.distribute_datasets_from_function`, 

381 when `fn` is executed on a particular replica, it will be executed with the 

382 component of `tf.distribute.DistributedValues` that correspond to that 

383 replica. 

384 

385 `fn` may call `tf.distribute.get_replica_context()` to access members such 

386 as `all_reduce`. 

387 

388 All arguments in `args` or `kwargs` should either be nest of tensors or 

389 `tf.distribute.DistributedValues` containing tensors or composite tensors. 

390 

391 Example usage: 

392 

393 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

394 >>> tf.config.experimental_connect_to_cluster(resolver) 

395 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 

396 >>> strategy = tf.distribute.TPUStrategy(resolver) 

397 >>> @tf.function 

398 ... def run(): 

399 ... def value_fn(value_context): 

400 ... return value_context.num_replicas_in_sync 

401 ... distributed_values = ( 

402 ... strategy.experimental_distribute_values_from_function(value_fn)) 

403 ... def replica_fn(input): 

404 ... return input * 2 

405 ... return strategy.run(replica_fn, args=(distributed_values,)) 

406 >>> result = run() 

407 

408 Args: 

409 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 

410 args: (Optional) Positional arguments to `fn`. 

411 kwargs: (Optional) Keyword arguments to `fn`. 

412 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 

413 the options to run `fn`. 

414 

415 Returns: 

416 Merged return value of `fn` across replicas. The structure of the return 

417 value is the same as the return value from `fn`. Each element in the 

418 structure can either be `tf.distribute.DistributedValues`, `Tensor` 

419 objects, or `Tensor`s (for example, if running on a single replica). 

420 """ 

421 validate_run_function(fn) 

422 

423 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 

424 

425 # Note: the target function is converted to graph even when in Eager mode, 

426 # so autograph is on by default here. 

427 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 

428 options = options or distribute_lib.RunOptions() 

429 return self.extended.tpu_run(fn, args, kwargs, options) 

430 

431 @property 

432 def cluster_resolver(self): 

433 """Returns the cluster resolver associated with this strategy. 

434 

435 `tf.distribute.TPUStrategy` provides the associated 

436 `tf.distribute.cluster_resolver.ClusterResolver`. If the user provides one 

437 in `__init__`, that instance is returned; if the user does not, a default 

438 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided. 

439 """ 

440 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access 

441 

442 def experimental_assign_to_logical_device(self, tensor, logical_device_id): 

443 """Adds annotation that `tensor` will be assigned to a logical device. 

444 

445 This adds an annotation to `tensor` specifying that operations on 

446 `tensor` will be invoked on logical core device id `logical_device_id`. 

447 When model parallelism is used, the default behavior is that all ops 

448 are placed on zero-th logical device. 

449 

450 ```python 

451 

452 # Initializing TPU system with 2 logical devices and 4 replicas. 

453 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

454 tf.config.experimental_connect_to_cluster(resolver) 

455 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 

456 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 

457 topology, 

458 computation_shape=[1, 1, 1, 2], 

459 num_replicas=4) 

460 strategy = tf.distribute.TPUStrategy( 

461 resolver, experimental_device_assignment=device_assignment) 

462 iterator = iter(inputs) 

463 

464 @tf.function() 

465 def step_fn(inputs): 

466 output = tf.add(inputs, inputs) 

467 

468 # Add operation will be executed on logical device 0. 

469 output = strategy.experimental_assign_to_logical_device(output, 0) 

470 return output 

471 

472 strategy.run(step_fn, args=(next(iterator),)) 

473 ``` 

474 

475 Args: 

476 tensor: Input tensor to annotate. 

477 logical_device_id: Id of the logical core to which the tensor will be 

478 assigned. 

479 

480 Raises: 

481 ValueError: The logical device id presented is not consistent with total 

482 number of partitions specified by the device assignment or the TPUStrategy 

483 is constructed with `experimental_spmd_xla_partitioning=True`. 

484 

485 Returns: 

486 Annotated tensor with identical value as `tensor`. 

487 """ 

488 if self.extended._use_spmd_for_xla_partitioning: # pylint: disable=protected-access 

489 raise ValueError( 

490 "Cannot assign a tensor to a logical device in SPMD mode. To disable " 

491 "SPMD, Please construct the TPUStrategy with " 

492 "`experimental_spmd_xla_partitioning=False`") 

493 

494 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 

495 if (logical_device_id < 0 or 

496 logical_device_id >= num_logical_devices_per_replica): 

497 raise ValueError("`logical_core_id` to assign must be lower then total " 

498 "number of logical devices per replica. Received " 

499 "logical device id {} but there are only total of {} " 

500 "logical devices in replica.".format( 

501 logical_device_id, num_logical_devices_per_replica)) 

502 return xla_sharding.assign_device( 

503 tensor, logical_device_id, use_sharding_op=True) 

504 

505 def experimental_split_to_logical_devices(self, tensor, partition_dimensions): 

506 """Adds annotation that `tensor` will be split across logical devices. 

507 

508 This adds an annotation to tensor `tensor` specifying that operations on 

509 `tensor` will be split among multiple logical devices. Tensor `tensor` will 

510 be split across dimensions specified by `partition_dimensions`. 

511 The dimensions of `tensor` must be divisible by corresponding value in 

512 `partition_dimensions`. 

513 

514 For example, for system with 8 logical devices, if `tensor` is an image 

515 tensor with shape (batch_size, width, height, channel) and 

516 `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split 

517 2 in width dimension and 4 way in height dimension and the split 

518 tensor values will be fed into 8 logical devices. 

519 

520 ```python 

521 # Initializing TPU system with 8 logical devices and 1 replica. 

522 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

523 tf.config.experimental_connect_to_cluster(resolver) 

524 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 

525 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 

526 topology, 

527 computation_shape=[1, 2, 2, 2], 

528 num_replicas=1) 

529 # Construct the TPUStrategy. Since we are going to split the image across 

530 # logical devices, here we set `experimental_spmd_xla_partitioning=True` 

531 # so that the partitioning can be compiled in SPMD mode, which usually 

532 # results in faster compilation and smaller HBM requirement if the size of 

533 # input and activation tensors are much bigger than that of the model 

534 # parameters. Note that this flag is suggested but not a hard requirement 

535 # for `experimental_split_to_logical_devices`. 

536 strategy = tf.distribute.TPUStrategy( 

537 resolver, experimental_device_assignment=device_assignment, 

538 experimental_spmd_xla_partitioning=True) 

539 

540 iterator = iter(inputs) 

541 

542 @tf.function() 

543 def step_fn(inputs): 

544 inputs = strategy.experimental_split_to_logical_devices( 

545 inputs, [1, 2, 4, 1]) 

546 

547 # model() function will be executed on 8 logical devices with `inputs` 

548 # split 2 * 4 ways. 

549 output = model(inputs) 

550 return output 

551 

552 strategy.run(step_fn, args=(next(iterator),)) 

553 ``` 

554 Args: 

555 tensor: Input tensor to annotate. 

556 partition_dimensions: An unnested list of integers with the size equal to 

557 rank of `tensor` specifying how `tensor` will be partitioned. The 

558 product of all elements in `partition_dimensions` must be equal to the 

559 total number of logical devices per replica. 

560 

561 Raises: 

562 ValueError: 1) If the size of partition_dimensions does not equal to rank 

563 of `tensor` or 2) if product of elements of `partition_dimensions` does 

564 not match the number of logical devices per replica defined by the 

565 implementing DistributionStrategy's device specification or 

566 3) if a known size of `tensor` is not divisible by corresponding 

567 value in `partition_dimensions`. 

568 

569 Returns: 

570 Annotated tensor with identical value as `tensor`. 

571 """ 

572 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 

573 num_partition_splits = np.prod(partition_dimensions) 

574 input_shape = tensor.shape 

575 tensor_rank = len(input_shape) 

576 

577 if tensor_rank != len(partition_dimensions): 

578 raise ValueError("Length of `partition_dimensions` must equal to the " 

579 "rank of `tensor.shape` ({}). Received " 

580 "len(partition_dimensions)={}.".format( 

581 tensor_rank, len(partition_dimensions))) 

582 

583 for dim_index, dim_size in enumerate(input_shape): 

584 if dim_size is None: 

585 continue 

586 

587 split_size = partition_dimensions[dim_index] 

588 if dim_size % split_size != 0: 

589 raise ValueError("Tensor shape at `partition_dimensions[{}]` must be " 

590 "divisible by corresponding value specified " 

591 "by `partition_dimensions` ({}). Received: {}.".format( 

592 dim_index, split_size, dim_size)) 

593 

594 if num_partition_splits != num_logical_devices_per_replica: 

595 raise ValueError( 

596 "The product of `partition_dimensions` should be the same as the " 

597 "number of logical devices (={}). Received `partition_dimensions`={}," 

598 "and their product is {}.".format(num_logical_devices_per_replica, 

599 partition_dimensions, 

600 num_partition_splits)) 

601 

602 tile_assignment = np.arange(num_partition_splits).reshape( 

603 partition_dimensions) 

604 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) 

605 

606 def experimental_replicate_to_logical_devices(self, tensor): 

607 """Adds annotation that `tensor` will be replicated to all logical devices. 

608 

609 This adds an annotation to tensor `tensor` specifying that operations on 

610 `tensor` will be invoked on all logical devices. 

611 

612 ```python 

613 # Initializing TPU system with 2 logical devices and 4 replicas. 

614 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

615 tf.config.experimental_connect_to_cluster(resolver) 

616 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 

617 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 

618 topology, 

619 computation_shape=[1, 1, 1, 2], 

620 num_replicas=4) 

621 strategy = tf.distribute.TPUStrategy( 

622 resolver, experimental_device_assignment=device_assignment) 

623 

624 iterator = iter(inputs) 

625 

626 @tf.function() 

627 def step_fn(inputs): 

628 images, labels = inputs 

629 images = strategy.experimental_split_to_logical_devices( 

630 inputs, [1, 2, 4, 1]) 

631 

632 # model() function will be executed on 8 logical devices with `inputs` 

633 # split 2 * 4 ways. 

634 output = model(inputs) 

635 

636 # For loss calculation, all logical devices share the same logits 

637 # and labels. 

638 labels = strategy.experimental_replicate_to_logical_devices(labels) 

639 output = strategy.experimental_replicate_to_logical_devices(output) 

640 loss = loss_fn(labels, output) 

641 

642 return loss 

643 

644 strategy.run(step_fn, args=(next(iterator),)) 

645 ``` 

646 Args: 

647 tensor: Input tensor to annotate. 

648 

649 Returns: 

650 Annotated tensor with identical value as `tensor`. 

651 """ 

652 return xla_sharding.replicate(tensor, use_sharding_op=True) 

653 

654 

655@tf_export("distribute.experimental.TPUStrategy", v1=[]) 

656@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy") 

657class TPUStrategy(distribute_lib.Strategy): 

658 """Synchronous training on TPUs and TPU Pods. 

659 

660 To construct a TPUStrategy object, you need to run the 

661 initialization code as below: 

662 

663 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

664 >>> tf.config.experimental_connect_to_cluster(resolver) 

665 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 

666 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 

667 

668 While using distribution strategies, the variables created within the 

669 strategy's scope will be replicated across all the replicas and can be kept in 

670 sync using all-reduce algorithms. 

671 

672 To run TF2 programs on TPUs, you can either use `.compile` and 

673 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 

674 training loop by calling `strategy.run` directly. Note that 

675 TPUStrategy doesn't support pure eager execution, so please make sure the 

676 function passed into `strategy.run` is a `tf.function` or 

677 `strategy.run` is called inside a `tf.function` if eager 

678 behavior is enabled. 

679 """ 

680 

681 def __init__(self, 

682 tpu_cluster_resolver=None, 

683 device_assignment=None): 

684 """Synchronous training in TPU donuts or Pods. 

685 

686 Args: 

687 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 

688 which provides information about the TPU cluster. 

689 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 

690 specify the placement of replicas on the TPU cluster. 

691 """ 

692 logging.warning( 

693 "`tf.distribute.experimental.TPUStrategy` is deprecated, please use " 

694 "the non-experimental symbol `tf.distribute.TPUStrategy` instead.") 

695 

696 super(TPUStrategy, self).__init__( 

697 TPUExtended( 

698 self, 

699 tpu_cluster_resolver, 

700 device_assignment=device_assignment, 

701 enable_data_reorder=device_assignment is not None, 

702 ) 

703 ) 

704 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 

705 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

706 "num_workers").set(self.extended.num_hosts) 

707 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

708 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 

709 # Packed variable is used to reduce the overhead of function execution. 

710 # For a DistributedVariable, only one variable handle is captured into a 

711 # function graph. It's only supported in eager mode. 

712 self._enable_packed_variable_in_eager_mode = True 

713 

714 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 

715 # can use the default implementation. 

716 # This implementation runs a single step. It does not use infeed or outfeed. 

717 def run(self, fn, args=(), kwargs=None, options=None): 

718 """See base class.""" 

719 validate_run_function(fn) 

720 

721 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 

722 

723 # Note: the target function is converted to graph even when in Eager mode, 

724 # so autograph is on by default here. 

725 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 

726 options = options or distribute_lib.RunOptions() 

727 return self.extended.tpu_run(fn, args, kwargs, options) 

728 

729 @property 

730 def cluster_resolver(self): 

731 """Returns the cluster resolver associated with this strategy. 

732 

733 `tf.distribute.experimental.TPUStrategy` provides the 

734 associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user 

735 provides one in `__init__`, that instance is returned; if the user does 

736 not, a default 

737 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided. 

738 """ 

739 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access 

740 

741 

742@tf_export(v1=["distribute.experimental.TPUStrategy"]) 

743class TPUStrategyV1(distribute_lib.StrategyV1): 

744 """TPU distribution strategy implementation.""" 

745 

746 def __init__(self, 

747 tpu_cluster_resolver=None, 

748 steps_per_run=None, 

749 device_assignment=None): 

750 """Initializes the TPUStrategy object. 

751 

752 Args: 

753 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 

754 which provides information about the TPU cluster. 

755 steps_per_run: Number of steps to run on device before returning to the 

756 host. Note that this can have side-effects on performance, hooks, 

757 metrics, summaries etc. 

758 This parameter is only used when Distribution Strategy is used with 

759 estimator or keras. 

760 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 

761 specify the placement of replicas on the TPU cluster. Currently only 

762 supports the usecase of using a single core within a TPU cluster. 

763 """ 

764 super(TPUStrategyV1, self).__init__(TPUExtended( 

765 self, tpu_cluster_resolver, steps_per_run, device_assignment)) 

766 distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy") 

767 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

768 "num_workers").set(self.extended.num_hosts) 

769 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

770 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 

771 # Packed variable is used to reduce the overhead of function execution. 

772 # For a DistributedVariable, only one variable handle is captured into a 

773 # function graph. It's only supported in eager mode. 

774 self._enable_packed_variable_in_eager_mode = True 

775 

776 @property 

777 def steps_per_run(self): 

778 """DEPRECATED: use .extended.steps_per_run instead.""" 

779 return self._extended.steps_per_run 

780 

781 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 

782 # can use the default implementation. 

783 # This implementation runs a single step. It does not use infeed or outfeed. 

784 def run(self, fn, args=(), kwargs=None, options=None): 

785 """Run `fn` on each replica, with the given arguments. 

786 

787 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 

788 "per-replica" values, such as those produced by a "distributed `Dataset`", 

789 when `fn` is executed on a particular replica, it will be executed with the 

790 component of those "per-replica" values that correspond to that replica. 

791 

792 `fn` may call `tf.distribute.get_replica_context()` to access members such 

793 as `all_reduce`. 

794 

795 All arguments in `args` or `kwargs` should either be nest of tensors or 

796 per-replica objects containing tensors or composite tensors. 

797 

798 Users can pass strategy specific options to `options` argument. An example 

799 to enable bucketizing dynamic shapes in `TPUStrategy.run` 

800 is: 

801 

802 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

803 >>> tf.config.experimental_connect_to_cluster(resolver) 

804 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 

805 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 

806 

807 >>> options = tf.distribute.RunOptions( 

808 ... experimental_bucketizing_dynamic_shape=True) 

809 

810 >>> dataset = tf.data.Dataset.range( 

811 ... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( 

812 ... strategy.num_replicas_in_sync, drop_remainder=True) 

813 >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 

814 

815 >>> @tf.function() 

816 ... def step_fn(inputs): 

817 ... output = tf.reduce_sum(inputs) 

818 ... return output 

819 

820 >>> strategy.run(step_fn, args=(next(input_iterator),), options=options) 

821 

822 Args: 

823 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 

824 args: (Optional) Positional arguments to `fn`. 

825 kwargs: (Optional) Keyword arguments to `fn`. 

826 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 

827 the options to run `fn`. 

828 

829 Returns: 

830 Merged return value of `fn` across replicas. The structure of the return 

831 value is the same as the return value from `fn`. Each element in the 

832 structure can either be "per-replica" `Tensor` objects or `Tensor`s 

833 (for example, if running on a single replica). 

834 """ 

835 validate_run_function(fn) 

836 

837 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 

838 

839 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 

840 options = options or distribute_lib.RunOptions() 

841 return self.extended.tpu_run(fn, args, kwargs, options) 

842 

843 

844# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 

845class TPUExtended(distribute_lib.StrategyExtendedV1): 

846 """Implementation of TPUStrategy.""" 

847 

848 def __init__( 

849 self, 

850 container_strategy, 

851 tpu_cluster_resolver=None, 

852 steps_per_run=None, 

853 device_assignment=None, 

854 use_spmd_for_xla_partitioning=False, 

855 enable_data_reorder=False, 

856 ): 

857 super(TPUExtended, self).__init__(container_strategy) 

858 

859 if tpu_cluster_resolver is None: 

860 tpu_cluster_resolver = TPUClusterResolver("") 

861 

862 if steps_per_run is None: 

863 # TODO(frankchn): Warn when we are being used by DS/Keras and this is 

864 # not specified. 

865 steps_per_run = 1 

866 

867 # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a 

868 # `tf.function` is passed into `strategy.run` in eager mode, the 

869 # `tf.function` won't get retraced. 

870 self._tpu_function_cache = weakref.WeakKeyDictionary() 

871 

872 self._tpu_cluster_resolver = tpu_cluster_resolver 

873 self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata() 

874 self._device_assignment = device_assignment 

875 

876 tpu_devices_flat = [ 

877 d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name] 

878 

879 # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is 

880 # indexed using `[replica_id][logical_device_id]`. 

881 if device_assignment is None: 

882 self._tpu_devices = np.array( 

883 [[d] for d in tpu_devices_flat], dtype=object) 

884 else: 

885 job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job 

886 

887 tpu_devices = [] 

888 for replica_id in range(device_assignment.num_replicas): 

889 replica_devices = [] 

890 

891 for logical_core in range(device_assignment.num_cores_per_replica): 

892 replica_devices.append( 

893 device_util.canonicalize( 

894 device_assignment.tpu_device( 

895 replica=replica_id, 

896 logical_core=logical_core, 

897 job=job_name))) 

898 

899 tpu_devices.append(replica_devices) 

900 self._tpu_devices = np.array(tpu_devices, dtype=object) 

901 

902 self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0]) 

903 

904 # Preload the data onto the TPUs. Currently we always preload onto logical 

905 # device 0 for each replica. 

906 # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the 

907 # input onto a different logical device? 

908 self._device_input_worker_devices = collections.OrderedDict() 

909 self._host_input_worker_devices = collections.OrderedDict() 

910 for tpu_device in self._tpu_devices[:, 0]: 

911 host_device = device_util.get_host_for_device(tpu_device) 

912 self._device_input_worker_devices.setdefault(host_device, []) 

913 self._device_input_worker_devices[host_device].append(tpu_device) 

914 self._host_input_worker_devices.setdefault(host_device, []) 

915 self._host_input_worker_devices[host_device].append(host_device) 

916 

917 # Create the replica order based on the assigned device order. 

918 # This replica order will be used to match the IteratorGetNext ops 

919 # with the device assigment. 

920 self._replica_order = ( 

921 self._get_replica_order(self._tpu_devices[:, 0]) 

922 if enable_data_reorder 

923 else None 

924 ) 

925 

926 # TODO(sourabhbajaj): Remove this once performance of running one step 

927 # at a time is comparable to multiple steps. 

928 self.steps_per_run = steps_per_run 

929 self._require_static_shapes = True 

930 

931 self.experimental_enable_get_next_as_optional = True 

932 

933 self._logical_device_stack = [0] 

934 

935 if context.executing_eagerly(): 

936 # In async remote eager, we want to sync the executors before exiting the 

937 # program. 

938 atexit.register(context.async_wait) 

939 

940 # Flag to turn on VariablePolicy. Var policy is deprecated because there is 

941 # another effort unifying DistributedVariables (see values_v2.py). SPMD XLA 

942 # partitioning is not implemented for var policies. 

943 # TODO(b/202048882): remove var policy from TPUStrategy. 

944 self._use_var_policy = not use_spmd_for_xla_partitioning 

945 

946 # Flag to enable XLA SPMD partitioning. 

947 self._use_spmd_for_xla_partitioning = use_spmd_for_xla_partitioning 

948 

949 def _get_replica_order(self, tpu_devices): 

950 """Get the replica order based on the tpu device order. 

951 

952 For example, if the tpu_devices are: 

953 '/job:worker/replica:0/task:0/device:TPU:0', 

954 '/job:worker/replica:0/task:0/device:TPU:2', 

955 '/job:worker/replica:0/task:1/device:TPU:0', 

956 '/job:worker/replica:0/task:1/device:TPU:2', 

957 '/job:worker/replica:0/task:1/device:TPU:6', 

958 '/job:worker/replica:0/task:1/device:TPU:4', 

959 '/job:worker/replica:0/task:0/device:TPU:6', 

960 '/job:worker/replica:0/task:0/device:TPU:4', 

961 

962 the returned replica order will be: 

963 [0, 1, 7, 6, 2, 3, 5, 4] 

964 

965 This replica order will be used to reorder the data returned by the 

966 iterators, 

967 so that they can be placed on the same node as their computation graphs. 

968 

969 Args: 

970 tpu_devices (List[str]): A list of tpu device names in the order of 

971 replicas. 

972 

973 Returns: 

974 A list containing the order ids of corresponding TPU devices. 

975 """ 

976 devices_with_ids = [] 

977 for i, tpu_device in enumerate(tpu_devices): 

978 spec = tf_device.DeviceSpec.from_string(tpu_device) 

979 devices_with_ids.append(( 

980 ( 

981 spec.job, 

982 spec.replica, 

983 spec.device_type, 

984 spec.task, 

985 spec.device_index, 

986 ), 

987 i, 

988 )) 

989 return [i for _, i in sorted(devices_with_ids)] 

990 

991 def _validate_colocate_with_variable(self, colocate_with_variable): 

992 distribute_utils.validate_colocate(colocate_with_variable, self) 

993 

994 def _make_dataset_iterator(self, dataset): 

995 """Make iterators for each of the TPU hosts.""" 

996 input_workers = input_lib.InputWorkers( 

997 tuple(self._device_input_worker_devices.items())) 

998 return input_lib_v1.DatasetIterator( 

999 dataset, 

1000 input_workers, 

1001 self._container_strategy(), 

1002 num_replicas_in_sync=self._num_replicas_in_sync) 

1003 

1004 def _make_input_fn_iterator( 

1005 self, 

1006 input_fn, 

1007 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 

1008 input_contexts = [] 

1009 input_workers = input_lib.InputWorkers( 

1010 tuple(self._device_input_worker_devices.items())) 

1011 num_workers = input_workers.num_workers 

1012 for i in range(num_workers): 

1013 input_contexts.append( 

1014 distribute_lib.InputContext( 

1015 num_input_pipelines=num_workers, 

1016 input_pipeline_id=i, 

1017 num_replicas_in_sync=self._num_replicas_in_sync)) 

1018 return input_lib_v1.InputFunctionIterator(input_fn, input_workers, 

1019 input_contexts, 

1020 self._container_strategy()) 

1021 

1022 def _experimental_make_numpy_dataset(self, numpy_input, session): 

1023 return numpy_dataset.one_host_numpy_dataset( 

1024 numpy_input, numpy_dataset.SingleDevice(self._host_device), 

1025 session) 

1026 

1027 def _get_input_workers(self, options): 

1028 if not options or options.experimental_fetch_to_device: 

1029 return input_lib.InputWorkers( 

1030 tuple(self._device_input_worker_devices.items())) 

1031 else: 

1032 return input_lib.InputWorkers( 

1033 tuple(self._host_input_worker_devices.items())) 

1034 

1035 def _check_spec(self, element_spec): 

1036 if isinstance(element_spec, values.PerReplicaSpec): 

1037 element_spec = element_spec._component_specs # pylint: disable=protected-access 

1038 specs = nest.flatten_with_joined_string_paths(element_spec) 

1039 for path, spec in specs: 

1040 if isinstance(spec, (sparse_tensor.SparseTensorSpec, 

1041 ragged_tensor.RaggedTensorSpec)): 

1042 raise ValueError( 

1043 "Found tensor {} with spec {}. TPUStrategy does not support " 

1044 "distributed datasets with device prefetch when using sparse or " 

1045 "ragged tensors. If you intend to use sparse or ragged tensors, " 

1046 "please pass a tf.distribute.InputOptions object with " 

1047 "experimental_fetch_to_device set to False to your dataset " 

1048 "distribution function.".format(path, type(spec))) 

1049 

1050 def _experimental_distribute_dataset(self, dataset, options): 

1051 if (options and options.experimental_replication_mode == 

1052 distribute_lib.InputReplicationMode.PER_REPLICA): 

1053 raise NotImplementedError( 

1054 "InputReplicationMode.PER_REPLICA " 

1055 "is only supported in " 

1056 "`experimental_distribute_datasets_from_function`." 

1057 ) 

1058 if options is None or options.experimental_fetch_to_device: 

1059 self._check_spec(dataset.element_spec) 

1060 

1061 return input_util.get_distributed_dataset( 

1062 dataset, 

1063 self._get_input_workers(options), 

1064 self._container_strategy(), 

1065 num_replicas_in_sync=self._num_replicas_in_sync, 

1066 options=options, 

1067 replica_order=self._replica_order, 

1068 ) 

1069 

1070 def _distribute_datasets_from_function(self, dataset_fn, options): 

1071 if (options and options.experimental_replication_mode == 

1072 distribute_lib.InputReplicationMode.PER_REPLICA): 

1073 raise NotImplementedError( 

1074 "InputReplicationMode.PER_REPLICA " 

1075 "is only supported in " 

1076 " `experimental_distribute_datasets_from_function` " 

1077 "of tf.distribute.MirroredStrategy") 

1078 input_workers = self._get_input_workers(options) 

1079 input_contexts = [] 

1080 num_workers = input_workers.num_workers 

1081 for i in range(num_workers): 

1082 input_contexts.append(distribute_lib.InputContext( 

1083 num_input_pipelines=num_workers, 

1084 input_pipeline_id=i, 

1085 num_replicas_in_sync=self._num_replicas_in_sync)) 

1086 

1087 distributed_dataset = input_util.get_distributed_datasets_from_function( 

1088 dataset_fn, 

1089 input_workers, 

1090 input_contexts, 

1091 self._container_strategy(), 

1092 options=options, 

1093 replica_order=self._replica_order, 

1094 ) 

1095 

1096 # We can only check after the dataset_fn is called. 

1097 if options is None or options.experimental_fetch_to_device: 

1098 self._check_spec(distributed_dataset.element_spec) 

1099 return distributed_dataset 

1100 

1101 def _experimental_distribute_values_from_function(self, value_fn): 

1102 per_replica_values = [] 

1103 for replica_id in range(self._num_replicas_in_sync): 

1104 per_replica_values.append( 

1105 value_fn(distribute_lib.ValueContext(replica_id, 

1106 self._num_replicas_in_sync))) 

1107 return distribute_utils.regroup(per_replica_values, always_wrap=True) 

1108 

1109 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 

1110 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have 

1111 # a mechanism to infer the outputs of `fn`. Pending b/110550782. 

1112 def _experimental_run_steps_on_iterator( 

1113 self, fn, multi_worker_iterator, iterations, initial_loop_values=None): 

1114 # Wrap `fn` for repeat. 

1115 if initial_loop_values is None: 

1116 initial_loop_values = {} 

1117 initial_loop_values = nest.flatten(initial_loop_values) 

1118 ctx = input_lib.MultiStepContext() 

1119 

1120 def run_fn(inputs): 

1121 """Single step on the TPU device.""" 

1122 fn_result = fn(ctx, inputs) 

1123 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 

1124 if flat_last_step_outputs: 

1125 with ops.control_dependencies([fn_result]): 

1126 return [array_ops.identity(f) for f in flat_last_step_outputs] 

1127 else: 

1128 return fn_result 

1129 

1130 # We capture the control_flow_context at this point, before we run `fn` 

1131 # inside a while_loop and TPU replicate context. This is useful in cases 

1132 # where we might need to exit these contexts and get back to the outer 

1133 # context to do some things, for e.g. create an op which should be 

1134 # evaluated only once at the end of the loop on the host. One such usage 

1135 # is in creating metrics' value op. 

1136 self._outer_control_flow_context = ( 

1137 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 

1138 

1139 def rewrite_fn(*args): 

1140 """The rewritten step fn running on TPU.""" 

1141 del args 

1142 

1143 per_replica_inputs = multi_worker_iterator.get_next() 

1144 replicate_inputs = [] 

1145 for replica_id in range(self._num_replicas_in_sync): 

1146 select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda 

1147 replica_id, x) # pylint: disable=cell-var-from-loop 

1148 replicate_inputs.append((nest.map_structure( 

1149 select_replica, per_replica_inputs),)) 

1150 

1151 replicate_outputs = tpu.replicate( 

1152 run_fn, 

1153 replicate_inputs, 

1154 device_assignment=self._device_assignment, 

1155 xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self 

1156 ._use_spmd_for_xla_partitioning)) 

1157 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We 

1158 # will flatten it in this case. If run_fn has no tensor outputs, 

1159 # tpu.replicate returns a list of no_ops, we will keep the output as it 

1160 # is. 

1161 if isinstance(replicate_outputs[0], list): 

1162 replicate_outputs = nest.flatten(replicate_outputs) 

1163 

1164 return replicate_outputs 

1165 

1166 # TODO(sourabhbajaj): The input to while loop should be based on the 

1167 # output type of the step_fn 

1168 assert isinstance(initial_loop_values, list) 

1169 initial_loop_values = initial_loop_values * self._num_replicas_in_sync 

1170 

1171 # Put the while loop op on TPU host 0. 

1172 with ops.device(self._host_device): 

1173 if self.steps_per_run == 1: 

1174 replicate_outputs = rewrite_fn() 

1175 else: 

1176 replicate_outputs = training_loop.repeat(iterations, rewrite_fn, 

1177 initial_loop_values) 

1178 

1179 del self._outer_control_flow_context 

1180 ctx.run_op = control_flow_ops.group(replicate_outputs) 

1181 

1182 if isinstance(replicate_outputs, list): 

1183 # Filter out any ops from the outputs, typically this would be the case 

1184 # when there were no tensor outputs. 

1185 last_step_tensor_outputs = [ 

1186 x for x in replicate_outputs if not isinstance(x, ops.Operation) 

1187 ] 

1188 

1189 # Outputs are currently of the structure (flattened) 

1190 # [output0_device0, output1_device0, output2_device0, 

1191 # output0_device1, output1_device1, output2_device1, 

1192 # ...] 

1193 # Convert this to the following structure instead: (grouped by output) 

1194 # [[output0_device0, output0_device1], 

1195 # [output1_device0, output1_device1], 

1196 # [output2_device0, output2_device1]] 

1197 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync 

1198 last_step_tensor_outputs = [ 

1199 last_step_tensor_outputs[i::output_num] for i in range(output_num) 

1200 ] 

1201 else: 

1202 # no tensors returned. 

1203 last_step_tensor_outputs = [] 

1204 

1205 _set_last_step_outputs(ctx, last_step_tensor_outputs) 

1206 return ctx 

1207 

1208 def _call_for_each_replica(self, fn, args, kwargs): 

1209 # TODO(jhseu): Consider making it so call_for_each_replica implies that 

1210 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. 

1211 with _TPUReplicaContext(self._container_strategy()): 

1212 return fn(*args, **kwargs) 

1213 

1214 @contextlib.contextmanager 

1215 def experimental_logical_device(self, logical_device_id): 

1216 """Places variables and ops on the specified logical device.""" 

1217 num_logical_devices_per_replica = self._tpu_devices.shape[1] 

1218 if logical_device_id >= num_logical_devices_per_replica: 

1219 raise ValueError( 

1220 "`logical_device_id` not in range (was {}, but there are only {} " 

1221 "logical devices per replica).".format( 

1222 logical_device_id, num_logical_devices_per_replica)) 

1223 

1224 self._logical_device_stack.append(logical_device_id) 

1225 try: 

1226 if tpu_util.enclosing_tpu_context() is None: 

1227 yield 

1228 else: 

1229 with ops.device(tpu.core(logical_device_id)): 

1230 yield 

1231 finally: 

1232 self._logical_device_stack.pop() 

1233 

1234 def _experimental_initialize_system(self): 

1235 """Experimental method added to be used by Estimator. 

1236 

1237 This is a private method only to be used by Estimator. Other frameworks 

1238 should directly be calling `tf.tpu.experimental.initialize_tpu_system` 

1239 """ 

1240 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) 

1241 

1242 def _create_variable(self, next_creator, **kwargs): 

1243 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" 

1244 if kwargs.pop("skip_mirrored_creator", False): 

1245 return next_creator(**kwargs) 

1246 

1247 colocate_with = kwargs.pop("colocate_with", None) 

1248 if colocate_with is None: 

1249 devices = self._tpu_devices[:, self._logical_device_stack[-1]] 

1250 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 

1251 with ops.device(colocate_with.device): 

1252 return next_creator(**kwargs) 

1253 else: 

1254 devices = colocate_with._devices # pylint: disable=protected-access 

1255 

1256 num_replicas, num_cores_per_replica = self._tpu_devices.shape 

1257 

1258 def _create_mirrored_tpu_variables(**kwargs): 

1259 """Returns a list of `tf.Variable`s. 

1260 

1261 The list contains `number_replicas` `tf.Variable`s and can be used to 

1262 initialize a `TPUMirroredVariable`. 

1263 

1264 Args: 

1265 **kwargs: the keyword arguments for creating a variable 

1266 """ 

1267 initial_value = None 

1268 value_list = [] 

1269 for i, d in enumerate(devices): 

1270 with ops.device(d): 

1271 if i == 0: 

1272 initial_value = kwargs["initial_value"] 

1273 # Note: some v1 code expects variable initializer creation to happen 

1274 # inside a init_scope. 

1275 with maybe_init_scope(): 

1276 initial_value = initial_value() if callable( 

1277 initial_value) else initial_value 

1278 

1279 if i > 0: 

1280 # Give replicas meaningful distinct names: 

1281 var0name = value_list[0].name.split(":")[0] 

1282 # We append a / to variable names created on replicas with id > 0 to 

1283 # ensure that we ignore the name scope and instead use the given 

1284 # name as the absolute name of the variable. 

1285 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 

1286 kwargs["initial_value"] = initial_value 

1287 

1288 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 

1289 v = next_creator(**kwargs) 

1290 

1291 assert not isinstance(v, tpu_values.TPUMirroredVariable) 

1292 value_list.append(v) 

1293 return value_list 

1294 

1295 def _create_mirrored_tpu_replicated_variables(**kwargs): 

1296 """Returns a list of `TPUReplicatedVariable`s. 

1297 

1298 The list consists of `num_replicas` `TPUReplicatedVariable`s and can be 

1299 used to initialize a `TPUMirroredVariable`. Each `TPUReplicatedVariable` 

1300 contains a list of `tf.Variable`s which are replicated to 

1301 `num_cores_per_replica` logical cores to enable XLA SPMD compilation. 

1302 

1303 Args: 

1304 **kwargs: the keyword arguments for creating a variable 

1305 """ 

1306 initial_value = kwargs["initial_value"] 

1307 # Note: some v1 code expects variable initializer creation to happen 

1308 # inside a init_scope. 

1309 with maybe_init_scope(): 

1310 initial_value = initial_value() if callable( 

1311 initial_value) else initial_value 

1312 

1313 mirrored_replicated_var_list = [] 

1314 

1315 for replica_id in range(num_replicas): 

1316 replicated_var_list = [] 

1317 for logic_core_id in range(num_cores_per_replica): 

1318 with ops.device(self._tpu_devices[replica_id][logic_core_id]): 

1319 kwargs["initial_value"] = initial_value 

1320 v = next_creator(**kwargs) 

1321 replicated_var_list.append(v) 

1322 replica_name = "{}/r:{}".format(kwargs["name"], replica_id) 

1323 tpu_replicated_var = tpu_replicated_variable.TPUReplicatedVariable( 

1324 variables=replicated_var_list, name=replica_name) 

1325 

1326 mirrored_replicated_var_list.append(tpu_replicated_var) 

1327 return mirrored_replicated_var_list 

1328 

1329 if self._use_spmd_for_xla_partitioning and num_cores_per_replica > 1: 

1330 real_creator = _create_mirrored_tpu_replicated_variables 

1331 else: 

1332 real_creator = _create_mirrored_tpu_variables 

1333 

1334 return distribute_utils.create_mirrored_variable( 

1335 self._container_strategy(), real_creator, 

1336 distribute_utils.TPU_VARIABLE_CLASS_MAPPING, 

1337 distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs) 

1338 

1339 def _resource_creator_scope(self): 

1340 

1341 def lookup_creator(next_creator, *args, **kwargs): 

1342 host_to_table = collections.OrderedDict() 

1343 for host_device in self._device_input_worker_devices.keys(): 

1344 with ops.device(host_device): 

1345 host_to_table[host_device] = next_creator(*args, **kwargs) 

1346 

1347 return values.PerWorkerResource(self._container_strategy(), host_to_table) 

1348 

1349 # TODO(b/194362531): Define creator(s) for other resources. 

1350 return ops.resource_creator_scope("StaticHashTable", lookup_creator) 

1351 

1352 def _gather_to_implementation(self, value, destinations, axis, options): 

1353 if not isinstance(value, values.DistributedValues): 

1354 return value 

1355 

1356 value_list = list(value.values) 

1357 # pylint: disable=protected-access 

1358 if isinstance( 

1359 value, 

1360 values.DistributedVariable) and value._packed_variable is not None: 

1361 value_list = list( 

1362 value._packed_variable.on_device(d) 

1363 for d in value._packed_variable.devices) 

1364 # pylint: enable=protected-access 

1365 

1366 # Currently XLA op by op mode has a limit for the number of inputs for a 

1367 # single op, thus we break one `add_n` op into a group of `add_n` ops to 

1368 # work around the constraint. 

1369 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 

1370 output = array_ops.concat(value_list, axis=axis) 

1371 else: 

1372 output = array_ops.concat( 

1373 value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis) 

1374 for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list), 

1375 _XLA_OP_BY_OP_INPUTS_LIMIT - 1): 

1376 output = array_ops.concat( 

1377 [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1], 

1378 axis=axis) 

1379 

1380 output = self._broadcast_output(destinations, output) 

1381 return output 

1382 

1383 def _broadcast_output(self, destinations, output): 

1384 devices = cross_device_ops_lib.get_devices_from(destinations) 

1385 

1386 if len(devices) == 1: 

1387 # If necessary, copy to requested destination. 

1388 dest_canonical = device_util.canonicalize(devices[0]) 

1389 host_canonical = device_util.canonicalize(self._host_device) 

1390 

1391 if dest_canonical != host_canonical: 

1392 with ops.device(dest_canonical): 

1393 output = array_ops.identity(output) 

1394 else: 

1395 output = cross_device_ops_lib.simple_broadcast(output, destinations) 

1396 

1397 return output 

1398 

1399 def _reduce_to(self, reduce_op, value, destinations, options): 

1400 if (isinstance(value, values.DistributedValues) or 

1401 tensor_util.is_tf_type(value) 

1402 ) and tpu_util.enclosing_tpu_context() is not None: 

1403 if reduce_op == reduce_util.ReduceOp.MEAN: 

1404 # TODO(jhseu): Revisit once we support model-parallelism. 

1405 # scalar_mul maintains the type of value: tensor or IndexedSlices. 

1406 value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value) 

1407 elif reduce_op != reduce_util.ReduceOp.SUM: 

1408 raise NotImplementedError( 

1409 f"`reduce_op`={reduce_op} is not supported. Currently we only " 

1410 "support ReduceOp.SUM and ReduceOp.MEAN in TPUStrategy.") 

1411 return tpu_ops.cross_replica_sum(value) 

1412 

1413 if not isinstance(value, values.DistributedValues): 

1414 # This function handles reducing values that are not PerReplica or 

1415 # Mirrored values. For example, the same value could be present on all 

1416 # replicas in which case `value` would be a single value or value could 

1417 # be 0. 

1418 return cross_device_ops_lib.reduce_non_distributed_value( 

1419 reduce_op, value, destinations, self._num_replicas_in_sync) 

1420 

1421 value_list = value.values 

1422 # pylint: disable=protected-access 

1423 if isinstance( 

1424 value, 

1425 values.DistributedVariable) and value._packed_variable is not None: 

1426 value_list = tuple( 

1427 value._packed_variable.on_device(d) 

1428 for d in value._packed_variable.devices) 

1429 # pylint: enable=protected-access 

1430 

1431 # Currently XLA op by op mode has a limit for the number of inputs for a 

1432 # single op, thus we break one `add_n` op into a group of `add_n` ops to 

1433 # work around the constraint. 

1434 # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. 

1435 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 

1436 output = math_ops.add_n(value_list) 

1437 else: 

1438 output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype) 

1439 for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT): 

1440 output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT]) 

1441 

1442 if reduce_op == reduce_util.ReduceOp.MEAN: 

1443 output *= (1. / len(value_list)) 

1444 

1445 output = self._broadcast_output(destinations, output) 

1446 return output 

1447 

1448 def _update(self, var, fn, args, kwargs, group): 

1449 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 

1450 var, resource_variable_ops.BaseResourceVariable) 

1451 if tpu_util.enclosing_tpu_context() is not None: 

1452 if group: 

1453 return fn(var, *args, **kwargs) 

1454 else: 

1455 return (fn(var, *args, **kwargs),) 

1456 

1457 # Inside `tf.function`, we don't expand PackedVariable in python as it will 

1458 # be expanded later during function instantiation in the runtime. 

1459 packed_var = var._packed_variable # pylint: disable=protected-access 

1460 if packed_var is not None and not context.executing_eagerly(): 

1461 if group: 

1462 return fn(packed_var, *args, **kwargs) 

1463 else: 

1464 return (fn(packed_var, *args, **kwargs),) 

1465 

1466 # Otherwise, we revert to MirroredStrategy behavior and update the variable 

1467 # on each replica directly. 

1468 updates = [] 

1469 values_and_devices = [] 

1470 if packed_var is not None: 

1471 for device in packed_var.devices: 

1472 values_and_devices.append((packed_var, device)) 

1473 else: 

1474 for value in var.values: 

1475 values_and_devices.append((value, value.device)) 

1476 

1477 if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and 

1478 var.aggregation != variables_lib.VariableAggregation.NONE): 

1479 distribute_utils.assert_mirrored(args) 

1480 distribute_utils.assert_mirrored(kwargs) 

1481 for i, value_and_device in enumerate(values_and_devices): 

1482 value = value_and_device[0] 

1483 device = value_and_device[1] 

1484 name = "update_%d" % i 

1485 with ops.device(device), \ 

1486 distribute_lib.UpdateContext(i), \ 

1487 ops.name_scope(name): 

1488 # If args and kwargs are not mirrored, the value is returned as is. 

1489 updates.append( 

1490 fn(value, *distribute_utils.select_replica(i, args), 

1491 **distribute_utils.select_replica(i, kwargs))) 

1492 return distribute_utils.update_regroup(self, updates, group) 

1493 

1494 def read_var(self, var): 

1495 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 

1496 var, resource_variable_ops.BaseResourceVariable) 

1497 return var.read_value() 

1498 

1499 def value_container(self, value): 

1500 return value 

1501 

1502 def _broadcast_to(self, tensor, destinations): 

1503 del destinations 

1504 # This is both a fast path for Python constants, and a way to delay 

1505 # converting Python values to a tensor until we know what type it 

1506 # should be converted to. Otherwise we have trouble with: 

1507 # global_step.assign_add(1) 

1508 # since the `1` gets broadcast as an int32 but global_step is int64. 

1509 if isinstance(tensor, (float, int)): 

1510 return tensor 

1511 if tpu_util.enclosing_tpu_context() is not None: 

1512 broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] 

1513 result = tpu_ops.all_to_all( 

1514 broadcast_tensor, 

1515 concat_dimension=0, 

1516 split_dimension=0, 

1517 split_count=self._num_replicas_in_sync) 

1518 

1519 # This uses the broadcasted value from the first replica because the only 

1520 # caller of this is for ONLY_FIRST_REPLICA variables aggregation. 

1521 return result[0] 

1522 return tensor 

1523 

1524 @property 

1525 def num_hosts(self): 

1526 if self._device_assignment is None: 

1527 return self._tpu_metadata.num_hosts 

1528 

1529 return len(set([self._device_assignment.host_device(r) 

1530 for r in range(self._device_assignment.num_replicas)])) 

1531 

1532 @property 

1533 def num_replicas_per_host(self): 

1534 if self._device_assignment is None: 

1535 return self._tpu_metadata.num_of_cores_per_host 

1536 

1537 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed 

1538 # as the computation of num_replicas_per_host is not a constant 

1539 # when using device_assignment. This is a temporary workaround to support 

1540 # StatefulRNN as everything is 1 in that case. 

1541 # This method needs to take host_id as input for correct computation. 

1542 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // 

1543 self._device_assignment.num_cores_per_replica) 

1544 return min(self._device_assignment.num_replicas, max_models_per_host) 

1545 

1546 @property 

1547 def _num_replicas_in_sync(self): 

1548 if self._device_assignment is None: 

1549 return self._tpu_metadata.num_cores 

1550 return self._device_assignment.num_replicas 

1551 

1552 @property 

1553 def experimental_between_graph(self): 

1554 return False 

1555 

1556 @property 

1557 def experimental_should_init(self): 

1558 return True 

1559 

1560 @property 

1561 def should_checkpoint(self): 

1562 return True 

1563 

1564 @property 

1565 def should_save_summary(self): 

1566 return True 

1567 

1568 @property 

1569 def worker_devices(self): 

1570 return tuple(self._tpu_devices[:, self._logical_device_stack[-1]]) 

1571 

1572 @property 

1573 def parameter_devices(self): 

1574 return self.worker_devices 

1575 

1576 @property 

1577 def tpu_hardware_feature(self): 

1578 """Return the `tf.tpu.experimental.HardwareFeature` class.""" 

1579 return tpu_hardware_feature.HardwareFeature( 

1580 self._tpu_cluster_resolver.tpu_hardware_feature) 

1581 

1582 def non_slot_devices(self, var_list): 

1583 return self._host_device 

1584 

1585 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 

1586 del colocate_with 

1587 with ops.device(self._host_device), distribute_lib.UpdateContext(None): 

1588 result = fn(*args, **kwargs) 

1589 if group: 

1590 return result 

1591 else: 

1592 return nest.map_structure(self._local_results, result) 

1593 

1594 def _configure(self, 

1595 session_config=None, 

1596 cluster_spec=None, 

1597 task_type=None, 

1598 task_id=None): 

1599 del cluster_spec, task_type, task_id 

1600 if session_config: 

1601 session_config.CopyFrom(self._update_config_proto(session_config)) 

1602 

1603 def _update_config_proto(self, config_proto): 

1604 updated_config = copy.deepcopy(config_proto) 

1605 updated_config.isolate_session_state = True 

1606 cluster_spec = self._tpu_cluster_resolver.cluster_spec() 

1607 if cluster_spec: 

1608 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 

1609 return updated_config 

1610 

1611 # TODO(priyag): Delete this once all strategies use global batch size. 

1612 @property 

1613 def _global_batch_size(self): 

1614 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 

1615 

1616 `make_input_fn_iterator` assumes per-replica batching. 

1617 

1618 Returns: 

1619 Boolean. 

1620 """ 

1621 return True 

1622 

1623 def tpu_run(self, fn, args, kwargs, options=None): 

1624 func = self._tpu_function_creator(fn, options) 

1625 return func(args, kwargs) 

1626 

1627 def _tpu_function_creator(self, fn, options): 

1628 if context.executing_eagerly() and fn in self._tpu_function_cache: 

1629 return self._tpu_function_cache[fn] 

1630 

1631 strategy = self._container_strategy() 

1632 

1633 def tpu_function(args, kwargs): 

1634 """TF Function used to replicate the user computation.""" 

1635 logging.vlog(1, 

1636 "`TPUStrategy.run` is called with [args: %s] [kwargs: %s]", 

1637 args, kwargs) 

1638 

1639 if kwargs is None: 

1640 kwargs = {} 

1641 

1642 # Used to re-structure flattened output tensors from `tpu.replicate()` 

1643 # into a structured format. 

1644 result = [[]] 

1645 

1646 def replicated_fn(replica_id, replica_args, replica_kwargs): 

1647 """Wraps user function to provide replica ID and `Tensor` inputs.""" 

1648 with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): 

1649 result[0] = fn(*replica_args, **replica_kwargs) 

1650 return result[0] 

1651 

1652 replicate_inputs = [] # By replica. 

1653 for i in range(strategy.num_replicas_in_sync): 

1654 replicate_inputs.append( 

1655 [constant_op.constant(i, dtype=dtypes.int32), 

1656 distribute_utils.select_replica(i, args), 

1657 distribute_utils.select_replica(i, kwargs)]) 

1658 

1659 # Construct and pass `maximum_shapes` so that we could support dynamic 

1660 # shapes using dynamic padder. 

1661 if options.experimental_enable_dynamic_batch_size and replicate_inputs: 

1662 maximum_shapes = [] 

1663 flattened_list = nest.flatten(replicate_inputs[0]) 

1664 for input_tensor in flattened_list: 

1665 if tensor_util.is_tf_type(input_tensor): 

1666 rank = input_tensor.shape.rank 

1667 else: 

1668 rank = np.ndim(input_tensor) 

1669 if rank is None: 

1670 raise ValueError( 

1671 "input tensor {} to TPUStrategy.run() has unknown rank, " 

1672 "which is not allowed".format(input_tensor)) 

1673 maximum_shape = tensor_shape.TensorShape([None] * rank) 

1674 maximum_shapes.append(maximum_shape) 

1675 maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], 

1676 maximum_shapes) 

1677 else: 

1678 maximum_shapes = None 

1679 

1680 if options.experimental_bucketizing_dynamic_shape: 

1681 padding_spec = tpu.PaddingSpec.POWER_OF_TWO 

1682 else: 

1683 padding_spec = None 

1684 

1685 with strategy.scope(): 

1686 xla_options = options.experimental_xla_options or tpu.XLAOptions( 

1687 use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning) 

1688 replicate_outputs = tpu.replicate( 

1689 replicated_fn, 

1690 replicate_inputs, 

1691 device_assignment=self._device_assignment, 

1692 maximum_shapes=maximum_shapes, 

1693 padding_spec=padding_spec, 

1694 xla_options=xla_options) 

1695 

1696 # Remove all no ops that may have been added during 'tpu.replicate()' 

1697 filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)] 

1698 if isinstance(result[0], list): 

1699 result[0] = filter_ops(result[0]) 

1700 

1701 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. 

1702 if result[0] is None or isinstance(result[0], ops.Operation): 

1703 replicate_outputs = [None] * len(replicate_outputs) 

1704 else: 

1705 replicate_outputs = [ 

1706 nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output))) 

1707 for output in replicate_outputs 

1708 ] 

1709 return distribute_utils.regroup(replicate_outputs) 

1710 

1711 if context.executing_eagerly(): 

1712 tpu_function = def_function.function(tpu_function) 

1713 self._tpu_function_cache[fn] = tpu_function 

1714 return tpu_function 

1715 

1716 def _in_multi_worker_mode(self): 

1717 """Whether this strategy indicates working in multi-worker settings.""" 

1718 # TPUStrategy has different distributed training structure that the whole 

1719 # cluster should be treated as single worker from higher-level (e.g. Keras) 

1720 # library's point of view. 

1721 # TODO(rchao): Revisit this as we design a fault-tolerance solution for 

1722 # TPUStrategy. 

1723 return False 

1724 

1725 def _get_local_replica_id(self, replica_id_in_sync_group): 

1726 return replica_id_in_sync_group 

1727 

1728 

1729def _make_axis_nonnegative(axis, rank): 

1730 # Convert a potentially negative `axis` to a non-negative one. 

1731 if isinstance(axis, int): 

1732 if axis >= 0: 

1733 return axis 

1734 else: 

1735 return axis + rank 

1736 else: 

1737 return array_ops.where_v2( 

1738 math_ops.greater_equal(axis, 0), 

1739 axis, 

1740 axis + rank) 

1741 

1742 

1743# List of Tensor dtypes supported by cross_replica_sum(). 

1744_DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM = ( 

1745 dtypes.bfloat16, 

1746 dtypes.float16, 

1747 dtypes.float32, 

1748 dtypes.float64, 

1749 dtypes.int32, 

1750 dtypes.uint32, 

1751) 

1752 

1753 

1754class _TPUReplicaContext(distribute_lib.ReplicaContext): 

1755 """Replication Context class for TPU Strategy.""" 

1756 

1757 # TODO(sourabhbajaj): Call for each replica should be updating this. 

1758 # TODO(b/118385803): Always properly initialize replica_id. 

1759 def __init__(self, strategy, replica_id_in_sync_group=0): 

1760 distribute_lib.ReplicaContext.__init__( 

1761 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) 

1762 

1763 @property 

1764 def devices(self): 

1765 distribute_lib.require_replica_context(self) 

1766 ds = self._strategy 

1767 replica_id = tensor_util.constant_value(self.replica_id_in_sync_group) 

1768 

1769 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. 

1770 # TODO(cjfj): Return other devices when model parallelism is supported. 

1771 return (tpu.core(0),) 

1772 else: 

1773 return (ds.extended.worker_devices[replica_id],) 

1774 

1775 def experimental_logical_device(self, logical_device_id): 

1776 """Places variables and ops on the specified logical device.""" 

1777 return self.strategy.extended.experimental_logical_device(logical_device_id) 

1778 

1779 def _compute_all_gather_output_shape(self, value_shape, value_rank, axis): 

1780 if isinstance(value_rank, int): 

1781 output_shape = list(value_shape) 

1782 output_shape[axis] *= self.num_replicas_in_sync 

1783 else: 

1784 output_shape = array_ops.where_v2( 

1785 math_ops.equal(math_ops.range(value_rank), axis), 

1786 value_shape * context.num_replicas_in_sync, 

1787 value_shape) 

1788 return output_shape 

1789 

1790 def all_gather(self, value, axis, experimental_hints=None): 

1791 del experimental_hints 

1792 for v in nest.flatten(value): 

1793 if isinstance(v, indexed_slices.IndexedSlices): 

1794 raise NotImplementedError("all_gather does not support IndexedSlices") 

1795 

1796 def _all_gather_tensor(value, axis): 

1797 value = ops.convert_to_tensor(value) 

1798 

1799 # Compute the shape and rank and rank of the input tensor. Use static 

1800 # shapes when possible to help with shape inference in graph mode, but 

1801 # fall back on dynamic shapes when necessary. 

1802 if value.shape.rank is None: 

1803 value_rank = array_ops.rank(value) 

1804 value_shape = array_ops.shape(value) 

1805 else: 

1806 value_rank = value.shape.rank 

1807 value_shape = value.shape.as_list() 

1808 value_shape_tensor = array_ops.shape(value) 

1809 for i in range(len(value_shape)): 

1810 if value_shape[i] is None: 

1811 value_shape[i] = value_shape_tensor[i] 

1812 

1813 # In the code below, we will insert a new "replica" dimension immediately 

1814 # *before* `axis`. To ensure that it's inserted before and not after, we 

1815 # must make `axis` non-negative. 

1816 axis = _make_axis_nonnegative(axis, value_rank) 

1817 

1818 # Create a list or 1D int Tensor such as 

1819 # [1, 1, ..., 1, num_replicas_in_sync, 1, ..., 1], 

1820 # which is equal to `num_replicas_in_sync` at index `axis` 

1821 # and is equal to 1 everywhere else. 

1822 if isinstance(value_rank, int): 

1823 replica_broadcast_shape = [1] * (value_rank + 1) 

1824 replica_broadcast_shape[axis] = self.num_replicas_in_sync 

1825 else: 

1826 replica_broadcast_shape = array_ops.where_v2( 

1827 math_ops.equal(math_ops.range(value_rank+1), axis), 

1828 self.num_replicas_in_sync, 

1829 1) 

1830 

1831 output_shape = self._compute_all_gather_output_shape( 

1832 value_shape, value_rank, axis) 

1833 

1834 if value.dtype in _DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM: 

1835 # optimized all_gather implementation based on cross_replica_sum(). 

1836 replica_id_mask = array_ops.one_hot( 

1837 self.replica_id_in_sync_group, self.num_replicas_in_sync) 

1838 replica_id_mask = array_ops.reshape( 

1839 replica_id_mask, replica_broadcast_shape) 

1840 replica_id_mask = math_ops.cast(replica_id_mask, value.dtype) 

1841 

1842 gathered_value = array_ops.expand_dims(value, axis) * replica_id_mask 

1843 gathered_value = self.all_reduce( 

1844 reduce_util.ReduceOp.SUM, gathered_value) 

1845 return array_ops.reshape(gathered_value, output_shape) 

1846 else: 

1847 # value.dtype isn't supported by cross_replica_sum(), so we fall back 

1848 # on a less efficient implementation based on all_to_all(). 

1849 

1850 # The underlying AllToAllOp first do a split of the input value and then 

1851 # cross-replica communication and concatenation of the result. So we 

1852 # concatenate the local tensor here first. 

1853 inputs = array_ops.expand_dims(value, axis=axis) 

1854 inputs = array_ops.tile(inputs, replica_broadcast_shape) 

1855 unordered_output = tpu_ops.all_to_all( 

1856 inputs, 

1857 concat_dimension=axis, 

1858 split_dimension=axis, 

1859 split_count=self.num_replicas_in_sync) 

1860 

1861 # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch. 

1862 # Start by computing a permutation -- a 1D Tensor which maps 

1863 # tensor[xla.replica_id] = ReplicaContext.replica_id 

1864 concat_replica_id = array_ops.reshape( 

1865 self.replica_id_in_sync_group, [1]) 

1866 concat_replica_id = array_ops.tile( 

1867 concat_replica_id, [self.num_replicas_in_sync]) 

1868 xla_to_replica_context_id = tpu_ops.all_to_all( 

1869 concat_replica_id, 

1870 concat_dimension=0, 

1871 split_dimension=0, 

1872 split_count=self.num_replicas_in_sync) 

1873 

1874 # Now invert the mapping to get 

1875 # tensor[ReplicaContext.replica_id] = xla.replica_id 

1876 replica_context_to_xla_id = math_ops.argmax( 

1877 array_ops.one_hot(xla_to_replica_context_id, 

1878 self.num_replicas_in_sync), 

1879 axis=0) 

1880 

1881 # Reorder the output elements so that they're sorted based on 

1882 # ReplicaContext.replica_id instead of xla.replica_id. 

1883 sorted_with_extra_dim = array_ops.gather( 

1884 unordered_output, replica_context_to_xla_id, axis=axis) 

1885 return array_ops.reshape(sorted_with_extra_dim, output_shape) 

1886 

1887 ys = [_all_gather_tensor(t, axis=axis) for t in nest.flatten(value)] 

1888 return nest.pack_sequence_as(value, ys) 

1889 

1890 

1891def _set_last_step_outputs(ctx, last_step_tensor_outputs): 

1892 """Sets the last step outputs on the given context.""" 

1893 # Convert replicate_outputs to the original dict structure of 

1894 # last_step_outputs. 

1895 last_step_tensor_outputs_dict = nest.pack_sequence_as( 

1896 ctx.last_step_outputs, last_step_tensor_outputs) 

1897 

1898 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 

1899 output = last_step_tensor_outputs_dict[name] 

1900 # For outputs that aren't reduced, return a PerReplica of all values. Else 

1901 # take the first value from the list as each value should be the same. 

1902 if reduce_op is None: 

1903 last_step_tensor_outputs_dict[name] = values.PerReplica(output) 

1904 else: 

1905 # TODO(priyag): Should this return the element or a list with 1 element 

1906 last_step_tensor_outputs_dict[name] = output[0] 

1907 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access