Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/embedding_ops.py: 17%

244 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operations for embeddings.""" 

16 

17from tensorflow.python.framework import constant_op 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import indexed_slices 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import sparse_tensor 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import array_ops_stack 

25from tensorflow.python.ops import clip_ops 

26# Imports gradient definitions. 

27from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 

28from tensorflow.python.ops import data_flow_ops 

29from tensorflow.python.ops import math_ops 

30from tensorflow.python.ops import resource_variable_ops 

31from tensorflow.python.ops import sparse_ops 

32from tensorflow.python.ops import variables 

33from tensorflow.python.util import dispatch 

34from tensorflow.python.util.tf_export import tf_export 

35 

36 

37def _clip(params, ids, max_norm): 

38 """Helper function for _embedding_lookup_and_transform. 

39 

40 This function optionally clips embeddings to an l2-norm of max_norm. 

41 

42 Args: 

43 params: A `Tensor` of embeddings retrieved by `gather`. 

44 ids: The `ids` argument that was passed to `gather`. 

45 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 

46 than this value. 

47 

48 Returns: 

49 A `Tensor` with the same type as `params`. 

50 """ 

51 

52 def _rank(x): 

53 """Helper function to retrieve the rank of a tensor. 

54 

55 Args: 

56 x: Something convertible to `Tensor`. 

57 

58 Returns: 

59 Either a pair `(rank, True)` where `rank` is an integer or a pair 

60 `(rank, False)` where `rank` is an integer `Tensor`. In either case, 

61 `rank` is the rank of `x`. 

62 """ 

63 rank = ops.convert_to_tensor(x).get_shape().ndims 

64 if rank: 

65 return rank, True 

66 else: 

67 return array_ops.rank(x), False 

68 

69 if max_norm is None: 

70 return params 

71 ids_rank, ids_static = _rank(ids) 

72 params_rank, params_static = _rank(params) 

73 return clip_ops.clip_by_norm( 

74 params, 

75 max_norm, 

76 axes=(list(range(ids_rank, params_rank)) if ids_static and params_static 

77 else math_ops.range(ids_rank, params_rank))) 

78 

79 

80def _colocate_with(param): 

81 if ops.inside_function() and hasattr(param, "handle"): 

82 # The `ops.colocate_with` will hard-code a device string if `param.device` 

83 # is known, which will then break serving. We capture it here so that it 

84 # produces a tensor without a device. 

85 return ops.colocate_with(ops.get_default_graph().capture(param.handle)) 

86 else: 

87 return ops.colocate_with(param) 

88 

89 

90def _embedding_lookup_and_transform(params, 

91 ids, 

92 partition_strategy="mod", 

93 name=None, 

94 max_norm=None, 

95 transform_fn=None): 

96 """Helper function for embedding_lookup and _compute_sampled_logits. 

97 

98 This function is a generalization of embedding_lookup that optionally 

99 applies a caller-specified transformation to each embedding. This is 

100 done through the `transform_fn` argument. If provided, the function is 

101 applied to each partitioned tensor of retrieved embeddings, colocated 

102 with the embeddings. This function will be called with a single `Tensor` 

103 argument of the same type as the `params` tensor and should return a 

104 `Tensor`. The shape of the argument will be the same as `params` except 

105 for the size of the first dimension. The first dimension of the result's 

106 shape must be the same size as the argument's. 

107 

108 Args: 

109 params: See embedding_lookup. 

110 ids: See embedding_lookup. 

111 partition_strategy: See embedding_lookup. 

112 name: See embedding_lookup. 

113 max_norm: See embedding_lookup. 

114 transform_fn: An optional function to apply to each retrieved embedding. If 

115 max_norm is provided, transform_fn is applied to the norm-limited 

116 embeddings. 

117 

118 Returns: 

119 See embedding_lookup for details. 

120 Raises: 

121 ValueError: If `params` is empty. 

