Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding_v2.py: 17%

488 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Mid level API for TPU Embeddings.""" 

16 

17import functools 

18from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union 

19 

20from absl import logging 

21 

22from tensorflow.core.framework import attr_value_pb2 

23from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 

24from tensorflow.python.distribute import device_util 

25from tensorflow.python.distribute import distribute_lib 

26from tensorflow.python.distribute import distribute_utils 

27from tensorflow.python.distribute import sharded_variable 

28from tensorflow.python.distribute import tpu_strategy 

29from tensorflow.python.eager import context 

30from tensorflow.python.eager import def_function 

31from tensorflow.python.framework import constant_op 

32from tensorflow.python.framework import device as tf_device 

33from tensorflow.python.framework import dtypes 

34from tensorflow.python.framework import ops 

35from tensorflow.python.framework import sparse_tensor 

36from tensorflow.python.framework.tensor_shape import TensorShape 

37from tensorflow.python.ops import array_ops 

38from tensorflow.python.ops import math_ops 

39from tensorflow.python.ops import variable_scope 

40from tensorflow.python.ops import variables as tf_variables 

41from tensorflow.python.ops.ragged import ragged_tensor 

42from tensorflow.python.saved_model import registration 

43from tensorflow.python.saved_model import save_context 

44from tensorflow.python.tpu import tpu 

45from tensorflow.python.tpu import tpu_embedding_v2_utils 

46from tensorflow.python.tpu import tpu_replication 

47from tensorflow.python.tpu.ops import tpu_ops 

48from tensorflow.python.trackable import autotrackable 

49from tensorflow.python.trackable import base 

50from tensorflow.python.types import internal as internal_types 

51from tensorflow.python.util import compat 

52from tensorflow.python.util import nest 

53from tensorflow.python.util import tf_inspect 

54from tensorflow.python.util.tf_export import tf_export 

55 

56 

57_HOOK_KEY = "TPUEmbedding_saveable" 

58_NAME_KEY = "_tpu_embedding_layer" 

59 

60 

61class TPUEmbeddingVariable(sharded_variable.ShardedVariableMixin): 

62 """A ShardedVariable class for TPU.""" 

63 

64 @property 

65 def _in_graph_mode(self): 

66 return self.variables[0]._in_graph_mode # pylint: disable=protected-access 

67 

68 

69def _add_key_attr(op, name): 

70 op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access 

71 

72 

73@tf_export("tpu.experimental.embedding.TPUEmbedding") 

74class TPUEmbedding(autotrackable.AutoTrackable): 

75 """The TPUEmbedding mid level API. 

76 

77 NOTE: When instantiated under a TPUStrategy, this class can only be created 

78 once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to 

79 re-initialize the embedding engine you must re-initialize the tpu as well. 

80 Doing this will clear any variables from TPU, so ensure you have checkpointed 

81 before you do this. If a further instances of the class are needed, 

82 set the `initialize_tpu_embedding` argument to `False`. 

83 

84 This class can be used to support training large embeddings on TPU. When 

85 creating an instance of this class, you must specify the complete set of 

86 tables and features you expect to lookup in those tables. See the 

87 documentation of `tf.tpu.experimental.embedding.TableConfig` and 

88 `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete 

89 set of options. We will cover the basic usage here. 

90 

91 NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object, 

92 allowing different features to share the same table: 

93 

94 ```python 

95 table_config_one = tf.tpu.experimental.embedding.TableConfig( 

96 vocabulary_size=..., 

97 dim=...) 

98 table_config_two = tf.tpu.experimental.embedding.TableConfig( 

99 vocabulary_size=..., 

100 dim=...) 

101 feature_config = { 

102 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( 

103 table=table_config_one), 

104 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( 

105 table=table_config_one), 

106 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( 

107 table=table_config_two)} 

108 ``` 

109 

110 There are two modes under which the `TPUEmbedding` class can used. This 

111 depends on if the class was created under a `TPUStrategy` scope or not. 

112 

113 Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and 

114 `apply_gradients`. We will show examples below of how to use these to train 

115 and evaluate your model. Under CPU, we only access to the `embedding_tables` 

116 property which allow access to the embedding tables so that you can use them 

117 to run model evaluation/prediction on CPU. 

118 

119 First lets look at the `TPUStrategy` mode. Initial setup looks like: 

120 

121 ```python 

122 strategy = tf.distribute.TPUStrategy(...) 

123 with strategy.scope(): 

124 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 

125 feature_config=feature_config, 

126 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 

127 ``` 

128 

129 When creating a distributed dataset that is to be passed to the enqueue 

130 operation a special input option must be specified: 

131 

132 ```python 

