Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/base_dense_attention.py: 21%

94 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for attention layers that can be used in sequence DNN/CNN models. 

16 

17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. 

18Attention is formed by three tensors: Query, Key and Value. 

19""" 

20 

21import tensorflow.compat.v2 as tf 

22from absl import logging 

23 

24from keras.src import backend 

25from keras.src.engine import base_layer 

26from keras.src.utils import control_flow_util 

27 

28# isort: off 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export("keras.__internal__.layers.BaseDenseAttention", v1=[]) 

33class BaseDenseAttention(base_layer.BaseRandomLayer): 

34 """Base Attention class for Dense networks. 

35 

36 This class is suitable for Dense or CNN networks, and not for RNN networks. 

37 

38 Implementations of attention mechanisms should inherit from this class, and 

39 reuse the `apply_attention_scores()` method. 

40 

41 Args: 

42 dropout: Float between 0 and 1. Fraction of the units to drop for the 

43 attention scores. 

44 

45 Call Args: 

46 inputs: List of the following tensors: 

47 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 

48 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 

49 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If 

50 not given, will use `value` for both `key` and `value`, which is 

51 the most common case. 

52 mask: List of the following tensors: 

53 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 

54 If given, the output will be zero at the positions where 

55 `mask==False`. 

56 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 

57 If given, will apply the mask such that values at positions 

58 where `mask==False` do not contribute to the result. 

59 training: Python boolean indicating whether the layer should behave in 

60 training mode (adding dropout) or in inference mode (no dropout). 

61 return_attention_scores: bool, if `True`, returns the attention scores 

62 (after masking and softmax) as an additional output argument. 

63 

64 Output: 

65 

66 Attention outputs of shape `[batch_size, Tq, dim]`. 

67 [Optional] Attention scores after masking and softmax with shape 

68 `[batch_size, Tq, Tv]`. 

69 """ 

70 

71 def __init__(self, dropout=0.0, **kwargs): 

72 # Deprecated field `causal` determines whether to using causal masking. 

73 # Use `use_causal_mask` in call() method instead. 

74 if "causal" in kwargs: 

75 logging.warning( 

76 "`causal` argument is deprecated. Please use `use_causal_mask` " 

77 "in call() method to specify causal masking." 

78 ) 

79 self.causal = kwargs.pop("causal", False) 

80 super().__init__(**kwargs) 

81 self.dropout = dropout 

82 self.supports_masking = True 

83 

84 def build(self, input_shape): 

85 # Skip RNG initialization if dropout rate is 0. This will let the layer 

86 # be purely stateless, with no reference to any variable. 

87 if self.dropout > 0: 

88 super().build(input_shape) 

89 self.built = True 

90 

91 def _calculate_scores(self, query, key): 

92 """Calculates attention scores. 

93 

94 Args: 

95 query: Query tensor of shape `[batch_size, Tq, dim]`. 

96 key: Key tensor of shape `[batch_size, Tv, dim]`. 

97 

98 Returns: 

99 Tensor of shape `[batch_size, Tq, Tv]`. 

100 """ 

101 return NotImplementedError 

102 

103 def _apply_scores(self, scores, value, scores_mask=None, training=None): 

104 """Applies attention scores to the given value tensor. 

105 

106 To use this method in your attention layer, follow the steps: 

107 

108 * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of 

109 shape `[batch_size, Tv]` to calculate the attention `scores`. 

110 * Pass `scores` and `value` tensors to this method. The method applies 

111 `scores_mask`, calculates 

112 `attention_distribution = softmax(scores)`, then returns 

113 `matmul(attention_distribution, value). 

114 * Apply `query_mask` and return the result. 

115 

116 Args: 

117 scores: Scores float tensor of shape `[batch_size, Tq, Tv]`. 

118 value: Value tensor of shape `[batch_size, Tv, dim]`. 

119 scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` 

120 or `[batch_size, Tq, Tv]`. If given, scores at positions where 

121 `scores_mask==False` do not contribute to the result. It must 

122 contain at least one `True` value in each line along the last 

123 dimension. 

124 training: Python boolean indicating whether the layer should behave 

125 in training mode (adding dropout) or in inference mode 

126 (no dropout). 

127 

128 Returns: 

129 Tensor of shape `[batch_size, Tq, dim]`. 

130 Attention scores after masking and softmax with shape 

131 `[batch_size, Tq, Tv]`. 

