Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding.py: 19%

894 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TPU embedding APIs.""" 

16 

17import collections 

18import copy 

19import math 

20import re 

21from typing import Optional 

22 

23from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 

24from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc 

25from tensorflow.python.eager import context 

26from tensorflow.python.framework import dtypes 

27from tensorflow.python.framework import ops 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import control_flow_ops 

30from tensorflow.python.ops import init_ops 

31from tensorflow.python.ops import math_ops 

32from tensorflow.python.ops import partitioned_variables 

33from tensorflow.python.ops import state_ops 

34from tensorflow.python.ops import variable_scope 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 

37from tensorflow.python.tpu.ops import tpu_ops 

38from tensorflow.python.util.tf_export import tf_export 

39 

40TRAINING = elc.TPUEmbeddingConfiguration.TRAINING 

41INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE 

42 

43 

44# TODO(shizhiw): a more future-proof way is to have optimization_parameter such 

45# as AdagradParameters etc instead of learning_rate. 

46class TableConfig( 

47 collections.namedtuple('TableConfig', [ 

48 'vocabulary_size', 

49 'dimension', 

50 'initializer', 

51 'combiner', 

52 'hot_id_replication', 

53 'learning_rate', 

54 'learning_rate_fn', 

55 'optimization_parameters', 

56 ])): 

57 """Embedding table configuration.""" 

58 

59 def __new__(cls, 

60 vocabulary_size, 

61 dimension, 

62 initializer=None, 

63 combiner='mean', 

64 hot_id_replication=False, 

65 learning_rate=None, 

66 learning_rate_fn=None, 

67 optimization_parameters=None): 

68 """Embedding table configuration. 

69 

70 Args: 

71 vocabulary_size: Number of vocabulary (/rows) in the table. 

72 dimension: The embedding dimension. 

73 initializer: A variable initializer function to be used in embedding 

74 variable initialization. If not specified, defaults to 

75 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard 

76 deviation `1/sqrt(dimension)`. 

77 combiner: A string specifying how to reduce if there are multiple entries 

78 in a single row. Currently 'mean', 'sqrtn', 'sum' and None are 

79 supported, with 'mean' the default. 'sqrtn' often achieves good 

80 accuracy, in particular with bag-of-words columns. For more information, 

81 see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather 

82 than sparse tensors. 

83 hot_id_replication: If true, enables hot id replication, which can make 

84 embedding lookups faster if there are some hot rows in the table. 

85 learning_rate: float, static learning rate for this table. If 

86 learning_rate and learning_rate_fn are both `None`, static learning rate 

87 as specified in local `optimization_parameters` will be used. In case 

88 local `optimization_parameters` is `None`, global 

89 `optimization_parameters` in `TPUEmbedding` constructor will be used. 

90 `learning_rate_fn` must be `None` if `learning_rate` is not `None. 

91 learning_rate_fn: string, use dynamic learning rate given by the function. 

92 This function will be passed the current global step. If learning_rate 

93 and learning_rate_fn are both `None`, static learning rate as specified 

94 in `optimization_parameters` is used. `learning_rate` must be `None` if 

95 `learning_rate_fn` is not `None. 

96 optimization_parameters: `AdagradParameters`, `AdamParameters`, 

97 `Stochasticgradientdescentparameters`. Specifies table level optimizer. 

98 If it's `None` global optimizer in `TPUEmbedding` constructor is used. 

99 

100 Returns: 

101 `TableConfig`. 

102 

103 Raises: 

104 ValueError: if `vocabulary_size` is not positive integer. 

105 ValueError: if `dimension` is not positive integer. 

106 ValueError: if `initializer` is specified and is not callable. 

107 ValueError: if `combiner` is not supported. 

108 ValueError: if `learning_rate` and `learning_rate_fn` are both not 

109 `None`. 

110 """ 

111 if not isinstance(vocabulary_size, int) or vocabulary_size < 1: 

112 raise ValueError(f'vocabulary_size must >= 1. ' 

113 f'Received: {vocabulary_size}.') 

114 

115 if not isinstance(dimension, int) or dimension < 1: 

116 raise ValueError( 

117 f'dimension must be a positive int. Received: {dimension}.') 

118 

119 if (initializer is not None) and (not callable(initializer)): 

120 raise ValueError(f'initializer must be callable if specified. ' 

121 f'Received: {initializer}.') 

122 if initializer is None: 

123 initializer = init_ops.truncated_normal_initializer( 

124 mean=0.0, stddev=1 / math.sqrt(dimension)) 

125 

126 if combiner not in ('mean', 'sum', 'sqrtn', None): 

127 raise ValueError(f'combiner must be "mean", "sum", "sqrtn" or None. ' 

128 f'Received: {combiner}.') 

129 

130 if learning_rate is not None and learning_rate_fn is not None: 

131 raise ValueError('At most one of learning_rate and learning_rate_fn ' 

132 'can be None. Received: {} and {}'.format( 

133 learning_rate, learning_rate_fn)) 

134 

135 if optimization_parameters is not None: 

136 if not isinstance(optimization_parameters, _OptimizationParameters): 

137 raise ValueError(f'`optimization_parameters` must inherit from ' 

138 f'`_OptimizationParameters`. ' 

139 f'Received: `type(optimization_parameters)`=' 

140 f'{type(optimization_parameters)}.') 

141 

142 return super().__new__(cls, vocabulary_size, dimension, initializer, 

143 combiner, hot_id_replication, learning_rate, 

144 learning_rate_fn, optimization_parameters) 

145 

146 

147class FeatureConfig( 

148 collections.namedtuple('FeatureConfig', 

149 ['table_id', 'max_sequence_length', 'weight_key'])): 

150 """Feature configuration.""" 

151 

152 def __new__(cls, table_id, max_sequence_length=0, weight_key=None): 

153 """Feature configuration. 

154 

155 Args: 

156 table_id: Which table the feature is uses for embedding lookups. 

157 max_sequence_length: If positive, the feature is a sequence feature with 

158 the corresponding maximum sequence length. If the sequence is longer 

159 than this, it will be truncated. If 0, the feature is not a sequence 

160 feature. 

161 weight_key: If using weights for the combiner, this key specifies which 

162 input feature contains the weights. 

163 

164 Returns: 

165 `FeatureConfig`. 

166 

167 Raises: 

168 ValueError: if `max_sequence_length` non-integer or negative. 

169 """ 

170 if not isinstance(max_sequence_length, int) or max_sequence_length < 0: 

171 raise ValueError(f'max_sequence_length must be zero or a positive int, ' 

172 f'got {max_sequence_length}.') 

173 

174 return super().__new__(cls, table_id, max_sequence_length, weight_key) 

175 

176 

177class EnqueueData( 

178 collections.namedtuple( 

179 'EnqueueData', 

180 ['embedding_indices', 'sample_indices', 'aggregation_weights'])): 

181 """Data to be enqueued through generate_enqueue_ops().""" 

182 

183 def __new__(cls, 

184 embedding_indices, 

185 sample_indices=None, 

186 aggregation_weights=None): 

187 """Data to be enqueued through generate_enqueue_ops(). 

188 

189 Args: 

190 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It 

191 corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 

192 and int64 are allowed and will be converted to int32 internally. 

193 sample_indices: A rank 2 Tensor specifying the training example to which 

194 the corresponding embedding_indices and aggregation_weights values 

195 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). 

196 If it is None, we assume each embedding_indices belongs to a different 

197 sample. Both int32 and int64 are allowed and will be converted to int32 

198 internally. 

199 aggregation_weights: A rank 1 Tensor containing aggregation weights. It 

200 corresponds to sp_weights.values in embedding_lookup_sparse(). If it is 

201 None, we assume all weights are 1. Both float32 and float64 are allowed 

202 and will be converted to float32 internally. 

203 

204 Returns: 

205 An EnqueueData tuple. 

206 

207 """ 

208 return super().__new__(cls, embedding_indices, sample_indices, 

209 aggregation_weights) 

210 

211 @staticmethod 

212 def from_sparse_tensor(sp_tensor, weights=None): 

213 return EnqueueData( 

214 sp_tensor.values, 

215 sp_tensor.indices, 

216 aggregation_weights=weights.values if weights is not None else None) 

217 

218 

219class RaggedEnqueueData( 

220 collections.namedtuple( 

221 'RaggedEnqueueData', 

222 ['embedding_indices', 'row_splits', 'aggregation_weights'])): 

223 """RaggedTensor Data to be enqueued through generate_enqueue_ops().""" 

224 

225 def __new__(cls, 

226 embedding_indices, 

227 row_splits=None, 

228 aggregation_weights=None): 

229 """Data to be enqueued through generate_enqueue_ops(). 

230 

231 Args: 

232 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It 

233 corresponds to ids.values in embedding_lookup(), when ids is a 

234 RaggedTensor. Both int32 and int64 are allowed and will be converted to 

235 int32 internally. 

236 row_splits: A rank 1 Tensor specifying the length of the break points for 

237 splitting embedding_indices and aggregation_weights. It corresponds to 

238 ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both 

239 int32 and int64 are allowed and will be converted to int32 internally. 

240 aggregation_weights: A rank 1 Tensor containing per training example 

241 aggregation weights. It corresponds to the values field of a 

242 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 

243 ids is a RaggedTensor. 

244 

245 Returns: 

246 An RaggedEnqueueData tuple. 

247 

248 """ 

249 return super().__new__(cls, embedding_indices, row_splits, 

250 aggregation_weights) 

251 

252 @staticmethod 

253 def from_ragged_tensor(rg_tensor, weights=None): 

254 return RaggedEnqueueData( 

255 rg_tensor.values, 

256 rg_tensor.row_splits, 

257 aggregation_weights=weights.values if weights is not None else None) 

258 

259 

260def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): 

261 """Convenient function for generate_enqueue_ops(). 

262 

263 Args: 

264 sp_tensors_list: a list of dictionary mapping from string of feature names 

265 to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the 

266 same host should be contiguous on the list. 

267 

268 Returns: 

269 enqueue_datas_list: a list of dictionary mapping from string 

270 of feature names to EnqueueData. Each dictionary is for one 

271 TPU core. Dictionaries for the same host should be contiguous 

272 on the list. 

273 

274 """ 

275 enqueue_datas_list = [] 

276 for sp_tensors in sp_tensors_list: 

277 enqueue_datas = collections.OrderedDict( 

278 (k, EnqueueData.from_sparse_tensor(v)) for k, v in sp_tensors.items()) 

279 enqueue_datas_list.append(enqueue_datas) 

280 return enqueue_datas_list 

281 

282 

283def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list): 

284 """Convenient function for generate_enqueue_ops(). 

285 

286 Args: 

287 rg_tensors_list: a list of dictionary mapping from string of feature names 

288 to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the 

289 same host should be contiguous on the list. 

290 

291 Returns: 

292 enqueue_datas_list: a list of dictionary mapping from string 

293 of feature names to RaggedEnqueueData. Each dictionary is for one 

294 TPU core. Dictionaries for the same host should be contiguous 

295 on the list. 

296 

297 """ 

298 enqueue_datas_list = [] 

299 for rg_tensors in rg_tensors_list: 

300 enqueue_datas = collections.OrderedDict( 

301 (k, RaggedEnqueueData.from_ragged_tensor(v)) 

302 for k, v in rg_tensors.items()) 

303 enqueue_datas_list.append(enqueue_datas) 

304 return enqueue_datas_list 

305 

306 

307AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames', 

308 ['m', 'v']) 

309 

310AdagradSlotVariableNames = collections.namedtuple('AdagradSlotVariableNames', 

311 ['accumulator']) 

312 

313MomentumSlotVariableNames = collections.namedtuple('MomentumSlotVariableNames', 

314 ['momenta']) 

315 

316AdagradMomentumSlotVariableNames = collections.namedtuple( 

317 'AdagradMomentumSlotVariableNames', ['accumulator', 'momenta']) 

318 

319RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames', 

320 ['ms', 'mom']) 

321 

322ProximalAdagradSlotVariableNames = collections.namedtuple( 

323 'ProximalAdagradSlotVariableNames', ['accumulator']) 

324 

325FtrlSlotVariableNames = collections.namedtuple('FtrlSlotVariableNames', 

326 ['accumulator', 'linear']) 

327 

328ProximalYogiSlotVariableNames = collections.namedtuple( 

329 'ProximalYogiSlotVariableNames', ['v', 'm']) 

330 

331FrequencyEstimatorSlotVariableNames = collections.namedtuple( 

332 'FrequencyEstimatorSlotVariableNames', ['last_hit_step']) 

333 

334AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v']) 

335 

336MomentumSlotVariables = collections.namedtuple('MomentumSlotVariables', 

337 ['momenta']) 

338 

339AdagradMomentumSlotVariables = collections.namedtuple( 

340 'AdagradMomentumSlotVariables', ['accumulator', 'momenta']) 

341 

342RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables', 

343 ['ms', 'mom']) 

344 

345AdagradSlotVariables = collections.namedtuple('AdagradSlotVariables', 

346 ['accumulator']) 

347 

348ProximalAdagradSlotVariables = collections.namedtuple( 

349 'ProximalAdagradSlotVariables', ['accumulator']) 

350 

351FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable', 

352 ['accumulator', 'linear']) 

353 

354ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables', 

355 ['v', 'm']) 

356 

357FrequencyEstimatorSlotVariables = collections.namedtuple( 

358 'FrequencyEstimatorSlotVariables', ['last_hit_step']) 

359 

360VariablesAndOps = collections.namedtuple('VariablesAndOps', [ 

361 'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops', 

362 'retrieve_ops' 

363]) 

364 

365 

366class _OptimizationParameters: 

367 """Parameters common to all optimizations.""" 

368 

369 def __init__( 

370 self, 

371 learning_rate: float, 

372 use_gradient_accumulation: bool, 

373 clip_weight_min: Optional[float], 

374 clip_weight_max: Optional[float], 

375 weight_decay_factor: Optional[float], 

376 multiply_weight_decay_factor_by_learning_rate: Optional[bool], 

377 clip_gradient_min: Optional[float] = None, 

378 clip_gradient_max: Optional[float] = None, 

379 ): 

380 self.learning_rate = learning_rate 

381 self.use_gradient_accumulation = use_gradient_accumulation 

382 self.clip_weight_min = clip_weight_min 

383 self.clip_weight_max = clip_weight_max 

384 self.weight_decay_factor = weight_decay_factor 

385 self.multiply_weight_decay_factor_by_learning_rate = ( 

386 multiply_weight_decay_factor_by_learning_rate) 

387 self.clip_gradient_min = clip_gradient_min 

388 self.clip_gradient_max = clip_gradient_max 

389 

390 if not use_gradient_accumulation and (clip_gradient_min is not None or 

391 clip_gradient_max is not None): 

392 raise ValueError('When using gradient clipping limits, gradient ' 

393 'accumulation must be enabled.') 

394 

395 

396@tf_export(v1=['tpu.experimental.AdagradParameters']) 

397class AdagradParameters(_OptimizationParameters): 

398 """Optimization parameters for Adagrad with TPU embeddings. 

399 

400 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

401 `optimization_parameters` argument to set the optimizer and its parameters. 

402 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

403 for more details. 

404 

405 ``` 

406 estimator = tf.estimator.tpu.TPUEstimator( 

407 ... 

408 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

409 ... 

410 optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1), 

411 ...)) 

412 ``` 

413 

414 """ 

415 

416 def __init__( 

417 self, 

418 learning_rate: float, 

419 initial_accumulator: float = 0.1, 

420 use_gradient_accumulation: bool = True, 

421 clip_weight_min: Optional[float] = None, 

422 clip_weight_max: Optional[float] = None, 

423 weight_decay_factor: Optional[float] = None, 

424 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

425 clip_gradient_min: Optional[float] = None, 

426 clip_gradient_max: Optional[float] = None, 

427 ): 

428 """Optimization parameters for Adagrad. 