122 """ 

123 if params is None: 

124 raise ValueError("params must be specified") 

125 if isinstance(params, (list, tuple)) and not params: 

126 raise ValueError("Length of params is currently 0. " 

127 "Need at least one param.") 

128 if isinstance(params, variables.PartitionedVariable): 

129 params = list(params) # Iterate to get the underlying Variables. 

130 if not isinstance(params, list): 

131 params = [params] 

132 

133 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: 

134 np = len(params) # Number of partitions 

135 # Preserve the resource variable status to avoid accidental dense reads. 

136 if not any( 

137 isinstance(p, resource_variable_ops.BaseResourceVariable) 

138 for p in params): 

139 params = indexed_slices.convert_n_to_tensor_or_indexed_slices( 

140 params, name="params") 

141 ids = ops.convert_to_tensor(ids, name="ids") 

142 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): 

143 with _colocate_with(params[0]): 

144 result = _clip( 

145 array_ops.gather(params[0], ids, name=name), ids, max_norm) 

146 if transform_fn: 

147 result = transform_fn(result) 

148 # Make sure the final result does not have colocation constraints on the 

149 # params. Similar to the case np > 1 where parallel_dynamic_stitch is 

150 # outside the scope of all with _colocate_with(params[p]). 

151 return array_ops.identity(result) 

152 else: 

153 # Flatten the ids. There are two cases where we need to do this. 

154 # - There is more than one params tensor. 

155 # - There is a transform_fn and ids is not statically known to be 1-D. 

156 # We must flatten in this case because transform_fn expects a flat 

157 # tensor of embeddings. 

158 flat_ids = array_ops.reshape(ids, [-1]) 

159 original_indices = math_ops.range(array_ops.size(flat_ids)) 

160 

161 # Create p_assignments and set new_ids depending on the strategy. 

162 if partition_strategy == "mod": 

163 p_assignments = flat_ids % np 

164 new_ids = flat_ids // np 

165 elif partition_strategy == "div": 

166 # Compute num_total_ids as the sum of dim-0 of params, then assign to 

167 # partitions based on a constant number of ids per partition. Optimize 

168 # if we already know the full shape statically. 

169 dim_0_size = tensor_shape.Dimension( 

170 tensor_shape.dimension_value(params[0].get_shape()[0])) 

171 for p in range(1, np): 

172 dim_0_size += tensor_shape.Dimension( 

173 tensor_shape.dimension_value(params[p].get_shape()[0])) 

174 if dim_0_size.value: 

175 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 

176 else: 

177 dim_0_sizes = [] 

178 for p in range(np): 

179 param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) 

180 if param_p_dim is not None: 

181 dim_0_sizes.append(param_p_dim) 

182 else: 

183 with _colocate_with(params[p]): 

184 dim_0_sizes.append(array_ops.shape(params[p])[0]) 

185 num_total_ids = math_ops.reduce_sum( 

186 math_ops.cast(array_ops_stack.stack(dim_0_sizes), flat_ids.dtype)) 

187 ids_per_partition = num_total_ids // np 

188 extras = num_total_ids % np 

189 

190 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), 

191 (flat_ids - extras) // 

192 ids_per_partition) 

193 

194 # Emulate a conditional using a boolean indicator tensor 

195 new_ids = array_ops.where(p_assignments < extras, 

196 flat_ids % (ids_per_partition + 1), 

197 (flat_ids - extras) % ids_per_partition) 

198 else: 

199 raise ValueError( 

200 f"Unrecognized partition strategy: {partition_strategy}." 

201 "Must be one of either `mod` or `div`.") 

202 

203 # Cast partition assignments to int32 for use in dynamic_partition. 

204 # There really should not be more than 2^32 partitions. 

205 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 

206 # Partition list of ids based on assignments into np separate lists 

207 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 

208 # Similarly, partition the original indices. 

209 pindices = data_flow_ops.dynamic_partition(original_indices, 

210 p_assignments, np) 

211 # Do np separate lookups, finding embeddings for plist[p] in params[p] 

212 partitioned_result = [] 

213 for p in range(np): 

214 pids = gather_ids[p] 

215 with ops.device_v2(None): 

216 with _colocate_with(params[p]): 

217 result = array_ops.gather(params[p], pids) 

218 if transform_fn: 

219 # If transform_fn is provided, the clip_by_norm precedes 

220 # the transform and hence must be co-located. See below 

221 # for the counterpart if transform_fn is not provided. 

222 result = transform_fn(_clip(result, pids, max_norm)) 

223 partitioned_result.append(result) 

224 # Stitch these back together 

225 ret = data_flow_ops.parallel_dynamic_stitch( 

226 pindices, partitioned_result, name=name) 

227 

228 # Determine the static element shape. 

229 if transform_fn is None: 

230 element_shape_s = params[0].get_shape()[1:] 

231 for p in params[1:]: 

232 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) 

233 else: 

234 element_shape_s = ret.get_shape()[1:] 

235 

236 # Compute the dynamic element shape. 

237 if element_shape_s.is_fully_defined(): 

238 element_shape_d = element_shape_s 

239 elif transform_fn is None: 

240 # It's important that we compute params[0].shape on the right device 

241 # to avoid data motion. 

242 with _colocate_with(params[0]): 

243 params_shape = array_ops.shape(params[0]) 

244 element_shape_d = params_shape[1:] 

245 else: 

246 element_shape_d = array_ops.shape(ret)[1:] 

247 

248 # Reshape to reverse the flattening of ids. 

249 ret = array_ops.reshape( 

250 ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0)) 

251 

252 # Normally the reshape is sufficient, but setting shape explicitly 

253 # teaches shape inference that params[1:].get_shape() matters 

254 # (in the case that transform_fn is None). 

255 ret.set_shape(ids.get_shape().concatenate(element_shape_s)) 

256 if not transform_fn: 

257 # If transform_fn was provided, the clip_by_norm was done above. 

258 ret = _clip(ret, ids, max_norm) 

259 return ret 

260 

261 

262@tf_export(v1=["nn.embedding_lookup"]) 

263@dispatch.add_dispatch_support 

264def embedding_lookup( 

265 params, 

266 ids, 

267 partition_strategy="mod", 

268 name=None, 

269 validate_indices=True, # pylint: disable=unused-argument 

270 max_norm=None): 

271 """Looks up embeddings for the given `ids` from a list of tensors. 

