Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ctc_ops.py: 18%

439 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""CTC (Connectionist Temporal Classification) Operations.""" 

16 

17import uuid 

18 

19from tensorflow.python.eager import context 

20from tensorflow.python.eager import def_function 

21 

22from tensorflow.python.framework import constant_op 

23from tensorflow.python.framework import device 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import function 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import sparse_tensor 

28from tensorflow.python.framework import tensor_shape 

29 

30from tensorflow.python.ops import array_ops 

31from tensorflow.python.ops import array_ops_stack 

32from tensorflow.python.ops import custom_gradient 

33from tensorflow.python.ops import functional_ops 

34from tensorflow.python.ops import gen_ctc_ops 

35from tensorflow.python.ops import inplace_ops 

36from tensorflow.python.ops import linalg_ops 

37from tensorflow.python.ops import map_fn 

38from tensorflow.python.ops import math_ops 

39from tensorflow.python.ops import nn_ops 

40from tensorflow.python.ops import sparse_ops 

41from tensorflow.python.ops.nn_grad import _BroadcastMul 

42from tensorflow.python.util import deprecation 

43from tensorflow.python.util import dispatch 

44from tensorflow.python.util import nest 

45from tensorflow.python.util.tf_export import tf_export 

46 

47_DEFUN_API_NAME_ATTRIBUTE = "api_implements" 

48_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device" 

49_CPU_DEVICE_NAME = "CPU" 

50_GPU_DEVICE_NAME = "GPU" 

51 

52 

53def _get_context_device_type(): 

54 """Parses the current context and returns the device type, eg CPU/GPU.""" 

55 current_device = context.context().device_name 

56 if current_device is None: 

57 return None 

58 return device.DeviceSpec.from_string(current_device).device_type 

59 

60 

61def _generate_defun_backend(unique_api_name, preferred_device, func): 

62 function_attributes = { 

63 _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, 

64 _DEFUN_DEVICE_ATTRIBUTE: preferred_device, 

65 } 

66 return def_function.function( 

67 func=func, experimental_attributes=function_attributes, autograph=False) 

68 

69# pylint: disable=protected-access, invalid-name 

70@tf_export(v1=["nn.ctc_loss"]) 

71@dispatch.add_dispatch_support 

72def ctc_loss(labels, 

73 inputs=None, 

74 sequence_length=None, 

75 preprocess_collapse_repeated=False, 

76 ctc_merge_repeated=True, 

77 ignore_longer_outputs_than_inputs=False, 

78 time_major=True, 

79 logits=None): 

80 """Computes the CTC (Connectionist Temporal Classification) Loss. 

81 

82 This op implements the CTC loss as presented in (Graves et al., 2006). 

83 

84 Input requirements: 

85 

86 ``` 

87 sequence_length(b) <= time for all b 

88 

89 max(labels.indices(labels.indices[:, 1] == b, 2)) 

90 <= sequence_length(b) for all b. 

91 ``` 

92 

93 Notes: 

94 

95 This class performs the softmax operation for you, so inputs should 

96 be e.g. linear projections of outputs by an LSTM. 

97 

98 The `inputs` Tensor's innermost dimension size, `num_classes`, represents 

99 `num_labels + 1` classes, where num_labels is the number of true labels, and 

100 the largest value `(num_classes - 1)` is reserved for the blank label. 

101 

102 For example, for a vocabulary containing 3 labels `[a, b, c]`, 

103 `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`. 

104 

105 Regarding the arguments `preprocess_collapse_repeated` and 

106 `ctc_merge_repeated`: 

107 

108 If `preprocess_collapse_repeated` is True, then a preprocessing step runs 

109 before loss calculation, wherein repeated labels passed to the loss 

110 are merged into single labels. This is useful if the training labels come 

111 from, e.g., forced alignments and therefore have unnecessary repetitions. 

112 

113 If `ctc_merge_repeated` is set False, then deep within the CTC calculation, 

114 repeated non-blank labels will not be merged and are interpreted 

115 as individual labels. This is a simplified (non-standard) version of CTC. 

116 

117 Here is a table of the (roughly) expected first order behavior: 

118 

119 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True` 

120 

121 Classical CTC behavior: Outputs true repeated classes with blanks in 

122 between, and can also output repeated classes with no blanks in 

123 between that need to be collapsed by the decoder. 

124 

125 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False` 

126 

127 Never learns to output repeated classes, as they are collapsed 

128 in the input labels before training. 

129 

130 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False` 

131 

132 Outputs repeated classes with blanks in between, but generally does not 

133 require the decoder to collapse/merge repeated classes. 

134 

135 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True` 

136 

137 Untested. Very likely will not learn to output repeated classes. 

138 

139 The `ignore_longer_outputs_than_inputs` option allows to specify the behavior 

140 of the CTCLoss when dealing with sequences that have longer outputs than 

141 inputs. If true, the CTCLoss will simply return zero gradient for those 

142 items, otherwise an InvalidArgument error is returned, stopping training. 

143 

144 Args: 

145 labels: An `int32` `SparseTensor`. 

146 `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id 

147 for (batch b, time t). `labels.values[i]` must take on values in `[0, 

148 num_labels)`. See `core/ops/ctc_ops.cc` for more details. 

149 inputs: 3-D `float` `Tensor`. 

150 If time_major == False, this will be a `Tensor` shaped: `[batch_size, 

151 max_time, num_classes]`. 

152 If time_major == True (default), this will be a `Tensor` shaped: 

153 `[max_time, batch_size, num_classes]`. The logits. 

154 sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence 

155 lengths. 

156 preprocess_collapse_repeated: Boolean. Default: False. If True, repeated 

157 labels are collapsed prior to the CTC calculation. 

158 ctc_merge_repeated: Boolean. Default: True. 

159 ignore_longer_outputs_than_inputs: Boolean. Default: False. If True, 

160 sequences with longer outputs than inputs will be ignored. 

161 time_major: The shape format of the `inputs` Tensors. If True, these 

162 `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False, 

163 these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. 

164 Using `time_major = True` (default) is a bit more efficient because it 

165 avoids transposes at the beginning of the ctc_loss calculation. However, 

166 most TensorFlow data is batch-major, so by this function also accepts 

167 inputs in batch-major form. 

168 logits: Alias for inputs. 

169 

170 Returns: 

171 A 1-D `float` `Tensor`, size `[batch]`, containing the negative log 

172 probabilities. 

173 

174 Raises: 

175 TypeError: if labels is not a `SparseTensor`. 

176 

177 References: 

178 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 

179 with Recurrent Neural Networks: 

180 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 

181 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 

