Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/ops/tpu_ops.py: 41%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Operations for TPUs.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.ops import array_ops 

20# pylint: disable=wildcard-import,unused-import 

21from tensorflow.python.ops import gen_tpu_ops 

22from tensorflow.python.ops.gen_tpu_ops import * 

23# pylint: enable=wildcard-import,unused-import 

24from tensorflow.python.platform import tf_logging as logging 

25from tensorflow.python.tpu import tpu_function 

26from tensorflow.python.util.tf_export import tf_export 

27 

28 

29def _create_default_group_assignment(): 

30 num_shards = tpu_function.get_tpu_context().number_of_shards 

31 if num_shards is None: 

32 logging.warning( 

33 "cross_replica_sum should be used within a tpu_shard_context, but " 

34 "got unset number_of_shards. Assuming 1.") 

35 num_shards = 1 

36 group_assignment = [list(range(num_shards))] 

37 return group_assignment 

38 

39 

40def all_to_all(x, 

41 concat_dimension, 

42 split_dimension, 

43 split_count, 

44 group_assignment=None, 

45 name=None): 

46 """Exchange data across TPU replicas. 

47 

48 Args: 

49 x: The local tensor. 

50 concat_dimension: The dimension number to concatenate. 

51 split_dimension: The dimension number to split. 

52 split_count: The number of splits, this number must equal to the sub-group 

53 size(group_assignment.get_shape()[1]) 

54 group_assignment: Optional 2d int32 lists with shape [num_groups, 

55 num_replicas_per_group]. `group_assignment[i]` represents the replica ids 

56 in the ith subgroup. 

57 name: Optional op name. 

58 

59 Returns: 

60 A `Tensor` which is concatenated by data from different replicas. 

61 """ 

62 if group_assignment is None: 

63 group_assignment = _create_default_group_assignment() 

64 return gen_tpu_ops.all_to_all( 

65 x, 

66 group_assignment, 

67 concat_dimension=concat_dimension, 

68 split_dimension=split_dimension, 

69 split_count=split_count, 

70 name=name) 

71 

72 

73@ops.RegisterGradient("AllToAll") 

74def _all_to_all_grad(op, grad): 

75 # The gradient of a all-to-all is also a all-to-all but the 

76 # split_dimension and concat_dimension is swapped. 

77 # The gradient with respect to group_assignment is None. 

78 return [ 

79 gen_tpu_ops.all_to_all( 

80 grad, 

81 op.inputs[1], 

82 concat_dimension=op.get_attr("split_dimension"), 

83 split_dimension=op.get_attr("concat_dimension"), 

84 split_count=op.get_attr("split_count")), None 

85 ] 

86 

87 

88@tf_export(v1=["tpu.cross_replica_sum"]) 

89def cross_replica_sum(x, group_assignment=None, name=None): 

90 """Sum the input tensor across replicas according to group_assignment. 

91 

92 Args: 

93 x: The local tensor to the sum. 

94 group_assignment: Optional 2d int32 lists with shape [num_groups, 

95 num_replicas_per_group]. `group_assignment[i]` represents the replica ids 

96 in the ith subgroup. 

97 name: Optional op name. 

98 

99 Returns: 

100 A `Tensor` which is summed across replicas. 

101 """ 

102 if group_assignment is None: 

103 group_assignment = _create_default_group_assignment() 

104 

105 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) 

106 

107 

108def collective_permute(x, source_target_pairs, name=None): 

109 """Permute the input tensor across replicas given source_target_pairs. 

110 

111 For each source_target_pair <a, b>, we send replica a's input to replica b. 

112 Each replica id must only appear once in the source column. Also it must 

113 only appear once in the target column. 

114 For the replica id not in the target column, this op returns a zero tensor 

115 with the same shape and dtype of the input x. 

116 

117 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing 

118 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: 

119 `[0, A, B, C]`. 

120 

121 Args: 

122 x: The local tensor to be permuted. 

123 source_target_pairs: 2d int lists with shape [num_pairs, 2]. 

124 source_target_pairs[i][0] represents the source replica id and 

125 source_target_pairs[i][1] represents the target replica id. 

126 name: Optional op name. 

127 

128 Returns: 

129 A `Tensor` which is permuted. 

130 """ 