272 

273 This function is used to perform parallel lookups on the list of tensors in 

274 `params`. It is a generalization of `tf.gather`, where `params` is 

275 interpreted as a partitioning of a large embedding tensor. `params` may be 

276 a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()` 

277 with a partitioner. 

278 

279 If `len(params) > 1`, each element `id` of `ids` is partitioned between 

280 the elements of `params` according to the `partition_strategy`. 

281 In all strategies, if the id space does not evenly divide the number of 

282 partitions, each of the first `(max_id + 1) % len(params)` partitions will 

283 be assigned one more id. 

284 

285 If `partition_strategy` is `"mod"`, we assign each id to partition 

286 `p = id % len(params)`. For instance, 

287 13 ids are split across 5 partitions as: 

288 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` 

289 

290 If `partition_strategy` is `"div"`, we assign ids to partitions in a 

291 contiguous manner. In this case, 13 ids are split across 5 partitions as: 

292 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 

293 

294 If the input ids are ragged tensors, partition variables are not supported and 

295 the partition strategy and the max_norm are ignored. 

296 The results of the lookup are concatenated into a dense 

297 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 

298 

299 Args: 

300 params: A single tensor representing the complete embedding tensor, or a 

301 list of P tensors all of same shape except for the first dimension, 

302 representing sharded embedding tensors. Alternatively, a 

303 `PartitionedVariable`, created by partitioning along dimension 0. Each 

304 element must be appropriately sized for the given `partition_strategy`. 

305 ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing 

306 the ids to be looked up in `params`. 

307 partition_strategy: A string specifying the partitioning strategy, relevant 

308 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 

309 is `"mod"`. 

310 name: A name for the operation (optional). 

311 validate_indices: DEPRECATED. If this operation is assigned to CPU, values 

312 in `indices` are always validated to be within range. If assigned to GPU, 

313 out-of-bound indices result in safe but unspecified behavior, which may 

314 include raising an error. 

315 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 

316 than this value. 

317 

318 Returns: 

319 A `Tensor` or a 'RaggedTensor', depending on the input, with the same type 

320 as the tensors in `params`. 

321 

322 Raises: 

323 ValueError: If `params` is empty. 

324 """ 

325 

326 return _embedding_lookup_and_transform( 

327 params=params, 

328 ids=ids, 

329 partition_strategy=partition_strategy, 

330 name=name, 

331 max_norm=max_norm, 

332 transform_fn=None) 

333 

334 

335@tf_export("nn.embedding_lookup", v1=[]) 

336@dispatch.add_dispatch_support 

337def embedding_lookup_v2(params, ids, max_norm=None, name=None): 

338 """Looks up embeddings for the given `ids` from a list of tensors. 

339 

340 This function is used to perform parallel lookups on the list of tensors in 

341 `params`. It is a generalization of `tf.gather`, where `params` is 

342 interpreted as a partitioning of a large embedding tensor. 

343 

344 If `len(params) > 1`, each element `id` of `ids` is partitioned between the 

345 elements of `params` according to the "div" partition strategy, which means we 

346 assign ids to partitions in a contiguous manner. For instance, 13 ids are 

347 split across 5 partitions as: 

348 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 

349 

350 If the id space does not evenly divide the number of partitions, each of the 

351 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 

352 

353 The results of the lookup are concatenated into a dense 

354 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 

355 

356 Args: 

357 params: A single tensor representing the complete embedding tensor, or a 

358 list of tensors all of same shape except for the first dimension, 

359 representing sharded embedding tensors following "div" partition strategy. 

360 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 

361 up in `params`. 

362 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 

363 than this value. 

364 name: A name for the operation (optional). 

365 

366 Returns: 

367 A `Tensor` with the same type as the tensors in `params`. 

368 

369 For instance, if `params` is a 5x2 matrix: 

370 

371 ```python 

372 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] 

373 ``` 

374 

375 or a list of matrices: 

376 

377 ```python 

378 params[0]: [[1, 2], [3, 4]] 

379 params[1]: [[5, 6], [7, 8]] 

380 params[2]: [[9, 10]] 

381 ``` 

382 

383 and `ids` is: 

