Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_feed.py: 17%

289 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# =================================================================== 

15 

16"""Helper library for handling infeed between hosts and TPUs. 

17""" 

18 

19import itertools 

20 

21import numpy as np 

22 

23from tensorflow.python.compiler.xla.experimental import xla_sharding 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.tpu import tpu_name_util 

29from tensorflow.python.tpu import tpu_sharding 

30from tensorflow.python.tpu.ops import tpu_ops 

31 

32from tensorflow.python.util import nest 

33 

34 

35def partition_or_replicate_on_host(tensor, dims): 

36 """Partitions or replicates the input tensor. 

37 

38 The ops inside this function are placed on the host side. 

39 

40 Args: 

41 tensor: The input tensor which will be partitioned or replicated. 

42 dims: A list of integer describes how to partition the input tensor. 

43 

44 Returns: 

45 An iterator of `Tensor`s or a list of partitioned tensors. 

46 """ 

47 if dims is None: 

48 return itertools.repeat(tensor) 

49 dims = np.array(dims) 

50 output = [tensor] 

51 shape_list = np.array(tensor.shape.as_list()) 

52 quotients, remainders = np.divmod(shape_list, dims) 

53 for axis, (quotient, remainder, dim, original_size) in enumerate( 

54 zip(quotients, remainders, dims, shape_list)): 

55 if dim <= 1: 

56 continue 

57 if remainder > 0: 

58 # For each dimension, when it cannot be evenly partitioned, XLA assumes 

59 # tensors are partitioned in a greedy manner by using 

60 # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims 

61 # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => 

62 # [[(3, 4), (3, 4), (2, 4), (2, 2)], 

63 # [(2, 4), (2, 4), (2, 4), (2, 2)]] 

64 ceil_ratio = quotient + 1 

65 num_full_slots, left_over = np.divmod(original_size, ceil_ratio) 

66 num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] 

67 if len(num_or_size_splits) < dim: 

68 num_or_size_splits += [0] * (dim - len(num_or_size_splits)) 

69 new_output = [] 

70 for x in output: 

71 new_output.append( 

72 array_ops.split( 

73 x, num_or_size_splits=num_or_size_splits, axis=axis)) 

74 output = new_output 

75 else: 

76 output = [array_ops.split(x, int(dim), axis=axis) for x in output] 

77 output = nest.flatten(output) 

78 return output 

79 

80 

81def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): 

82 """Tags appropriate XLA sharding attribute to the dequeued tensor. 

83 

84 The sharding attribute of the dequeued tensor will be a tuple. 

85 

86 Args: 

87 tensor: The dequeued tensor on TPU. 

88 dims: A list of integer describes how the tensor is partitioned. 

89 

90 Returns: 

91 The same tensor with the xla_sharding attribute. 

92 """ 

93 if dims is None: 

94 return xla_sharding.replicate(tensor, assign_tuple_sharding=True) 

95 elif np.prod(dims) == 1: 

96 return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True) 

97 else: 

98 tile_assignment = np.arange(np.prod(dims)).reshape(dims) 

99 return xla_sharding.tile( 

100 tensor=tensor, 

101 tile_assignment=tile_assignment, 

102 assign_tuple_sharding=True) 

103 

104 

105def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): 

106 """Tags appropriate XLA sharding attribute to the dequeued tensors. 

107 

108 Args: 

109 dequeues: A list of dequeued tensors on TPU. 

110 dims: A list of integer describes how the tensor is partitioned. 

111 

112 Returns: 

113 The same dequeues with appropriate xla_sharding attribute. 

114 """ 

115 nest.assert_shallow_structure(dequeues, dims) 

116 return nest.map_structure_up_to( 

117 dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims) 

118 

119 

120class InfeedQueue(object): 

121 """A helper object to build a device infeed queue. 

122 

123 The InfeedQueue builds the host-side and device-side Ops to enqueue and 

124 dequeue elements, respectively, and ensures that their types and 

125 shapes match. 

126 """ 

127 

128 def __init__(self, 

129 number_of_tuple_elements=None, 

130 tuple_types=None, 

131 tuple_shapes=None, 

132 shard_dimensions=None, 

133 number_of_partitions=None, 

134 name=None): 