429 

430 Args: 

431 learning_rate: used for updating embedding table. 

432 initial_accumulator: initial accumulator for Adagrad. 

433 use_gradient_accumulation: setting this to `False` makes embedding 

434 gradients calculation less accurate but faster. Please see 

435 `optimization_parameters.proto` for details. 

436 clip_weight_min: the minimum value to clip by; None means -infinity. 

437 clip_weight_max: the maximum value to clip by; None means +infinity. 

438 weight_decay_factor: amount of weight decay to apply; None means that the 

439 weights are not decayed. 

440 multiply_weight_decay_factor_by_learning_rate: if true, 

441 `weight_decay_factor` is multiplied by the current learning rate. 

442 clip_gradient_min: the minimum value to clip by; None means -infinity. 

443 Gradient accumulation must be set to true if this is set. 

444 clip_gradient_max: the maximum value to clip by; None means +infinity. 

445 Gradient accumulation must be set to true if this is set. 

446 """ 

447 super().__init__( 

448 learning_rate=learning_rate, 

449 use_gradient_accumulation=use_gradient_accumulation, 

450 clip_weight_min=clip_weight_min, 

451 clip_weight_max=clip_weight_max, 

452 weight_decay_factor=weight_decay_factor, 

453 multiply_weight_decay_factor_by_learning_rate=( 

454 multiply_weight_decay_factor_by_learning_rate), 

455 clip_gradient_min=clip_gradient_min, 

456 clip_gradient_max=clip_gradient_max, 

457 ) 

458 if initial_accumulator <= 0: 

459 raise ValueError( 

460 f'Adagrad initial_accumulator must be greater than zero. ' 

461 f'Received: {initial_accumulator}.') 

462 self.initial_accumulator = initial_accumulator 

463 

464 

465class AdagradMomentumParameters(_OptimizationParameters): 

466 """Optimization parameters for Adagrad + Momentum with TPU embeddings. 

467 

468 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

469 `optimization_parameters` argument to set the optimizer and its parameters. 

470 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

471 for more details. 

472 

473 ``` 

474 estimator = tf.estimator.tpu.TPUEstimator( 

475 ... 

476 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

477 ... 

478 optimization_parameters=tf.tpu.experimental.AdagradMomentumParameters(0.1), 

479 ...)) 

480 ``` 

481 

482 """ 

483 

484 def __init__( 

485 self, 

486 learning_rate: float, 

487 momentum: float, 

488 use_nesterov: bool = False, 

489 exponent: float = 2, 

490 beta2: float = 1, 

491 epsilon: float = 1e-10, 

492 use_gradient_accumulation: bool = True, 

493 clip_weight_min: Optional[float] = None, 

494 clip_weight_max: Optional[float] = None, 

495 weight_decay_factor: Optional[float] = None, 

496 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

497 clip_gradient_min: Optional[float] = None, 

498 clip_gradient_max: Optional[float] = None, 

499 ): 

500 """Optimization parameters for Adagrad. 

501 

502 Args: 

503 learning_rate: used for updating embedding table. 

504 momentum: Moving average parameter for the momentum accumulator. 

505 use_nesterov: Whether to use the Nesterov variant of momentum. See 

506 Sutskever et al., 2013. 

507 exponent: Exponent for the Adagrad accumulator. 

508 beta2: Moving average parameter for the Adagrad accumulator. 

509 epsilon: initial accumulator for Adagrad accumulator. 

510 use_gradient_accumulation: setting this to `False` makes embedding 

511 gradients calculation less accurate but faster. Please see 

512 `optimization_parameters.proto` for details. 

513 clip_weight_min: the minimum value to clip by; None means -infinity. 

514 clip_weight_max: the maximum value to clip by; None means +infinity. 

515 weight_decay_factor: amount of weight decay to apply; None means that the 

516 weights are not decayed. 

517 multiply_weight_decay_factor_by_learning_rate: if true, 

518 `weight_decay_factor` is multiplied by the current learning rate. 

519 clip_gradient_min: the minimum value to clip by; None means -infinity. 

520 Gradient accumulation must be set to true if this is set. 

521 clip_gradient_max: the maximum value to clip by; None means +infinity. 

522 Gradient accumulation must be set to true if this is set. 

523 """ 

524 super().__init__( 

525 learning_rate=learning_rate, 

526 use_gradient_accumulation=use_gradient_accumulation, 

527 clip_weight_min=clip_weight_min, 

528 clip_weight_max=clip_weight_max, 

529 weight_decay_factor=weight_decay_factor, 

530 multiply_weight_decay_factor_by_learning_rate=( 

531 multiply_weight_decay_factor_by_learning_rate), 

532 clip_gradient_min=clip_gradient_min, 

533 clip_gradient_max=clip_gradient_max, 

534 ) 

535 if epsilon <= 0: 

536 raise ValueError('Adagrad momentum: epsilon must be positive') 

537 if exponent <= 0: 

538 raise ValueError('Adagrad momentum: Precondition exponent must >0') 

539 self.momentum = momentum 

540 self.use_nesterov = use_nesterov 

541 self.exponent = exponent 

542 self.beta2 = beta2 

543 self.epsilon = epsilon 

544 

545 

546class ProximalAdagradParameters(_OptimizationParameters): 

547 """Optimization parameters for ProximalAdagrad with TPU embeddings. 

548 

549 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

550 `optimization_parameters` argument to set the optimizer and its parameters. 

551 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

552 for more details. 

553 """ 

554 

555 def __init__( 

556 self, 

557 learning_rate: float, 

558 initial_accumulator: float = 0.1, 

559 l1_regularization_strength: float = 0.0, 

560 l2_regularization_strength: float = 0.0, 

561 use_gradient_accumulation: bool = True, 

562 clip_weight_min: Optional[float] = None, 

563 clip_weight_max: Optional[float] = None, 

564 weight_decay_factor: Optional[float] = None, 

565 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

566 clip_gradient_min: Optional[float] = None, 

567 clip_gradient_max: Optional[float] = None, 

568 ): 

569 """Optimization parameters for Adagrad. 

570 

571 Args: 

572 learning_rate: used for updating embedding table. 

573 initial_accumulator: initial accumulator for Adagrad. 

574 l1_regularization_strength: A float value, must be greater than or equal 

575 to zero. 

576 l2_regularization_strength: A float value, must be greater than or equal 

577 to zero. 

578 use_gradient_accumulation: setting this to `False` makes embedding 

579 gradients calculation less accurate but faster. Please see 

580 `optimization_parameters.proto` for details. for details. 

581 clip_weight_min: the minimum value to clip by; None means -infinity. 

582 clip_weight_max: the maximum value to clip by; None means +infinity. 

583 weight_decay_factor: amount of weight decay to apply; None means that the 

584 weights are not decayed. 

585 multiply_weight_decay_factor_by_learning_rate: if true, 

586 `weight_decay_factor` is multiplied by the current learning rate. 

587 clip_gradient_min: the minimum value to clip by; None means -infinity. 

588 Gradient accumulation must be set to true if this is set. 

589 clip_gradient_max: the maximum value to clip by; None means +infinity. 

590 Gradient accumulation must be set to true if this is set. 

591 """ 

592 super().__init__( 

593 learning_rate=learning_rate, 

594 use_gradient_accumulation=use_gradient_accumulation, 

595 clip_weight_min=clip_weight_min, 

596 clip_weight_max=clip_weight_max, 

597 weight_decay_factor=weight_decay_factor, 

598 multiply_weight_decay_factor_by_learning_rate=( 

599 multiply_weight_decay_factor_by_learning_rate), 

600 clip_gradient_min=clip_gradient_min, 

601 clip_gradient_max=clip_gradient_max, 

602 ) 

603 if initial_accumulator <= 0: 

604 raise ValueError(f'Adagrad initial_accumulator must be positive. ' 

605 f'Received: {initial_accumulator}.') 

606 if l1_regularization_strength < 0.: 

607 raise ValueError('l1_regularization_strength must be greater than or ' 

608 'equal to 0. got {}.'.format(l1_regularization_strength)) 

609 

610 if l2_regularization_strength < 0.: 

611 raise ValueError('l2_regularization_strength must be greater than or ' 

612 'equal to 0. got {}.'.format(l2_regularization_strength)) 

613 

614 self.initial_accumulator = initial_accumulator 

615 self.l1_regularization_strength = l1_regularization_strength 

616 self.l2_regularization_strength = l2_regularization_strength 

617 

618 

619@tf_export(v1=['tpu.experimental.AdamParameters']) 

620class AdamParameters(_OptimizationParameters): 

621 """Optimization parameters for Adam with TPU embeddings. 

622 

623 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

624 `optimization_parameters` argument to set the optimizer and its parameters. 

625 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

626 for more details. 

627 

628 ``` 

629 estimator = tf.estimator.tpu.TPUEstimator( 

630 ... 

631 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

632 ... 

633 optimization_parameters=tf.tpu.experimental.AdamParameters(0.1), 

634 ...)) 

635 ``` 

636 

637 """ 

638 

639 def __init__( 

640 self, 

641 learning_rate: float, 

642 beta1: float = 0.9, 

643 beta2: float = 0.999, 

644 epsilon: float = 1e-08, 

645 lazy_adam: bool = True, 

646 sum_inside_sqrt: bool = True, 

647 use_gradient_accumulation: bool = True, 

648 clip_weight_min: Optional[float] = None, 

649 clip_weight_max: Optional[float] = None, 

650 weight_decay_factor: Optional[float] = None, 

651 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

652 clip_gradient_min: Optional[float] = None, 

653 clip_gradient_max: Optional[float] = None, 

654 ): 

655 """Optimization parameters for Adam. 

656 

657 Args: 

658 learning_rate: a floating point value. The learning rate. 

659 beta1: A float value. The exponential decay rate for the 1st moment 

660 estimates. 

661 beta2: A float value. The exponential decay rate for the 2nd moment 

662 estimates. 

663 epsilon: A small constant for numerical stability. 

664 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See 

665 `optimization_parameters.proto` for details. 

666 sum_inside_sqrt: This improves training speed. Please see 

667 `optimization_parameters.proto` for details. 

668 use_gradient_accumulation: setting this to `False` makes embedding 

669 gradients calculation less accurate but faster. Please see 

670 `optimization_parameters.proto` for details. 

671 clip_weight_min: the minimum value to clip by; None means -infinity. 

672 clip_weight_max: the maximum value to clip by; None means +infinity. 

673 weight_decay_factor: amount of weight decay to apply; None means that the 

674 weights are not decayed. 

675 multiply_weight_decay_factor_by_learning_rate: if true, 

676 `weight_decay_factor` is multiplied by the current learning rate. 

677 clip_gradient_min: the minimum value to clip by; None means -infinity. 

678 Gradient accumulation must be set to true if this is set. 

679 clip_gradient_max: the maximum value to clip by; None means +infinity. 

680 Gradient accumulation must be set to true if this is set. 

681 """ 

682 super().__init__( 

683 learning_rate=learning_rate, 

684 use_gradient_accumulation=use_gradient_accumulation, 

685 clip_weight_min=clip_weight_min, 

686 clip_weight_max=clip_weight_max, 

687 weight_decay_factor=weight_decay_factor, 

688 multiply_weight_decay_factor_by_learning_rate=( 

689 multiply_weight_decay_factor_by_learning_rate), 

690 clip_gradient_min=clip_gradient_min, 

691 clip_gradient_max=clip_gradient_max, 

692 ) 

693 if beta1 < 0. or beta1 >= 1.: 

694 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 

695 if beta2 < 0. or beta2 >= 1.: 

696 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 

697 if epsilon <= 0.: 

698 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 

699 if not use_gradient_accumulation and not lazy_adam: 

700 raise ValueError( 

701 'When disabling Lazy Adam, gradient accumulation must be used.') 

702 

703 self.beta1 = beta1 

704 self.beta2 = beta2 

705 self.epsilon = epsilon 

706 self.lazy_adam = lazy_adam 

707 self.sum_inside_sqrt = sum_inside_sqrt 

708 

709 

710@tf_export(v1=['tpu.experimental.FtrlParameters']) 

711class FtrlParameters(_OptimizationParameters): 

712 """Optimization parameters for Ftrl with TPU embeddings. 

713 

714 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

715 `optimization_parameters` argument to set the optimizer and its parameters. 

716 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

717 for more details. 

718 

719 ``` 

720 estimator = tf.estimator.tpu.TPUEstimator( 

721 ... 

722 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

723 ... 

724 optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1), 

725 ...)) 

726 ``` 

727 

728 """ 

729 

730 def __init__( 

731 self, 

732 learning_rate: float, 

733 learning_rate_power: float = -0.5, 

734 initial_accumulator_value: float = 0.1, 

735 l1_regularization_strength: float = 0.0, 

736 l2_regularization_strength: float = 0.0, 

737 use_gradient_accumulation: bool = True, 

738 clip_weight_min: Optional[float] = None, 

739 clip_weight_max: Optional[float] = None, 

740 weight_decay_factor: Optional[float] = None, 

741 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

742 multiply_linear_by_learning_rate: bool = False, 

743 beta: float = 0, 

744 allow_zero_accumulator: bool = False, 

745 clip_gradient_min: Optional[float] = None, 

746 clip_gradient_max: Optional[float] = None, 

747 ): 

748 """Optimization parameters for Ftrl. 

749 

750 Implements FTRL as described in the following [paper]( 

751 https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) 

752 

753 Args: 

754 learning_rate: a floating point value. The learning rate. 

755 learning_rate_power: A float value, must be less or equal to zero. 

756 Controls how the learning rate decreases during training. Use zero for a 

757 fixed learning rate. See section 3.1 in the 

