Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/text/crf.py: 16%

198 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15import warnings 

16 

17import numpy as np 

18import tensorflow as tf 

19 

20from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell 

21from tensorflow_addons.utils.types import TensorLike 

22from typeguard import typechecked 

23from typing import Optional, Tuple 

24 

25# TODO: Wrap functions in @tf.function once 

26# https://github.com/tensorflow/tensorflow/issues/29075 is resolved 

27 

28 

29def crf_filtered_inputs(inputs: TensorLike, tag_bitmap: TensorLike) -> tf.Tensor: 

30 """Constrains the inputs to filter out certain tags at each time step. 

31 

32 tag_bitmap limits the allowed tags at each input time step. 

33 This is useful when an observed output at a given time step needs to be 

34 constrained to a selected set of tags. 

35 

36 Args: 

37 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 

38 to use as input to the CRF layer. 

39 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor 

40 representing all active tags at each index for which to calculate the 

41 unnormalized score. 

42 Returns: 

43 filtered_inputs: A [batch_size] vector of unnormalized sequence scores. 

44 """ 

45 

46 # set scores of filtered out inputs to be -inf. 

47 filtered_inputs = tf.where( 

48 tag_bitmap, 

49 inputs, 

50 tf.fill(tf.shape(inputs), tf.cast(float("-inf"), inputs.dtype)), 

51 ) 

52 return filtered_inputs 

53 

54 

55def crf_sequence_score( 

56 inputs: TensorLike, 

57 tag_indices: TensorLike, 

58 sequence_lengths: TensorLike, 

59 transition_params: TensorLike, 

60) -> tf.Tensor: 

61 """Computes the unnormalized score for a tag sequence. 

62 

63 Args: 

64 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 

65 to use as input to the CRF layer. 

66 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which 

67 we compute the unnormalized score. 

68 sequence_lengths: A [batch_size] vector of true sequence lengths. 

69 transition_params: A [num_tags, num_tags] transition matrix. 

70 Returns: 

71 sequence_scores: A [batch_size] vector of unnormalized sequence scores. 

72 """ 

73 tag_indices = tf.cast(tag_indices, dtype=tf.int32) 

74 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

75 

76 # If max_seq_len is 1, we skip the score calculation and simply gather the 

77 # unary potentials of the single tag. 

78 def _single_seq_fn(): 

79 batch_size = tf.shape(inputs, out_type=tf.int32)[0] 

80 batch_inds = tf.reshape(tf.range(batch_size), [-1, 1]) 

81 indices = tf.concat([batch_inds, tf.zeros_like(batch_inds)], axis=1) 

82 

83 tag_inds = tf.gather_nd(tag_indices, indices) 

84 tag_inds = tf.reshape(tag_inds, [-1, 1]) 

85 indices = tf.concat([indices, tag_inds], axis=1) 

86 

87 sequence_scores = tf.gather_nd(inputs, indices) 

88 

89 sequence_scores = tf.where( 

90 tf.less_equal(sequence_lengths, 0), 

91 tf.zeros_like(sequence_scores), 

92 sequence_scores, 

93 ) 

94 return sequence_scores 

95 

96 def _multi_seq_fn(): 

97 # Compute the scores of the given tag sequence. 

98 unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) 

99 binary_scores = crf_binary_score( 

100 tag_indices, sequence_lengths, transition_params 

101 ) 

102 sequence_scores = unary_scores + binary_scores 

103 return sequence_scores 

104 

105 return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn) 

106 

107 

108def crf_multitag_sequence_score( 

109 inputs: TensorLike, 

110 tag_bitmap: TensorLike, 

111 sequence_lengths: TensorLike, 

112 transition_params: TensorLike, 

113) -> tf.Tensor: 

114 """Computes the unnormalized score of all tag sequences matching 

115 tag_bitmap. 

116 

117 tag_bitmap enables more than one tag to be considered correct at each time 

118 step. This is useful when an observed output at a given time step is 

119 consistent with more than one tag, and thus the log likelihood of that 

120 observation must take into account all possible consistent tags. 

121 

122 Using one-hot vectors in tag_bitmap gives results identical to 

123 crf_sequence_score. 

124 

125 Args: 

126 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 

127 to use as input to the CRF layer. 

128 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor 

129 representing all active tags at each index for which to calculate the 

130 unnormalized score. 

131 sequence_lengths: A [batch_size] vector of true sequence lengths. 

132 transition_params: A [num_tags, num_tags] transition matrix. 

133 Returns: 

134 sequence_scores: A [batch_size] vector of unnormalized sequence scores. 

135 """ 