135 """Creates a new InfeedQueue with the given configuration. 

136 

137 The configuration need not be fully specified at creation since it 

138 can be modified subsequently by methods that set the values 

139 explicitly or infer them from the shapes of inputs. 

140 

141 Args: 

142 number_of_tuple_elements: the number of Tensors fed atomically through the 

143 queue, must be present unless it can be inferred from other arguments. 

144 tuple_types: if not None, a list of types of the elements of the queue. 

145 tuple_shapes: if not None, a list of shapes of the elements of the queue. 

146 shard_dimensions: if not None, a list of dimensions on which the 

147 elements of the queue should be sharded during automatic 

148 parallelization. 

149 number_of_partitions: if > 1, the infeed dequeue shape will contain 

150 the full shape that includes all partitions and add corresponding XLA 

151 annotation on the infeed dequeue op. In this case, the infeed is still 

152 data parallel that feeds per-core batch size to each core while the XLA 

153 computation may be partitioned. As XLA requires infeed dequeue shape to 

154 be per-replica shape, thus we need number_of_partitions here to 

155 calculate the per-replica unpartitioned shape. 

156 name: the name of the queue. 

157 

158 Raises: 

159 ValueError: if number_of_tuple_elements <= 0; or 

160 number_of_tuple_arguments, tuple_types, tuple_shapes, and 

161 shard_dimensions are all None; or the length of tuple_types, 

162 tuple_shapes, or shard_dimensions is not equal to 

163 number_of_tuple_elements; or any element of shard_dimensions 

164 can't be converted to a Dimension. 

165 TypeError: if any element of tuple_types or tuple_shapes can't 

166 be converted to a dtype or TensorShape, respectively. 

167 """ 

168 self._frozen = False 

169 self._generated_enqueue_ops = False 

170 self._generated_dequeue_op = False 

171 self._name = "InfeedQueue" if name is None else name 

172 if number_of_partitions is None: 

173 self._number_of_partitions = 1 

174 else: 

175 self._number_of_partitions = number_of_partitions 

176 if number_of_tuple_elements is None: 

177 if tuple_types is not None: 

178 number_of_tuple_elements = len(tuple_types) 

179 elif tuple_shapes is not None: 

180 number_of_tuple_elements = len(tuple_shapes) 

181 elif shard_dimensions is not None: 

182 number_of_tuple_elements = len(shard_dimensions) 

183 else: 

184 raise ValueError( 

185 "number of tuple elements cannot be inferred from InfeedQueue " 

186 "constructor") 

187 if number_of_tuple_elements <= 0: 

188 raise ValueError(f"number_of_tuple_elements {number_of_tuple_elements} " 

189 "must be > 0") 

190 # Make an empty sharding policy for each tuple element. 

191 self._sharding_policies = [ 

192 tpu_sharding.ShardingPolicy() for _ in range(number_of_tuple_elements) 

193 ] 

194 if tuple_types is not None: 

195 self.set_tuple_types(tuple_types) 

196 else: 

197 self._tuple_types = None 

198 if tuple_shapes is not None: 

199 self.set_tuple_shapes(tuple_shapes) 

200 else: 

201 self._tuple_shapes = None 

202 if shard_dimensions is not None: 

203 self.set_shard_dimensions(shard_dimensions) 

204 self._validate() 

205 

206 def _validate(self): 

207 """Checks that the configuration is self-consistent. 

208 

209 Raises: 

210 ValueError: if the shapes and sharding policies don't match. 

211 """ 

212 if self.tuple_shapes is not None: 

213 for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): 

214 # Raise an error if the policy is incompatible with the shape. 

215 _ = policy.get_sharded_shape(shape) 

216 

217 @property 

218 def number_of_tuple_elements(self): 

219 """Returns the number of InfeedQueue tuple elements.""" 

220 return len(self._sharding_policies) 

221 

222 @property 

223 def tuple_types(self): 

224 """Returns the types of the InfeedQueue tuple elements.""" 

225 return self._tuple_types 

226 

227 def set_tuple_types(self, tuple_types): 

228 """Sets the type of each element of the queue. 

229 

230 tuple_types must be a list of length 

231 self.number_of_tuple_elements, and each element must be 

232 convertible to a dtype. 

233 

234 Args: 

235 tuple_types: the types of each queue element. 

236 

237 Raises: 

238 ValueError: if tuple_types is not of length 

239 self.number_of_tuple_elements. 

240 TypeError: if an element of tuple_types cannot be converted to a 

241 dtype. 

242 """ 

243 if len(tuple_types) != self.number_of_tuple_elements: 

244 raise ValueError( 

245 f"tuple_types is {str(tuple_types)}, but must be a list of " 

246 f"length {self.number_of_tuple_elements}" 

247 ) 

248 if self._frozen: 

249 for (frozen, updated) in zip(self._tuple_types, tuple_types): 

250 if frozen != updated: 

251 raise ValueError( 

252 "Trying to update InfeedQueue with frozen configuration with an " 

253 f"incompatible type. Frozen types are {str(self._tuple_types)}, " 

254 f"updated types are {str(tuple_types)}") 

255 else: 

256 try: 

257 self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] 

258 except (TypeError) as e: 

259 raise TypeError( 

260 f"tuple_types is {str(tuple_types)}, but must be a list of " 

261 f"elements each convertible to dtype: got error {str(e)}") from e 

262 