384 

385 ```python 

386 [0, 3, 4] 

387 ``` 

388 

389 The output will be a 3x2 matrix: 

390 

391 ```python 

392 [[1, 2], [7, 8], [9, 10]] 

393 ``` 

394 

395 Raises: 

396 ValueError: If `params` is empty. 

397 """ 

398 return embedding_lookup(params, ids, "div", name, max_norm=max_norm) 

399 

400 

401@tf_export(v1=["nn.embedding_lookup_sparse"]) 

402@dispatch.add_dispatch_support 

403def embedding_lookup_sparse( 

404 params, 

405 sp_ids, 

406 sp_weights, 

407 partition_strategy="mod", 

408 name=None, 

409 combiner=None, 

410 max_norm=None, 

411 allow_fast_lookup=False, 

412): 

413 """Looks up embeddings for the given ids and weights from a list of tensors. 

414 

415 This op assumes that there is at least one id for each row in the dense tensor 

416 represented by sp_ids (i.e. there are no rows with empty features), and that 

417 all the indices of sp_ids are in canonical row-major order. 

418 

419 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s or `RaggedTensor`s 

420 with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which 

421 can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher 

422 performance. 

423 

424 It also assumes that all id values lie in the range [0, p0), where p0 

425 is the sum of the size of params along dimension 0. 

426 

427 Args: 

428 params: A single tensor representing the complete embedding tensor, or a 

429 list tensors all of same shape except for the first dimension, 

430 representing sharded embedding tensors. Alternatively, a 

431 `PartitionedVariable`, created by partitioning along dimension 0. Each 

432 element must be appropriately sized for the given `partition_strategy`. 

433 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 

434 and M is arbitrary or a `RaggedTensor` with rank 2. 

435 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as 

436 `sparse_ids`, containing float / double weights corresponding to 

437 `sparse_ids`, or `None` if all weights are assumed to be 1.0. 

438 partition_strategy: A string specifying the partitioning strategy, relevant 

439 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 

440 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 

441 name: Optional name for the op. 

442 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 

443 and "sum" are supported. "sum" computes the weighted sum of the embedding 

444 results for each row. "mean" is the weighted sum divided by the total 

445 weight. "sqrtn" is the weighted sum divided by the square root of the sum 

446 of the squares of the weights. Defaults to `mean`. 

447 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 

448 than this value, before combining. 

449 allow_fast_lookup: An optional boolean specifying whether to allow 

450 simplified embedding lookups when `params` is a single tensor and 

451 `max_norm` is `None`. Setting this flag to `True` during training can 

452 cause the use of dense gradients with increased memory footprint. 

453 

454 Returns: 

455 A dense tensor representing the combined embeddings for the 

456 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 

457 looks up the embeddings for all ids in that row, multiplies them by the 

458 corresponding weight, and combines these embeddings as specified. 

459 

460 In other words, if 

461 

462 `shape(combined params) = [p0, p1, ..., pm]` 

463 

464 and 

465 

466 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 

467 

468 then 

469 

470 `shape(output) = [d0, p1, ..., pm]`. 

471 

472 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 

473 

474 ```python 

475 [0, 0]: id 1, weight 2.0 

476 [0, 1]: id 3, weight 0.5 

477 [1, 0]: id 0, weight 1.0 

478 [2, 3]: id 1, weight 3.0 

479 ``` 

480 

481 with `combiner`="mean", then the output will be a 3x20 matrix where 

482 

483 ```python 

484 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 

485 output[1, :] = (params[0, :] * 1.0) / 1.0 

486 output[2, :] = (params[1, :] * 3.0) / 3.0 

487 ``` 

488 

489 Raises: 

490 TypeError: If `sp_ids` is not a `SparseTensor` or `RaggedTensor`, or if 

491 `sp_weights` is neither `None` nor of the same type as `sp_ids`. 

492 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 

