Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/embedding_bag.py: 37%
41 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow as tf
17from typeguard import typechecked
19from tensorflow_addons.utils.types import Constraint, Initializer, Regularizer
20from tensorflow_addons.utils.resource_loader import LazySO
22_embedding_bag_so = LazySO("custom_ops/layers/_embedding_bag_ops.so")
25def _embedding_bag(
26 indices,
27 params,
28 weights=None,
29 combiner="sum",
30 name=None,
31):
32 """EmbeddingBag computation.
34 See [PyTorch op](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).
36 Equivalent to tf.gather() followed by tf.reduce_{sum,mean}() across the last dimension, with optional
37 weights. Fusing these into a single op has massive benefits for execution speed and particularly
38 memory usage, as the intermediate output of the gather never needs to be materialized.
40 Args:
41 indices: An int32 or int64 `Tensor` of the indices to gather from
42 `params`. Must be at least 2-dimensional, as the last dimension
43 will be summed out. Maximum value must be less than params.shape[0].
44 params: A float32 `Tensor` from which to gather params. Must be rank 2.
45 weights: A float32 `Tensor` of weights which will be applied to each of
46 the gathered embedding vectors before the sum step.
47 name: A name for the operation (optional).
49 Returns:
50 A `Tensor` of the format specified by `data_format`.
51 """
52 if weights is None:
53 weights = tf.ones_like(indices, dtype=params.dtype)
54 elif combiner != "sum":
55 raise RuntimeError(
56 "Combiner mode must be 'sum' when weights are supplied to EmbeddingBag!"
57 )
59 return _embedding_bag_so.ops.addons_embedding_bag(
60 indices, params, weights, combiner=combiner.upper(), name=name
61 )
64@tf.RegisterGradient("Addons>EmbeddingBag")
65def _embedding_bag_grad(op, grads):
66 indices, params, weights = op.inputs[:3]
67 combiner = op.get_attr("combiner")
68 value_grads, weight_grads = _embedding_bag_so.ops.addons_embedding_bag_grad(
69 indices, params, weights, grads, combiner=combiner
70 )
71 return [None, value_grads, weight_grads]
74@tf.keras.utils.register_keras_serializable(package="Addons")
75class EmbeddingBag(tf.keras.layers.Layer):
76 """EmbeddingBag Layer.
78 See [PyTorch op](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).
80 Equivalent to tf.gather() followed by tf.reduce_sum() across the last dimension, with optional
81 weights. Fusing these into a single op has massive benefits for execution speed and particularly
82 memory usage, as the intermediate output of the gather never needs to be materialized.
84 Input Shapes:
85 indices: An int32 or int64 `Tensor` of the indices to gather from
86 `params`. Must be at least 2-dimensional, as the last dimension
87 will be summed out. Maximum value must be less than params.shape[0].
88 params: A float32 `Tensor` from which to gather params. Must be rank 2.
89 weights: A float32 `Tensor` of weights which will be applied to each of
90 the gathered embedding vectors before the sum step.
92 Output shape:
93 indices.shape[:-1], params.shape[-1]
94 """
96 @typechecked
97 def __init__(
98 self,
99 input_dim: int,
100 output_dim: int,
101 embeddings_initializer: Initializer = "uniform",
102 embeddings_regularizer: Regularizer = None,
103 embeddings_constraint: Constraint = None,
104 mask_zero: bool = False,
105 combiner: str = "sum",
106 **kwargs,
107 ):
108 super(EmbeddingBag, self).__init__(**kwargs)
109 if input_dim <= 0 or output_dim <= 0:
110 raise ValueError(
111 "Both `input_dim` and `output_dim` should be positive, "
112 "found input_dim {} and output_dim {}".format(input_dim, output_dim)
113 )
114 self.input_dim = input_dim
115 self.output_dim = output_dim
116 self.embeddings_initializer = tf.keras.initializers.get(embeddings_initializer)
117 self.embeddings_regularizer = tf.keras.regularizers.get(embeddings_regularizer)
118 self.embeddings_constraint = tf.keras.constraints.get(embeddings_constraint)
119 self.mask_zero = mask_zero
120 self.supports_masking = mask_zero
121 self.combiner = combiner
123 def build(self, input_shape):
124 self.embeddings = self.add_weight(
125 shape=(self.input_dim, self.output_dim),
126 name="embeddings",
127 initializer=self.embeddings_initializer,
128 regularizer=self.embeddings_regularizer,
129 constraint=self.embeddings_constraint,
130 )
131 self.built = True
133 def call(self, indices, weights=None):
134 return _embedding_bag(indices, self.embeddings, weights, combiner=self.combiner)
136 def get_config(self):
137 config = {
138 "input_dim": self.input_dim,
139 "output_dim": self.output_dim,
140 "embeddings_initializer": tf.keras.initializers.serialize(
141 self.embeddings_initializer
142 ),
143 "embeddings_regularizer": tf.keras.regularizers.serialize(
144 self.embeddings_regularizer
145 ),
146 "embeddings_constraint": tf.keras.constraints.serialize(
147 self.embeddings_constraint
148 ),
149 "mask_zero": self.mask_zero,
150 "combiner": self.combiner,
151 }
152 base_config = super(EmbeddingBag, self).get_config()
153 return dict(list(base_config.items()) + list(config.items()))