Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/multihead_attention.py: 12%

83 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16import typing 

17import warnings 

18 

19import tensorflow as tf 

20 

21 

22@tf.keras.utils.register_keras_serializable(package="Addons") 

23class MultiHeadAttention(tf.keras.layers.Layer): 

24 r"""MultiHead Attention layer. 

25 

26 Defines the MultiHead Attention operation as described in 

27 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) which takes 

28 in the tensors `query`, `key`, and `value`, and returns the dot-product attention 

29 between them: 

30 

31 >>> mha = MultiHeadAttention(head_size=128, num_heads=12) 

32 >>> query = np.random.rand(3, 5, 4) # (batch_size, query_elements, query_depth) 

33 >>> key = np.random.rand(3, 6, 5) # (batch_size, key_elements, key_depth) 

34 >>> value = np.random.rand(3, 6, 6) # (batch_size, key_elements, value_depth) 

35 >>> attention = mha([query, key, value]) # (batch_size, query_elements, value_depth) 

36 >>> attention.shape 

37 TensorShape([3, 5, 6]) 

38 

39 If `value` is not given then internally `value = key` will be used: 

40 

41 >>> mha = MultiHeadAttention(head_size=128, num_heads=12) 

42 >>> query = np.random.rand(3, 5, 5) # (batch_size, query_elements, query_depth) 

43 >>> key = np.random.rand(3, 6, 10) # (batch_size, key_elements, key_depth) 

44 >>> attention = mha([query, key]) # (batch_size, query_elements, key_depth) 

45 >>> attention.shape 

46 TensorShape([3, 5, 10]) 

47 

48 Args: 

49 head_size: int, dimensionality of the `query`, `key` and `value` tensors 

50 after the linear transformation. 

51 num_heads: int, number of attention heads. 

52 output_size: int, dimensionality of the output space, if `None` then the 

53 input dimension of `value` or `key` will be used, 

54 default `None`. 

55 dropout: float, `rate` parameter for the dropout layer that is 

56 applied to attention after softmax, 

57 default `0`. 

58 use_projection_bias: bool, whether to use a bias term after the linear 

59 output projection. 

60 return_attn_coef: bool, if `True`, return the attention coefficients as 

61 an additional output argument. 

62 kernel_initializer: initializer, initializer for the kernel weights. 

63 kernel_regularizer: regularizer, regularizer for the kernel weights. 

64 kernel_constraint: constraint, constraint for the kernel weights. 

65 bias_initializer: initializer, initializer for the bias weights. 

66 bias_regularizer: regularizer, regularizer for the bias weights. 

67 bias_constraint: constraint, constraint for the bias weights. 

68 

69 Call Args: 

70 inputs: List of `[query, key, value]` where 

71 * `query`: Tensor of shape `(..., query_elements, query_depth)` 

72 * `key`: `Tensor of shape '(..., key_elements, key_depth)` 

73 * `value`: Tensor of shape `(..., key_elements, value_depth)`, optional, if not given `key` will be used. 

74 mask: a binary Tensor of shape `[batch_size?, num_heads?, query_elements, key_elements]` 

75 which specifies which query elements can attendo to which key elements, 

76 `1` indicates attention and `0` indicates no attention. 

77 

78 Output shape: 

79 * `(..., query_elements, output_size)` if `output_size` is given, else 

80 * `(..., query_elements, value_depth)` if `value` is given, else 

81 * `(..., query_elements, key_depth)` 

82 """ 

83 

84 def __init__( 

85 self, 

86 head_size: int, 

87 num_heads: int, 

88 output_size: int = None, 

89 dropout: float = 0.0, 

90 use_projection_bias: bool = True, 

91 return_attn_coef: bool = False, 

92 kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform", 

93 kernel_regularizer: typing.Union[str, typing.Callable] = None, 

94 kernel_constraint: typing.Union[str, typing.Callable] = None, 

95 bias_initializer: typing.Union[str, typing.Callable] = "zeros", 

96 bias_regularizer: typing.Union[str, typing.Callable] = None, 

97 bias_constraint: typing.Union[str, typing.Callable] = None, 

98 **kwargs, 

99 ): 

100 warnings.warn( 

101 "`MultiHeadAttention` will be deprecated in Addons 0.13. " 

102 "Please use `tf.keras.layers.MultiHeadAttention` instead.", 

103 DeprecationWarning, 

104 ) 

105 

106 super().__init__(**kwargs) 

107 

108 if output_size is not None and output_size < 1: 

109 raise ValueError("output_size must be a positive number") 

110 

111 self.head_size = head_size 

112 self.num_heads = num_heads 

113 self.output_size = output_size 

114 self.use_projection_bias = use_projection_bias 

115 self.return_attn_coef = return_attn_coef 

116 

117 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) 

118 self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) 

119 self.kernel_constraint = tf.keras.constraints.get(kernel_constraint) 

120 self.bias_initializer = tf.keras.initializers.get(bias_initializer) 

121 self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) 

122 self.bias_constraint = tf.keras.constraints.get(bias_constraint) 

123 

124 self.dropout = tf.keras.layers.Dropout(dropout) 

125 self._droput_rate = dropout 

126 

127 def build(self, input_shape): 

128 

129 num_query_features = input_shape[0][-1] 

130 num_key_features = input_shape[1][-1] 

131 num_value_features = ( 

132 input_shape[2][-1] if len(input_shape) > 2 else num_key_features 

133 ) 

134 output_size = ( 

135 self.output_size if self.output_size is not None else num_value_features 

136 ) 

137 