132 """ 

133 if scores_mask is not None: 

134 padding_mask = tf.logical_not(scores_mask) 

135 # Bias so padding positions do not contribute to attention 

136 # distribution. Note 65504. is the max float16 value. 

137 if scores.dtype is tf.float16: 

138 scores -= 65504.0 * tf.cast(padding_mask, dtype=scores.dtype) 

139 else: 

140 scores -= 1.0e9 * tf.cast(padding_mask, dtype=scores.dtype) 

141 if training is None: 

142 training = backend.learning_phase() 

143 weights = tf.nn.softmax(scores) 

144 

145 if self.dropout > 0: 

146 

147 def dropped_weights(): 

148 return self._random_generator.dropout( 

149 weights, rate=self.dropout 

150 ) 

151 

152 weights = control_flow_util.smart_cond( 

153 training, dropped_weights, lambda: tf.identity(weights) 

154 ) 

155 return tf.matmul(weights, value), weights 

156 

157 # TODO(b/125916026): Consider exposing a __call__ method with named args. 

158 def call( 

159 self, 

160 inputs, 

161 mask=None, 

162 training=None, 

163 return_attention_scores=False, 

164 use_causal_mask=False, 

165 ): 

166 self._validate_call_args(inputs=inputs, mask=mask) 

167 q = inputs[0] 

168 v = inputs[1] 

169 k = inputs[2] if len(inputs) > 2 else v 

170 q_mask = mask[0] if mask else None 

171 v_mask = mask[1] if mask else None 

172 scores = self._calculate_scores(query=q, key=k) 

173 if v_mask is not None: 

174 # Mask of shape [batch_size, 1, Tv]. 

175 v_mask = tf.expand_dims(v_mask, axis=-2) 

176 if self.causal or use_causal_mask: 

177 # Creates a lower triangular mask, so position i cannot attend to 

178 # positions j>i. This prevents the flow of information from the 

179 # future into the past. 

180 scores_shape = tf.shape(scores) 

181 # causal_mask_shape = [1, Tq, Tv]. 

182 causal_mask_shape = tf.concat( 

183 [tf.ones_like(scores_shape[:-2]), scores_shape[-2:]], axis=0 

184 ) 

185 causal_mask = _lower_triangular_mask(causal_mask_shape) 

186 else: 

187 causal_mask = None 

188 scores_mask = _merge_masks(v_mask, causal_mask) 

189 result, attention_scores = self._apply_scores( 

190 scores=scores, value=v, scores_mask=scores_mask, training=training 

191 ) 

192 if q_mask is not None: 

193 # Mask of shape [batch_size, Tq, 1]. 

194 q_mask = tf.expand_dims(q_mask, axis=-1) 

195 result *= tf.cast(q_mask, dtype=result.dtype) 

196 if return_attention_scores: 

197 return result, attention_scores 

198 return result 

199 

200 def compute_mask(self, inputs, mask=None): 

201 self._validate_call_args(inputs=inputs, mask=mask) 

202 if mask: 

203 q_mask = mask[0] 

204 if q_mask is None: 

205 return None 

206 return tf.convert_to_tensor(q_mask) 

207 return None 

208 

209 def compute_output_shape(self, input_shape): 

210 # return_attention_scores argument of BaseDenseAttention.call method 

211 # is ignored. Output shape of attention_scores cannot be returned. 

212 return tf.TensorShape(input_shape[0]) 

213 

214 def _validate_call_args(self, inputs, mask): 

215 """Validates arguments of the call method.""" 

216 class_name = self.__class__.__name__ 

217 if not isinstance(inputs, list): 

218 raise ValueError( 

219 f"{class_name} layer must be called on a list of inputs, " 

220 "namely [query, value] or [query, value, key]. " 

221 f"Received: {inputs}." 

222 ) 

223 if len(inputs) < 2 or len(inputs) > 3: 

224 raise ValueError( 

225 f"{class_name} layer accepts inputs list of length 2 or 3, " 

226 "namely [query, value] or [query, value, key]. " 

227 f"Received length: {len(inputs)}." 

228 ) 

229 if mask: 

230 if not isinstance(mask, list): 

231 raise ValueError( 

232 f"{class_name} layer mask must be a list, " 

233 f"namely [query_mask, value_mask]. Received: {mask}." 

234 ) 

235 if len(mask) < 2 or len(mask) > len(inputs): 

236 raise ValueError( 

237 f"{class_name} layer mask must be a list of length 2, " 

238 "namely [query_mask, value_mask]. " 

239 f"Received length: {len(mask)}." 

240 ) 

241 

242 def get_config(self): 

243 config = { 

244 "dropout": self.dropout, 

245 } 

246 base_config = super().get_config() 

247 return dict(list(base_config.items()) + list(config.items())) 

248 

249 

250def _lower_triangular_mask(shape): 

251 """Creates a lower-triangular boolean mask over the last 2 dimensions.""" 

252 row_index = tf.cumsum(tf.ones(shape=shape, dtype=tf.int32), axis=-2) 

253 col_index = tf.cumsum(tf.ones(shape=shape, dtype=tf.int32), axis=-1) 

254 return tf.greater_equal(row_index, col_index) 

255 

256 

257def _merge_masks(x, y): 

258 if x is None: 

259 return y 

260 if y is None: 

261 return x 

262 return tf.logical_and(x, y) 

263