Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding.py: 19%
894 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU embedding APIs."""
17import collections
18import copy
19import math
20import re
21from typing import Optional
23from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
24from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import init_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import partitioned_variables
33from tensorflow.python.ops import state_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
37from tensorflow.python.tpu.ops import tpu_ops
38from tensorflow.python.util.tf_export import tf_export
40TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
41INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
44# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
45# as AdagradParameters etc instead of learning_rate.
46class TableConfig(
47 collections.namedtuple('TableConfig', [
48 'vocabulary_size',
49 'dimension',
50 'initializer',
51 'combiner',
52 'hot_id_replication',
53 'learning_rate',
54 'learning_rate_fn',
55 'optimization_parameters',
56 ])):
57 """Embedding table configuration."""
59 def __new__(cls,
60 vocabulary_size,
61 dimension,
62 initializer=None,
63 combiner='mean',
64 hot_id_replication=False,
65 learning_rate=None,
66 learning_rate_fn=None,
67 optimization_parameters=None):
68 """Embedding table configuration.
70 Args:
71 vocabulary_size: Number of vocabulary (/rows) in the table.
72 dimension: The embedding dimension.
73 initializer: A variable initializer function to be used in embedding
74 variable initialization. If not specified, defaults to
75 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard
76 deviation `1/sqrt(dimension)`.
77 combiner: A string specifying how to reduce if there are multiple entries
78 in a single row. Currently 'mean', 'sqrtn', 'sum' and None are
79 supported, with 'mean' the default. 'sqrtn' often achieves good
80 accuracy, in particular with bag-of-words columns. For more information,
81 see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
82 than sparse tensors.
83 hot_id_replication: If true, enables hot id replication, which can make
84 embedding lookups faster if there are some hot rows in the table.
85 learning_rate: float, static learning rate for this table. If
86 learning_rate and learning_rate_fn are both `None`, static learning rate
87 as specified in local `optimization_parameters` will be used. In case
88 local `optimization_parameters` is `None`, global
89 `optimization_parameters` in `TPUEmbedding` constructor will be used.
90 `learning_rate_fn` must be `None` if `learning_rate` is not `None.
91 learning_rate_fn: string, use dynamic learning rate given by the function.
92 This function will be passed the current global step. If learning_rate
93 and learning_rate_fn are both `None`, static learning rate as specified
94 in `optimization_parameters` is used. `learning_rate` must be `None` if
95 `learning_rate_fn` is not `None.
96 optimization_parameters: `AdagradParameters`, `AdamParameters`,
97 `Stochasticgradientdescentparameters`. Specifies table level optimizer.
98 If it's `None` global optimizer in `TPUEmbedding` constructor is used.
100 Returns:
101 `TableConfig`.
103 Raises:
104 ValueError: if `vocabulary_size` is not positive integer.
105 ValueError: if `dimension` is not positive integer.
106 ValueError: if `initializer` is specified and is not callable.
107 ValueError: if `combiner` is not supported.
108 ValueError: if `learning_rate` and `learning_rate_fn` are both not
109 `None`.
110 """
111 if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
112 raise ValueError(f'vocabulary_size must >= 1. '
113 f'Received: {vocabulary_size}.')
115 if not isinstance(dimension, int) or dimension < 1:
116 raise ValueError(
117 f'dimension must be a positive int. Received: {dimension}.')
119 if (initializer is not None) and (not callable(initializer)):
120 raise ValueError(f'initializer must be callable if specified. '
121 f'Received: {initializer}.')
122 if initializer is None:
123 initializer = init_ops.truncated_normal_initializer(
124 mean=0.0, stddev=1 / math.sqrt(dimension))
126 if combiner not in ('mean', 'sum', 'sqrtn', None):
127 raise ValueError(f'combiner must be "mean", "sum", "sqrtn" or None. '
128 f'Received: {combiner}.')
130 if learning_rate is not None and learning_rate_fn is not None:
131 raise ValueError('At most one of learning_rate and learning_rate_fn '
132 'can be None. Received: {} and {}'.format(
133 learning_rate, learning_rate_fn))
135 if optimization_parameters is not None:
136 if not isinstance(optimization_parameters, _OptimizationParameters):
137 raise ValueError(f'`optimization_parameters` must inherit from '
138 f'`_OptimizationParameters`. '
139 f'Received: `type(optimization_parameters)`='
140 f'{type(optimization_parameters)}.')
142 return super().__new__(cls, vocabulary_size, dimension, initializer,
143 combiner, hot_id_replication, learning_rate,
144 learning_rate_fn, optimization_parameters)
147class FeatureConfig(
148 collections.namedtuple('FeatureConfig',
149 ['table_id', 'max_sequence_length', 'weight_key'])):
150 """Feature configuration."""
152 def __new__(cls, table_id, max_sequence_length=0, weight_key=None):
153 """Feature configuration.
155 Args:
156 table_id: Which table the feature is uses for embedding lookups.
157 max_sequence_length: If positive, the feature is a sequence feature with
158 the corresponding maximum sequence length. If the sequence is longer
159 than this, it will be truncated. If 0, the feature is not a sequence
160 feature.
161 weight_key: If using weights for the combiner, this key specifies which
162 input feature contains the weights.
164 Returns:
165 `FeatureConfig`.
167 Raises:
168 ValueError: if `max_sequence_length` non-integer or negative.
169 """
170 if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
171 raise ValueError(f'max_sequence_length must be zero or a positive int, '
172 f'got {max_sequence_length}.')
174 return super().__new__(cls, table_id, max_sequence_length, weight_key)
177class EnqueueData(
178 collections.namedtuple(
179 'EnqueueData',
180 ['embedding_indices', 'sample_indices', 'aggregation_weights'])):
181 """Data to be enqueued through generate_enqueue_ops()."""
183 def __new__(cls,
184 embedding_indices,
185 sample_indices=None,
186 aggregation_weights=None):
187 """Data to be enqueued through generate_enqueue_ops().
189 Args:
190 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
191 corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
192 and int64 are allowed and will be converted to int32 internally.
193 sample_indices: A rank 2 Tensor specifying the training example to which
194 the corresponding embedding_indices and aggregation_weights values
195 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
196 If it is None, we assume each embedding_indices belongs to a different
197 sample. Both int32 and int64 are allowed and will be converted to int32
198 internally.
199 aggregation_weights: A rank 1 Tensor containing aggregation weights. It
200 corresponds to sp_weights.values in embedding_lookup_sparse(). If it is
201 None, we assume all weights are 1. Both float32 and float64 are allowed
202 and will be converted to float32 internally.
204 Returns:
205 An EnqueueData tuple.
207 """
208 return super().__new__(cls, embedding_indices, sample_indices,
209 aggregation_weights)
211 @staticmethod
212 def from_sparse_tensor(sp_tensor, weights=None):
213 return EnqueueData(
214 sp_tensor.values,
215 sp_tensor.indices,
216 aggregation_weights=weights.values if weights is not None else None)
219class RaggedEnqueueData(
220 collections.namedtuple(
221 'RaggedEnqueueData',
222 ['embedding_indices', 'row_splits', 'aggregation_weights'])):
223 """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
225 def __new__(cls,
226 embedding_indices,
227 row_splits=None,
228 aggregation_weights=None):
229 """Data to be enqueued through generate_enqueue_ops().
231 Args:
232 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
233 corresponds to ids.values in embedding_lookup(), when ids is a
234 RaggedTensor. Both int32 and int64 are allowed and will be converted to
235 int32 internally.
236 row_splits: A rank 1 Tensor specifying the length of the break points for
237 splitting embedding_indices and aggregation_weights. It corresponds to
238 ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
239 int32 and int64 are allowed and will be converted to int32 internally.
240 aggregation_weights: A rank 1 Tensor containing per training example
241 aggregation weights. It corresponds to the values field of a
242 RaggedTensor with the same row_splits as ids in embedding_lookup(), when
243 ids is a RaggedTensor.
245 Returns:
246 An RaggedEnqueueData tuple.
248 """
249 return super().__new__(cls, embedding_indices, row_splits,
250 aggregation_weights)
252 @staticmethod
253 def from_ragged_tensor(rg_tensor, weights=None):
254 return RaggedEnqueueData(
255 rg_tensor.values,
256 rg_tensor.row_splits,
257 aggregation_weights=weights.values if weights is not None else None)
260def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
261 """Convenient function for generate_enqueue_ops().
263 Args:
264 sp_tensors_list: a list of dictionary mapping from string of feature names
265 to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
266 same host should be contiguous on the list.
268 Returns:
269 enqueue_datas_list: a list of dictionary mapping from string
270 of feature names to EnqueueData. Each dictionary is for one
271 TPU core. Dictionaries for the same host should be contiguous
272 on the list.
274 """
275 enqueue_datas_list = []
276 for sp_tensors in sp_tensors_list:
277 enqueue_datas = collections.OrderedDict(
278 (k, EnqueueData.from_sparse_tensor(v)) for k, v in sp_tensors.items())
279 enqueue_datas_list.append(enqueue_datas)
280 return enqueue_datas_list
283def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
284 """Convenient function for generate_enqueue_ops().
286 Args:
287 rg_tensors_list: a list of dictionary mapping from string of feature names
288 to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
289 same host should be contiguous on the list.
291 Returns:
292 enqueue_datas_list: a list of dictionary mapping from string
293 of feature names to RaggedEnqueueData. Each dictionary is for one
294 TPU core. Dictionaries for the same host should be contiguous
295 on the list.
297 """
298 enqueue_datas_list = []
299 for rg_tensors in rg_tensors_list:
300 enqueue_datas = collections.OrderedDict(
301 (k, RaggedEnqueueData.from_ragged_tensor(v))
302 for k, v in rg_tensors.items())
303 enqueue_datas_list.append(enqueue_datas)
304 return enqueue_datas_list
307AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames',
308 ['m', 'v'])
310AdagradSlotVariableNames = collections.namedtuple('AdagradSlotVariableNames',
311 ['accumulator'])
313MomentumSlotVariableNames = collections.namedtuple('MomentumSlotVariableNames',
314 ['momenta'])
316AdagradMomentumSlotVariableNames = collections.namedtuple(
317 'AdagradMomentumSlotVariableNames', ['accumulator', 'momenta'])
319RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames',
320 ['ms', 'mom'])
322ProximalAdagradSlotVariableNames = collections.namedtuple(
323 'ProximalAdagradSlotVariableNames', ['accumulator'])
325FtrlSlotVariableNames = collections.namedtuple('FtrlSlotVariableNames',
326 ['accumulator', 'linear'])
328ProximalYogiSlotVariableNames = collections.namedtuple(
329 'ProximalYogiSlotVariableNames', ['v', 'm'])
331FrequencyEstimatorSlotVariableNames = collections.namedtuple(
332 'FrequencyEstimatorSlotVariableNames', ['last_hit_step'])
334AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v'])
336MomentumSlotVariables = collections.namedtuple('MomentumSlotVariables',
337 ['momenta'])
339AdagradMomentumSlotVariables = collections.namedtuple(
340 'AdagradMomentumSlotVariables', ['accumulator', 'momenta'])
342RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables',
343 ['ms', 'mom'])
345AdagradSlotVariables = collections.namedtuple('AdagradSlotVariables',
346 ['accumulator'])
348ProximalAdagradSlotVariables = collections.namedtuple(
349 'ProximalAdagradSlotVariables', ['accumulator'])
351FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
352 ['accumulator', 'linear'])
354ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
355 ['v', 'm'])
357FrequencyEstimatorSlotVariables = collections.namedtuple(
358 'FrequencyEstimatorSlotVariables', ['last_hit_step'])
360VariablesAndOps = collections.namedtuple('VariablesAndOps', [
361 'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
362 'retrieve_ops'
363])
366class _OptimizationParameters:
367 """Parameters common to all optimizations."""
369 def __init__(
370 self,
371 learning_rate: float,
372 use_gradient_accumulation: bool,
373 clip_weight_min: Optional[float],
374 clip_weight_max: Optional[float],
375 weight_decay_factor: Optional[float],
376 multiply_weight_decay_factor_by_learning_rate: Optional[bool],
377 clip_gradient_min: Optional[float] = None,
378 clip_gradient_max: Optional[float] = None,
379 ):
380 self.learning_rate = learning_rate
381 self.use_gradient_accumulation = use_gradient_accumulation
382 self.clip_weight_min = clip_weight_min
383 self.clip_weight_max = clip_weight_max
384 self.weight_decay_factor = weight_decay_factor
385 self.multiply_weight_decay_factor_by_learning_rate = (
386 multiply_weight_decay_factor_by_learning_rate)
387 self.clip_gradient_min = clip_gradient_min
388 self.clip_gradient_max = clip_gradient_max
390 if not use_gradient_accumulation and (clip_gradient_min is not None or
391 clip_gradient_max is not None):
392 raise ValueError('When using gradient clipping limits, gradient '
393 'accumulation must be enabled.')
396@tf_export(v1=['tpu.experimental.AdagradParameters'])
397class AdagradParameters(_OptimizationParameters):
398 """Optimization parameters for Adagrad with TPU embeddings.
400 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
401 `optimization_parameters` argument to set the optimizer and its parameters.
402 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
403 for more details.
405 ```
406 estimator = tf.estimator.tpu.TPUEstimator(
407 ...
408 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
409 ...
410 optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1),
411 ...))
412 ```
414 """
416 def __init__(
417 self,
418 learning_rate: float,
419 initial_accumulator: float = 0.1,
420 use_gradient_accumulation: bool = True,
421 clip_weight_min: Optional[float] = None,
422 clip_weight_max: Optional[float] = None,
423 weight_decay_factor: Optional[float] = None,
424 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
425 clip_gradient_min: Optional[float] = None,
426 clip_gradient_max: Optional[float] = None,
427 ):
428 """Optimization parameters for Adagrad.
430 Args:
431 learning_rate: used for updating embedding table.
432 initial_accumulator: initial accumulator for Adagrad.
433 use_gradient_accumulation: setting this to `False` makes embedding
434 gradients calculation less accurate but faster. Please see
435 `optimization_parameters.proto` for details.
436 clip_weight_min: the minimum value to clip by; None means -infinity.
437 clip_weight_max: the maximum value to clip by; None means +infinity.
438 weight_decay_factor: amount of weight decay to apply; None means that the
439 weights are not decayed.
440 multiply_weight_decay_factor_by_learning_rate: if true,
441 `weight_decay_factor` is multiplied by the current learning rate.
442 clip_gradient_min: the minimum value to clip by; None means -infinity.
443 Gradient accumulation must be set to true if this is set.
444 clip_gradient_max: the maximum value to clip by; None means +infinity.
445 Gradient accumulation must be set to true if this is set.
446 """
447 super().__init__(
448 learning_rate=learning_rate,
449 use_gradient_accumulation=use_gradient_accumulation,
450 clip_weight_min=clip_weight_min,
451 clip_weight_max=clip_weight_max,
452 weight_decay_factor=weight_decay_factor,
453 multiply_weight_decay_factor_by_learning_rate=(
454 multiply_weight_decay_factor_by_learning_rate),
455 clip_gradient_min=clip_gradient_min,
456 clip_gradient_max=clip_gradient_max,
457 )
458 if initial_accumulator <= 0:
459 raise ValueError(
460 f'Adagrad initial_accumulator must be greater than zero. '
461 f'Received: {initial_accumulator}.')
462 self.initial_accumulator = initial_accumulator
465class AdagradMomentumParameters(_OptimizationParameters):
466 """Optimization parameters for Adagrad + Momentum with TPU embeddings.
468 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
469 `optimization_parameters` argument to set the optimizer and its parameters.
470 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
471 for more details.
473 ```
474 estimator = tf.estimator.tpu.TPUEstimator(
475 ...
476 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
477 ...
478 optimization_parameters=tf.tpu.experimental.AdagradMomentumParameters(0.1),
479 ...))
480 ```
482 """
484 def __init__(
485 self,
486 learning_rate: float,
487 momentum: float,
488 use_nesterov: bool = False,
489 exponent: float = 2,
490 beta2: float = 1,
491 epsilon: float = 1e-10,
492 use_gradient_accumulation: bool = True,
493 clip_weight_min: Optional[float] = None,
494 clip_weight_max: Optional[float] = None,
495 weight_decay_factor: Optional[float] = None,
496 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
497 clip_gradient_min: Optional[float] = None,
498 clip_gradient_max: Optional[float] = None,
499 ):
500 """Optimization parameters for Adagrad.
502 Args:
503 learning_rate: used for updating embedding table.
504 momentum: Moving average parameter for the momentum accumulator.
505 use_nesterov: Whether to use the Nesterov variant of momentum. See
506 Sutskever et al., 2013.
507 exponent: Exponent for the Adagrad accumulator.
508 beta2: Moving average parameter for the Adagrad accumulator.
509 epsilon: initial accumulator for Adagrad accumulator.
510 use_gradient_accumulation: setting this to `False` makes embedding
511 gradients calculation less accurate but faster. Please see
512 `optimization_parameters.proto` for details.
513 clip_weight_min: the minimum value to clip by; None means -infinity.
514 clip_weight_max: the maximum value to clip by; None means +infinity.
515 weight_decay_factor: amount of weight decay to apply; None means that the
516 weights are not decayed.
517 multiply_weight_decay_factor_by_learning_rate: if true,
518 `weight_decay_factor` is multiplied by the current learning rate.
519 clip_gradient_min: the minimum value to clip by; None means -infinity.
520 Gradient accumulation must be set to true if this is set.
521 clip_gradient_max: the maximum value to clip by; None means +infinity.
522 Gradient accumulation must be set to true if this is set.
523 """
524 super().__init__(
525 learning_rate=learning_rate,
526 use_gradient_accumulation=use_gradient_accumulation,
527 clip_weight_min=clip_weight_min,
528 clip_weight_max=clip_weight_max,
529 weight_decay_factor=weight_decay_factor,
530 multiply_weight_decay_factor_by_learning_rate=(
531 multiply_weight_decay_factor_by_learning_rate),
532 clip_gradient_min=clip_gradient_min,
533 clip_gradient_max=clip_gradient_max,
534 )
535 if epsilon <= 0:
536 raise ValueError('Adagrad momentum: epsilon must be positive')
537 if exponent <= 0:
538 raise ValueError('Adagrad momentum: Precondition exponent must >0')
539 self.momentum = momentum
540 self.use_nesterov = use_nesterov
541 self.exponent = exponent
542 self.beta2 = beta2
543 self.epsilon = epsilon
546class ProximalAdagradParameters(_OptimizationParameters):
547 """Optimization parameters for ProximalAdagrad with TPU embeddings.
549 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
550 `optimization_parameters` argument to set the optimizer and its parameters.
551 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
552 for more details.
553 """
555 def __init__(
556 self,
557 learning_rate: float,
558 initial_accumulator: float = 0.1,
559 l1_regularization_strength: float = 0.0,
560 l2_regularization_strength: float = 0.0,
561 use_gradient_accumulation: bool = True,
562 clip_weight_min: Optional[float] = None,
563 clip_weight_max: Optional[float] = None,
564 weight_decay_factor: Optional[float] = None,
565 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
566 clip_gradient_min: Optional[float] = None,
567 clip_gradient_max: Optional[float] = None,
568 ):
569 """Optimization parameters for Adagrad.
571 Args:
572 learning_rate: used for updating embedding table.
573 initial_accumulator: initial accumulator for Adagrad.
574 l1_regularization_strength: A float value, must be greater than or equal
575 to zero.
576 l2_regularization_strength: A float value, must be greater than or equal
577 to zero.
578 use_gradient_accumulation: setting this to `False` makes embedding
579 gradients calculation less accurate but faster. Please see
580 `optimization_parameters.proto` for details. for details.
581 clip_weight_min: the minimum value to clip by; None means -infinity.
582 clip_weight_max: the maximum value to clip by; None means +infinity.
583 weight_decay_factor: amount of weight decay to apply; None means that the
584 weights are not decayed.
585 multiply_weight_decay_factor_by_learning_rate: if true,
586 `weight_decay_factor` is multiplied by the current learning rate.
587 clip_gradient_min: the minimum value to clip by; None means -infinity.
588 Gradient accumulation must be set to true if this is set.
589 clip_gradient_max: the maximum value to clip by; None means +infinity.
590 Gradient accumulation must be set to true if this is set.
591 """
592 super().__init__(
593 learning_rate=learning_rate,
594 use_gradient_accumulation=use_gradient_accumulation,
595 clip_weight_min=clip_weight_min,
596 clip_weight_max=clip_weight_max,
597 weight_decay_factor=weight_decay_factor,
598 multiply_weight_decay_factor_by_learning_rate=(
599 multiply_weight_decay_factor_by_learning_rate),
600 clip_gradient_min=clip_gradient_min,
601 clip_gradient_max=clip_gradient_max,
602 )
603 if initial_accumulator <= 0:
604 raise ValueError(f'Adagrad initial_accumulator must be positive. '
605 f'Received: {initial_accumulator}.')
606 if l1_regularization_strength < 0.:
607 raise ValueError('l1_regularization_strength must be greater than or '
608 'equal to 0. got {}.'.format(l1_regularization_strength))
610 if l2_regularization_strength < 0.:
611 raise ValueError('l2_regularization_strength must be greater than or '
612 'equal to 0. got {}.'.format(l2_regularization_strength))
614 self.initial_accumulator = initial_accumulator
615 self.l1_regularization_strength = l1_regularization_strength
616 self.l2_regularization_strength = l2_regularization_strength
619@tf_export(v1=['tpu.experimental.AdamParameters'])
620class AdamParameters(_OptimizationParameters):
621 """Optimization parameters for Adam with TPU embeddings.
623 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
624 `optimization_parameters` argument to set the optimizer and its parameters.
625 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
626 for more details.
628 ```
629 estimator = tf.estimator.tpu.TPUEstimator(
630 ...
631 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
632 ...
633 optimization_parameters=tf.tpu.experimental.AdamParameters(0.1),
634 ...))
635 ```
637 """
639 def __init__(
640 self,
641 learning_rate: float,
642 beta1: float = 0.9,
643 beta2: float = 0.999,
644 epsilon: float = 1e-08,
645 lazy_adam: bool = True,
646 sum_inside_sqrt: bool = True,
647 use_gradient_accumulation: bool = True,
648 clip_weight_min: Optional[float] = None,
649 clip_weight_max: Optional[float] = None,
650 weight_decay_factor: Optional[float] = None,
651 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
652 clip_gradient_min: Optional[float] = None,
653 clip_gradient_max: Optional[float] = None,
654 ):
655 """Optimization parameters for Adam.
657 Args:
658 learning_rate: a floating point value. The learning rate.
659 beta1: A float value. The exponential decay rate for the 1st moment
660 estimates.
661 beta2: A float value. The exponential decay rate for the 2nd moment
662 estimates.
663 epsilon: A small constant for numerical stability.
664 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See
665 `optimization_parameters.proto` for details.
666 sum_inside_sqrt: This improves training speed. Please see
667 `optimization_parameters.proto` for details.
668 use_gradient_accumulation: setting this to `False` makes embedding
669 gradients calculation less accurate but faster. Please see
670 `optimization_parameters.proto` for details.
671 clip_weight_min: the minimum value to clip by; None means -infinity.
672 clip_weight_max: the maximum value to clip by; None means +infinity.
673 weight_decay_factor: amount of weight decay to apply; None means that the
674 weights are not decayed.
675 multiply_weight_decay_factor_by_learning_rate: if true,
676 `weight_decay_factor` is multiplied by the current learning rate.
677 clip_gradient_min: the minimum value to clip by; None means -infinity.
678 Gradient accumulation must be set to true if this is set.
679 clip_gradient_max: the maximum value to clip by; None means +infinity.
680 Gradient accumulation must be set to true if this is set.
681 """
682 super().__init__(
683 learning_rate=learning_rate,
684 use_gradient_accumulation=use_gradient_accumulation,
685 clip_weight_min=clip_weight_min,
686 clip_weight_max=clip_weight_max,
687 weight_decay_factor=weight_decay_factor,
688 multiply_weight_decay_factor_by_learning_rate=(
689 multiply_weight_decay_factor_by_learning_rate),
690 clip_gradient_min=clip_gradient_min,
691 clip_gradient_max=clip_gradient_max,
692 )
693 if beta1 < 0. or beta1 >= 1.:
694 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
695 if beta2 < 0. or beta2 >= 1.:
696 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
697 if epsilon <= 0.:
698 raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
699 if not use_gradient_accumulation and not lazy_adam:
700 raise ValueError(
701 'When disabling Lazy Adam, gradient accumulation must be used.')
703 self.beta1 = beta1
704 self.beta2 = beta2
705 self.epsilon = epsilon
706 self.lazy_adam = lazy_adam
707 self.sum_inside_sqrt = sum_inside_sqrt
710@tf_export(v1=['tpu.experimental.FtrlParameters'])
711class FtrlParameters(_OptimizationParameters):
712 """Optimization parameters for Ftrl with TPU embeddings.
714 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
715 `optimization_parameters` argument to set the optimizer and its parameters.
716 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
717 for more details.
719 ```
720 estimator = tf.estimator.tpu.TPUEstimator(
721 ...
722 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
723 ...
724 optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1),
725 ...))
726 ```
728 """
730 def __init__(
731 self,
732 learning_rate: float,
733 learning_rate_power: float = -0.5,
734 initial_accumulator_value: float = 0.1,
735 l1_regularization_strength: float = 0.0,
736 l2_regularization_strength: float = 0.0,
737 use_gradient_accumulation: bool = True,
738 clip_weight_min: Optional[float] = None,
739 clip_weight_max: Optional[float] = None,
740 weight_decay_factor: Optional[float] = None,
741 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
742 multiply_linear_by_learning_rate: bool = False,
743 beta: float = 0,
744 allow_zero_accumulator: bool = False,
745 clip_gradient_min: Optional[float] = None,
746 clip_gradient_max: Optional[float] = None,
747 ):
748 """Optimization parameters for Ftrl.
750 Implements FTRL as described in the following [paper](
751 https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
753 Args:
754 learning_rate: a floating point value. The learning rate.
755 learning_rate_power: A float value, must be less or equal to zero.
756 Controls how the learning rate decreases during training. Use zero for a
757 fixed learning rate. See section 3.1 in the
758 [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
759 initial_accumulator_value: The starting value for accumulators. Only zero
760 or positive values are allowed.
761 l1_regularization_strength: A float value, must be greater than or equal
762 to zero.
763 l2_regularization_strength: A float value, must be greater than or equal
764 to zero.
765 use_gradient_accumulation: setting this to `False` makes embedding
766 gradients calculation less accurate but faster. Please see
767 `optimization_parameters.proto` for details. for details.
768 clip_weight_min: the minimum value to clip by; None means -infinity.
769 clip_weight_max: the maximum value to clip by; None means +infinity.
770 weight_decay_factor: amount of weight decay to apply; None means that the
771 weights are not decayed.
772 multiply_weight_decay_factor_by_learning_rate: if true,
773 `weight_decay_factor` is multiplied by the current learning rate.
774 multiply_linear_by_learning_rate: When true, multiplies the usages of the
775 linear slot in the weight update by the learning rate. This is useful
776 when ramping up learning rate from 0 (which would normally produce
777 NaNs).
778 beta: The beta parameter for FTRL.
779 allow_zero_accumulator: Changes the implementation of the square root to
780 allow for the case of initial_accumulator_value being zero. This will
781 cause a slight performance drop.
782 clip_gradient_min: the minimum value to clip by; None means -infinity.
783 Gradient accumulation must be set to true if this is set.
784 clip_gradient_max: the maximum value to clip by; None means +infinity.
785 Gradient accumulation must be set to true if this is set.
786 """
787 super().__init__(
788 learning_rate=learning_rate,
789 use_gradient_accumulation=use_gradient_accumulation,
790 clip_weight_min=clip_weight_min,
791 clip_weight_max=clip_weight_max,
792 weight_decay_factor=weight_decay_factor,
793 multiply_weight_decay_factor_by_learning_rate=(
794 multiply_weight_decay_factor_by_learning_rate),
795 clip_gradient_min=clip_gradient_min,
796 clip_gradient_max=clip_gradient_max,
797 )
798 if learning_rate_power > 0.:
799 raise ValueError('learning_rate_power must be less than or equal to 0. '
800 'got {}.'.format(learning_rate_power))
802 if initial_accumulator_value < 0.:
803 raise ValueError('initial_accumulator_value must be greater than or equal'
804 ' to 0. got {}.'.format(initial_accumulator_value))
806 if l1_regularization_strength < 0.:
807 raise ValueError('l1_regularization_strength must be greater than or '
808 'equal to 0. got {}.'.format(l1_regularization_strength))
810 if l2_regularization_strength < 0.:
811 raise ValueError('l2_regularization_strength must be greater than or '
812 'equal to 0. got {}.'.format(l2_regularization_strength))
814 self.learning_rate_power = learning_rate_power
815 self.initial_accumulator_value = initial_accumulator_value
816 self.initial_linear_value = 0.0
817 self.l1_regularization_strength = l1_regularization_strength
818 self.l2_regularization_strength = l2_regularization_strength
819 self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate
820 self.beta = beta
821 self.allow_zero_accumulator = allow_zero_accumulator
824class ProximalYogiParameters(_OptimizationParameters):
825 # pylint: disable=line-too-long
826 """Optimization parameters for Proximal Yogi with TPU embeddings.
828 Implements the Yogi optimizer as described in
829 [Adaptive Methods for Nonconvex
830 Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
832 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
833 `optimization_parameters` argument to set the optimizer and its parameters.
834 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
835 for more details.
836 """
838 # pylint: enable=line-too-long
840 def __init__(
841 self,
842 learning_rate: float = 0.01,
843 beta1: float = 0.9,
844 beta2: float = 0.999,
845 epsilon: float = 1e-3,
846 l1_regularization_strength: float = 0.0,
847 l2_regularization_strength: float = 0.0,
848 initial_accumulator_value: float = 1e-6,
849 use_gradient_accumulation: bool = True,
850 clip_weight_min: Optional[float] = None,
851 clip_weight_max: Optional[float] = None,
852 weight_decay_factor: Optional[float] = None,
853 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
854 clip_gradient_min: Optional[float] = None,
855 clip_gradient_max: Optional[float] = None,
856 ):
857 """Optimization parameters for Proximal Yogi.
859 Args:
860 learning_rate: a floating point value. The learning rate.
861 beta1: A float value. The exponential decay rate for the 1st moment
862 estimates.
863 beta2: A float value. The exponential decay rate for the 2nd moment
864 estimates.
865 epsilon: A small constant for numerical stability.
866 l1_regularization_strength: A float value, must be greater than or equal
867 to zero.
868 l2_regularization_strength: A float value, must be greater than or equal
869 to zero.
870 initial_accumulator_value: The starting value for accumulators. Only zero
871 or positive values are allowed.
872 use_gradient_accumulation: setting this to `False` makes embedding
873 gradients calculation less accurate but faster. Please see
874 `optimization_parameters.proto` for details. for details.
875 clip_weight_min: the minimum value to clip by; None means -infinity.
876 clip_weight_max: the maximum value to clip by; None means +infinity.
877 weight_decay_factor: amount of weight decay to apply; None means that the
878 weights are not decayed.
879 multiply_weight_decay_factor_by_learning_rate: if true,
880 `weight_decay_factor` is multiplied by the current learning rate.
881 clip_gradient_min: the minimum value to clip by; None means -infinity.
882 Gradient accumulation must be set to true if this is set.
883 clip_gradient_max: the maximum value to clip by; None means +infinity.
884 Gradient accumulation must be set to true if this is set.
885 """
886 super().__init__(
887 learning_rate=learning_rate,
888 use_gradient_accumulation=use_gradient_accumulation,
889 clip_weight_min=clip_weight_min,
890 clip_weight_max=clip_weight_max,
891 weight_decay_factor=weight_decay_factor,
892 multiply_weight_decay_factor_by_learning_rate=(
893 multiply_weight_decay_factor_by_learning_rate),
894 clip_gradient_min=clip_gradient_min,
895 clip_gradient_max=clip_gradient_max,
896 )
897 if beta1 < 0. or beta1 >= 1.:
898 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
899 if beta2 < 0. or beta2 >= 1.:
900 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
901 if epsilon <= 0.:
902 raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
903 if l1_regularization_strength < 0.:
904 raise ValueError('l1_regularization_strength must be greater than or '
905 'equal to 0. got {}.'.format(l1_regularization_strength))
906 if l2_regularization_strength < 0.:
907 raise ValueError('l2_regularization_strength must be greater than or '
908 'equal to 0. got {}.'.format(l2_regularization_strength))
910 self.beta1 = beta1
911 self.beta2 = beta2
912 self.epsilon = epsilon
913 self.l1_regularization_strength = l1_regularization_strength
914 self.l2_regularization_strength = l2_regularization_strength
915 self.initial_accumulator_value = initial_accumulator_value
918class MomentumParameters(_OptimizationParameters):
919 """Optimization parameters for Momentum with TPU embeddings.
921 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
922 `optimization_parameters` argument to set the optimizer and its parameters.
923 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
924 for more details.
926 ```
927 estimator = tf.estimator.tpu.TPUEstimator(
928 ...
929 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
930 ...
931 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
932 ...))
933 ```
935 """
937 def __init__(
938 self,
939 learning_rate: float,
940 momentum: float,
941 use_nesterov: bool = False,
942 use_gradient_accumulation: bool = True,
943 clip_weight_min: Optional[float] = None,
944 clip_weight_max: Optional[float] = None,
945 weight_decay_factor: Optional[float] = None,
946 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
947 clip_gradient_min: Optional[float] = None,
948 clip_gradient_max: Optional[float] = None,
949 ):
950 """Optimization parameters for momentum.
952 Args:
953 learning_rate: a floating point value. The learning rate.
954 momentum: a floating point value. The momentum.
955 use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al.,
956 2013). This implementation always computes gradients at the value of the
957 variable(s) passed to the optimizer. Using Nesterov Momentum makes the
958 variable(s) track the values called `theta_t + mu*v_t` in the paper.
959 This implementation is an approximation of the original formula, valid
960 for high values of momentum. It will compute the "adjusted gradient" in
961 NAG by assuming that the new gradient will be estimated by the current
962 average gradient plus the product of momentum and the change in the
963 average gradient.
964 use_gradient_accumulation: setting this to `False` makes embedding
965 gradients calculation less accurate but faster. Please see
966 `optimization_parameters.proto` for details.
967 clip_weight_min: the minimum value to clip by; None means -infinity.
968 clip_weight_max: the maximum value to clip by; None means +infinity.
969 weight_decay_factor: amount of weight decay to apply; None means that the
970 weights are not decayed.
971 multiply_weight_decay_factor_by_learning_rate: if true,
972 `weight_decay_factor` is multiplied by the current learning rate.
973 clip_gradient_min: the minimum value to clip by; None means -infinity.
974 Gradient accumulation must be set to true if this is set.
975 clip_gradient_max: the maximum value to clip by; None means +infinity.
976 Gradient accumulation must be set to true if this is set.
977 """
978 super().__init__(
979 learning_rate=learning_rate,
980 use_gradient_accumulation=use_gradient_accumulation,
981 clip_weight_min=clip_weight_min,
982 clip_weight_max=clip_weight_max,
983 weight_decay_factor=weight_decay_factor,
984 multiply_weight_decay_factor_by_learning_rate=(
985 multiply_weight_decay_factor_by_learning_rate),
986 clip_gradient_min=clip_gradient_min,
987 clip_gradient_max=clip_gradient_max,
988 )
989 self.momentum = momentum
990 self.use_nesterov = use_nesterov
993class RMSPropParameters(_OptimizationParameters):
994 """Optimization parameters for RMSProp with TPU embeddings.
996 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
997 `optimization_parameters` argument to set the optimizer and its parameters.
998 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
999 for more details.
1001 ```
1002 estimator = tf.estimator.tpu.TPUEstimator(
1003 ...
1004 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1005 ...
1006 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
1007 ...))
1008 ```
1010 """
1012 def __init__(
1013 self,
1014 learning_rate: float,
1015 rho: float,
1016 momentum: float,
1017 epsilon: float,
1018 use_gradient_accumulation: bool = True,
1019 clip_weight_min: Optional[float] = None,
1020 clip_weight_max: Optional[float] = None,
1021 weight_decay_factor: Optional[float] = None,
1022 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
1023 clip_gradient_min: Optional[float] = None,
1024 clip_gradient_max: Optional[float] = None,
1025 ):
1026 """Optimization parameters for RMS prop.
1028 Args:
1029 learning_rate: a floating point value. The learning rate.
1030 rho: Discounting factor for the history/coming gradient
1031 momentum: A scalar tensor.
1032 epsilon: Small value to avoid zero denominator.
1033 use_gradient_accumulation: setting this to `False` makes embedding
1034 gradients calculation less accurate but faster. Please see
1035 `optimization_parameters.proto` for details. for details.
1036 clip_weight_min: the minimum value to clip by; None means -infinity.
1037 clip_weight_max: the maximum value to clip by; None means +infinity.
1038 weight_decay_factor: amount of weight decay to apply; None means that the
1039 weights are not decayed.
1040 multiply_weight_decay_factor_by_learning_rate: if true,
1041 `weight_decay_factor` is multiplied by the current learning rate.
1042 clip_gradient_min: the minimum value to clip by; None means -infinity.
1043 Gradient accumulation must be set to true if this is set.
1044 clip_gradient_max: the maximum value to clip by; None means +infinity.
1045 Gradient accumulation must be set to true if this is set.
1046 """
1047 super().__init__(
1048 learning_rate=learning_rate,
1049 use_gradient_accumulation=use_gradient_accumulation,
1050 clip_weight_min=clip_weight_min,
1051 clip_weight_max=clip_weight_max,
1052 weight_decay_factor=weight_decay_factor,
1053 multiply_weight_decay_factor_by_learning_rate=(
1054 multiply_weight_decay_factor_by_learning_rate),
1055 clip_gradient_min=clip_gradient_min,
1056 clip_gradient_max=clip_gradient_max,
1057 )
1058 self.rho = rho
1059 self.momentum = momentum
1060 self.epsilon = epsilon
1063@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
1064class StochasticGradientDescentParameters(_OptimizationParameters):
1065 """Optimization parameters for stochastic gradient descent for TPU embeddings.
1067 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
1068 `optimization_parameters` argument to set the optimizer and its parameters.
1069 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
1070 for more details.
1072 ```
1073 estimator = tf.estimator.tpu.TPUEstimator(
1074 ...
1075 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1076 ...
1077 optimization_parameters=(
1078 tf.tpu.experimental.StochasticGradientDescentParameters(0.1))))
1079 ```
1081 """
1083 def __init__(
1084 self,
1085 learning_rate: float,
1086 use_gradient_accumulation: bool = True,
1087 clip_weight_min: Optional[float] = None,
1088 clip_weight_max: Optional[float] = None,
1089 weight_decay_factor: Optional[float] = None,
1090 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
1091 clip_gradient_min: Optional[float] = None,
1092 clip_gradient_max: Optional[float] = None,
1093 ):
1094 """Optimization parameters for stochastic gradient descent.
1096 Args:
1097 learning_rate: a floating point value. The learning rate.
1098 use_gradient_accumulation: setting this to `False` makes embedding
1099 gradients calculation less accurate but faster. Please see
1100 `optimization_parameters.proto` for details.
1101 clip_weight_min: the minimum value to clip by; None means -infinity.
1102 clip_weight_max: the maximum value to clip by; None means +infinity.
1103 weight_decay_factor: amount of weight decay to apply; None means that the
1104 weights are not decayed.
1105 multiply_weight_decay_factor_by_learning_rate: if true,
1106 `weight_decay_factor` is multiplied by the current learning rate.
1107 clip_gradient_min: the minimum value to clip by; None means -infinity.
1108 clip_gradient_max: the maximum value to clip by; None means +infinity.
1109 """
1110 super().__init__(
1111 learning_rate=learning_rate,
1112 use_gradient_accumulation=use_gradient_accumulation,
1113 clip_weight_min=clip_weight_min,
1114 clip_weight_max=clip_weight_max,
1115 weight_decay_factor=weight_decay_factor,
1116 multiply_weight_decay_factor_by_learning_rate=(
1117 multiply_weight_decay_factor_by_learning_rate),
1118 clip_gradient_min=clip_gradient_min,
1119 clip_gradient_max=clip_gradient_max,
1120 )
1123class FrequencyEstimatorParameters(_OptimizationParameters):
1124 """Optimization parameters for Frequency Estimator TPU embeddings.
1126 This is a non-standard optimizer, which returns the estimated frequency of
1127 lookup for the feature passed to it. It should only be used on a table of
1128 width 1. The gradient fed back to the TPU embedding should always be zero.
1129 This can be acomplished via using `tf.stop_gradients` on the feature before
1130 using it.
1132 You must use the dynamic learning rate mechanism to set the 'learning rate'
1133 for this table to be the a float32 cast of the global training step counter.
1135 See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more
1136 details on this optimizer.
1138 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
1139 `optimization_parameters` argument to set the optimizer and its parameters.
1140 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
1141 for more details.
1143 ```
1144 estimator = tf.estimator.tpu.TPUEstimator(
1145 ...
1146 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1147 ...
1148 optimization_parameters=FrequencyEstimatorParameters(0.1),
1149 ...))
1150 ```
1152 """
1154 def __init__(self, tau: float, max_delta: float, outlier_threshold: float,
1155 weight_exponent: float):
1156 """Optimization parameters for frequency estimator.
1158 Args:
1159 tau: Learning rate between (0, 1) that is used to update the array.
1160 max_delta: Maximum value of delta, the difference between the current
1161 global step and the last global step at which the row was sampled.
1162 outlier_threshold: Threshold used to determine whether the current update
1163 is an outlier.
1164 weight_exponent: The weight exponent used to transform the estimated delta
1165 into weights.
1166 """
1167 super().__init__(
1168 learning_rate=1.0,
1169 use_gradient_accumulation=True,
1170 clip_weight_min=None,
1171 clip_weight_max=None,
1172 weight_decay_factor=None,
1173 multiply_weight_decay_factor_by_learning_rate=None,
1174 )
1175 self.tau = tau
1176 self.max_delta = max_delta
1177 self.outlier_threshold = outlier_threshold
1178 self.weight_exponent = weight_exponent
1181DeviceConfig = collections.namedtuple('DeviceConfig',
1182 ['num_hosts', 'num_cores', 'job_name'])
1185class TPUEmbedding:
1186 """API for using TPU for embedding.
1188 Example:
1189 ```
1190 table_config_user = tpu_embedding.TableConfig(
1191 vocabulary_size=4, dimension=2,
1192 initializer=initializer, combiner='mean')
1193 table_to_config_dict = {'video': table_config_video,
1194 'user': table_config_user}
1195 feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
1196 'favorited': tpu_embedding.FeatureConfig('video'),
1197 'friends': tpu_embedding.FeatureConfig('user')}
1198 batch_size = 4
1199 num_hosts = 1
1200 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
1201 mode = tpu_embedding.TRAINING
1202 embedding = tpu_embedding.TPUEmbedding(
1203 table_to_config_dict, feature_to_config_dict,
1204 batch_size, num_hosts, mode, optimization_parameters)
1206 batch_size_per_core = embedding.batch_size_per_core
1207 sparse_features_list = []
1208 for host in hosts:
1209 with ops.device(host):
1210 for _ in range(embedding.num_cores_per_host):
1211 sparse_features = {}
1212 sparse_features['watched'] = sparse_tensor.SparseTensor(...)
1213 sparse_features['favorited'] = sparse_tensor.SparseTensor(...)
1214 sparse_features['friends'] = sparse_tensor.SparseTensor(...)
1215 sparse_features_list.append(sparse_features)
1217 enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)
1218 embedding_variables_and_ops = embedding.create_variables_and_ops()
1220 def computation():
1221 activations = embedding.get_activations()
1222 loss = compute_loss(activations)
1224 base_optimizer = gradient_descent.GradientDescentOptimizer(
1225 learning_rate=1)
1226 cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer(
1227 base_optimizer)
1229 train_op = cross_shard_optimizer.minimize(loss)
1230 gradients = (
1231 tpu_embedding_gradient.get_gradients_through_compute_gradients(
1232 cross_shard_optimizer, loss, activations)
1233 send_gradients_op = embedding.generate_send_gradients_op(gradients)
1234 with ops.control_dependencies([train_op, send_gradients_op]):
1235 loss = array_ops.identity(loss)
1237 loss = tpu.shard(computation,
1238 num_shards=embedding.num_cores)
1240 with self.test_session() as sess:
1241 sess.run(tpu.initialize_system(embedding_config=
1242 embedding.config_proto))
1243 sess.run(variables.global_variables_initializer())
1244 sess.run(embedding_variables_and_ops.load_ops())
1245 sess.run(enqueue_ops)
1246 loss_val = sess.run(loss)
1247 ```
1249 Example with weight decay:
1251 >>> def learning_rate_fn(global_step):
1252 ... return tf.compat.v1.train.polynomial_decay(
1253 ... learning_rate=5e-5,
1254 ... global_step=global_step,
1255 ... decay_steps=100000,
1256 ... end_learning_rate=0.0)
1257 >>> wordpiece_table_config = TableConfig(
1258 ... vocabulary_size=119547,
1259 ... dimension=256,
1260 ... learning_rate_fn=learning_rate_fn)
1261 >>> wordpiece_feature_config = FeatureConfig(
1262 ... table_id='bert/embeddings/word_embeddings',
1263 ... max_sequence_length=512)
1264 >>> optimization_parameters = AdamParameters(
1265 ... learning_rate=5e-5,
1266 ... epsilon=1e-6,
1267 ... weight_decay_factor=0.01,
1268 ... multiply_weight_decay_factor_by_learning_rate=True)
1269 >>> tpu_embedding = TPUEmbedding(
1270 ... table_to_config_dict={
1271 ... 'bert/embeddings/word_embeddings': wordpiece_table_config,
1272 ... },
1273 ... feature_to_config_dict={'input_ids': wordpiece_feature_config},
1274 ... batch_size=128,
1275 ... mode=TRAINING,
1276 ... optimization_parameters=optimization_parameters,
1277 ... master='')
1278 >>> with tf.Graph().as_default():
1279 ... init_tpu_op = tf.compat.v1.tpu.initialize_system(
1280 ... embedding_config=tpu_embedding.config_proto)
1281 ... tf.compat.v1.Session().run(init_tpu_op)
1282 """
1284 # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
1285 # the feature should not be used to update embedding table (cr/204852758,
1286 # cr/204940540). Also, this can support different combiners for different
1287 # features within the same table.
1288 # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
1289 # to `FeatureConfig`?
1291 # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
1292 # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
1293 # respectively?
1295 # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
1296 # for-loops around construction of inputs.
1298 # `optimization_parameter` applies to all tables. If the need arises,
1299 # we can add `optimization_parameters` to `TableConfig` to override this
1300 # global setting.
1301 def __init__(self,
1302 table_to_config_dict,
1303 feature_to_config_dict,
1304 batch_size,
1305 mode,
1306 master=None,
1307 optimization_parameters=None,
1308 cluster_def=None,
1309 pipeline_execution_with_tensor_core=False,
1310 partition_strategy='div',
1311 profile_data_directory=None,
1312 device_config=None,
1313 master_job_name=None):
1314 """API for using TPU for embedding lookups.
1316 Args:
1317 table_to_config_dict: A dictionary mapping from string of table name to
1318 `TableConfig`. Table refers to an embedding table, e.g. `params`
1319 argument to `tf.nn.embedding_lookup_sparse()`.
1320 feature_to_config_dict: A dictionary mapping from string of feature name
1321 to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
1322 e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
1323 batch_size: An `int` representing the global batch size.
1324 mode: `TRAINING` or `INFERENCE`.
1325 master: A `string` representing the TensorFlow master to use.
1326 optimization_parameters: `AdagradParameters`, `AdamParameters`,
1327 `Stochasticgradientdescentparameters`. Must be set in training unless
1328 all tables specify their own optimizers. And it must be `None` in
1329 inference.
1330 cluster_def: A ClusterDef object describing the TPU cluster.
1331 pipeline_execution_with_tensor_core: setting this to `True` makes training
1332 faster, but trained model will be different if step N and step N+1
1333 involve the same set of embedding IDs. Please see
1334 `tpu_embedding_configuration.proto` for details.
1335 partition_strategy: A string, either 'mod' or 'div', specifying how to map
1336 the lookup id to the embedding tensor. For more information see
1337 `tf.nn.embedding_lookup_sparse`.
1338 profile_data_directory: Directory where embedding lookup statistics are
1339 stored. These statistics summarize information about the inputs to the
1340 embedding lookup operation, in particular, the average number of
1341 embedding IDs per example and how well the embedding IDs are load
1342 balanced across the system. The lookup statistics are used during TPU
1343 initialization for embedding table partitioning. Collection of lookup
1344 statistics is done at runtime by profiling the embedding inputs, only a
1345 small fraction of input samples are profiled to minimize host CPU
1346 overhead. Once a suitable number of samples are profiled, the lookup
1347 statistics are saved to table-specific files in the profile data
1348 directory generally at the end of a TPU training loop. The filename
1349 corresponding to each table is obtained by hashing table specific
1350 parameters (e.g., table name and number of features) and global
1351 configuration parameters (e.g., sharding strategy and task count). The
1352 same profile data directory can be shared among several models to reuse
1353 embedding lookup statistics.
1354 device_config: A DeviceConfig instance, used when `master` and
1355 `cluster_def` are both `None`.
1356 master_job_name: if set, overrides the master job name used to schedule
1357 embedding ops.
1359 Raises:
1360 ValueError: if any input is invalid.
1361 """
1362 if partition_strategy not in ('div', 'mod'):
1363 raise ValueError(f'partition_strategy must be "div" or "mod". '
1364 f'Received: {partition_strategy}.')
1365 self._partition_strategy = partition_strategy
1367 self._profile_data_directory = profile_data_directory
1369 _validate_table_to_config_dict(table_to_config_dict)
1370 # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
1371 self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
1373 _validate_feature_to_config_dict(table_to_config_dict,
1374 feature_to_config_dict)
1375 self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
1376 self._table_to_features_dict = (
1377 _create_table_to_features_dict(self._feature_to_config_dict))
1378 self._combiners = _create_combiners(self._table_to_config_dict,
1379 self._table_to_features_dict)
1381 self._batch_size = batch_size
1383 if master is None and cluster_def is None:
1384 if device_config is None:
1385 raise ValueError('When master and cluster_def are both None,'
1386 'device_config must be set but is not.')
1387 if device_config.num_cores % device_config.num_hosts:
1388 raise ValueError('num_hosts ({}) should divide num_cores ({}) '
1389 'but does not.'.format(device_config.num_cores,
1390 device_config.num_hosts))
1391 self._num_hosts = device_config.num_hosts
1392 self._num_cores = device_config.num_cores
1393 self._num_cores_per_host = self._num_cores // self._num_hosts
1394 self._hosts = [
1395 '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i)
1396 for i in range(self._num_hosts)
1397 ]
1398 else:
1399 tpu_system_metadata = (
1400 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access
1401 master,
1402 cluster_def=cluster_def))
1403 if tpu_system_metadata.num_cores == 0:
1404 raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
1405 'TPUs.'.format(master))
1406 self._num_hosts = tpu_system_metadata.num_hosts
1407 if master_job_name is None:
1408 try:
1409 master_job_name = tpu_system_metadata_lib.master_job(
1410 master, cluster_def)
1411 except ValueError as e:
1412 raise ValueError(str(e) + ' Please specify a master_job_name.')
1413 self._hosts = []
1414 for device in tpu_system_metadata.devices:
1415 if 'device:CPU:' in device.name and (master_job_name is None or
1416 master_job_name in device.name):
1417 self._hosts.append(device.name)
1418 self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
1419 self._num_cores = tpu_system_metadata.num_cores
1421 _validate_batch_size(self._batch_size, self._num_cores)
1422 self._batch_size_per_core = self._batch_size // self._num_cores
1424 # TODO(shizhiw): remove `mode`?
1425 if mode == TRAINING:
1426 _validate_optimization_parameters(optimization_parameters,
1427 self._table_to_config_dict)
1428 self._optimization_parameters = optimization_parameters
1429 elif mode == INFERENCE:
1430 if optimization_parameters is not None:
1431 raise ValueError(f'`optimization_parameters` should be `None` '
1432 f'for inference mode. '
1433 f'Received: {optimization_parameters}.')
1434 self._optimization_parameters = (StochasticGradientDescentParameters(1.))
1435 else:
1436 raise ValueError('`mode` only supports {} and {}; got {}.'.format(
1437 TRAINING, INFERENCE, mode))
1438 self._mode = mode
1440 # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
1441 # and create special handler for inference that inherits from
1442 # StochasticGradientDescentHandler with more user-friendly error message
1443 # on get_slot().
1444 self._optimizer_handler_dict = self._get_optimizer_handler_by_table()
1446 self._pipeline_execution_with_tensor_core = (
1447 pipeline_execution_with_tensor_core)
1448 self._learning_rate_fn = list(
1449 set(c.learning_rate_fn
1450 for c in self._table_to_config_dict.values()
1451 if c.learning_rate_fn is not None))
1452 self._learning_rate_fn_to_tag = {
1453 fn: id for id, fn in enumerate(self._learning_rate_fn)
1454 }
1456 self._config_proto = self._create_config_proto()
1458 @property
1459 def hosts(self):
1460 """A list of device names for CPU hosts.
1462 Returns:
1463 A list of device names for CPU hosts.
1464 """
1465 return copy.copy(self._hosts)
1467 # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and
1468 # to be consistent with `tpu_embedding_configuration.proto`.
1469 @property
1470 def num_cores_per_host(self):
1471 """Number of TPU cores on a CPU host.
1473 Returns:
1474 Number of TPU cores on a CPU host.
1475 """
1476 return self._num_cores_per_host
1478 @property
1479 def num_cores(self):
1480 """Total number of TPU cores on all hosts.
1482 Returns:
1483 Total number of TPU cores on all hosts.
1484 """
1485 return self._num_cores
1487 @property
1488 def batch_size_per_core(self):
1489 """Batch size for each TPU core.
1491 The sparse tensors in `sparse_features_list` to `generate_enqueue_ops`
1492 must have batch dimension equal to this.
1494 Returns:
1495 Batch size for each TPU core.
1496 """
1497 return self._batch_size_per_core
1499 @property
1500 def config_proto(self):
1501 """Create embedding config proto for `tpu.initialize_system()`.
1503 Returns:
1504 an `TPUEmbeddingConfiguration` proto describing the desired
1505 configuration of the hardware embedding lookup tables, which
1506 is passed to `tpu.initialize_system()`.
1507 """
1508 return self._config_proto
1510 @property
1511 def table_to_config_dict(self):
1512 return copy.copy(self._table_to_config_dict)
1514 @property
1515 def feature_to_config_dict(self):
1516 return copy.copy(self._feature_to_config_dict)
1518 @property
1519 def table_to_features_dict(self):
1520 return copy.copy(self._table_to_features_dict)
1522 @property
1523 def optimization_parameters(self):
1524 return self._optimization_parameters
1526 def _create_config_proto(self):
1527 """Create `TPUEmbeddingConfiguration`."""
1528 config_proto = elc.TPUEmbeddingConfiguration()
1529 for table in self._table_to_config_dict:
1530 table_descriptor = config_proto.table_descriptor.add()
1531 table_descriptor.name = table
1533 table_config = self._table_to_config_dict[table]
1534 # For small tables, we pad to the number of hosts so that at least one
1535 # id will be assigned to each host.
1536 table_descriptor.vocabulary_size = max(table_config.vocabulary_size,
1537 len(self.hosts))
1538 table_descriptor.dimension = table_config.dimension
1540 optimization_parameters = (
1541 self._optimizer_handler_dict[table].get_optimization_parameters())
1543 parameters = table_descriptor.optimization_parameters
1544 if table_config.learning_rate:
1545 parameters.learning_rate.constant = table_config.learning_rate
1546 elif table_config.learning_rate_fn:
1547 parameters.learning_rate.dynamic.tag = (
1548 self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
1549 else:
1550 parameters.learning_rate.constant = (
1551 optimization_parameters.learning_rate)
1552 parameters.gradient_accumulation_status = (
1553 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
1554 if optimization_parameters.use_gradient_accumulation else
1555 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
1557 if optimization_parameters.clip_gradient_min is not None:
1558 parameters.gradient_clipping_limits.lower.value = (
1559 optimization_parameters.clip_gradient_min)
1560 if optimization_parameters.clip_gradient_max is not None:
1561 parameters.gradient_clipping_limits.upper.value = (
1562 optimization_parameters.clip_gradient_max)
1564 if optimization_parameters.clip_weight_min is not None:
1565 parameters.clipping_limits.lower.value = (
1566 optimization_parameters.clip_weight_min)
1567 if optimization_parameters.clip_weight_max is not None:
1568 parameters.clipping_limits.upper.value = (
1569 optimization_parameters.clip_weight_max)
1570 if optimization_parameters.weight_decay_factor:
1571 parameters.weight_decay_factor = (
1572 optimization_parameters.weight_decay_factor)
1573 if (optimization_parameters
1574 .multiply_weight_decay_factor_by_learning_rate):
1575 parameters.multiply_weight_decay_factor_by_learning_rate = True
1576 if table_config.hot_id_replication:
1577 parameters.hot_id_replication_configuration.status = (
1578 optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
1579 optimizer_handler = self._optimizer_handler_dict[table]
1580 optimizer_handler.set_optimization_parameters(table_descriptor)
1582 table_to_id = {
1583 table: i for i, table in enumerate(self._table_to_config_dict)
1584 }
1586 # Set feature descriptor field in the config proto.
1587 for table in self._table_to_features_dict:
1588 features = self._table_to_features_dict[table]
1589 for feature in features:
1590 feature_descriptor = config_proto.feature_descriptor.add()
1592 feature_descriptor.table_id = table_to_id[
1593 self._feature_to_config_dict[feature].table_id]
1594 if self._feature_to_config_dict[feature].max_sequence_length > 0:
1595 feature_descriptor.input_shape.extend([
1596 self._batch_size_per_core,
1597 self._feature_to_config_dict[feature].max_sequence_length
1598 ])
1599 else:
1600 feature_descriptor.input_shape.extend([self._batch_size_per_core])
1602 config_proto.mode = self._mode
1603 config_proto.num_hosts = self._num_hosts
1604 config_proto.num_tensor_cores = self._num_cores
1605 config_proto.sharding_strategy = (
1606 elc.TPUEmbeddingConfiguration.DIV_DEFAULT if self._partition_strategy
1607 == 'div' else elc.TPUEmbeddingConfiguration.MOD)
1608 config_proto.pipeline_execution_with_tensor_core = (
1609 self._pipeline_execution_with_tensor_core)
1610 if self._profile_data_directory:
1611 config_proto.profile_data_directory = self._profile_data_directory
1613 return config_proto
1615 def create_variables_and_ops(self,
1616 embedding_variable_name_by_table=None,
1617 slot_variable_names_by_table=None):
1618 """Create embedding and slot variables, with ops to load and retrieve them.
1620 N.B.: the retrieve embedding variables (including slot variables) ops are
1621 returned as lambda fn, as the call side might want to impose control
1622 dependencies between the TPU computation and retrieving actions. For
1623 example, the following code snippet ensures the TPU computation finishes
1624 first, and then we pull the variables back from TPU to CPU.
1626 ```
1627 updates_ops = []
1628 with ops.control_dependencies([loss]):
1629 for op_fn in retrieve_parameters_op_fns:
1630 update_ops.append(op_fn())
1631 ```
1633 Args:
1634 embedding_variable_name_by_table: A dictionary mapping from string of
1635 table name to string of embedding variable name. If `None`, defaults
1636 from `get_default_slot_variable_names()` will be used.
1637 slot_variable_names_by_table: A dictionary mapping from string of table
1638 name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
1639 `None`, defaults from `get_default_slot_variable_names()` will be used.
1641 Returns:
1642 `tpu_embedding.VariablesAndOps` with:
1643 A dictionary mapping from string of table name to embedding variables,
1644 A dictionary mapping from string of table name to AdagradSlotVariables,
1645 AdamSlotVariables etc with slot variables,
1646 A function which returns a list of ops to load embedding and slot
1647 variables from CPU to TPU.
1648 A function which returns a list of ops to retrieve embedding and slot
1649 variables from TPU to CPU.
1650 """
1651 embedding_variables_by_table = {}
1652 slot_variables_by_table = {}
1653 load_op_fns = []
1654 retrieve_op_fns = []
1656 for i, table in enumerate(self._table_to_config_dict):
1657 if embedding_variable_name_by_table:
1658 embedding_variable_name = embedding_variable_name_by_table[table]
1659 else:
1660 embedding_variable_name = table
1661 if slot_variable_names_by_table:
1662 slot_variable_names = slot_variable_names_by_table[table]
1663 else:
1664 optimizer_handler = self._optimizer_handler_dict[table]
1665 slot_variable_names = (
1666 optimizer_handler.get_default_slot_variable_names(table))
1668 # TODO(b/139144091): Multi-host support for mid-level API in
1669 # eager context (TF 2.0)
1670 # Workaround below allows single-host use case in TF 2.0
1671 if context.executing_eagerly():
1672 device = ''
1673 else:
1674 device = _create_device_fn(self._hosts)
1676 with ops.device(device):
1677 table_variables = _create_partitioned_variables(
1678 name=embedding_variable_name,
1679 num_hosts=self._num_hosts,
1680 vocabulary_size=self._table_to_config_dict[table].vocabulary_size,
1681 embedding_dimension=self._table_to_config_dict[table].dimension,
1682 initializer=self._table_to_config_dict[table].initializer,
1683 collections=[ops.GraphKeys.GLOBAL_VARIABLES])
1684 embedding_variables_by_table[table] = table_variables
1686 # Only loads embedding config to load/retrieve nodes for the first table
1687 # on the first host, other nodes would use config from the first node.
1688 config = None if i else self.config_proto.SerializeToString()
1689 slot_variables_for_table, load_ops_fn, retrieve_ops_fn = (
1690 self._optimizer_handler_dict[table].create_variables_and_ops(
1691 table, slot_variable_names, self._num_hosts,
1692 self._table_to_config_dict[table], table_variables, config))
1693 slot_variables_by_table[table] = slot_variables_for_table
1694 load_op_fns.append(load_ops_fn)
1695 retrieve_op_fns.append(retrieve_ops_fn)
1697 def load_ops():
1698 """Calls and returns the load ops for each embedding table.
1700 Returns:
1701 A list of ops to load embedding and slot variables from CPU to TPU.
1702 """
1703 load_ops_list = []
1704 for load_op_fn in load_op_fns:
1705 load_ops_list.extend(load_op_fn())
1706 return load_ops_list
1708 def retrieve_ops():
1709 """Calls and returns the retrieve ops for each embedding table.
1711 Returns:
1712 A list of ops to retrieve embedding and slot variables from TPU to CPU.
1713 """
1714 retrieve_ops_list = []
1715 for retrieve_op_fn in retrieve_op_fns:
1716 retrieve_ops_list.extend(retrieve_op_fn())
1717 return retrieve_ops_list
1719 return VariablesAndOps(embedding_variables_by_table,
1720 slot_variables_by_table, load_ops, retrieve_ops)
1722 def generate_enqueue_ops(
1723 self,
1724 enqueue_datas_list,
1725 mode_override=None,
1726 ragged=False,
1727 ):
1728 """Generate enqueue ops.
1730 Args:
1731 enqueue_datas_list: a list of dictionary mapping from string of feature
1732 names to EnqueueData. Each dictionary is for one TPU core. Dictionaries
1733 for the same host should be contiguous in the list.
1734 mode_override: A string input that overrides the mode specified in the
1735 TPUEmbeddingConfiguration. Supported values are {'unspecified',
1736 'inference', 'training', 'backward_pass_only'}. When set to
1737 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
1738 otherwise mode_override is used (optional).
1739 ragged: If True, creates RaggedTensor enqueue ops rather than
1740 SparseTensor.
1742 Returns:
1743 Ops to enqueue to TPU for embedding.
1744 """
1745 self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
1746 return [
1747 self._generate_enqueue_op( # pylint: disable=g-complex-comprehension
1748 enqueue_datas,
1749 device_ordinal=i % self._num_cores_per_host,
1750 mode_override=mode_override,
1751 ragged=ragged,
1752 ) for i, enqueue_datas in enumerate(enqueue_datas_list)
1753 ]
1755 def _validate_generate_enqueue_ops_enqueue_datas_list(self,
1756 enqueue_datas_list):
1757 """Validate `enqueue_datas_list`."""
1759 def _check_agreement(data, name, feature, enqueue_data):
1760 """Helper function to check device agreement."""
1761 if (data is not None and
1762 data.device != enqueue_data.embedding_indices.device):
1763 raise ValueError('Device of {0} does not agree with that of'
1764 'embedding_indices for feature {1}.'.format(
1765 name, feature))
1767 feature_set = set(self._feature_to_config_dict.keys())
1768 contiguous_device = None
1769 for i, enqueue_datas in enumerate(enqueue_datas_list):
1770 used_feature_set = set(enqueue_datas.keys())
1772 # Check features are valid.
1773 missing_feature_set = feature_set - used_feature_set
1774 if missing_feature_set:
1775 raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
1776 'in `feature_to_config_dict`: {}.'.format(
1777 i, missing_feature_set))
1779 extra_feature_set = used_feature_set - feature_set
1780 if extra_feature_set:
1781 raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
1782 'in `feature_to_config_dict`: {}.'.format(
1783 i, extra_feature_set))
1785 device = None
1786 device_feature = None
1787 for feature, enqueue_data in enqueue_datas.items():
1788 combiner = self._table_to_config_dict[
1789 self._feature_to_config_dict[feature].table_id].combiner
1791 if isinstance(enqueue_data, EnqueueData):
1792 if enqueue_data.sample_indices is None and combiner:
1793 logging.warn(
1794 'No sample indices set for features %f table %f but '
1795 'combiner is set to %s.', feature,
1796 self._feature_to_config_dict[feature].table_id, combiner)
1797 _check_agreement(enqueue_data.sample_indices, 'sample_indices',
1798 feature, enqueue_data)
1799 _check_agreement(enqueue_data.aggregation_weights,
1800 'aggregation_weights', feature, enqueue_data)
1802 elif isinstance(enqueue_data, RaggedEnqueueData):
1803 if enqueue_data.row_splits is None and combiner:
1804 logging.warn(
1805 'No row splits set for features %f table %f but '
1806 'combiner is set to %s.', feature,
1807 self._feature_to_config_dict[feature].table_id, combiner)
1808 _check_agreement(enqueue_data.row_splits, 'row_splits', feature,
1809 enqueue_data)
1810 _check_agreement(enqueue_data.aggregation_weights,
1811 'aggregation_weights', feature, enqueue_data)
1812 else:
1813 raise ValueError(
1814 '`enqueue_datas_list[{}]` has a feature that is not mapped to '
1815 '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
1816 i, feature))
1817 # Check all features are on the same device.
1818 if device is None:
1819 device = enqueue_data.embedding_indices.device
1820 device_feature = feature
1821 else:
1822 if device != enqueue_data.embedding_indices.device:
1823 raise ValueError('Devices are different between features in '
1824 '`enqueue_datas_list[{}]`; '
1825 'devices: {}, {}; features: {}, {}.'.format(
1826 i, device,
1827 enqueue_data.embedding_indices.device, feature,
1828 device_feature))
1830 if i % self._num_cores_per_host:
1831 if device != contiguous_device:
1832 raise ValueError('We expect the `enqueue_datas` which are on the '
1833 'same host to be contiguous in '
1834 '`enqueue_datas_list`, '
1835 '`enqueue_datas_list[{}]` is on device {}, '
1836 'but is expected to be on device {}.'.format(
1837 i, device, contiguous_device))
1838 else:
1839 contiguous_device = device
1841 def _generate_enqueue_op(self,
1842 enqueue_datas,
1843 device_ordinal,
1844 mode_override=None,
1845 ragged=False):
1846 """Creates op for enqueuing batch to TPU."""
1847 enqueue_data0 = list(enqueue_datas.values())[0]
1848 with ops.colocate_with(enqueue_data0.embedding_indices):
1849 return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
1850 device_ordinal=device_ordinal,
1851 combiners=self._combiners,
1852 mode_override=mode_override,
1853 **self._format_for_tpu_embedding_arbitrary_tensor_batch(
1854 enqueue_datas, ragged))
1856 def _format_for_tpu_embedding_arbitrary_tensor_batch(self, enqueue_datas,
1857 ragged):
1858 """Format features for `enqueue_tpu_embedding_arbitrary_tensor_batch()`.
1860 Args:
1861 enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
1862 ragged: If True, extract row splits from the data rather than sample
1863 indices.
1865 Returns:
1866 Dict of arguments for `enqueue_tpu_embedding_arbitrary_tensor_batch()`.
1867 """
1869 kwargs = {
1870 'sample_indices_or_row_splits': [],
1871 'embedding_indices': [],
1872 'aggregation_weights': [],
1873 }
1874 int_zeros = array_ops.zeros((0,), dtype=dtypes.int64)
1875 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
1876 for table in self._table_to_features_dict:
1877 features = self._table_to_features_dict[table]
1878 for feature in features:
1879 enqueue_data = enqueue_datas[feature]
1880 if ragged:
1881 kwargs['sample_indices_or_row_splits'].append(
1882 enqueue_data.row_splits if enqueue_data
1883 .row_splits is not None else int_zeros)
1884 else:
1885 if (self._feature_to_config_dict[feature].max_sequence_length > 0 and
1886 enqueue_data.sample_indices is not None and
1887 enqueue_data.sample_indices.shape[1] == 2):
1888 # Pad the sample indices as if the enqueued sparse tensor is rank 2.
1889 sample_indices = array_ops.pad(
1890 enqueue_data.sample_indices, paddings=[[0, 0], [0, 1]])
1891 kwargs['sample_indices_or_row_splits'].append(sample_indices)
1892 else:
1893 # If the sample_indices is rank 1 or not present, treat it as dense
1894 # tensor.
1895 if (enqueue_data.sample_indices is None or
1896 enqueue_data.sample_indices.shape[1] == 1):
1897 kwargs['sample_indices_or_row_splits'].append(int_zeros)
1898 else:
1899 kwargs['sample_indices_or_row_splits'].append(
1900 enqueue_data.sample_indices)
1902 kwargs['aggregation_weights'].append(
1903 enqueue_data.aggregation_weights if enqueue_data
1904 .aggregation_weights is not None else float_zeros)
1906 kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
1907 return kwargs
1909 def get_activations(self):
1910 """Get activations for features.
1912 This should be called within `computation` that is passed to
1913 `tpu.replicate` and friends.
1915 Returns:
1916 A dictionary mapping from `String` of feature name to `Tensor`
1917 of activation.
1918 """
1919 recv_activations = tpu_ops.recv_tpu_embedding_activations(
1920 num_outputs=len(self._feature_to_config_dict),
1921 config=self._config_proto.SerializeToString())
1923 activations = collections.OrderedDict()
1924 index = 0
1925 for table in self._table_to_features_dict:
1926 for feature in self._table_to_features_dict[table]:
1927 activations[feature] = recv_activations[index]
1928 index += 1
1929 return activations
1931 def generate_send_gradients_op(self, feature_to_gradient_dict, step=None):
1932 """Send gradient to TPU embedding.
1934 Args:
1935 feature_to_gradient_dict: dict mapping feature names to gradient wrt
1936 activations.
1937 step: the current global step, used for dynamic learning rate.
1939 Returns:
1940 SendTPUEmbeddingGradients Op.
1942 Raises:
1943 RuntimeError: If `mode` is not `TRAINING`.
1944 """
1945 if self._mode != TRAINING:
1946 raise RuntimeError('Only in training mode gradients need to '
1947 'be sent to TPU embedding; got mode {}.'.format(
1948 self._mode))
1949 if step is None and self._learning_rate_fn:
1950 raise ValueError('There are dynamic learning rates but step is None.')
1952 gradients = []
1953 for table in self._table_to_features_dict:
1954 for feature in self._table_to_features_dict[table]:
1955 gradients.append(feature_to_gradient_dict[feature])
1957 return tpu_ops.send_tpu_embedding_gradients(
1958 inputs=gradients,
1959 learning_rates=[
1960 math_ops.cast(fn(step), dtype=dtypes.float32)
1961 for fn in self._learning_rate_fn
1962 ],
1963 config=self.config_proto.SerializeToString())
1965 def _get_optimizer_handler_by_table(self):
1966 optimizer_handlers = {}
1967 for table, table_config in self.table_to_config_dict.items():
1968 if table_config.optimization_parameters is not None:
1969 optimizer = table_config.optimization_parameters
1970 else:
1971 optimizer = self._optimization_parameters
1972 optimizer_handlers[table] = _get_optimization_handler(optimizer)
1974 return optimizer_handlers
1977def _validate_table_to_config_dict(table_to_config_dict):
1978 """Validate `table_to_config_dict`."""
1979 for k, v in table_to_config_dict.items():
1980 if not isinstance(v, TableConfig):
1981 raise ValueError('Value of `table_to_config_dict` must be of type '
1982 '`TableConfig`, got {} for {}.'.format(type(v), k))
1985def _validate_feature_to_config_dict(table_to_config_dict,
1986 feature_to_config_dict):
1987 """Validate `feature_to_config_dict`."""
1988 used_table_set = set(
1989 [feature.table_id for feature in feature_to_config_dict.values()])
1990 table_set = set(table_to_config_dict.keys())
1992 unused_table_set = table_set - used_table_set
1993 if unused_table_set:
1994 raise ValueError(
1995 '`table_to_config_dict` specifies table that is not '
1996 'used in `feature_to_config_dict`: {}.'.format(unused_table_set))
1998 extra_table_set = used_table_set - table_set
1999 if extra_table_set:
2000 raise ValueError(
2001 '`feature_to_config_dict` refers to a table that is not '
2002 'specified in `table_to_config_dict`: {}.'.format(extra_table_set))
2005def _validate_batch_size(batch_size, num_cores):
2006 if batch_size % num_cores:
2007 raise ValueError('`batch_size` is not a multiple of number of '
2008 'cores. `batch_size`={}, `_num_cores`={}.'.format(
2009 batch_size, num_cores))
2012def _validate_optimization_parameters(optimization_parameters,
2013 table_to_config_dict):
2014 """Validate global optimization_parameters and per table optimizers.
2016 If global optimizer is `None`, all table optimizers should be non `None`.
2018 Args:
2019 optimization_parameters: global optimizer provided in `TPUEmbedding`
2020 constructor.
2021 table_to_config_dict: A dictionary mapping from string of table name to
2022 `TableConfig`.
2023 """
2024 tbl_optimizer_missing = False
2025 for _, table_config in table_to_config_dict.items():
2026 if table_config.optimization_parameters is None:
2027 tbl_optimizer_missing = True
2028 break
2030 if optimization_parameters:
2031 if not isinstance(optimization_parameters, _OptimizationParameters):
2032 raise ValueError('`optimization_parameters` must inherit from '
2033 '`_OptimizationParameters`. '
2034 '`type(optimization_parameters)`={}'.format(
2035 type(optimization_parameters)))
2036 else:
2037 # Missing global optimization_parameters.
2038 if tbl_optimizer_missing:
2039 raise ValueError('`optimization_parameters` is missing.')
2042class _OptimizerHandler:
2043 """Interface class for handling optimizer specific logic."""
2045 def __init__(self, optimization_parameters):
2046 self._optimization_parameters = optimization_parameters
2048 def get_optimization_parameters(self):
2049 return self._optimization_parameters
2051 def set_optimization_parameters(self, table_descriptor):
2052 raise NotImplementedError()
2054 def get_default_slot_variable_names(self, table):
2055 raise NotImplementedError()
2057 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2058 table_config, table_variables, config_proto):
2059 raise NotImplementedError()
2062class _AdagradHandler(_OptimizerHandler):
2063 """Handles Adagrad specific logic."""
2065 def set_optimization_parameters(self, table_descriptor):
2066 table_descriptor.optimization_parameters.adagrad.SetInParent()
2068 def get_default_slot_variable_names(self, table):
2069 return AdagradSlotVariableNames('{}/{}'.format(table, 'Adagrad'))
2071 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2072 table_config, table_variables, config_proto):
2073 accumulator_initializer = init_ops.constant_initializer(
2074 self._optimization_parameters.initial_accumulator)
2075 accumulator_variables = _create_partitioned_variables(
2076 name=slot_variable_names.accumulator,
2077 num_hosts=num_hosts,
2078 vocabulary_size=table_config.vocabulary_size,
2079 embedding_dimension=table_config.dimension,
2080 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2081 initializer=accumulator_initializer)
2082 slot_variables = AdagradSlotVariables(accumulator_variables)
2084 def load_ops_fn():
2085 """Returns the retrieve ops for AdaGrad embedding tables.
2087 Returns:
2088 A list of ops to load embedding and slot variables from CPU to TPU.
2089 """
2090 config = config_proto
2091 load_op_list = []
2092 for host_id, table_variable, accumulator_variable in zip(
2093 range(num_hosts), table_variables, accumulator_variables):
2094 with ops.colocate_with(table_variable):
2095 load_parameters_op = (
2096 tpu_ops.load_tpu_embedding_adagrad_parameters(
2097 parameters=table_variable,
2098 accumulators=accumulator_variable,
2099 table_name=table,
2100 num_shards=num_hosts,
2101 shard_id=host_id,
2102 config=config))
2103 config = None
2104 load_op_list.append(load_parameters_op)
2105 return load_op_list
2107 def retrieve_ops_fn():
2108 """Returns the retrieve ops for AdaGrad embedding tables.
2110 Returns:
2111 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2112 """
2113 config = config_proto
2114 retrieve_op_list = []
2115 for host_id, table_variable, accumulator_variable in (zip(
2116 range(num_hosts), table_variables, accumulator_variables)):
2117 with ops.colocate_with(table_variable):
2118 retrieved_table, retrieved_accumulator = (
2119 tpu_ops.retrieve_tpu_embedding_adagrad_parameters(
2120 table_name=table,
2121 num_shards=num_hosts,
2122 shard_id=host_id,
2123 config=config))
2124 retrieve_parameters_op = control_flow_ops.group(
2125 state_ops.assign(table_variable, retrieved_table),
2126 state_ops.assign(accumulator_variable, retrieved_accumulator))
2127 config = None
2128 retrieve_op_list.append(retrieve_parameters_op)
2129 return retrieve_op_list
2131 return slot_variables, load_ops_fn, retrieve_ops_fn
2134class _AdagradMomentumHandler(_OptimizerHandler):
2135 """Handles Adagrad with Momentum specific logic.
2137 Creates slot variables and defines their initializers. Defines load/retrieve
2138 operations to be used for loading variables into TPU memory (from host memory)
2139 and retrieving variables from TPU memory (into host memory).
2140 """
2142 def set_optimization_parameters(self, table_descriptor):
2143 table_descriptor.optimization_parameters.adagrad_momentum.SetInParent()
2144 table_descriptor.optimization_parameters.adagrad_momentum.momentum = (
2145 self._optimization_parameters.momentum)
2146 table_descriptor.optimization_parameters.adagrad_momentum.use_nesterov = (
2147 self._optimization_parameters.use_nesterov)
2148 table_descriptor.optimization_parameters.adagrad_momentum.exponent = (
2149 self._optimization_parameters.exponent)
2150 table_descriptor.optimization_parameters.adagrad_momentum.beta2 = (
2151 self._optimization_parameters.beta2)
2152 table_descriptor.optimization_parameters.adagrad_momentum.epsilon = (
2153 self._optimization_parameters.epsilon)
2155 def get_default_slot_variable_names(self, table):
2156 return AdagradMomentumSlotVariableNames(
2157 '{}/{}/Accumulator'.format(table, 'AdagradMomentum'),
2158 '{}/{}/Momentum'.format(table, 'AdagradMomentum'))
2160 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2161 table_config, table_variables, config_proto):
2162 accumulator_initializer = init_ops.zeros_initializer()
2163 accumulator_variables = _create_partitioned_variables(
2164 name=slot_variable_names.accumulator,
2165 num_hosts=num_hosts,
2166 vocabulary_size=table_config.vocabulary_size,
2167 embedding_dimension=table_config.dimension,
2168 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2169 initializer=accumulator_initializer)
2170 momenta_initializer = init_ops.zeros_initializer()
2171 momenta_variables = _create_partitioned_variables(
2172 name=slot_variable_names.momenta,
2173 num_hosts=num_hosts,
2174 vocabulary_size=table_config.vocabulary_size,
2175 embedding_dimension=table_config.dimension,
2176 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2177 initializer=momenta_initializer)
2178 slot_variables = AdagradMomentumSlotVariables(accumulator_variables,
2179 momenta_variables)
2181 def load_ops_fn():
2182 """Returns the load ops for AdaGrad with momentum embedding tables.
2184 Returns:
2185 A list of ops to load embedding and slot variables from CPU to TPU.
2186 """
2187 config = config_proto
2188 load_op_list = []
2189 for host_id, table_variable, accumulator_variable, momenta_variable in zip(
2190 range(num_hosts), table_variables, accumulator_variables,
2191 momenta_variables):
2192 with ops.colocate_with(table_variable):
2193 load_parameters_op = (
2194 tpu_ops.load_tpu_embedding_adagrad_momentum_parameters(
2195 parameters=table_variable,
2196 accumulators=accumulator_variable,
2197 momenta=momenta_variable,
2198 table_name=table,
2199 num_shards=num_hosts,
2200 shard_id=host_id,
2201 config=config))
2202 config = None
2203 load_op_list.append(load_parameters_op)
2204 return load_op_list
2206 def retrieve_ops_fn():
2207 """Returns the retrieve ops for AdaGrad with momentum embedding tables.
2209 Returns:
2210 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2211 """
2212 config = config_proto
2213 retrieve_op_list = []
2214 for host_id, table_variable, accumulator_variable, momenta_variable in (
2215 zip(
2216 range(num_hosts), table_variables, accumulator_variables,
2217 momenta_variables)):
2218 with ops.colocate_with(table_variable):
2219 retrieved_table, retrieved_accumulator, retrieved_momenta = (
2220 tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters(
2221 table_name=table,
2222 num_shards=num_hosts,
2223 shard_id=host_id,
2224 config=config))
2225 retrieve_parameters_op = control_flow_ops.group(
2226 state_ops.assign(table_variable, retrieved_table),
2227 state_ops.assign(accumulator_variable, retrieved_accumulator),
2228 state_ops.assign(momenta_variable, retrieved_momenta))
2229 config = None
2230 retrieve_op_list.append(retrieve_parameters_op)
2231 return retrieve_op_list
2233 return slot_variables, load_ops_fn, retrieve_ops_fn
2236class _ProximalAdagradHandler(_OptimizerHandler):
2237 """Handles ProximalAdagrad specific logic."""
2239 def set_optimization_parameters(self, table_descriptor):
2240 table_descriptor.optimization_parameters.proximal_adagrad.SetInParent()
2241 table_descriptor.optimization_parameters.proximal_adagrad.l1 = (
2242 self._optimization_parameters.l1_regularization_strength)
2243 table_descriptor.optimization_parameters.proximal_adagrad.l2 = (
2244 self._optimization_parameters.l2_regularization_strength)
2246 def get_default_slot_variable_names(self, table):
2247 return ProximalAdagradSlotVariableNames('{}/{}'.format(
2248 table, 'ProximalAdagrad'))
2250 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2251 table_config, table_variables, config_proto):
2252 accumulator_initializer = init_ops.constant_initializer(
2253 self._optimization_parameters.initial_accumulator)
2254 accumulator_variables = _create_partitioned_variables(
2255 name=slot_variable_names.accumulator,
2256 num_hosts=num_hosts,
2257 vocabulary_size=table_config.vocabulary_size,
2258 embedding_dimension=table_config.dimension,
2259 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2260 initializer=accumulator_initializer)
2261 slot_variables = ProximalAdagradSlotVariables(accumulator_variables)
2263 def load_ops_fn():
2264 """Returns the retrieve ops for Proximal AdaGrad embedding tables.
2266 Returns:
2267 A list of ops to load embedding and slot variables from CPU to TPU.
2268 """
2269 config = config_proto
2270 load_op_list = []
2271 for host_id, table_variable, accumulator_variable in zip(
2272 range(num_hosts), table_variables, accumulator_variables):
2273 with ops.colocate_with(table_variable):
2274 load_parameters_op = (
2275 tpu_ops.load_tpu_embedding_proximal_adagrad_parameters(
2276 parameters=table_variable,
2277 accumulators=accumulator_variable,
2278 table_name=table,
2279 num_shards=num_hosts,
2280 shard_id=host_id,
2281 config=config))
2282 config = None
2283 load_op_list.append(load_parameters_op)
2284 return load_op_list
2286 def retrieve_ops_fn():
2287 """Returns the retrieve ops for Proximal AdaGrad embedding tables.
2289 Returns:
2290 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2291 """
2292 config = config_proto
2293 retrieve_op_list = []
2294 for host_id, table_variable, accumulator_variable in (zip(
2295 range(num_hosts), table_variables, accumulator_variables)):
2296 with ops.colocate_with(table_variable):
2297 retrieved_table, retrieved_accumulator = (
2298 tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters(
2299 table_name=table,
2300 num_shards=num_hosts,
2301 shard_id=host_id,
2302 config=config))
2303 retrieve_parameters_op = control_flow_ops.group(
2304 state_ops.assign(table_variable, retrieved_table),
2305 state_ops.assign(accumulator_variable, retrieved_accumulator))
2306 config = None
2307 retrieve_op_list.append(retrieve_parameters_op)
2308 return retrieve_op_list
2310 return slot_variables, load_ops_fn, retrieve_ops_fn
2313class _AdamHandler(_OptimizerHandler):
2314 """Handles Adam specific logic."""
2316 def set_optimization_parameters(self, table_descriptor):
2317 table_descriptor.optimization_parameters.adam.beta1 = (
2318 self._optimization_parameters.beta1)
2319 table_descriptor.optimization_parameters.adam.beta2 = (
2320 self._optimization_parameters.beta2)
2321 table_descriptor.optimization_parameters.adam.epsilon = (
2322 self._optimization_parameters.epsilon)
2323 table_descriptor.optimization_parameters.adam.use_non_lazy_adam = (
2324 not self._optimization_parameters.lazy_adam)
2325 table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = (
2326 self._optimization_parameters.sum_inside_sqrt)
2328 def get_default_slot_variable_names(self, table):
2329 return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'),
2330 '{}/{}/v'.format(table, 'Adam'))
2332 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2333 table_config, table_variables, config_proto):
2334 m_initializer = init_ops.zeros_initializer()
2335 m_variables = _create_partitioned_variables(
2336 name=slot_variable_names.m,
2337 num_hosts=num_hosts,
2338 vocabulary_size=table_config.vocabulary_size,
2339 embedding_dimension=table_config.dimension,
2340 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2341 initializer=m_initializer)
2342 v_initializer = init_ops.zeros_initializer()
2343 v_variables = _create_partitioned_variables(
2344 name=slot_variable_names.v,
2345 num_hosts=num_hosts,
2346 vocabulary_size=table_config.vocabulary_size,
2347 embedding_dimension=table_config.dimension,
2348 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2349 initializer=v_initializer)
2350 slot_variables = AdamSlotVariables(m_variables, v_variables)
2352 def load_ops_fn():
2353 """Returns the retrieve ops for AdaGrad embedding tables.
2355 Returns:
2356 A list of ops to load embedding and slot variables from CPU to TPU.
2357 """
2358 load_op_list = []
2359 config = config_proto
2360 for host_id, table_variable, m_variable, v_variable in (zip(
2361 range(num_hosts), table_variables, m_variables, v_variables)):
2362 with ops.colocate_with(table_variable):
2363 load_parameters_op = (
2364 tpu_ops.load_tpu_embedding_adam_parameters(
2365 parameters=table_variable,
2366 momenta=m_variable,
2367 velocities=v_variable,
2368 table_name=table,
2369 num_shards=num_hosts,
2370 shard_id=host_id,
2371 config=config))
2372 # Set config to None to enforce that config is only loaded to the first
2373 # table.
2374 config = None
2375 load_op_list.append(load_parameters_op)
2376 return load_op_list
2378 def retrieve_ops_fn():
2379 """Returns the retrieve ops for Adam embedding tables.
2381 Returns:
2382 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2383 """
2384 retrieve_op_list = []
2385 config = config_proto
2386 for host_id, table_variable, m_variable, v_variable in (zip(
2387 range(num_hosts), table_variables, m_variables, v_variables)):
2388 with ops.colocate_with(table_variable):
2389 retrieved_table, retrieved_m, retrieved_v = (
2390 tpu_ops.retrieve_tpu_embedding_adam_parameters(
2391 table_name=table,
2392 num_shards=num_hosts,
2393 shard_id=host_id,
2394 config=config))
2395 retrieve_parameters_op = control_flow_ops.group(
2396 state_ops.assign(table_variable, retrieved_table),
2397 state_ops.assign(m_variable, retrieved_m),
2398 state_ops.assign(v_variable, retrieved_v))
2399 config = None
2400 retrieve_op_list.append(retrieve_parameters_op)
2401 return retrieve_op_list
2403 return slot_variables, load_ops_fn, retrieve_ops_fn
2406class _FtrlHandler(_OptimizerHandler):
2407 """Handles Ftrl specific logic."""
2409 def set_optimization_parameters(self, table_descriptor):
2410 table_descriptor.optimization_parameters.ftrl.lr_power = (
2411 self._optimization_parameters.learning_rate_power)
2412 table_descriptor.optimization_parameters.ftrl.l1 = (
2413 self._optimization_parameters.l1_regularization_strength)
2414 table_descriptor.optimization_parameters.ftrl.l2 = (
2415 self._optimization_parameters.l2_regularization_strength)
2416 table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = (
2417 self._optimization_parameters.multiply_linear_by_learning_rate)
2418 table_descriptor.optimization_parameters.ftrl.beta = (
2419 self._optimization_parameters.beta)
2420 table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = (
2421 self._optimization_parameters.allow_zero_accumulator)
2423 def get_default_slot_variable_names(self, table):
2424 # These match the default slot variable names created by
2425 # tf.train.FtrlOptimizer.
2426 return FtrlSlotVariableNames(
2427 '{}/{}'.format(table, 'Ftrl'), # accumulator
2428 '{}/{}'.format(table, 'Ftrl_1')) # linear
2430 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2431 table_config, table_variables, config_proto):
2432 accumulator_initializer = init_ops.constant_initializer(
2433 self._optimization_parameters.initial_accumulator_value)
2434 accumulator_variables = _create_partitioned_variables(
2435 name=slot_variable_names.accumulator,
2436 num_hosts=num_hosts,
2437 vocabulary_size=table_config.vocabulary_size,
2438 embedding_dimension=table_config.dimension,
2439 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2440 initializer=accumulator_initializer)
2441 linear_initializer = init_ops.constant_initializer(
2442 self._optimization_parameters.initial_linear_value)
2443 linear_variables = _create_partitioned_variables(
2444 name=slot_variable_names.linear,
2445 num_hosts=num_hosts,
2446 vocabulary_size=table_config.vocabulary_size,
2447 embedding_dimension=table_config.dimension,
2448 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2449 initializer=linear_initializer)
2450 slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables)
2452 def load_ops_fn():
2453 """Returns the retrieve ops for Ftrl embedding tables.
2455 Returns:
2456 A list of ops to load embedding and slot variables from CPU to TPU.
2457 """
2458 config = config_proto
2459 load_op_list = []
2460 for host_id, table_variable, accumulator_variable, linear_variable in zip(
2461 range(num_hosts), table_variables, accumulator_variables,
2462 linear_variables):
2463 with ops.colocate_with(table_variable):
2464 load_parameters_op = (
2465 tpu_ops.load_tpu_embedding_ftrl_parameters(
2466 parameters=table_variable,
2467 accumulators=accumulator_variable,
2468 linears=linear_variable,
2469 table_name=table,
2470 num_shards=num_hosts,
2471 shard_id=host_id,
2472 config=config))
2473 config = None
2474 load_op_list.append(load_parameters_op)
2475 return load_op_list
2477 def retrieve_ops_fn():
2478 """Returns the retrieve ops for Ftrl embedding tables.
2480 Returns:
2481 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2482 """
2483 config = config_proto
2484 retrieve_op_list = []
2485 for host_id, table_variable, accumulator_variable, linear_variable in zip(
2486 range(num_hosts), table_variables, accumulator_variables,
2487 linear_variables):
2488 with ops.colocate_with(table_variable):
2489 retrieved_table, retrieved_accumulator, retrieved_linear = (
2490 tpu_ops.retrieve_tpu_embedding_ftrl_parameters(
2491 table_name=table,
2492 num_shards=num_hosts,
2493 shard_id=host_id,
2494 config=config))
2495 retrieve_parameters_op = control_flow_ops.group(
2496 state_ops.assign(table_variable, retrieved_table),
2497 state_ops.assign(accumulator_variable, retrieved_accumulator),
2498 state_ops.assign(linear_variable, retrieved_linear))
2499 config = None
2500 retrieve_op_list.append(retrieve_parameters_op)
2501 return retrieve_op_list
2503 return slot_variables, load_ops_fn, retrieve_ops_fn
2506class _ProximalYogiHandler(_OptimizerHandler):
2507 """Handles Proximal Yogi specific logic."""
2509 def set_optimization_parameters(self, table_descriptor):
2510 table_descriptor.optimization_parameters.proximal_yogi.SetInParent()
2511 table_descriptor.optimization_parameters.proximal_yogi.beta1 = (
2512 self._optimization_parameters.beta1)
2513 table_descriptor.optimization_parameters.proximal_yogi.beta2 = (
2514 self._optimization_parameters.beta2)
2515 table_descriptor.optimization_parameters.proximal_yogi.epsilon = (
2516 self._optimization_parameters.epsilon)
2517 table_descriptor.optimization_parameters.proximal_yogi.l1 = (
2518 self._optimization_parameters.l1_regularization_strength)
2519 table_descriptor.optimization_parameters.proximal_yogi.l2 = (
2520 self._optimization_parameters.l2_regularization_strength)
2522 def get_default_slot_variable_names(self, table):
2523 return ProximalYogiSlotVariableNames(
2524 '{}/{}'.format(table, 'ProximalYogi'), # v
2525 '{}/{}_1'.format(table, 'ProximalYogi')) # m
2527 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2528 table_config, table_variables, config_proto):
2529 v_initializer = init_ops.constant_initializer(
2530 self._optimization_parameters.initial_accumulator_value)
2531 v_variables = _create_partitioned_variables(
2532 name=slot_variable_names.v,
2533 num_hosts=num_hosts,
2534 vocabulary_size=table_config.vocabulary_size,
2535 embedding_dimension=table_config.dimension,
2536 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2537 initializer=v_initializer)
2538 m_initializer = init_ops.zeros_initializer()
2539 m_variables = _create_partitioned_variables(
2540 name=slot_variable_names.m,
2541 num_hosts=num_hosts,
2542 vocabulary_size=table_config.vocabulary_size,
2543 embedding_dimension=table_config.dimension,
2544 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2545 initializer=m_initializer)
2546 slot_variables = ProximalYogiSlotVariables(v_variables, m_variables)
2548 def load_ops_fn():
2549 """Returns the load ops for Proximal Yogi embedding tables.
2551 Returns:
2552 A list of ops to load embedding and slot variables from CPU to TPU.
2553 """
2554 load_op_list = []
2555 config = config_proto
2556 for host_id, table_variable, v_variable, m_variable in (zip(
2557 range(num_hosts), table_variables, v_variables, m_variables)):
2558 with ops.colocate_with(table_variable):
2559 load_parameters_op = (
2560 tpu_ops.load_tpu_embedding_proximal_yogi_parameters(
2561 parameters=table_variable,
2562 v=v_variable,
2563 m=m_variable,
2564 table_name=table,
2565 num_shards=num_hosts,
2566 shard_id=host_id,
2567 config=config))
2568 # Set config to None to enforce that config is only loaded to the first
2569 # table.
2570 config = None
2571 load_op_list.append(load_parameters_op)
2572 return load_op_list
2574 def retrieve_ops_fn():
2575 """Returns the retrieve ops for Proximal Yogi embedding tables.
2577 Returns:
2578 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2579 """
2580 retrieve_op_list = []
2581 config = config_proto
2582 for host_id, table_variable, v_variable, m_variable in (zip(
2583 range(num_hosts), table_variables, v_variables, m_variables)):
2584 with ops.colocate_with(table_variable):
2585 retrieved_table, retrieved_v, retrieved_m = (
2586 tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters(
2587 table_name=table,
2588 num_shards=num_hosts,
2589 shard_id=host_id,
2590 config=config))
2591 retrieve_parameters_op = control_flow_ops.group(
2592 state_ops.assign(table_variable, retrieved_table),
2593 state_ops.assign(v_variable, retrieved_v),
2594 state_ops.assign(m_variable, retrieved_m))
2595 config = None
2596 retrieve_op_list.append(retrieve_parameters_op)
2597 return retrieve_op_list
2599 return slot_variables, load_ops_fn, retrieve_ops_fn
2602class _MomentumHandler(_OptimizerHandler):
2603 """Handles Momentum specific logic."""
2605 def set_optimization_parameters(self, table_descriptor):
2606 (table_descriptor.optimization_parameters.momentum.SetInParent())
2607 table_descriptor.optimization_parameters.momentum.momentum = (
2608 self._optimization_parameters.momentum)
2609 table_descriptor.optimization_parameters.momentum.use_nesterov = (
2610 self._optimization_parameters.use_nesterov)
2612 def get_default_slot_variable_names(self, table):
2613 return MomentumSlotVariableNames('{}/{}'.format(table, 'Momentum'))
2615 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2616 table_config, table_variables, config_proto):
2618 momenta_initializer = init_ops.zeros_initializer()
2619 momenta_variables = _create_partitioned_variables(
2620 name=slot_variable_names.momenta,
2621 num_hosts=num_hosts,
2622 vocabulary_size=table_config.vocabulary_size,
2623 embedding_dimension=table_config.dimension,
2624 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2625 initializer=momenta_initializer)
2626 slot_variables = MomentumSlotVariables(momenta_variables)
2628 def load_ops_fn():
2629 """Returns the retrieve ops for Momentum embedding tables.
2631 Returns:
2632 A list of ops to load embedding and slot variables from CPU to TPU.
2633 """
2634 load_op_list = []
2635 config = config_proto
2636 for host_id, table_variable, momenta_variable in (zip(
2637 range(num_hosts), table_variables, momenta_variables)):
2638 with ops.colocate_with(table_variable):
2639 load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters(
2640 parameters=table_variable,
2641 momenta=momenta_variable,
2642 table_name=table,
2643 num_shards=num_hosts,
2644 shard_id=host_id,
2645 config=config,
2646 )
2647 config = None
2648 load_op_list.append(load_parameters_op)
2649 return load_op_list
2651 def retrieve_ops_fn():
2652 """Returns the retrieve ops for Momentum embedding tables.
2654 Returns:
2655 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2656 """
2657 retrieve_op_list = []
2658 config = config_proto
2659 for host_id, table_variable, momenta_variable in (zip(
2660 range(num_hosts), table_variables, momenta_variables)):
2661 with ops.colocate_with(table_variable):
2662 retrieved_table, retrieved_momenta = (
2663 tpu_ops.retrieve_tpu_embedding_momentum_parameters(
2664 table_name=table,
2665 num_shards=num_hosts,
2666 shard_id=host_id,
2667 config=config,
2668 ))
2669 retrieve_parameters_op = control_flow_ops.group(
2670 state_ops.assign(table_variable, retrieved_table),
2671 state_ops.assign(momenta_variable, retrieved_momenta))
2672 config = None
2673 retrieve_op_list.append(retrieve_parameters_op)
2674 return retrieve_op_list
2676 return slot_variables, load_ops_fn, retrieve_ops_fn
2679class _RMSPropHandler(_OptimizerHandler):
2680 """Handles RMS prop specific logic."""
2682 def set_optimization_parameters(self, table_descriptor):
2683 (table_descriptor.optimization_parameters.rms_prop.SetInParent())
2684 table_descriptor.optimization_parameters.rms_prop.rho = (
2685 self._optimization_parameters.rho)
2686 table_descriptor.optimization_parameters.rms_prop.epsilon = (
2687 self._optimization_parameters.epsilon)
2688 table_descriptor.optimization_parameters.rms_prop.momentum = (
2689 self._optimization_parameters.momentum)
2691 def get_default_slot_variable_names(self, table):
2692 return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'),
2693 '{}/{}/mom'.format(table, 'RMSProp'))
2695 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2696 table_config, table_variables, config_proto):
2698 ms_variables = _create_partitioned_variables(
2699 name=slot_variable_names.ms,
2700 num_hosts=num_hosts,
2701 vocabulary_size=table_config.vocabulary_size,
2702 embedding_dimension=table_config.dimension,
2703 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2704 initializer=init_ops.zeros_initializer(),
2705 )
2706 mom_variables = _create_partitioned_variables(
2707 name=slot_variable_names.mom,
2708 num_hosts=num_hosts,
2709 vocabulary_size=table_config.vocabulary_size,
2710 embedding_dimension=table_config.dimension,
2711 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2712 initializer=init_ops.zeros_initializer(),
2713 )
2714 slot_variables = RMSPropSlotVariables(ms_variables, mom_variables)
2716 def load_ops_fn():
2717 """Returns the retrieve ops for RMS Prop embedding tables.
2719 Returns:
2720 A list of ops to load embedding and slot variables from CPU to TPU.
2721 """
2722 load_op_list = []
2723 config = config_proto
2724 for host_id, table_variable, ms_variable, mom_variable in (zip(
2725 range(num_hosts), table_variables, ms_variables, mom_variables)):
2726 with ops.colocate_with(table_variable):
2727 load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters(
2728 parameters=table_variable,
2729 ms=ms_variable,
2730 mom=mom_variable,
2731 table_name=table,
2732 num_shards=num_hosts,
2733 shard_id=host_id,
2734 config=config,
2735 )
2736 config = None
2737 load_op_list.append(load_parameters_op)
2738 return load_op_list
2740 def retrieve_ops_fn():
2741 """Returns the retrieve ops for RMS Prop embedding tables.
2743 Returns:
2744 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2745 """
2746 retrieve_op_list = []
2747 config = config_proto
2748 for host_id, table_variable, ms_variable, mom_variable in (zip(
2749 range(num_hosts), table_variables, ms_variables, mom_variables)):
2750 with ops.colocate_with(table_variable):
2751 retrieved_table, retrieved_ms, retrieved_mom = (
2752 tpu_ops.retrieve_tpu_embedding_rms_prop_parameters(
2753 table_name=table,
2754 num_shards=num_hosts,
2755 shard_id=host_id,
2756 config=config,
2757 ))
2758 retrieve_parameters_op = control_flow_ops.group(
2759 state_ops.assign(table_variable, retrieved_table),
2760 state_ops.assign(ms_variable, retrieved_ms),
2761 state_ops.assign(mom_variable, retrieved_mom))
2762 config = None
2763 retrieve_op_list.append(retrieve_parameters_op)
2764 return retrieve_op_list
2766 return slot_variables, load_ops_fn, retrieve_ops_fn
2769class _FrequencyEstimatorHandler(_OptimizerHandler):
2770 """Handles frequency estimator specific logic."""
2772 def set_optimization_parameters(self, table_descriptor):
2773 table_descriptor.optimization_parameters.frequency_estimator.SetInParent()
2774 freq = table_descriptor.optimization_parameters.frequency_estimator
2775 freq.tau = self._optimization_parameters.tau
2776 freq.max_delta = self._optimization_parameters.max_delta
2777 freq.outlier_threshold = self._optimization_parameters.outlier_threshold
2778 freq.weight_exponent = self._optimization_parameters.weight_exponent
2780 def get_default_slot_variable_names(self, table):
2781 return FrequencyEstimatorSlotVariableNames(
2782 '{}/FrequencyEstimator'.format(table))
2784 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2785 table_config, table_variables, config_proto):
2786 if table_config.dimension != 1:
2787 raise ValueError('FrequencyEstimator tables should only have a dimension '
2788 'of 1. Received dimension {}'.format(
2789 table_config.dimension))
2791 last_hit_step_variables = _create_partitioned_variables(
2792 name=slot_variable_names.last_hit_step,
2793 num_hosts=num_hosts,
2794 vocabulary_size=table_config.vocabulary_size,
2795 embedding_dimension=table_config.dimension,
2796 collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2797 initializer=init_ops.zeros_initializer(),
2798 )
2799 slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables)
2801 def load_ops_fn():
2802 """Returns the retrieve ops for Frequency Estimator embedding tables.
2804 Returns:
2805 A list of ops to load embedding and slot variables from CPU to TPU.
2806 """
2807 load_op_list = []
2808 config = config_proto
2809 for host_id, table_variable, last_hit_step_variable in (zip(
2810 range(num_hosts), table_variables, last_hit_step_variables)):
2811 with ops.colocate_with(table_variable):
2812 load_parameters_op = (
2813 tpu_ops.load_tpu_embedding_frequency_estimator_parameters(
2814 parameters=table_variable,
2815 last_hit_step=last_hit_step_variable,
2816 table_name=table,
2817 num_shards=num_hosts,
2818 shard_id=host_id,
2819 config=config))
2820 config = None
2821 load_op_list.append(load_parameters_op)
2822 return load_op_list
2824 def retrieve_ops_fn():
2825 """Returns the retrieve ops for Frequency Estimator embedding tables.
2827 Returns:
2828 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2829 """
2830 retrieve_op_list = []
2831 config = config_proto
2832 for host_id, table_variable, last_hit_step_variable in (zip(
2833 range(num_hosts), table_variables, last_hit_step_variables)):
2834 with ops.colocate_with(table_variable):
2835 retrieved_table, retrieved_last_hit_step = (
2836 tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters(
2837 table_name=table,
2838 num_shards=num_hosts,
2839 shard_id=host_id,
2840 config=config,
2841 ))
2842 retrieve_parameters_op = control_flow_ops.group(
2843 state_ops.assign(table_variable, retrieved_table),
2844 state_ops.assign(last_hit_step_variable, retrieved_last_hit_step))
2845 config = None
2846 retrieve_op_list.append(retrieve_parameters_op)
2847 return retrieve_op_list
2849 return slot_variables, load_ops_fn, retrieve_ops_fn
2852class _StochasticGradientDescentHandler(_OptimizerHandler):
2853 """Handles stochastic gradient descent specific logic."""
2855 def set_optimization_parameters(self, table_descriptor):
2856 (table_descriptor.optimization_parameters.stochastic_gradient_descent
2857 .SetInParent())
2859 def get_default_slot_variable_names(self, table):
2860 return None
2862 def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2863 table_config, table_variables, config_proto):
2864 del table_config
2866 def load_ops_fn():
2867 """Returns the retrieve ops for AdaGrad embedding tables.
2869 Returns:
2870 A list of ops to load embedding and slot variables from CPU to TPU.
2871 """
2872 load_op_list = []
2873 config = config_proto
2874 for host_id, table_variable in enumerate(table_variables):
2875 with ops.colocate_with(table_variable):
2876 load_parameters_op = (
2877 tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
2878 parameters=table_variable,
2879 table_name=table,
2880 num_shards=num_hosts,
2881 shard_id=host_id,
2882 config=config))
2883 config = None
2884 load_op_list.append(load_parameters_op)
2885 return load_op_list
2887 def retrieve_ops_fn():
2888 """Returns the retrieve ops for SGD embedding tables.
2890 Returns:
2891 A list of ops to retrieve embedding and slot variables from TPU to CPU.
2892 """
2893 retrieve_op_list = []
2894 config = config_proto
2895 for host_id, table_variable in enumerate(table_variables):
2896 with ops.colocate_with(table_variable):
2897 retrieved_table = (
2898 tpu_ops
2899 .retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
2900 table_name=table,
2901 num_shards=num_hosts,
2902 shard_id=host_id,
2903 config=config))
2904 retrieve_parameters_op = control_flow_ops.group(
2905 state_ops.assign(table_variable, retrieved_table))
2906 config = None
2907 retrieve_op_list.append(retrieve_parameters_op)
2908 return retrieve_op_list
2910 return None, load_ops_fn, retrieve_ops_fn
2913def _get_optimization_handler(optimization_parameters):
2914 """Gets the optimization handler given the parameter type."""
2915 if isinstance(optimization_parameters, AdagradParameters):
2916 return _AdagradHandler(optimization_parameters)
2917 elif isinstance(optimization_parameters, AdagradMomentumParameters):
2918 return _AdagradMomentumHandler(optimization_parameters)
2919 elif isinstance(optimization_parameters, ProximalAdagradParameters):
2920 return _ProximalAdagradHandler(optimization_parameters)
2921 elif isinstance(optimization_parameters, AdamParameters):
2922 return _AdamHandler(optimization_parameters)
2923 elif isinstance(optimization_parameters, FtrlParameters):
2924 return _FtrlHandler(optimization_parameters)
2925 elif isinstance(optimization_parameters, ProximalYogiParameters):
2926 return _ProximalYogiHandler(optimization_parameters)
2927 elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
2928 return _StochasticGradientDescentHandler(optimization_parameters)
2929 elif isinstance(optimization_parameters, MomentumParameters):
2930 return _MomentumHandler(optimization_parameters)
2931 elif isinstance(optimization_parameters, RMSPropParameters):
2932 return _RMSPropHandler(optimization_parameters)
2933 elif isinstance(optimization_parameters, FrequencyEstimatorParameters):
2934 return _FrequencyEstimatorHandler(optimization_parameters)
2935 return NotImplementedError()
2938def _create_ordered_dict(d):
2939 """Create an OrderedDict from Dict."""
2940 return collections.OrderedDict((k, d[k]) for k in sorted(d))
2943def _create_combiners(table_to_config_dict, table_to_features_dict):
2944 """Create a per feature list of combiners, ordered by table."""
2945 combiners = []
2946 for table in table_to_config_dict:
2947 combiner = table_to_config_dict[table].combiner or 'sum'
2948 combiners.extend([combiner] * len(table_to_features_dict[table]))
2949 return combiners
2952def _create_table_to_features_dict(feature_to_config_dict):
2953 """Create mapping from table to a list of its features."""
2954 table_to_features_dict_tmp = {}
2955 for feature, feature_config in feature_to_config_dict.items():
2956 if feature_config.table_id in table_to_features_dict_tmp:
2957 table_to_features_dict_tmp[feature_config.table_id].append(feature)
2958 else:
2959 table_to_features_dict_tmp[feature_config.table_id] = [feature]
2961 table_to_features_dict = collections.OrderedDict()
2962 for table in sorted(table_to_features_dict_tmp):
2963 table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
2964 return table_to_features_dict
2967def _create_device_fn(hosts):
2968 """Create device_fn() to use with _create_partitioned_variables()."""
2970 def device_fn(op):
2971 """Returns the `device` for `op`."""
2972 part_match = re.match(r'.*/part_(\d+)(/|$)', op.name)
2973 dummy_match = re.match(r'.*dummy_(\d+).*', op.name)
2974 if not part_match and not dummy_match:
2975 raise RuntimeError(
2976 'Internal Error: Expected {} to contain /part_* or dummy_*'.format(
2977 op.name))
2979 if part_match:
2980 idx = int(part_match.group(1))
2981 else:
2982 idx = int(dummy_match.group(1)) # pytype: disable=attribute-error
2984 device = hosts[idx]
2985 logging.debug('assigning {} to {}.', op, device)
2986 return device
2988 return device_fn
2991def _create_partitioned_variables(name,
2992 num_hosts,
2993 vocabulary_size,
2994 embedding_dimension,
2995 initializer,
2996 collections=None): # pylint: disable=redefined-outer-name
2997 """Creates PartitionedVariables based on `num_hosts` for `table`."""
2999 num_slices = min(vocabulary_size, num_hosts)
3001 var_list = list(
3002 variable_scope.get_variable(
3003 name,
3004 shape=(vocabulary_size, embedding_dimension),
3005 partitioner=partitioned_variables.fixed_size_partitioner(num_slices),
3006 dtype=dtypes.float32,
3007 initializer=initializer,
3008 collections=collections,
3009 trainable=False))
3011 if vocabulary_size >= num_hosts:
3012 return var_list
3014 # For padded part, define the dummy variable to be loaded into TPU system.
3015 for idx in range(num_hosts - vocabulary_size):
3016 var_list.append(
3017 variable_scope.get_variable(
3018 'dummy_{}_{}'.format(vocabulary_size + idx, name),
3019 shape=(1, embedding_dimension),
3020 dtype=dtypes.float32,
3021 initializer=initializer,
3022 collections=[ops.GraphKeys.LOCAL_VARIABLES],
3023 trainable=False))
3025 return var_list