182 """ 

183 return _ctc_loss_impl( 

184 labels, 

185 inputs, 

186 sequence_length, 

187 preprocess_collapse_repeated, 

188 ctc_merge_repeated, 

189 ignore_longer_outputs_than_inputs, 

190 time_major, 

191 logits, 

192 use_cudnn=False) 

193 

194 

195def _ctc_loss_impl(labels, 

196 inputs=None, 

197 sequence_length=None, 

198 preprocess_collapse_repeated=False, 

199 ctc_merge_repeated=True, 

200 ignore_longer_outputs_than_inputs=False, 

201 time_major=True, 

202 logits=None, 

203 use_cudnn=False): 

204 # Helper function of ctc_loss with one additional param: 

205 # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank 

206 # index has to be 0. 

207 

208 # The second, third, etc output tensors contain the gradients. We use it in 

209 # _CTCLossGrad() below. 

210 if not isinstance(labels, sparse_tensor.SparseTensor): 

211 raise TypeError("Expected argument `labels` to be a SparseTensor. " 

212 f"Received labels={labels} of type: " 

213 f"{type(labels).__name__}") 

214 

215 # For internal calculations, we transpose to [time, batch, num_classes] 

216 inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs", 

217 inputs) 

218 

219 inputs = ops.convert_to_tensor(inputs, name="logits") 

220 if not time_major: 

221 inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) 

222 

223 orig_dtype = inputs.dtype 

224 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 

225 inputs = math_ops.cast(inputs, dtypes.float32) 

226 

227 # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the 

228 # blank index to be 0, but v1 views it as the last index. 

229 if use_cudnn: 

230 ctc_loss_func = gen_ctc_ops.ctc_loss_v2 

231 else: 

232 ctc_loss_func = gen_ctc_ops.ctc_loss 

233 

234 loss, _ = ctc_loss_func( 

235 inputs, 

236 labels.indices, 

237 labels.values, 

238 sequence_length, 

239 preprocess_collapse_repeated=preprocess_collapse_repeated, 

240 ctc_merge_repeated=ctc_merge_repeated, 

241 ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) 

242 

243 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 

244 loss = math_ops.cast(loss, orig_dtype) 

245 

246 return loss 

247 

248# pylint: disable=unused-argument 

249def _CTCLossGradImpl(op, grad_loss, _): 

250 # Outputs are: loss, grad 

251 # 

252 # Currently there is no way to take the second derivative of this op 

253 # due to the fused implementation's interaction with tf.gradients(), 

254 # so we make sure we prevent silently incorrect results by raising 

255 # an error if the second derivative is requested via prevent_gradient. 

256 grad_without_gradient = array_ops.prevent_gradient( 

257 op.outputs[1], 

258 message="Currently there is no way to take the second " 

259 " derivative of ctc_loss due to the fused implementation's interaction " 

260 " with tf.gradients()") 

261 # Return gradient for inputs and None for 

262 # labels_indices, labels_values and sequence_length 

263 return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] 

264 

265 

266# pylint: disable=unused-argument 

267@ops.RegisterGradient("CTCLoss") 

268def _CTCLossGrad(op, grad_loss, _): 

269 """The derivative provided by CTC Loss. 

270 

271 Args: 

272 op: the CTCLoss op. 

273 grad_loss: The backprop for cost. 

274 

275 Returns: 

276 The CTC Loss gradient. 

277 """ 

278 return _CTCLossGradImpl(op, grad_loss, _) 

279 

280 

281# pylint: disable=unused-argument 

282@ops.RegisterGradient("CTCLossV2") 

283def _CTCLossV2Grad(op, grad_loss, _): 

284 """The derivative provided by CTC Loss V2. 

285 

286 Args: 

287 op: the CTCLossV2 op. 

288 grad_loss: The backprop for cost. 

289 

290 Returns: 

291 The CTC Loss V2 gradient. 

292 """ 

293 return _CTCLossGradImpl(op, grad_loss, _) 

294 

295 

296@tf_export("nn.ctc_greedy_decoder") 

297@dispatch.add_dispatch_support 

298def ctc_greedy_decoder(inputs, 

299 sequence_length, 

300 merge_repeated=True, 

301 blank_index=None): 

302 """Performs greedy decoding on the logits given in input (best path). 

303 

304 Given a tensor as `inputs`, the `blank_index` parameter defines the class 

305 index of the blank symbol. 

306 

307 For example: 

308 

309 If `blank_index` is equal to 1: 

310 

311 >>> inf = float("inf") 

312 >>> logits = tf.constant([[[ 0., -inf, -inf], 

313 ... [ -2.3, -inf, -0.1]], 

314 ... [[ -inf, -0.5, -inf], 

315 ... [ -inf, -inf, -0.1]], 

316 ... [[ -inf, -inf, -inf], 

317 ... [ -0.1, -inf, -2.3]]]) 

318 >>> seq_lens = tf.constant([2, 3]) 

319 >>> outputs = tf.nn.ctc_greedy_decoder( 

320 ... logits, 

321 ... seq_lens, 

322 ... blank_index=1) 

323 

324 Notes: 

325 

326 - Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanks 

327 as regular elements when computing the probability of a sequence. 

328 - Default `blank_index` is `(num_classes - 1)`, unless overriden. 

329 

330 If `merge_repeated` is `True`, merge repeated classes in output. 

331 This means that if consecutive logits' maximum indices are the same, 

332 only the first of these is emitted. The sequence `A B B * B * B` (where '*' 

333 is the blank label) becomes 

334 

335 * `A B B B` if `merge_repeated=True`. 

336 * `A B B B B` if `merge_repeated=False`. 

337 

338 Args: 

339 inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. 

340 The logits. 

341 sequence_length: 1-D `int32` vector containing sequence lengths, having size 

342 `[batch_size]`. 

343 merge_repeated: Boolean. Default: True. 

344 blank_index: (Optional). Default: `num_classes - 1`. Define the class index 

345 to use for the blank label. Negative values will start from num_classes, 

346 ie, -1 will reproduce the ctc_greedy_decoder behavior of using 

347 num_classes - 1 for the blank symbol, which corresponds to the default. 

348 

349 Returns: 

350 A tuple `(decoded, neg_sum_logits)` where 

351 

352 decoded: A single-element list. `decoded[0]` 

353 is an `SparseTensor` containing the decoded outputs s.t.: 

354 

355 `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. 

356 The rows store: `[batch, time]`. 

357 

358 `decoded.values`: Values vector, size `(total_decoded_outputs)`. 

359 The vector stores the decoded classes. 

360 

361 `decoded.dense_shape`: Shape vector, size `(2)`. 

362 The shape values are: `[batch_size, max_decoded_length]` 

363 

364 neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the 

365 sequence found, the negative of the sum of the greatest logit at each 

366 timeframe. 

367 """ 

368 

369 outputs = gen_ctc_ops.ctc_greedy_decoder( 

370 inputs, 

371 sequence_length, 

372 merge_repeated=merge_repeated, 

373 blank_index=blank_index) 

374 (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs 

375 return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, 

376 decoded_shape)], log_probabilities) 

377 

378 

379@tf_export(v1=["nn.ctc_beam_search_decoder"]) 

380@dispatch.add_dispatch_support 

381def ctc_beam_search_decoder(inputs, 

382 sequence_length, 

383 beam_width=100, 

384 top_paths=1, 

385 merge_repeated=True): 

386 """Performs beam search decoding on the logits given in input. 

387 

388 **Note** Although in general greedy search is a special case of beam-search 

389 with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs 

390 from `ctc_greedy_decoder` in the treatment of blanks when computing the 

391 probability of a sequence: 

392 - `ctc_beam_search_decoder` treats blanks as sequence termination 

393 - `ctc_greedy_decoder` treats blanks as regular elements 

394 

395 If `merge_repeated` is `True`, merge repeated classes in the output beams. 

396 This means that if consecutive entries in a beam are the same, 

397 only the first of these is emitted. That is, when the sequence is 

398 `A B B * B * B` (where '*' is the blank label), the return value is: 

399 

400 * `A B` if `merge_repeated = True`. 

401 * `A B B B` if `merge_repeated = False`. 

402 

403 Args: 

404 inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`. 

405 The logits. 

406 sequence_length: 1-D `int32` vector containing sequence lengths, having size 

407 `[batch_size]`. 

408 beam_width: An int scalar >= 0 (beam search beam width). 

409 top_paths: An int scalar >= 0, <= beam_width (controls output size). 

410 merge_repeated: Boolean. Default: True. 

411 

412 Returns: 

413 A tuple `(decoded, log_probabilities)` where 

414 

415 decoded: A list of length top_paths, where `decoded[j]` 

416 is a `SparseTensor` containing the decoded outputs: 