758 [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 

759 initial_accumulator_value: The starting value for accumulators. Only zero 

760 or positive values are allowed. 

761 l1_regularization_strength: A float value, must be greater than or equal 

762 to zero. 

763 l2_regularization_strength: A float value, must be greater than or equal 

764 to zero. 

765 use_gradient_accumulation: setting this to `False` makes embedding 

766 gradients calculation less accurate but faster. Please see 

767 `optimization_parameters.proto` for details. for details. 

768 clip_weight_min: the minimum value to clip by; None means -infinity. 

769 clip_weight_max: the maximum value to clip by; None means +infinity. 

770 weight_decay_factor: amount of weight decay to apply; None means that the 

771 weights are not decayed. 

772 multiply_weight_decay_factor_by_learning_rate: if true, 

773 `weight_decay_factor` is multiplied by the current learning rate. 

774 multiply_linear_by_learning_rate: When true, multiplies the usages of the 

775 linear slot in the weight update by the learning rate. This is useful 

776 when ramping up learning rate from 0 (which would normally produce 

777 NaNs). 

778 beta: The beta parameter for FTRL. 

779 allow_zero_accumulator: Changes the implementation of the square root to 

780 allow for the case of initial_accumulator_value being zero. This will 

781 cause a slight performance drop. 

782 clip_gradient_min: the minimum value to clip by; None means -infinity. 

783 Gradient accumulation must be set to true if this is set. 

784 clip_gradient_max: the maximum value to clip by; None means +infinity. 

785 Gradient accumulation must be set to true if this is set. 

786 """ 

787 super().__init__( 

788 learning_rate=learning_rate, 

789 use_gradient_accumulation=use_gradient_accumulation, 

790 clip_weight_min=clip_weight_min, 

791 clip_weight_max=clip_weight_max, 

792 weight_decay_factor=weight_decay_factor, 

793 multiply_weight_decay_factor_by_learning_rate=( 

794 multiply_weight_decay_factor_by_learning_rate), 

795 clip_gradient_min=clip_gradient_min, 

796 clip_gradient_max=clip_gradient_max, 

797 ) 

798 if learning_rate_power > 0.: 

799 raise ValueError('learning_rate_power must be less than or equal to 0. ' 

800 'got {}.'.format(learning_rate_power)) 

801 

802 if initial_accumulator_value < 0.: 

803 raise ValueError('initial_accumulator_value must be greater than or equal' 

804 ' to 0. got {}.'.format(initial_accumulator_value)) 

805 

806 if l1_regularization_strength < 0.: 

807 raise ValueError('l1_regularization_strength must be greater than or ' 

808 'equal to 0. got {}.'.format(l1_regularization_strength)) 

809 

810 if l2_regularization_strength < 0.: 

811 raise ValueError('l2_regularization_strength must be greater than or ' 

812 'equal to 0. got {}.'.format(l2_regularization_strength)) 

813 

814 self.learning_rate_power = learning_rate_power 

815 self.initial_accumulator_value = initial_accumulator_value 

816 self.initial_linear_value = 0.0 

817 self.l1_regularization_strength = l1_regularization_strength 

818 self.l2_regularization_strength = l2_regularization_strength 

819 self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate 

820 self.beta = beta 

821 self.allow_zero_accumulator = allow_zero_accumulator 

822 

823 

824class ProximalYogiParameters(_OptimizationParameters): 

825 # pylint: disable=line-too-long 

826 """Optimization parameters for Proximal Yogi with TPU embeddings. 

827 

828 Implements the Yogi optimizer as described in 

829 [Adaptive Methods for Nonconvex 

830 Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). 

831 

832 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

833 `optimization_parameters` argument to set the optimizer and its parameters. 

834 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

835 for more details. 

836 """ 

837 

838 # pylint: enable=line-too-long 

839 

840 def __init__( 

841 self, 

842 learning_rate: float = 0.01, 

843 beta1: float = 0.9, 

844 beta2: float = 0.999, 

845 epsilon: float = 1e-3, 

846 l1_regularization_strength: float = 0.0, 

847 l2_regularization_strength: float = 0.0, 

848 initial_accumulator_value: float = 1e-6, 

849 use_gradient_accumulation: bool = True, 

850 clip_weight_min: Optional[float] = None, 

851 clip_weight_max: Optional[float] = None, 

852 weight_decay_factor: Optional[float] = None, 

853 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

854 clip_gradient_min: Optional[float] = None, 

855 clip_gradient_max: Optional[float] = None, 

856 ): 

857 """Optimization parameters for Proximal Yogi. 

858 

859 Args: 

860 learning_rate: a floating point value. The learning rate. 

861 beta1: A float value. The exponential decay rate for the 1st moment 

862 estimates. 

863 beta2: A float value. The exponential decay rate for the 2nd moment 

864 estimates. 

865 epsilon: A small constant for numerical stability. 

866 l1_regularization_strength: A float value, must be greater than or equal 

867 to zero. 

868 l2_regularization_strength: A float value, must be greater than or equal 

869 to zero. 

870 initial_accumulator_value: The starting value for accumulators. Only zero 

871 or positive values are allowed. 

872 use_gradient_accumulation: setting this to `False` makes embedding 

873 gradients calculation less accurate but faster. Please see 

874 `optimization_parameters.proto` for details. for details. 

875 clip_weight_min: the minimum value to clip by; None means -infinity. 

876 clip_weight_max: the maximum value to clip by; None means +infinity. 

877 weight_decay_factor: amount of weight decay to apply; None means that the 

878 weights are not decayed. 

879 multiply_weight_decay_factor_by_learning_rate: if true, 

880 `weight_decay_factor` is multiplied by the current learning rate. 

881 clip_gradient_min: the minimum value to clip by; None means -infinity. 

882 Gradient accumulation must be set to true if this is set. 

883 clip_gradient_max: the maximum value to clip by; None means +infinity. 

884 Gradient accumulation must be set to true if this is set. 

885 """ 

886 super().__init__( 

887 learning_rate=learning_rate, 

888 use_gradient_accumulation=use_gradient_accumulation, 

889 clip_weight_min=clip_weight_min, 

890 clip_weight_max=clip_weight_max, 

891 weight_decay_factor=weight_decay_factor, 

892 multiply_weight_decay_factor_by_learning_rate=( 

893 multiply_weight_decay_factor_by_learning_rate), 

894 clip_gradient_min=clip_gradient_min, 

895 clip_gradient_max=clip_gradient_max, 

896 ) 

897 if beta1 < 0. or beta1 >= 1.: 

898 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 

899 if beta2 < 0. or beta2 >= 1.: 

900 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 

901 if epsilon <= 0.: 

902 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 

903 if l1_regularization_strength < 0.: 

904 raise ValueError('l1_regularization_strength must be greater than or ' 

905 'equal to 0. got {}.'.format(l1_regularization_strength)) 

906 if l2_regularization_strength < 0.: 

907 raise ValueError('l2_regularization_strength must be greater than or ' 

908 'equal to 0. got {}.'.format(l2_regularization_strength)) 

909 

910 self.beta1 = beta1 

911 self.beta2 = beta2 

912 self.epsilon = epsilon 

913 self.l1_regularization_strength = l1_regularization_strength 

914 self.l2_regularization_strength = l2_regularization_strength 

915 self.initial_accumulator_value = initial_accumulator_value 

916 

917 

918class MomentumParameters(_OptimizationParameters): 

919 """Optimization parameters for Momentum with TPU embeddings. 

920 

921 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

922 `optimization_parameters` argument to set the optimizer and its parameters. 

923 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

924 for more details. 

925 

926 ``` 

927 estimator = tf.estimator.tpu.TPUEstimator( 

928 ... 

929 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

930 ... 

931 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), 

932 ...)) 

933 ``` 

934 

935 """ 

936 

937 def __init__( 

938 self, 

939 learning_rate: float, 

940 momentum: float, 

941 use_nesterov: bool = False, 

942 use_gradient_accumulation: bool = True, 

943 clip_weight_min: Optional[float] = None, 

944 clip_weight_max: Optional[float] = None, 

945 weight_decay_factor: Optional[float] = None, 

946 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

947 clip_gradient_min: Optional[float] = None, 

948 clip_gradient_max: Optional[float] = None, 

949 ): 

950 """Optimization parameters for momentum. 

951 

952 Args: 

953 learning_rate: a floating point value. The learning rate. 

954 momentum: a floating point value. The momentum. 

955 use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., 

956 2013). This implementation always computes gradients at the value of the 

957 variable(s) passed to the optimizer. Using Nesterov Momentum makes the 

958 variable(s) track the values called `theta_t + mu*v_t` in the paper. 

959 This implementation is an approximation of the original formula, valid 

960 for high values of momentum. It will compute the "adjusted gradient" in 

961 NAG by assuming that the new gradient will be estimated by the current 

962 average gradient plus the product of momentum and the change in the 

963 average gradient. 

964 use_gradient_accumulation: setting this to `False` makes embedding 

965 gradients calculation less accurate but faster. Please see 

966 `optimization_parameters.proto` for details. 

967 clip_weight_min: the minimum value to clip by; None means -infinity. 

968 clip_weight_max: the maximum value to clip by; None means +infinity. 

969 weight_decay_factor: amount of weight decay to apply; None means that the 

970 weights are not decayed. 

971 multiply_weight_decay_factor_by_learning_rate: if true, 

972 `weight_decay_factor` is multiplied by the current learning rate. 

973 clip_gradient_min: the minimum value to clip by; None means -infinity. 

974 Gradient accumulation must be set to true if this is set. 

975 clip_gradient_max: the maximum value to clip by; None means +infinity. 

976 Gradient accumulation must be set to true if this is set. 

977 """ 

978 super().__init__( 

979 learning_rate=learning_rate, 

980 use_gradient_accumulation=use_gradient_accumulation, 

981 clip_weight_min=clip_weight_min, 

982 clip_weight_max=clip_weight_max, 

983 weight_decay_factor=weight_decay_factor, 

984 multiply_weight_decay_factor_by_learning_rate=( 

985 multiply_weight_decay_factor_by_learning_rate), 

986 clip_gradient_min=clip_gradient_min, 

987 clip_gradient_max=clip_gradient_max, 

988 ) 

989 self.momentum = momentum 

990 self.use_nesterov = use_nesterov 

991 

992 

993class RMSPropParameters(_OptimizationParameters): 

994 """Optimization parameters for RMSProp with TPU embeddings. 

995 

996 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

997 `optimization_parameters` argument to set the optimizer and its parameters. 

998 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

999 for more details. 

1000 

1001 ``` 

1002 estimator = tf.estimator.tpu.TPUEstimator( 

1003 ... 

1004 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

1005 ... 

1006 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), 

1007 ...)) 

1008 ``` 

1009 

1010 """ 

1011 

1012 def __init__( 

1013 self, 

1014 learning_rate: float, 

1015 rho: float, 

1016 momentum: float, 

1017 epsilon: float, 

1018 use_gradient_accumulation: bool = True, 

1019 clip_weight_min: Optional[float] = None, 

1020 clip_weight_max: Optional[float] = None, 

1021 weight_decay_factor: Optional[float] = None, 

1022 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

1023 clip_gradient_min: Optional[float] = None, 

1024 clip_gradient_max: Optional[float] = None, 

1025 ): 

1026 """Optimization parameters for RMS prop. 

1027 

1028 Args: 

1029 learning_rate: a floating point value. The learning rate. 

1030 rho: Discounting factor for the history/coming gradient 

1031 momentum: A scalar tensor. 

1032 epsilon: Small value to avoid zero denominator. 

1033 use_gradient_accumulation: setting this to `False` makes embedding 

1034 gradients calculation less accurate but faster. Please see 

1035 `optimization_parameters.proto` for details. for details. 

1036 clip_weight_min: the minimum value to clip by; None means -infinity. 

1037 clip_weight_max: the maximum value to clip by; None means +infinity. 

1038 weight_decay_factor: amount of weight decay to apply; None means that the 

1039 weights are not decayed. 

1040 multiply_weight_decay_factor_by_learning_rate: if true, 

1041 `weight_decay_factor` is multiplied by the current learning rate. 

1042 clip_gradient_min: the minimum value to clip by; None means -infinity. 

1043 Gradient accumulation must be set to true if this is set. 

1044 clip_gradient_max: the maximum value to clip by; None means +infinity. 

1045 Gradient accumulation must be set to true if this is set. 

1046 """ 

1047 super().__init__( 

1048 learning_rate=learning_rate, 

1049 use_gradient_accumulation=use_gradient_accumulation, 

1050 clip_weight_min=clip_weight_min, 

1051 clip_weight_max=clip_weight_max, 

1052 weight_decay_factor=weight_decay_factor, 

1053 multiply_weight_decay_factor_by_learning_rate=( 

1054 multiply_weight_decay_factor_by_learning_rate), 

1055 clip_gradient_min=clip_gradient_min, 

1056 clip_gradient_max=clip_gradient_max, 

1057 ) 

1058 self.rho = rho 

1059 self.momentum = momentum 

1060 self.epsilon = epsilon 

1061 

1062 

1063@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) 

1064class StochasticGradientDescentParameters(_OptimizationParameters): 

1065 """Optimization parameters for stochastic gradient descent for TPU embeddings. 

1066 

1067 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

1068 `optimization_parameters` argument to set the optimizer and its parameters. 

1069 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

1070 for more details. 

1071 

1072 ``` 

1073 estimator = tf.estimator.tpu.TPUEstimator( 

1074 ... 

1075 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

1076 ... 

1077 optimization_parameters=( 

1078 tf.tpu.experimental.StochasticGradientDescentParameters(0.1)))) 

1079 ``` 

1080 

1081 """ 

1082 

1083 def __init__( 

1084 self, 

1085 learning_rate: float, 

1086 use_gradient_accumulation: bool = True, 

1087 clip_weight_min: Optional[float] = None, 

1088 clip_weight_max: Optional[float] = None, 

1089 weight_decay_factor: Optional[float] = None, 

1090 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 

1091 clip_gradient_min: Optional[float] = None, 

1092 clip_gradient_max: Optional[float] = None, 

1093 ): 

1094 """Optimization parameters for stochastic gradient descent. 

1095 

1096 Args: 

1097 learning_rate: a floating point value. The learning rate. 

1098 use_gradient_accumulation: setting this to `False` makes embedding 

1099 gradients calculation less accurate but faster. Please see 

1100 `optimization_parameters.proto` for details. 

1101 clip_weight_min: the minimum value to clip by; None means -infinity. 

1102 clip_weight_max: the maximum value to clip by; None means +infinity. 

1103 weight_decay_factor: amount of weight decay to apply; None means that the 

1104 weights are not decayed. 

1105 multiply_weight_decay_factor_by_learning_rate: if true, 

1106 `weight_decay_factor` is multiplied by the current learning rate. 

1107 clip_gradient_min: the minimum value to clip by; None means -infinity. 

1108 clip_gradient_max: the maximum value to clip by; None means +infinity. 

1109 """ 

1110 super().__init__( 

1111 learning_rate=learning_rate, 

1112 use_gradient_accumulation=use_gradient_accumulation, 

1113 clip_weight_min=clip_weight_min, 

1114 clip_weight_max=clip_weight_max, 

1115 weight_decay_factor=weight_decay_factor, 

1116 multiply_weight_decay_factor_by_learning_rate=( 

1117 multiply_weight_decay_factor_by_learning_rate), 

1118 clip_gradient_min=clip_gradient_min, 

1119 clip_gradient_max=clip_gradient_max, 

1120 ) 

1121 

1122 

1123class FrequencyEstimatorParameters(_OptimizationParameters): 

1124 """Optimization parameters for Frequency Estimator TPU embeddings. 

1125 

1126 This is a non-standard optimizer, which returns the estimated frequency of 

1127 lookup for the feature passed to it. It should only be used on a table of 

1128 width 1. The gradient fed back to the TPU embedding should always be zero. 

1129 This can be acomplished via using `tf.stop_gradients` on the feature before 

1130 using it. 

1131 

1132 You must use the dynamic learning rate mechanism to set the 'learning rate' 

1133 for this table to be the a float32 cast of the global training step counter. 

1134 

1135 See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more 

1136 details on this optimizer. 

1137 

1138 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 

1139 `optimization_parameters` argument to set the optimizer and its parameters. 

1140 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 

1141 for more details. 

1142 

1143 ``` 

1144 estimator = tf.estimator.tpu.TPUEstimator( 

1145 ... 

1146 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

1147 ... 

1148 optimization_parameters=FrequencyEstimatorParameters(0.1), 

1149 ...)) 

1150 ``` 

1151 

1152 """ 

1153 

1154 def __init__(self, tau: float, max_delta: float, outlier_threshold: float, 

1155 weight_exponent: float): 

1156 """Optimization parameters for frequency estimator. 

1157 

1158 Args: 

1159 tau: Learning rate between (0, 1) that is used to update the array. 

1160 max_delta: Maximum value of delta, the difference between the current 

1161 global step and the last global step at which the row was sampled. 

1162 outlier_threshold: Threshold used to determine whether the current update 

1163 is an outlier. 

1164 weight_exponent: The weight exponent used to transform the estimated delta 

1165 into weights. 

1166 """ 

1167 super().__init__( 

1168 learning_rate=1.0, 

1169 use_gradient_accumulation=True, 

1170 clip_weight_min=None, 

1171 clip_weight_max=None, 

1172 weight_decay_factor=None, 

1173 multiply_weight_decay_factor_by_learning_rate=None, 

1174 ) 

1175 self.tau = tau 

1176 self.max_delta = max_delta 

1177 self.outlier_threshold = outlier_threshold 

1178 self.weight_exponent = weight_exponent 

1179 

1180 

1181DeviceConfig = collections.namedtuple('DeviceConfig', 

1182 ['num_hosts', 'num_cores', 'job_name']) 

1183 

1184 

1185class TPUEmbedding: 

1186 """API for using TPU for embedding. 

