Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding_v1.py: 21%
142 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mid level API for TPU Embeddings without Embedding Accelerator."""
17from typing import Any, Dict, Iterable, Optional, Text, Union
19from tensorflow.python.distribute import distribute_lib
20from tensorflow.python.distribute import tpu_strategy
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import embedding_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sparse_ops
28from tensorflow.python.ops import variables as tf_variables
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.tpu import tpu_embedding_base
31from tensorflow.python.tpu import tpu_embedding_v2_utils
32from tensorflow.python.tpu import tpu_replication
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
37@tf_export("tpu.experimental.embedding.TPUEmbeddingV0")
38class TPUEmbeddingV0(tpu_embedding_base.TPUEmbeddingBase):
39 """The TPUEmbedding mid level API running on TPU without Embedding accelerator.
41 NOTE: This mid level API is not intended for large embedding table lookup.
42 Embedding tables will be replicated across devices rather than sharding
43 across them. To do large embedding table lookup, please use the
44 `tpu.experimental.embedding.TPUEmbedding` class. This class is an alternative
45 way to do embedding lookups when the TPU doesn't support any version of
46 embedding feature. See
47 `tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed
48 explanation.
50 This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError
51 will be raised.
52 ```python
53 strategy = tf.distribute.TPUStrategy(...)
54 with strategy.scope():
55 embedding = tf.tpu.experimental.embedding.TPUEmbeddingV0(
56 feature_config=feature_config,
57 optimizer=tf.tpu.experimental.embedding.SGD(0.1))
58 ```
59 When creating a distributed dataset that is to be passed to the lookup
60 operation a special input option must be specified:
62 ```python
63 distributed_dataset = (
64 strategy.distribute_datasets_from_function(
65 dataset_fn=...,
66 options=tf.distribute.InputOptions(
67 experimental_fetch_to_device=False))
68 dataset_iterator = iter(distributed_dataset)
69 ```
71 Below is an example of a training and evaluation step:
73 ```python
74 optimizer = tf.keras.optimizers.SGD(0.1)
76 @tf.function
77 def training_step(dataset_iterator, num_steps):
78 def tpu_step(embedding_features):
79 with tf.GradientTape() as tape:
80 tape.watch(embedding.embedding_table.values())
81 activation = embedding(embedding_features)
82 model_output = model(activations)
83 loss = ... # some function of labels and model_output
85 embedding_gradients = tape.gradient(loss,
86 embedding.embedding_table.values())
87 optimizer.apply_gradients(list(zip(gradients,
88 mid_level_api.embedding_tables.values())))
89 # Insert your model gradient and optimizer application here
91 for _ in tf.range(num_steps):
92 strategy.run(tpu_step, args=(next(dataset_iterator), ))
94 @tf.function
95 def evalution_step(dataset_iterator, num_steps):
96 def tpu_step(embedding_features):
97 activations = embedding(embedding_features)
98 model_output = model(activations)
99 # Insert your evaluation code here.
101 for _ in tf.range(num_steps):
102 strategy.run(tpu_step, args=(next(dataset_iterator), ))
103 ```
105 NOTE: The optimizer used here is a Keras optimizer. In order to make the slot
106 variable creation stay consistent between Keras optimizers and
107 embedding optimizers, the `slot_variable_creation_fn` argument of the
108 embedding optimizers has to be passed with the Keras `add_slot` function. Also
109 note that the slot names might be slightly different between them.
111 ```python
112 optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
114 def slot_variable_creation_fn(table, slot_names, slot_initializers):
115 slots = {}
116 for slot, initializer in zip(slot_names, slot_initializers):
117 slots[slot] = optimizer.add_slot(table, slot, initializer)
118 return slots
120 embedding_optimizer = tf.experimental.embedding.Adagrad(
121 learning_rate=0.1,
122 slot_variable_creation_fn=slot_variable_creation_fn)
124 # Use the embedding optimizer to create mid level api and keras optimizer to
125 # apply gradients.
126 ```
127 """
129 def __init__(
130 self,
131 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
132 optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access
133 super(TPUEmbeddingV0, self).__init__(feature_config, optimizer)
134 self._strategy = distribute_lib.get_strategy()
135 if not isinstance(self._strategy,
136 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)):
137 raise RuntimeError(
138 "TPUEmbeddingV0 should be created under TPUStrategy but found {}."
139 .format(self._strategy))
140 self._built = False
142 @property
143 def embedding_tables(
144 self) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
145 """Returns a dict of embedding tables, keyed by `TableConfig`."""
146 self._maybe_build()
147 # Only return the tables and not the slot variables.
148 return {
149 table: self._variables[table.name]["parameters"]
150 for table in self._table_config
151 }
153 def _create_variables_and_slots(
154 self) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
155 """Create variables for TPU embeddings.
157 Note that this will always ensure that the variable is created under the
158 TPUStrategy.
160 Returns:
161 A dict of dicts. The outer dict is keyed by the table names and the inner
162 dicts are keyed by 'parameters' and the slot variable names.
163 """
164 variables = {}
165 for table in self._table_config:
166 # created TPUDistributedVariable.
167 variables[table.name] = self._create_variables(table, trainable=True)
168 return variables
170 def _maybe_build(self):
171 if not self._built:
172 # This can be called while tracing a function, so we wrap the
173 # initialization code with init_scope so it runs eagerly, this means that
174 # it will not be included in the function graph generated by tracing so
175 # that we can be sure that we only initialize the TPU for embeddings
176 # exactly once.
177 with ops.init_scope():
178 self.build()
180 def _apply_combiner_to_embeddings(
181 self,
182 embeddings: ops.Tensor,
183 weight: ops.Tensor,
184 combiner: Optional[Text] = None) -> ops.Tensor:
185 """Apply the combiner to the embedding look up result on second to last axis.
187 Args:
188 embeddings: A Tensor of the embedding lookup result.
189 weight: A Tensor of weight which has the same shape of the embeddings.
190 combiner: One of "mean", "sum", "sqrtn". Defaults to "mean".
192 Raises:
193 ValueError: If the combiner is not one of 'mean', 'sqrtn' or 'sum'.
194 Returns:
195 A Tensor.
196 """
197 if combiner is None:
198 combiner = "mean"
199 if combiner == "sum":
200 embeddings = math_ops.reduce_sum(embeddings, axis=-2)
201 elif combiner == "mean":
202 embeddings = math_ops.reduce_sum(embeddings, axis=-2)
203 weight_sum = math_ops.reduce_sum(weight, axis=-2)
204 embeddings = math_ops.div_no_nan(embeddings, weight_sum)
205 elif combiner == "sqrtn":
206 embeddings = math_ops.reduce_sum(embeddings, axis=-2)
207 weight_squared = math_ops.pow(weight, 2)
208 weight_sum = math_ops.reduce_sum(weight_squared, axis=-2)
209 weight_sum_sqrt = math_ops.sqrt(weight_sum)
210 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt)
211 else:
212 raise ValueError(
213 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}")
214 return embeddings
216 def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor,
217 sequence_length: int) -> ops.Tensor:
218 """Pad or truncate the embedding lookup result based on the sequence length.
220 Args:
221 embeddings: A rank 3 Tensor of the embedding lookup result.
222 sequence_length: number of the max sequence length set in the feature
223 config.
225 Returns:
226 A Tensor with second last axis padded or truncated.
227 """
228 original_sequence_length = embeddings.shape[1]
229 if original_sequence_length > sequence_length:
230 embeddings = array_ops.slice(
231 embeddings, begin=[0, 0, 0], size=[-1, sequence_length, -1])
232 else:
233 embeddings = array_ops.pad(
234 embeddings,
235 paddings=[[0, 0], [0, sequence_length - original_sequence_length],
236 [0, 0]])
237 return embeddings
239 def embedding_lookup(self,
240 features: Any,
241 weights: Optional[Any] = None) -> Any:
242 """Apply embedding lookup on TPUs using Tensorcore.
244 Note that all the sparse and ragged tensors will be converted to dense
245 tensors on CPU and then passed to the TPU to do embedding look up. Large
246 embedding lookup is not supported by this API, use the TPUEmbedding mid
247 level api instead.
249 Args:
250 features: a nested structure of Tensors, SparseTensors or RaggedTensors.
251 weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
252 None for no weights. If not None, structure must match that of inputs,
253 but entries are allowed to be None.
255 Returns:
256 A nested structure of Tensors with the same structure as inputs.
257 """
258 if not self._built:
259 self.build()
260 nest.assert_same_structure(features, self._feature_config)
262 flat_inputs = nest.flatten(features)
263 flat_weights = [None] * len(flat_inputs)
264 if weights is not None:
265 nest.assert_same_structure(features, weights)
266 flat_weights = nest.flatten(weights)
267 flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
269 outputs = []
270 for inp, weight, (path, feature) in zip(flat_inputs, flat_weights,
271 flat_features):
272 table = self.embedding_tables[feature.table]
274 if weight is not None:
275 if isinstance(inp, ops.Tensor):
276 raise ValueError(
277 "Weight specified for {}, but input is dense.".format(path))
278 elif type(weight) is not type(inp):
279 raise ValueError(
280 "Weight for {} is of type {} but it does not match type of the "
281 "input which is {}.".format(path, type(weight), type(inp)))
282 elif feature.max_sequence_length > 0:
283 raise ValueError("Weight specified for {}, but this is a sequence "
284 "feature.".format(path))
286 if isinstance(inp, ops.Tensor):
287 if feature.max_sequence_length > 0:
288 raise ValueError(
289 "Feature {} is a sequence feature but a dense tensor "
290 "was passed.".format(path))
291 outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
293 elif isinstance(inp, sparse_tensor.SparseTensor):
294 outputs.append(
295 self._embedding_lookup_for_sparse_tensor(inp, weight, table,
296 feature))
297 elif isinstance(inp, ragged_tensor.RaggedTensor):
298 outputs.append(
299 self._embedding_lookup_for_ragged_tensor(inp, weight, table,
300 feature))
301 else:
302 raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
303 "RaggedTensor expected.".format(path, type(inp)))
304 return nest.pack_sequence_as(self._feature_config, outputs)
306 def _embedding_lookup_for_sparse_tensor(
307 self, inp: sparse_tensor.SparseTensor,
308 weight: Optional[sparse_tensor.SparseTensor],
309 table: tf_variables.Variable,
310 feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
311 """Embedding lookup for sparse tensor based on its feature config.
313 Args:
314 inp: a single SparseTensor input.
315 weight: None or SparseTensor which has the same shape of the input.
316 table: a table variable.
317 feature: a feature config.
319 Returns:
320 Embedding lookup result.
321 """
323 # This computation needs to placed outside of tpu as the size of the
324 # indices and values can change for different batch which can cause
325 # the program to re-compile.
326 def sparse_to_dense_computation(inp, weight):
327 if weight is None:
328 weight = sparse_tensor.SparseTensor(
329 inp.indices,
330 array_ops.ones_like(inp.values, dtype=dtypes.float32),
331 dense_shape=inp.dense_shape)
332 # Pad the sparse tensor to be dense tensor.
333 inp = sparse_ops.sparse_tensor_to_dense(inp)
334 weight = sparse_ops.sparse_tensor_to_dense(weight)
335 return inp, weight
337 inp, weight = tpu_replication.outside_compilation(
338 sparse_to_dense_computation, inp=inp, weight=weight)
340 embeddings = embedding_ops.embedding_lookup_v2(table, inp)
341 weight = array_ops.expand_dims(weight, -1)
342 embeddings *= weight
343 if not feature.output_shape and feature.max_sequence_length > 0:
344 embeddings = self._pad_or_truncate_with_sequence_length(
345 embeddings, feature.max_sequence_length)
346 else:
347 embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
348 feature.table.combiner)
349 return embeddings
351 def _embedding_lookup_for_ragged_tensor(
352 self, inp: ragged_tensor.RaggedTensor,
353 weight: Optional[ragged_tensor.RaggedTensor],
354 table: tf_variables.Variable,
355 feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
356 """Embedding lookup for ragged tensor based on its feature config.
358 Args:
359 inp: a single rank 2 RaggedTensor input.
360 weight: None or RaggedTensor which has the same shape of the input.
361 table: a table variable.
362 feature: a feature config.
364 Returns:
365 Embedding lookup result.
367 Raises:
368 ValueError: if input ragged tensor is not rank 2 or output shape set in
369 the feature config doesn't match with the first dim size of the input.
370 """
371 if inp.shape.rank != 2:
372 raise ValueError(
373 "Only rank 2 ragged tensor is supported, but got rank {}".format(
374 inp.shape.rank))
375 batch_size = inp.shape[0]
377 # This computation needs to placed outside of tpu as the size of the row
378 # splits and values can change for different batch which can cause
379 # the program to re-compile.
380 def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature):
381 if weight is None:
382 weight = ragged_tensor.RaggedTensor.from_row_splits(
383 array_ops.ones_like(inp.values, dtype=dtypes.float32),
384 inp.row_splits)
385 if not feature.output_shape and feature.max_sequence_length > 0:
386 inp = inp.to_tensor(shape=(batch_size, feature.max_sequence_length))
387 # Ignore weight if it is a sequence feature.
388 weight = array_ops.ones_like(inp, dtype=dtypes.float32)
389 elif feature.output_shape:
390 # Eagerly run the following op as the result as to be a number in
391 # order to use it as part of the output shape.
392 with ops.init_scope():
393 output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy()
394 # If the output batch size matches the data batch size, treat it as
395 # normal ragged input.
396 if output_batch_size == batch_size:
397 inp, weight = inp.to_tensor(), weight.to_tensor()
398 # If the data batch size is a factor of the output batch size, the
399 # divide result will be the sequence length. Ignore the weights and
400 # combiner.
401 elif output_batch_size > batch_size and output_batch_size % batch_size == 0:
402 # Pad or truncate in the sequence dimension
403 seq_length = output_batch_size // batch_size
404 inp = inp.to_tensor(shape=(batch_size, seq_length))
405 # Ignore weight if it is a sequence feature.
406 weight = array_ops.ones_like(inp, dtype=dtypes.float32)
407 else:
408 raise ValueError(
409 "Output shape set in the FeatureConfig should be the factor of "
410 "the input data batch size. But instead got output shape {}, "
411 "input data batch size {}".format(feature.output_shape,
412 batch_size))
413 else:
414 inp, weight = inp.to_tensor(), weight.to_tensor()
415 return inp, weight
417 inp, weight = tpu_replication.outside_compilation(
418 ragged_to_dense_outside_compilation,
419 inp=inp,
420 weight=weight,
421 batch_size=batch_size,
422 feature=feature)
424 embeddings = embedding_ops.embedding_lookup_v2(table, inp)
425 weight = array_ops.expand_dims(weight, -1)
426 embeddings *= weight
428 if feature.output_shape:
429 with ops.init_scope():
430 output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy()
431 if output_batch_size == batch_size:
432 embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
433 feature.table.combiner)
434 embeddings = array_ops.reshape(
435 embeddings, shape=feature.output_shape + [feature.table.dim])
436 else:
437 if feature.max_sequence_length == 0:
438 embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
439 feature.table.combiner)
440 return embeddings