Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_embedding_ops.py: 23%
92 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Embedding operations."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import array_ops_stack
21from tensorflow.python.ops import embedding_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import resource_variable_ops
24from tensorflow.python.ops import variables
25from tensorflow.python.ops.ragged import ragged_array_ops
26from tensorflow.python.ops.ragged import ragged_functional_ops
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.util import dispatch
31@dispatch.dispatch_for_api(embedding_ops.embedding_lookup)
32def embedding_lookup(
33 params,
34 ids: ragged_tensor.Ragged,
35 partition_strategy="mod",
36 name=None,
37 validate_indices=True, # pylint: disable=unused-argument
38 max_norm=None,
39):
40 """Look up the ragged ids in a list of embedding tensors.
42 Args:
43 params: A tensor representing the complete embedding tensor having the shape
44 [e1, ...eM]
45 ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids
46 to be looked up in 'params' of shape [r0, ..rN]. Values must be in the
47 range '[0, params.shape[0]]'.
48 partition_strategy: A string specifying the partitioning strategy.
49 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
50 than this value.
51 name: A name for the operation (optional)
53 Returns:
54 A ragged tensor of shape [r0, r1, ...rN, e1, ...eM].
56 Raises:
57 ValueError: When params is empty or the type of the ids is not int32 or
58 int64.
59 """
60 if params is None:
61 raise ValueError("params must be specified.")
62 if isinstance(params, (list, tuple)) and not params:
63 raise ValueError("params should not be empty.")
64 if ids.dtype != dtypes.int32 and ids.dtype != dtypes.int64:
65 raise ValueError(
66 "The values contained by the inputs have type "
67 f"{str(ids.dtype)}"
68 " and cannot be processed. All values"
69 " should be indices, either of type `int32` or `int64`."
70 )
72 with ops.name_scope(name, "embedding_lookup_ragged") as name:
73 looked_up_ragged = ragged_functional_ops.map_flat_values(
74 embedding_ops.embedding_lookup,
75 params=params,
76 ids=ids,
77 partition_strategy=partition_strategy,
78 max_norm=max_norm,
79 )
81 return looked_up_ragged
84@dispatch.dispatch_for_api(embedding_ops.embedding_lookup_sparse)
85def embedding_lookup_sparse(
86 params,
87 sp_ids: ragged_tensor.Ragged,
88 sp_weights,
89 partition_strategy="mod",
90 name=None,
91 combiner=None,
92 max_norm=None,
93 allow_fast_lookup=False,
94):
95 """Looks up embeddings for the given ids and weights from a list of tensors.
97 This op assumes that there is at least one id for each row in the dense tensor
98 represented by sp_ids (i.e. there are no rows with empty features), and that
99 all the indices of sp_ids are in canonical row-major order.
101 `sp_ids` and `sp_weights` (if not None) are `RaggedTensor`s with rank of 2.
102 Embeddings are always aggregated along the last dimension.
104 It also assumes that all id values lie in the range [0, p0), where p0
105 is the sum of the size of params along dimension 0.
107 Args:
108 params: A single tensor representing the complete embedding tensor, or a
109 list tensors all of same shape except for the first dimension,
110 representing sharded embedding tensors. Alternatively, a
111 `PartitionedVariable`, created by partitioning along dimension 0. Each
112 element must be appropriately sized for the given `partition_strategy`.
113 sp_ids: `RaggedTensor` with rank 2. The rank is not verified for performance
114 reasons.
115 sparse_weights: `RaggedTensor` of same type and shape as `sparse_ids`,
116 containing float / double weights corresponding to `sparse_ids`, or `None`
117 if all weights are assumed to be 1.0.
118 partition_strategy: A string specifying the partitioning strategy, relevant
119 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
120 is `"mod"`. See `tf.nn.embedding_lookup` for more details.
121 name: Optional name for the op.
122 combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
123 and "sum" are supported. "sum" computes the weighted sum of the embedding
124 results for each row. "mean" is the weighted sum divided by the total
125 weight. "sqrtn" is the weighted sum divided by the square root of the sum
126 of the squares of the weights. Defaults to `mean`.
127 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
128 than this value, before combining.
129 allow_fast_lookup: An optional boolean specifying whether to allow
130 simplified embedding lookups when `params` is a single tensor and
131 `max_norm` is `None`. Setting this flag to `True` during training can
132 cause the use of dense gradients with increased memory footprint.
134 Returns:
135 A dense tensor representing the combined embeddings for the
136 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
137 looks up the embeddings for all ids in that row, multiplies them by the
138 corresponding weight, and combines these embeddings as specified.
140 In other words, if
142 `shape(combined params) = [p0, p1, ..., pm]`
144 and
146 `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
148 then
150 `shape(output) = [d0, p1, ..., pm]`.
152 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
154 ```python
155 [0, 0]: id 1, weight 2.0
156 [0, 1]: id 3, weight 0.5
157 [1, 0]: id 0, weight 1.0
158 [2, 3]: id 1, weight 3.0
159 ```
161 with `combiner`="mean", then the output will be a 3x20 matrix where
163 ```python
164 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
165 output[1, :] = (params[0, :] * 1.0) / 1.0
166 output[2, :] = (params[1, :] * 3.0) / 3.0
167 ```
169 Raises:
170 TypeError: If `sp_weights` is neither `None` nor of the same type as
171 `sp_ids`.
172 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
173 """
174 rt_ids = sp_ids
175 rt_weights = sp_weights
176 if combiner is None:
177 combiner = "mean"
178 if combiner not in ("mean", "sqrtn", "sum"):
179 raise ValueError(
180 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}"
181 )
182 if isinstance(params, variables.PartitionedVariable):
183 params = list(params) # Iterate to get the underlying Variables.
184 if not isinstance(params, list):
185 params = [params]
186 ignore_weights = rt_weights is None
187 if not ignore_weights:
188 if not isinstance(rt_weights, ragged_tensor.RaggedTensor):
189 raise TypeError(
190 f"sp_ids must be of the same type as sp_weights, "
191 f"received {{type(sp_ids).__name__!r}} for sp_ids and "
192 f"{{type(sp_weights).__name__!r}} for sp_weights."
193 )
194 rt_ids.values.get_shape().assert_is_compatible_with(
195 rt_weights.values.get_shape()
196 )
197 rt_ids.get_shape().assert_is_compatible_with(rt_weights.get_shape())
199 with ops.name_scope(
200 name, "embedding_lookup_sparse", params + [rt_ids]
201 ) as name:
202 segment_ids = rt_ids.value_rowids()
203 ids = rt_ids.flat_values
205 return embedding_ops.embedding_lookup_sparse_impl(
206 params,
207 segment_ids,
208 sp_weights,
209 ids,
210 combiner,
211 ignore_weights,
212 max_norm,
213 allow_fast_lookup,
214 partition_strategy,
215 name,
216 )
219@dispatch.dispatch_for_api(embedding_ops.safe_embedding_lookup_sparse)
220def safe_embedding_lookup_sparse(
221 embedding_weights,
222 sparse_ids: ragged_tensor.Ragged,
223 sparse_weights=None,
224 combiner="mean",
225 default_id=None,
226 name=None,
227 partition_strategy="div",
228 max_norm=None,
229 allow_fast_lookup=False,
230):
231 """Lookup embedding results, accounting for invalid IDs and empty features.
233 The partitioned embedding in `embedding_weights` must all be the same shape
234 except for the first dimension. The first dimension is allowed to vary as the
235 vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
236 may be a `PartitionedVariable` as returned by using
237 `tf.compat.v1.get_variable()` with a
238 partitioner.
240 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
241 with non-positive weight. For an entry with no features, the embedding vector
242 for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
244 The ids and weights may be multi-dimensional `SparseTensor`s or
245 `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero
246 entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can
247 yield higher performance. Embeddings are always aggregated along the last
248 dimension.
250 Args:
251 embedding_weights: A single tensor representing the complete embedding
252 tensor, or a list tensors all of same shape except for the first
253 dimension, representing sharded embedding tensors. Alternatively, a
254 `PartitionedVariable`, created by partitioning along dimension 0. Each
255 element must be appropriately sized for the given `partition_strategy`.
256 sp_ids: `RaggedTensor` with rank 2. The rank is not verified for performance
257 reasons.
258 sparse_weights: `RaggedTensor` of same type and shape as `sparse_ids`,
259 containing float weights corresponding to `sparse_ids`, or `None` if all
260 weights are assumed to be 1.0.
261 combiner: A string specifying how to combine embedding results for each
262 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
263 default.
264 default_id: The id to use for an entry with no features.
265 name: A name for this operation (optional).
266 partition_strategy: A string specifying the partitioning strategy. Currently
267 `"div"` and `"mod"` are supported. Default is `"div"`.
268 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
269 combining.
270 allow_fast_lookup: An optional boolean specifying whether to allow
271 simplified embedding lookups when `params` is a single tensor and
272 `max_norm` is `None`. Setting this flag to `True` during training can
273 cause the use of dense gradients with increased memory footprint.
275 Returns:
276 A dense tensor representing the combined embeddings for the
277 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
278 looks up the embeddings for all ids in that row, multiplies them by the
279 corresponding weight, and combines these embeddings as specified.
281 In other words, if
283 `shape(combined embedding_weights) = [p0, p1, ..., pm]`
285 and
287 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
289 then
291 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
293 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
295 ```python
296 [0, 0]: id 1, weight 2.0
297 [0, 1]: id 3, weight 0.5
298 [1, 0]: id -1, weight 1.0
299 [2, 3]: id 1, weight 3.0
300 ```
302 `default_id` is 0.
304 with `combiner`="mean", then the output will be a 3x20 matrix where
306 ```python
307 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
308 output[1, :] = (params[0, :] * 1.0) / 1.0
309 output[2, :] = (params[1, :] * 3.0) / 3.0
310 ```
312 Raises:
313 ValueError: if `embedding_weights` is empty.
314 """
315 ragged_ids = sparse_ids
316 ragged_weights = sparse_weights
317 if embedding_weights is None:
318 raise ValueError(f"Missing embedding_weights {embedding_weights}.")
319 if isinstance(embedding_weights, variables.PartitionedVariable):
320 embedding_weights = list(embedding_weights) # get underlying Variables.
321 if not isinstance(embedding_weights, list):
322 embedding_weights = [embedding_weights]
323 if len(embedding_weights) < 1:
324 raise ValueError(f"Missing embedding_weights {embedding_weights}.")
326 dtype = ragged_weights.dtype if ragged_weights is not None else None
327 embedding_weights = [
328 w
329 if (
330 isinstance(w, resource_variable_ops.ResourceVariable)
331 and dtype in (None, w.dtype)
332 )
333 else ops.convert_to_tensor(w, dtype=dtype)
334 for w in embedding_weights
335 ]
337 with ops.name_scope(
338 name, "embedding_lookup", embedding_weights + [ragged_ids, ragged_weights]
339 ) as scope:
340 # Prune invalid ids and weights.
341 ragged_ids, ragged_weights = _prune_invalid_ids_ragged(
342 ragged_ids, ragged_weights
343 )
344 if combiner != "sum":
345 ragged_ids, ragged_weights = _prune_invalid_weights_ragged(
346 ragged_ids, ragged_weights
347 )
348 ragged_ids, is_row_empty = ragged_array_ops.fill_empty_rows(
349 ragged_ids, default_id or 0
350 )
351 if ragged_weights is not None:
352 ragged_weights, _ = ragged_array_ops.fill_empty_rows(ragged_weights, 1.0)
354 result = embedding_lookup_sparse(
355 embedding_weights,
356 ragged_ids,
357 ragged_weights,
358 combiner=combiner,
359 partition_strategy=partition_strategy,
360 name=None if default_id is None else scope,
361 max_norm=max_norm,
362 allow_fast_lookup=allow_fast_lookup,
363 )
365 if default_id is None:
366 # Broadcast is_row_empty to the same shape as embedding_lookup_result,
367 # for use in Select.
368 is_row_empty = array_ops.tile(
369 array_ops.reshape(is_row_empty, [-1, 1]),
370 array_ops_stack.stack([1, array_ops.shape(result)[1]]),
371 )
373 result = array_ops.where(
374 is_row_empty, array_ops.zeros_like(result), result, name=scope
375 )
377 return result
380def _prune_invalid_ids_ragged(ids, weights):
381 """Prune invalid IDs (< 0) from the input ids and weights."""
382 is_id_valid = math_ops.greater_equal(ids.values, 0)
383 nrows = ids.nrows()
384 # TODO(philipphack): Consider calling ragged_array_ops.boolean_mask once the
385 # resulting performance is comparable to array_ops.boolean_mask. Currently,
386 # ragged_array_ops.boolean_mask constructs the returned RaggedTensor by
387 # calling its from_row_splits method which does not set value_row_ids and
388 # requires it to be computed on demand.
389 pruned_values = array_ops.boolean_mask_v2(ids.values, is_id_valid)
390 pruned_value_rowids = array_ops.boolean_mask_v2(
391 ids.value_rowids(), is_id_valid
392 )
393 ids = ragged_tensor.RaggedTensor.from_value_rowids(
394 pruned_values, pruned_value_rowids, nrows=nrows, validate=False
395 )
396 if weights is not None:
397 pruned_weights_values = array_ops.boolean_mask_v2(
398 weights.values, is_id_valid
399 )
400 weights = ragged_tensor.RaggedTensor.from_value_rowids(
401 pruned_weights_values, pruned_value_rowids, nrows=nrows, validate=False
402 )
404 return ids, weights
407def _prune_invalid_weights_ragged(ids, weights):
408 """Prune invalid weights (< 0) from the input ids and weights."""
409 if weights is not None:
410 is_weights_valid = math_ops.greater(weights.values, 0)
411 nrows = ids.nrows()
412 # TODO(philipphack): Consider calling ragged_array_ops.boolean_mask once the
413 # resulting performance is comparable to array_ops.boolean_mask. Currently,
414 # ragged_array_ops.boolean_mask constructs the returned RaggedTensor by
415 # calling its from_row_splits method which does not set value_row_ids and
416 # requires it to be computed on demand.
417 pruned_values = array_ops.boolean_mask_v2(ids.values, is_weights_valid)
418 pruned_value_rowids = array_ops.boolean_mask_v2(
419 ids.value_rowids(), is_weights_valid
420 )
421 ids = ragged_tensor.RaggedTensor.from_value_rowids(
422 pruned_values, pruned_value_rowids, nrows=nrows, validate=False
423 )
425 pruned_weights_values = array_ops.boolean_mask_v2(
426 weights.values, is_weights_valid
427 )
428 weights = ragged_tensor.RaggedTensor.from_value_rowids(
429 pruned_weights_values, pruned_value_rowids, nrows=nrows, validate=False
430 )
432 return ids, weights