417 

418 `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)` 

419 The rows store: [batch, time]. 

420 

421 `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`. 

422 The vector stores the decoded classes for beam j. 

423 

424 `decoded[j].dense_shape`: Shape vector, size `(2)`. 

425 The shape values are: `[batch_size, max_decoded_length[j]]`. 

426 

427 log_probability: A `float` matrix `(batch_size x top_paths)` containing 

428 sequence log-probabilities. 

429 """ 

430 

431 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( 

432 gen_ctc_ops.ctc_beam_search_decoder( 

433 inputs, 

434 sequence_length, 

435 beam_width=beam_width, 

436 top_paths=top_paths, 

437 merge_repeated=merge_repeated)) 

438 

439 return ([ 

440 sparse_tensor.SparseTensor(ix, val, shape) 

441 for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes) 

442 ], log_probabilities) 

443 

444 

445@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) 

446@dispatch.add_dispatch_support 

447def ctc_beam_search_decoder_v2(inputs, 

448 sequence_length, 

449 beam_width=100, 

450 top_paths=1): 

451 """Performs beam search decoding on the logits given in input. 

452 

453 **Note** Although in general greedy search is a special case of beam-search 

454 with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs 

455 from `ctc_greedy_decoder` in the treatment of blanks when computing the 

456 probability of a sequence: 

457 - `ctc_beam_search_decoder` treats blanks as sequence termination 

458 - `ctc_greedy_decoder` treats blanks as regular elements 

459 

460 Args: 