493 """ 

494 if combiner is None: 

495 combiner = "mean" 

496 if combiner not in ("mean", "sqrtn", "sum"): 

497 raise ValueError( 

498 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") 

499 if isinstance(params, variables.PartitionedVariable): 

500 params = list(params) # Iterate to get the underlying Variables. 

501 if not isinstance(params, list): 

502 params = [params] 

503 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 

504 raise TypeError(f"sp_ids must be SparseTensor, got {type(sp_ids)}") 

505 ignore_weights = sp_weights is None 

506 if not ignore_weights: 

507 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 

508 raise TypeError(f"sp_weights must be either None or SparseTensor," 

509 f"got {type(sp_weights)}") 

510 sp_ids.values.get_shape().assert_is_compatible_with( 

511 sp_weights.values.get_shape()) 

512 sp_ids.indices.get_shape().assert_is_compatible_with( 

513 sp_weights.indices.get_shape()) 

514 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 

515 sp_weights.dense_shape.get_shape()) 

516 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 

517 # sp_weights have equal indices and shapes. 

518 

519 with ops.name_scope(name, "embedding_lookup_sparse", 

520 params + [sp_ids]) as name: 

521 

522 segment_ids = sp_ids.indices[:, 0] 

523 ids = sp_ids.values 

524 

525 return embedding_lookup_sparse_impl( 

526 params, 

527 segment_ids, 

528 sp_weights, 

529 ids, 

530 combiner, 

531 ignore_weights, 

532 max_norm, 

533 allow_fast_lookup, 

534 partition_strategy, 

535 name, 

536 ) 

537 

538 

539@tf_export("nn.embedding_lookup_sparse", v1=[]) 

540@dispatch.add_dispatch_support 

541def embedding_lookup_sparse_v2( 

542 params, 

543 sp_ids, 

544 sp_weights, 

545 combiner=None, 

546 max_norm=None, 

547 name=None, 

548 allow_fast_lookup=False, 

549): 

550 """Looks up embeddings for the given ids and weights from a list of tensors. 

551 

552 This op assumes that there is at least one id for each row in the dense tensor 

553 represented by sp_ids (i.e. there are no rows with empty features), and that 

554 all the indices of sp_ids are in canonical row-major order. 

555 

556 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s or `RaggedTensor`s 

557 with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which 

558 can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher 

559 performance. 

560 

561 It also assumes that all id values lie in the range [0, p0), where p0 

562 is the sum of the size of params along dimension 0. 

563 

564 If `len(params) > 1`, each element of `sp_ids` is partitioned between the 

565 elements of `params` according to the "div" partition strategy, which means we 

566 assign ids to partitions in a contiguous manner. For instance, 13 ids are 

567 split across 5 partitions as: 

568 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 

569 

570 If the id space does not evenly divide the number of partitions, each of the 

571 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 

572 

573 Args: 

574 params: A single tensor representing the complete embedding tensor, or a 

575 list of tensors all of same shape except for the first dimension, 

576 representing sharded embedding tensors following "div" partition strategy. 

577 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 

578 and M is arbitrary or a `RaggedTensor` with rank 2. 

579 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as 

580 `sparse_ids`, containing float / double weights corresponding to 

581 `sparse_ids`, or `None` if all weights are assumed to be 1.0. 

582 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 

583 and "sum" are supported. "sum" computes the weighted sum of the embedding 

584 results for each row. "mean" is the weighted sum divided by the total 

585 weight. "sqrtn" is the weighted sum divided by the square root of the sum 

586 of the squares of the weights. Defaults to `mean`. 

587 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 

588 than this value, before combining. 

589 name: Optional name for the op. 

590 allow_fast_lookup: An optional boolean specifying whether to allow 

591 simplified embedding lookups when `params` is a single tensor and 

592 `max_norm` is `None`. Setting this flag to `True` during training can 

593 cause the use of dense gradients with increased memory footprint. 

594 

595 Returns: 

596 A dense tensor representing the combined embeddings for the 

597 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 

598 looks up the embeddings for all ids in that row, multiplies them by the 

599 corresponding weight, and combines these embeddings as specified. 

600 

601 In other words, if 

602 

603 `shape(combined params) = [p0, p1, ..., pm]` 

604 

605 and 

606 

607 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 

608 

609 then 

610 

611 `shape(output) = [d0, p1, ..., pm]`. 

612 

613 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 

614 

615 ```python 

616 [0, 0]: id 1, weight 2.0 

617 [0, 1]: id 3, weight 0.5 

618 [1, 0]: id 0, weight 1.0 

619 [2, 3]: id 1, weight 3.0 

620 ``` 

621 

622 with `combiner`="mean", then the output will be a 3x20 matrix where 

623 

624 ```python 

625 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 

626 output[1, :] = (params[0, :] * 1.0) / 1.0 

627 output[2, :] = (params[1, :] * 3.0) / 3.0 

628 ``` 

629 

630 Raises: 

631 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 

632 neither `None` nor `SparseTensor`. 

633 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 

634 """ 

635 return embedding_lookup_sparse( 

636 params, 

637 sp_ids, 

638 sp_weights, 

639 "div", 

640 name, 

641 combiner, 

642 max_norm, 

643 allow_fast_lookup, 

644 ) 

645 

646 

647@tf_export("nn.safe_embedding_lookup_sparse", v1=[]) 

648@dispatch.add_dispatch_support 

649def safe_embedding_lookup_sparse_v2( 

650 embedding_weights, 

651 sparse_ids, 

652 sparse_weights=None, 

653 combiner="mean", 

654 default_id=None, 

655 max_norm=None, 

656 name=None, 

657 allow_fast_lookup=False, 

658): 

659 """Lookup embedding results, accounting for invalid IDs and empty features. 