131 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) 

132 

133 

134@ops.RegisterGradient("CollectivePermute") 

135def _collective_permute_grad(op, grad): 

136 # The gradient of a collective permute operation is also a collective 

137 # permute, but with source/target pairs reversed. The gradient with respect 

138 # to input argument `source_target_pairs` is `None`. 

139 source_target_pairs = op.inputs[1][:, ::-1] 

140 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] 

141 

142 

143@ops.RegisterGradient("CrossReplicaSum") 

144def _cross_replica_sum_grad(op, grad): 

145 # The gradient of a cross replica sum is also a cross-replica sum. 

146 # The gradient with respect to group_assignment is None. 

147 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] 

148 

149 

150# This extra type checking exists to give a more helpful error message. 

151_SUPPORTED_INFEED_DTYPES = frozenset([ 

152 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, 

153 dtypes.complex64, dtypes.uint32, dtypes.uint8, dtypes.int8 

154]) 

155 

156 

157@ops.RegisterGradient("TPUEmbeddingActivations") 

158def _embedding_activations_grad(activations_op, grad_wrt_activations): 

159 """Saves the gradient of embedding activations ops in a graph collection.""" 

160 g = ops.get_default_graph() 

161 table_id = activations_op.get_attr("table_id") 

162 lookup_id = activations_op.get_attr("lookup_id") 

163 table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" % 

164 table_id) 

165 

166 if not table_gradients: 

167 raise RuntimeError( 

168 "Gradients for TPUEmbedding have been generated in non-training mode." 

169 "This is not expected. Consider putting your Optimizer.minimize code " 

170 "behind the training mode condition check. For Estimator, you can " 

171 "do \n\n" 

172 " if mode == tf.estimator.ModeKeys.TRAIN:\n" 

173 " train_op = opt.minimize(loss)\n" 

174 "\n") 

175 

176 if lookup_id < 0 or lookup_id >= len(table_gradients): 

177 raise RuntimeError( 

178 "Gradients (w.r.t. TPUEmbedding activations) generated for table_id {} " 

179 "and lookup_id {}. The lookup_id attribute is outside the expected " 

180 "range [0, {}).".format(table_id, lookup_id, len(table_gradients))) 

181 

182 if table_gradients[lookup_id] is not None: 

183 raise RuntimeError( 

184 "Duplicate gradients (w.r.t. TPUEmbedding activations) generated for " 

185 "table_id {} and lookup_id {}. This happens when there are multiple " 

186 "calls to tf.gradients in a graph containing TPU embeddings. " 

187 "TF cannot identify which gradient to use for updating the embedding " 

188 "variables. Consider placing tf.StopGradient around tensors where " 

189 "variable update is not required. Previous gradients were generated by " 

190 "the following callstack: {}.".format( 

191 table_id, lookup_id, table_gradients[lookup_id].op.traceback)) 

192 

193 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) 

194 return [ 

195 # RegisterGradient requires that value be returned for all inputs. Since 

196 # the first argument (tpu_gradient_variable_{table_name}) has shape [1], 

197 # we will return zeros(shape=[1]). The actual gradient w.r.t. the 

198 # embedding activations (grad_wrt_activations) has the same shape as the 

199 # activations returned by embedding_activations. 

200 array_ops.zeros(arg.shape, dtype=dtypes.float32) 

201 for arg in activations_op.inputs 

202 ] 

203 

204 

205def infeed_dequeue(dtype, shape, name=None): 

206 """A placeholder op for a value that will be fed into the computation. 

207 

208 Args: 

209 dtype: A `tf.DType`. The type of elements in the tensor. 

210 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 

211 name: A name for the operation (optional). 

212 

213 Returns: 

214 A `Tensor` of type `dtype`. 

215 A tensor that will be provided using the infeed mechanism. 

216 

217 Raises: 

218 TypeError: If 'dtype` is not a supported infeed type. 

219 """ 