136 tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool) 

137 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

138 filtered_inputs = crf_filtered_inputs(inputs, tag_bitmap) 

139 

140 # If max_seq_len is 1, we skip the score calculation and simply gather the 

141 # unary potentials of all active tags. 

142 def _single_seq_fn(): 

143 return tf.reduce_logsumexp(filtered_inputs, axis=[1, 2], keepdims=False) 

144 

145 def _multi_seq_fn(): 

146 # Compute the logsumexp of all scores of sequences 

147 # matching the given tags. 

148 return crf_log_norm( 

149 inputs=filtered_inputs, 

150 sequence_lengths=sequence_lengths, 

151 transition_params=transition_params, 

152 ) 

153 

154 return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn) 

155 

156 

157def crf_log_norm( 

158 inputs: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike 

159) -> tf.Tensor: 

160 """Computes the normalization for a CRF. 

161 

162 Args: 

163 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 

164 to use as input to the CRF layer. 

165 sequence_lengths: A [batch_size] vector of true sequence lengths. 

166 transition_params: A [num_tags, num_tags] transition matrix. 

167 Returns: 

168 log_norm: A [batch_size] vector of normalizers for a CRF. 

169 """ 

170 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

171 # Split up the first and rest of the inputs in preparation for the forward 

172 # algorithm. 

173 first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) 

174 first_input = tf.squeeze(first_input, [1]) 

175 

176 # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp 

177 # over the "initial state" (the unary potentials). 

178 def _single_seq_fn(): 

179 log_norm = tf.reduce_logsumexp(first_input, [1]) 

180 # Mask `log_norm` of the sequences with length <= zero. 

181 log_norm = tf.where( 

182 tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm 

183 ) 

184 return log_norm 

185 

186 def _multi_seq_fn(): 

187 """Forward computation of alpha values.""" 

188 rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) 

189 # Compute the alpha values in the forward algorithm in order to get the 

190 # partition function. 

191 

192 alphas = crf_forward( 

193 rest_of_input, first_input, transition_params, sequence_lengths 

194 ) 

195 log_norm = tf.reduce_logsumexp(alphas, [1]) 

196 # Mask `log_norm` of the sequences with length <= zero. 

197 log_norm = tf.where( 

198 tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm 

199 ) 

200 return log_norm 

201 

202 return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn) 

203 

204 

205def crf_log_likelihood( 

206 inputs: TensorLike, 

207 tag_indices: TensorLike, 

208 sequence_lengths: TensorLike, 

209 transition_params: Optional[TensorLike] = None, 

210) -> Tuple[tf.Tensor, tf.Tensor]: 

211 """Computes the log-likelihood of tag sequences in a CRF. 

212 

213 Args: 

214 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 

215 to use as input to the CRF layer. 

216 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which 

217 we compute the log-likelihood. 

218 sequence_lengths: A [batch_size] vector of true sequence lengths. 

219 transition_params: A [num_tags, num_tags] transition matrix, 

220 if available. 

221 Returns: 

222 log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of 

223 each example, given the sequence of tag indices. 

224 transition_params: A [num_tags, num_tags] transition matrix. This is 

225 either provided by the caller or created in this function. 

226 """ 

227 inputs = tf.convert_to_tensor(inputs) 

228 

229 num_tags = inputs.shape[2] 

230 

231 # cast type to handle different types 

232 tag_indices = tf.cast(tag_indices, dtype=tf.int32) 

233 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

234 

235 # TODO(windqaq): re-evaluate if `transition_params` can be `None`. 

236 if transition_params is None: 

237 initializer = tf.keras.initializers.GlorotUniform() 

238 transition_params = tf.Variable( 

239 initializer([num_tags, num_tags]), "transitions" 

240 ) 

241 transition_params = tf.cast(transition_params, inputs.dtype) 

242 sequence_scores = crf_sequence_score( 

243 inputs, tag_indices, sequence_lengths, transition_params 

244 ) 

245 log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) 

246 

247 # Normalize the scores to get the log-likelihood per example. 

248 log_likelihood = sequence_scores - log_norm 

249 return log_likelihood, transition_params 

250 

251 

252def crf_unary_score( 

253 tag_indices: TensorLike, sequence_lengths: TensorLike, inputs: TensorLike 

254) -> tf.Tensor: 

255 """Computes the unary scores of tag sequences. 

256 

257 Args: 

258 tag_indices: A [batch_size, max_seq_len] matrix of tag indices. 

259 sequence_lengths: A [batch_size] vector of true sequence lengths. 

260 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. 

261 Returns: 

262 unary_scores: A [batch_size] vector of unary scores. 

263 """ 