138 self.query_kernel = self.add_weight( 

139 name="query_kernel", 

140 shape=[self.num_heads, num_query_features, self.head_size], 

141 initializer=self.kernel_initializer, 

142 regularizer=self.kernel_regularizer, 

143 constraint=self.kernel_constraint, 

144 ) 

145 self.key_kernel = self.add_weight( 

146 name="key_kernel", 

147 shape=[self.num_heads, num_key_features, self.head_size], 

148 initializer=self.kernel_initializer, 

149 regularizer=self.kernel_regularizer, 

150 constraint=self.kernel_constraint, 

151 ) 

152 self.value_kernel = self.add_weight( 

153 name="value_kernel", 

154 shape=[self.num_heads, num_value_features, self.head_size], 

155 initializer=self.kernel_initializer, 

156 regularizer=self.kernel_regularizer, 

157 constraint=self.kernel_constraint, 

158 ) 

159 self.projection_kernel = self.add_weight( 

160 name="projection_kernel", 

161 shape=[self.num_heads, self.head_size, output_size], 

162 initializer=self.kernel_initializer, 

163 regularizer=self.kernel_regularizer, 

164 constraint=self.kernel_constraint, 

165 ) 

166 

167 if self.use_projection_bias: 

168 self.projection_bias = self.add_weight( 

169 name="projection_bias", 

170 shape=[output_size], 

171 initializer=self.bias_initializer, 

172 regularizer=self.bias_regularizer, 

173 constraint=self.bias_constraint, 

174 ) 

175 else: 

176 self.projection_bias = None 

177 

178 super().build(input_shape) 

179 

180 def call(self, inputs, training=None, mask=None): 

181 

182 # einsum nomenclature 

183 # ------------------------ 

184 # N = query elements 

185 # M = key/value elements 

186 # H = heads 

187 # I = input features 

188 # O = output features 

189 

190 query = inputs[0] 

191 key = inputs[1] 

192 value = inputs[2] if len(inputs) > 2 else key 

193 

194 # verify shapes 

195 if key.shape[-2] != value.shape[-2]: 

196 raise ValueError( 

197 "the number of elements in 'key' must be equal to the same as the number of elements in 'value'" 

198 ) 

199 

200 if mask is not None: 

201 if len(mask.shape) < 2: 

202 raise ValueError("'mask' must have atleast 2 dimensions") 

203 if query.shape[-2] != mask.shape[-2]: 

204 raise ValueError( 

205 "mask's second to last dimension must be equal to the number of elements in 'query'" 

206 ) 

207 if key.shape[-2] != mask.shape[-1]: 

208 raise ValueError( 

209 "mask's last dimension must be equal to the number of elements in 'key'" 

210 ) 

211 

212 # Linear transformations 

213 query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel) 

214 key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel) 

215 value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel) 

216 

217 # Scale dot-product, doing the division to either query or key 

218 # instead of their product saves some computation 

219 depth = tf.constant(self.head_size, dtype=query.dtype) 

220 query /= tf.sqrt(depth) 

221 

222 # Calculate dot product attention 

223 logits = tf.einsum("...NHO,...MHO->...HNM", query, key) 

224 

225 # apply mask 

226 if mask is not None: 

227 mask = tf.cast(mask, tf.float32) 

228 

229 # possibly expand on the head dimension so broadcasting works 

230 if len(mask.shape) != len(logits.shape): 

231 mask = tf.expand_dims(mask, -3) 

232 

233 logits += -10e9 * (1.0 - mask) 

234 

235 attn_coef = tf.nn.softmax(logits) 

236 

237 # attention dropout 

238 attn_coef_dropout = self.dropout(attn_coef, training=training) 

239 

240 # attention * value 

241 multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value) 

242 

243 # Run the outputs through another linear projection layer. Recombining heads 

244 # is automatically done. 

245 output = tf.einsum( 

246 "...NHI,HIO->...NO", multihead_output, self.projection_kernel 

247 ) 

248 

249 if self.projection_bias is not None: 

250 output += self.projection_bias 

251 

252 if self.return_attn_coef: 

253 return output, attn_coef 

254 else: 

255 return output 

256 

257 def compute_output_shape(self, input_shape): 

258 num_value_features = ( 

259 input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1] 

260 ) 

261 output_size = ( 

262 self.output_size if self.output_size is not None else num_value_features 

263 ) 

264 

265 output_shape = input_shape[0][:-1] + (output_size,) 

266 

267 if self.return_attn_coef: 

268 num_query_elements = input_shape[0][-2] 

269 num_key_elements = input_shape[1][-2] 

270 attn_coef_shape = input_shape[0][:-2] + ( 

271 self.num_heads, 

272 num_query_elements, 

273 num_key_elements, 

274 ) 

275 

276 return output_shape, attn_coef_shape 

277 else: 

278 return output_shape 

279 

280 def get_config(self): 

281 config = super().get_config() 

282 

283 config.update( 

284 head_size=self.head_size, 

285 num_heads=self.num_heads, 

286 output_size=self.output_size, 

287 dropout=self._droput_rate, 

288 use_projection_bias=self.use_projection_bias, 

289 return_attn_coef=self.return_attn_coef, 

290 kernel_initializer=tf.keras.initializers.serialize(self.kernel_initializer), 

291 kernel_regularizer=tf.keras.regularizers.serialize(self.kernel_regularizer), 

292 kernel_constraint=tf.keras.constraints.serialize(self.kernel_constraint), 

293 bias_initializer=tf.keras.initializers.serialize(self.bias_initializer), 

294 bias_regularizer=tf.keras.regularizers.serialize(self.bias_regularizer), 

295 bias_constraint=tf.keras.constraints.serialize(self.bias_constraint), 

296 ) 

297 

298 return config