1187 

1188 Example: 

1189 ``` 

1190 table_config_user = tpu_embedding.TableConfig( 

1191 vocabulary_size=4, dimension=2, 

1192 initializer=initializer, combiner='mean') 

1193 table_to_config_dict = {'video': table_config_video, 

1194 'user': table_config_user} 

1195 feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'), 

1196 'favorited': tpu_embedding.FeatureConfig('video'), 

1197 'friends': tpu_embedding.FeatureConfig('user')} 

1198 batch_size = 4 

1199 num_hosts = 1 

1200 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) 

1201 mode = tpu_embedding.TRAINING 

1202 embedding = tpu_embedding.TPUEmbedding( 

1203 table_to_config_dict, feature_to_config_dict, 

1204 batch_size, num_hosts, mode, optimization_parameters) 

1205 

1206 batch_size_per_core = embedding.batch_size_per_core 

1207 sparse_features_list = [] 

1208 for host in hosts: 

1209 with ops.device(host): 

1210 for _ in range(embedding.num_cores_per_host): 

1211 sparse_features = {} 

1212 sparse_features['watched'] = sparse_tensor.SparseTensor(...) 

1213 sparse_features['favorited'] = sparse_tensor.SparseTensor(...) 

1214 sparse_features['friends'] = sparse_tensor.SparseTensor(...) 

1215 sparse_features_list.append(sparse_features) 

1216 

1217 enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) 

1218 embedding_variables_and_ops = embedding.create_variables_and_ops() 

1219 

1220 def computation(): 

1221 activations = embedding.get_activations() 

1222 loss = compute_loss(activations) 

1223 

1224 base_optimizer = gradient_descent.GradientDescentOptimizer( 

1225 learning_rate=1) 

1226 cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( 

1227 base_optimizer) 

1228 

1229 train_op = cross_shard_optimizer.minimize(loss) 

1230 gradients = ( 

1231 tpu_embedding_gradient.get_gradients_through_compute_gradients( 

1232 cross_shard_optimizer, loss, activations) 

1233 send_gradients_op = embedding.generate_send_gradients_op(gradients) 

1234 with ops.control_dependencies([train_op, send_gradients_op]): 

1235 loss = array_ops.identity(loss) 

1236 

1237 loss = tpu.shard(computation, 

1238 num_shards=embedding.num_cores) 

1239 

1240 with self.test_session() as sess: 

1241 sess.run(tpu.initialize_system(embedding_config= 

1242 embedding.config_proto)) 

1243 sess.run(variables.global_variables_initializer()) 

1244 sess.run(embedding_variables_and_ops.load_ops()) 

1245 sess.run(enqueue_ops) 

1246 loss_val = sess.run(loss) 

1247 ``` 

1248 

1249 Example with weight decay: 

1250 

1251 >>> def learning_rate_fn(global_step): 

1252 ... return tf.compat.v1.train.polynomial_decay( 

1253 ... learning_rate=5e-5, 

1254 ... global_step=global_step, 

1255 ... decay_steps=100000, 

1256 ... end_learning_rate=0.0) 

1257 >>> wordpiece_table_config = TableConfig( 

1258 ... vocabulary_size=119547, 

1259 ... dimension=256, 

1260 ... learning_rate_fn=learning_rate_fn) 

1261 >>> wordpiece_feature_config = FeatureConfig( 

1262 ... table_id='bert/embeddings/word_embeddings', 

1263 ... max_sequence_length=512) 

1264 >>> optimization_parameters = AdamParameters( 

1265 ... learning_rate=5e-5, 

1266 ... epsilon=1e-6, 

1267 ... weight_decay_factor=0.01, 

1268 ... multiply_weight_decay_factor_by_learning_rate=True) 

1269 >>> tpu_embedding = TPUEmbedding( 

1270 ... table_to_config_dict={ 

1271 ... 'bert/embeddings/word_embeddings': wordpiece_table_config, 

1272 ... }, 

1273 ... feature_to_config_dict={'input_ids': wordpiece_feature_config}, 

1274 ... batch_size=128, 

1275 ... mode=TRAINING, 

1276 ... optimization_parameters=optimization_parameters, 

1277 ... master='') 

1278 >>> with tf.Graph().as_default(): 

1279 ... init_tpu_op = tf.compat.v1.tpu.initialize_system( 

1280 ... embedding_config=tpu_embedding.config_proto) 

1281 ... tf.compat.v1.Session().run(init_tpu_op) 

1282 """ 

1283 

1284 # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that 

1285 # the feature should not be used to update embedding table (cr/204852758, 

1286 # cr/204940540). Also, this can support different combiners for different 

1287 # features within the same table. 

1288 # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it 

1289 # to `FeatureConfig`? 

1290 

1291 # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and 

1292 # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec` 

1293 # respectively? 

1294 

1295 # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate 

1296 # for-loops around construction of inputs. 

1297 

1298 # `optimization_parameter` applies to all tables. If the need arises, 

1299 # we can add `optimization_parameters` to `TableConfig` to override this 

1300 # global setting. 

1301 def __init__(self, 

1302 table_to_config_dict, 

1303 feature_to_config_dict, 

1304 batch_size, 

1305 mode, 

1306 master=None, 

1307 optimization_parameters=None, 

1308 cluster_def=None, 

1309 pipeline_execution_with_tensor_core=False, 

1310 partition_strategy='div', 

1311 profile_data_directory=None, 

1312 device_config=None, 

1313 master_job_name=None): 

1314 """API for using TPU for embedding lookups. 

1315 

1316 Args: 

1317 table_to_config_dict: A dictionary mapping from string of table name to 

1318 `TableConfig`. Table refers to an embedding table, e.g. `params` 

1319 argument to `tf.nn.embedding_lookup_sparse()`. 

1320 feature_to_config_dict: A dictionary mapping from string of feature name 

1321 to `FeatureConfig`. Feature refers to ids to lookup in embedding table, 

1322 e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. 

1323 batch_size: An `int` representing the global batch size. 

1324 mode: `TRAINING` or `INFERENCE`. 

1325 master: A `string` representing the TensorFlow master to use. 

1326 optimization_parameters: `AdagradParameters`, `AdamParameters`, 

1327 `Stochasticgradientdescentparameters`. Must be set in training unless 

1328 all tables specify their own optimizers. And it must be `None` in 

1329 inference. 

1330 cluster_def: A ClusterDef object describing the TPU cluster. 

1331 pipeline_execution_with_tensor_core: setting this to `True` makes training 

1332 faster, but trained model will be different if step N and step N+1 

1333 involve the same set of embedding IDs. Please see 

1334 `tpu_embedding_configuration.proto` for details. 

1335 partition_strategy: A string, either 'mod' or 'div', specifying how to map 

1336 the lookup id to the embedding tensor. For more information see 

1337 `tf.nn.embedding_lookup_sparse`. 

1338 profile_data_directory: Directory where embedding lookup statistics are 

1339 stored. These statistics summarize information about the inputs to the 

1340 embedding lookup operation, in particular, the average number of 

1341 embedding IDs per example and how well the embedding IDs are load 

1342 balanced across the system. The lookup statistics are used during TPU 

1343 initialization for embedding table partitioning. Collection of lookup 

1344 statistics is done at runtime by profiling the embedding inputs, only a 

1345 small fraction of input samples are profiled to minimize host CPU 

1346 overhead. Once a suitable number of samples are profiled, the lookup 

1347 statistics are saved to table-specific files in the profile data 

1348 directory generally at the end of a TPU training loop. The filename 

1349 corresponding to each table is obtained by hashing table specific 

1350 parameters (e.g., table name and number of features) and global 

1351 configuration parameters (e.g., sharding strategy and task count). The 

1352 same profile data directory can be shared among several models to reuse 

1353 embedding lookup statistics. 

1354 device_config: A DeviceConfig instance, used when `master` and 

1355 `cluster_def` are both `None`. 

1356 master_job_name: if set, overrides the master job name used to schedule 

1357 embedding ops. 

1358 

1359 Raises: 

1360 ValueError: if any input is invalid. 

1361 """ 

1362 if partition_strategy not in ('div', 'mod'): 

1363 raise ValueError(f'partition_strategy must be "div" or "mod". ' 

1364 f'Received: {partition_strategy}.') 

1365 self._partition_strategy = partition_strategy 

1366 

1367 self._profile_data_directory = profile_data_directory 

1368 

1369 _validate_table_to_config_dict(table_to_config_dict) 

1370 # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. 

1371 self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) 

1372 

1373 _validate_feature_to_config_dict(table_to_config_dict, 

1374 feature_to_config_dict) 

1375 self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict) 

1376 self._table_to_features_dict = ( 

1377 _create_table_to_features_dict(self._feature_to_config_dict)) 

1378 self._combiners = _create_combiners(self._table_to_config_dict, 

1379 self._table_to_features_dict) 

1380 

1381 self._batch_size = batch_size 

1382 

1383 if master is None and cluster_def is None: 

1384 if device_config is None: 

1385 raise ValueError('When master and cluster_def are both None,' 

1386 'device_config must be set but is not.') 

1387 if device_config.num_cores % device_config.num_hosts: 

1388 raise ValueError('num_hosts ({}) should divide num_cores ({}) ' 

1389 'but does not.'.format(device_config.num_cores, 

1390 device_config.num_hosts)) 

1391 self._num_hosts = device_config.num_hosts 

1392 self._num_cores = device_config.num_cores 

1393 self._num_cores_per_host = self._num_cores // self._num_hosts 

1394 self._hosts = [ 

1395 '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i) 

1396 for i in range(self._num_hosts) 

1397 ] 

1398 else: 

1399 tpu_system_metadata = ( 

1400 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 

1401 master, 

1402 cluster_def=cluster_def)) 

1403 if tpu_system_metadata.num_cores == 0: 

1404 raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' 

1405 'TPUs.'.format(master)) 

1406 self._num_hosts = tpu_system_metadata.num_hosts 

1407 if master_job_name is None: 

1408 try: 

1409 master_job_name = tpu_system_metadata_lib.master_job( 

1410 master, cluster_def) 

1411 except ValueError as e: 

1412 raise ValueError(str(e) + ' Please specify a master_job_name.') 

1413 self._hosts = [] 

1414 for device in tpu_system_metadata.devices: 

1415 if 'device:CPU:' in device.name and (master_job_name is None or 

1416 master_job_name in device.name): 

1417 self._hosts.append(device.name) 

1418 self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host 

1419 self._num_cores = tpu_system_metadata.num_cores 

1420 

1421 _validate_batch_size(self._batch_size, self._num_cores) 

1422 self._batch_size_per_core = self._batch_size // self._num_cores 

1423 

1424 # TODO(shizhiw): remove `mode`? 

1425 if mode == TRAINING: 

1426 _validate_optimization_parameters(optimization_parameters, 

1427 self._table_to_config_dict) 

1428 self._optimization_parameters = optimization_parameters 

1429 elif mode == INFERENCE: 

1430 if optimization_parameters is not None: 

1431 raise ValueError(f'`optimization_parameters` should be `None` ' 

1432 f'for inference mode. ' 

1433 f'Received: {optimization_parameters}.') 

1434 self._optimization_parameters = (StochasticGradientDescentParameters(1.)) 

1435 else: 

1436 raise ValueError('`mode` only supports {} and {}; got {}.'.format( 

1437 TRAINING, INFERENCE, mode)) 

1438 self._mode = mode 

1439 

1440 # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` 

1441 # and create special handler for inference that inherits from 

1442 # StochasticGradientDescentHandler with more user-friendly error message 

1443 # on get_slot(). 

1444 self._optimizer_handler_dict = self._get_optimizer_handler_by_table() 

1445 

1446 self._pipeline_execution_with_tensor_core = ( 

1447 pipeline_execution_with_tensor_core) 

1448 self._learning_rate_fn = list( 

1449 set(c.learning_rate_fn 

1450 for c in self._table_to_config_dict.values() 

1451 if c.learning_rate_fn is not None)) 

1452 self._learning_rate_fn_to_tag = { 

1453 fn: id for id, fn in enumerate(self._learning_rate_fn) 

1454 } 

1455 

1456 self._config_proto = self._create_config_proto() 

1457 

1458 @property 

1459 def hosts(self): 

1460 """A list of device names for CPU hosts. 

1461 

1462 Returns: 

1463 A list of device names for CPU hosts. 

1464 """ 

1465 return copy.copy(self._hosts) 

1466 

1467 # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and 

1468 # to be consistent with `tpu_embedding_configuration.proto`. 

1469 @property 

1470 def num_cores_per_host(self): 

1471 """Number of TPU cores on a CPU host. 

1472 

1473 Returns: 

1474 Number of TPU cores on a CPU host. 

1475 """ 

1476 return self._num_cores_per_host 

1477 

1478 @property 

1479 def num_cores(self): 

1480 """Total number of TPU cores on all hosts. 

1481 

1482 Returns: 

1483 Total number of TPU cores on all hosts. 

1484 """ 

1485 return self._num_cores 

1486 

1487 @property 

1488 def batch_size_per_core(self): 

1489 """Batch size for each TPU core. 

1490 