263 @property 

264 def tuple_shapes(self): 

265 """Returns the shapes of the InfeedQueue tuple elements.""" 

266 return self._tuple_shapes 

267 

268 def set_tuple_shapes(self, tuple_shapes): 

269 """Sets the shape of each element of the queue. 

270 

271 tuple_shapes must be a list of length 

272 self.number_of_tuple_elements, and each element must be 

273 convertible to a TensorShape. 

274 

275 Args: 

276 tuple_shapes: the shapes of each queue element. 

277 

278 Raises: 

279 ValueError: if tuple_shapes is not of length 

280 self.number_of_tuple_elements. 

281 TypeError: if an element of tuple_shapes cannot be converted to 

282 a TensorShape. 

283 """ 

284 if len(tuple_shapes) != self.number_of_tuple_elements: 

285 raise ValueError( 

286 f"tuple_shapes is {str(tuple_shapes)}, but must be a list of " 

287 f"length {self.number_of_tuple_elements}" 

288 ) 

289 try: 

290 tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] 

291 except (ValueError, TypeError) as e: 

292 raise TypeError( 

293 f"tuple_shapes is {str(tuple_shapes)}, but must be a list of " 

294 "elements each convertible to TensorShape: got error " 

295 f"{str(e)}") from e 

296 if self._frozen: 

297 for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): 

298 if frozen != updated: 

299 raise ValueError( 

300 "Trying to update InfeedQueue with frozen configuration with an " 

301 "incompatible shape. Frozen shapes are " 

302 f"{str(self._tuple_shapes)}, updated shapes are " 

303 f"{str(tuple_shapes)}") 

304 

305 else: 

306 self._tuple_shapes = tuple_shapes 

307 self._validate() 

308 

309 @property 

310 def sharding_policies(self): 

311 """Returns the sharding policies of the InfeedQueue tuple elements.""" 

312 return self._sharding_policies 

313 

314 @property 

315 def shard_dimensions(self): 

316 """Gets the shard dimension of each tuple element. 

317 

318 Returns: 

319 A list of length number_of_tuple_elements, where each list entry 

320 is the shard dimension of that tuple element or None if the 

321 shard dimension has not been set. 

322 """ 

323 # The number of shards is always the same for all the policies. 

324 return [policy.shard_dimension for policy in self._sharding_policies] 

325 

326 def set_shard_dimensions(self, shard_dimensions): 

327 """Sets the shard_dimension of each element of the queue. 

328 

329 shard_dimensions must be a list of length 

330 self.number_of_tuple_elements, and each element must be 

331 convertible to a Dimension compatible with self.tuple_shapes. 

332 

333 Args: 

334 shard_dimensions: the dimensions of each queue element. 

335 

336 Raises: 

337 ValueError: if shard_dimensions is not of length 

338 self.number_of_tuple_elements; or an element of 

339 shard_dimensions cannot be converted to a Dimension; or an 

340 element of shard_dimensions is a Dimension that is out of 

341 range for the corresponding tuple element shape. 

342 """ 

343 if len(shard_dimensions) != self.number_of_tuple_elements: 

344 raise ValueError(f"shard_dimensions is {str(shard_dimensions)}, but must " 

345 f"be a list of length {self.number_of_tuple_elements}") 

346 for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): 

347 policy.set_shard_dimension(dimension) 

348 self._validate() 

349 

350 @property 

351 def number_of_shards(self): 

352 """Gets the number of shards to use for the InfeedQueue. 

353 

354 Returns: 

355 Number of shards or None if the number of shards has not been set. 

356 """ 

357 # The number of shards is always the same for all the policies. 

358 return self._sharding_policies[0].number_of_shards 

359 

360 def set_number_of_shards(self, number_of_shards): 

361 """Sets the number of shards to use for the InfeedQueue. 

362 

363 Args: 

364 number_of_shards: number of ways to shard the InfeedQueue. 

365 

366 Raises: 

367 ValueError: if number_of_shards is not > 0; or the policies have 

368 been frozen and number_of_shards was already set to something 

369 else. 

370 """ 

371 for policy in self._sharding_policies: 

372 policy.set_number_of_shards(number_of_shards) 

373 policy.set_number_of_partitions(self._number_of_partitions) 

374 self._validate() 

375 

376 def set_configuration_from_input_tensors(self, input_tensors): 

377 """Sets the shapes and types of the queue tuple elements. 

378 

379 input_tensors is a list of Tensors whose types and shapes are used 

380 to set the queue configuration. 

381 

382 Args: 

383 input_tensors: list of Tensors of the same types and shapes as 

384 the desired queue Tuple. 

385 

386 Raises: 

387 ValueError: if input_tensors is not a list of length 

388 self.number_of_tuple_elements 

389 """ 

390 if len(input_tensors) != self.number_of_tuple_elements: 