660 

661 The partitioned embedding in `embedding_weights` must all be the same shape 

662 except for the first dimension. The first dimension is allowed to vary as the 

663 vocabulary size is not necessarily a multiple of num of shards. 

664 

665 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 

666 with non-positive weight. For an entry with no features, the embedding vector 

667 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 

668 

669 The ids and weights may be multi-dimensional `SparseTensor`s or 

670 `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero 

671 entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can 

672 yield higher performance. 

673 

674 If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned 

675 between the elements of `embedding_weights` according to the "div" partition 

676 strategy, which means we assign ids to partitions in a contiguous manner. For 

677 instance, 13 ids are split across 5 partitions as: 

678 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 

679 

680 If the id space does not evenly divide the number of partitions, each of the 

681 first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one 

682 more id. 

683 

684 Args: 

685 embedding_weights: A single tensor representing the complete embedding 

686 tensor, or a list of tensors all of same shape except for the first 

687 dimension, representing sharded embedding tensors following "div" 

688 partition strategy. 

689 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 

690 ids, where `d_0` is typically batch size, or a `RaggedTensor` with rank 2. 

691 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as 

692 `sparse_ids`, containing float weights corresponding to `sparse_ids`, or 

693 `None` if all weights are assumed to be 1.0. 

694 combiner: A string specifying how to combine embedding results for each 

695 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 

696 default. 

697 default_id: The id to use for an entry with no features. Defaults to 

698 0-vector. 

699 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 

700 combining. 

701 name: A name for this operation (optional). 

702 allow_fast_lookup: An optional boolean specifying whether to allow 

703 simplified embedding lookups when `params` is a single tensor and 

704 `max_norm` is `None`. Setting this flag to `True` during training can 

705 cause the use of dense gradients with increased memory footprint. 

706 

707 Returns: 

708 A dense tensor representing the combined embeddings for the 

709 sparse ids. For each row in the dense tensor represented by `sparse_ids`, 

710 the op looks up the embeddings for all ids in that row, multiplies them by 

711 the corresponding weight, and combines these embeddings as specified. 

712 

713 In other words, if 

714 

715 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 

716 

717 and 

718 

719 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 

720 

721 then 

722 

723 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 

724 

725 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 

726 

727 ```python 

728 [0, 0]: id 1, weight 2.0 

729 [0, 1]: id 3, weight 0.5 

730 [1, 0]: id -1, weight 1.0 

731 [2, 3]: id 1, weight 3.0 

732 ``` 

733 

734 `default_id` is 0. 

735 

736 with `combiner`="mean", then the output will be a 3x20 matrix where 

737 

738 ```python 

739 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 

740 output[1, :] = (params[0, :] * 1.0) / 1.0 

741 output[2, :] = (params[1, :] * 3.0) / 3.0 

742 ``` 

743 

744 Raises: 

745 ValueError: if `embedding_weights` is empty. 

746 """ 

747 return safe_embedding_lookup_sparse( 

748 embedding_weights, 

749 sparse_ids, 

750 sparse_weights=sparse_weights, 

751 combiner=combiner, 

752 default_id=default_id, 

753 name=name, 

754 partition_strategy="div", 

755 max_norm=max_norm, 

756 allow_fast_lookup=allow_fast_lookup, 

757 ) 

758 

759 

760@tf_export(v1=["nn.safe_embedding_lookup_sparse"]) 

761@dispatch.add_dispatch_support 

762def safe_embedding_lookup_sparse( 

763 embedding_weights, 

764 sparse_ids, 

765 sparse_weights=None, 

766 combiner="mean", 

767 default_id=None, 

768 name=None, 

769 partition_strategy="div", 

770 max_norm=None, 

771 allow_fast_lookup=False, 

772): 

773 """Lookup embedding results, accounting for invalid IDs and empty features. 

774 

775 The partitioned embedding in `embedding_weights` must all be the same shape 

776 except for the first dimension. The first dimension is allowed to vary as the 

777 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 

778 may be a `PartitionedVariable` as returned by using 

779 `tf.compat.v1.get_variable()` with a 

780 partitioner. 

781 

782 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 

783 with non-positive weight. For an entry with no features, the embedding vector 

784 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 

785 

786 The ids and weights may be multi-dimensional `SparseTensor`s or 

787 `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero 

788 entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can 

789 yield higher performance. Embeddings are always aggregated along the last 

790 dimension. 

791 

792 Args: 

793 embedding_weights: A single tensor representing the complete embedding 

794 tensor, or a list tensors all of same shape except for the first 

795 dimension, representing sharded embedding tensors. Alternatively, a 

796 `PartitionedVariable`, created by partitioning along dimension 0. Each 

797 element must be appropriately sized for the given `partition_strategy`. 

798 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 

