Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/embedding_ops.py: 17%
244 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for embeddings."""
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import indexed_slices
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import sparse_tensor
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import array_ops_stack
25from tensorflow.python.ops import clip_ops
26# Imports gradient definitions.
27from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import
28from tensorflow.python.ops import data_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import sparse_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.util import dispatch
34from tensorflow.python.util.tf_export import tf_export
37def _clip(params, ids, max_norm):
38 """Helper function for _embedding_lookup_and_transform.
40 This function optionally clips embeddings to an l2-norm of max_norm.
42 Args:
43 params: A `Tensor` of embeddings retrieved by `gather`.
44 ids: The `ids` argument that was passed to `gather`.
45 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
46 than this value.
48 Returns:
49 A `Tensor` with the same type as `params`.
50 """
52 def _rank(x):
53 """Helper function to retrieve the rank of a tensor.
55 Args:
56 x: Something convertible to `Tensor`.
58 Returns:
59 Either a pair `(rank, True)` where `rank` is an integer or a pair
60 `(rank, False)` where `rank` is an integer `Tensor`. In either case,
61 `rank` is the rank of `x`.
62 """
63 rank = ops.convert_to_tensor(x).get_shape().ndims
64 if rank:
65 return rank, True
66 else:
67 return array_ops.rank(x), False
69 if max_norm is None:
70 return params
71 ids_rank, ids_static = _rank(ids)
72 params_rank, params_static = _rank(params)
73 return clip_ops.clip_by_norm(
74 params,
75 max_norm,
76 axes=(list(range(ids_rank, params_rank)) if ids_static and params_static
77 else math_ops.range(ids_rank, params_rank)))
80def _colocate_with(param):
81 if ops.inside_function() and hasattr(param, "handle"):
82 # The `ops.colocate_with` will hard-code a device string if `param.device`
83 # is known, which will then break serving. We capture it here so that it
84 # produces a tensor without a device.
85 return ops.colocate_with(ops.get_default_graph().capture(param.handle))
86 else:
87 return ops.colocate_with(param)
90def _embedding_lookup_and_transform(params,
91 ids,
92 partition_strategy="mod",
93 name=None,
94 max_norm=None,
95 transform_fn=None):
96 """Helper function for embedding_lookup and _compute_sampled_logits.
98 This function is a generalization of embedding_lookup that optionally
99 applies a caller-specified transformation to each embedding. This is
100 done through the `transform_fn` argument. If provided, the function is
101 applied to each partitioned tensor of retrieved embeddings, colocated
102 with the embeddings. This function will be called with a single `Tensor`
103 argument of the same type as the `params` tensor and should return a
104 `Tensor`. The shape of the argument will be the same as `params` except
105 for the size of the first dimension. The first dimension of the result's
106 shape must be the same size as the argument's.
108 Args:
109 params: See embedding_lookup.
110 ids: See embedding_lookup.
111 partition_strategy: See embedding_lookup.
112 name: See embedding_lookup.
113 max_norm: See embedding_lookup.
114 transform_fn: An optional function to apply to each retrieved embedding. If
115 max_norm is provided, transform_fn is applied to the norm-limited
116 embeddings.
118 Returns:
119 See embedding_lookup for details.
120 Raises:
121 ValueError: If `params` is empty.
122 """
123 if params is None:
124 raise ValueError("params must be specified")
125 if isinstance(params, (list, tuple)) and not params:
126 raise ValueError("Length of params is currently 0. "
127 "Need at least one param.")
128 if isinstance(params, variables.PartitionedVariable):
129 params = list(params) # Iterate to get the underlying Variables.
130 if not isinstance(params, list):
131 params = [params]
133 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
134 np = len(params) # Number of partitions
135 # Preserve the resource variable status to avoid accidental dense reads.
136 if not any(
137 isinstance(p, resource_variable_ops.BaseResourceVariable)
138 for p in params):
139 params = indexed_slices.convert_n_to_tensor_or_indexed_slices(
140 params, name="params")
141 ids = ops.convert_to_tensor(ids, name="ids")
142 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
143 with _colocate_with(params[0]):
144 result = _clip(
145 array_ops.gather(params[0], ids, name=name), ids, max_norm)
146 if transform_fn:
147 result = transform_fn(result)
148 # Make sure the final result does not have colocation constraints on the
149 # params. Similar to the case np > 1 where parallel_dynamic_stitch is
150 # outside the scope of all with _colocate_with(params[p]).
151 return array_ops.identity(result)
152 else:
153 # Flatten the ids. There are two cases where we need to do this.
154 # - There is more than one params tensor.
155 # - There is a transform_fn and ids is not statically known to be 1-D.
156 # We must flatten in this case because transform_fn expects a flat
157 # tensor of embeddings.
158 flat_ids = array_ops.reshape(ids, [-1])
159 original_indices = math_ops.range(array_ops.size(flat_ids))
161 # Create p_assignments and set new_ids depending on the strategy.
162 if partition_strategy == "mod":
163 p_assignments = flat_ids % np
164 new_ids = flat_ids // np
165 elif partition_strategy == "div":
166 # Compute num_total_ids as the sum of dim-0 of params, then assign to
167 # partitions based on a constant number of ids per partition. Optimize
168 # if we already know the full shape statically.
169 dim_0_size = tensor_shape.Dimension(
170 tensor_shape.dimension_value(params[0].get_shape()[0]))
171 for p in range(1, np):
172 dim_0_size += tensor_shape.Dimension(
173 tensor_shape.dimension_value(params[p].get_shape()[0]))
174 if dim_0_size.value:
175 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
176 else:
177 dim_0_sizes = []
178 for p in range(np):
179 param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
180 if param_p_dim is not None:
181 dim_0_sizes.append(param_p_dim)
182 else:
183 with _colocate_with(params[p]):
184 dim_0_sizes.append(array_ops.shape(params[p])[0])
185 num_total_ids = math_ops.reduce_sum(
186 math_ops.cast(array_ops_stack.stack(dim_0_sizes), flat_ids.dtype))
187 ids_per_partition = num_total_ids // np
188 extras = num_total_ids % np
190 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
191 (flat_ids - extras) //
192 ids_per_partition)
194 # Emulate a conditional using a boolean indicator tensor
195 new_ids = array_ops.where(p_assignments < extras,
196 flat_ids % (ids_per_partition + 1),
197 (flat_ids - extras) % ids_per_partition)
198 else:
199 raise ValueError(
200 f"Unrecognized partition strategy: {partition_strategy}."
201 "Must be one of either `mod` or `div`.")
203 # Cast partition assignments to int32 for use in dynamic_partition.
204 # There really should not be more than 2^32 partitions.
205 p_assignments = math_ops.cast(p_assignments, dtypes.int32)
206 # Partition list of ids based on assignments into np separate lists
207 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
208 # Similarly, partition the original indices.
209 pindices = data_flow_ops.dynamic_partition(original_indices,
210 p_assignments, np)
211 # Do np separate lookups, finding embeddings for plist[p] in params[p]
212 partitioned_result = []
213 for p in range(np):
214 pids = gather_ids[p]
215 with ops.device_v2(None):
216 with _colocate_with(params[p]):
217 result = array_ops.gather(params[p], pids)
218 if transform_fn:
219 # If transform_fn is provided, the clip_by_norm precedes
220 # the transform and hence must be co-located. See below
221 # for the counterpart if transform_fn is not provided.
222 result = transform_fn(_clip(result, pids, max_norm))
223 partitioned_result.append(result)
224 # Stitch these back together
225 ret = data_flow_ops.parallel_dynamic_stitch(
226 pindices, partitioned_result, name=name)
228 # Determine the static element shape.
229 if transform_fn is None:
230 element_shape_s = params[0].get_shape()[1:]
231 for p in params[1:]:
232 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
233 else:
234 element_shape_s = ret.get_shape()[1:]
236 # Compute the dynamic element shape.
237 if element_shape_s.is_fully_defined():
238 element_shape_d = element_shape_s
239 elif transform_fn is None:
240 # It's important that we compute params[0].shape on the right device
241 # to avoid data motion.
242 with _colocate_with(params[0]):
243 params_shape = array_ops.shape(params[0])
244 element_shape_d = params_shape[1:]
245 else:
246 element_shape_d = array_ops.shape(ret)[1:]
248 # Reshape to reverse the flattening of ids.
249 ret = array_ops.reshape(
250 ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0))
252 # Normally the reshape is sufficient, but setting shape explicitly
253 # teaches shape inference that params[1:].get_shape() matters
254 # (in the case that transform_fn is None).
255 ret.set_shape(ids.get_shape().concatenate(element_shape_s))
256 if not transform_fn:
257 # If transform_fn was provided, the clip_by_norm was done above.
258 ret = _clip(ret, ids, max_norm)
259 return ret
262@tf_export(v1=["nn.embedding_lookup"])
263@dispatch.add_dispatch_support
264def embedding_lookup(
265 params,
266 ids,
267 partition_strategy="mod",
268 name=None,
269 validate_indices=True, # pylint: disable=unused-argument
270 max_norm=None):
271 """Looks up embeddings for the given `ids` from a list of tensors.
273 This function is used to perform parallel lookups on the list of tensors in
274 `params`. It is a generalization of `tf.gather`, where `params` is
275 interpreted as a partitioning of a large embedding tensor. `params` may be
276 a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()`
277 with a partitioner.
279 If `len(params) > 1`, each element `id` of `ids` is partitioned between
280 the elements of `params` according to the `partition_strategy`.
281 In all strategies, if the id space does not evenly divide the number of
282 partitions, each of the first `(max_id + 1) % len(params)` partitions will
283 be assigned one more id.
285 If `partition_strategy` is `"mod"`, we assign each id to partition
286 `p = id % len(params)`. For instance,
287 13 ids are split across 5 partitions as:
288 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
290 If `partition_strategy` is `"div"`, we assign ids to partitions in a
291 contiguous manner. In this case, 13 ids are split across 5 partitions as:
292 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
294 If the input ids are ragged tensors, partition variables are not supported and
295 the partition strategy and the max_norm are ignored.
296 The results of the lookup are concatenated into a dense
297 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
299 Args:
300 params: A single tensor representing the complete embedding tensor, or a
301 list of P tensors all of same shape except for the first dimension,
302 representing sharded embedding tensors. Alternatively, a
303 `PartitionedVariable`, created by partitioning along dimension 0. Each
304 element must be appropriately sized for the given `partition_strategy`.
305 ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing
306 the ids to be looked up in `params`.
307 partition_strategy: A string specifying the partitioning strategy, relevant
308 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
309 is `"mod"`.
310 name: A name for the operation (optional).
311 validate_indices: DEPRECATED. If this operation is assigned to CPU, values
312 in `indices` are always validated to be within range. If assigned to GPU,
313 out-of-bound indices result in safe but unspecified behavior, which may
314 include raising an error.
315 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
316 than this value.
318 Returns:
319 A `Tensor` or a 'RaggedTensor', depending on the input, with the same type
320 as the tensors in `params`.
322 Raises:
323 ValueError: If `params` is empty.
324 """
326 return _embedding_lookup_and_transform(
327 params=params,
328 ids=ids,
329 partition_strategy=partition_strategy,
330 name=name,
331 max_norm=max_norm,
332 transform_fn=None)
335@tf_export("nn.embedding_lookup", v1=[])
336@dispatch.add_dispatch_support
337def embedding_lookup_v2(params, ids, max_norm=None, name=None):
338 """Looks up embeddings for the given `ids` from a list of tensors.
340 This function is used to perform parallel lookups on the list of tensors in
341 `params`. It is a generalization of `tf.gather`, where `params` is
342 interpreted as a partitioning of a large embedding tensor.
344 If `len(params) > 1`, each element `id` of `ids` is partitioned between the
345 elements of `params` according to the "div" partition strategy, which means we
346 assign ids to partitions in a contiguous manner. For instance, 13 ids are
347 split across 5 partitions as:
348 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
350 If the id space does not evenly divide the number of partitions, each of the
351 first `(max_id + 1) % len(params)` partitions will be assigned one more id.
353 The results of the lookup are concatenated into a dense
354 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
356 Args:
357 params: A single tensor representing the complete embedding tensor, or a
358 list of tensors all of same shape except for the first dimension,
359 representing sharded embedding tensors following "div" partition strategy.
360 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
361 up in `params`.
362 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
363 than this value.
364 name: A name for the operation (optional).
366 Returns:
367 A `Tensor` with the same type as the tensors in `params`.
369 For instance, if `params` is a 5x2 matrix:
371 ```python
372 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
373 ```
375 or a list of matrices:
377 ```python
378 params[0]: [[1, 2], [3, 4]]
379 params[1]: [[5, 6], [7, 8]]
380 params[2]: [[9, 10]]
381 ```
383 and `ids` is:
385 ```python
386 [0, 3, 4]
387 ```
389 The output will be a 3x2 matrix:
391 ```python
392 [[1, 2], [7, 8], [9, 10]]
393 ```
395 Raises:
396 ValueError: If `params` is empty.
397 """
398 return embedding_lookup(params, ids, "div", name, max_norm=max_norm)
401@tf_export(v1=["nn.embedding_lookup_sparse"])
402@dispatch.add_dispatch_support
403def embedding_lookup_sparse(
404 params,
405 sp_ids,
406 sp_weights,
407 partition_strategy="mod",
408 name=None,
409 combiner=None,
410 max_norm=None,
411 allow_fast_lookup=False,
412):
413 """Looks up embeddings for the given ids and weights from a list of tensors.
415 This op assumes that there is at least one id for each row in the dense tensor
416 represented by sp_ids (i.e. there are no rows with empty features), and that
417 all the indices of sp_ids are in canonical row-major order.
419 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s or `RaggedTensor`s
420 with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which
421 can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher
422 performance.
424 It also assumes that all id values lie in the range [0, p0), where p0
425 is the sum of the size of params along dimension 0.
427 Args:
428 params: A single tensor representing the complete embedding tensor, or a
429 list tensors all of same shape except for the first dimension,
430 representing sharded embedding tensors. Alternatively, a
431 `PartitionedVariable`, created by partitioning along dimension 0. Each
432 element must be appropriately sized for the given `partition_strategy`.
433 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
434 and M is arbitrary or a `RaggedTensor` with rank 2.
435 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as
436 `sparse_ids`, containing float / double weights corresponding to
437 `sparse_ids`, or `None` if all weights are assumed to be 1.0.
438 partition_strategy: A string specifying the partitioning strategy, relevant
439 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
440 is `"mod"`. See `tf.nn.embedding_lookup` for more details.
441 name: Optional name for the op.
442 combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
443 and "sum" are supported. "sum" computes the weighted sum of the embedding
444 results for each row. "mean" is the weighted sum divided by the total
445 weight. "sqrtn" is the weighted sum divided by the square root of the sum
446 of the squares of the weights. Defaults to `mean`.
447 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
448 than this value, before combining.
449 allow_fast_lookup: An optional boolean specifying whether to allow
450 simplified embedding lookups when `params` is a single tensor and
451 `max_norm` is `None`. Setting this flag to `True` during training can
452 cause the use of dense gradients with increased memory footprint.
454 Returns:
455 A dense tensor representing the combined embeddings for the
456 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
457 looks up the embeddings for all ids in that row, multiplies them by the
458 corresponding weight, and combines these embeddings as specified.
460 In other words, if
462 `shape(combined params) = [p0, p1, ..., pm]`
464 and
466 `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
468 then
470 `shape(output) = [d0, p1, ..., pm]`.
472 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
474 ```python
475 [0, 0]: id 1, weight 2.0
476 [0, 1]: id 3, weight 0.5
477 [1, 0]: id 0, weight 1.0
478 [2, 3]: id 1, weight 3.0
479 ```
481 with `combiner`="mean", then the output will be a 3x20 matrix where
483 ```python
484 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
485 output[1, :] = (params[0, :] * 1.0) / 1.0
486 output[2, :] = (params[1, :] * 3.0) / 3.0
487 ```
489 Raises:
490 TypeError: If `sp_ids` is not a `SparseTensor` or `RaggedTensor`, or if
491 `sp_weights` is neither `None` nor of the same type as `sp_ids`.
492 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
493 """
494 if combiner is None:
495 combiner = "mean"
496 if combiner not in ("mean", "sqrtn", "sum"):
497 raise ValueError(
498 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}")
499 if isinstance(params, variables.PartitionedVariable):
500 params = list(params) # Iterate to get the underlying Variables.
501 if not isinstance(params, list):
502 params = [params]
503 if not isinstance(sp_ids, sparse_tensor.SparseTensor):
504 raise TypeError(f"sp_ids must be SparseTensor, got {type(sp_ids)}")
505 ignore_weights = sp_weights is None
506 if not ignore_weights:
507 if not isinstance(sp_weights, sparse_tensor.SparseTensor):
508 raise TypeError(f"sp_weights must be either None or SparseTensor,"
509 f"got {type(sp_weights)}")
510 sp_ids.values.get_shape().assert_is_compatible_with(
511 sp_weights.values.get_shape())
512 sp_ids.indices.get_shape().assert_is_compatible_with(
513 sp_weights.indices.get_shape())
514 sp_ids.dense_shape.get_shape().assert_is_compatible_with(
515 sp_weights.dense_shape.get_shape())
516 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
517 # sp_weights have equal indices and shapes.
519 with ops.name_scope(name, "embedding_lookup_sparse",
520 params + [sp_ids]) as name:
522 segment_ids = sp_ids.indices[:, 0]
523 ids = sp_ids.values
525 return embedding_lookup_sparse_impl(
526 params,
527 segment_ids,
528 sp_weights,
529 ids,
530 combiner,
531 ignore_weights,
532 max_norm,
533 allow_fast_lookup,
534 partition_strategy,
535 name,
536 )
539@tf_export("nn.embedding_lookup_sparse", v1=[])
540@dispatch.add_dispatch_support
541def embedding_lookup_sparse_v2(
542 params,
543 sp_ids,
544 sp_weights,
545 combiner=None,
546 max_norm=None,
547 name=None,
548 allow_fast_lookup=False,
549):
550 """Looks up embeddings for the given ids and weights from a list of tensors.
552 This op assumes that there is at least one id for each row in the dense tensor
553 represented by sp_ids (i.e. there are no rows with empty features), and that
554 all the indices of sp_ids are in canonical row-major order.
556 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s or `RaggedTensor`s
557 with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which
558 can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher
559 performance.
561 It also assumes that all id values lie in the range [0, p0), where p0
562 is the sum of the size of params along dimension 0.
564 If `len(params) > 1`, each element of `sp_ids` is partitioned between the
565 elements of `params` according to the "div" partition strategy, which means we
566 assign ids to partitions in a contiguous manner. For instance, 13 ids are
567 split across 5 partitions as:
568 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
570 If the id space does not evenly divide the number of partitions, each of the
571 first `(max_id + 1) % len(params)` partitions will be assigned one more id.
573 Args:
574 params: A single tensor representing the complete embedding tensor, or a
575 list of tensors all of same shape except for the first dimension,
576 representing sharded embedding tensors following "div" partition strategy.
577 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
578 and M is arbitrary or a `RaggedTensor` with rank 2.
579 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as
580 `sparse_ids`, containing float / double weights corresponding to
581 `sparse_ids`, or `None` if all weights are assumed to be 1.0.
582 combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
583 and "sum" are supported. "sum" computes the weighted sum of the embedding
584 results for each row. "mean" is the weighted sum divided by the total
585 weight. "sqrtn" is the weighted sum divided by the square root of the sum
586 of the squares of the weights. Defaults to `mean`.
587 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
588 than this value, before combining.
589 name: Optional name for the op.
590 allow_fast_lookup: An optional boolean specifying whether to allow
591 simplified embedding lookups when `params` is a single tensor and
592 `max_norm` is `None`. Setting this flag to `True` during training can
593 cause the use of dense gradients with increased memory footprint.
595 Returns:
596 A dense tensor representing the combined embeddings for the
597 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
598 looks up the embeddings for all ids in that row, multiplies them by the
599 corresponding weight, and combines these embeddings as specified.
601 In other words, if
603 `shape(combined params) = [p0, p1, ..., pm]`
605 and
607 `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
609 then
611 `shape(output) = [d0, p1, ..., pm]`.
613 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
615 ```python
616 [0, 0]: id 1, weight 2.0
617 [0, 1]: id 3, weight 0.5
618 [1, 0]: id 0, weight 1.0
619 [2, 3]: id 1, weight 3.0
620 ```
622 with `combiner`="mean", then the output will be a 3x20 matrix where
624 ```python
625 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
626 output[1, :] = (params[0, :] * 1.0) / 1.0
627 output[2, :] = (params[1, :] * 3.0) / 3.0
628 ```
630 Raises:
631 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
632 neither `None` nor `SparseTensor`.
633 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
634 """
635 return embedding_lookup_sparse(
636 params,
637 sp_ids,
638 sp_weights,
639 "div",
640 name,
641 combiner,
642 max_norm,
643 allow_fast_lookup,
644 )
647@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
648@dispatch.add_dispatch_support
649def safe_embedding_lookup_sparse_v2(
650 embedding_weights,
651 sparse_ids,
652 sparse_weights=None,
653 combiner="mean",
654 default_id=None,
655 max_norm=None,
656 name=None,
657 allow_fast_lookup=False,
658):
659 """Lookup embedding results, accounting for invalid IDs and empty features.
661 The partitioned embedding in `embedding_weights` must all be the same shape
662 except for the first dimension. The first dimension is allowed to vary as the
663 vocabulary size is not necessarily a multiple of num of shards.
665 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
666 with non-positive weight. For an entry with no features, the embedding vector
667 for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
669 The ids and weights may be multi-dimensional `SparseTensor`s or
670 `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero
671 entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can
672 yield higher performance.
674 If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned
675 between the elements of `embedding_weights` according to the "div" partition
676 strategy, which means we assign ids to partitions in a contiguous manner. For
677 instance, 13 ids are split across 5 partitions as:
678 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
680 If the id space does not evenly divide the number of partitions, each of the
681 first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one
682 more id.
684 Args:
685 embedding_weights: A single tensor representing the complete embedding
686 tensor, or a list of tensors all of same shape except for the first
687 dimension, representing sharded embedding tensors following "div"
688 partition strategy.
689 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
690 ids, where `d_0` is typically batch size, or a `RaggedTensor` with rank 2.
691 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as
692 `sparse_ids`, containing float weights corresponding to `sparse_ids`, or
693 `None` if all weights are assumed to be 1.0.
694 combiner: A string specifying how to combine embedding results for each
695 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
696 default.
697 default_id: The id to use for an entry with no features. Defaults to
698 0-vector.
699 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
700 combining.
701 name: A name for this operation (optional).
702 allow_fast_lookup: An optional boolean specifying whether to allow
703 simplified embedding lookups when `params` is a single tensor and
704 `max_norm` is `None`. Setting this flag to `True` during training can
705 cause the use of dense gradients with increased memory footprint.
707 Returns:
708 A dense tensor representing the combined embeddings for the
709 sparse ids. For each row in the dense tensor represented by `sparse_ids`,
710 the op looks up the embeddings for all ids in that row, multiplies them by
711 the corresponding weight, and combines these embeddings as specified.
713 In other words, if
715 `shape(combined embedding_weights) = [p0, p1, ..., pm]`
717 and
719 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
721 then
723 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
725 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
727 ```python
728 [0, 0]: id 1, weight 2.0
729 [0, 1]: id 3, weight 0.5
730 [1, 0]: id -1, weight 1.0
731 [2, 3]: id 1, weight 3.0
732 ```
734 `default_id` is 0.
736 with `combiner`="mean", then the output will be a 3x20 matrix where
738 ```python
739 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
740 output[1, :] = (params[0, :] * 1.0) / 1.0
741 output[2, :] = (params[1, :] * 3.0) / 3.0
742 ```
744 Raises:
745 ValueError: if `embedding_weights` is empty.
746 """
747 return safe_embedding_lookup_sparse(
748 embedding_weights,
749 sparse_ids,
750 sparse_weights=sparse_weights,
751 combiner=combiner,
752 default_id=default_id,
753 name=name,
754 partition_strategy="div",
755 max_norm=max_norm,
756 allow_fast_lookup=allow_fast_lookup,
757 )
760@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
761@dispatch.add_dispatch_support
762def safe_embedding_lookup_sparse(
763 embedding_weights,
764 sparse_ids,
765 sparse_weights=None,
766 combiner="mean",
767 default_id=None,
768 name=None,
769 partition_strategy="div",
770 max_norm=None,
771 allow_fast_lookup=False,
772):
773 """Lookup embedding results, accounting for invalid IDs and empty features.
775 The partitioned embedding in `embedding_weights` must all be the same shape
776 except for the first dimension. The first dimension is allowed to vary as the
777 vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
778 may be a `PartitionedVariable` as returned by using
779 `tf.compat.v1.get_variable()` with a
780 partitioner.
782 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
783 with non-positive weight. For an entry with no features, the embedding vector
784 for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
786 The ids and weights may be multi-dimensional `SparseTensor`s or
787 `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero
788 entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can
789 yield higher performance. Embeddings are always aggregated along the last
790 dimension.
792 Args:
793 embedding_weights: A single tensor representing the complete embedding
794 tensor, or a list tensors all of same shape except for the first
795 dimension, representing sharded embedding tensors. Alternatively, a
796 `PartitionedVariable`, created by partitioning along dimension 0. Each
797 element must be appropriately sized for the given `partition_strategy`.
798 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
799 ids, where `d_0` is typically batch size, or a `RaggedTensor` with rank 2.
800 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as
801 `sparse_ids`, containing float weights corresponding to `sparse_ids`, or
802 `None` if all weights are assumed to be 1.0.
803 combiner: A string specifying how to combine embedding results for each
804 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
805 default.
806 default_id: The id to use for an entry with no features.
807 name: A name for this operation (optional).
808 partition_strategy: A string specifying the partitioning strategy. Currently
809 `"div"` and `"mod"` are supported. Default is `"div"`.
810 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
811 combining.
812 allow_fast_lookup: An optional boolean specifying whether to allow
813 simplified embedding lookups when `params` is a single tensor and
814 `max_norm` is `None`. Setting this flag to `True` during training can
815 cause the use of dense gradients with increased memory footprint.
817 Returns:
818 A dense tensor representing the combined embeddings for the
819 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
820 looks up the embeddings for all ids in that row, multiplies them by the
821 corresponding weight, and combines these embeddings as specified.
823 In other words, if
825 `shape(combined embedding_weights) = [p0, p1, ..., pm]`
827 and
829 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
831 then
833 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
835 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
837 ```python
838 [0, 0]: id 1, weight 2.0
839 [0, 1]: id 3, weight 0.5
840 [1, 0]: id -1, weight 1.0
841 [2, 3]: id 1, weight 3.0
842 ```
844 `default_id` is 0.
846 with `combiner`="mean", then the output will be a 3x20 matrix where
848 ```python
849 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
850 output[1, :] = (params[0, :] * 1.0) / 1.0
851 output[2, :] = (params[1, :] * 3.0) / 3.0
852 ```
854 Raises:
855 ValueError: if `embedding_weights` is empty.
856 """
857 if embedding_weights is None:
858 raise ValueError(f"Missing embedding_weights {embedding_weights}.")
859 if isinstance(embedding_weights, variables.PartitionedVariable):
860 embedding_weights = list(embedding_weights) # get underlying Variables.
861 if not isinstance(embedding_weights, list):
862 embedding_weights = [embedding_weights]
863 if len(embedding_weights) < 1:
864 raise ValueError(f"Missing embedding_weights {embedding_weights}.")
866 dtype = sparse_weights.dtype if sparse_weights is not None else None
867 embedding_weights = [
868 w if (isinstance(w, resource_variable_ops.ResourceVariable)
869 and dtype in (None, w.dtype))
870 else ops.convert_to_tensor(w, dtype=dtype)
871 for w in embedding_weights
872 ]
874 with ops.name_scope(name, "embedding_lookup", embedding_weights +
875 [sparse_ids, sparse_weights]) as scope:
876 # Reshape higher-rank sparse ids and weights to linear segment ids.
877 original_shape = sparse_ids.dense_shape
878 original_rank_dim = tensor_shape.dimension_value(
879 sparse_ids.dense_shape.get_shape()[0])
880 original_rank = (
881 array_ops.size(original_shape)
882 if original_rank_dim is None else original_rank_dim)
883 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
884 math_ops.reduce_prod(
885 array_ops.slice(original_shape, [0], [original_rank - 1])),
886 array_ops.gather(original_shape, original_rank - 1)
887 ])
888 if sparse_weights is not None:
889 sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
890 sparse_weights.values,
891 sparse_ids.dense_shape)
893 # Prune invalid ids and weights.
894 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
895 if combiner != "sum":
896 sparse_ids, sparse_weights = _prune_invalid_weights(
897 sparse_ids, sparse_weights)
899 # Fill in dummy values for empty features, if necessary.
900 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
901 sparse_ids, default_id or 0)
902 if sparse_weights is not None:
903 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
905 result = embedding_lookup_sparse(
906 embedding_weights,
907 sparse_ids,
908 sparse_weights,
909 combiner=combiner,
910 partition_strategy=partition_strategy,
911 name=None if default_id is None else scope,
912 max_norm=max_norm,
913 allow_fast_lookup=allow_fast_lookup,
914 )
916 if default_id is None:
917 # Broadcast is_row_empty to the same shape as embedding_lookup_result,
918 # for use in Select.
919 is_row_empty = array_ops.tile(
920 array_ops.reshape(is_row_empty, [-1, 1]),
921 array_ops_stack.stack([1, array_ops.shape(result)[1]]))
923 result = array_ops.where(
924 is_row_empty, array_ops.zeros_like(result), result, name=scope)
926 # Reshape back from linear ids back into higher-dimensional dense result.
927 final_result = array_ops.reshape(
928 result,
929 array_ops.concat([
930 array_ops.slice(
931 math_ops.cast(original_shape, dtypes.int32), [0],
932 [original_rank - 1]),
933 array_ops.slice(array_ops.shape(result), [1], [-1])
934 ], 0))
935 final_result.set_shape(
936 tensor_shape.unknown_shape(
937 (tensor_shape.Dimension(original_rank_dim) - 1).value
938 ).concatenate(result.get_shape()[1:])
939 )
940 return final_result
943def embedding_lookup_sparse_impl(
944 params,
945 segment_ids,
946 sp_weights,
947 ids,
948 combiner,
949 ignore_weights,
950 max_norm,
951 allow_fast_lookup,
952 partition_strategy,
953 name,
954):
955 """Implementation of sparse embedding aggregation."""
956 if len(params) == 1 and max_norm is None and allow_fast_lookup:
957 idx = ids
958 embeddings = params[0]
959 else:
960 ids, idx = array_ops.unique(ids)
961 embeddings = embedding_lookup(
962 params, ids, partition_strategy=partition_strategy, max_norm=max_norm
963 )
965 if not ignore_weights:
966 if segment_ids.dtype != dtypes.int32:
967 segment_ids = math_ops.cast(segment_ids, dtypes.int32)
969 weights = sp_weights.values
970 embeddings = array_ops.gather(embeddings, idx)
972 original_dtype = embeddings.dtype
973 if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
974 # Cast low-precision embeddings to float32 during the computation to
975 # avoid numerical issues.
976 embeddings = math_ops.cast(embeddings, dtypes.float32)
977 if weights.dtype != embeddings.dtype:
978 weights = math_ops.cast(weights, embeddings.dtype)
980 # Reshape weights to allow broadcast
981 ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0)
982 ones = array_ops.ones(ones_shape, dtype=dtypes.int32)
983 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
985 orig_weights_shape = weights.get_shape()
986 weights = array_ops.reshape(weights, bcast_weights_shape)
988 # Set the weight shape, since after reshaping to bcast_weights_shape,
989 # the shape becomes None.
990 if embeddings.get_shape().ndims is not None:
991 weights.set_shape(
992 orig_weights_shape.concatenate(
993 [1 for _ in range(embeddings.get_shape().ndims - 1)]
994 )
995 )
997 embeddings *= weights
999 if combiner == "sum":
1000 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
1001 elif combiner == "mean":
1002 embeddings = math_ops.segment_sum(embeddings, segment_ids)
1003 weight_sum = math_ops.segment_sum(weights, segment_ids)
1004 embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name)
1005 elif combiner == "sqrtn":
1006 embeddings = math_ops.segment_sum(embeddings, segment_ids)
1007 weights_squared = math_ops.pow(weights, 2)
1008 weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
1009 weight_sum_sqrt = math_ops.sqrt(weight_sum)
1010 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name)
1011 else:
1012 assert False, "Unrecognized combiner"
1013 if embeddings.dtype != original_dtype:
1014 embeddings = math_ops.cast(embeddings, original_dtype)
1015 else:
1016 if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
1017 segment_ids = math_ops.cast(segment_ids, dtypes.int32)
1018 assert idx is not None
1019 if combiner == "sum":
1020 embeddings = math_ops.sparse_segment_sum(
1021 embeddings, idx, segment_ids, name=name
1022 )
1023 elif combiner == "mean":
1024 embeddings = math_ops.sparse_segment_mean(
1025 embeddings, idx, segment_ids, name=name
1026 )
1027 elif combiner == "sqrtn":
1028 embeddings = math_ops.sparse_segment_sqrt_n(
1029 embeddings, idx, segment_ids, name=name
1030 )
1031 else:
1032 assert False, "Unrecognized combiner"
1034 return embeddings
1037def _prune_invalid_ids(sparse_ids, sparse_weights):
1038 """Prune invalid IDs (< 0) from the input ids and weights."""
1039 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
1040 if sparse_weights is not None:
1041 is_id_valid = math_ops.logical_and(
1042 is_id_valid,
1043 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
1044 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
1045 if sparse_weights is not None:
1046 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
1047 return sparse_ids, sparse_weights
1050def _prune_invalid_weights(sparse_ids, sparse_weights):
1051 """Prune invalid weights (< 0) from the input ids and weights."""
1052 if sparse_weights is not None:
1053 is_weights_valid = math_ops.greater(sparse_weights.values, 0)
1054 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
1055 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
1056 return sparse_ids, sparse_weights