391 raise ValueError(f"input_tensors is {str(input_tensors)}, but should be " 

392 f"a list of {self.number_of_tuple_elements} Tensors") 

393 self.set_tuple_shapes([t.shape for t in input_tensors]) 

394 self.set_tuple_types([t.dtype for t in input_tensors]) 

395 

396 def set_configuration_from_sharded_input_tensors(self, input_tensors): 

397 """Sets the shapes and types of the queue tuple elements. 

398 

399 input_tensors is a list of lists of Tensors whose types and shapes are used 

400 to set the queue configuration. The length of the outer list is the number 

401 of shards required, and each inner list is the tuple of Tensors to use to 

402 determine the types and shapes of the corresponding shard. This method 

403 depends on the shard dimension, and calling it freezes the shard policy. 

404 

405 Args: 

406 input_tensors: list of lists of Tensors. The outer list length corresponds 

407 to the desired number of shards, and each inner list is the size 

408 and shape of the desired configuration of the corresponding shard. 

409 

410 Raises: 

411 ValueError: if any inner list is not a list of length 

412 self.number_of_tuple_elements; or the inner lists do not combine to 

413 form a consistent unsharded shape. 

414 TypeError: if the types of the Tensors in the inner lists do not match. 

415 """ 

416 if not self._frozen: 

417 # Unset the tuple shapes in case the configuration becomes 

418 # transiently inconsistent. 

419 self._tuple_shapes = None 

420 number_of_shards = len(input_tensors) 

421 self.set_number_of_shards(number_of_shards) 

422 for t in input_tensors: 

423 if len(t) != self.number_of_tuple_elements: 

424 raise ValueError( 

425 f"input_tensors is {str(input_tensors)} but must be a list of " 

426 "lists, where each inner list has length " 

427 f"number_of_tuple_elements={self.number_of_tuple_elements}") 

428 # Transpose the inputs to make a list of shard shapes for each tuple 

429 # element. 

430 sharded_shapes = [[t[i].shape 

431 for t in input_tensors] 

432 for i in range(self.number_of_tuple_elements)] 

433 # For each tuple, get the unsharded shape using that tuple's policy. 

434 unsharded_shapes = [ 

435 policy.get_unsharded_shape(s) 

436 for (policy, s) in zip(self._sharding_policies, sharded_shapes) 

437 ] 

438 self.set_tuple_shapes(unsharded_shapes) 

439 for i in range(1, self.number_of_shards): 

440 for (t1, t2) in zip(input_tensors[0], input_tensors[i]): 

441 if t1.dtype != t2.dtype: 

442 raise TypeError( 

443 "types of the tuple elements of input_tensors " 

444 f"{str(input_tensors)} are not consistent") 

445 self.set_tuple_types([t.dtype for t in input_tensors[0]]) 

446 

447 def freeze(self): 

448 """Freezes the InfeedQueue so it can no longer be modified. 

449 

450 The configuration is implicitly frozen before any host-side or 

451 device-side Ops are generated. The configuration cannot be frozen 

452 until the types and shapes of the tuple elements have been set. 

453 

454 Raises: 

455 ValueError: if the types or shapes of the tuple elements have not been 

456 set. 

457 """ 

458 self._frozen = True 

459 if self._tuple_types is None: 

460 raise ValueError( 

461 "Can't freeze an InfeedQueue without setting all tuple types.") 

462 if self._tuple_shapes is None: 

463 raise ValueError( 

464 "Can't freeze an InfeedQueue without setting all tuple shapes.") 

465 for shape in self._tuple_shapes: 

466 if shape.dims is None: 

467 raise ValueError( 

468 "Can't freeze an InfeedQueue without setting all tuple shapes.") 

469 for policy in self._sharding_policies: 

470 policy.freeze() 

471 self._validate() 

472 

473 def generate_dequeue_op(self, tpu_device=0): 

474 """Generates the device-side Op to dequeue a tuple from the queue. 

475 

476 Implicitly freezes the queue configuration if it is not already 

477 frozen, which will raise errors if the shapes and types have not 

478 been fully specified. 

479 

480 Args: 

481 tpu_device: The TPU device ordinal where the infeed instruction should be 

482 placed. If None, no explicit placement will be performed, and it is up 

483 to the user to call this API from within a proper TPU device scope. 

484 The XLA code will fail if the TPU dequeue instruction is not bound to 

485 any device. 

486 

487 Returns: 

488 A list of Outputs corresponding to a shard of infeed dequeued 

489 into XLA, suitable for use within a replicated block. 

490 

491 Raises: 

492 ValueError: if the types or shapes of the tuple elements have not been 

493 set; or if a dequeue op has already been generated. 

494 """ 

495 self.freeze() 

496 if self._generated_dequeue_op and not ops.inside_function(): 