220 if dtype not in _SUPPORTED_INFEED_DTYPES: 

221 raise TypeError( 

222 "Operation '{}' has type {} which is not a supported TPU infeed type. " 

223 "Supported types are: {}".format(name, dtype, 

224 list(_SUPPORTED_INFEED_DTYPES))) 

225 

226 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 

227 

228 

229# pylint: disable=redefined-outer-name 

230def infeed_dequeue_tuple(dtypes, shapes, name=None): 

231 """A placeholder op for values fed into the TPU simultaneously as a tuple. 

232 

233 Args: 

234 dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of 

235 each element in `outputs`. 

236 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The 

237 shapes of each tensor in `outputs`. 

238 name: A name for the operation (optional). 

239 

240 Returns: 

241 A list of `Tensor` objects of type `dtypes`. 

242 A list of tensors that will be provided using the infeed mechanism. 

243 

244 Raises: 

245 TypeError: If a type in 'dtypes` is not a supported infeed type. 

246 """ 

247 for dtype in dtypes: 

248 if dtype not in _SUPPORTED_INFEED_DTYPES: 

249 raise TypeError( 

250 "{} is not a supported TPU infeed type. Supported types are: " 

251 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 

252 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 

253 

254 

255# pylint: enable=redefined-outer-name 

256 

257 

258# pylint: disable=protected-access 

259def send_tpu_embedding_gradients(inputs, 

260 config, 

261 learning_rates=None, 

262 name=None): 

263 """A placeholder op for feeding per-sample gradients to the embedding layer. 

264 

265 Args: 

266 inputs: A TensorList of gradients with which to update embedding tables. 

267 This argument has the same length and shapes as the return value of 

268 RecvTPUEmbeddingActivations, but contains gradients of the model's loss 

269 with respect to the embedding activations. The embedding tables are 

270 updated from these gradients via the optimizers specified in the TPU 

271 embedding configuration given to tpu.initialize_system. 

272 config: Serialized TPUEmbeddingConfiguration proto. 

273 learning_rates: A TensorList of float32 scalars, one for each dynamic 

274 learning rate tag: see the comments in 

275 //third_party/tensorflow/core/protobuf/tpu/ 

276 optimization_parameters.proto. Multiple tables can share the same 

277 dynamic learning rate tag as specified in the configuration. If the 

278 learning rates for all tables are constant, this list should be empty. 

279 name: A name for the operation (optional). 

280 

281 Returns: 

282 A SendTPUEmbeddingGradients operation. 

283 """ 

284 if learning_rates is None: 

285 learning_rates = [] 

286 return gen_tpu_ops.send_tpu_embedding_gradients( 

287 inputs=inputs, learning_rates=learning_rates, config=config, name=name) 

288 

289 

290send_tpu_embedding_gradients.__doc__ = ( 

291 gen_tpu_ops.send_tpu_embedding_gradients.__doc__) 

292 

293 

294# pylint: disable=protected-access 

295def enqueue_tpu_embedding_integer_batch(batch, 

296 device_ordinal, 

297 mode_override=None, 

298 name=None): 

299 """A placeholder op for enqueueing embedding IDs to the TPU. 

300 

301 Args: 

302 batch: A list of 1D tensors, one for each embedding table, containing the 

303 indices into the tables. 

304 device_ordinal: The TPU device to use. Should be >= 0 and less than the 

305 number of TPU cores in the task on which the node is placed. 

306 mode_override: A string input that overrides the mode specified in the 

307 TPUEmbeddingConfiguration. Supported values are {'unspecified', 

308 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 

309 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 

310 is used (optional). 

311 name: A name for the operation (optional). 

312 

313 Returns: 

314 An EnqueueTPUEmbeddingIntegerBatch operation. 

315 """ 

316 if mode_override is None: 

317 mode_override = "unspecified" 

318 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( 

319 batch=batch, 

320 device_ordinal=device_ordinal, 

321 mode_override=mode_override, 

322 name=name) 

323 

324 

325enqueue_tpu_embedding_integer_batch.__doc__ = ( 

326 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) 

327 

328 

329# pylint: disable=protected-access 

330def enqueue_tpu_embedding_sparse_batch(sample_indices, 

331 embedding_indices, 

332 aggregation_weights, 

333 device_ordinal, 

334 combiners=None, 

335 mode_override=None, 

336 name=None): 

337 """A placeholder op for enqueueing embedding IDs to the TPU. 

338 

339 Args: 

340 sample_indices: A list of rank 1 Tensors specifying the training example and 

341 feature to which the corresponding embedding_indices and 

342 aggregation_weights values belong. sample_indices[i] must equal b * nf + 

343 f, where nf is the number of features from the corresponding table, f is 

344 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, 

345 and will be converted to int32 internally. 

346 embedding_indices: A list of rank 1 Tensors, indices into the embedding 

347 tables. Both int32 and int64 are allowed and will be converted to int32 

348 internally. 

349 aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e., 

350 per (training example, feature) -- aggregation weights. Both float32 and 

351 float64 are allowed and will be converted to float32 internally. 

352 device_ordinal: The TPU device to use. Should be >= 0 and less than the 

353 number of TPU cores in the task on which the node is placed. 

354 combiners: A list of string scalars, one for each embedding table that 

355 specify how to normalize the embedding activations after weighted 

356 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 

357 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 

358 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 

359 is to use 'sum' for all tables (optional). 

360 mode_override: A string input that overrides the mode specified in the 

361 TPUEmbeddingConfiguration. Supported values are {'unspecified', 

362 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 

363 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 

364 is used (optional). 

365 name: A name for the operation (optional). 

366 

367 Returns: 

368 An EnqueueTPUEmbeddingSparseBatch operation. 

369 """ 

370 if mode_override is None: 

371 mode_override = "unspecified" 

372 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( 

373 sample_indices=sample_indices, 

374 embedding_indices=embedding_indices, 

375 aggregation_weights=aggregation_weights, 

376 device_ordinal=device_ordinal, 

377 combiners=combiners, 

378 mode_override=mode_override, 

379 name=name) 

380 

381 

382enqueue_tpu_embedding_sparse_batch.__doc__ = ( 

383 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) 

384 

385 

386# pylint: disable=protected-access 

387def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, 

388 embedding_indices, 

389 aggregation_weights, 

390 table_ids, 

391 device_ordinal, 

392 max_sequence_lengths=None, 

393 num_features=None, 

394 combiners=None, 

395 mode_override=None, 

396 name=None): 