799 ids, where `d_0` is typically batch size, or a `RaggedTensor` with rank 2. 

800 sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as 

801 `sparse_ids`, containing float weights corresponding to `sparse_ids`, or 

802 `None` if all weights are assumed to be 1.0. 

803 combiner: A string specifying how to combine embedding results for each 

804 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 

805 default. 

806 default_id: The id to use for an entry with no features. 

807 name: A name for this operation (optional). 

808 partition_strategy: A string specifying the partitioning strategy. Currently 

809 `"div"` and `"mod"` are supported. Default is `"div"`. 

810 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 

811 combining. 

812 allow_fast_lookup: An optional boolean specifying whether to allow 

813 simplified embedding lookups when `params` is a single tensor and 

814 `max_norm` is `None`. Setting this flag to `True` during training can 

815 cause the use of dense gradients with increased memory footprint. 

816 

817 Returns: 

818 A dense tensor representing the combined embeddings for the 

819 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 

820 looks up the embeddings for all ids in that row, multiplies them by the 

821 corresponding weight, and combines these embeddings as specified. 

822 

823 In other words, if 

824 

825 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 

826 

827 and 

828 

829 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 

830 

831 then 

832 

833 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 

834 

835 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 

836 

837 ```python 

838 [0, 0]: id 1, weight 2.0 

839 [0, 1]: id 3, weight 0.5 

840 [1, 0]: id -1, weight 1.0 

841 [2, 3]: id 1, weight 3.0 

842 ``` 

843 

844 `default_id` is 0. 

845 

846 with `combiner`="mean", then the output will be a 3x20 matrix where 

847 

848 ```python 

849 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 

850 output[1, :] = (params[0, :] * 1.0) / 1.0 

851 output[2, :] = (params[1, :] * 3.0) / 3.0 

852 ``` 

853 

854 Raises: 

855 ValueError: if `embedding_weights` is empty. 