497 raise ValueError("Can't generate two dequeue Ops from the same queue") 

498 self._generated_dequeue_op = True 

499 full_name = "%s/dequeue" % self._name 

500 sharded_shapes = [ 

501 policy.get_unpartitioned_shape(policy.get_sharded_shape(shape)) 

502 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 

503 ] 

504 if tpu_device is not None: 

505 with ops.device(tpu_name_util.core(tpu_device)): 

506 dequeue_op = tpu_ops.infeed_dequeue_tuple( 

507 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 

508 else: 

509 dequeue_op = tpu_ops.infeed_dequeue_tuple( 

510 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 

511 if self._number_of_partitions <= 1: 

512 return dequeue_op 

513 partitions = [ 

514 policy.get_unpartitioned_shape([1] * shape.ndims).as_list() 

515 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 

516 ] 

517 return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions) 

518 

519 def _generate_enqueue_op(self, 

520 inputs, 

521 name_prefix, 

522 index, 

523 device=None, 

524 tpu_ordinal=-1): 

525 """Generate a host-side Op to enqueue a tuple to the queue. 

526 

527 If device is None the inputs are all required to have the same 

528 device specification, and the enqueue Op is colocated with 

529 inputs[0]. Otherwise the enqueue Op is placed on 'device'. 

530 

531 Args: 

532 inputs: a list of Tensors with the types and shapes of the tuple elements. 

533 name_prefix: the base name for the Op. 

534 index: the shard index, used to uniquify the Op name. 

535 device: device to place the Op on, or None if it should be 

536 colocated with the inputs. 

537 tpu_ordinal: ordinal of the TPU device on the host to use for 

538 infeed if device is a CPU device. Should be set to -1 if device 

539 is a TPU device. 

540 

541 Returns: 

542 An Op corresponding to a shard of infeed enqueued at the host, 

543 suitable for use within a replicated block. 

544 

545 Raises: 

546 ValueError: if device is None and inputs do not all have the 

547 same device specification. 

548 """ 

549 full_name = "%s/%d" % (name_prefix, index) 

550 shapes = [t.shape for t in inputs] 

551 if device is None: 

552 devices = [t.device for t in inputs] 

553 for i in range(1, self.number_of_tuple_elements): 

554 if devices[0] != devices[i]: 

555 raise ValueError( 

556 f"input devices for shard {index} are {str(devices)}, but should " 

557 "all be the same") 

558 with ops.colocate_with(inputs[0]): 

559 return tpu_ops.infeed_enqueue_tuple( 

560 inputs=inputs, 

561 shapes=shapes, 

562 name=full_name, 

563 device_ordinal=tpu_ordinal) 

564 else: 

565 with ops.device(device): 

566 return tpu_ops.infeed_enqueue_tuple( 

567 inputs=inputs, 

568 shapes=shapes, 

569 name=full_name, 

570 device_ordinal=tpu_ordinal) 

571 

572 def generate_enqueue_ops(self, 

573 sharded_inputs, 

574 tpu_ordinal_function=None, 

575 placement_function=None): 

576 """Generates the host-side Ops to enqueue the shards of a tuple. 

577 

578 sharded_inputs is a list, one for each shard, of lists of 

579 Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed 

580 shard i of the queue. Returns the host-side Ops that must be run to 

581 enqueue the sharded tuple. The Op for shard i is colocated with the inputs 

582 for shard i. 

583 

584 Implicitly freezes the queue configuration if it is not already 

585 frozen. If the configuration has already been frozen, and is not 

586 compatible with the types and shapes of sharded_inputs, an error 

587 will be raised. 

588 

589 Args: 

590 sharded_inputs: a list of lists of Tensors. The length of the outer list 

591 determines the number of shards. Each inner list indicates the types 

592 and shapes of the tuples in the corresponding shard. 

593 tpu_ordinal_function: if not None, a function that takes the 

594 shard index as input and returns the ordinal of the TPU device 

595 the shard's infeed should be placed on. tpu_ordinal_function must be 

596 set if the inputs are placed on CPU devices. 

597 placement_function: if not None, a function that takes the shard index as 

598 input and returns the host device where the enqueue op should be placed 

599 on. 

600 

601 Returns: 

602 A list of host-side Ops, one for each shard, that when executed together 

603 will enqueue a full-size element of infeed. 

604 

605 Raises: 

606 ValueError: if the queue configuration has previously been frozen and the 

607 shapes of the elements of sharded_inputs are not compatible with the 

608 frozen configuration; or if the shapes of the elements of sharded_inputs 

609 don't form a consistent unsharded tuple; or if the elements of a tuple 

610 have different device constraints. 

611 TypeError: if the queue configuration has previously been frozen and the 

612 types of the elements of sharded_inputs are not compatible with the 

613 frozen configuration; or if the types of the elements of sharded_inputs 

614 don't form a consistent unsharded tuple. 

615 """ 

616 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 

617 self.freeze() 

618 if self._generated_enqueue_ops and not ops.inside_function(): 

619 raise ValueError("Can't generate two enqueue Ops from the same queue") 

620 self._generated_enqueue_ops = True 

621 if tpu_ordinal_function is None: 

622 tpu_ordinal_function = lambda index: -1 

623 name_prefix = "%s/enqueue" % self._name 

624 return [ 

625 self._generate_enqueue_op( 

626 shard, 

627 name_prefix, 

628 index, 

629 tpu_ordinal=tpu_ordinal_function(index), 

630 device=placement_function(index) if placement_function else None) 

631 for (shard, index) in zip(sharded_inputs, range(self.number_of_shards)) 

632 ] 

633 

634 # TODO(misard) Generalize this to the case of systems that don't 

635 # have 8 devices per host, and figure out what to do with 

636 # model-parallelism. 

637 def _default_placement_function(self, index): 

638 return "/task:%d/device:CPU:0" % (index / 8) 

639 

640 def _default_ordinal_function(self, index): 

641 return index % 8 

642 

643 # TODO(b/36470756) remove this from tutorials once we have a better story 

644 # for automatic placement of input pipelines. 

645 def split_inputs_and_generate_enqueue_ops(self, 

646 inputs, 

647 device_assignment=None, 

648 placement_function=None, 

649 tpu_ordinal_function=None): 

650 """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. 

651 

652 Generates the host-side Ops to enqueue a tuple. 

653 

654 This method performs poorly because it takes an entire input on a single 

655 host, splits it, and distributes it to all of the cores. It is present only 

656 to simplify tutorial examples. 

657 

658 inputs is a list of Tensors to use to feed the queue. Each input is split 

659 into self.number_of_shards shards. Returns an Op for each shard to enqueue 

660 the shard. The Op for shard i is placed on device placement_function(i). 

661 

662 Implicitly freezes the queue configuration if it is not already 

663 frozen. If the configuration has already been frozen, and is not 

664 compatible with the types and shapes of inputs, an error 

665 will be raised. 

666 

667 Args: 

668 inputs: a list of Tensors which indicates the types and shapes of the 

669 queue tuple. 

670 device_assignment: if not `None`, a TPU `DeviceAssignment`. If 

671 device_assignment is not `None`, but `placement_function` and 

672 `ordinal_function` are None, then `device_assignment` will be used to 

673 place infeeds on the first k TPU shards, where k is the number of shards 

674 in the queue. If all three are `None`, then default placement and 

675 ordinal functions are used. 

676 placement_function: if not None, a function that takes the shard 

677 index as input and returns a device string indicating which 

678 device the shard's infeed should be placed on. If placement_function 

679 and tpu_ordinal_function are None, inputs are sharded round-robin 

680 across the devices in the system. 

681 tpu_ordinal_function: if not None, a function that takes the 

682 shard index as input and returns the ordinal of the TPU device 

683 the shard's infeed should be placed on. If placement_function 

684 and tpu_ordinal_function are None, inputs are sharded round-robin 

685 across the devices in the system. 

686 

687 Returns: 

688 A list of host-side Ops, one for each shard, that when executed together 

689 will enqueue a full-size element of infeed. 

690 

691 Raises: 

692 ValueError: if the queue configuration has previously been frozen and the 

693 shapes of the elements of inputs are not compatible with the frozen 

694 configuration. 

695 TypeError: if the queue configuration has previously been frozen and the 

696 types of the elements of inputs are not compatible with the frozen 

697 configuration. 

698 """ 

699 if device_assignment is None: 

700 if placement_function is None: 

701 placement_function = self._default_placement_function 

702 if tpu_ordinal_function is None: 

703 tpu_ordinal_function = self._default_ordinal_function 

704 else: 

705 

706 def _placement_function_from_map(index): 

707 return device_assignment.host_device(replica=index) 

708 

709 def _ordinal_function_from_map(index): 

710 return device_assignment.tpu_ordinal(replica=index) 

711 

712 if placement_function is None: 

713 placement_function = _placement_function_from_map 

714 if tpu_ordinal_function is None: 

715 tpu_ordinal_function = _ordinal_function_from_map 

716 self.set_configuration_from_input_tensors(inputs) 

717 self.freeze() 

718 if self._generated_enqueue_ops and not ops.inside_function(): 

719 raise ValueError("Can't generate two enqueue Ops from the same queue") 

720 self._generated_enqueue_ops = True 

721 split_name_prefix = "%s/split" % self._name 

722 if self.number_of_shards == 1: 

723 transposed_sharded_inputs = [[inp] for inp in inputs] 

724 else: 

725 

726 def split_fn(inp, num_shards, axis, name): 

727 with ops.colocate_with(inp): 

728 return array_ops.split(inp, num_shards, axis=axis, name=name) 

729 

730 transposed_sharded_inputs = [ 

731 split_fn( 

732 inp, 

733 self.number_of_shards, 

734 axis=policy.shard_dimension, 

735 name="%s/%d" % (split_name_prefix, index)) 

736 for (inp, policy, index) in zip(inputs, self._sharding_policies, 

737 range(self.number_of_tuple_elements)) 

738 ] 

739 sharded_inputs = [[shard[i] 

740 for shard in transposed_sharded_inputs] 

741 for i in range(self.number_of_shards)] 

742 name_prefix = "%s/enqueue" % self._name 

743 return [ 

744 self._generate_enqueue_op( 

745 shard, 

746 name_prefix, 

747 index, 

748 device=placement_function(index), 

749 tpu_ordinal=tpu_ordinal_function(index)) 

750 for (shard, index) in zip(sharded_inputs, range(self.number_of_shards)) 

751 ] 

752 

753 

754class _PartitionedInfeedQueue(InfeedQueue): 

755 """A helper object to build a device infeed queue with input partition. 

756 

757 Args: 

758 number_of_tuple_elements: the number of Tensors fed atomically through the 

759 queue, must be present unless it can be inferred from other arguments. 

760 device_assignment: A TPU `DeviceAssignment` which is used to place all the 

761 partitions to different TPU infeed queues. 

762 host_id: The id of the host machine. 

763 input_partition_dims: A nested list/tuple of integers. Each inner 

764 list/tuple describes how to partition the corresponding input tensor. 

765 tuple_types: If not None, a list of types of the elements of the queue. 

766 tuple_shapes: If not None, a list of shapes of the elements of the queue. 

767 name: The name of the queue. 

768 """ 

769 

770 def __init__(self, 

771 number_of_tuple_elements, 

772 device_assignment, 

773 host_id, 

774 input_partition_dims=None, 

775 tuple_types=None, 

776 tuple_shapes=None, 

777 name=None): 

778 super(_PartitionedInfeedQueue, self).__init__( 

779 number_of_tuple_elements=number_of_tuple_elements, 

780 tuple_types=tuple_types, 

781 tuple_shapes=None, 

782 shard_dimensions=None, 

783 name="PartitionedInfeedQueue" if name is None else name) 

784 self._input_partition_dims = input_partition_dims 

785 self._host_id = host_id 

786 self._device_assignment = device_assignment 

787 

788 def generate_dequeue_op(self, tpu_device=0): 

789 """Generate TPU dequeue ops. 

790 

791 Args: 

792 tpu_device: The TPU device ordinal where the infeed instruction should be 

793 placed. 

794 

795 Returns: 

796 A list of Outputs corresponding to a partition of infeed dequeued 

797 into XLA, suitable for use within a replicated block. 

798 

799 Raises: 

800 ValueError: if the types or shapes of the tuple elements have not been 

801 set; or if a dequeue op has already been generated. 

802 """ 

803 self.freeze() 

804 if self._generated_dequeue_op and not ops.inside_function(): 

805 raise ValueError("Can't generate two dequeue Ops from the same queue") 

806 self._generated_dequeue_op = True 

807 full_name = "%s/dequeue" % self._name 

808 sharded_shapes = [ 

809 policy.get_sharded_shape(shape) 

810 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 

811 ] 

812 with ops.device(tpu_name_util.core(tpu_device)): 

813 values = tpu_ops.infeed_dequeue_tuple( 

814 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 

815 return tag_sharding_attribute_for_dequeued_tensors( 

816 values, self._input_partition_dims) 

817 

818 def generate_enqueue_ops(self, sharded_inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks 

819 """Generates the host-side Ops to enqueue the partitioned inputs. 

820 

821 sharded_inputs is a list, one for each replica, of lists of 

822 Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed 

823 replica i. 

824 sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. 

825 

826 For example, if sharded_inputs[i][j] is a 2-D Tensor: 

827 [[A, B, C, D], 

828 [E ,F, G, H]] 

829 self._input_partition_dims[j] is [2, 4]. 

830 

831 sharded_inputs[i][j] will be partitioned and flattened into: 

832 [A, B, C, D, E, F, G, H] and fed into the logical core ids: 

833 [0, 1, 2, 3, 4, 5, 6, 7] respectively. 

834 

835 Args: 

836 sharded_inputs: a list of lists of Tensors. The length of the 

837 outer list determines the number of shards. Each inner list indicates 

838 the types and shapes of the tuples in the corresponding shard. 

839 

840 Returns: 

841 A list of host-side Ops, one for each shard, that when executed together 

842 will enqueue a full-size element of infeed. 

843 

844 Raises: 

845 ValueError: if the queue configuration has previously been frozen and the 

846 shapes of the elements of sharded_inputs are not compatible with the 

847 frozen configuration; or if the shapes of the elements of sharded_inputs 

848 don't form a consistent unsharded tuple; or if the elements of a tuple 

849 have different device constraints; or if the partition dims are invalid. 

850 TypeError: if the queue configuration has previously been frozen and the 

851 types of the elements of sharded_inputs are not compatible with the 

852 frozen configuration; or if the types of the elements of sharded_inputs 

853 don't form a consistent unsharded tuple. 

854 """ 

855 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 

856 number_of_replicas = len(sharded_inputs) 

857 number_of_tuple_elements = len(sharded_inputs[0]) 

858 

859 assert len(self._input_partition_dims) == number_of_tuple_elements 

860 enqueue_ops = [] 

861 

862 for replica_index in range(number_of_replicas): 

863 flattened_inputs = sharded_inputs[replica_index] 

864 inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, 

865 self._input_partition_dims) 

866 inputs_parted_iters = [ 

867 iter(self._check_dims_and_partition_or_replicate_on_host(x, dims)) 

868 for x, dims in zip(sharded_inputs[replica_index], 

869 inputs_part_dims_flat) 

870 ] 

871 

872 # Find the replica_id of the host's logical core 0. 

873 # The self._host_id is guaranteed to contain the logical core 0, 

874 # even when num_cores_per_replica > num_cores_per_host -- the function 

875 # caller makes sure that this host_id will must be receiving data (calls 

876 # input_fn). 

877 replica_id = self._device_assignment.lookup_replicas( 

878 task_id=self._host_id, logical_core=0)[replica_index] 

879 for logical_core in range(self._device_assignment.num_cores_per_replica): 

880 # Places different partitions to different logic cores. 

881 # Since there can be multiple hosts per replica, we need to find 

882 # the actual host (device) of this logical core. 

883 device = self._device_assignment.host_device( 

884 replica=replica_id, logical_core=logical_core) 

885 

886 with ops.device(device): 

887 ordinal = self._device_assignment.tpu_ordinal( 

888 replica=replica_id, logical_core=logical_core) 

889 infeed_inputs = [] 

890 for it in inputs_parted_iters: 

891 input_for_device = next(it, None) 

892 if input_for_device is not None: 

893 infeed_inputs.append(input_for_device) 

894 

895 if infeed_inputs: 

896 enqueue_ops.append( 

897 tpu_ops.infeed_enqueue_tuple( 

898 inputs=infeed_inputs, 

899 shapes=[x.shape for x in infeed_inputs], 

900 name="enqueue/replica_{0}/input_{1}".format( 

901 replica_index, logical_core), 

902 device_ordinal=ordinal)) 

903 return enqueue_ops 

904 

905 def _check_input_partition_dims(self, tensor, dims): 

906 """Checks that input partition dims are valid for the `Tensor`. 

907 

908 Args: 

909 tensor: Input tensor for partitioning. 

910 dims: A list of integer describes how to partition the input tensor. 

911 

912 Raises: 

913 ValueError: If the tensor can't be partitioned by dims or the 

914 num_cores_per_replica doesn't match the number of 

915 partitions(dims.prod()). 

916 """ 

917 # No partitioning specified, so don't perform further checks. 

918 if dims is None: 

919 return 

920 

921 dims = np.array(dims) 

922 

923 if (dims < 1).any(): 

924 raise ValueError("All input partition dims must be >= 1.") 

925 

926 # No partitioning, so don't perform further checks. 

927 if dims.prod() == 1: 

928 return 

929 

930 if dims.prod() != self._device_assignment.num_cores_per_replica: 

931 raise ValueError( 

932 "The product of each input partition dim should equal to " 

933 "num_cores_per_replica. (dim = {}, num_cores_per_replica " 

934 "= {})".format(dims, self._device_assignment.num_cores_per_replica)) 

935 if dims.shape[0] != tensor.shape.ndims: 

936 raise ValueError( 

937 "Input partition dims must have the same number of dimensions " 

938 "as the `Tensor` to be partitioned. (tensor shape = {}, input " 

939 "partition dims = {}).".format(tensor.shape.as_list(), dims)) 

940 

941 tensor.shape.assert_is_fully_defined() 

942 

943 def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims): 

944 """Checks dims and partitions or replicates the input tensor. 

945 

946 The ops inside this function are placed on the host side. 

947 

948 Args: 

949 tensor: The input tensor which will be partitioned or replicated. 

950 dims: A list of integer describes how to partition the input tensor. 

951 

952 Returns: 

953 An iterator of `Tensor`s or a list of partitioned tensors. 

954 """ 

955 self._check_input_partition_dims(tensor, dims) 

956 return partition_or_replicate_on_host(tensor, dims)