Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/embedding_bag.py: 37%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17from typeguard import typechecked 

18 

19from tensorflow_addons.utils.types import Constraint, Initializer, Regularizer 

20from tensorflow_addons.utils.resource_loader import LazySO 

21 

22_embedding_bag_so = LazySO("custom_ops/layers/_embedding_bag_ops.so") 

23 

24 

25def _embedding_bag( 

26 indices, 

27 params, 

28 weights=None, 

29 combiner="sum", 

30 name=None, 

31): 

32 """EmbeddingBag computation. 

33 

34 See [PyTorch op](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html). 

35 

36 Equivalent to tf.gather() followed by tf.reduce_{sum,mean}() across the last dimension, with optional 

37 weights. Fusing these into a single op has massive benefits for execution speed and particularly 

38 memory usage, as the intermediate output of the gather never needs to be materialized. 

39 

40 Args: 

41 indices: An int32 or int64 `Tensor` of the indices to gather from 

42 `params`. Must be at least 2-dimensional, as the last dimension 

43 will be summed out. Maximum value must be less than params.shape[0]. 

44 params: A float32 `Tensor` from which to gather params. Must be rank 2. 

45 weights: A float32 `Tensor` of weights which will be applied to each of 

46 the gathered embedding vectors before the sum step. 

47 name: A name for the operation (optional). 

48 

49 Returns: 

50 A `Tensor` of the format specified by `data_format`. 

51 """ 

52 if weights is None: 

53 weights = tf.ones_like(indices, dtype=params.dtype) 

54 elif combiner != "sum": 

55 raise RuntimeError( 

56 "Combiner mode must be 'sum' when weights are supplied to EmbeddingBag!" 

57 ) 

58 

59 return _embedding_bag_so.ops.addons_embedding_bag( 

60 indices, params, weights, combiner=combiner.upper(), name=name 

61 ) 

62 

63 

64@tf.RegisterGradient("Addons>EmbeddingBag") 

65def _embedding_bag_grad(op, grads): 

66 indices, params, weights = op.inputs[:3] 

67 combiner = op.get_attr("combiner") 

68 value_grads, weight_grads = _embedding_bag_so.ops.addons_embedding_bag_grad( 

69 indices, params, weights, grads, combiner=combiner 

70 ) 

71 return [None, value_grads, weight_grads] 

72 

73 

74@tf.keras.utils.register_keras_serializable(package="Addons") 

75class EmbeddingBag(tf.keras.layers.Layer): 

76 """EmbeddingBag Layer. 

77 

78 See [PyTorch op](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html). 

79 

80 Equivalent to tf.gather() followed by tf.reduce_sum() across the last dimension, with optional 

81 weights. Fusing these into a single op has massive benefits for execution speed and particularly 

82 memory usage, as the intermediate output of the gather never needs to be materialized. 

83 

84 Input Shapes: 

85 indices: An int32 or int64 `Tensor` of the indices to gather from 

86 `params`. Must be at least 2-dimensional, as the last dimension 

87 will be summed out. Maximum value must be less than params.shape[0]. 

88 params: A float32 `Tensor` from which to gather params. Must be rank 2. 

89 weights: A float32 `Tensor` of weights which will be applied to each of 

90 the gathered embedding vectors before the sum step. 

91 

92 Output shape: 

93 indices.shape[:-1], params.shape[-1] 

94 """ 

95 

96 @typechecked 

97 def __init__( 

98 self, 

99 input_dim: int, 

100 output_dim: int, 

101 embeddings_initializer: Initializer = "uniform", 

102 embeddings_regularizer: Regularizer = None, 

103 embeddings_constraint: Constraint = None, 

104 mask_zero: bool = False, 

105 combiner: str = "sum", 

106 **kwargs, 

107 ): 

108 super(EmbeddingBag, self).__init__(**kwargs) 

109 if input_dim <= 0 or output_dim <= 0: 

110 raise ValueError( 

111 "Both `input_dim` and `output_dim` should be positive, " 

112 "found input_dim {} and output_dim {}".format(input_dim, output_dim) 

113 ) 

114 self.input_dim = input_dim 

115 self.output_dim = output_dim 

116 self.embeddings_initializer = tf.keras.initializers.get(embeddings_initializer) 

117 self.embeddings_regularizer = tf.keras.regularizers.get(embeddings_regularizer) 

118 self.embeddings_constraint = tf.keras.constraints.get(embeddings_constraint) 

119 self.mask_zero = mask_zero 

120 self.supports_masking = mask_zero 

121 self.combiner = combiner 

122 

123 def build(self, input_shape): 

124 self.embeddings = self.add_weight( 

125 shape=(self.input_dim, self.output_dim), 

126 name="embeddings", 

127 initializer=self.embeddings_initializer, 

128 regularizer=self.embeddings_regularizer, 

129 constraint=self.embeddings_constraint, 

130 ) 

131 self.built = True 

132 

133 def call(self, indices, weights=None): 

134 return _embedding_bag(indices, self.embeddings, weights, combiner=self.combiner) 

135 

136 def get_config(self): 

137 config = { 

138 "input_dim": self.input_dim, 

139 "output_dim": self.output_dim, 

140 "embeddings_initializer": tf.keras.initializers.serialize( 

141 self.embeddings_initializer 

142 ), 

143 "embeddings_regularizer": tf.keras.regularizers.serialize( 

144 self.embeddings_regularizer 

145 ), 

146 "embeddings_constraint": tf.keras.constraints.serialize( 

147 self.embeddings_constraint 

148 ), 

149 "mask_zero": self.mask_zero, 

150 "combiner": self.combiner, 

151 } 

152 base_config = super(EmbeddingBag, self).get_config() 

153 return dict(list(base_config.items()) + list(config.items()))