133 distributed_dataset = ( 

134 strategy.distribute_datasets_from_function( 

135 dataset_fn=..., 

136 options=tf.distribute.InputOptions( 

137 experimental_fetch_to_device=False)) 

138 dataset_iterator = iter(distributed_dataset) 

139 ``` 

140 

141 Different feature inputs can have different shapes. For dense and sparse 

142 tensor, rank 2 and above is supported. For ragged tensor, although only rank 2 

143 is supported, you can specify the output shape to be rank 2 and above. The 

144 output shape specified in the FeatureConfig has the first priority. The input 

145 shape passed in build method has second priority and the input shapes 

146 auto detected from input feature has the lowest priority. The latter two will 

147 be converted to output shapes by omitting the last dimension. If the lower 

148 priority one has output shapes which don't match the former one. A ValueError 

149 will be raised. Only when the former one has undefined output shapes, the 

150 latter one can override. 

151 

152 NOTE: All batches passed to the layer can have different input shapes. But 

153 these input shapes need to match with the output shapes set by either 

154 `FeatureConfig` or build method except for ragged tensor. Only 2D 

155 ragged tensor with output shape set to higher dimensions is allowed as 

156 long as the total number of elements matches. All subsequent calls must have 

157 the same input shapes. In the event that the input shapes cannot be 

158 automatically determined by the enqueue method, you must call 

159 the build method with the input shapes or provide output shapes in the 

160 `FeatureConfig` to initialize the layer. 

161 

162 To use this API on TPU you should use a custom training loop. Below is an 

163 example of a training and evaluation step: 

164 

165 ```python 

166 @tf.function 

167 def training_step(dataset_iterator, num_steps): 

168 def tpu_step(tpu_features): 

169 with tf.GradientTape() as tape: 

170 activations = embedding.dequeue() 

171 tape.watch(activations) 

172 model_output = model(activations) 

173 loss = ... # some function of labels and model_output 

174 

175 embedding_gradients = tape.gradient(loss, activations) 

176 embedding.apply_gradients(embedding_gradients) 

177 # Insert your model gradient and optimizer application here 

178 

179 for _ in tf.range(num_steps): 

180 embedding_features, tpu_features = next(dataset_iterator) 

181 embedding.enqueue(embedding_features, training=True) 

182 strategy.run(tpu_step, args=(tpu_features, )) 

183 

184 @tf.function 

185 def evaluation_step(dataset_iterator, num_steps): 

186 def tpu_step(tpu_features): 

187 activations = embedding.dequeue() 

188 model_output = model(activations) 

189 # Insert your evaluation code here. 

190 

191 for _ in tf.range(num_steps): 

192 embedding_features, tpu_features = next(dataset_iterator) 

193 embedding.enqueue(embedding_features, training=False) 

194 strategy.run(tpu_step, args=(tpu_features, )) 

195 ``` 

196 

197 NOTE: The calls to `enqueue` have `training` set to `True` when 

198 `embedding.apply_gradients` is used and set to `False` when 

199 `embedding.apply_gradients` is not present in the function. If you don't 

200 follow this pattern you may cause an error to be raised or the tpu may 

201 deadlock. 

202 

203 In the above examples, we assume that the user has a dataset which returns 

204 a tuple where the first element of the tuple matches the structure of what 

205 was passed as the `feature_config` argument to the object initializer. Also we 

206 utilize `tf.range` to get a `tf.while_loop` in order to increase performance. 

207 

208 When checkpointing your model, you should include your 

209 `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a 

210 trackable object and saving it will save the embedding tables and their 

211 optimizer slot variables: 

212 

213 ```python 

214 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 

215 checkpoint.save(...) 

216 ``` 

217 

218 On CPU, only the `embedding_table` property is usable. This will allow you to 

219 restore a checkpoint to the object and have access to the table variables: 

220 

221 ```python 

222 model = model_fn(...) 

223 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 

224 feature_config=feature_config, 

225 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 

226 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 

227 checkpoint.restore(...) 

228 

229 tables = embedding.embedding_tables 

230 ``` 

231 

232 You can now use table in functions like `tf.nn.embedding_lookup` to perform 

233 your embedding lookup and pass to your model. 

234 

235 """ 

236 

237 def __init__( 

238 self, 

239 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 

240 optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access 

241 pipeline_execution_with_tensor_core: bool = False): 

242 """Creates the TPUEmbedding mid level API object. 

243 

244 ```python 

245 strategy = tf.distribute.TPUStrategy(...) 

246 with strategy.scope(): 

247 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 

248 feature_config=tf.tpu.experimental.embedding.FeatureConfig( 

249 table=tf.tpu.experimental.embedding.TableConfig( 

250 dim=..., 

251 vocabulary_size=...))) 

252 ``` 

253 

254 Args: 

255 feature_config: A nested structure of 

256 `tf.tpu.experimental.embedding.FeatureConfig` configs. 

257 optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`, 

258 `tf.tpu.experimental.embedding.Adagrad` or 

259 `tf.tpu.experimental.embedding.Adam`. When not created under 

260 TPUStrategy may be set to None to avoid the creation of the optimizer 

261 slot variables, useful for optimizing memory consumption when exporting 

262 the model for serving where slot variables aren't needed. 

263 pipeline_execution_with_tensor_core: If True, the TPU embedding 

264 computations will overlap with the TensorCore computations (and hence 

265 will be one step old). Set to True for improved performance. 

266 

267 Raises: 

268 ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD, 

269 Adam or Adagrad) or None when created under a TPUStrategy. 

270 """ 

271 self._strategy = distribute_lib.get_strategy() 

272 self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy, 

273 tpu_strategy.TPUStrategyV2)) 

274 self._pipeline_execution_with_tensor_core = ( 

275 pipeline_execution_with_tensor_core) 

276 

277 self._feature_config = feature_config 

278 self._output_shapes = [] 

279 for feature in nest.flatten(feature_config): 

280 self._output_shapes.append(feature.output_shape) 

281 

282 # The TPU embedding ops are slightly inconsistent with how they refer to 

283 # tables: 

284 # * The enqueue op takes a parallel list of tensors for input, one of those 

285 # is the table id for the feature which matches the integer index of the 

286 # table in the proto created by _create_config_proto(). 

287 # * The recv_tpu_embedding_activations op emits lookups per table in the 

288 # order from the config proto. 

289 # * The send_tpu_embedding_gradients expects input tensors to be per table 

290 # in the same order as the config proto. 

291 # * Per optimizer load and retrieve ops are specified per table and take the 

292 # table name rather than the table id. 

293 # Thus we must fix a common order to tables and ensure they have unique 

294 # names. 

295 

296 # Set table order here to the order of the first occurence of the table in a 

297 # feature provided by the user. The order of this struct must be fixed 

298 # to provide the user with deterministic behavior over multiple 

299 # instantiations. 

300 self._table_config = [] 

301 for feature in nest.flatten(feature_config): 

302 if feature.table not in self._table_config: 

303 self._table_config.append(feature.table) 

304 

305 # Ensure tables have unique names. Also error check the optimizer as we 

306 # specifically don't do that in the TableConfig class to allow high level 

307 # APIs that are built on this to use strings/other classes to represent 

308 # optimizers (before they are passed to this class). 

309 table_names = [] 

310 for i, table in enumerate(self._table_config): 

311 if table.optimizer is None: 

312 # TODO(bfontain) Should we allow some sort of optimizer merging here? 

313 table.optimizer = optimizer 

314 if ((table.optimizer is not None or self._using_tpu) and 

315 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access 

316 raise ValueError("{} is an unsupported optimizer class. Please pass an " 

317 "instance of one of the optimizer classes under " 

318 "tf.tpu.experimental.embedding.".format( 

319 type(table.optimizer))) 

320 if table.name is None: 

321 table.name = "table_{}".format(i) 

322 if table.name in table_names: 

323 raise ValueError("Tables must have a unique name. " 

324 f"Multiple tables with name {table.name} found.") 

325 table_names.append(table.name) 

326 

327 if self._using_tpu: 

328 # Extract a list of callable learning rates also in fixed order. Each 

329 # table in the config proto will get an index into this list, and we will 

330 # pass this list in the same order after evaluation to the 

331 # send_tpu_embedding_gradients op. 

332 self._dynamic_learning_rates = [] 

333 for table in self._table_config: 

334 if (callable(table.optimizer.learning_rate) and 

335 table.optimizer.learning_rate not in self._dynamic_learning_rates): 

336 self._dynamic_learning_rates.append(table.optimizer.learning_rate) 

337 

338 # We need to list of host devices for the load/retrieve operations. 

339 self._hosts = tpu_embedding_v2_utils.get_list_of_hosts(self._strategy) 

340 

341 self._built = False 

342 self._verify_output_shapes_on_enqueue = True 

343 

344 def build(self, per_replica_input_shapes=None, per_replica_batch_size=None): # pylint:disable=g-bare-generic 

345 """Create the underlying variables and initializes the TPU for embeddings. 

346 

347 This method creates the underlying variables (including slot variables). If 

348 created under a TPUStrategy, this will also initialize the TPU for 

349 embeddings. 

350 

351 This function will automatically get called by enqueue, which will try to 

352 determine your output shapes. If this fails, you must manually 

353 call this method before you call enqueue. 

354 

355 Args: 

356 per_replica_input_shapes: A nested structure of The per replica input 

357 shapes that matches the structure of the feature config. The input 

358 shapes should be the same as the input shape of the feature (except for 

359 ragged tensor) Note that it is fixed and the same per replica input 

360 shapes must be used for both training and evaluation. If you want to 

361 calculate this from the global input shapes, you can use 

362 `num_replicas_in_sync` property of your strategy object. May be set to 

363 None if not created under a TPUStrategy. 

364 per_replica_batch_size: (Deprecated) The per replica batch size that you 

365 intend to use. Note that is fixed and the same batch size must be used 

366 for both training and evaluation. If you want to calculate this from the 

367 global batch size, you can use `num_replicas_in_sync` property of your 

368 strategy object. May be set to None if not created under a TPUStrategy. 

369 

370 Raises: 

371 ValueError: If per_replica_input_shapes is inconsistent with the output 

372 shapes stored in the feature config or the output shapes get from the 

373 input shapes are not fully defined. 

374 RuntimeError: If tpu embedding is already initialized on TPU. 

375 """ 

376 if self._built: 

377 return 

378 

379 if self._using_tpu: 

380 # If the tpu embedding is already initialized on TPU, raise runtime error. 

381 # Below logic is not added in `initialize_system_for_tpu_embedding` 

382 # because doing exception control flow in graph mode is difficult. 

383 if tpu_ops.is_tpu_embedding_initialized(): 

384 raise RuntimeError( 

385 "TPU is already initialized for embeddings. This may be caused by " 

386 "using multiple TPUEmbedding instances in a TPU scope which is " 

387 "unsupported") 

388 self._get_and_update_output_shapes_from_input(per_replica_input_shapes, 

389 per_replica_batch_size) 

390 

391 self._config_proto = self._create_config_proto() 

392 

393 logging.info("Initializing TPU Embedding engine.") 

394 tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto) 

395 

396 @def_function.function 

397 def load_config(): 

398 tpu.initialize_system_for_tpu_embedding(self._config_proto) 

399 

400 load_config() 

401 logging.info("Done initializing TPU Embedding engine.") 

402 

403 # Create and load variables and slot variables into the TPU. 

404 # Note that this is a dict of dicts. Keys to the first dict are table names. 

405 # We would prefer to use TableConfigs, but then these variables won't be 

406 # properly tracked by the tracking API. 

407 self._variables = self._create_variables_and_slots() 

408 

409 self._built = True 

410 

411 # This is internally conditioned self._built and self._using_tpu 

412 self._load_variables() 

413 

414 def _maybe_build(self, 

415 output_shapes: Optional[Union[List[int], Iterable]] = None): # pylint:disable=g-bare-generic 

416 if not self._built: 

417 # This can be called while tracing a function, so we wrap the 

418 # initialization code with init_scope so it runs eagerly, this means that 

419 # it will not be included the function graph generated by tracing so that 

420 # we can be sure that we only initialize the TPU for embeddings exactly 

421 # once. 

422 with ops.init_scope(): 

423 self.build(output_shapes) 

424 

425 def _get_and_update_output_shapes_from_input( 

426 self, 

427 per_replica_input_shapes: Optional[List[TensorShape]] = None, 

428 per_replica_batch_size: Optional[int] = None): 

429 """Get and update the per replica output shapes from the input.""" 

430 per_replica_output_shapes = None 

431 if per_replica_batch_size and per_replica_input_shapes is None: 

432 logging.warning( 

433 "per_replica_batch_size argument will be deprecated, please specify " 

434 "all the input shapes using per_replica_input_shapes argument.") 

435 per_replica_output_shapes = self._get_output_shapes_from_batch_size( 

436 per_replica_batch_size) 

437 

438 # Update the input shapes if provided. 

439 if per_replica_input_shapes is not None: 

440 if isinstance(per_replica_input_shapes, int): 

441 logging.warning( 

442 "Passing batch size to per_replica_input_shapes argument will be" 

443 " deprecated, please specify all the input shapes using" 

444 " per_replica_input_shapes argument.") 

445 per_replica_output_shapes = self._get_output_shapes_from_batch_size( 

446 per_replica_input_shapes) 

447 else: 

448 nest.assert_same_structure( 

449 nest.flatten(per_replica_input_shapes), 

450 nest.flatten(self._feature_config)) 

451 

452 # Convert the nested structure to list. 

453 per_replica_input_shapes = nest.flatten(per_replica_input_shapes) 

454 

455 per_replica_output_shapes = self._get_output_shapes_from_input_shapes( 

456 per_replica_input_shapes) 

457 

458 if per_replica_output_shapes is not None: 

459 

460 # Check the output shapes with existing output shapes setting. 

461 self._check_output_shapes(per_replica_output_shapes) 

462 

463 # Update the output shapes with existing output shapes setting. 

464 # This is necessary Because the output shapes might be missing from 

465 # the feature config, the usr can set it: 

466 # 1. calling the build method 

467 # 2. output shapes auto detected when calling the dequeue method for 

468 # for the first time. The dequeue method will call build method 

469 # with the output shapes. 

470 # Either these two situations will lead to an update to the existing 

471 # output shapes. 

472 self._update_output_shapes(per_replica_output_shapes) 

473 

474 # Check if the output shapes are fully defined. This is required in order 

475 # to set them in the feature descriptor field of the tpu embedding config 

476 # proto. 

477 self._check_output_shapes_fully_defined() 

478 

479 def _get_output_shapes_from_input_shapes( 

480 self, input_shapes: List[TensorShape]) -> List[TensorShape]: 

481 """Get output shapes from the flattened input shapes list.""" 

482 output_shapes = [] 

483 for input_shape, feature in zip(input_shapes, 

484 nest.flatten(self._feature_config)): 

485 if input_shape.rank is None or input_shape.rank < 1: 

486 raise ValueError( 

487 "Received input tensor of shape {}. Rank must be 1 and above" 

488 .format(input_shape)) 

489 # Update the input shape with the max sequence length. Only update when 

490 # 1. Input feature is 2D ragged or sparse tensor. 

491 # 2. Output shape is not set in the feature config and the max sequence 

492 # length is set. 

493 if (len(input_shape) == 2 and input_shape[-1] != 1 and 

494 not feature.output_shape and feature.max_sequence_length > 0): 

495 input_shape_list = input_shape.as_list() 

496 input_shape_list.insert( 

497 len(input_shape_list) - 1, feature.max_sequence_length) 

498 input_shape = TensorShape(input_shape_list) 

499 if input_shape.rank == 1: 

500 output_shapes.append(input_shape) 

501 else: 

502 output_shapes.append(input_shape[:-1]) 

503 return output_shapes 

504 

505 @property 

506 def embedding_tables( 

507 self 

508 ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: 

509 """Returns a dict of embedding tables, keyed by `TableConfig`. 

510 

511 This property only works when the `TPUEmbedding` object is created under a 

512 non-TPU strategy. This is intended to be used to for CPU based lookup when 

513 creating a serving checkpoint. 

514 

515 Returns: 

516 A dict of embedding tables, keyed by `TableConfig`. 

517 

518 Raises: 

519 RuntimeError: If object was created under a `TPUStrategy`. 

520 """ 

521 # We don't support returning tables on TPU due to their sharded nature and 

522 # the fact that when using a TPUStrategy: 

523 # 1. Variables are stale and are only updated when a checkpoint is made. 

524 # 2. Updating the variables won't affect the actual tables on the TPU. 

525 if self._using_tpu: 

526 if save_context.in_save_context(): 

527 return {table: self._variables[table.name]["parameters"].variables[0] 

528 for table in self._table_config} 

529 raise RuntimeError("Unable to retrieve embedding tables when using a TPU " 

530 "strategy. If you need access, save your model, " 

531 "create this object under a CPU strategy and restore.") 

532 

533 self._maybe_build(None) 

534 

535 # Only return the tables and not the slot variables. On CPU this are honest 

536 # tf.Variables. 

537 return {table: self._variables[table.name]["parameters"] 

538 for table in self._table_config} 

539 

540 def _create_config_proto( 

541 self 

542 ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration: 

543 """Creates the TPUEmbeddingConfiguration proto. 

544 

545 This proto is used to initialize the TPU embedding engine. 

546 

547 Returns: 

548 A TPUEmbeddingConfiguration proto. 

549 """ 

550 

551 config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() 

552 

553 # Map each callable dynamic learning rate to its in index in the list. 

554 # The learning rate index is the index of the dynamic learning rate for this 

555 # table (if it exists) in the list we created at initialization. We don't 

556 # simply create one learning rate index per table as this has extremely bad 

557 # performance characteristics. The more separate optimization configurations 

558 # we have, the worse the performance will be. 

559 learning_rate_index = {r: i for i, r in enumerate( 

560 self._dynamic_learning_rates)} 

561 

562 for table in self._table_config: 

563 table._set_table_descriptor( # pylint: disable=protected-access 

564 config_proto.table_descriptor.add(), 

565 self._strategy.extended.num_hosts, 

566 learning_rate_index) 

567 

568 table_to_id = {table: i for i, table in enumerate(self._table_config)} 

569 

570 # Set feature descriptor field in the config proto. 

571 for feature, output_shape in zip( 

572 nest.flatten(self._feature_config), self._output_shapes): 

573 feature_descriptor = config_proto.feature_descriptor.add() 

574 

575 if feature.name: 

576 feature_descriptor.name = feature.name 

577 

578 feature_descriptor.table_id = table_to_id[feature.table] 

579 # The input shape of the feature is the actual shape of the input tensor 

580 # except the last dimension because the last dimension will always be 

581 # reduced. 

582 feature_descriptor.input_shape.extend(output_shape.as_list()) 

583 

584 # Always set mode to training, we override the mode during enqueue. 

585 config_proto.mode = ( 

586 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING) 

587 

588 config_proto.num_hosts = self._strategy.extended.num_hosts 

589 config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync 

590 

591 # TODO(bfontain): Allow users to pick MOD for the host sharding. 

592 config_proto.sharding_strategy = ( 

593 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT) 

594 config_proto.pipeline_execution_with_tensor_core = ( 

595 self._pipeline_execution_with_tensor_core) 

596 

597 return config_proto 

598 

599 def apply_gradients(self, gradients, name: Optional[Text] = None): 

600 """Applies the gradient update to the embedding tables. 

601 

602 If a gradient of `None` is passed in any position of the nested structure, 

603 then an gradient update with a zero gradient is applied for that feature. 

604 For optimizers like SGD or Adagrad, this is the same as applying no update 

605 at all. For lazy Adam and other sparsely applied optimizers with decay, 

606 ensure you understand the effect of applying a zero gradient. 

607 

608 ```python 

609 strategy = tf.distribute.TPUStrategy(...) 

610 with strategy.scope(): 

611 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 

612 

613 distributed_dataset = ( 

614 strategy.distribute_datasets_from_function( 

615 dataset_fn=..., 

616 options=tf.distribute.InputOptions( 

617 experimental_fetch_to_device=False)) 

618 dataset_iterator = iter(distributed_dataset) 

619 

620 @tf.function 

621 def training_step(): 

622 def tpu_step(tpu_features): 

623 with tf.GradientTape() as tape: 

624 activations = embedding.dequeue() 

625 tape.watch(activations) 

626 

627 loss = ... # some computation involving activations 

628 

629 embedding_gradients = tape.gradient(loss, activations) 

630 embedding.apply_gradients(embedding_gradients) 

631 

632 embedding_features, tpu_features = next(dataset_iterator) 

633 embedding.enqueue(embedding_features, training=True) 

634 strategy.run(tpu_step, args=(tpu_features, )) 

635 

636 training_step() 

637 ``` 

638 

639 Args: 

640 gradients: A nested structure of gradients, with structure matching the 

641 `feature_config` passed to this object. 

642 name: A name for the underlying op. 

643 

644 Raises: 

645 RuntimeError: If called when object wasn't created under a `TPUStrategy` 

646 or if not built (either by manually calling build or calling enqueue). 

647 ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a 

648 `tf.Tensor` of the incorrect shape is passed in. Also if 

649 the size of any sequence in `gradients` does not match corresponding 

650 sequence in `feature_config`. 

651 TypeError: If the type of any sequence in `gradients` does not match 

652 corresponding sequence in `feature_config`. 

653 """ 

654 if not self._using_tpu: 

655 raise RuntimeError("apply_gradients is not valid when TPUEmbedding " 

656 "object is not created under a TPUStrategy.") 

657 

658 if not self._built: 

659 raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding " 

660 "object. Please either call enqueue first or manually " 

661 "call the build method.") 

662 

663 nest.assert_same_structure(self._feature_config, gradients) 

664 updated_gradients = [] 

665 for (path, gradient), feature, output_shape in zip( 

666 nest.flatten_with_joined_string_paths(gradients), 

667 nest.flatten(self._feature_config), self._output_shapes): 

668 full_output_shape = list(output_shape) + [feature.table.dim] 

669 if gradient is not None and not isinstance(gradient, ops.Tensor): 

670 raise ValueError( 

671 f"found non-tensor type: {type(gradient)} at path {path}.") 

672 if gradient is not None: 

673 if gradient.shape != full_output_shape: 

674 raise ValueError("Found gradient of shape {} at path {}. Expected " 

675 "shape {}.".format(gradient.shape, path, 

676 full_output_shape)) 

677 else: 

678 # No gradient for this feature, since we must give a gradient for all 

679 # features, pass in a zero tensor here. Note that this is not correct 

680 # for all optimizers. 

681 logging.warning( 

682 "No gradient passed for feature %s, sending zero " 

683 "gradient. This may not be correct behavior for certain " 

684 "optimizers like Adam.", path) 

685 gradient = array_ops.zeros(full_output_shape, dtype=dtypes.float32) 

686 # Some gradients can be passed with op which shape is not correctly set. 

687 # This ensures that the shape of the gradient is correctly set. 

688 updated_gradients.append( 

689 array_ops.reshape(gradient, shape=gradient.shape)) 

690 op = tpu_ops.send_tpu_embedding_gradients( 

691 inputs=updated_gradients, 

692 learning_rates=[ 

693 math_ops.cast(fn(), dtype=dtypes.float32) 

694 for fn in self._dynamic_learning_rates 

695 ], 

696 config=self._config_proto.SerializeToString()) 

697 

698 # Apply the name tag to the op. 

699 if name is not None: 

700 _add_key_attr(op, name) 

701 

702 def dequeue(self, name: Optional[Text] = None): 

703 """Get the embedding results. 

704 

705 Returns a nested structure of `tf.Tensor` objects, matching the structure of 

706 the `feature_config` argument to the `TPUEmbedding` class. The output shape 

707 of the tensors is `(*output_shape, dim)`, `dim` is the dimension of the 

708 corresponding `TableConfig`. For output_shape, there are three places where 

709 it can be set. 

710 1. FeatureConfig provided in the __init__ function. 

711 2. Per_replica_output_shapes by directly calling the build method 

712 after initializing the tpu embedding class. 

713 3. Auto detected from the shapes of the input feature. 

714 The priority of these places is the exact same order. 

715 

716 ```python 

717 strategy = tf.distribute.TPUStrategy(...) 

718 with strategy.scope(): 

719 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 

720 

721 distributed_dataset = ( 

722 strategy.distribute_datasets_from_function( 

723 dataset_fn=..., 

724 options=tf.distribute.InputOptions( 

725 experimental_fetch_to_device=False)) 

726 dataset_iterator = iter(distributed_dataset) 

727 

728 @tf.function 

729 def training_step(): 

730 def tpu_step(tpu_features): 

731 with tf.GradientTape() as tape: 

732 activations = embedding.dequeue() 

733 tape.watch(activations) 

734 

735 loss = ... # some computation involving activations 

736 

737 embedding_gradients = tape.gradient(loss, activations) 

738 embedding.apply_gradients(embedding_gradients) 

739 

740 embedding_features, tpu_features = next(dataset_iterator) 

741 embedding.enqueue(embedding_features, training=True) 

742 strategy.run(tpu_step, args=(tpu_features, )) 

743 

744 training_step() 

745 ``` 

746 

747 Args: 

748 name: A name for the underlying op. 

749 

750 Returns: 

751 A nested structure of tensors, with the same structure as `feature_config` 

752 passed to this instance of the `TPUEmbedding` object. 

753 

754 Raises: 

755 RuntimeError: If called when object wasn't created under a `TPUStrategy` 

756 or if not built (either by manually calling build or calling enqueue). 

757 """ 

758 if not self._using_tpu: 

759 raise RuntimeError("dequeue is not valid when TPUEmbedding object is not " 

760 "created under a TPUStrategy.") 

761 

762 if not self._built: 

763 raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. " 

764 "Please either call enqueue first or manually call " 

765 "the build method.") 

766 

767 # The activations returned by this op are per feature. 

768 activations = tpu_ops.recv_tpu_embedding_activations( 

769 num_outputs=len(self._config_proto.feature_descriptor), 

770 config=self._config_proto.SerializeToString()) 

771 

772 # Apply the name tag to the op. 

773 if name is not None: 

774 _add_key_attr(activations[0].op, name) 

775 

776 # Pack the list back into the same nested structure as the features. 

777 return nest.pack_sequence_as(self._feature_config, activations) 

778 

779 def _create_variables_and_slots( 

780 self 

781 ) -> Dict[Text, Dict[Text, tf_variables.Variable]]: 

782 """Create variables for TPU embeddings. 

783 

784 Note under TPUStrategy this will ensure that all creations happen within a 

785 variable creation scope of the sharded variable creator. 

786 

787 Returns: 

788 A dict of dicts. The outer dict is keyed by the table names and the inner 

789 dicts are keyed by 'parameters' and the slot variable names. 

790 """ 

791 

792 def create_variables(table): 

793 """Create all variables.""" 

794 variable_shape = (table.vocabulary_size, table.dim) 

795 

796 def getter(name, shape, dtype, initializer, trainable): 

797 del shape 

798 # _add_variable_with_custom_getter clears the shape sometimes, so we 

799 # take the global shape from outside the getter. 

800 initial_value = functools.partial(initializer, variable_shape, 

801 dtype=dtype) 

802 return tf_variables.Variable( 

803 name=name, 

804 initial_value=initial_value, 

805 shape=variable_shape, 

806 dtype=dtype, 

807 trainable=trainable) 

808 

809 def variable_creator(name, initializer, trainable=True): 

810 # use add_variable_with_custom_getter here so that we take advantage of 

811 # the checkpoint loading to allow restore before the variables get 

812 # created which avoids double initialization. 

813 return self._add_variable_with_custom_getter( 

814 name=name, 

815 initializer=initializer, 

816 shape=variable_shape, 

817 dtype=dtypes.float32, 

818 getter=getter, 

819 trainable=trainable) 

820 

821 parameters = variable_creator(table.name, table.initializer, 

822 trainable=not self._using_tpu) 

823 

824 def slot_creator(name, initializer): 

825 return variable_creator(table.name + "/" + name, 

826 initializer, 

827 False) 

828 

829 if table.optimizer is not None: 

830 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access 

831 else: 

832 slot_vars = {} 

833 slot_vars["parameters"] = parameters 

834 return slot_vars 

835 

836 # Store tables based on name rather than TableConfig as we can't track 

837 # through dicts with non-string keys, i.e. we won't be able to save. 

838 variables = {} 

839 for table in self._table_config: 

840 if not self._using_tpu: 

841 variables[table.name] = create_variables(table) 

842 else: 

843 with variable_scope.variable_creator_scope( 

844 make_sharded_variable_creator(self._hosts)): 

845 variables[table.name] = create_variables(table) 

846 

847 return variables 

848 

849 def _load_variables(self): 

850 # Only load the variables if we are: 

851 # 1) Using TPU 

852 # 2) Variables are created 

853 # 3) Not in save context (except if running eagerly) 

854 if self._using_tpu and self._built and not ( 

855 not context.executing_eagerly() and save_context.in_save_context()): 

856 _load_variables_impl(self._config_proto.SerializeToString(), 

857 self._hosts, 

858 self._variables, 

859 self._table_config) 

860 

861 def _retrieve_variables(self): 

862 # Only retrieve the variables if we are: 

863 # 1) Using TPU 

864 # 2) Variables are created 

865 # 3) Not in save context (except if running eagerly) 

866 if self._using_tpu and self._built and not ( 

867 not context.executing_eagerly() and save_context.in_save_context()): 

868 _retrieve_variables_impl(self._config_proto.SerializeToString(), 

869 self._hosts, 

870 self._variables, 

871 self._table_config) 

872 

873 # Some helper functions for the below enqueue function. 

874 def _add_data_for_tensor(self, tensor, weight, indices, values, weights, 

875 int_zeros, float_zeros, path): 

876 if weight is not None: 

877 raise ValueError( 

878 "Weight specified for dense input {}, which is not allowed. " 

879 "Weight will always be 1 in this case.".format(path)) 

880 # For tensors, there are no indices and no weights. 

881 indices.append(int_zeros) 

882 values.append(math_ops.cast(array_ops.reshape(tensor, [-1]), dtypes.int64)) 

883 weights.append(float_zeros) 

884 

885 def _add_data_for_sparse_tensor(self, tensor, weight, indices, values, 

886 weights, int_zeros, float_zeros, path, 

887 feature): 

888 sample_indices = math_ops.cast(tensor.indices, dtypes.int32) 

889 if tensor.shape.rank == 2: 

890 if not feature.output_shape and feature.max_sequence_length > 0: 

891 # Add one dimension to the last axis. 

892 sample_indices = array_ops.pad( 

893 sample_indices, paddings=[[0, 0], [0, 1]]) 

894 else: 

895 if feature.max_sequence_length > 0: 

896 logging.warning( 

897 ( 

898 "Input tensor is rank %d which is above 2, the" 

899 " max_sequence_length setting will be ignored." 

900 ), 

901 tensor.shape.rank, 

902 ) 

903 indices.append(sample_indices) 

904 values.append(math_ops.cast(tensor.values, dtypes.int64)) 

905 # If we have weights they must be a SparseTensor. 

906 if weight is not None: 

907 if not isinstance(weight, sparse_tensor.SparseTensor): 

908 raise ValueError("Weight for {} is type {} which does not match " 

909 "type input which is SparseTensor.".format( 

910 path, type(weight))) 

911 weights.append(math_ops.cast(weight.values, dtypes.float32)) 

912 else: 

913 weights.append(float_zeros) 

914 

915 def _add_data_for_ragged_tensor(self, tensor, weight, row_splits, values, 

916 weights, int_zeros, float_zeros, path, 

917 feature): 

918 row_splits.append(math_ops.cast(tensor.row_splits, dtypes.int32)) 

919 values.append(math_ops.cast(tensor.values, dtypes.int64)) 

920 # If we have weights they must be a RaggedTensor. 

921 if weight is not None: 

922 if not isinstance(weight, ragged_tensor.RaggedTensor): 

923 raise ValueError("Weight for {} is type {} which does not match " 

924 "type input which is RaggedTensor.".format( 

925 path, type(weight))) 

926 weights.append(math_ops.cast(weight.values, dtypes.float32)) 

927 else: 

928 weights.append(float_zeros) 

929 

930 def _generate_enqueue_op( 

931 self, 

932 flat_inputs: List[internal_types.NativeObject], 

933 flat_weights: List[Optional[internal_types.NativeObject]], 

934 flat_features: List[tpu_embedding_v2_utils.FeatureConfig], 

935 device_ordinal: int, 

936 mode_override: Text 

937 ) -> ops.Operation: 

938 """Outputs a the enqueue op given the inputs and weights. 

939 

940 Args: 

941 flat_inputs: A list of input tensors. 

942 flat_weights: A list of input weights (or None) of the same length as 

943 flat_inputs. 

944 flat_features: A list of FeatureConfigs of the same length as flat_inputs. 

945 device_ordinal: The device to create the enqueue op for. 

946 mode_override: A tensor containing the string "train" or "inference". 

947 

948 Returns: 

949 The enqueue op. 

950 """ 

951 # Combiners are per table, list in the same order as the table order. 

952 combiners = [table.combiner for table in self._table_config] 

953 

954 # These parallel arrays will be the inputs to the enqueue op. 

955 # sample_indices for sparse, row_splits for ragged. 

956 indices_or_row_splits = [] 

957 values = [] 

958 weights = [] 

959 

960 # We have to supply a empty/zero tensor in a list position where we don't 

961 # have data (e.g. indices for standard Tensor input, weight when no weight 

962 # is specified). We create one op here per call, so that we reduce the 

963 # graph size. 

964 int_zeros = array_ops.zeros((0,), dtype=dtypes.int32) 

965 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 

966 

967 # In the following loop we insert casts so that everything is either int32 

968 # or float32. This is because op inputs which are lists of tensors must be 

969 # of the same type within the list. Moreover the CPU implementations of 

970 # these ops cast to these types anyway, so we don't lose any data by casting 

971 # early. 

972 for inp, weight, (path, feature) in zip( 

973 flat_inputs, flat_weights, flat_features): 

974 if isinstance(inp, ops.Tensor): 

975 self._add_data_for_tensor(inp, weight, indices_or_row_splits, values, 

976 weights, int_zeros, float_zeros, path) 

977 elif isinstance(inp, sparse_tensor.SparseTensor): 

978 self._add_data_for_sparse_tensor(inp, weight, indices_or_row_splits, 

979 values, weights, int_zeros, 

980 float_zeros, path, feature) 

981 elif isinstance(inp, ragged_tensor.RaggedTensor): 

982 self._add_data_for_ragged_tensor(inp, weight, indices_or_row_splits, 

983 values, weights, int_zeros, 

984 float_zeros, path, feature) 

985 else: 

986 raise ValueError("Input {} is of unknown type {}. Please only pass " 

987 "Tensor, SparseTensor or RaggedTensor as input to " 

988 "enqueue.".format(path, type(inp))) 

989 

990 return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch( 

991 sample_indices_or_row_splits=indices_or_row_splits, 

992 embedding_indices=values, 

993 aggregation_weights=weights, 

994 mode_override=mode_override, 

995 device_ordinal=device_ordinal, 

996 combiners=combiners) 

997 

998 def _raise_error_for_incorrect_control_flow_context(self): 

999 """Raises an error if we are not in the TPUReplicateContext.""" 

1000 # Do not allow any XLA control flow (i.e. control flow in between a 

1001 # TPUStrategy's run call and the call to this function), as we can't 

1002 # extract the enqueue from the head when in XLA control flow. 

1003 graph = ops.get_default_graph() 

1004 in_tpu_ctx = False 

1005 while graph is not None: 

1006 ctx = graph._get_control_flow_context() # pylint: disable=protected-access 

1007 while ctx is not None: 

1008 if isinstance(ctx, tpu_replication.TPUReplicateContext): 

1009 in_tpu_ctx = True 

1010 break 

1011 ctx = ctx.outer_context 

1012 if in_tpu_ctx: 

1013 break 

1014 graph = getattr(graph, "outer_graph", None) 

1015 if graph != ops.get_default_graph() and in_tpu_ctx: 

1016 raise RuntimeError( 

1017 "Current graph {} does not match graph which contains " 

1018 "TPUReplicateContext {}. This is most likely due to the fact that " 

1019 "enqueueing embedding data is called inside control flow or a " 

1020 "tf.function inside `strategy.run`. This is not supported because " 

1021 "outside compilation fails to extract the enqueue ops as the head of " 

1022 "a computation.".format(ops.get_default_graph(), graph)) 

1023 return in_tpu_ctx 

1024 

1025 def _raise_error_for_non_direct_inputs(self, features): 

1026 """Checks all tensors in features to see if they are a direct input.""" 

1027 

1028 # expand_composites here is important: as composite tensors pass through 

1029 # tpu.replicate, they get 'flattened' into their component tensors and then 

1030 # repacked before being passed to the tpu function. In means that it is the 

1031 # component tensors which are produced by an op with the 

1032 # "_tpu_input_identity" attribute. 

1033 for path, input_tensor in nest.flatten_with_joined_string_paths( 

1034 features, expand_composites=True): 

1035 if input_tensor.op.type == "Placeholder": 

1036 continue 

1037 try: 

1038 is_input = input_tensor.op.get_attr("_tpu_input_identity") 

1039 except ValueError: 

1040 is_input = False 

1041 if not is_input: 

1042 raise ValueError( 

1043 "Received input tensor {} which is the output of op {} (type {}) " 

1044 "which does not have the `_tpu_input_identity` attr. Please " 

1045 "ensure that the inputs to this layer are taken directly from " 

1046 "the arguments of the function called by " 

1047 "strategy.run. Two possible causes are: dynamic batch size " 

1048 "support or you are using a keras layer and are not passing " 

1049 "tensors which match the dtype of the `tf.keras.Input`s." 

1050 "If you are triggering dynamic batch size support, you can " 

1051 "disable it by passing tf.distribute.RunOptions(" 

1052 "experimental_enable_dynamic_batch_size=False) to the options " 

1053 "argument of strategy.run().".format(path, 

1054 input_tensor.op.name, 

1055 input_tensor.op.type)) 

1056 

1057 def _raise_error_for_inputs_not_on_cpu(self, flat_inputs, flat_paths): 

1058 """Checks all tensors in features to see are placed on the CPU.""" 

1059 

1060 def check_device(path, device_string): 

1061 spec = tf_device.DeviceSpec.from_string(device_string) 

1062 if spec.device_type == "TPU": 

1063 raise ValueError( 

1064 "Received input tensor {} which is on a TPU input device {}. Input " 

1065 "tensors for TPU embeddings must be placed on the CPU. Please " 

1066 "ensure that your dataset is prefetching tensors to the host by " 

1067 "setting the 'experimental_fetch_to_device' option of the " 

1068 "dataset distribution function. See the documentation of the " 

1069 "enqueue method for an example.".format(path, device_string)) 

1070 

1071 # expand_composites here is important, we need to check the device of each 

1072 # underlying tensor. 

1073 for input_tensor, input_path in zip(flat_inputs, flat_paths): 

1074 if nest.is_nested_or_composite(input_tensor): 

1075 input_tensors = nest.flatten(input_tensor, expand_composites=True) 

1076 else: 

1077 input_tensors = [input_tensor] 

1078 for t in input_tensors: 

1079 if (t.op.type == "Identity" and 

1080 t.op.inputs[0].op.type == "TPUReplicatedInput"): 

1081 for tensor in t.op.inputs[0].op.inputs: 

1082 check_device(input_path, tensor.device) 

1083 else: 

1084 check_device(input_path, t.device) 

1085 

1086 def enqueue( 

1087 self, 

1088 features, 

1089 weights=None, 

1090 training: bool = True, 

1091 name: Optional[Text] = None, 

1092 device: Optional[Text] = None): 

1093 """Enqueues id tensors for embedding lookup. 

1094 

1095 This function enqueues a structure of features to be looked up in the 

1096 embedding tables. We expect that the input shapes of each of the tensors in 

1097 features matches the output shapes set via FeatureConfig or build method 

1098 (if any). the output shapes will be auto detected based on the input shapes 

1099 with the max_sequence_length or output shape setting in the FeatureConfig. 

1100 Note that the output shapes is based on per replica batch size. 

1101 If your input dataset is batched to the global batch size and you use 

1102 `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset` 

1103 or if you use `distribute_datasets_from_function` and batch 

1104 to the per core batch size computed by the context passed to your input 

1105 function, the output shapes should match automatically. 

1106 

1107 The auto detected the output shapes: 

1108 1. For dense tensor, if rank 2 or above, make sure the tensor has last 

1109 dimension as 1. The output shape will be the input shape excluding 

1110 the last dimension. 

1111 2. For sparse tensor, make sure the tensor has rank 2 and above. 

1112 a. If feature config has max_sequence_length equals 0 or output shape 

1113 set (the max_sequence_length setting will be ignored), the 

1114 output shape will be the input shape excluding the last dimension. 

1115 b. Otherwise, if the tensor is rank 2, the output shape will be input 

1116 shape with last dimension set as max_sequence_length. If the 

1117 tensor is above rank 2, the output shape will be the input shape 

1118 excluding the last dimension and the last dimension of the output 

1119 shape will be set to max_sequence_length. 

1120 3. For ragged tensor, make sure the tensor has rank 2. 

1121 a. If feature config has max_sequence_length equals 0 or output shape 

1122 set (the max_sequence_length setting will be ignored), the 

1123 output shape will be the input shape excluding the last dimension. 

1124 b. Otherwise, the output shape will be the input shape excluding the 

1125 last dimension and the last dimension of the output shape will be 

1126 set to max_sequence_length. 

1127 

1128 ```python 

1129 strategy = tf.distribute.TPUStrategy(...) 

1130 with strategy.scope(): 

1131 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 

1132 

1133 distributed_dataset = ( 

1134 strategy.distribute_datasets_from_function( 

1135 dataset_fn=..., 

1136 options=tf.distribute.InputOptions( 

1137 experimental_fetch_to_device=False)) 

1138 dataset_iterator = iter(distributed_dataset) 

1139 

1140 @tf.function 

1141 def training_step(): 

1142 def tpu_step(tpu_features): 

1143 with tf.GradientTape() as tape: 

1144 activations = embedding.dequeue() 

1145 tape.watch(activations) 

1146 

1147 loss = ... # some computation involving activations 

1148 

1149 embedding_gradients = tape.gradient(loss, activations) 

1150 embedding.apply_gradients(embedding_gradients) 

1151 

1152 embedding_features, tpu_features = next(dataset_iterator) 

1153 embedding.enqueue(embedding_features, training=True) 

1154 strategy.run(tpu_step, args=(tpu_features,)) 

1155 

1156 training_step() 

1157 ``` 

1158 

1159 NOTE: You should specify `training=True` when using 

1160 `embedding.apply_gradients` as above and `training=False` when not using 

1161 `embedding.apply_gradients` (e.g. for frozen embeddings or when doing 

1162 evaluation). 

1163 

1164 For finer grained control, in the above example the line 

1165 

1166 ``` 

1167 embedding.enqueue(embedding_features, training=True) 

1168 ``` 

1169 

1170 may be replaced with 

1171 

1172 ``` 

1173 per_core_embedding_features = self.strategy.experimental_local_results( 

1174 embedding_features) 

1175 

1176 def per_core_enqueue(ctx): 

1177 core_id = ctx.replica_id_in_sync_group 

1178 device = strategy.extended.worker_devices[core_id] 

1179 embedding.enqueue(per_core_embedding_features[core_id], 

1180 device=device) 

1181 

1182 strategy.experimental_distribute_values_from_function( 

1183 per_core_queue_inputs) 

1184 ``` 

1185 

1186 Args: 

1187 features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or 

1188 `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs 

1189 will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor` 

1190 or `tf.RaggedTensor` is supported per call. 

1191 weights: If not `None`, a nested structure of `tf.Tensor`s, 

1192 `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except 

1193 that the tensors should be of float type (and they will be downcast to 

1194 `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the 

1195 same for the parallel entries from `features` and similarly for 

1196 `tf.RaggedTensor`s we assume the row_splits are the same. 

1197 training: Defaults to `True`. If `False`, enqueue the batch as inference 

1198 batch (forward pass only). Do not call `apply_gradients` when this is 

1199 `False` as this may lead to a deadlock. 

1200 name: A name for the underlying op. 

1201 device: The device name (e.g. '/task:0/device:TPU:2') where this batch 

1202 should be enqueued. This should be set if and only if features is not a 

1203 `tf.distribute.DistributedValues` and enqueue is not being called 

1204 inside a TPU context (e.g. inside `TPUStrategy.run`). 

1205 

1206 Raises: 

1207 ValueError: When called inside a strategy.run call and input is not 

1208 directly taken from the args of the `strategy.run` call. Also if 

1209 the size of any sequence in `features` does not match corresponding 

1210 sequence in `feature_config`. Similarly for `weights`, if not `None`. 

1211 If input shapes of features is unequal or different from a previous 

1212 call. 

1213 RuntimeError: When called inside a strategy.run call and inside XLA 

1214 control flow. If batch_size is not able to be determined and build was 

1215 not called. 

1216 TypeError: If the type of any sequence in `features` does not match 

1217 corresponding sequence in `feature_config`. Similarly for `weights`, if 

1218 not `None`. 

1219 """ 

1220 if not self._using_tpu: 

1221 raise RuntimeError("enqueue is not valid when TPUEmbedding object is not " 

1222 "created under a TPUStrategy.") 

1223 

1224 in_tpu_context = self._raise_error_for_incorrect_control_flow_context() 

1225 

1226 nest.assert_same_structure(self._feature_config, features) 

1227 

1228 if not self._verify_output_shapes_on_enqueue: 

1229 if not self._output_shapes or not self._built: 

1230 raise ValueError( 

1231 "Configured not to check output shapes on each enqueue() call; please " 

1232 "ensure build() was called with output shapes to initialize " 

1233 "the TPU for embeddings.") 

1234 else: 

1235 input_shapes = self._get_input_shapes(features, in_tpu_context) 

1236 

1237 self._maybe_build(input_shapes) 

1238 # If is already built, we still need to check if the output shapes matches 

1239 # with the previous ones. 

1240 self._check_output_shapes( 

1241 self._get_output_shapes_from_input_shapes(input_shapes)) 

1242 

1243 flat_inputs = nest.flatten(features) 

1244 flat_weights = [None] * len(flat_inputs) 

1245 if weights is not None: 

1246 nest.assert_same_structure(self._feature_config, weights) 

1247 flat_weights = nest.flatten(weights) 

1248 flat_features = nest.flatten_with_joined_string_paths(self._feature_config) 

1249 flat_paths, _ = zip(*flat_features) 

1250 

1251 self._raise_error_for_inputs_not_on_cpu(flat_inputs, flat_paths) 

1252 # If we are in a tpu_context, automatically apply outside compilation. 

1253 if in_tpu_context: 

1254 self._raise_error_for_non_direct_inputs(features) 

1255 

1256 def generate_enqueue_ops(): 

1257 """Generate enqueue ops for outside compilation.""" 

1258 # Note that we put array_ops.where_v2 rather than a python if so that 

1259 # the op is explicitly create and the constant ops are both in the graph 

1260 # even though we don't expect training to be a tensor (and thus generate 

1261 # control flow automatically). This need to make it easier to re-write 

1262 # the graph later if we need to fix which mode needs to be used. 

1263 mode_override = array_ops.where_v2(training, 

1264 constant_op.constant("train"), 

1265 constant_op.constant("inference")) 

1266 # Device ordinal is -1 here, a later rewrite will fix this once the op 

1267 # is expanded by outside compilation. 

1268 enqueue_op = self._generate_enqueue_op( 

1269 flat_inputs, flat_weights, flat_features, device_ordinal=-1, 

1270 mode_override=mode_override) 

1271 

1272 # Apply the name tag to the op. 

1273 if name is not None: 

1274 _add_key_attr(enqueue_op, name) 

1275 

1276 tpu_replication.outside_compilation(generate_enqueue_ops) 

1277 

1278 elif device is None: 

1279 mode_override = "train" if training else "inference" 

1280 # We generate enqueue ops per device, so we need to gather the all 

1281 # features for a single device in to a dict. 

1282 # We rely here on the fact that the devices in the PerReplica value occur 

1283 # in the same (standard) order as self._strategy.extended.worker_devices. 

1284 enqueue_ops = [] 

1285 for replica_id in range(self._strategy.num_replicas_in_sync): 

1286 replica_inputs = distribute_utils.select_replica(replica_id, 

1287 flat_inputs) 

1288 replica_weights = distribute_utils.select_replica(replica_id, 

1289 flat_weights) 

1290 tpu_device = self._strategy.extended.worker_devices[replica_id] 

1291 # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 

1292 # the device ordinal is the last number 

1293 device_ordinal = ( 

1294 tf_device.DeviceSpec.from_string(tpu_device).device_index) 

1295 

1296 with ops.device(device_util.get_host_for_device(tpu_device)): 

1297 enqueue_op = self._generate_enqueue_op( 

1298 replica_inputs, replica_weights, flat_features, 

1299 device_ordinal=device_ordinal, mode_override=mode_override) 

1300 

1301 # Apply the name tag to the op. 

1302 if name is not None: 

1303 _add_key_attr(enqueue_op, name) 

1304 enqueue_ops.append(enqueue_op) 

1305 else: 

1306 mode_override = "train" if training else "inference" 

1307 device_spec = tf_device.DeviceSpec.from_string(device) 

1308 if device_spec.device_type != "TPU": 

1309 raise ValueError( 

1310 "Non-TPU device {} passed to enqueue.".format(device)) 

1311 

1312 with ops.device(device_util.get_host_for_device(device)): 

1313 enqueue_op = self._generate_enqueue_op( 

1314 flat_inputs, flat_weights, flat_features, 

1315 device_ordinal=device_spec.device_index, 

1316 mode_override=mode_override) 

1317 

1318 # Apply the name tag to the op. 

1319 if name is not None: 

1320 _add_key_attr(enqueue_op, name) 

1321 

1322 def _get_input_shapes(self, tensors, 

1323 in_tpu_context: bool) -> List[TensorShape]: 

1324 """Get the input shapes from the input tensor.""" 

1325 input_shapes = [] 

1326 for (path, maybe_tensor), feature in zip( 

1327 nest.flatten_with_joined_string_paths(tensors), 

1328 nest.flatten(self._feature_config)): 

1329 if not in_tpu_context: 

1330 tensor = distribute_utils.select_replica(0, maybe_tensor) 

1331 else: 

1332 tensor = maybe_tensor 

1333 

1334 if isinstance(tensor, ops.Tensor): 

1335 input_shapes.append( 

1336 self._get_input_shape_for_tensor(tensor, feature, path)) 

1337 elif isinstance(tensor, sparse_tensor.SparseTensor): 

1338 input_shapes.append( 

1339 self._get_input_shape_for_sparse_tensor(tensor, feature, path)) 

1340 elif isinstance(tensor, ragged_tensor.RaggedTensor): 

1341 input_shapes.append( 

1342 self._get_input_shape_for_ragged_tensor(tensor, feature, path)) 

1343 return input_shapes 

1344 

1345 def _get_input_shape_for_tensor(self, tensor, feature, path) -> TensorShape: 

1346 """Get the input shape for the dense tensor.""" 

1347 shape = tensor.shape.as_list() 

1348 if len(shape) < 1: 

1349 raise ValueError("Only rank 1 and above dense tensor is supported," 

1350 " find rank {} sparse tensor for input {}".format( 

1351 len(shape), path)) 

1352 if len(shape) > 1 and shape[-1] != 1: 

1353 raise ValueError( 

1354 "Rank 2 or above dense tensor should have last dimension as 1 " 

1355 "as the last dimension will always be reduced. " 

1356 "Instead got dense tensor as shape {}".format(shape)) 

1357 return TensorShape(shape) 

1358 

1359 def _get_input_shape_for_sparse_tensor(self, tensor, feature, 

1360 path) -> TensorShape: 

1361 """Get the input shape for the sparse tensor.""" 

1362 shape = tensor.shape.as_list() 

1363 # Only 2 and above rank sparse tensor is supported. 

1364 if len(shape) < 2: 

1365 raise ValueError("Only rank 2 and above sparse tensor is supported," 

1366 " find rank {} sparse tensor for input {}".format( 

1367 len(shape), path)) 

1368 if not feature.output_shape and feature.max_sequence_length > 0: 

1369 # If the max_sequence_length is set and the output shape for FeatureConfig 

1370 # is not set, we modify the shape of the input feature. Only rank 2 

1371 # feature output shape is modified 

1372 if len(shape) == 2: 

1373 # If the sparse tensor is 2D and max_sequence_length is set, 

1374 # we need to add one dimension to the input feature. 

1375 shape.insert(len(shape) - 1, feature.max_sequence_length) 

1376 

1377 return TensorShape(shape) 

1378 

1379 def _get_input_shape_for_ragged_tensor(self, tensor, feature, 

1380 path) -> TensorShape: 

1381 """Get the input shape for the ragged tensor.""" 

1382 shape = tensor.shape.as_list() 

1383 # Only rank 2 ragged tensor is supported. 

1384 if len(shape) != 2: 

1385 raise ValueError("Only rank 2 ragged tensor is supported," 

1386 " find rank {} ragged tensor for input {}".format( 

1387 len(shape), path)) 

1388 if not feature.output_shape and feature.max_sequence_length > 0: 

1389 # If the max_sequence_length is set and the output shape for FeatureConfig 

1390 # is not set, add the sequence length as second last dimension of 

1391 # the ragged tensor. 

1392 shape.insert(len(shape) - 1, feature.max_sequence_length) 

1393 

1394 return TensorShape(shape) 

1395 

1396 def _update_output_shapes(self, incoming_output_shapes: List[TensorShape]): 

1397 """Update the existing output shapes based on the new output shapes. 

1398 

1399 The existing output shapes always have higher piority than the new incoming 

1400 output shapes. 

1401 Args: 

1402 incoming_output_shapes: nested structure of TensorShape to override the 

1403 existing output shapes. 

1404 """ 

1405 nest.assert_same_structure(self._output_shapes, incoming_output_shapes) 

1406 updated_output_shapes = [] 

1407 for old_output_shape, incoming_output_shape in zip(self._output_shapes, 

1408 incoming_output_shapes): 

1409 if old_output_shape: 

1410 updated_output_shapes.append(old_output_shape) 

1411 else: 

1412 updated_output_shapes.append(incoming_output_shape) 

1413 self._output_shapes = updated_output_shapes 

1414 

1415 def _check_output_shapes(self, incoming_output_shapes: List[TensorShape]): 

1416 """Check the incoming output shapes against the output shapes stored.""" 

1417 # The incoming output shape should have the same structure with the existing 

1418 # output shapes. 

1419 nest.assert_same_structure(self._output_shapes, incoming_output_shapes) 

1420 

1421 for (path, _), old_output_shape, incoming_output_shape in zip( 

1422 nest.flatten_with_joined_string_paths(self._feature_config), 

1423 self._output_shapes, incoming_output_shapes): 

1424 # First check if both shapes are not None. 

1425 if old_output_shape and incoming_output_shape: 

1426 # We skip the check when the incoming output shape is rank 1 or 2 and 

1427 # rank of the old output shape is larger. This can happen for 

1428 # (sequence) ragged tensor, we push the check down to the enqueue op. 

1429 if (len(incoming_output_shape) == 1 or len(incoming_output_shape) 

1430 == 2) and len(old_output_shape) > len(incoming_output_shape): 

1431 continue 

1432 if len(old_output_shape) != len( 

1433 incoming_output_shape) or not self._is_tensor_shape_match( 

1434 old_output_shape, incoming_output_shape): 

1435 raise ValueError( 

1436 f"Inconsistent shape founded for input feature {path}, " 

1437 f"Output shape is set to be {old_output_shape}, " 

1438 f"But got incoming output shape {incoming_output_shape}") 

1439 

1440 def _check_output_shapes_fully_defined(self): 

1441 """Check if the output shape is fully defined.""" 

1442 for (path, _), output_shape in zip( 

1443 nest.flatten_with_joined_string_paths(self._feature_config), 

1444 self._output_shapes): 

1445 if not output_shape.is_fully_defined(): 

1446 raise ValueError( 

1447 f"Input Feature {path} has output shape set as " 

1448 f"{output_shape} which is not fully defined. " 

1449 "Please specify the fully defined shape in either FeatureConfig " 

1450 "or for the build method.") 

1451 

1452 def _is_tensor_shape_match(self, shape_a: TensorShape, 

1453 shape_b: TensorShape) -> bool: 

1454 """Check if shape b matches with shape a.""" 

1455 for s_a, s_b in zip(shape_a.as_list(), shape_b.as_list()): 

1456 if s_a and s_b and s_a != s_b: 

1457 return False 

1458 return True 

1459 

1460 def _get_output_shapes_from_batch_size(self, per_replica_batch_size): 

1461 """Get the output shapes from the batch size.""" 

1462 output_shapes = [] 

1463 for feature in nest.flatten(self._feature_config): 

1464 if not feature.output_shape and feature.max_sequence_length > 0: 

1465 output_shapes.append( 

1466 TensorShape([per_replica_batch_size, feature.max_sequence_length])) 

1467 else: 

1468 output_shapes.append(TensorShape(per_replica_batch_size)) 

1469 return output_shapes 

1470 

1471 def _create_copy_for_async_checkpoint( 

1472 self, feature_config, optimizer, pipeline_execution_with_tensor_core): 

1473 """Create a TPUEmbedding copy for checkpoint/async_checkpoint_helper.py.""" 

1474 return TPUEmbedding( 

1475 feature_config=feature_config, 

1476 optimizer=optimizer, 

1477 pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core) 

1478 

1479 

1480@def_function.function 

1481def _load_variables_impl( 

1482 config: Text, 

1483 hosts: List[Tuple[int, Text]], 

1484 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 

1485 table_config: tpu_embedding_v2_utils.TableConfig): 

1486 """Load embedding tables to onto TPU for each table and host. 

1487 

1488 Args: 

1489 config: A serialized TPUEmbeddingConfiguration proto. 

1490 hosts: A list of CPU devices, on per host. 

1491 variables: A dictionary of dictionaries of TPUEmbeddingVariables. First key 

1492 is the table name, second key is 'parameters' or the optimizer slot name. 

1493 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 

1494 """ 

1495 def select_fn(host_id): 

1496 

1497 def select_or_zeros(x): 

1498 if host_id >= len(x.variables): 

1499 # In the edge case where we have more hosts than variables, due to using 

1500 # a small number of rows, we load zeros for the later hosts. We copy 

1501 # the shape of the first host's variables, which we assume is defined 

1502 # because TableConfig guarantees at least one row. 

1503 return array_ops.zeros_like(x.variables[0]) 

1504 return x.variables[host_id] 

1505 

1506 return select_or_zeros 

1507 

1508 for host_id, host in enumerate(hosts): 

1509 with ops.device(host): 

1510 host_variables = nest.map_structure(select_fn(host_id), variables) 

1511 for table in table_config: 

1512 table.optimizer._load()( # pylint: disable=protected-access 

1513 table_name=table.name, 

1514 num_shards=len(hosts), 

1515 shard_id=host_id, 

1516 config=config, 

1517 **host_variables[table.name]) 

1518 # Ensure that only the first table/first host gets a config so that we 

1519 # don't bloat graph by attaching this large string to each op. 

1520 # We have num tables * num hosts of these so for models with a large 

1521 # number of tables training on a large slice, this can be an issue. 

1522 config = None 

1523 

1524 

1525@def_function.function 

1526def _retrieve_variables_impl( 

1527 config: Text, 

1528 hosts: List[Tuple[int, Text]], 

1529 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 

1530 table_config: tpu_embedding_v2_utils.TableConfig): 

1531 """Retrieve embedding tables from TPU to host memory. 

1532 

1533 Args: 

1534 config: A serialized TPUEmbeddingConfiguration proto. 

1535 hosts: A list of all the host CPU devices. 

1536 variables: A dictionary of dictionaries of TPUEmbeddingVariables. First key 

1537 is the table name, second key is 'parameters' or the optimizer slot name. 

1538 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 

1539 """ 

1540 for host_id, host in enumerate(hosts): 

1541 with ops.device(host): 

1542 for table in table_config: 

1543 retrieved = table.optimizer._retrieve()( # pylint: disable=protected-access 

1544 table_name=table.name, 

1545 num_shards=len(hosts), 

1546 shard_id=host_id, 

1547 config=config) 

1548 # When there are no slot variables (e.g with SGD) this returns a 

1549 # single tensor rather than a tuple. In this case we put the tensor in 

1550 # a list to make the following code easier to write. 

1551 if not isinstance(retrieved, tuple): 

1552 retrieved = (retrieved,) 

1553 

1554 for i, slot in enumerate(["parameters"] + 

1555 table.optimizer._slot_names()): # pylint: disable=protected-access 

1556 # We must assign the CPU variables the values of tensors that were 

1557 # returned from the TPU. 

1558 sharded_var = variables[table.name][slot] 

1559 if host_id < len(sharded_var.variables): 

1560 # In the edge case where we have more hosts than variables, due to 

1561 # using a small number of rows, we skip the later hosts. 

1562 sharded_var.variables[host_id].assign(retrieved[i]) 

1563 # Ensure that only the first table/first host gets a config so that we 

1564 # don't bloat graph by attaching this large string to each op. 

1565 # We have num tables * num hosts of these so for models with a large 

1566 # number of tables training on a large slice, this can be an issue. 

1567 config = None 

1568 

1569 

1570def _save_callback(trackables, **unused_kwargs): 

1571 for trackable in trackables.values(): 

1572 trackable._retrieve_variables() # pylint: disable=protected-access 

1573 return [] 

1574 

1575 

1576def _restore_callback(trackables, **unused_kwargs): 

1577 for trackable in trackables.values(): 

1578 trackable._load_variables() # pylint: disable=protected-access 

1579 

1580 

1581registration.register_tf_checkpoint_saver( 

1582 "TPUEmbeddingCallback", 

1583 predicate=lambda x: isinstance(x, TPUEmbedding), 

1584 save_fn=_save_callback, 

1585 restore_fn=_restore_callback, 

1586 # Set strict_predicate_restore to `False` to because the isinstance 

1587 # predicate check does not pass after a TPUEmbedding object is loaded from 

1588 # SavedModel. 

1589 strict_predicate_restore=False 

1590) 

1591 

1592 

1593def extract_variable_info( 

1594 kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]: 

1595 """Extracts the variable creation attributes from the kwargs. 

1596 

1597 Args: 

1598 kwargs: a dict of keyword arguments that were passed to a variable creator 

1599 scope. 

1600 

1601 Returns: 

1602 A tuple of variable name, shape, dtype, initialization function. 

1603 """ 

1604 if (isinstance(kwargs["initial_value"], functools.partial) and ( 

1605 "shape" in kwargs["initial_value"].keywords or 

1606 kwargs["initial_value"].args)): 

1607 # Sometimes shape is passed positionally, sometimes it's passed as a kwarg. 

1608 if "shape" in kwargs["initial_value"].keywords: 

1609 shape = kwargs["initial_value"].keywords["shape"] 

1610 else: 

1611 shape = kwargs["initial_value"].args[0] 

1612 return (kwargs["name"], shape, 

1613 kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), 

1614 kwargs["initial_value"].func) 

1615 elif "shape" not in kwargs or kwargs["shape"] is None or not callable( 

1616 kwargs["initial_value"]): 

1617 raise ValueError( 

1618 "Unable to extract initializer function and shape from {}. Please " 

1619 "either pass a function that expects a shape and dtype as the " 

1620 "initial value for your variable or functools.partial object with " 

1621 "the shape and dtype kwargs set. This is needed so that we can " 

1622 "initialize the shards of the ShardedVariable locally.".format( 

1623 kwargs["initial_value"])) 

1624 else: 

1625 return (kwargs["name"], kwargs["shape"], kwargs["dtype"], 

1626 kwargs["initial_value"]) 

1627 

1628 

1629def make_sharded_variable_creator( 

1630 hosts: List[Text]) -> Callable[..., TPUEmbeddingVariable]: 

1631 """Makes a sharded variable creator given a list of hosts. 

1632 

1633 Args: 

1634 hosts: a list of tensorflow devices on which to shard the tensors. 

1635 

1636 Returns: 

1637 A variable creator function. 

1638 """ 

1639 

1640 def sharded_variable_creator( 

1641 next_creator: Callable[..., tf_variables.Variable], *args, **kwargs): 

1642 """The sharded variable creator.""" 

1643 kwargs["skip_mirrored_creator"] = True 

1644 

1645 num_hosts = len(hosts) 

1646 name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs) 

1647 initial_value = kwargs["initial_value"] 

1648 rows = shape[0] 

1649 cols = shape[1] 

1650 partial_partition = rows % num_hosts 

1651 full_rows_per_host = rows // num_hosts 

1652 # We partition as if we were using MOD sharding: at least 

1653 # `full_rows_per_host` rows to `num_hosts` hosts, where the first 

1654 # `partial_partition` hosts get an additional row when the number of rows 

1655 # is not cleanly divisible. Note that `full_rows_per_host` may be zero. 

1656 partitions = ( 

1657 [full_rows_per_host + 1] * partial_partition 

1658 + [full_rows_per_host] * (num_hosts - partial_partition)) 

1659 variables = [] 

1660 sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args 

1661 

1662 # Keep track of offset for sharding aware initializers. 

1663 offset = 0 

1664 kwargs["dtype"] = dtype 

1665 for i, p in enumerate(partitions): 

1666 if p == 0: 

1667 # Skip variable creation for empty partitions, resulting from the edge 

1668 # case of 'rows < num_hosts'. This is safe because both load/restore 

1669 # can handle the missing values. 

1670 continue 

1671 with ops.device(hosts[i]): 

1672 kwargs["name"] = "{}_{}".format(name, i) 

1673 kwargs["shape"] = (p, cols) 

1674 if sharding_aware: 

1675 shard_info = base.ShardInfo(kwargs["shape"], (offset, 0)) 

1676 kwargs["initial_value"] = functools.partial( 

1677 initial_value, shard_info=shard_info) 

1678 offset += p 

1679 else: 

1680 kwargs["initial_value"] = functools.partial( 

1681 unwrapped_initial_value, kwargs["shape"], dtype=dtype) 

1682 variables.append(next_creator(*args, **kwargs)) 

1683 return TPUEmbeddingVariable(variables, name=name) 

1684 return sharded_variable_creator