264 tag_indices = tf.cast(tag_indices, dtype=tf.int32) 

265 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

266 

267 batch_size = tf.shape(inputs)[0] 

268 max_seq_len = tf.shape(inputs)[1] 

269 num_tags = tf.shape(inputs)[2] 

270 

271 flattened_inputs = tf.reshape(inputs, [-1]) 

272 

273 offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) 

274 offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) 

275 # Use int32 or int64 based on tag_indices' dtype. 

276 if tag_indices.dtype == tf.int64: 

277 offsets = tf.cast(offsets, tf.int64) 

278 flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) 

279 

280 unary_scores = tf.reshape( 

281 tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len] 

282 ) 

283 

284 masks = tf.sequence_mask( 

285 sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=unary_scores.dtype 

286 ) 

287 

288 unary_scores = tf.reduce_sum(unary_scores * masks, 1) 

289 return unary_scores 

290 

291 

292def crf_binary_score( 

293 tag_indices: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike 

294) -> tf.Tensor: 

295 """Computes the binary scores of tag sequences. 

296 

297 Args: 

298 tag_indices: A [batch_size, max_seq_len] matrix of tag indices. 

299 sequence_lengths: A [batch_size] vector of true sequence lengths. 

300 transition_params: A [num_tags, num_tags] matrix of binary potentials. 

301 Returns: 

302 binary_scores: A [batch_size] vector of binary scores. 

303 """ 

304 tag_indices = tf.cast(tag_indices, dtype=tf.int32) 

305 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

306 

307 num_tags = tf.shape(transition_params)[0] 

308 num_transitions = tf.shape(tag_indices)[1] - 1 

309 

310 # Truncate by one on each side of the sequence to get the start and end 

311 # indices of each transition. 

312 start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions]) 

313 end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) 

314 

315 # Encode the indices in a flattened representation. 

316 flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices 

317 flattened_transition_params = tf.reshape(transition_params, [-1]) 

318 

319 # Get the binary scores based on the flattened representation. 

320 binary_scores = tf.gather(flattened_transition_params, flattened_transition_indices) 

321 

322 masks = tf.sequence_mask( 

323 sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=binary_scores.dtype 

324 ) 

325 truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) 

326 binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) 

327 return binary_scores 

328 

329 

330def crf_forward( 

331 inputs: TensorLike, 

332 state: TensorLike, 

333 transition_params: TensorLike, 

334 sequence_lengths: TensorLike, 

335) -> tf.Tensor: 

336 """Computes the alpha values in a linear-chain CRF. 

337 

338 See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. 

339 

340 Args: 

341 inputs: A [batch_size, num_tags] matrix of unary potentials. 

342 state: A [batch_size, num_tags] matrix containing the previous alpha 

343 values. 

344 transition_params: A [num_tags, num_tags] matrix of binary potentials. 

345 This matrix is expanded into a [1, num_tags, num_tags] in preparation 

346 for the broadcast summation occurring within the cell. 

347 sequence_lengths: A [batch_size] vector of true sequence lengths. 

348 

349 Returns: 

350 new_alphas: A [batch_size, num_tags] matrix containing the 

351 new alpha values. 

352 """ 

353 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

354 

355 last_index = tf.maximum( 

356 tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1 

357 ) 

358 inputs = tf.transpose(inputs, [1, 0, 2]) 

359 transition_params = tf.expand_dims(transition_params, 0) 

360 

361 def _scan_fn(_state, _inputs): 

362 _state = tf.expand_dims(_state, 2) 

363 transition_scores = _state + transition_params 

364 new_alphas = _inputs + tf.reduce_logsumexp(transition_scores, [1]) 

365 return new_alphas 

366 

367 all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) 

368 # add first state for sequences of length 1 

369 all_alphas = tf.concat([tf.expand_dims(state, 1), all_alphas], 1) 

370 

371 idxs = tf.stack([tf.range(tf.shape(last_index)[0]), last_index], axis=1) 

372 return tf.gather_nd(all_alphas, idxs) 

373 

374 

375def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tensor: 

376 """Decode the highest scoring sequence of tags outside of TensorFlow. 

377 

378 This should only be used at test time. 

379 

380 Args: 

381 score: A [seq_len, num_tags] matrix of unary potentials. 

382 transition_params: A [num_tags, num_tags] matrix of binary potentials. 

383 

384 Returns: 

385 viterbi: A [seq_len] list of integers containing the highest scoring tag 

386 indices. 

387 viterbi_score: A float containing the score for the Viterbi sequence. 

388 """ 