1491 The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` 

1492 must have batch dimension equal to this. 

1493 

1494 Returns: 

1495 Batch size for each TPU core. 

1496 """ 

1497 return self._batch_size_per_core 

1498 

1499 @property 

1500 def config_proto(self): 

1501 """Create embedding config proto for `tpu.initialize_system()`. 

1502 

1503 Returns: 

1504 an `TPUEmbeddingConfiguration` proto describing the desired 

1505 configuration of the hardware embedding lookup tables, which 

1506 is passed to `tpu.initialize_system()`. 

1507 """ 

1508 return self._config_proto 

1509 

1510 @property 

1511 def table_to_config_dict(self): 

1512 return copy.copy(self._table_to_config_dict) 

1513 

1514 @property 

1515 def feature_to_config_dict(self): 

1516 return copy.copy(self._feature_to_config_dict) 

1517 

1518 @property 

1519 def table_to_features_dict(self): 

1520 return copy.copy(self._table_to_features_dict) 

1521 

1522 @property 

1523 def optimization_parameters(self): 

1524 return self._optimization_parameters 

1525 

1526 def _create_config_proto(self): 

1527 """Create `TPUEmbeddingConfiguration`.""" 

1528 config_proto = elc.TPUEmbeddingConfiguration() 

1529 for table in self._table_to_config_dict: 

1530 table_descriptor = config_proto.table_descriptor.add() 

1531 table_descriptor.name = table 

1532 

1533 table_config = self._table_to_config_dict[table] 

1534 # For small tables, we pad to the number of hosts so that at least one 

1535 # id will be assigned to each host. 

1536 table_descriptor.vocabulary_size = max(table_config.vocabulary_size, 

1537 len(self.hosts)) 

1538 table_descriptor.dimension = table_config.dimension 

1539 

1540 optimization_parameters = ( 

1541 self._optimizer_handler_dict[table].get_optimization_parameters()) 

1542 

1543 parameters = table_descriptor.optimization_parameters 

1544 if table_config.learning_rate: 

1545 parameters.learning_rate.constant = table_config.learning_rate 

1546 elif table_config.learning_rate_fn: 

1547 parameters.learning_rate.dynamic.tag = ( 

1548 self._learning_rate_fn_to_tag[table_config.learning_rate_fn]) 

1549 else: 

1550 parameters.learning_rate.constant = ( 

1551 optimization_parameters.learning_rate) 

1552 parameters.gradient_accumulation_status = ( 

1553 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED 

1554 if optimization_parameters.use_gradient_accumulation else 

1555 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) 

1556 

1557 if optimization_parameters.clip_gradient_min is not None: 

1558 parameters.gradient_clipping_limits.lower.value = ( 

1559 optimization_parameters.clip_gradient_min) 

1560 if optimization_parameters.clip_gradient_max is not None: 

1561 parameters.gradient_clipping_limits.upper.value = ( 

1562 optimization_parameters.clip_gradient_max) 

1563 

1564 if optimization_parameters.clip_weight_min is not None: 

1565 parameters.clipping_limits.lower.value = ( 

1566 optimization_parameters.clip_weight_min) 

1567 if optimization_parameters.clip_weight_max is not None: 

1568 parameters.clipping_limits.upper.value = ( 

1569 optimization_parameters.clip_weight_max) 

1570 if optimization_parameters.weight_decay_factor: 

1571 parameters.weight_decay_factor = ( 

1572 optimization_parameters.weight_decay_factor) 

1573 if (optimization_parameters 

1574 .multiply_weight_decay_factor_by_learning_rate): 

1575 parameters.multiply_weight_decay_factor_by_learning_rate = True 

1576 if table_config.hot_id_replication: 

1577 parameters.hot_id_replication_configuration.status = ( 

1578 optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED) 

1579 optimizer_handler = self._optimizer_handler_dict[table] 

1580 optimizer_handler.set_optimization_parameters(table_descriptor) 

1581 

1582 table_to_id = { 

1583 table: i for i, table in enumerate(self._table_to_config_dict) 

1584 } 

1585 

1586 # Set feature descriptor field in the config proto. 

1587 for table in self._table_to_features_dict: 

1588 features = self._table_to_features_dict[table] 

1589 for feature in features: 

1590 feature_descriptor = config_proto.feature_descriptor.add() 

1591 

1592 feature_descriptor.table_id = table_to_id[ 

1593 self._feature_to_config_dict[feature].table_id] 

1594 if self._feature_to_config_dict[feature].max_sequence_length > 0: 

1595 feature_descriptor.input_shape.extend([ 

1596 self._batch_size_per_core, 

1597 self._feature_to_config_dict[feature].max_sequence_length 

1598 ]) 

1599 else: 

1600 feature_descriptor.input_shape.extend([self._batch_size_per_core]) 

1601 

1602 config_proto.mode = self._mode 

1603 config_proto.num_hosts = self._num_hosts 

1604 config_proto.num_tensor_cores = self._num_cores 

1605 config_proto.sharding_strategy = ( 

1606 elc.TPUEmbeddingConfiguration.DIV_DEFAULT if self._partition_strategy 

1607 == 'div' else elc.TPUEmbeddingConfiguration.MOD) 

1608 config_proto.pipeline_execution_with_tensor_core = ( 

1609 self._pipeline_execution_with_tensor_core) 

1610 if self._profile_data_directory: 

1611 config_proto.profile_data_directory = self._profile_data_directory 

1612 

1613 return config_proto 

1614 

1615 def create_variables_and_ops(self, 

1616 embedding_variable_name_by_table=None, 

1617 slot_variable_names_by_table=None): 

1618 """Create embedding and slot variables, with ops to load and retrieve them. 

1619 

1620 N.B.: the retrieve embedding variables (including slot variables) ops are 

1621 returned as lambda fn, as the call side might want to impose control 

1622 dependencies between the TPU computation and retrieving actions. For 

1623 example, the following code snippet ensures the TPU computation finishes 

1624 first, and then we pull the variables back from TPU to CPU. 

1625 

1626 ``` 

1627 updates_ops = [] 

1628 with ops.control_dependencies([loss]): 

1629 for op_fn in retrieve_parameters_op_fns: 

1630 update_ops.append(op_fn()) 