397 """A placeholder op for enqueueing embedding IDs to the TPU. 

398 

399 Args: 

400 sample_indices: A list of rank 2 Tensors specifying the training example to 

401 which the corresponding embedding_indices and aggregation_weights values 

402 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If 

403 the size of its first dimension is 0, we assume each embedding_indices 

404 belongs to a different sample. Both int32 and int64 are allowed and will 

405 be converted to int32 internally. 

406 embedding_indices: A list of rank 1 Tensors, indices into the embedding 

407 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both 

408 int32 and int64 are allowed and will be converted to int32 internally. 

409 aggregation_weights: A list of rank 1 Tensors containing per training 

410 example aggregation weights. It corresponds to sp_weights.values in 

411 embedding_lookup_sparse(). If the size of its first dimension is 0, we 

412 assume all weights are 1. Both float32 and float64 are allowed and will be 

413 converted to float32 internally. 

414 table_ids: A list of integers specifying the identifier of the embedding 

415 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 

416 lookup the corresponding input. The ith input is looked up using 

417 table_ids[i]. The size of the table_ids list must be equal to that of 

418 sample_indices, embedding_indices and aggregation_weights. 

419 device_ordinal: The TPU device to use. Should be >= 0 and less than the 

420 number of TPU cores in the task on which the node is placed. 

421 max_sequence_lengths: A list of integers, the size of which is equal to 

422 sample_indices. If equal to 0, the corresponding feature is considered to 

423 be a non-sequence feature, If greater than 0, the corresponding feature is 

424 a sequence feature with the given maximal length. If None, then we assume 

425 a list of all zeroes. 

426 num_features: A list of integers, the size of which is equal to 

427 sample_indices. If non-empty, entries in this list must be at least 1. For 

428 each batch element, we will take num_features rows of the input tensor for 

429 embedding lookup. E.g., when sample_indices is empty, the embedding 

430 indices must be of shape (batch_size*num_features). 

431 combiners: A list of string scalars, one for each embedding table that 

432 specify how to normalize the embedding activations after weighted 

433 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 

434 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 

435 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 

436 is to use 'sum' for all tables (optional). 

437 mode_override: A string input that overrides the mode specified in the 

438 TPUEmbeddingConfiguration. Supported values are {'unspecified', 

439 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 

440 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 

441 is used (optional). 

442 name: A name for the operation (optional). 

443 

444 Returns: 

445 An EnqueueTPUEmbeddingSparseTensorBatch operation. 

446 """ 