389 trellis = np.zeros_like(score) 

390 backpointers = np.zeros_like(score, dtype=np.int32) 

391 trellis[0] = score[0] 

392 

393 for t in range(1, score.shape[0]): 

394 v = np.expand_dims(trellis[t - 1], 1) + transition_params 

395 trellis[t] = score[t] + np.max(v, 0) 

396 backpointers[t] = np.argmax(v, 0) 

397 

398 viterbi = [np.argmax(trellis[-1])] 

399 for bp in reversed(backpointers[1:]): 

400 viterbi.append(bp[viterbi[-1]]) 

401 viterbi.reverse() 

402 

403 viterbi_score = np.max(trellis[-1]) 

404 return viterbi, viterbi_score 

405 

406 

407class CrfDecodeForwardRnnCell(AbstractRNNCell): 

408 """Computes the forward decoding in a linear-chain CRF.""" 

409 

410 @typechecked 

411 def __init__(self, transition_params: TensorLike, **kwargs): 

412 """Initialize the CrfDecodeForwardRnnCell. 

413 

414 Args: 

415 transition_params: A [num_tags, num_tags] matrix of binary 

416 potentials. This matrix is expanded into a 

417 [1, num_tags, num_tags] in preparation for the broadcast 

418 summation occurring within the cell. 

419 """ 

420 super().__init__(**kwargs) 

421 self._transition_params = tf.expand_dims(transition_params, 0) 

422 self._num_tags = transition_params.shape[0] 

423 

424 @property 

425 def state_size(self): 

426 return self._num_tags 

427 

428 @property 

429 def output_size(self): 

430 return self._num_tags 

431 

432 def build(self, input_shape): 

433 super().build(input_shape) 

434 

435 def call(self, inputs, state): 

436 """Build the CrfDecodeForwardRnnCell. 

437 

438 Args: 

439 inputs: A [batch_size, num_tags] matrix of unary potentials. 

440 state: A [batch_size, num_tags] matrix containing the previous step's 

441 score values. 

442 

443 Returns: 

444 backpointers: A [batch_size, num_tags] matrix of backpointers. 

445 new_state: A [batch_size, num_tags] matrix of new score values. 

446 """ 

447 state = tf.expand_dims(state[0], 2) 

448 transition_scores = state + tf.cast( 

449 self._transition_params, self._compute_dtype 

450 ) 

451 new_state = inputs + tf.reduce_max(transition_scores, [1]) 

452 backpointers = tf.argmax(transition_scores, 1) 

453 backpointers = tf.cast(backpointers, dtype=tf.int32) 

454 return backpointers, new_state 

455 

456 def get_config(self) -> dict: 

457 config = { 

458 "transition_params": tf.squeeze(self._transition_params, 0).numpy().tolist() 

459 } 

460 base_config = super(CrfDecodeForwardRnnCell, self).get_config() 

461 return dict(list(base_config.items()) + list(config.items())) 

462 

463 @classmethod 

464 def from_config(cls, config: dict) -> "CrfDecodeForwardRnnCell": 

465 config["transition_params"] = np.array( 

466 config["transition_params"], dtype=np.float32 

467 ) 

468 return cls(**config) 

469 

470 

471def crf_decode_forward( 

472 inputs: TensorLike, 

473 state: TensorLike, 

474 transition_params: TensorLike, 

475 sequence_lengths: TensorLike, 

476) -> tf.Tensor: 

477 """Computes forward decoding in a linear-chain CRF. 

478 

479 Args: 

480 inputs: A [batch_size, num_tags] matrix of unary potentials. 

481 state: A [batch_size, num_tags] matrix containing the previous step's 

482 score values. 

483 transition_params: A [num_tags, num_tags] matrix of binary potentials. 

484 sequence_lengths: A [batch_size] vector of true sequence lengths. 

485 

486 Returns: 

487 backpointers: A [batch_size, num_tags] matrix of backpointers. 

488 new_state: A [batch_size, num_tags] matrix of new score values. 

489 """ 

490 sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) 

491 mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) 

492 crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype) 

493 crf_fwd_layer = tf.keras.layers.RNN( 

494 crf_fwd_cell, 

495 return_sequences=True, 

496 return_state=True, 

497 dtype=inputs.dtype, 

498 zero_output_for_mask=True, # See: https://github.com/tensorflow/addons/issues/2639 

499 ) 