1631 ``` 

1632 

1633 Args: 

1634 embedding_variable_name_by_table: A dictionary mapping from string of 

1635 table name to string of embedding variable name. If `None`, defaults 

1636 from `get_default_slot_variable_names()` will be used. 

1637 slot_variable_names_by_table: A dictionary mapping from string of table 

1638 name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If 

1639 `None`, defaults from `get_default_slot_variable_names()` will be used. 

1640 

1641 Returns: 

1642 `tpu_embedding.VariablesAndOps` with: 

1643 A dictionary mapping from string of table name to embedding variables, 

1644 A dictionary mapping from string of table name to AdagradSlotVariables, 

1645 AdamSlotVariables etc with slot variables, 

1646 A function which returns a list of ops to load embedding and slot 

1647 variables from CPU to TPU. 

1648 A function which returns a list of ops to retrieve embedding and slot 

1649 variables from TPU to CPU. 

1650 """ 

1651 embedding_variables_by_table = {} 

1652 slot_variables_by_table = {} 

1653 load_op_fns = [] 

1654 retrieve_op_fns = [] 

1655 

1656 for i, table in enumerate(self._table_to_config_dict): 

1657 if embedding_variable_name_by_table: 

1658 embedding_variable_name = embedding_variable_name_by_table[table] 

1659 else: 

1660 embedding_variable_name = table 

1661 if slot_variable_names_by_table: 

1662 slot_variable_names = slot_variable_names_by_table[table] 

1663 else: 

1664 optimizer_handler = self._optimizer_handler_dict[table] 

1665 slot_variable_names = ( 

1666 optimizer_handler.get_default_slot_variable_names(table)) 

1667 

1668 # TODO(b/139144091): Multi-host support for mid-level API in 

1669 # eager context (TF 2.0) 

1670 # Workaround below allows single-host use case in TF 2.0 

1671 if context.executing_eagerly(): 

1672 device = '' 

1673 else: 

1674 device = _create_device_fn(self._hosts) 

1675 

1676 with ops.device(device): 

1677 table_variables = _create_partitioned_variables( 

1678 name=embedding_variable_name, 

1679 num_hosts=self._num_hosts, 

1680 vocabulary_size=self._table_to_config_dict[table].vocabulary_size, 

1681 embedding_dimension=self._table_to_config_dict[table].dimension, 

1682 initializer=self._table_to_config_dict[table].initializer, 

1683 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 

1684 embedding_variables_by_table[table] = table_variables 

1685 

1686 # Only loads embedding config to load/retrieve nodes for the first table 

1687 # on the first host, other nodes would use config from the first node. 

1688 config = None if i else self.config_proto.SerializeToString() 

1689 slot_variables_for_table, load_ops_fn, retrieve_ops_fn = ( 

1690 self._optimizer_handler_dict[table].create_variables_and_ops( 

1691 table, slot_variable_names, self._num_hosts, 

1692 self._table_to_config_dict[table], table_variables, config)) 

1693 slot_variables_by_table[table] = slot_variables_for_table 

1694 load_op_fns.append(load_ops_fn) 

1695 retrieve_op_fns.append(retrieve_ops_fn) 

1696 

1697 def load_ops(): 

1698 """Calls and returns the load ops for each embedding table. 

1699 

1700 Returns: 

1701 A list of ops to load embedding and slot variables from CPU to TPU. 

1702 """ 

1703 load_ops_list = [] 

1704 for load_op_fn in load_op_fns: 

1705 load_ops_list.extend(load_op_fn()) 

1706 return load_ops_list 

1707 

1708 def retrieve_ops(): 

1709 """Calls and returns the retrieve ops for each embedding table. 

1710 

1711 Returns: 

1712 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

1713 """ 

1714 retrieve_ops_list = [] 

1715 for retrieve_op_fn in retrieve_op_fns: 

1716 retrieve_ops_list.extend(retrieve_op_fn()) 

1717 return retrieve_ops_list 

1718 

1719 return VariablesAndOps(embedding_variables_by_table, 

1720 slot_variables_by_table, load_ops, retrieve_ops) 

1721 

1722 def generate_enqueue_ops( 

1723 self, 

1724 enqueue_datas_list, 

1725 mode_override=None, 

1726 ragged=False, 

1727 ): 

1728 """Generate enqueue ops. 

1729 

1730 Args: 

1731 enqueue_datas_list: a list of dictionary mapping from string of feature 

1732 names to EnqueueData. Each dictionary is for one TPU core. Dictionaries 

1733 for the same host should be contiguous in the list. 

1734 mode_override: A string input that overrides the mode specified in the 

1735 TPUEmbeddingConfiguration. Supported values are {'unspecified', 

1736 'inference', 'training', 'backward_pass_only'}. When set to 

1737 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 

1738 otherwise mode_override is used (optional). 

1739 ragged: If True, creates RaggedTensor enqueue ops rather than 

1740 SparseTensor. 

1741 

1742 Returns: 

1743 Ops to enqueue to TPU for embedding. 

1744 """ 

1745 self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list) 

1746 return [ 

1747 self._generate_enqueue_op( # pylint: disable=g-complex-comprehension 

1748 enqueue_datas, 

1749 device_ordinal=i % self._num_cores_per_host, 

1750 mode_override=mode_override, 

1751 ragged=ragged, 

1752 ) for i, enqueue_datas in enumerate(enqueue_datas_list) 

1753 ] 

1754 

1755 def _validate_generate_enqueue_ops_enqueue_datas_list(self, 

1756 enqueue_datas_list): 

1757 """Validate `enqueue_datas_list`.""" 

1758 

1759 def _check_agreement(data, name, feature, enqueue_data): 

1760 """Helper function to check device agreement.""" 

1761 if (data is not None and 

1762 data.device != enqueue_data.embedding_indices.device): 

1763 raise ValueError('Device of {0} does not agree with that of' 

1764 'embedding_indices for feature {1}.'.format( 

1765 name, feature)) 

1766 

1767 feature_set = set(self._feature_to_config_dict.keys()) 

1768 contiguous_device = None 

1769 for i, enqueue_datas in enumerate(enqueue_datas_list): 

1770 used_feature_set = set(enqueue_datas.keys()) 

1771 

1772 # Check features are valid. 

1773 missing_feature_set = feature_set - used_feature_set 

1774 if missing_feature_set: 

1775 raise ValueError('`enqueue_datas_list[{}]` misses a feature that is ' 

1776 'in `feature_to_config_dict`: {}.'.format( 

1777 i, missing_feature_set)) 

1778 

1779 extra_feature_set = used_feature_set - feature_set 

1780 if extra_feature_set: 

1781 raise ValueError('`enqueue_datas_list[{}]` has a feature that is not ' 

1782 'in `feature_to_config_dict`: {}.'.format( 

1783 i, extra_feature_set)) 

1784 

1785 device = None 

1786 device_feature = None 

1787 for feature, enqueue_data in enqueue_datas.items(): 

1788 combiner = self._table_to_config_dict[ 

1789 self._feature_to_config_dict[feature].table_id].combiner 

1790 

1791 if isinstance(enqueue_data, EnqueueData): 

1792 if enqueue_data.sample_indices is None and combiner: 

1793 logging.warn( 

1794 'No sample indices set for features %f table %f but ' 

1795 'combiner is set to %s.', feature, 

1796 self._feature_to_config_dict[feature].table_id, combiner) 

1797 _check_agreement(enqueue_data.sample_indices, 'sample_indices', 

1798 feature, enqueue_data) 

1799 _check_agreement(enqueue_data.aggregation_weights, 

1800 'aggregation_weights', feature, enqueue_data) 

1801 

1802 elif isinstance(enqueue_data, RaggedEnqueueData): 

1803 if enqueue_data.row_splits is None and combiner: 

1804 logging.warn( 

1805 'No row splits set for features %f table %f but ' 

1806 'combiner is set to %s.', feature, 

1807 self._feature_to_config_dict[feature].table_id, combiner) 

1808 _check_agreement(enqueue_data.row_splits, 'row_splits', feature, 

1809 enqueue_data) 

1810 _check_agreement(enqueue_data.aggregation_weights, 

1811 'aggregation_weights', feature, enqueue_data) 

1812 else: 

1813 raise ValueError( 

1814 '`enqueue_datas_list[{}]` has a feature that is not mapped to ' 

1815 '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format( 

1816 i, feature)) 

1817 # Check all features are on the same device. 

1818 if device is None: 

1819 device = enqueue_data.embedding_indices.device 

1820 device_feature = feature 

1821 else: 

1822 if device != enqueue_data.embedding_indices.device: 

1823 raise ValueError('Devices are different between features in ' 

1824 '`enqueue_datas_list[{}]`; ' 

1825 'devices: {}, {}; features: {}, {}.'.format( 

1826 i, device, 

1827 enqueue_data.embedding_indices.device, feature, 

1828 device_feature)) 

1829 

1830 if i % self._num_cores_per_host: 

1831 if device != contiguous_device: 

1832 raise ValueError('We expect the `enqueue_datas` which are on the ' 

1833 'same host to be contiguous in ' 

1834 '`enqueue_datas_list`, ' 

1835 '`enqueue_datas_list[{}]` is on device {}, ' 

1836 'but is expected to be on device {}.'.format( 

1837 i, device, contiguous_device)) 

1838 else: 

1839 contiguous_device = device 

1840 

1841 def _generate_enqueue_op(self, 

1842 enqueue_datas, 

1843 device_ordinal, 

1844 mode_override=None, 

1845 ragged=False): 

1846 """Creates op for enqueuing batch to TPU.""" 

1847 enqueue_data0 = list(enqueue_datas.values())[0] 

1848 with ops.colocate_with(enqueue_data0.embedding_indices): 

1849 return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch( 

1850 device_ordinal=device_ordinal, 

1851 combiners=self._combiners, 

1852 mode_override=mode_override, 

1853 **self._format_for_tpu_embedding_arbitrary_tensor_batch( 

1854 enqueue_datas, ragged)) 

1855 

1856 def _format_for_tpu_embedding_arbitrary_tensor_batch(self, enqueue_datas, 

1857 ragged): 

1858 """Format features for `enqueue_tpu_embedding_arbitrary_tensor_batch()`. 

1859 

1860 Args: 

1861 enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding. 

1862 ragged: If True, extract row splits from the data rather than sample 

1863 indices. 

1864 

1865 Returns: 

1866 Dict of arguments for `enqueue_tpu_embedding_arbitrary_tensor_batch()`. 

1867 """ 

1868 

1869 kwargs = { 

1870 'sample_indices_or_row_splits': [], 

1871 'embedding_indices': [], 

1872 'aggregation_weights': [], 

1873 } 

1874 int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) 

1875 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 

1876 for table in self._table_to_features_dict: 

1877 features = self._table_to_features_dict[table] 

1878 for feature in features: 

1879 enqueue_data = enqueue_datas[feature] 

1880 if ragged: 

1881 kwargs['sample_indices_or_row_splits'].append( 

1882 enqueue_data.row_splits if enqueue_data 

1883 .row_splits is not None else int_zeros) 

1884 else: 

1885 if (self._feature_to_config_dict[feature].max_sequence_length > 0 and 

1886 enqueue_data.sample_indices is not None and 

1887 enqueue_data.sample_indices.shape[1] == 2): 

1888 # Pad the sample indices as if the enqueued sparse tensor is rank 2. 

1889 sample_indices = array_ops.pad( 

1890 enqueue_data.sample_indices, paddings=[[0, 0], [0, 1]]) 

1891 kwargs['sample_indices_or_row_splits'].append(sample_indices) 

1892 else: 

1893 # If the sample_indices is rank 1 or not present, treat it as dense 

1894 # tensor. 

1895 if (enqueue_data.sample_indices is None or 

1896 enqueue_data.sample_indices.shape[1] == 1): 

1897 kwargs['sample_indices_or_row_splits'].append(int_zeros) 

1898 else: 

1899 kwargs['sample_indices_or_row_splits'].append( 

1900 enqueue_data.sample_indices) 

1901 

1902 kwargs['aggregation_weights'].append( 

1903 enqueue_data.aggregation_weights if enqueue_data 

1904 .aggregation_weights is not None else float_zeros) 

1905 

1906 kwargs['embedding_indices'].append(enqueue_data.embedding_indices) 

1907 return kwargs 

1908 

1909 def get_activations(self): 

1910 """Get activations for features. 

1911 

1912 This should be called within `computation` that is passed to 

1913 `tpu.replicate` and friends. 

1914 

1915 Returns: 

1916 A dictionary mapping from `String` of feature name to `Tensor` 

1917 of activation. 

1918 """ 

1919 recv_activations = tpu_ops.recv_tpu_embedding_activations( 

1920 num_outputs=len(self._feature_to_config_dict), 

1921 config=self._config_proto.SerializeToString()) 

1922 

1923 activations = collections.OrderedDict() 

1924 index = 0 

1925 for table in self._table_to_features_dict: 

1926 for feature in self._table_to_features_dict[table]: 

1927 activations[feature] = recv_activations[index] 

1928 index += 1 

1929 return activations 

1930 

1931 def generate_send_gradients_op(self, feature_to_gradient_dict, step=None): 

1932 """Send gradient to TPU embedding. 

1933 

1934 Args: 

1935 feature_to_gradient_dict: dict mapping feature names to gradient wrt 

1936 activations. 

1937 step: the current global step, used for dynamic learning rate. 

1938 

1939 Returns: 

1940 SendTPUEmbeddingGradients Op. 

1941 

1942 Raises: 

1943 RuntimeError: If `mode` is not `TRAINING`. 

1944 """ 

1945 if self._mode != TRAINING: 

1946 raise RuntimeError('Only in training mode gradients need to ' 

1947 'be sent to TPU embedding; got mode {}.'.format( 

1948 self._mode)) 

1949 if step is None and self._learning_rate_fn: 

1950 raise ValueError('There are dynamic learning rates but step is None.') 

1951 

1952 gradients = [] 

1953 for table in self._table_to_features_dict: 

1954 for feature in self._table_to_features_dict[table]: 

1955 gradients.append(feature_to_gradient_dict[feature]) 

1956 

1957 return tpu_ops.send_tpu_embedding_gradients( 

1958 inputs=gradients, 

1959 learning_rates=[ 

1960 math_ops.cast(fn(step), dtype=dtypes.float32) 

1961 for fn in self._learning_rate_fn 

1962 ], 

1963 config=self.config_proto.SerializeToString()) 

1964 

1965 def _get_optimizer_handler_by_table(self): 

1966 optimizer_handlers = {} 

1967 for table, table_config in self.table_to_config_dict.items(): 

1968 if table_config.optimization_parameters is not None: 

1969 optimizer = table_config.optimization_parameters 

1970 else: 

1971 optimizer = self._optimization_parameters 

1972 optimizer_handlers[table] = _get_optimization_handler(optimizer) 

1973 

1974 return optimizer_handlers 

1975 

1976 

1977def _validate_table_to_config_dict(table_to_config_dict): 

1978 """Validate `table_to_config_dict`.""" 

1979 for k, v in table_to_config_dict.items(): 

1980 if not isinstance(v, TableConfig): 

1981 raise ValueError('Value of `table_to_config_dict` must be of type ' 

1982 '`TableConfig`, got {} for {}.'.format(type(v), k)) 

1983 

1984 

1985def _validate_feature_to_config_dict(table_to_config_dict, 

1986 feature_to_config_dict): 

1987 """Validate `feature_to_config_dict`.""" 

1988 used_table_set = set( 

1989 [feature.table_id for feature in feature_to_config_dict.values()]) 

1990 table_set = set(table_to_config_dict.keys()) 

1991 

1992 unused_table_set = table_set - used_table_set 

1993 if unused_table_set: 

1994 raise ValueError( 

1995 '`table_to_config_dict` specifies table that is not ' 

1996 'used in `feature_to_config_dict`: {}.'.format(unused_table_set)) 

1997 

1998 extra_table_set = used_table_set - table_set 

1999 if extra_table_set: 

2000 raise ValueError( 

2001 '`feature_to_config_dict` refers to a table that is not ' 

2002 'specified in `table_to_config_dict`: {}.'.format(extra_table_set)) 

2003 

2004 

2005def _validate_batch_size(batch_size, num_cores): 

2006 if batch_size % num_cores: 

2007 raise ValueError('`batch_size` is not a multiple of number of ' 

2008 'cores. `batch_size`={}, `_num_cores`={}.'.format( 

2009 batch_size, num_cores)) 

2010 

2011 

2012def _validate_optimization_parameters(optimization_parameters, 

2013 table_to_config_dict): 

2014 """Validate global optimization_parameters and per table optimizers. 

2015 

2016 If global optimizer is `None`, all table optimizers should be non `None`. 

2017 

2018 Args: 

2019 optimization_parameters: global optimizer provided in `TPUEmbedding` 

2020 constructor. 

2021 table_to_config_dict: A dictionary mapping from string of table name to 

2022 `TableConfig`. 

2023 """ 

2024 tbl_optimizer_missing = False 

2025 for _, table_config in table_to_config_dict.items(): 

2026 if table_config.optimization_parameters is None: 

2027 tbl_optimizer_missing = True 

2028 break 

2029 

2030 if optimization_parameters: 

2031 if not isinstance(optimization_parameters, _OptimizationParameters): 

2032 raise ValueError('`optimization_parameters` must inherit from ' 

2033 '`_OptimizationParameters`. ' 

2034 '`type(optimization_parameters)`={}'.format( 

2035 type(optimization_parameters))) 

2036 else: 

2037 # Missing global optimization_parameters. 

2038 if tbl_optimizer_missing: 

2039 raise ValueError('`optimization_parameters` is missing.') 

2040 

2041 

2042class _OptimizerHandler: 

2043 """Interface class for handling optimizer specific logic.""" 

2044 

2045 def __init__(self, optimization_parameters): 

2046 self._optimization_parameters = optimization_parameters 

2047 

2048 def get_optimization_parameters(self): 

2049 return self._optimization_parameters 

2050 

2051 def set_optimization_parameters(self, table_descriptor): 

2052 raise NotImplementedError() 

2053 

2054 def get_default_slot_variable_names(self, table): 

2055 raise NotImplementedError() 

2056 

2057 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2058 table_config, table_variables, config_proto): 

2059 raise NotImplementedError() 

2060 

2061 

2062class _AdagradHandler(_OptimizerHandler): 

2063 """Handles Adagrad specific logic.""" 

2064 

2065 def set_optimization_parameters(self, table_descriptor): 

2066 table_descriptor.optimization_parameters.adagrad.SetInParent() 

2067 

2068 def get_default_slot_variable_names(self, table): 

2069 return AdagradSlotVariableNames('{}/{}'.format(table, 'Adagrad')) 

2070 

2071 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2072 table_config, table_variables, config_proto): 

2073 accumulator_initializer = init_ops.constant_initializer( 

2074 self._optimization_parameters.initial_accumulator) 

2075 accumulator_variables = _create_partitioned_variables( 

2076 name=slot_variable_names.accumulator, 

2077 num_hosts=num_hosts, 

2078 vocabulary_size=table_config.vocabulary_size, 

2079 embedding_dimension=table_config.dimension, 

2080 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2081 initializer=accumulator_initializer) 

2082 slot_variables = AdagradSlotVariables(accumulator_variables) 

2083 

2084 def load_ops_fn(): 

2085 """Returns the retrieve ops for AdaGrad embedding tables. 

2086 

2087 Returns: 

2088 A list of ops to load embedding and slot variables from CPU to TPU. 

2089 """ 

2090 config = config_proto 

2091 load_op_list = [] 

2092 for host_id, table_variable, accumulator_variable in zip( 

2093 range(num_hosts), table_variables, accumulator_variables): 

2094 with ops.colocate_with(table_variable): 

2095 load_parameters_op = ( 

2096 tpu_ops.load_tpu_embedding_adagrad_parameters( 

2097 parameters=table_variable, 

2098 accumulators=accumulator_variable, 

2099 table_name=table, 

2100 num_shards=num_hosts, 

2101 shard_id=host_id, 

2102 config=config)) 

2103 config = None 

2104 load_op_list.append(load_parameters_op) 

2105 return load_op_list 

2106 

2107 def retrieve_ops_fn(): 

2108 """Returns the retrieve ops for AdaGrad embedding tables. 

2109 

2110 Returns: 

2111 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2112 """ 

2113 config = config_proto 

2114 retrieve_op_list = [] 

2115 for host_id, table_variable, accumulator_variable in (zip( 

2116 range(num_hosts), table_variables, accumulator_variables)): 

2117 with ops.colocate_with(table_variable): 

2118 retrieved_table, retrieved_accumulator = ( 

2119 tpu_ops.retrieve_tpu_embedding_adagrad_parameters( 

2120 table_name=table, 

2121 num_shards=num_hosts, 

2122 shard_id=host_id, 

2123 config=config)) 

2124 retrieve_parameters_op = control_flow_ops.group( 

2125 state_ops.assign(table_variable, retrieved_table), 

2126 state_ops.assign(accumulator_variable, retrieved_accumulator)) 

2127 config = None 

2128 retrieve_op_list.append(retrieve_parameters_op) 

2129 return retrieve_op_list 

2130 

2131 return slot_variables, load_ops_fn, retrieve_ops_fn 

2132 

2133 

2134class _AdagradMomentumHandler(_OptimizerHandler): 

2135 """Handles Adagrad with Momentum specific logic. 

2136 

2137 Creates slot variables and defines their initializers. Defines load/retrieve 

2138 operations to be used for loading variables into TPU memory (from host memory) 

2139 and retrieving variables from TPU memory (into host memory). 

2140 """ 

2141 

2142 def set_optimization_parameters(self, table_descriptor): 

2143 table_descriptor.optimization_parameters.adagrad_momentum.SetInParent() 

2144 table_descriptor.optimization_parameters.adagrad_momentum.momentum = ( 

2145 self._optimization_parameters.momentum) 

2146 table_descriptor.optimization_parameters.adagrad_momentum.use_nesterov = ( 

2147 self._optimization_parameters.use_nesterov) 

2148 table_descriptor.optimization_parameters.adagrad_momentum.exponent = ( 

2149 self._optimization_parameters.exponent) 

2150 table_descriptor.optimization_parameters.adagrad_momentum.beta2 = ( 

2151 self._optimization_parameters.beta2) 

2152 table_descriptor.optimization_parameters.adagrad_momentum.epsilon = ( 

2153 self._optimization_parameters.epsilon) 

2154 

2155 def get_default_slot_variable_names(self, table): 

2156 return AdagradMomentumSlotVariableNames( 

2157 '{}/{}/Accumulator'.format(table, 'AdagradMomentum'), 

2158 '{}/{}/Momentum'.format(table, 'AdagradMomentum')) 

2159 

2160 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2161 table_config, table_variables, config_proto): 

2162 accumulator_initializer = init_ops.zeros_initializer() 

2163 accumulator_variables = _create_partitioned_variables( 

2164 name=slot_variable_names.accumulator, 

2165 num_hosts=num_hosts, 

2166 vocabulary_size=table_config.vocabulary_size, 

2167 embedding_dimension=table_config.dimension, 

2168 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2169 initializer=accumulator_initializer) 

2170 momenta_initializer = init_ops.zeros_initializer() 

2171 momenta_variables = _create_partitioned_variables( 

2172 name=slot_variable_names.momenta, 

2173 num_hosts=num_hosts, 

2174 vocabulary_size=table_config.vocabulary_size, 

2175 embedding_dimension=table_config.dimension, 

2176 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2177 initializer=momenta_initializer) 

2178 slot_variables = AdagradMomentumSlotVariables(accumulator_variables, 

2179 momenta_variables) 

2180 

2181 def load_ops_fn(): 

2182 """Returns the load ops for AdaGrad with momentum embedding tables. 

2183 

2184 Returns: 

2185 A list of ops to load embedding and slot variables from CPU to TPU. 

2186 """ 

2187 config = config_proto 

2188 load_op_list = [] 

2189 for host_id, table_variable, accumulator_variable, momenta_variable in zip( 

2190 range(num_hosts), table_variables, accumulator_variables, 

2191 momenta_variables): 

2192 with ops.colocate_with(table_variable): 

2193 load_parameters_op = ( 

2194 tpu_ops.load_tpu_embedding_adagrad_momentum_parameters( 

2195 parameters=table_variable, 

2196 accumulators=accumulator_variable, 

2197 momenta=momenta_variable, 

2198 table_name=table, 

2199 num_shards=num_hosts, 

2200 shard_id=host_id, 

2201 config=config)) 

2202 config = None 

2203 load_op_list.append(load_parameters_op) 

2204 return load_op_list 

2205 

2206 def retrieve_ops_fn(): 

2207 """Returns the retrieve ops for AdaGrad with momentum embedding tables. 

2208 

2209 Returns: 

2210 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2211 """ 

2212 config = config_proto 

2213 retrieve_op_list = [] 

2214 for host_id, table_variable, accumulator_variable, momenta_variable in ( 

2215 zip( 

2216 range(num_hosts), table_variables, accumulator_variables, 

2217 momenta_variables)): 

2218 with ops.colocate_with(table_variable): 

2219 retrieved_table, retrieved_accumulator, retrieved_momenta = ( 

2220 tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters( 

2221 table_name=table, 

2222 num_shards=num_hosts, 

2223 shard_id=host_id, 

2224 config=config)) 

2225 retrieve_parameters_op = control_flow_ops.group( 

2226 state_ops.assign(table_variable, retrieved_table), 

2227 state_ops.assign(accumulator_variable, retrieved_accumulator), 

2228 state_ops.assign(momenta_variable, retrieved_momenta)) 

2229 config = None 

2230 retrieve_op_list.append(retrieve_parameters_op) 

2231 return retrieve_op_list 

2232 

2233 return slot_variables, load_ops_fn, retrieve_ops_fn 

2234 

2235 

2236class _ProximalAdagradHandler(_OptimizerHandler): 

2237 """Handles ProximalAdagrad specific logic.""" 

2238 

2239 def set_optimization_parameters(self, table_descriptor): 

2240 table_descriptor.optimization_parameters.proximal_adagrad.SetInParent() 

2241 table_descriptor.optimization_parameters.proximal_adagrad.l1 = ( 

2242 self._optimization_parameters.l1_regularization_strength) 

2243 table_descriptor.optimization_parameters.proximal_adagrad.l2 = ( 

2244 self._optimization_parameters.l2_regularization_strength) 

2245 

2246 def get_default_slot_variable_names(self, table): 

2247 return ProximalAdagradSlotVariableNames('{}/{}'.format( 

2248 table, 'ProximalAdagrad')) 

2249 

2250 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2251 table_config, table_variables, config_proto): 

2252 accumulator_initializer = init_ops.constant_initializer( 

2253 self._optimization_parameters.initial_accumulator) 

2254 accumulator_variables = _create_partitioned_variables( 

2255 name=slot_variable_names.accumulator, 

2256 num_hosts=num_hosts, 

2257 vocabulary_size=table_config.vocabulary_size, 

2258 embedding_dimension=table_config.dimension, 

2259 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2260 initializer=accumulator_initializer) 

2261 slot_variables = ProximalAdagradSlotVariables(accumulator_variables) 

2262 

2263 def load_ops_fn(): 

2264 """Returns the retrieve ops for Proximal AdaGrad embedding tables. 

2265 

2266 Returns: 

2267 A list of ops to load embedding and slot variables from CPU to TPU. 

2268 """ 

2269 config = config_proto 

2270 load_op_list = [] 

2271 for host_id, table_variable, accumulator_variable in zip( 

2272 range(num_hosts), table_variables, accumulator_variables): 

2273 with ops.colocate_with(table_variable): 

2274 load_parameters_op = ( 

2275 tpu_ops.load_tpu_embedding_proximal_adagrad_parameters( 

2276 parameters=table_variable, 

2277 accumulators=accumulator_variable, 

2278 table_name=table, 

2279 num_shards=num_hosts, 

2280 shard_id=host_id, 

2281 config=config)) 

2282 config = None 

2283 load_op_list.append(load_parameters_op) 

2284 return load_op_list 

2285 

2286 def retrieve_ops_fn(): 

2287 """Returns the retrieve ops for Proximal AdaGrad embedding tables. 

2288 

2289 Returns: 

2290 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2291 """ 

2292 config = config_proto 

2293 retrieve_op_list = [] 

2294 for host_id, table_variable, accumulator_variable in (zip( 

2295 range(num_hosts), table_variables, accumulator_variables)): 

2296 with ops.colocate_with(table_variable): 

2297 retrieved_table, retrieved_accumulator = ( 

2298 tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters( 

2299 table_name=table, 

2300 num_shards=num_hosts, 

2301 shard_id=host_id, 

2302 config=config)) 

2303 retrieve_parameters_op = control_flow_ops.group( 

2304 state_ops.assign(table_variable, retrieved_table), 

2305 state_ops.assign(accumulator_variable, retrieved_accumulator)) 

2306 config = None 

2307 retrieve_op_list.append(retrieve_parameters_op) 

2308 return retrieve_op_list 

2309 

2310 return slot_variables, load_ops_fn, retrieve_ops_fn 

2311 

2312 

2313class _AdamHandler(_OptimizerHandler): 

2314 """Handles Adam specific logic.""" 

2315 

2316 def set_optimization_parameters(self, table_descriptor): 

2317 table_descriptor.optimization_parameters.adam.beta1 = ( 

2318 self._optimization_parameters.beta1) 

2319 table_descriptor.optimization_parameters.adam.beta2 = ( 

2320 self._optimization_parameters.beta2) 

2321 table_descriptor.optimization_parameters.adam.epsilon = ( 

2322 self._optimization_parameters.epsilon) 

2323 table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( 

2324 not self._optimization_parameters.lazy_adam) 

2325 table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( 

2326 self._optimization_parameters.sum_inside_sqrt) 

2327 

2328 def get_default_slot_variable_names(self, table): 

2329 return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), 

2330 '{}/{}/v'.format(table, 'Adam')) 

2331 

2332 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2333 table_config, table_variables, config_proto): 

2334 m_initializer = init_ops.zeros_initializer() 

2335 m_variables = _create_partitioned_variables( 

2336 name=slot_variable_names.m, 

2337 num_hosts=num_hosts, 

2338 vocabulary_size=table_config.vocabulary_size, 

2339 embedding_dimension=table_config.dimension, 

2340 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2341 initializer=m_initializer) 

2342 v_initializer = init_ops.zeros_initializer() 

2343 v_variables = _create_partitioned_variables( 

2344 name=slot_variable_names.v, 

2345 num_hosts=num_hosts, 

2346 vocabulary_size=table_config.vocabulary_size, 

2347 embedding_dimension=table_config.dimension, 

2348 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2349 initializer=v_initializer) 

2350 slot_variables = AdamSlotVariables(m_variables, v_variables) 

2351 

2352 def load_ops_fn(): 

2353 """Returns the retrieve ops for AdaGrad embedding tables. 

2354 

2355 Returns: 

2356 A list of ops to load embedding and slot variables from CPU to TPU. 

2357 """ 

2358 load_op_list = [] 

2359 config = config_proto 

2360 for host_id, table_variable, m_variable, v_variable in (zip( 

2361 range(num_hosts), table_variables, m_variables, v_variables)): 

2362 with ops.colocate_with(table_variable): 

2363 load_parameters_op = ( 

2364 tpu_ops.load_tpu_embedding_adam_parameters( 

2365 parameters=table_variable, 

2366 momenta=m_variable, 

2367 velocities=v_variable, 

2368 table_name=table, 

2369 num_shards=num_hosts, 

2370 shard_id=host_id, 

2371 config=config)) 

2372 # Set config to None to enforce that config is only loaded to the first 

2373 # table. 

2374 config = None 

2375 load_op_list.append(load_parameters_op) 

2376 return load_op_list 

2377 

2378 def retrieve_ops_fn(): 

2379 """Returns the retrieve ops for Adam embedding tables. 

2380 

2381 Returns: 

2382 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2383 """ 

2384 retrieve_op_list = [] 

2385 config = config_proto 

2386 for host_id, table_variable, m_variable, v_variable in (zip( 

2387 range(num_hosts), table_variables, m_variables, v_variables)): 

2388 with ops.colocate_with(table_variable): 

2389 retrieved_table, retrieved_m, retrieved_v = ( 

2390 tpu_ops.retrieve_tpu_embedding_adam_parameters( 

2391 table_name=table, 

2392 num_shards=num_hosts, 

2393 shard_id=host_id, 

2394 config=config)) 

2395 retrieve_parameters_op = control_flow_ops.group( 

2396 state_ops.assign(table_variable, retrieved_table), 

2397 state_ops.assign(m_variable, retrieved_m), 

2398 state_ops.assign(v_variable, retrieved_v)) 

2399 config = None 

2400 retrieve_op_list.append(retrieve_parameters_op) 

2401 return retrieve_op_list 

2402 

2403 return slot_variables, load_ops_fn, retrieve_ops_fn 

2404 

2405 

2406class _FtrlHandler(_OptimizerHandler): 

2407 """Handles Ftrl specific logic.""" 

2408 

2409 def set_optimization_parameters(self, table_descriptor): 

2410 table_descriptor.optimization_parameters.ftrl.lr_power = ( 

2411 self._optimization_parameters.learning_rate_power) 

2412 table_descriptor.optimization_parameters.ftrl.l1 = ( 

2413 self._optimization_parameters.l1_regularization_strength) 

2414 table_descriptor.optimization_parameters.ftrl.l2 = ( 

2415 self._optimization_parameters.l2_regularization_strength) 

2416 table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = ( 

2417 self._optimization_parameters.multiply_linear_by_learning_rate) 

2418 table_descriptor.optimization_parameters.ftrl.beta = ( 

2419 self._optimization_parameters.beta) 

2420 table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = ( 

2421 self._optimization_parameters.allow_zero_accumulator) 

2422 

2423 def get_default_slot_variable_names(self, table): 

2424 # These match the default slot variable names created by 

2425 # tf.train.FtrlOptimizer. 

2426 return FtrlSlotVariableNames( 

2427 '{}/{}'.format(table, 'Ftrl'), # accumulator 

2428 '{}/{}'.format(table, 'Ftrl_1')) # linear 

2429 

2430 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2431 table_config, table_variables, config_proto): 

2432 accumulator_initializer = init_ops.constant_initializer( 

2433 self._optimization_parameters.initial_accumulator_value) 

2434 accumulator_variables = _create_partitioned_variables( 

2435 name=slot_variable_names.accumulator, 

2436 num_hosts=num_hosts, 

2437 vocabulary_size=table_config.vocabulary_size, 

2438 embedding_dimension=table_config.dimension, 

2439 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2440 initializer=accumulator_initializer) 

2441 linear_initializer = init_ops.constant_initializer( 

2442 self._optimization_parameters.initial_linear_value) 

2443 linear_variables = _create_partitioned_variables( 

2444 name=slot_variable_names.linear, 

2445 num_hosts=num_hosts, 

2446 vocabulary_size=table_config.vocabulary_size, 

2447 embedding_dimension=table_config.dimension, 

2448 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2449 initializer=linear_initializer) 

2450 slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables) 

2451 

2452 def load_ops_fn(): 

2453 """Returns the retrieve ops for Ftrl embedding tables. 

2454 

2455 Returns: 

2456 A list of ops to load embedding and slot variables from CPU to TPU. 

2457 """ 

2458 config = config_proto 

2459 load_op_list = [] 

2460 for host_id, table_variable, accumulator_variable, linear_variable in zip( 

2461 range(num_hosts), table_variables, accumulator_variables, 

2462 linear_variables): 

2463 with ops.colocate_with(table_variable): 

2464 load_parameters_op = ( 

2465 tpu_ops.load_tpu_embedding_ftrl_parameters( 

2466 parameters=table_variable, 

2467 accumulators=accumulator_variable, 

2468 linears=linear_variable, 

2469 table_name=table, 

2470 num_shards=num_hosts, 

2471 shard_id=host_id, 

2472 config=config)) 

2473 config = None 

2474 load_op_list.append(load_parameters_op) 

2475 return load_op_list 

2476 

2477 def retrieve_ops_fn(): 

2478 """Returns the retrieve ops for Ftrl embedding tables. 

2479 

2480 Returns: 

2481 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2482 """ 

2483 config = config_proto 

2484 retrieve_op_list = [] 

2485 for host_id, table_variable, accumulator_variable, linear_variable in zip( 

2486 range(num_hosts), table_variables, accumulator_variables, 

2487 linear_variables): 

2488 with ops.colocate_with(table_variable): 

2489 retrieved_table, retrieved_accumulator, retrieved_linear = ( 

2490 tpu_ops.retrieve_tpu_embedding_ftrl_parameters( 

2491 table_name=table, 

2492 num_shards=num_hosts, 

2493 shard_id=host_id, 

2494 config=config)) 

2495 retrieve_parameters_op = control_flow_ops.group( 

2496 state_ops.assign(table_variable, retrieved_table), 

2497 state_ops.assign(accumulator_variable, retrieved_accumulator), 

2498 state_ops.assign(linear_variable, retrieved_linear)) 

2499 config = None 

2500 retrieve_op_list.append(retrieve_parameters_op) 

2501 return retrieve_op_list 

2502 

2503 return slot_variables, load_ops_fn, retrieve_ops_fn 

2504 

2505 

2506class _ProximalYogiHandler(_OptimizerHandler): 

2507 """Handles Proximal Yogi specific logic.""" 

2508 

2509 def set_optimization_parameters(self, table_descriptor): 

2510 table_descriptor.optimization_parameters.proximal_yogi.SetInParent() 

2511 table_descriptor.optimization_parameters.proximal_yogi.beta1 = ( 

2512 self._optimization_parameters.beta1) 

2513 table_descriptor.optimization_parameters.proximal_yogi.beta2 = ( 

2514 self._optimization_parameters.beta2) 

2515 table_descriptor.optimization_parameters.proximal_yogi.epsilon = ( 

2516 self._optimization_parameters.epsilon) 

2517 table_descriptor.optimization_parameters.proximal_yogi.l1 = ( 

2518 self._optimization_parameters.l1_regularization_strength) 

2519 table_descriptor.optimization_parameters.proximal_yogi.l2 = ( 

2520 self._optimization_parameters.l2_regularization_strength) 

2521 

2522 def get_default_slot_variable_names(self, table): 

2523 return ProximalYogiSlotVariableNames( 

2524 '{}/{}'.format(table, 'ProximalYogi'), # v 

2525 '{}/{}_1'.format(table, 'ProximalYogi')) # m 

2526 

2527 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2528 table_config, table_variables, config_proto): 

2529 v_initializer = init_ops.constant_initializer( 

2530 self._optimization_parameters.initial_accumulator_value) 

2531 v_variables = _create_partitioned_variables( 

2532 name=slot_variable_names.v, 

2533 num_hosts=num_hosts, 

2534 vocabulary_size=table_config.vocabulary_size, 

2535 embedding_dimension=table_config.dimension, 

2536 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2537 initializer=v_initializer) 

2538 m_initializer = init_ops.zeros_initializer() 

2539 m_variables = _create_partitioned_variables( 

2540 name=slot_variable_names.m, 

2541 num_hosts=num_hosts, 

2542 vocabulary_size=table_config.vocabulary_size, 

2543 embedding_dimension=table_config.dimension, 

2544 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2545 initializer=m_initializer) 

2546 slot_variables = ProximalYogiSlotVariables(v_variables, m_variables) 

2547 

2548 def load_ops_fn(): 

2549 """Returns the load ops for Proximal Yogi embedding tables. 

2550 

2551 Returns: 

2552 A list of ops to load embedding and slot variables from CPU to TPU. 

2553 """ 

2554 load_op_list = [] 

2555 config = config_proto 

2556 for host_id, table_variable, v_variable, m_variable in (zip( 

2557 range(num_hosts), table_variables, v_variables, m_variables)): 

2558 with ops.colocate_with(table_variable): 

2559 load_parameters_op = ( 

2560 tpu_ops.load_tpu_embedding_proximal_yogi_parameters( 

2561 parameters=table_variable, 

2562 v=v_variable, 

2563 m=m_variable, 

2564 table_name=table, 

2565 num_shards=num_hosts, 

2566 shard_id=host_id, 

2567 config=config)) 

2568 # Set config to None to enforce that config is only loaded to the first 

2569 # table. 

2570 config = None 

2571 load_op_list.append(load_parameters_op) 

2572 return load_op_list 

2573 

2574 def retrieve_ops_fn(): 

2575 """Returns the retrieve ops for Proximal Yogi embedding tables. 

2576 

2577 Returns: 

2578 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2579 """ 

2580 retrieve_op_list = [] 

2581 config = config_proto 

2582 for host_id, table_variable, v_variable, m_variable in (zip( 

2583 range(num_hosts), table_variables, v_variables, m_variables)): 

2584 with ops.colocate_with(table_variable): 

2585 retrieved_table, retrieved_v, retrieved_m = ( 

2586 tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters( 

2587 table_name=table, 

2588 num_shards=num_hosts, 

2589 shard_id=host_id, 

2590 config=config)) 

2591 retrieve_parameters_op = control_flow_ops.group( 

2592 state_ops.assign(table_variable, retrieved_table), 

2593 state_ops.assign(v_variable, retrieved_v), 

2594 state_ops.assign(m_variable, retrieved_m)) 

2595 config = None 

2596 retrieve_op_list.append(retrieve_parameters_op) 

2597 return retrieve_op_list 

2598 

2599 return slot_variables, load_ops_fn, retrieve_ops_fn 

2600 

2601 

2602class _MomentumHandler(_OptimizerHandler): 

2603 """Handles Momentum specific logic.""" 

2604 

2605 def set_optimization_parameters(self, table_descriptor): 

2606 (table_descriptor.optimization_parameters.momentum.SetInParent()) 

2607 table_descriptor.optimization_parameters.momentum.momentum = ( 

2608 self._optimization_parameters.momentum) 

2609 table_descriptor.optimization_parameters.momentum.use_nesterov = ( 

2610 self._optimization_parameters.use_nesterov) 

2611 

2612 def get_default_slot_variable_names(self, table): 

2613 return MomentumSlotVariableNames('{}/{}'.format(table, 'Momentum')) 

2614 

2615 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2616 table_config, table_variables, config_proto): 

2617 

2618 momenta_initializer = init_ops.zeros_initializer() 

2619 momenta_variables = _create_partitioned_variables( 

2620 name=slot_variable_names.momenta, 

2621 num_hosts=num_hosts, 

2622 vocabulary_size=table_config.vocabulary_size, 

2623 embedding_dimension=table_config.dimension, 

2624 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2625 initializer=momenta_initializer) 

2626 slot_variables = MomentumSlotVariables(momenta_variables) 

2627 

2628 def load_ops_fn(): 

2629 """Returns the retrieve ops for Momentum embedding tables. 

2630 

2631 Returns: 

2632 A list of ops to load embedding and slot variables from CPU to TPU. 

2633 """ 

2634 load_op_list = [] 

2635 config = config_proto 

2636 for host_id, table_variable, momenta_variable in (zip( 

2637 range(num_hosts), table_variables, momenta_variables)): 

2638 with ops.colocate_with(table_variable): 

2639 load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters( 

2640 parameters=table_variable, 

2641 momenta=momenta_variable, 

2642 table_name=table, 

2643 num_shards=num_hosts, 

2644 shard_id=host_id, 

2645 config=config, 

2646 ) 

2647 config = None 

2648 load_op_list.append(load_parameters_op) 

2649 return load_op_list 

2650 

2651 def retrieve_ops_fn(): 

2652 """Returns the retrieve ops for Momentum embedding tables. 

2653 

2654 Returns: 

2655 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2656 """ 

2657 retrieve_op_list = [] 

2658 config = config_proto 

2659 for host_id, table_variable, momenta_variable in (zip( 

2660 range(num_hosts), table_variables, momenta_variables)): 

2661 with ops.colocate_with(table_variable): 

2662 retrieved_table, retrieved_momenta = ( 

2663 tpu_ops.retrieve_tpu_embedding_momentum_parameters( 

2664 table_name=table, 

2665 num_shards=num_hosts, 

2666 shard_id=host_id, 

2667 config=config, 

2668 )) 

2669 retrieve_parameters_op = control_flow_ops.group( 

2670 state_ops.assign(table_variable, retrieved_table), 

2671 state_ops.assign(momenta_variable, retrieved_momenta)) 

2672 config = None 

2673 retrieve_op_list.append(retrieve_parameters_op) 

2674 return retrieve_op_list 

2675 

2676 return slot_variables, load_ops_fn, retrieve_ops_fn 

2677 

2678 

2679class _RMSPropHandler(_OptimizerHandler): 

2680 """Handles RMS prop specific logic.""" 

2681 

2682 def set_optimization_parameters(self, table_descriptor): 

2683 (table_descriptor.optimization_parameters.rms_prop.SetInParent()) 

2684 table_descriptor.optimization_parameters.rms_prop.rho = ( 

2685 self._optimization_parameters.rho) 

2686 table_descriptor.optimization_parameters.rms_prop.epsilon = ( 

2687 self._optimization_parameters.epsilon) 

2688 table_descriptor.optimization_parameters.rms_prop.momentum = ( 

2689 self._optimization_parameters.momentum) 

2690 

2691 def get_default_slot_variable_names(self, table): 

2692 return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'), 

2693 '{}/{}/mom'.format(table, 'RMSProp')) 

2694 

2695 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2696 table_config, table_variables, config_proto): 

2697 

2698 ms_variables = _create_partitioned_variables( 

2699 name=slot_variable_names.ms, 

2700 num_hosts=num_hosts, 

2701 vocabulary_size=table_config.vocabulary_size, 

2702 embedding_dimension=table_config.dimension, 

2703 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2704 initializer=init_ops.zeros_initializer(), 

2705 ) 

2706 mom_variables = _create_partitioned_variables( 

2707 name=slot_variable_names.mom, 

2708 num_hosts=num_hosts, 

2709 vocabulary_size=table_config.vocabulary_size, 

2710 embedding_dimension=table_config.dimension, 

2711 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2712 initializer=init_ops.zeros_initializer(), 

2713 ) 

2714 slot_variables = RMSPropSlotVariables(ms_variables, mom_variables) 

2715 

2716 def load_ops_fn(): 

2717 """Returns the retrieve ops for RMS Prop embedding tables. 

2718 

2719 Returns: 

2720 A list of ops to load embedding and slot variables from CPU to TPU. 

2721 """ 

2722 load_op_list = [] 

2723 config = config_proto 

2724 for host_id, table_variable, ms_variable, mom_variable in (zip( 

2725 range(num_hosts), table_variables, ms_variables, mom_variables)): 

2726 with ops.colocate_with(table_variable): 

2727 load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters( 

2728 parameters=table_variable, 

2729 ms=ms_variable, 

2730 mom=mom_variable, 

2731 table_name=table, 

2732 num_shards=num_hosts, 

2733 shard_id=host_id, 

2734 config=config, 

2735 ) 

2736 config = None 

2737 load_op_list.append(load_parameters_op) 

2738 return load_op_list 

2739 

2740 def retrieve_ops_fn(): 

2741 """Returns the retrieve ops for RMS Prop embedding tables. 

2742 

2743 Returns: 

2744 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2745 """ 

2746 retrieve_op_list = [] 

2747 config = config_proto 

2748 for host_id, table_variable, ms_variable, mom_variable in (zip( 

2749 range(num_hosts), table_variables, ms_variables, mom_variables)): 

2750 with ops.colocate_with(table_variable): 

2751 retrieved_table, retrieved_ms, retrieved_mom = ( 

2752 tpu_ops.retrieve_tpu_embedding_rms_prop_parameters( 

2753 table_name=table, 

2754 num_shards=num_hosts, 

2755 shard_id=host_id, 

2756 config=config, 

2757 )) 

2758 retrieve_parameters_op = control_flow_ops.group( 

2759 state_ops.assign(table_variable, retrieved_table), 

2760 state_ops.assign(ms_variable, retrieved_ms), 

2761 state_ops.assign(mom_variable, retrieved_mom)) 

2762 config = None 

2763 retrieve_op_list.append(retrieve_parameters_op) 

2764 return retrieve_op_list 

2765 

2766 return slot_variables, load_ops_fn, retrieve_ops_fn 

2767 

2768 

2769class _FrequencyEstimatorHandler(_OptimizerHandler): 

2770 """Handles frequency estimator specific logic.""" 

2771 

2772 def set_optimization_parameters(self, table_descriptor): 

2773 table_descriptor.optimization_parameters.frequency_estimator.SetInParent() 

2774 freq = table_descriptor.optimization_parameters.frequency_estimator 

2775 freq.tau = self._optimization_parameters.tau 

2776 freq.max_delta = self._optimization_parameters.max_delta 

2777 freq.outlier_threshold = self._optimization_parameters.outlier_threshold 

2778 freq.weight_exponent = self._optimization_parameters.weight_exponent 

2779 

2780 def get_default_slot_variable_names(self, table): 

2781 return FrequencyEstimatorSlotVariableNames( 

2782 '{}/FrequencyEstimator'.format(table)) 

2783 

2784 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2785 table_config, table_variables, config_proto): 

2786 if table_config.dimension != 1: 

2787 raise ValueError('FrequencyEstimator tables should only have a dimension ' 

2788 'of 1. Received dimension {}'.format( 

2789 table_config.dimension)) 

2790 

2791 last_hit_step_variables = _create_partitioned_variables( 

2792 name=slot_variable_names.last_hit_step, 

2793 num_hosts=num_hosts, 

2794 vocabulary_size=table_config.vocabulary_size, 

2795 embedding_dimension=table_config.dimension, 

2796 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 

2797 initializer=init_ops.zeros_initializer(), 

2798 ) 

2799 slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables) 

2800 

2801 def load_ops_fn(): 

2802 """Returns the retrieve ops for Frequency Estimator embedding tables. 

2803 

2804 Returns: 

2805 A list of ops to load embedding and slot variables from CPU to TPU. 

2806 """ 

2807 load_op_list = [] 

2808 config = config_proto 

2809 for host_id, table_variable, last_hit_step_variable in (zip( 

2810 range(num_hosts), table_variables, last_hit_step_variables)): 

2811 with ops.colocate_with(table_variable): 

2812 load_parameters_op = ( 

2813 tpu_ops.load_tpu_embedding_frequency_estimator_parameters( 

2814 parameters=table_variable, 

2815 last_hit_step=last_hit_step_variable, 

2816 table_name=table, 

2817 num_shards=num_hosts, 

2818 shard_id=host_id, 

2819 config=config)) 

2820 config = None 

2821 load_op_list.append(load_parameters_op) 

2822 return load_op_list 

2823 

2824 def retrieve_ops_fn(): 

2825 """Returns the retrieve ops for Frequency Estimator embedding tables. 

2826 

2827 Returns: 

2828 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2829 """ 

2830 retrieve_op_list = [] 

2831 config = config_proto 

2832 for host_id, table_variable, last_hit_step_variable in (zip( 

2833 range(num_hosts), table_variables, last_hit_step_variables)): 

2834 with ops.colocate_with(table_variable): 

2835 retrieved_table, retrieved_last_hit_step = ( 

2836 tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters( 

2837 table_name=table, 

2838 num_shards=num_hosts, 

2839 shard_id=host_id, 

2840 config=config, 

2841 )) 

2842 retrieve_parameters_op = control_flow_ops.group( 

2843 state_ops.assign(table_variable, retrieved_table), 

2844 state_ops.assign(last_hit_step_variable, retrieved_last_hit_step)) 

2845 config = None 

2846 retrieve_op_list.append(retrieve_parameters_op) 

2847 return retrieve_op_list 

2848 

2849 return slot_variables, load_ops_fn, retrieve_ops_fn 

2850 

2851 

2852class _StochasticGradientDescentHandler(_OptimizerHandler): 

2853 """Handles stochastic gradient descent specific logic.""" 

2854 

2855 def set_optimization_parameters(self, table_descriptor): 

2856 (table_descriptor.optimization_parameters.stochastic_gradient_descent 

2857 .SetInParent()) 

2858 

2859 def get_default_slot_variable_names(self, table): 

2860 return None 

2861 

2862 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 

2863 table_config, table_variables, config_proto): 

2864 del table_config 

2865 

2866 def load_ops_fn(): 

2867 """Returns the retrieve ops for AdaGrad embedding tables. 

2868 

2869 Returns: 

2870 A list of ops to load embedding and slot variables from CPU to TPU. 

2871 """ 

2872 load_op_list = [] 

2873 config = config_proto 

2874 for host_id, table_variable in enumerate(table_variables): 

2875 with ops.colocate_with(table_variable): 

2876 load_parameters_op = ( 

2877 tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters( 

2878 parameters=table_variable, 

2879 table_name=table, 

2880 num_shards=num_hosts, 

2881 shard_id=host_id, 

2882 config=config)) 

2883 config = None 

2884 load_op_list.append(load_parameters_op) 

2885 return load_op_list 

2886 

2887 def retrieve_ops_fn(): 

2888 """Returns the retrieve ops for SGD embedding tables. 

2889 

2890 Returns: 

2891 A list of ops to retrieve embedding and slot variables from TPU to CPU. 

2892 """ 

2893 retrieve_op_list = [] 

2894 config = config_proto 

2895 for host_id, table_variable in enumerate(table_variables): 

2896 with ops.colocate_with(table_variable): 

2897 retrieved_table = ( 

2898 tpu_ops 

2899 .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( 

2900 table_name=table, 

2901 num_shards=num_hosts, 

2902 shard_id=host_id, 

2903 config=config)) 

2904 retrieve_parameters_op = control_flow_ops.group( 

2905 state_ops.assign(table_variable, retrieved_table)) 

2906 config = None 

2907 retrieve_op_list.append(retrieve_parameters_op) 

2908 return retrieve_op_list 

2909 

2910 return None, load_ops_fn, retrieve_ops_fn 

2911 

2912 

2913def _get_optimization_handler(optimization_parameters): 

2914 """Gets the optimization handler given the parameter type.""" 

2915 if isinstance(optimization_parameters, AdagradParameters): 

2916 return _AdagradHandler(optimization_parameters) 

2917 elif isinstance(optimization_parameters, AdagradMomentumParameters): 

2918 return _AdagradMomentumHandler(optimization_parameters) 

2919 elif isinstance(optimization_parameters, ProximalAdagradParameters): 

2920 return _ProximalAdagradHandler(optimization_parameters) 

2921 elif isinstance(optimization_parameters, AdamParameters): 

2922 return _AdamHandler(optimization_parameters) 

2923 elif isinstance(optimization_parameters, FtrlParameters): 

2924 return _FtrlHandler(optimization_parameters) 

2925 elif isinstance(optimization_parameters, ProximalYogiParameters): 

2926 return _ProximalYogiHandler(optimization_parameters) 

2927 elif isinstance(optimization_parameters, StochasticGradientDescentParameters): 

2928 return _StochasticGradientDescentHandler(optimization_parameters) 

2929 elif isinstance(optimization_parameters, MomentumParameters): 

2930 return _MomentumHandler(optimization_parameters) 

2931 elif isinstance(optimization_parameters, RMSPropParameters): 

2932 return _RMSPropHandler(optimization_parameters) 

2933 elif isinstance(optimization_parameters, FrequencyEstimatorParameters): 

2934 return _FrequencyEstimatorHandler(optimization_parameters) 

2935 return NotImplementedError() 

2936 

2937 

2938def _create_ordered_dict(d): 

2939 """Create an OrderedDict from Dict.""" 

2940 return collections.OrderedDict((k, d[k]) for k in sorted(d)) 

2941 

2942 

2943def _create_combiners(table_to_config_dict, table_to_features_dict): 

2944 """Create a per feature list of combiners, ordered by table.""" 

2945 combiners = [] 

2946 for table in table_to_config_dict: 

2947 combiner = table_to_config_dict[table].combiner or 'sum' 

2948 combiners.extend([combiner] * len(table_to_features_dict[table])) 

2949 return combiners 

2950 

2951 

2952def _create_table_to_features_dict(feature_to_config_dict): 

2953 """Create mapping from table to a list of its features.""" 

2954 table_to_features_dict_tmp = {} 

2955 for feature, feature_config in feature_to_config_dict.items(): 

2956 if feature_config.table_id in table_to_features_dict_tmp: 

2957 table_to_features_dict_tmp[feature_config.table_id].append(feature) 

2958 else: 

2959 table_to_features_dict_tmp[feature_config.table_id] = [feature] 

2960 

2961 table_to_features_dict = collections.OrderedDict() 

2962 for table in sorted(table_to_features_dict_tmp): 

2963 table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) 

2964 return table_to_features_dict 

2965 

2966 

2967def _create_device_fn(hosts): 

2968 """Create device_fn() to use with _create_partitioned_variables().""" 

2969 

2970 def device_fn(op): 

2971 """Returns the `device` for `op`.""" 

2972 part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) 

2973 dummy_match = re.match(r'.*dummy_(\d+).*', op.name) 

2974 if not part_match and not dummy_match: 

2975 raise RuntimeError( 

2976 'Internal Error: Expected {} to contain /part_* or dummy_*'.format( 

2977 op.name)) 

2978 

2979 if part_match: 

2980 idx = int(part_match.group(1)) 

2981 else: 

2982 idx = int(dummy_match.group(1)) # pytype: disable=attribute-error 

2983 

2984 device = hosts[idx] 

2985 logging.debug('assigning {} to {}.', op, device) 

2986 return device 

2987 

2988 return device_fn 

2989 

2990 

2991def _create_partitioned_variables(name, 

2992 num_hosts, 

2993 vocabulary_size, 

2994 embedding_dimension, 

2995 initializer, 

2996 collections=None): # pylint: disable=redefined-outer-name 

2997 """Creates PartitionedVariables based on `num_hosts` for `table`.""" 

2998 

2999 num_slices = min(vocabulary_size, num_hosts) 

3000 

3001 var_list = list( 

3002 variable_scope.get_variable( 

3003 name, 

3004 shape=(vocabulary_size, embedding_dimension), 

3005 partitioner=partitioned_variables.fixed_size_partitioner(num_slices), 

3006 dtype=dtypes.float32, 

3007 initializer=initializer, 

3008 collections=collections, 

3009 trainable=False)) 

3010 

3011 if vocabulary_size >= num_hosts: 

3012 return var_list 

3013 

3014 # For padded part, define the dummy variable to be loaded into TPU system. 

3015 for idx in range(num_hosts - vocabulary_size): 

3016 var_list.append( 

3017 variable_scope.get_variable( 

3018 'dummy_{}_{}'.format(vocabulary_size + idx, name), 

3019 shape=(1, embedding_dimension), 

3020 dtype=dtypes.float32, 

3021 initializer=initializer, 

3022 collections=[ops.GraphKeys.LOCAL_VARIABLES], 

3023 trainable=False)) 

3024 

3025 return var_list