Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/text/crf.py: 16%
198 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import warnings
17import numpy as np
18import tensorflow as tf
20from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
21from tensorflow_addons.utils.types import TensorLike
22from typeguard import typechecked
23from typing import Optional, Tuple
25# TODO: Wrap functions in @tf.function once
26# https://github.com/tensorflow/tensorflow/issues/29075 is resolved
29def crf_filtered_inputs(inputs: TensorLike, tag_bitmap: TensorLike) -> tf.Tensor:
30 """Constrains the inputs to filter out certain tags at each time step.
32 tag_bitmap limits the allowed tags at each input time step.
33 This is useful when an observed output at a given time step needs to be
34 constrained to a selected set of tags.
36 Args:
37 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
38 to use as input to the CRF layer.
39 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
40 representing all active tags at each index for which to calculate the
41 unnormalized score.
42 Returns:
43 filtered_inputs: A [batch_size] vector of unnormalized sequence scores.
44 """
46 # set scores of filtered out inputs to be -inf.
47 filtered_inputs = tf.where(
48 tag_bitmap,
49 inputs,
50 tf.fill(tf.shape(inputs), tf.cast(float("-inf"), inputs.dtype)),
51 )
52 return filtered_inputs
55def crf_sequence_score(
56 inputs: TensorLike,
57 tag_indices: TensorLike,
58 sequence_lengths: TensorLike,
59 transition_params: TensorLike,
60) -> tf.Tensor:
61 """Computes the unnormalized score for a tag sequence.
63 Args:
64 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
65 to use as input to the CRF layer.
66 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
67 we compute the unnormalized score.
68 sequence_lengths: A [batch_size] vector of true sequence lengths.
69 transition_params: A [num_tags, num_tags] transition matrix.
70 Returns:
71 sequence_scores: A [batch_size] vector of unnormalized sequence scores.
72 """
73 tag_indices = tf.cast(tag_indices, dtype=tf.int32)
74 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
76 # If max_seq_len is 1, we skip the score calculation and simply gather the
77 # unary potentials of the single tag.
78 def _single_seq_fn():
79 batch_size = tf.shape(inputs, out_type=tf.int32)[0]
80 batch_inds = tf.reshape(tf.range(batch_size), [-1, 1])
81 indices = tf.concat([batch_inds, tf.zeros_like(batch_inds)], axis=1)
83 tag_inds = tf.gather_nd(tag_indices, indices)
84 tag_inds = tf.reshape(tag_inds, [-1, 1])
85 indices = tf.concat([indices, tag_inds], axis=1)
87 sequence_scores = tf.gather_nd(inputs, indices)
89 sequence_scores = tf.where(
90 tf.less_equal(sequence_lengths, 0),
91 tf.zeros_like(sequence_scores),
92 sequence_scores,
93 )
94 return sequence_scores
96 def _multi_seq_fn():
97 # Compute the scores of the given tag sequence.
98 unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
99 binary_scores = crf_binary_score(
100 tag_indices, sequence_lengths, transition_params
101 )
102 sequence_scores = unary_scores + binary_scores
103 return sequence_scores
105 return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn)
108def crf_multitag_sequence_score(
109 inputs: TensorLike,
110 tag_bitmap: TensorLike,
111 sequence_lengths: TensorLike,
112 transition_params: TensorLike,
113) -> tf.Tensor:
114 """Computes the unnormalized score of all tag sequences matching
115 tag_bitmap.
117 tag_bitmap enables more than one tag to be considered correct at each time
118 step. This is useful when an observed output at a given time step is
119 consistent with more than one tag, and thus the log likelihood of that
120 observation must take into account all possible consistent tags.
122 Using one-hot vectors in tag_bitmap gives results identical to
123 crf_sequence_score.
125 Args:
126 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
127 to use as input to the CRF layer.
128 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
129 representing all active tags at each index for which to calculate the
130 unnormalized score.
131 sequence_lengths: A [batch_size] vector of true sequence lengths.
132 transition_params: A [num_tags, num_tags] transition matrix.
133 Returns:
134 sequence_scores: A [batch_size] vector of unnormalized sequence scores.
135 """
136 tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool)
137 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
138 filtered_inputs = crf_filtered_inputs(inputs, tag_bitmap)
140 # If max_seq_len is 1, we skip the score calculation and simply gather the
141 # unary potentials of all active tags.
142 def _single_seq_fn():
143 return tf.reduce_logsumexp(filtered_inputs, axis=[1, 2], keepdims=False)
145 def _multi_seq_fn():
146 # Compute the logsumexp of all scores of sequences
147 # matching the given tags.
148 return crf_log_norm(
149 inputs=filtered_inputs,
150 sequence_lengths=sequence_lengths,
151 transition_params=transition_params,
152 )
154 return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn)
157def crf_log_norm(
158 inputs: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike
159) -> tf.Tensor:
160 """Computes the normalization for a CRF.
162 Args:
163 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
164 to use as input to the CRF layer.
165 sequence_lengths: A [batch_size] vector of true sequence lengths.
166 transition_params: A [num_tags, num_tags] transition matrix.
167 Returns:
168 log_norm: A [batch_size] vector of normalizers for a CRF.
169 """
170 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
171 # Split up the first and rest of the inputs in preparation for the forward
172 # algorithm.
173 first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
174 first_input = tf.squeeze(first_input, [1])
176 # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp
177 # over the "initial state" (the unary potentials).
178 def _single_seq_fn():
179 log_norm = tf.reduce_logsumexp(first_input, [1])
180 # Mask `log_norm` of the sequences with length <= zero.
181 log_norm = tf.where(
182 tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm
183 )
184 return log_norm
186 def _multi_seq_fn():
187 """Forward computation of alpha values."""
188 rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
189 # Compute the alpha values in the forward algorithm in order to get the
190 # partition function.
192 alphas = crf_forward(
193 rest_of_input, first_input, transition_params, sequence_lengths
194 )
195 log_norm = tf.reduce_logsumexp(alphas, [1])
196 # Mask `log_norm` of the sequences with length <= zero.
197 log_norm = tf.where(
198 tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm
199 )
200 return log_norm
202 return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn)
205def crf_log_likelihood(
206 inputs: TensorLike,
207 tag_indices: TensorLike,
208 sequence_lengths: TensorLike,
209 transition_params: Optional[TensorLike] = None,
210) -> Tuple[tf.Tensor, tf.Tensor]:
211 """Computes the log-likelihood of tag sequences in a CRF.
213 Args:
214 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
215 to use as input to the CRF layer.
216 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
217 we compute the log-likelihood.
218 sequence_lengths: A [batch_size] vector of true sequence lengths.
219 transition_params: A [num_tags, num_tags] transition matrix,
220 if available.
221 Returns:
222 log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
223 each example, given the sequence of tag indices.
224 transition_params: A [num_tags, num_tags] transition matrix. This is
225 either provided by the caller or created in this function.
226 """
227 inputs = tf.convert_to_tensor(inputs)
229 num_tags = inputs.shape[2]
231 # cast type to handle different types
232 tag_indices = tf.cast(tag_indices, dtype=tf.int32)
233 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
235 # TODO(windqaq): re-evaluate if `transition_params` can be `None`.
236 if transition_params is None:
237 initializer = tf.keras.initializers.GlorotUniform()
238 transition_params = tf.Variable(
239 initializer([num_tags, num_tags]), "transitions"
240 )
241 transition_params = tf.cast(transition_params, inputs.dtype)
242 sequence_scores = crf_sequence_score(
243 inputs, tag_indices, sequence_lengths, transition_params
244 )
245 log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
247 # Normalize the scores to get the log-likelihood per example.
248 log_likelihood = sequence_scores - log_norm
249 return log_likelihood, transition_params
252def crf_unary_score(
253 tag_indices: TensorLike, sequence_lengths: TensorLike, inputs: TensorLike
254) -> tf.Tensor:
255 """Computes the unary scores of tag sequences.
257 Args:
258 tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
259 sequence_lengths: A [batch_size] vector of true sequence lengths.
260 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
261 Returns:
262 unary_scores: A [batch_size] vector of unary scores.
263 """
264 tag_indices = tf.cast(tag_indices, dtype=tf.int32)
265 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
267 batch_size = tf.shape(inputs)[0]
268 max_seq_len = tf.shape(inputs)[1]
269 num_tags = tf.shape(inputs)[2]
271 flattened_inputs = tf.reshape(inputs, [-1])
273 offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
274 offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
275 # Use int32 or int64 based on tag_indices' dtype.
276 if tag_indices.dtype == tf.int64:
277 offsets = tf.cast(offsets, tf.int64)
278 flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])
280 unary_scores = tf.reshape(
281 tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]
282 )
284 masks = tf.sequence_mask(
285 sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=unary_scores.dtype
286 )
288 unary_scores = tf.reduce_sum(unary_scores * masks, 1)
289 return unary_scores
292def crf_binary_score(
293 tag_indices: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike
294) -> tf.Tensor:
295 """Computes the binary scores of tag sequences.
297 Args:
298 tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
299 sequence_lengths: A [batch_size] vector of true sequence lengths.
300 transition_params: A [num_tags, num_tags] matrix of binary potentials.
301 Returns:
302 binary_scores: A [batch_size] vector of binary scores.
303 """
304 tag_indices = tf.cast(tag_indices, dtype=tf.int32)
305 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
307 num_tags = tf.shape(transition_params)[0]
308 num_transitions = tf.shape(tag_indices)[1] - 1
310 # Truncate by one on each side of the sequence to get the start and end
311 # indices of each transition.
312 start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions])
313 end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions])
315 # Encode the indices in a flattened representation.
316 flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
317 flattened_transition_params = tf.reshape(transition_params, [-1])
319 # Get the binary scores based on the flattened representation.
320 binary_scores = tf.gather(flattened_transition_params, flattened_transition_indices)
322 masks = tf.sequence_mask(
323 sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=binary_scores.dtype
324 )
325 truncated_masks = tf.slice(masks, [0, 1], [-1, -1])
326 binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1)
327 return binary_scores
330def crf_forward(
331 inputs: TensorLike,
332 state: TensorLike,
333 transition_params: TensorLike,
334 sequence_lengths: TensorLike,
335) -> tf.Tensor:
336 """Computes the alpha values in a linear-chain CRF.
338 See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
340 Args:
341 inputs: A [batch_size, num_tags] matrix of unary potentials.
342 state: A [batch_size, num_tags] matrix containing the previous alpha
343 values.
344 transition_params: A [num_tags, num_tags] matrix of binary potentials.
345 This matrix is expanded into a [1, num_tags, num_tags] in preparation
346 for the broadcast summation occurring within the cell.
347 sequence_lengths: A [batch_size] vector of true sequence lengths.
349 Returns:
350 new_alphas: A [batch_size, num_tags] matrix containing the
351 new alpha values.
352 """
353 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
355 last_index = tf.maximum(
356 tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1
357 )
358 inputs = tf.transpose(inputs, [1, 0, 2])
359 transition_params = tf.expand_dims(transition_params, 0)
361 def _scan_fn(_state, _inputs):
362 _state = tf.expand_dims(_state, 2)
363 transition_scores = _state + transition_params
364 new_alphas = _inputs + tf.reduce_logsumexp(transition_scores, [1])
365 return new_alphas
367 all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
368 # add first state for sequences of length 1
369 all_alphas = tf.concat([tf.expand_dims(state, 1), all_alphas], 1)
371 idxs = tf.stack([tf.range(tf.shape(last_index)[0]), last_index], axis=1)
372 return tf.gather_nd(all_alphas, idxs)
375def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tensor:
376 """Decode the highest scoring sequence of tags outside of TensorFlow.
378 This should only be used at test time.
380 Args:
381 score: A [seq_len, num_tags] matrix of unary potentials.
382 transition_params: A [num_tags, num_tags] matrix of binary potentials.
384 Returns:
385 viterbi: A [seq_len] list of integers containing the highest scoring tag
386 indices.
387 viterbi_score: A float containing the score for the Viterbi sequence.
388 """
389 trellis = np.zeros_like(score)
390 backpointers = np.zeros_like(score, dtype=np.int32)
391 trellis[0] = score[0]
393 for t in range(1, score.shape[0]):
394 v = np.expand_dims(trellis[t - 1], 1) + transition_params
395 trellis[t] = score[t] + np.max(v, 0)
396 backpointers[t] = np.argmax(v, 0)
398 viterbi = [np.argmax(trellis[-1])]
399 for bp in reversed(backpointers[1:]):
400 viterbi.append(bp[viterbi[-1]])
401 viterbi.reverse()
403 viterbi_score = np.max(trellis[-1])
404 return viterbi, viterbi_score
407class CrfDecodeForwardRnnCell(AbstractRNNCell):
408 """Computes the forward decoding in a linear-chain CRF."""
410 @typechecked
411 def __init__(self, transition_params: TensorLike, **kwargs):
412 """Initialize the CrfDecodeForwardRnnCell.
414 Args:
415 transition_params: A [num_tags, num_tags] matrix of binary
416 potentials. This matrix is expanded into a
417 [1, num_tags, num_tags] in preparation for the broadcast
418 summation occurring within the cell.
419 """
420 super().__init__(**kwargs)
421 self._transition_params = tf.expand_dims(transition_params, 0)
422 self._num_tags = transition_params.shape[0]
424 @property
425 def state_size(self):
426 return self._num_tags
428 @property
429 def output_size(self):
430 return self._num_tags
432 def build(self, input_shape):
433 super().build(input_shape)
435 def call(self, inputs, state):
436 """Build the CrfDecodeForwardRnnCell.
438 Args:
439 inputs: A [batch_size, num_tags] matrix of unary potentials.
440 state: A [batch_size, num_tags] matrix containing the previous step's
441 score values.
443 Returns:
444 backpointers: A [batch_size, num_tags] matrix of backpointers.
445 new_state: A [batch_size, num_tags] matrix of new score values.
446 """
447 state = tf.expand_dims(state[0], 2)
448 transition_scores = state + tf.cast(
449 self._transition_params, self._compute_dtype
450 )
451 new_state = inputs + tf.reduce_max(transition_scores, [1])
452 backpointers = tf.argmax(transition_scores, 1)
453 backpointers = tf.cast(backpointers, dtype=tf.int32)
454 return backpointers, new_state
456 def get_config(self) -> dict:
457 config = {
458 "transition_params": tf.squeeze(self._transition_params, 0).numpy().tolist()
459 }
460 base_config = super(CrfDecodeForwardRnnCell, self).get_config()
461 return dict(list(base_config.items()) + list(config.items()))
463 @classmethod
464 def from_config(cls, config: dict) -> "CrfDecodeForwardRnnCell":
465 config["transition_params"] = np.array(
466 config["transition_params"], dtype=np.float32
467 )
468 return cls(**config)
471def crf_decode_forward(
472 inputs: TensorLike,
473 state: TensorLike,
474 transition_params: TensorLike,
475 sequence_lengths: TensorLike,
476) -> tf.Tensor:
477 """Computes forward decoding in a linear-chain CRF.
479 Args:
480 inputs: A [batch_size, num_tags] matrix of unary potentials.
481 state: A [batch_size, num_tags] matrix containing the previous step's
482 score values.
483 transition_params: A [num_tags, num_tags] matrix of binary potentials.
484 sequence_lengths: A [batch_size] vector of true sequence lengths.
486 Returns:
487 backpointers: A [batch_size, num_tags] matrix of backpointers.
488 new_state: A [batch_size, num_tags] matrix of new score values.
489 """
490 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
491 mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
492 crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype)
493 crf_fwd_layer = tf.keras.layers.RNN(
494 crf_fwd_cell,
495 return_sequences=True,
496 return_state=True,
497 dtype=inputs.dtype,
498 zero_output_for_mask=True, # See: https://github.com/tensorflow/addons/issues/2639
499 )
500 return crf_fwd_layer(inputs, state, mask=mask)
503def crf_decode_backward(inputs: TensorLike, state: TensorLike) -> tf.Tensor:
504 """Computes backward decoding in a linear-chain CRF.
506 Args:
507 inputs: A [batch_size, num_tags] matrix of
508 backpointer of next step (in time order).
509 state: A [batch_size, 1] matrix of tag index of next step.
511 Returns:
512 new_tags: A [batch_size, num_tags]
513 tensor containing the new tag indices.
514 """
515 inputs = tf.transpose(inputs, [1, 0, 2])
517 def _scan_fn(state, inputs):
518 state = tf.squeeze(state, axis=[1])
519 idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1)
520 new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1)
521 return new_tags
523 return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
526def crf_decode(
527 potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike
528) -> tf.Tensor:
529 """Decode the highest scoring sequence of tags.
531 Args:
532 potentials: A [batch_size, max_seq_len, num_tags] tensor of
533 unary potentials.
534 transition_params: A [num_tags, num_tags] matrix of
535 binary potentials.
536 sequence_length: A [batch_size] vector of true sequence lengths.
538 Returns:
539 decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
540 Contains the highest scoring tag indices.
541 best_score: A [batch_size] vector, containing the score of `decode_tags`.
542 """
543 if tf.__version__[:3] == "2.4":
544 warnings.warn(
545 "CRF Decoding does not work with KerasTensors in TF2.4. The bug has since been fixed in tensorflow/tensorflow##45534"
546 )
548 sequence_length = tf.cast(sequence_length, dtype=tf.int32)
550 # If max_seq_len is 1, we skip the algorithm and simply return the
551 # argmax tag and the max activation.
552 def _single_seq_fn():
553 decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32)
554 best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1])
555 return decode_tags, best_score
557 def _multi_seq_fn():
558 # Computes forward decoding. Get last score and backpointers.
559 initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1])
560 initial_state = tf.squeeze(initial_state, axis=[1])
561 inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1])
563 sequence_length_less_one = tf.maximum(
564 tf.constant(0, dtype=tf.int32), sequence_length - 1
565 )
567 backpointers, last_score = crf_decode_forward(
568 inputs, initial_state, transition_params, sequence_length_less_one
569 )
571 backpointers = tf.reverse_sequence(
572 backpointers, sequence_length_less_one, seq_axis=1
573 )
575 initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32)
576 initial_state = tf.expand_dims(initial_state, axis=-1)
578 decode_tags = crf_decode_backward(backpointers, initial_state)
579 decode_tags = tf.squeeze(decode_tags, axis=[2])
580 decode_tags = tf.concat([initial_state, decode_tags], axis=1)
581 decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1)
583 best_score = tf.reduce_max(last_score, axis=1)
584 return decode_tags, best_score
586 if potentials.shape[1] is not None:
587 # shape is statically know, so we just execute
588 # the appropriate code path
589 if potentials.shape[1] == 1:
590 return _single_seq_fn()
591 else:
592 return _multi_seq_fn()
593 else:
594 return tf.cond(
595 tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn
596 )
599def crf_constrained_decode(
600 potentials: TensorLike,
601 tag_bitmap: TensorLike,
602 transition_params: TensorLike,
603 sequence_length: TensorLike,
604) -> tf.Tensor:
605 """Decode the highest scoring sequence of tags under constraints.
607 This is a function for tensor.
609 Args:
610 potentials: A [batch_size, max_seq_len, num_tags] tensor of
611 unary potentials.
612 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
613 representing all active tags at each index for which to calculate the
614 unnormalized score.
615 transition_params: A [num_tags, num_tags] matrix of
616 binary potentials.
617 sequence_length: A [batch_size] vector of true sequence lengths.
618 Returns:
619 decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
620 Contains the highest scoring tag indices.
621 best_score: A [batch_size] vector, containing the score of `decode_tags`.
622 """
624 filtered_potentials = crf_filtered_inputs(potentials, tag_bitmap)
625 return crf_decode(filtered_potentials, transition_params, sequence_length)