Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/crf.py: 25%
93 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# Orginal implementation from keras_contrib/layers/crf
16# ==============================================================================
17"""Implementing Conditional Random Field layer."""
19import tensorflow as tf
20from typeguard import typechecked
22from tensorflow_addons.text.crf import crf_decode
23from tensorflow_addons.utils import types
26@tf.keras.utils.register_keras_serializable(package="Addons")
27class CRF(tf.keras.layers.Layer):
28 """Linear chain conditional random field (CRF).
30 Inherits from: `tf.keras.layers.Layer`.
32 References:
33 - [Conditional Random Field](https://en.wikipedia.org/wiki/Conditional_random_field)
35 Example:
37 >>> layer = tfa.layers.CRF(4)
38 >>> inputs = np.random.rand(2, 4, 8).astype(np.float32)
39 >>> decoded_sequence, potentials, sequence_length, chain_kernel = layer(inputs)
40 >>> decoded_sequence.shape
41 TensorShape([2, 4])
42 >>> potentials.shape
43 TensorShape([2, 4, 4])
44 >>> sequence_length
45 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 4])>
46 >>> chain_kernel.shape
47 TensorShape([4, 4])
49 Args:
50 units: Positive integer, dimensionality of the reservoir.
51 chain_initializer: Orthogonal matrix. Default to `orthogonal`.
52 use_boundary: `Boolean`, whether the layer uses a boundary vector. Default to `True`.
53 boundary_initializer: Tensors initialized to 0. Default to `zeros`.
54 use_kernel: `Boolean`, whether the layer uses a kernel weights. Default to `True`.
55 Call Args:
56 inputs: Positive integer, dimensionality of the output space.
57 mask: A boolean `Tensor` of shape `[batch_size, sequence_length]`
58 or `None`. Default to `None`.
59 Raises:
60 ValueError: If input mask doesn't have dim 2 or None.
61 NotImplementedError: If left padding is provided.
62 """
64 @typechecked
65 def __init__(
66 self,
67 units: int,
68 chain_initializer: types.Initializer = "orthogonal",
69 use_boundary: bool = True,
70 boundary_initializer: types.Initializer = "zeros",
71 use_kernel: bool = True,
72 **kwargs,
73 ):
74 super().__init__(**kwargs)
76 # setup mask supporting flag, used by base class (the Layer)
77 # because base class's init method will set it to False unconditionally
78 # So this assigned must be executed after call base class's init method
79 self.supports_masking = True
81 self.units = units # numbers of tags
83 self.use_boundary = use_boundary
84 self.use_kernel = use_kernel
85 self.chain_initializer = tf.keras.initializers.get(chain_initializer)
86 self.boundary_initializer = tf.keras.initializers.get(boundary_initializer)
88 # weights that work as transfer probability of each tags
89 self.chain_kernel = self.add_weight(
90 shape=(self.units, self.units),
91 name="chain_kernel",
92 initializer=self.chain_initializer,
93 )
95 # weight of <START> to tag probability and tag to <END> probability
96 if self.use_boundary:
97 self.left_boundary = self.add_weight(
98 shape=(self.units,),
99 name="left_boundary",
100 initializer=self.boundary_initializer,
101 )
102 self.right_boundary = self.add_weight(
103 shape=(self.units,),
104 name="right_boundary",
105 initializer=self.boundary_initializer,
106 )
108 if self.use_kernel:
109 self._dense_layer = tf.keras.layers.Dense(
110 units=self.units, dtype=self.dtype
111 )
112 else:
113 self._dense_layer = lambda x: tf.cast(x, dtype=self.dtype)
115 def call(self, inputs, mask=None):
116 # mask: Tensor(shape=(batch_size, sequence_length), dtype=bool) or None
118 if mask is not None:
119 if tf.keras.backend.ndim(mask) != 2:
120 raise ValueError("Input mask to CRF must have dim 2 if not None")
122 if mask is not None:
123 # left padding of mask is not supported, due the underline CRF function
124 # detect it and report it to user
125 left_boundary_mask = self._compute_mask_left_boundary(mask)
126 first_mask = left_boundary_mask[:, 0]
127 if first_mask is not None and tf.executing_eagerly():
128 no_left_padding = tf.math.reduce_all(first_mask)
129 left_padding = not no_left_padding
130 if left_padding:
131 raise NotImplementedError(
132 "Currently, CRF layer do not support left padding"
133 )
135 potentials = self._dense_layer(inputs)
137 # appending boundary probability info
138 if self.use_boundary:
139 potentials = self.add_boundary_energy(
140 potentials, mask, self.left_boundary, self.right_boundary
141 )
143 sequence_length = self._get_sequence_length(inputs, mask)
145 decoded_sequence, _ = self.get_viterbi_decoding(potentials, sequence_length)
147 return [decoded_sequence, potentials, sequence_length, self.chain_kernel]
149 def _get_sequence_length(self, input_, mask):
150 """Currently underline CRF fucntion (provided by
151 tensorflow_addons.text.crf) do not support bi-direction masking (left
152 padding / right padding), it support right padding by tell it the
153 sequence length.
155 this function is compute the sequence length from input and
156 mask.
157 """
158 if mask is not None:
159 sequence_length = self.mask_to_sequence_length(mask)
160 else:
161 # make a mask tensor from input, then used to generate sequence_length
162 input_energy_shape = tf.shape(input_)
163 raw_input_shape = tf.slice(input_energy_shape, [0], [2])
164 alt_mask = tf.ones(raw_input_shape)
166 sequence_length = self.mask_to_sequence_length(alt_mask)
168 return sequence_length
170 def mask_to_sequence_length(self, mask):
171 """compute sequence length from mask."""
172 sequence_length = tf.reduce_sum(tf.cast(mask, tf.int64), 1)
173 return sequence_length
175 @staticmethod
176 def _compute_mask_right_boundary(mask):
177 """input mask: 0011100, output right_boundary: 0000100."""
178 # shift mask to left by 1: 0011100 => 0111000
179 offset = 1
180 left_shifted_mask = tf.concat(
181 [mask[:, offset:], tf.zeros_like(mask[:, :offset])], axis=1
182 )
184 # NOTE: below code is different from keras_contrib
185 # Original code in keras_contrib:
186 # end_mask = K.cast(
187 # K.greater(self.shift_left(mask), mask),
188 # K.floatx()
189 # )
190 # has a bug, confirmed
191 # by the original keras_contrib maintainer
192 # Luiz Felix (github: lzfelix),
194 # 0011100 > 0111000 => 0000100
195 right_boundary = tf.math.greater(
196 tf.cast(mask, tf.int32), tf.cast(left_shifted_mask, tf.int32)
197 )
199 return right_boundary
201 @staticmethod
202 def _compute_mask_left_boundary(mask):
203 """input mask: 0011100, output left_boundary: 0010000."""
204 # shift mask to right by 1: 0011100 => 0001110
205 offset = 1
206 right_shifted_mask = tf.concat(
207 [tf.zeros_like(mask[:, :offset]), mask[:, :-offset]], axis=1
208 )
210 # 0011100 > 0001110 => 0010000
211 left_boundary = tf.math.greater(
212 tf.cast(mask, tf.int32), tf.cast(right_shifted_mask, tf.int32)
213 )
215 return left_boundary
217 def add_boundary_energy(self, potentials, mask, start, end):
218 def expand_scalar_to_3d(x):
219 # expand tensor from shape (x, ) to (1, 1, x)
220 return tf.reshape(x, (1, 1, -1))
222 start = tf.cast(expand_scalar_to_3d(start), potentials.dtype)
223 end = tf.cast(expand_scalar_to_3d(end), potentials.dtype)
224 if mask is None:
225 potentials = tf.concat(
226 [potentials[:, :1, :] + start, potentials[:, 1:, :]], axis=1
227 )
228 potentials = tf.concat(
229 [potentials[:, :-1, :], potentials[:, -1:, :] + end], axis=1
230 )
231 else:
232 mask = tf.keras.backend.expand_dims(tf.cast(mask, start.dtype), axis=-1)
233 start_mask = tf.cast(self._compute_mask_left_boundary(mask), start.dtype)
235 end_mask = tf.cast(self._compute_mask_right_boundary(mask), end.dtype)
236 potentials = potentials + start_mask * start
237 potentials = potentials + end_mask * end
238 return potentials
240 def get_viterbi_decoding(self, potentials, sequence_length):
241 # decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`
242 decode_tags, best_score = crf_decode(
243 potentials, self.chain_kernel, sequence_length
244 )
246 return decode_tags, best_score
248 def get_config(self):
249 # used for loading model from disk
250 config = {
251 "units": self.units,
252 "chain_initializer": tf.keras.initializers.serialize(
253 self.chain_initializer
254 ),
255 "use_boundary": self.use_boundary,
256 "boundary_initializer": tf.keras.initializers.serialize(
257 self.boundary_initializer
258 ),
259 "use_kernel": self.use_kernel,
260 }
261 base_config = super().get_config()
262 return {**base_config, **config}
264 def compute_output_shape(self, input_shape):
265 output_shape = input_shape[:2]
266 return output_shape
268 def compute_mask(self, input_, mask=None):
269 """keep mask shape [batch_size, max_seq_len]"""
270 return mask
272 @property
273 def _compute_dtype(self):
274 # fixed output dtype from underline CRF functions
275 return tf.int32