461 inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`. 

462 The logits. 

463 sequence_length: 1-D `int32` vector containing sequence lengths, having size 

464 `[batch_size]`. 

465 beam_width: An int scalar >= 0 (beam search beam width). 

466 top_paths: An int scalar >= 0, <= beam_width (controls output size). 

467 

468 Returns: 

469 A tuple `(decoded, log_probabilities)` where 

470 

471 decoded: A list of length top_paths, where `decoded[j]` 

472 is a `SparseTensor` containing the decoded outputs: 

473 

474 `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`; 

475 The rows store: `[batch, time]`. 

476 

477 `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`. 

478 The vector stores the decoded classes for beam `j`. 

479 

480 `decoded[j].dense_shape`: Shape vector, size `(2)`. 

481 The shape values are: `[batch_size, max_decoded_length[j]]`. 

482 

483 log_probability: A `float` matrix `[batch_size, top_paths]` containing 

484 sequence log-probabilities. 

485 """ 

486 

487 # Note, merge_repeated is an invalid optimization that is removed from the 

488 # public API: it returns low probability paths. 

489 return ctc_beam_search_decoder( 

490 inputs, 

491 sequence_length=sequence_length, 

492 beam_width=beam_width, 

493 top_paths=top_paths, 

494 merge_repeated=False) 

495 

496 

497ops.NotDifferentiable("CTCGreedyDecoder") 

498ops.NotDifferentiable("CTCBeamSearchDecoder") 

499 

500 

501def _ctc_state_trans(label_seq): 

502 """Computes CTC alignment model transition matrix. 

503 

504 Args: 

505 label_seq: tensor of shape [batch_size, max_seq_length] 

506 

507 Returns: 

508 tensor of shape [batch_size, states, states] with a state transition matrix 

509 computed for each sequence of the batch. 

510 """ 

511 

512 with ops.name_scope("ctc_state_trans"): 

513 label_seq = ops.convert_to_tensor(label_seq, name="label_seq") 

514 batch_size = _get_dim(label_seq, 0) 

515 num_labels = _get_dim(label_seq, 1) 

516 

517 num_label_states = num_labels + 1 

518 num_states = 2 * num_label_states 

519 

520 label_states = math_ops.range(num_label_states) 

521 blank_states = label_states + num_label_states 

522 

523 # Start state to first label. 

524 start_to_label = [[1, 0]] 

525 

526 # Blank to label transitions. 

527 blank_to_label = array_ops_stack.stack( 

528 [label_states[1:], blank_states[:-1]], 1) 

529 

530 # Label to blank transitions. 

531 label_to_blank = array_ops_stack.stack([blank_states, label_states], 1) 

532 

533 # Scatter transitions that don't depend on sequence. 

534 indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank], 

535 0) 

536 values = array_ops.ones([_get_dim(indices, 0)]) 

537 trans = array_ops.scatter_nd( 

538 indices, values, shape=[num_states, num_states]) 

539 trans += linalg_ops.eye(num_states) # Self-loops. 

540 

541 # Label to label transitions. Disallow transitions between repeated labels 

542 # with no blank state in between. 

543 batch_idx = array_ops.zeros_like(label_states[2:]) 

544 indices = array_ops_stack.stack( 

545 [batch_idx, label_states[2:], label_states[1:-1]], 1) 

546 indices = array_ops.tile( 

547 array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) 

548 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] 

549 indices += array_ops.expand_dims(batch_idx, 1) 

550 repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) 

551 values = 1.0 - math_ops.cast(repeats, dtypes.float32) 

552 batched_shape = [batch_size, num_states, num_states] 

553 label_to_label = array_ops.scatter_nd(indices, values, batched_shape) 

554 

555 return array_ops.expand_dims(trans, 0) + label_to_label 

556 

557 

558def ctc_state_log_probs(seq_lengths, max_seq_length): 

559 """Computes CTC alignment initial and final state log probabilities. 

560 

561 Create the initial/final state values directly as log values to avoid 

562 having to take a float64 log on tpu (which does not exist). 

563 

564 Args: 

565 seq_lengths: int tensor of shape [batch_size], seq lengths in the batch. 

566 max_seq_length: int, max sequence length possible. 

567 

568 Returns: 

569 initial_state_log_probs, final_state_log_probs 

570 """ 

571 

572 batch_size = _get_dim(seq_lengths, 0) 

573 num_label_states = max_seq_length + 1 

574 num_duration_states = 2 

575 num_states = num_duration_states * num_label_states 

576 log_0 = math_ops.cast( 

577 math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32) 

578 

579 initial_state_log_probs = array_ops.one_hot( 

580 indices=array_ops.zeros([batch_size], dtype=dtypes.int32), 

581 depth=num_states, 

582 on_value=0.0, 

583 off_value=log_0, 

584 axis=1) 

585 

586 label_final_state_mask = array_ops.one_hot( 

587 seq_lengths, depth=num_label_states, axis=0) 

588 duration_final_state_mask = array_ops.ones( 

589 [num_duration_states, 1, batch_size]) 

590 final_state_mask = duration_final_state_mask * label_final_state_mask 

591 final_state_log_probs = (1.0 - final_state_mask) * log_0 

592 final_state_log_probs = array_ops.reshape(final_state_log_probs, 

593 [num_states, batch_size]) 

594 

595 return initial_state_log_probs, array_ops.transpose(final_state_log_probs) 

596 

597 

598def _ilabel_to_state(labels, num_labels, ilabel_log_probs): 

599 """Project ilabel log probs to state log probs.""" 

600 

601 num_label_states = _get_dim(labels, 1) 

602 blank = ilabel_log_probs[:, :, :1] 

603 blank = array_ops.tile(blank, [1, 1, num_label_states + 1]) 

604 one_hot = array_ops.one_hot(labels, depth=num_labels) 

605 one_hot = array_ops.expand_dims(one_hot, axis=0) 

606 ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2) 

607 state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3) 

608 state_log_probs = array_ops.concat([state_log_probs, blank], axis=2) 

609 return array_ops.pad( 

610 state_log_probs, [[0, 0], [0, 0], [1, 0]], 

611 constant_values=math_ops.log(0.0)) 

612 

613 

614def _state_to_olabel(labels, num_labels, states): 

615 """Sum state log probs to ilabel log probs.""" 

616 

617 num_label_states = _get_dim(labels, 1) + 1 

618 label_states = states[:, :, 1:num_label_states] 

619 blank_states = states[:, :, num_label_states:] 

620 one_hot = array_ops.one_hot( 

621 labels - 1, 

622 depth=(num_labels - 1), 

623 on_value=0.0, 

624 off_value=math_ops.log(0.0)) 

625 one_hot = array_ops.expand_dims(one_hot, axis=0) 

626 label_states = array_ops.expand_dims(label_states, axis=3) 

627 label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2) 

628 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 

629 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 

630 

631 

632# pylint: disable=redefined-outer-name 

633def _state_to_olabel_unique(labels, num_labels, states, unique): 

634 """Sum state log probs to ilabel log probs using unique label indices.""" 

635 

636 num_label_states = _get_dim(labels, 1) + 1 

637 label_states = states[:, :, 1:num_label_states] 

638 blank_states = states[:, :, num_label_states:] 

639 

640 unique_y, unique_idx = unique 

641 mul_reduce = _sum_states(unique_idx, label_states) 

642 

643 num_frames = _get_dim(states, 0) 

644 batch_size = _get_dim(states, 1) 

645 num_states = num_label_states - 1 

646 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) 

647 batch_state_major = array_ops.reshape(batch_state_major, 

648 [batch_size * num_states, num_frames]) 

649 batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels 

650 indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) 

651 indices = array_ops.reshape(indices, [-1, 1]) 

652 scatter = array_ops.scatter_nd( 

653 indices=indices, 

654 updates=batch_state_major, 

655 shape=[batch_size * num_labels, num_frames]) 

656 scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) 

657 

658 mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool) 

659 mask = array_ops.scatter_nd( 

660 indices=indices, 

661 updates=mask, 

662 shape=[batch_size * num_labels, num_frames]) 

663 mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames]) 

664 

665 scatter = array_ops.where( 

666 mask, scatter, 

667 array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0))) 

668 

669 label_olabels = array_ops.transpose(scatter, [2, 0, 1]) 

670 label_olabels = label_olabels[:, :, 1:] 

671 

672 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 

673 

674 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 

675 

676 

677def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): 

678 """Computes the CTC loss and gradients. 

679 

680 Most users will want fwd_bwd.ctc_loss 

681 

682 This function returns the computed gradient, it does not have a gradient 

683 of its own defined. 

684 

685 Args: 

686 logits: tensor of shape [frames, batch_size, num_labels] 

687 labels: tensor of shape [batch_size, max_label_seq_length] 

688 label_length: tensor of shape [batch_size] Length of reference label 

689 sequence in labels. 

690 logit_length: tensor of shape [batch_size] Length of input sequence in 

691 logits. 

692 unique: (optional) unique label indices as computed by unique(labels) If 

693 supplied, enables an implementation that is faster and more memory 

694 efficient on TPU. 

695 

696 Returns: 

697 loss: tensor of shape [batch_size] 

698 gradient: tensor of shape [frames, batch_size, num_labels] 

699 """ 

700 

701 num_labels = _get_dim(logits, 2) 

702 max_label_seq_length = _get_dim(labels, 1) 

703 

704 ilabel_log_probs = nn_ops.log_softmax(logits) 

705 state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs) 

706 state_trans_probs = _ctc_state_trans(labels) 

707 initial_state_log_probs, final_state_log_probs = ctc_state_log_probs( 

708 label_length, max_label_seq_length) 

709 fwd_bwd_log_probs, log_likelihood = _forward_backward_log( 

710 state_trans_log_probs=math_ops.log(state_trans_probs), 

711 initial_state_log_probs=initial_state_log_probs, 

712 final_state_log_probs=final_state_log_probs, 

713 observed_log_probs=state_log_probs, 

714 sequence_length=logit_length) 

715 

716 if unique: 

717 olabel_log_probs = _state_to_olabel_unique(labels, num_labels, 

718 fwd_bwd_log_probs, unique) 

719 else: 

720 olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) 

721 

722 grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) 

723 

724 # Applies the sequence mask for the gradient. It is enough to appply the mask 

725 # only for ilabel_log_probs because olabel_log_probs already consider the 

726 # mask. However, it is just safe and clean to apply it for the gradient. 

727 max_logit_length = _get_dim(logits, 0) 

728 logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, 

729 dtypes.float32) 

730 logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) 

731 logit_mask = array_ops.expand_dims(logit_mask, axis=2) 

732 grad *= logit_mask 

733 

734 loss = -log_likelihood 

735 return loss, grad 

736 

737 

738def _ctc_loss_grad(op, grad_loss, _): 

739 grad = op.outputs[1] 

740 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad] 

741 grad += [None] * (len(op.inputs) - len(grad)) 

742 return grad 

743 

744 

745def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major, 

746 blank_index): 

747 part_before = logits[:, :, :blank_index] 

748 part_after = logits[:, :, blank_index + 1:] 

749 part_blank = logits[:, :, blank_index:blank_index + 1] 

750 logits = array_ops.concat([part_before, part_after, part_blank], axis=2) 

751 labels = sparse_tensor.SparseTensor( 

752 labels.indices, 

753 array_ops.where(labels.values < blank_index, labels.values, 

754 labels.values - 1), labels.dense_shape) 

755 return _ctc_loss_impl( 

756 labels=labels, 

757 inputs=logits, 

758 sequence_length=logit_length, 

759 time_major=logits_time_major, 

760 use_cudnn=False) 

761 

762 

763def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major, 

764 blank_index): 

765 part_before = logits[:, :, :blank_index] 

766 part_after = logits[:, :, blank_index + 1:] 

767 part_blank = logits[:, :, blank_index:blank_index + 1] 

768 logits = array_ops.concat([part_blank, part_before, part_after], axis=2) 

769 labels = sparse_tensor.SparseTensor( 

770 labels.indices, 

771 array_ops.where(labels.values < blank_index, labels.values + 1, 

772 labels.values), labels.dense_shape) 

773 return _ctc_loss_impl( 

774 labels=labels, 

775 inputs=logits, 

776 sequence_length=logit_length, 

777 time_major=logits_time_major, 

778 use_cudnn=True) 

779 

780 

781def _ctc_loss_shape(op): 

782 return [op.inputs[2].get_shape(), op.inputs[0].get_shape()] 

783 

784 

785# pylint: disable=protected-access, invalid-name 

786@tf_export(v1=["nn.ctc_loss_v2"]) 

787@dispatch.add_dispatch_support 

788def ctc_loss_v2(labels, 

789 logits, 

790 label_length, 

791 logit_length, 

792 logits_time_major=True, 

793 unique=None, 

794 blank_index=None, 

795 name=None): 

796 """Computes CTC (Connectionist Temporal Classification) loss. 

797 

798 This op implements the CTC loss as presented in (Graves et al., 2006). 

799 

800 Notes: 

801 

802 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 

803 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 

804 - Labels may be supplied as either a dense, zero-padded tensor with a 

805 vector of label sequence lengths OR as a SparseTensor. 

806 - On TPU and GPU: Only dense padded labels are supported. 

807 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 

808 a SparseTensor will be significantly faster. 

809 - Default blank label is 0 rather num_classes - 1, unless overridden by 

810 blank_index. 

811 

812 Args: 

813 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 

814 logits: tensor of shape [frames, batch_size, num_labels], if 

815 logits_time_major == False, shape is [batch_size, frames, num_labels]. 

816 label_length: tensor of shape [batch_size], None if labels is SparseTensor 

817 Length of reference label sequence in labels. 

818 logit_length: tensor of shape [batch_size] Length of input sequence in 

819 logits. 

820 logits_time_major: (optional) If True (default), logits is shaped [time, 

821 batch, logits]. If False, shape is [batch, time, logits] 

822 unique: (optional) Unique label indices as computed by 

823 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 

824 implementation on TPU. 

825 blank_index: (optional) Set the class index to use for the blank label. 

826 Negative values will start from num_classes, ie, -1 will reproduce the 

827 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 

828 some memory/performance overhead to switching from the default of 0 as an 

829 additional shifted copy of the logits may be created. 

830 name: A name for this `Op`. Defaults to "ctc_loss_dense". 

831 

832 Returns: 

833 loss: tensor of shape [batch_size], negative log probabilities. 

834 

835 References: 

836 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 

837 with Recurrent Neural Networks: 

838 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 

839 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 

840 """ 

841 if isinstance(labels, sparse_tensor.SparseTensor): 

842 if blank_index is None: 

843 raise ValueError( 

844 "Argument `blank_index` must be provided when labels is a " 

845 "SparseTensor.") 

846 

847 if blank_index < 0: 

848 blank_index += _get_dim(logits, 2) 

849 

850 if blank_index != _get_dim(logits, 2) - 1: 

851 logits = array_ops.concat([ 

852 logits[:, :, :blank_index], 

853 logits[:, :, blank_index + 1:], 

854 logits[:, :, blank_index:blank_index + 1], 

855 ], 

856 axis=2) 

857 labels = sparse_tensor.SparseTensor( 

858 labels.indices, 

859 array_ops.where(labels.values < blank_index, labels.values, 

860 labels.values - 1), labels.dense_shape) 

861 

862 return ctc_loss( 

863 labels=labels, 

864 inputs=logits, 

865 sequence_length=logit_length, 

866 time_major=logits_time_major) 

867 

868 if blank_index is None: 

869 blank_index = 0 

870 

871 return ctc_loss_dense( 

872 labels=labels, 

873 logits=logits, 

874 label_length=label_length, 

875 logit_length=logit_length, 

876 logits_time_major=logits_time_major, 

877 unique=unique, 

878 blank_index=blank_index, 

879 name=name) 

880 

881 

882@tf_export("nn.ctc_loss", v1=[]) 

883@dispatch.add_dispatch_support 

884def ctc_loss_v3(labels, 

885 logits, 

886 label_length, 

887 logit_length, 

888 logits_time_major=True, 

889 unique=None, 

890 blank_index=None, 

891 name=None): 

892 """Computes CTC (Connectionist Temporal Classification) loss. 

893 

894 This op implements the CTC loss as presented in 

895 [Graves et al., 2006](https://www.cs.toronto.edu/~graves/icml_2006.pdf) 

896 

897 Connectionist temporal classification (CTC) is a type of neural network output 

898 and associated scoring function, for training recurrent neural networks (RNNs) 

899 such as LSTM networks to tackle sequence problems where the timing is 

900 variable. It can be used for tasks like on-line handwriting recognition or 

901 recognizing phones in speech audio. CTC refers to the outputs and scoring, and 

902 is independent of the underlying neural network structure. 

903 

904 Notes: 

905 

906 - This class performs the softmax operation for you, so `logits` should be 

907 e.g. linear projections of outputs by an LSTM. 

908 - Outputs true repeated classes with blanks in between, and can also output 

909 repeated classes with no blanks in between that need to be collapsed by the 

910 decoder. 

911 - `labels` may be supplied as either a dense, zero-padded `Tensor` with a 

912 vector of label sequence lengths OR as a `SparseTensor`. 

913 - On TPU: Only dense padded `labels` are supported. 

914 - On CPU and GPU: Caller may use `SparseTensor` or dense padded `labels` 

915 but calling with a `SparseTensor` will be significantly faster. 

916 - Default blank label is `0` instead of `num_labels - 1` (where `num_labels` 

917 is the innermost dimension size of `logits`), unless overridden by 

918 `blank_index`. 

919 

920 >>> tf.random.set_seed(50) 

921 >>> batch_size = 8 

922 >>> num_labels = 6 

923 >>> max_label_length = 5 

924 >>> num_frames = 12 

925 >>> labels = tf.random.uniform([batch_size, max_label_length], 

926 ... minval=1, maxval=num_labels, dtype=tf.int64) 

927 >>> logits = tf.random.uniform([num_frames, batch_size, num_labels]) 

928 >>> label_length = tf.random.uniform([batch_size], minval=2, 

929 ... maxval=max_label_length, dtype=tf.int64) 

930 >>> label_mask = tf.sequence_mask(label_length, maxlen=max_label_length, 

931 ... dtype=label_length.dtype) 

932 >>> labels *= label_mask 

933 >>> logit_length = [num_frames] * batch_size 

934 >>> with tf.GradientTape() as t: 

935 ... t.watch(logits) 

936 ... ref_loss = tf.nn.ctc_loss( 

937 ... labels=labels, 

938 ... logits=logits, 

939 ... label_length=label_length, 

940 ... logit_length=logit_length, 

941 ... blank_index=0) 

942 >>> ref_grad = t.gradient(ref_loss, logits) 

943 

944 Args: 

945 labels: `Tensor` of shape `[batch_size, max_label_seq_length]` or 

946 `SparseTensor`. 

947 logits: `Tensor` of shape `[frames, batch_size, num_labels]`. If 

948 `logits_time_major == False`, shape is `[batch_size, frames, num_labels]`. 

949 label_length: `Tensor` of shape `[batch_size]`. None, if `labels` is a 

950 `SparseTensor`. Length of reference label sequence in `labels`. 

951 logit_length: `Tensor` of shape `[batch_size]`. Length of input sequence in 

952 `logits`. 

953 logits_time_major: (optional) If True (default), `logits` is shaped [frames, 

954 batch_size, num_labels]. If False, shape is 

955 `[batch_size, frames, num_labels]`. 

956 unique: (optional) Unique label indices as computed by 

957 `ctc_unique_labels(labels)`. If supplied, enable a faster, memory 

958 efficient implementation on TPU. 

959 blank_index: (optional) Set the class index to use for the blank label. 

960 Negative values will start from `num_labels`, ie, `-1` will reproduce the 

961 ctc_loss behavior of using `num_labels - 1` for the blank symbol. There is 

962 some memory/performance overhead to switching from the default of 0 as an 

963 additional shifted copy of `logits` may be created. 

964 name: A name for this `Op`. Defaults to "ctc_loss_dense". 

965 

966 Returns: 

967 loss: A 1-D `float Tensor` of shape `[batch_size]`, containing negative log 

968 probabilities. 

969 

970 Raises: 

971 ValueError: Argument `blank_index` must be provided when `labels` is a 

972 `SparseTensor`. 

973 

974 References: 

975 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 

976 with Recurrent Neural Networks: 

977 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 

978 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 

979 

980 https://en.wikipedia.org/wiki/Connectionist_temporal_classification 

981 """ 

982 if isinstance(labels, sparse_tensor.SparseTensor): 

983 if blank_index is None: 

984 raise ValueError( 

985 "Argument `blank_index` must be provided when `labels` is a " 

986 "`SparseTensor`.") 

987 

988 if blank_index < 0: 

989 blank_index += _get_dim(logits, 2) 

990 

991 logits = ops.convert_to_tensor(logits, name="logits") 

992 

993 params = { 

994 "labels": labels, 

995 "logits": logits, 

996 "logit_length": logit_length, 

997 "logits_time_major": logits_time_major, 

998 "blank_index": blank_index 

999 } 

1000 

1001 if context.executing_eagerly(): 

1002 device_type = _get_context_device_type() 

1003 can_use_gpu = ( 

1004 # Either user specified GPU or unspecified but GPU is available. 

1005 (device_type == _GPU_DEVICE_NAME or 

1006 (device_type is None and context.num_gpus() > 0))) 

1007 # Under eager context, check the device placement and prefer the 

1008 if can_use_gpu: 

1009 res = _ctc_loss_op_cudnn(**params) 

1010 else: 

1011 res = _ctc_loss_op_standard(**params) 

1012 else: 

1013 api_name = "ctc_loss_" + str(uuid.uuid4()) 

1014 ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME, 

1015 _ctc_loss_op_standard) 

1016 ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME, 

1017 _ctc_loss_op_cudnn) 

1018 res = ctc_loss_op_standard(**params) 

1019 concrete_func = ctc_loss_op_cudnn.get_concrete_function(**params) 

1020 concrete_func.add_to_graph() 

1021 concrete_func.add_gradient_functions_to_graph() 

1022 return res 

1023 

1024 if blank_index is None: 

1025 blank_index = 0 

1026 

1027 return ctc_loss_dense( 

1028 labels=labels, 

1029 logits=logits, 

1030 label_length=label_length, 

1031 logit_length=logit_length, 

1032 logits_time_major=logits_time_major, 

1033 unique=unique, 

1034 blank_index=blank_index, 

1035 name=name) 

1036 

1037 

1038def ctc_loss_dense(labels, 

1039 logits, 

1040 label_length, 

1041 logit_length, 

1042 logits_time_major=True, 

1043 unique=None, 

1044 blank_index=0, 

1045 name=None): 

1046 """Computes CTC (Connectionist Temporal Classification) loss. 

1047 

1048 This op implements the CTC loss as presented in (Graves et al., 2006), 

1049 using the batched forward backward algorithm described in (Sim et al., 2017). 

1050 

1051 Notes: 

1052 Significant differences from `tf.compat.v1.nn.ctc_loss`: 

1053 Supports GPU and TPU (`tf.compat.v1.nn.ctc_loss` supports CPU only): 

1054 For batched operations, GPU and TPU are significantly faster than using 

1055 `ctc_loss` on CPU. 

1056 This implementation runs on CPU, but significantly slower than ctc_loss. 

1057 Blank label is 0 rather num_classes - 1, unless overridden by blank_index. 

1058 Logits and labels are dense arrays with padding rather than SparseTensor. 

1059 The only mode supported is the same as: 

1060 preprocess_collapse_repeated=False, ctc_merge_repeated=True 

1061 To collapse labels, the caller can preprocess label sequence first. 

1062 

1063 The dense implementation supports both CPU, GPU and TPU. A fast path is 

1064 provided that significantly improves memory use for large vocabulary if the 

1065 caller preprocesses label sequences to get unique label indices on the CPU 

1066 (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in 

1067 the optional "unique" kwarg. This is especially useful for TPU and GPU but 

1068 also works with if used on CPU. 

1069 

1070 Args: 

1071 labels: tensor of shape [batch_size, max_label_seq_length] 

1072 logits: tensor of shape [frames, batch_size, num_labels], if 

1073 logits_time_major == False, shape is [batch_size, frames, num_labels]. 

1074 label_length: tensor of shape [batch_size] Length of reference label 

1075 sequence in labels. 

1076 logit_length: tensor of shape [batch_size] Length of input sequence in 

1077 logits. 

1078 logits_time_major: (optional) If True (default), logits is shaped [time, 

1079 batch, logits]. If False, shape is [batch, time, logits] 

1080 unique: (optional) Unique label indices as computed by unique(labels). If 

1081 supplied, enable a faster, memory efficient implementation on TPU. 

1082 blank_index: (optional) Set the class index to use for the blank label. 

1083 Negative values will start from num_classes, ie, -1 will reproduce the 

1084 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 

1085 some memory/performance overhead to switching from the default of 0 as an 

1086 additional shifted copy of the logits may be created. 

1087 name: A name for this `Op`. Defaults to "ctc_loss_dense". 

1088 

1089 Returns: 

1090 loss: tensor of shape [batch_size], negative log probabilities. 

1091 

1092 References: 

1093 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 

1094 with Recurrent Neural Networks: 

1095 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 

1096 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 

1097 Improving the efficiency of forward-backward algorithm using batched 

1098 computation in TensorFlow: 

1099 [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944) 

1100 ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf)) 

1101 """ 

1102 

1103 with ops.name_scope(name, "ctc_loss_dense", 

1104 [logits, labels, label_length, logit_length]): 

1105 logits = ops.convert_to_tensor(logits, name="logits") 

1106 labels = ops.convert_to_tensor(labels, name="labels") 

1107 label_length = ops.convert_to_tensor(label_length, name="label_length") 

1108 logit_length = ops.convert_to_tensor(logit_length, name="logit_length") 

1109 

1110 orig_dtype = logits.dtype 

1111 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 

1112 logits = math_ops.cast(logits, dtypes.float32) 

1113 

1114 if not logits_time_major: 

1115 logits = array_ops.transpose(logits, perm=[1, 0, 2]) 

1116 

1117 if blank_index != 0: 

1118 if blank_index < 0: 

1119 blank_index += _get_dim(logits, 2) 

1120 logits = array_ops.concat([ 

1121 logits[:, :, blank_index:blank_index + 1], 

1122 logits[:, :, :blank_index], 

1123 logits[:, :, blank_index + 1:], 

1124 ], 

1125 axis=2) 

1126 labels = array_ops.where(labels < blank_index, labels + 1, labels) 

1127 

1128 args = [logits, labels, label_length, logit_length] 

1129 

1130 if unique: 

1131 unique_y, unique_idx = unique 

1132 if blank_index != 0: 

1133 unique_y = array_ops.where(unique_y < blank_index, unique_y + 1, 

1134 unique_y) 

1135 label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1 

1136 max_label_length = _get_dim(unique_y, 1) 

1137 label_mask = array_ops.sequence_mask(label_mask_len, max_label_length) 

1138 unique_y = array_ops.where(label_mask, unique_y, 

1139 array_ops.zeros_like(unique_y)) 

1140 args.extend([unique_y, unique_idx]) 

1141 

1142 @custom_gradient.custom_gradient 

1143 def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t, 

1144 *unique_t): 

1145 """Compute CTC loss.""" 

1146 logits_t.set_shape(logits.shape) 

1147 labels_t.set_shape(labels.shape) 

1148 label_length_t.set_shape(label_length.shape) 

1149 logit_length_t.set_shape(logit_length.shape) 

1150 kwargs = dict( 

1151 logits=logits_t, 

1152 labels=labels_t, 

1153 label_length=label_length_t, 

1154 logit_length=logit_length_t) 

1155 if unique_t: 

1156 kwargs["unique"] = unique_t 

1157 result = ctc_loss_and_grad(**kwargs) 

1158 def grad(grad_loss): 

1159 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]] 

1160 grad += [None] * (len(args) - len(grad)) 

1161 return grad 

1162 

1163 return result[0], grad 

1164 

1165 loss = compute_ctc_loss(*args) 

1166 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 

1167 loss = math_ops.cast(loss, orig_dtype) 

1168 return loss 

1169 

1170 

1171@tf_export("nn.collapse_repeated") 

1172@dispatch.add_dispatch_support 

1173def collapse_repeated(labels, seq_length, name=None): 

1174 """Merge repeated labels into single labels. 

1175 

1176 Args: 

1177 labels: Tensor of shape [batch, max value in seq_length] 

1178 seq_length: Tensor of shape [batch], sequence length of each batch element. 

1179 name: A name for this `Op`. Defaults to "collapse_repeated_labels". 

1180 

1181 Returns: 

1182 A tuple `(collapsed_labels, new_seq_length)` where 

1183 

1184 collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated 

1185 labels collapsed and padded to max_seq_length, eg: 

1186 `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]` 

1187 

1188 new_seq_length: int tensor of shape [batch] with new sequence lengths. 

1189 """ 

1190 

1191 with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): 

1192 labels = ops.convert_to_tensor(labels, name="labels") 

1193 seq_length = ops.convert_to_tensor(seq_length, name="seq_length") 

1194 

1195 # Mask labels that don't equal previous label. 

1196 label_mask = array_ops.concat([ 

1197 array_ops.ones_like(labels[:, :1], dtypes.bool), 

1198 math_ops.not_equal(labels[:, 1:], labels[:, :-1]) 

1199 ], 

1200 axis=1) 

1201 

1202 # Filter labels that aren't in the original sequence. 

1203 maxlen = _get_dim(labels, 1) 

1204 seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) 

1205 label_mask = math_ops.logical_and(label_mask, seq_mask) 

1206 

1207 # Count masks for new sequence lengths. 

1208 new_seq_len = math_ops.reduce_sum( 

1209 math_ops.cast(label_mask, dtypes.int32), axis=1) 

1210 

1211 # Mask indexes based on sequence length mask. 

1212 new_maxlen = math_ops.reduce_max(new_seq_len) 

1213 idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) 

1214 

1215 # Flatten everything and mask out labels to keep and sparse indices. 

1216 flat_labels = array_ops.reshape(labels, [-1]) 

1217 flat_label_mask = array_ops.reshape(label_mask, [-1]) 

1218 flat_idx_mask = array_ops.reshape(idx_mask, [-1]) 

1219 idx = math_ops.range(_get_dim(flat_idx_mask, 0)) 

1220 

1221 # Scatter to flat shape. 

1222 flat = array_ops.scatter_nd( 

1223 indices=array_ops.expand_dims( 

1224 array_ops.boolean_mask(idx, flat_idx_mask), axis=1), 

1225 updates=array_ops.boolean_mask(flat_labels, flat_label_mask), 

1226 shape=array_ops.shape(flat_idx_mask)) 

1227 

1228 # Reshape back to square batch. 

1229 batch_size = _get_dim(labels, 0) 

1230 new_shape = [batch_size, new_maxlen] 

1231 return (array_ops.reshape(flat, new_shape), 

1232 math_ops.cast(new_seq_len, seq_length.dtype)) 

1233 

1234 

1235def dense_labels_to_sparse(dense, length): 

1236 """Convert dense labels with sequence lengths to sparse tensor. 

1237 

1238 Args: 

1239 dense: tensor of shape [batch, max_length] 

1240 length: int tensor of shape [batch] The length of each sequence in dense. 

1241 

1242 Returns: 

1243 tf.sparse.SparseTensor with values only for the valid elements of sequences. 

1244 """ 

1245 

1246 flat_values = array_ops.reshape(dense, [-1]) 

1247 flat_indices = math_ops.range( 

1248 array_ops.shape(flat_values, out_type=dtypes.int64)[0]) 

1249 mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1]) 

1250 flat_mask = array_ops.reshape(mask, [-1]) 

1251 indices = array_ops.expand_dims( 

1252 array_ops.boolean_mask(flat_indices, flat_mask), 1) 

1253 values = array_ops.boolean_mask(flat_values, flat_mask) 

1254 sparse = sparse_tensor.SparseTensor( 

1255 indices=indices, 

1256 values=math_ops.cast(values, dtypes.int32), 

1257 dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64)) 

1258 reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense)) 

1259 max_length = math_ops.reduce_max(length) 

1260 return sparse_tensor.SparseTensor( 

1261 indices=reshaped.indices, 

1262 values=reshaped.values, 

1263 dense_shape=[ 

1264 math_ops.cast(reshaped.dense_shape[0], dtypes.int64), 

1265 math_ops.cast(max_length, dtypes.int64) 

1266 ]) 

1267 

1268 

1269@tf_export("nn.ctc_unique_labels") 

1270@dispatch.add_dispatch_support 

1271def ctc_unique_labels(labels, name=None): 

1272 """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. 

1273 

1274 For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be 

1275 used to preprocess labels in input pipeline to for better speed/memory use 

1276 computing the ctc loss on TPU. 

1277 

1278 Example: 

1279 ctc_unique_labels([[3, 4, 4, 3]]) -> 

1280 unique labels padded with 0: [[3, 4, 0, 0]] 

1281 indices of original labels in unique: [0, 1, 1, 0] 

1282 

1283 Args: 

1284 labels: tensor of shape [batch_size, max_label_length] padded with 0. 

1285 name: A name for this `Op`. Defaults to "ctc_unique_labels". 

1286 

1287 Returns: 

1288 tuple of 

1289 - unique labels, tensor of shape `[batch_size, max_label_length]` 

1290 - indices into unique labels, shape `[batch_size, max_label_length]` 

1291 """ 

1292 

1293 with ops.name_scope(name, "ctc_unique_labels", [labels]): 

1294 labels = ops.convert_to_tensor(labels, name="labels") 

1295 

1296 def _unique(x): 

1297 u = array_ops.unique(x) 

1298 y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]]) 

1299 y = math_ops.cast(y, dtypes.int64) 

1300 return [y, u.idx] 

1301 

1302 return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32]) 

1303 

1304 

1305def _sum_states(idx, states): 

1306 """Take logsumexp for each unique state out of all label states. 

1307 

1308 Args: 

1309 idx: tensor of shape [batch, label_length] For each sequence, indices into a 

1310 set of unique labels as computed by calling unique. 

1311 states: tensor of shape [frames, batch, label_length] Log probabilities for 

1312 each label state. 

1313 

1314 Returns: 

1315 tensor of shape [frames, batch_size, label_length], log probabilities summed 

1316 for each unique label of the sequence. 

1317 """ 

1318 

1319 with ops.name_scope("sum_states"): 

1320 idx = ops.convert_to_tensor(idx, name="idx") 

1321 num_states = _get_dim(states, 2) 

1322 states = array_ops.expand_dims(states, axis=2) 

1323 one_hot = array_ops.one_hot( 

1324 idx, 

1325 depth=num_states, 

1326 on_value=0.0, 

1327 off_value=math_ops.log(0.0), 

1328 axis=1) 

1329 return math_ops.reduce_logsumexp(states + one_hot, axis=-1) 

1330 

1331 

1332def _forward_backward_log(state_trans_log_probs, initial_state_log_probs, 

1333 final_state_log_probs, observed_log_probs, 

1334 sequence_length): 

1335 """Forward-backward algorithm computed in log domain. 

1336 

1337 Args: 

1338 state_trans_log_probs: tensor of shape [states, states] or if different 

1339 transition matrix per batch [batch_size, states, states] 

1340 initial_state_log_probs: tensor of shape [batch_size, states] 

1341 final_state_log_probs: tensor of shape [batch_size, states] 

1342 observed_log_probs: tensor of shape [frames, batch_size, states] 

1343 sequence_length: tensor of shape [batch_size] 

1344 

1345 Returns: 

1346 forward backward log probabilities: tensor of shape [frames, batch, states] 

1347 log_likelihood: tensor of shape [batch_size] 

1348 

1349 Raises: 

1350 ValueError: If state_trans_log_probs has unknown or incorrect rank. 

1351 """ 

1352 

1353 if state_trans_log_probs.shape.ndims == 2: 

1354 perm = [1, 0] 

1355 elif state_trans_log_probs.shape.ndims == 3: 

1356 perm = [0, 2, 1] 

1357 else: 

1358 raise ValueError( 

1359 "Rank of argument `state_trans_log_probs` must be known and equal to " 

1360 f"2 or 3. Received state_trans_log_probs={state_trans_log_probs} of " 

1361 f"rank {state_trans_log_probs.shape.ndims}") 

1362 

1363 bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm) 

1364 batch_size = _get_dim(observed_log_probs, 1) 

1365 

1366 def _forward(state_log_prob, obs_log_prob): 

1367 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 

1368 state_log_prob += state_trans_log_probs 

1369 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 

1370 state_log_prob += obs_log_prob 

1371 log_prob_sum = math_ops.reduce_logsumexp( 

1372 state_log_prob, axis=-1, keepdims=True) 

1373 state_log_prob -= log_prob_sum 

1374 return state_log_prob 

1375 

1376 fwd = _scan( 

1377 _forward, observed_log_probs, initial_state_log_probs, inclusive=True) 

1378 

1379 def _backward(accs, elems): 

1380 """Calculate log probs and cumulative sum masked for sequence length.""" 

1381 state_log_prob, cum_log_sum = accs 

1382 obs_log_prob, mask = elems 

1383 state_log_prob += obs_log_prob 

1384 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 

1385 state_log_prob += bwd_state_trans_log_probs 

1386 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 

1387 

1388 log_prob_sum = math_ops.reduce_logsumexp( 

1389 state_log_prob, axis=-1, keepdims=True) 

1390 state_log_prob -= log_prob_sum 

1391 

1392 cum_log_sum += array_ops.squeeze(log_prob_sum, axis=[-1]) * mask 

1393 batched_mask = array_ops.expand_dims(mask, axis=1) 

1394 out = state_log_prob * batched_mask 

1395 out += final_state_log_probs * (1.0 - batched_mask) 

1396 return out, cum_log_sum 

1397 

1398 zero_log_sum = array_ops.zeros([batch_size]) 

1399 maxlen = _get_dim(observed_log_probs, 0) 

1400 mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32) 

1401 mask = array_ops.transpose(mask, perm=[1, 0]) 

1402 

1403 bwd, cum_log_sum = _scan( 

1404 _backward, (observed_log_probs, mask), 

1405 (final_state_log_probs, zero_log_sum), 

1406 reverse=True, 

1407 inclusive=True) 

1408 

1409 fwd_bwd_log_probs = fwd[1:] + bwd[1:] 

1410 fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp( 

1411 fwd_bwd_log_probs, axis=2, keepdims=True) 

1412 fwd_bwd_log_probs -= fwd_bwd_log_probs_sum 

1413 fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2)) 

1414 

1415 log_likelihood = bwd[0, :, 0] + cum_log_sum[0] 

1416 

1417 return fwd_bwd_log_probs, log_likelihood 

1418 

1419 

1420# TODO(tombagby): This is currently faster for the ctc implementation than using 

1421# functional_ops.scan, but could be replaced by that or something similar if 

1422# things change. 

1423def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): 

1424 """Repeatedly applies callable `fn` to a sequence of elements. 

1425 

1426 Implemented by functional_ops.While, tpu friendly, no gradient. 

1427 

1428 This is similar to functional_ops.scan but significantly faster on tpu/gpu 

1429 for the forward backward use case. 

1430 

1431 Examples: 

1432 scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] 

1433 

1434 Multiple accumulators: 

1435 scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) 

1436 

1437 Multiple inputs: 

1438 scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) 

1439 

1440 Args: 

1441 fn: callable, fn(accumulators, element) return new accumulator values. The 

1442 (possibly nested) sequence of accumulators is the same as `initial` and 

1443 the return value must have the same structure. 

1444 elems: A (possibly nested) tensor which will be unpacked along the first 

1445 dimension. The resulting slices will be the second argument to fn. The 

1446 first dimension of all nested input tensors must be the same. 

1447 initial: A tensor or (possibly nested) sequence of tensors with initial 

1448 values for the accumulators. 

1449 reverse: (optional) True enables scan and output elems in reverse order. 

1450 inclusive: (optional) True includes the initial accumulator values in the 

1451 output. Length of output will be len(elem sequence) + 1. Not meaningful if 

1452 final_only is True. 

1453 final_only: (optional) When True, return only the final accumulated values, 

1454 not the concatenation of accumulated values for each input. 

1455 

1456 Returns: 

1457 A (possibly nested) sequence of tensors with the results of applying fn 

1458 to tensors unpacked from elems and previous accumulator values. 

1459 """ 

1460 

1461 flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] 

1462 num_elems = array_ops.shape(flat_elems[0])[0] 

1463 pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) 

1464 flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] 

1465 pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) 

1466 accum_dtypes = [x.dtype for x in flat_initial] 

1467 num_accums = len(flat_initial) 

1468 

1469 # Types for counter, [outputs], [accumulators] loop arguments. 

1470 if final_only: 

1471 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes 

1472 else: 

1473 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes 

1474 

1475 # TODO(tombagby): Update to tfe.defun 

1476 def cond(i, num_elems, *args): 

1477 del args 

1478 return i >= 0 if reverse else i < num_elems 

1479 

1480 # The loop *args are [output tensors] + [accumulator tensors] which must 

1481 # be paired. Each output corresponds to one accumulator. 

1482 def body(i, num_elems, *args): 

1483 """Loop body.""" 

1484 i.set_shape([]) 

1485 if final_only: 

1486 accum = args 

1487 else: 

1488 out, accum = args[:num_accums], args[num_accums:] 

1489 slices = [array_ops.gather(e, i) for e in flat_elems] 

1490 accum = fn(pack(accum), pack_elems(slices)) 

1491 flat_accum = nest.flatten(accum) 

1492 if final_only: 

1493 new_out = [] 

1494 else: 

1495 update_i = i + 1 if inclusive and not reverse else i 

1496 new_out = [ 

1497 inplace_ops.alias_inplace_update(x, update_i, y) 

1498 for x, y in zip(out, flat_accum) 

1499 ] 

1500 i = i - 1 if reverse else i + 1 

1501 return [i, num_elems] + new_out + flat_accum 

1502 

1503 init_i = ( 

1504 array_ops.shape(flat_elems[0])[0] - 

1505 1 if reverse else constant_op.constant(0, dtype=dtypes.int32)) 

1506 outputs = [] 

1507 if not final_only: 

1508 num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0) 

1509 for initial_accum in flat_initial: 

1510 out_shape = array_ops.concat( 

1511 [[num_outputs], array_ops.shape(initial_accum)], 0) 

1512 out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) 

1513 if inclusive: 

1514 out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0), 

1515 initial_accum) 

1516 outputs.append(out) 

1517 loop_in = [init_i, num_elems] + outputs + flat_initial 

1518 hostmem = [ 

1519 i for i, x in enumerate(loop_in) 

1520 if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) 

1521 ] 

1522 

1523 if context.executing_eagerly(): 

1524 loop_results = loop_in 

1525 while cond(*loop_results): 

1526 loop_results = body(*loop_results) 

1527 else: 

1528 # TODO(tombagby): Update to while_v2. 

1529 cond = function.Defun(*loop_dtypes)(cond) 

1530 body = function.Defun(*loop_dtypes)(body) 

1531 loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) 

1532 out = loop_results[2:num_accums + 2] 

1533 return pack(out) 

1534 

1535 

1536def _get_dim(tensor, i): 

1537 """Get value of tensor shape[i] preferring static value if available.""" 

1538 return tensor_shape.dimension_value( 

1539 tensor.shape[i]) or array_ops.shape(tensor)[i]