447 if mode_override is None: 

448 mode_override = "unspecified" 

449 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 

450 sample_indices=sample_indices, 

451 embedding_indices=embedding_indices, 

452 aggregation_weights=aggregation_weights, 

453 table_ids=table_ids, 

454 device_ordinal=device_ordinal, 

455 max_sequence_lengths=max_sequence_lengths, 

456 combiners=combiners, 

457 mode_override=mode_override, 

458 num_features=num_features, 

459 name=name) 

460 

461 

462enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( 

463 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) 

464 

465 

466# pylint: disable=protected-access 

467def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits, 

468 embedding_indices, 

469 aggregation_weights, 

470 table_ids, 

471 device_ordinal, 

472 max_sequence_lengths=None, 

473 num_features=None, 

474 combiners=None, 

475 mode_override=None, 

476 name=None): 

477 """A placeholder op for enqueueing embedding IDs to the TPU. 

478 

479 Args: 

480 sample_splits: A list of rank 1 Tensors specifying the break points for 

481 splitting embedding_indices and aggregation_weights into rows. It 

482 corresponds to ids.row_splits in embedding_lookup(), when ids is a 

483 RaggedTensor. Both int32 and int64 are allowed and will be converted to 

484 int32 internally. 

485 embedding_indices: A list of rank 1 Tensors, indices into the embedding 

486 tables. It corresponds to ids.values in embedding_lookup(), when ids is a 

487 RaggedTensor. Both int32 and int64 are allowed and will be converted to 

488 int32 internally. 

489 aggregation_weights: A list of rank 1 Tensors containing per training 

490 example aggregation weights. It corresponds to the values field of a 

491 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 

492 ids is a RaggedTensor. Both float32 and float64 are allowed and will be 

493 converted to float32 internally. 

494 table_ids: A list of integers specifying the identifier of the embedding 

495 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 

496 lookup the corresponding input. The ith input is looked up using 

497 table_ids[i]. The size of the table_ids list must be equal to that of 

498 sample_indices, embedding_indices and aggregation_weights. 

499 device_ordinal: The TPU device to use. Should be >= 0 and less than the 

500 number of TPU cores in the task on which the node is placed. 

501 max_sequence_lengths: A list of integers, the size of which is equal to 

502 sample_indices. If equal to 0, the corresponding feature is considered to 

503 be a non-sequence feature, If greater than 0, the corresponding feature is 

504 a sequence feature with the given maximal length. If None, then we assume 

505 a list of all zeroes. 

506 num_features: A list of integers, the size of which must be equal to 

507 sample_indices. If non-empty, entries in this list must be at least 1. For 

508 each batch element, we will take num_features rows of the input tensor for 

509 embedding lookup. E.g., when sample_indices is empty, the embedding 

510 indices must be of shape (batch_size*num_features). 

511 combiners: A list of string scalars, one for each embedding table that 

512 specify how to normalize the embedding activations after weighted 

513 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 

514 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 

515 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 

516 is to use 'sum' for all tables (optional). 

517 mode_override: A string input that overrides the mode specified in the 

518 TPUEmbeddingConfiguration. Supported values are {'unspecified', 

519 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', 

520 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 

521 is used (optional). 

522 name: A name for the operation (optional). 

523 

524 Returns: 

525 An EnqueueTPUEmbeddingRaggedTensorBatch operation. 

526 """ 