856 """ 

857 if embedding_weights is None: 

858 raise ValueError(f"Missing embedding_weights {embedding_weights}.") 

859 if isinstance(embedding_weights, variables.PartitionedVariable): 

860 embedding_weights = list(embedding_weights) # get underlying Variables. 

861 if not isinstance(embedding_weights, list): 

862 embedding_weights = [embedding_weights] 

863 if len(embedding_weights) < 1: 

864 raise ValueError(f"Missing embedding_weights {embedding_weights}.") 

865 

866 dtype = sparse_weights.dtype if sparse_weights is not None else None 

867 embedding_weights = [ 

868 w if (isinstance(w, resource_variable_ops.ResourceVariable) 

869 and dtype in (None, w.dtype)) 

870 else ops.convert_to_tensor(w, dtype=dtype) 

871 for w in embedding_weights 

872 ] 

873 

874 with ops.name_scope(name, "embedding_lookup", embedding_weights + 

875 [sparse_ids, sparse_weights]) as scope: 

876 # Reshape higher-rank sparse ids and weights to linear segment ids. 

877 original_shape = sparse_ids.dense_shape 

878 original_rank_dim = tensor_shape.dimension_value( 

879 sparse_ids.dense_shape.get_shape()[0]) 

880 original_rank = ( 

881 array_ops.size(original_shape) 

882 if original_rank_dim is None else original_rank_dim) 

883 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 

884 math_ops.reduce_prod( 

885 array_ops.slice(original_shape, [0], [original_rank - 1])), 

886 array_ops.gather(original_shape, original_rank - 1) 

887 ]) 

888 if sparse_weights is not None: 

889 sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices, 

890 sparse_weights.values, 

891 sparse_ids.dense_shape) 

892 

893 # Prune invalid ids and weights. 

894 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 

895 if combiner != "sum": 

896 sparse_ids, sparse_weights = _prune_invalid_weights( 

897 sparse_ids, sparse_weights) 

898 

899 # Fill in dummy values for empty features, if necessary. 

900 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( 

901 sparse_ids, default_id or 0) 

902 if sparse_weights is not None: 

903 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 

904 

905 result = embedding_lookup_sparse( 

906 embedding_weights, 

907 sparse_ids, 

908 sparse_weights, 

909 combiner=combiner, 

910 partition_strategy=partition_strategy, 

911 name=None if default_id is None else scope, 

912 max_norm=max_norm, 

913 allow_fast_lookup=allow_fast_lookup, 

914 ) 

915 

916 if default_id is None: 

917 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 

918 # for use in Select. 

919 is_row_empty = array_ops.tile( 

920 array_ops.reshape(is_row_empty, [-1, 1]), 

921 array_ops_stack.stack([1, array_ops.shape(result)[1]])) 

922 

923 result = array_ops.where( 

924 is_row_empty, array_ops.zeros_like(result), result, name=scope) 

925 

926 # Reshape back from linear ids back into higher-dimensional dense result. 

927 final_result = array_ops.reshape( 

928 result, 

929 array_ops.concat([ 

930 array_ops.slice( 

931 math_ops.cast(original_shape, dtypes.int32), [0], 

932 [original_rank - 1]), 

933 array_ops.slice(array_ops.shape(result), [1], [-1]) 

934 ], 0)) 

935 final_result.set_shape( 

936 tensor_shape.unknown_shape( 

937 (tensor_shape.Dimension(original_rank_dim) - 1).value 

938 ).concatenate(result.get_shape()[1:]) 

939 ) 

940 return final_result 

941 

942 

943def embedding_lookup_sparse_impl( 

944 params, 

945 segment_ids, 

946 sp_weights, 

947 ids, 

948 combiner, 

949 ignore_weights, 

950 max_norm, 

951 allow_fast_lookup, 

952 partition_strategy, 

953 name, 

954): 

955 """Implementation of sparse embedding aggregation.""" 

956 if len(params) == 1 and max_norm is None and allow_fast_lookup: 

957 idx = ids 

958 embeddings = params[0] 

959 else: 

960 ids, idx = array_ops.unique(ids) 

961 embeddings = embedding_lookup( 

962 params, ids, partition_strategy=partition_strategy, max_norm=max_norm 

963 ) 

964 

965 if not ignore_weights: 

966 if segment_ids.dtype != dtypes.int32: 

967 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 

968 

969 weights = sp_weights.values 

970 embeddings = array_ops.gather(embeddings, idx) 

971 

972 original_dtype = embeddings.dtype 

973 if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): 

974 # Cast low-precision embeddings to float32 during the computation to 

975 # avoid numerical issues. 

976 embeddings = math_ops.cast(embeddings, dtypes.float32) 

977 if weights.dtype != embeddings.dtype: 

978 weights = math_ops.cast(weights, embeddings.dtype) 

979 

980 # Reshape weights to allow broadcast 

981 ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0) 

982 ones = array_ops.ones(ones_shape, dtype=dtypes.int32) 

983 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) 

984 

985 orig_weights_shape = weights.get_shape() 

986 weights = array_ops.reshape(weights, bcast_weights_shape) 

987 

988 # Set the weight shape, since after reshaping to bcast_weights_shape, 

989 # the shape becomes None. 

990 if embeddings.get_shape().ndims is not None: 

991 weights.set_shape( 

992 orig_weights_shape.concatenate( 

993 [1 for _ in range(embeddings.get_shape().ndims - 1)] 

994 ) 

995 ) 

996 

997 embeddings *= weights 

998 

999 if combiner == "sum": 

1000 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) 

1001 elif combiner == "mean": 

1002 embeddings = math_ops.segment_sum(embeddings, segment_ids) 

1003 weight_sum = math_ops.segment_sum(weights, segment_ids) 

1004 embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name) 

1005 elif combiner == "sqrtn": 

1006 embeddings = math_ops.segment_sum(embeddings, segment_ids) 

1007 weights_squared = math_ops.pow(weights, 2) 

1008 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 

1009 weight_sum_sqrt = math_ops.sqrt(weight_sum) 

1010 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name) 

1011 else: 

1012 assert False, "Unrecognized combiner" 

1013 if embeddings.dtype != original_dtype: 

1014 embeddings = math_ops.cast(embeddings, original_dtype) 

1015 else: 

1016 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 

1017 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 

1018 assert idx is not None 

1019 if combiner == "sum": 

1020 embeddings = math_ops.sparse_segment_sum( 

1021 embeddings, idx, segment_ids, name=name 

1022 ) 

1023 elif combiner == "mean": 

1024 embeddings = math_ops.sparse_segment_mean( 

1025 embeddings, idx, segment_ids, name=name 

1026 ) 

1027 elif combiner == "sqrtn": 

1028 embeddings = math_ops.sparse_segment_sqrt_n( 

1029 embeddings, idx, segment_ids, name=name 

1030 ) 

1031 else: 

1032 assert False, "Unrecognized combiner" 

1033 

1034 return embeddings 

1035 

1036 

1037def _prune_invalid_ids(sparse_ids, sparse_weights): 

1038 """Prune invalid IDs (< 0) from the input ids and weights.""" 

1039 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 

1040 if sparse_weights is not None: 

1041 is_id_valid = math_ops.logical_and( 

1042 is_id_valid, 

1043 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 

1044 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 

1045 if sparse_weights is not None: 

1046 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 

1047 return sparse_ids, sparse_weights 

1048 

1049 

1050def _prune_invalid_weights(sparse_ids, sparse_weights): 

1051 """Prune invalid weights (< 0) from the input ids and weights.""" 

1052 if sparse_weights is not None: 

1053 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 

1054 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 

1055 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 

1056 return sparse_ids, sparse_weights