500 return crf_fwd_layer(inputs, state, mask=mask) 

501 

502 

503def crf_decode_backward(inputs: TensorLike, state: TensorLike) -> tf.Tensor: 

504 """Computes backward decoding in a linear-chain CRF. 

505 

506 Args: 

507 inputs: A [batch_size, num_tags] matrix of 

508 backpointer of next step (in time order). 

509 state: A [batch_size, 1] matrix of tag index of next step. 

510 

511 Returns: 

512 new_tags: A [batch_size, num_tags] 

513 tensor containing the new tag indices. 

514 """ 

515 inputs = tf.transpose(inputs, [1, 0, 2]) 

516 

517 def _scan_fn(state, inputs): 

518 state = tf.squeeze(state, axis=[1]) 

519 idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) 

520 new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) 

521 return new_tags 

522 

523 return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) 

524 

525 

526def crf_decode( 

527 potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike 

528) -> tf.Tensor: 

529 """Decode the highest scoring sequence of tags. 

530 

531 Args: 

532 potentials: A [batch_size, max_seq_len, num_tags] tensor of 

533 unary potentials. 

534 transition_params: A [num_tags, num_tags] matrix of 

535 binary potentials. 

536 sequence_length: A [batch_size] vector of true sequence lengths. 

537 

538 Returns: 

539 decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. 

540 Contains the highest scoring tag indices. 

541 best_score: A [batch_size] vector, containing the score of `decode_tags`. 

542 """ 

543 if tf.__version__[:3] == "2.4": 

544 warnings.warn( 

545 "CRF Decoding does not work with KerasTensors in TF2.4. The bug has since been fixed in tensorflow/tensorflow##45534" 

546 ) 

547 

548 sequence_length = tf.cast(sequence_length, dtype=tf.int32) 

549 

550 # If max_seq_len is 1, we skip the algorithm and simply return the 

551 # argmax tag and the max activation. 

552 def _single_seq_fn(): 

553 decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) 

554 best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) 

555 return decode_tags, best_score 

556 

557 def _multi_seq_fn(): 

558 # Computes forward decoding. Get last score and backpointers. 

559 initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) 

560 initial_state = tf.squeeze(initial_state, axis=[1]) 

561 inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) 

562 

563 sequence_length_less_one = tf.maximum( 

564 tf.constant(0, dtype=tf.int32), sequence_length - 1 

565 ) 

566 

567 backpointers, last_score = crf_decode_forward( 

568 inputs, initial_state, transition_params, sequence_length_less_one 

569 ) 

570 

571 backpointers = tf.reverse_sequence( 

572 backpointers, sequence_length_less_one, seq_axis=1 

573 ) 

574 

575 initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) 

576 initial_state = tf.expand_dims(initial_state, axis=-1) 

577 

578 decode_tags = crf_decode_backward(backpointers, initial_state) 

579 decode_tags = tf.squeeze(decode_tags, axis=[2]) 

580 decode_tags = tf.concat([initial_state, decode_tags], axis=1) 

581 decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1) 

582 

583 best_score = tf.reduce_max(last_score, axis=1) 

584 return decode_tags, best_score 

585 

586 if potentials.shape[1] is not None: 

587 # shape is statically know, so we just execute 

588 # the appropriate code path 

589 if potentials.shape[1] == 1: 

590 return _single_seq_fn() 

591 else: 

592 return _multi_seq_fn() 

593 else: 

594 return tf.cond( 

595 tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn 

596 ) 

597 

598 

599def crf_constrained_decode( 

600 potentials: TensorLike, 

601 tag_bitmap: TensorLike, 

602 transition_params: TensorLike, 

603 sequence_length: TensorLike, 

604) -> tf.Tensor: 

605 """Decode the highest scoring sequence of tags under constraints. 

606 

607 This is a function for tensor. 

608 

609 Args: 

610 potentials: A [batch_size, max_seq_len, num_tags] tensor of 

611 unary potentials. 

612 tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor 

613 representing all active tags at each index for which to calculate the 

614 unnormalized score. 

615 transition_params: A [num_tags, num_tags] matrix of 

616 binary potentials. 

617 sequence_length: A [batch_size] vector of true sequence lengths. 

618 Returns: 

619 decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. 

620 Contains the highest scoring tag indices. 

621 best_score: A [batch_size] vector, containing the score of `decode_tags`. 

622 """ 

623 

624 filtered_potentials = crf_filtered_inputs(potentials, tag_bitmap) 

625 return crf_decode(filtered_potentials, transition_params, sequence_length)