527 if mode_override is None: 

528 mode_override = "unspecified" 

529 return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 

530 sample_splits=sample_splits, 

531 embedding_indices=embedding_indices, 

532 aggregation_weights=aggregation_weights, 

533 table_ids=table_ids, 

534 device_ordinal=device_ordinal, 

535 max_sequence_lengths=max_sequence_lengths, 

536 combiners=combiners, 

537 mode_override=mode_override, 

538 num_features=num_features, 

539 name=name) 

540 

541 

542enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = ( 

543 gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__) 

544 

545 

546def enqueue_tpu_embedding_arbitrary_tensor_batch(sample_indices_or_row_splits, 

547 embedding_indices, 

548 aggregation_weights, 

549 device_ordinal, 

550 combiners=None, 

551 mode_override=None, 

552 name=None): 

553 """A placeholder op for enqueueing embedding IDs to the TPU. 

554 

555 Args: 

556 sample_indices_or_row_splits: A list of rank 1 or 2 Tensors. When rank 2, 

557 the tensors specify the training example to which the corresponding 

558 embedding_indices and aggregation_weights values belong. If the size of 

559 its first dimension is 0, we assume each embedding_indices belongs to a 

560 different sample. Both int32 and int64 are allowed and will be converted 

561 to int32 internally. When rank 1, the tensors specify the row splits for 

562 splitting embedding_indices and aggregation_weights into rows. It 

563 corresponds to ids.row_splits in embedding_lookup(), when ids is a 

564 RaggedTensor. When enqueuing N-D ragged tensor, only the last dimension is 

565 allowed to be ragged. the row splits is 1-D dense tensor. When empty, we 

566 assume a dense tensor is passed to the op. Both int32 and int64 are 

567 allowed and will be converted to int32 internally. 

568 embedding_indices: A list of rank 1 Tensors, indices into the embedding 

569 tables. Both int32 and int64 are allowed and will be converted to int32 

570 internally. 

571 aggregation_weights: A list of rank 1 Tensors containing per training 

572 example aggregation weights. Both float32 and float64 are allowed and will 

573 be converted to float32 internally. 

574 device_ordinal: The TPU device to use. Should be >= 0 and less than the 

575 number of TPU cores in the task on which the node is placed. 

576 combiners: A list of string scalars, one for each embedding table that 

577 specify how to normalize the embedding activations after weighted 

578 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 

579 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 

580 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 

581 is to use 'sum' for all tables (optional). 

582 mode_override: A string input that overrides the mode specified in the 

583 TPUEmbeddingConfiguration. Supported values are {'unspecified', 

584 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', 

585 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 

586 is used (optional). 

587 name: A name for the operation (optional). 

588 

589 Returns: 

590 An EnqueueTPUEmbeddingArbitraryTensorBatch operation. 

591 """ 

592 if mode_override is None: 

593 mode_override = "unspecified" 

594 return gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch( 

595 sample_indices_or_row_splits=sample_indices_or_row_splits, 

596 embedding_indices=embedding_indices, 

597 aggregation_weights=aggregation_weights, 

598 device_ordinal=device_ordinal, 

599 combiners=combiners, 

600 mode_override=mode_override, 

601 name=name) 

602 

603 

604enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__ = ( 

605 gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__)