Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/attention.py: 28%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Attention layer that can be used in sequence DNN/CNN models. 

16 

17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. 

18Attention is formed by three tensors: Query, Key and Value. 

19""" 

20 

21 

22import tensorflow.compat.v2 as tf 

23 

24from keras.src.layers.attention.base_dense_attention import BaseDenseAttention 

25 

26# isort: off 

27from tensorflow.python.util.tf_export import keras_export 

28 

29 

30@keras_export("keras.layers.Attention") 

31class Attention(BaseDenseAttention): 

32 """Dot-product attention layer, a.k.a. Luong-style attention. 

33 

34 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor 

35 of shape `[batch_size, Tv, dim]` and `key` tensor of shape 

36 `[batch_size, Tv, dim]`. The calculation follows the steps: 

37 

38 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot 

39 product: `scores = tf.matmul(query, key, transpose_b=True)`. 

40 2. Use scores to calculate a distribution with shape 

41 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. 

42 3. Use `distribution` to create a linear combination of `value` with 

43 shape `[batch_size, Tq, dim]`: 

44 `return tf.matmul(distribution, value)`. 

45 

46 Args: 

47 use_scale: If `True`, will create a scalar variable to scale the 

48 attention scores. 

49 dropout: Float between 0 and 1. Fraction of the units to drop for the 

50 attention scores. Defaults to 0.0. 

51 score_mode: Function to use to compute attention scores, one of 

52 `{"dot", "concat"}`. `"dot"` refers to the dot product between the 

53 query and key vectors. `"concat"` refers to the hyperbolic tangent 

54 of the concatenation of the query and key vectors. 

55 

56 Call Args: 

57 

58 inputs: List of the following tensors: 

59 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 

60 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 

61 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If 

62 not given, will use `value` for both `key` and `value`, which is 

63 the most common case. 

64 mask: List of the following tensors: 

65 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 

66 If given, the output will be zero at the positions where 

67 `mask==False`. 

68 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 

69 If given, will apply the mask such that values at positions 

70 where `mask==False` do not contribute to the result. 

71 return_attention_scores: bool, it `True`, returns the attention scores 

72 (after masking and softmax) as an additional output argument. 

73 training: Python boolean indicating whether the layer should behave in 

74 training mode (adding dropout) or in inference mode (no dropout). 

75 use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds 

76 a mask such that position `i` cannot attend to positions `j > i`. 

77 This prevents the flow of information from the future towards the 

78 past. 

79 Defaults to `False`. 

80 

81 Output: 

82 

83 Attention outputs of shape `[batch_size, Tq, dim]`. 

84 [Optional] Attention scores after masking and softmax with shape 

85 `[batch_size, Tq, Tv]`. 

86 

87 The meaning of `query`, `value` and `key` depend on the application. In the 

88 case of text similarity, for example, `query` is the sequence embeddings of 

89 the first piece of text and `value` is the sequence embeddings of the second 

90 piece of text. `key` is usually the same tensor as `value`. 

91 

92 Here is a code example for using `Attention` in a CNN+Attention network: 

93 

94 ```python 

95 # Variable-length int sequences. 

96 query_input = tf.keras.Input(shape=(None,), dtype='int32') 

97 value_input = tf.keras.Input(shape=(None,), dtype='int32') 

98 

99 # Embedding lookup. 

100 token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64) 

101 # Query embeddings of shape [batch_size, Tq, dimension]. 

102 query_embeddings = token_embedding(query_input) 

103 # Value embeddings of shape [batch_size, Tv, dimension]. 

104 value_embeddings = token_embedding(value_input) 

105 

106 # CNN layer. 

107 cnn_layer = tf.keras.layers.Conv1D( 

108 filters=100, 

109 kernel_size=4, 

110 # Use 'same' padding so outputs have the same shape as inputs. 

111 padding='same') 

112 # Query encoding of shape [batch_size, Tq, filters]. 

113 query_seq_encoding = cnn_layer(query_embeddings) 

114 # Value encoding of shape [batch_size, Tv, filters]. 

115 value_seq_encoding = cnn_layer(value_embeddings) 

116 

117 # Query-value attention of shape [batch_size, Tq, filters]. 

118 query_value_attention_seq = tf.keras.layers.Attention()( 

119 [query_seq_encoding, value_seq_encoding]) 

120 

121 # Reduce over the sequence axis to produce encodings of shape 

122 # [batch_size, filters]. 

123 query_encoding = tf.keras.layers.GlobalAveragePooling1D()( 

124 query_seq_encoding) 

125 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()( 

126 query_value_attention_seq) 

127 

128 # Concatenate query and document encodings to produce a DNN input layer. 

129 input_layer = tf.keras.layers.Concatenate()( 

130 [query_encoding, query_value_attention]) 

131 

132 # Add DNN layers, and create Model. 

133 # ... 

134 ``` 

135 """ 

136 

137 def __init__(self, use_scale=False, score_mode="dot", **kwargs): 

138 super().__init__(**kwargs) 

139 self.use_scale = use_scale 

140 self.score_mode = score_mode 

141 if self.score_mode not in ["dot", "concat"]: 

142 raise ValueError( 

143 f"Received: score_mode={score_mode}. Acceptable values " 

144 'are: ["dot", "concat"]' 

145 ) 

146 

147 def build(self, input_shape): 

148 """Creates variable when `use_scale` is True or `score_mode` is 

149 `concat`.""" 

150 if self.use_scale: 

151 self.scale = self.add_weight( 

152 name="scale", 

153 shape=(), 

154 initializer="ones", 

155 dtype=self.dtype, 

156 trainable=True, 

157 ) 

158 else: 

159 self.scale = None 

160 if self.score_mode == "concat": 

161 self.concat_score_weight = self.add_weight( 

162 name="concat_score_weight", 

163 shape=(), 

164 initializer="ones", 

165 dtype=self.dtype, 

166 trainable=True, 

167 ) 

168 else: 

169 self.concat_score_weight = None 

170 super().build(input_shape) 

171 

172 def _calculate_scores(self, query, key): 

173 """Calculates attention scores as a query-key dot product. 

174 

175 Args: 

176 query: Query tensor of shape `[batch_size, Tq, dim]`. 

177 key: Key tensor of shape `[batch_size, Tv, dim]`. 

178 Returns: 

179 Tensor of shape `[batch_size, Tq, Tv]`. 

180 """ 

181 if self.score_mode == "dot": 

182 scores = tf.matmul(query, key, transpose_b=True) 

183 if self.scale is not None: 

184 scores *= self.scale 

185 elif self.score_mode == "concat": 

186 # Reshape tensors to enable broadcasting. 

187 # Reshape into [batch_size, Tq, 1, dim]. 

188 q_reshaped = tf.expand_dims(query, axis=-2) 

189 # Reshape into [batch_size, 1, Tv, dim]. 

190 k_reshaped = tf.expand_dims(key, axis=-3) 

191 if self.scale is not None: 

192 scores = self.concat_score_weight * tf.reduce_sum( 

193 tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1 

194 ) 

195 else: 

196 scores = self.concat_score_weight * tf.reduce_sum( 

197 tf.tanh(q_reshaped + k_reshaped), axis=-1 

198 ) 

199 

200 return scores 

201 

202 def get_config(self): 

203 config = {"use_scale": self.use_scale, "score_mode": self.score_mode} 

204 base_config = super().get_config() 

205 return dict(list(base_config.items()) + list(config.items())) 

206