Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ctc_ops.py: 18%
439 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CTC (Connectionist Temporal Classification) Operations."""
17import uuid
19from tensorflow.python.eager import context
20from tensorflow.python.eager import def_function
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import device
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import array_ops_stack
32from tensorflow.python.ops import custom_gradient
33from tensorflow.python.ops import functional_ops
34from tensorflow.python.ops import gen_ctc_ops
35from tensorflow.python.ops import inplace_ops
36from tensorflow.python.ops import linalg_ops
37from tensorflow.python.ops import map_fn
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn_ops
40from tensorflow.python.ops import sparse_ops
41from tensorflow.python.ops.nn_grad import _BroadcastMul
42from tensorflow.python.util import deprecation
43from tensorflow.python.util import dispatch
44from tensorflow.python.util import nest
45from tensorflow.python.util.tf_export import tf_export
47_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
48_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
49_CPU_DEVICE_NAME = "CPU"
50_GPU_DEVICE_NAME = "GPU"
53def _get_context_device_type():
54 """Parses the current context and returns the device type, eg CPU/GPU."""
55 current_device = context.context().device_name
56 if current_device is None:
57 return None
58 return device.DeviceSpec.from_string(current_device).device_type
61def _generate_defun_backend(unique_api_name, preferred_device, func):
62 function_attributes = {
63 _DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
64 _DEFUN_DEVICE_ATTRIBUTE: preferred_device,
65 }
66 return def_function.function(
67 func=func, experimental_attributes=function_attributes, autograph=False)
69# pylint: disable=protected-access, invalid-name
70@tf_export(v1=["nn.ctc_loss"])
71@dispatch.add_dispatch_support
72def ctc_loss(labels,
73 inputs=None,
74 sequence_length=None,
75 preprocess_collapse_repeated=False,
76 ctc_merge_repeated=True,
77 ignore_longer_outputs_than_inputs=False,
78 time_major=True,
79 logits=None):
80 """Computes the CTC (Connectionist Temporal Classification) Loss.
82 This op implements the CTC loss as presented in (Graves et al., 2006).
84 Input requirements:
86 ```
87 sequence_length(b) <= time for all b
89 max(labels.indices(labels.indices[:, 1] == b, 2))
90 <= sequence_length(b) for all b.
91 ```
93 Notes:
95 This class performs the softmax operation for you, so inputs should
96 be e.g. linear projections of outputs by an LSTM.
98 The `inputs` Tensor's innermost dimension size, `num_classes`, represents
99 `num_labels + 1` classes, where num_labels is the number of true labels, and
100 the largest value `(num_classes - 1)` is reserved for the blank label.
102 For example, for a vocabulary containing 3 labels `[a, b, c]`,
103 `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
105 Regarding the arguments `preprocess_collapse_repeated` and
106 `ctc_merge_repeated`:
108 If `preprocess_collapse_repeated` is True, then a preprocessing step runs
109 before loss calculation, wherein repeated labels passed to the loss
110 are merged into single labels. This is useful if the training labels come
111 from, e.g., forced alignments and therefore have unnecessary repetitions.
113 If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
114 repeated non-blank labels will not be merged and are interpreted
115 as individual labels. This is a simplified (non-standard) version of CTC.
117 Here is a table of the (roughly) expected first order behavior:
119 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
121 Classical CTC behavior: Outputs true repeated classes with blanks in
122 between, and can also output repeated classes with no blanks in
123 between that need to be collapsed by the decoder.
125 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
127 Never learns to output repeated classes, as they are collapsed
128 in the input labels before training.
130 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
132 Outputs repeated classes with blanks in between, but generally does not
133 require the decoder to collapse/merge repeated classes.
135 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
137 Untested. Very likely will not learn to output repeated classes.
139 The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
140 of the CTCLoss when dealing with sequences that have longer outputs than
141 inputs. If true, the CTCLoss will simply return zero gradient for those
142 items, otherwise an InvalidArgument error is returned, stopping training.
144 Args:
145 labels: An `int32` `SparseTensor`.
146 `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id
147 for (batch b, time t). `labels.values[i]` must take on values in `[0,
148 num_labels)`. See `core/ops/ctc_ops.cc` for more details.
149 inputs: 3-D `float` `Tensor`.
150 If time_major == False, this will be a `Tensor` shaped: `[batch_size,
151 max_time, num_classes]`.
152 If time_major == True (default), this will be a `Tensor` shaped:
153 `[max_time, batch_size, num_classes]`. The logits.
154 sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence
155 lengths.
156 preprocess_collapse_repeated: Boolean. Default: False. If True, repeated
157 labels are collapsed prior to the CTC calculation.
158 ctc_merge_repeated: Boolean. Default: True.
159 ignore_longer_outputs_than_inputs: Boolean. Default: False. If True,
160 sequences with longer outputs than inputs will be ignored.
161 time_major: The shape format of the `inputs` Tensors. If True, these
162 `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False,
163 these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
164 Using `time_major = True` (default) is a bit more efficient because it
165 avoids transposes at the beginning of the ctc_loss calculation. However,
166 most TensorFlow data is batch-major, so by this function also accepts
167 inputs in batch-major form.
168 logits: Alias for inputs.
170 Returns:
171 A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
172 probabilities.
174 Raises:
175 TypeError: if labels is not a `SparseTensor`.
177 References:
178 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
179 with Recurrent Neural Networks:
180 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
181 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
182 """
183 return _ctc_loss_impl(
184 labels,
185 inputs,
186 sequence_length,
187 preprocess_collapse_repeated,
188 ctc_merge_repeated,
189 ignore_longer_outputs_than_inputs,
190 time_major,
191 logits,
192 use_cudnn=False)
195def _ctc_loss_impl(labels,
196 inputs=None,
197 sequence_length=None,
198 preprocess_collapse_repeated=False,
199 ctc_merge_repeated=True,
200 ignore_longer_outputs_than_inputs=False,
201 time_major=True,
202 logits=None,
203 use_cudnn=False):
204 # Helper function of ctc_loss with one additional param:
205 # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
206 # index has to be 0.
208 # The second, third, etc output tensors contain the gradients. We use it in
209 # _CTCLossGrad() below.
210 if not isinstance(labels, sparse_tensor.SparseTensor):
211 raise TypeError("Expected argument `labels` to be a SparseTensor. "
212 f"Received labels={labels} of type: "
213 f"{type(labels).__name__}")
215 # For internal calculations, we transpose to [time, batch, num_classes]
216 inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
217 inputs)
219 inputs = ops.convert_to_tensor(inputs, name="logits")
220 if not time_major:
221 inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
223 orig_dtype = inputs.dtype
224 if orig_dtype in (dtypes.float16, dtypes.bfloat16):
225 inputs = math_ops.cast(inputs, dtypes.float32)
227 # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
228 # blank index to be 0, but v1 views it as the last index.
229 if use_cudnn:
230 ctc_loss_func = gen_ctc_ops.ctc_loss_v2
231 else:
232 ctc_loss_func = gen_ctc_ops.ctc_loss
234 loss, _ = ctc_loss_func(
235 inputs,
236 labels.indices,
237 labels.values,
238 sequence_length,
239 preprocess_collapse_repeated=preprocess_collapse_repeated,
240 ctc_merge_repeated=ctc_merge_repeated,
241 ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
243 if orig_dtype in (dtypes.float16, dtypes.bfloat16):
244 loss = math_ops.cast(loss, orig_dtype)
246 return loss
248# pylint: disable=unused-argument
249def _CTCLossGradImpl(op, grad_loss, _):
250 # Outputs are: loss, grad
251 #
252 # Currently there is no way to take the second derivative of this op
253 # due to the fused implementation's interaction with tf.gradients(),
254 # so we make sure we prevent silently incorrect results by raising
255 # an error if the second derivative is requested via prevent_gradient.
256 grad_without_gradient = array_ops.prevent_gradient(
257 op.outputs[1],
258 message="Currently there is no way to take the second "
259 " derivative of ctc_loss due to the fused implementation's interaction "
260 " with tf.gradients()")
261 # Return gradient for inputs and None for
262 # labels_indices, labels_values and sequence_length
263 return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
266# pylint: disable=unused-argument
267@ops.RegisterGradient("CTCLoss")
268def _CTCLossGrad(op, grad_loss, _):
269 """The derivative provided by CTC Loss.
271 Args:
272 op: the CTCLoss op.
273 grad_loss: The backprop for cost.
275 Returns:
276 The CTC Loss gradient.
277 """
278 return _CTCLossGradImpl(op, grad_loss, _)
281# pylint: disable=unused-argument
282@ops.RegisterGradient("CTCLossV2")
283def _CTCLossV2Grad(op, grad_loss, _):
284 """The derivative provided by CTC Loss V2.
286 Args:
287 op: the CTCLossV2 op.
288 grad_loss: The backprop for cost.
290 Returns:
291 The CTC Loss V2 gradient.
292 """
293 return _CTCLossGradImpl(op, grad_loss, _)
296@tf_export("nn.ctc_greedy_decoder")
297@dispatch.add_dispatch_support
298def ctc_greedy_decoder(inputs,
299 sequence_length,
300 merge_repeated=True,
301 blank_index=None):
302 """Performs greedy decoding on the logits given in input (best path).
304 Given a tensor as `inputs`, the `blank_index` parameter defines the class
305 index of the blank symbol.
307 For example:
309 If `blank_index` is equal to 1:
311 >>> inf = float("inf")
312 >>> logits = tf.constant([[[ 0., -inf, -inf],
313 ... [ -2.3, -inf, -0.1]],
314 ... [[ -inf, -0.5, -inf],
315 ... [ -inf, -inf, -0.1]],
316 ... [[ -inf, -inf, -inf],
317 ... [ -0.1, -inf, -2.3]]])
318 >>> seq_lens = tf.constant([2, 3])
319 >>> outputs = tf.nn.ctc_greedy_decoder(
320 ... logits,
321 ... seq_lens,
322 ... blank_index=1)
324 Notes:
326 - Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanks
327 as regular elements when computing the probability of a sequence.
328 - Default `blank_index` is `(num_classes - 1)`, unless overriden.
330 If `merge_repeated` is `True`, merge repeated classes in output.
331 This means that if consecutive logits' maximum indices are the same,
332 only the first of these is emitted. The sequence `A B B * B * B` (where '*'
333 is the blank label) becomes
335 * `A B B B` if `merge_repeated=True`.
336 * `A B B B B` if `merge_repeated=False`.
338 Args:
339 inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
340 The logits.
341 sequence_length: 1-D `int32` vector containing sequence lengths, having size
342 `[batch_size]`.
343 merge_repeated: Boolean. Default: True.
344 blank_index: (Optional). Default: `num_classes - 1`. Define the class index
345 to use for the blank label. Negative values will start from num_classes,
346 ie, -1 will reproduce the ctc_greedy_decoder behavior of using
347 num_classes - 1 for the blank symbol, which corresponds to the default.
349 Returns:
350 A tuple `(decoded, neg_sum_logits)` where
352 decoded: A single-element list. `decoded[0]`
353 is an `SparseTensor` containing the decoded outputs s.t.:
355 `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
356 The rows store: `[batch, time]`.
358 `decoded.values`: Values vector, size `(total_decoded_outputs)`.
359 The vector stores the decoded classes.
361 `decoded.dense_shape`: Shape vector, size `(2)`.
362 The shape values are: `[batch_size, max_decoded_length]`
364 neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
365 sequence found, the negative of the sum of the greatest logit at each
366 timeframe.
367 """
369 outputs = gen_ctc_ops.ctc_greedy_decoder(
370 inputs,
371 sequence_length,
372 merge_repeated=merge_repeated,
373 blank_index=blank_index)
374 (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
375 return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
376 decoded_shape)], log_probabilities)
379@tf_export(v1=["nn.ctc_beam_search_decoder"])
380@dispatch.add_dispatch_support
381def ctc_beam_search_decoder(inputs,
382 sequence_length,
383 beam_width=100,
384 top_paths=1,
385 merge_repeated=True):
386 """Performs beam search decoding on the logits given in input.
388 **Note** Although in general greedy search is a special case of beam-search
389 with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs
390 from `ctc_greedy_decoder` in the treatment of blanks when computing the
391 probability of a sequence:
392 - `ctc_beam_search_decoder` treats blanks as sequence termination
393 - `ctc_greedy_decoder` treats blanks as regular elements
395 If `merge_repeated` is `True`, merge repeated classes in the output beams.
396 This means that if consecutive entries in a beam are the same,
397 only the first of these is emitted. That is, when the sequence is
398 `A B B * B * B` (where '*' is the blank label), the return value is:
400 * `A B` if `merge_repeated = True`.
401 * `A B B B` if `merge_repeated = False`.
403 Args:
404 inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`.
405 The logits.
406 sequence_length: 1-D `int32` vector containing sequence lengths, having size
407 `[batch_size]`.
408 beam_width: An int scalar >= 0 (beam search beam width).
409 top_paths: An int scalar >= 0, <= beam_width (controls output size).
410 merge_repeated: Boolean. Default: True.
412 Returns:
413 A tuple `(decoded, log_probabilities)` where
415 decoded: A list of length top_paths, where `decoded[j]`
416 is a `SparseTensor` containing the decoded outputs:
418 `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
419 The rows store: [batch, time].
421 `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
422 The vector stores the decoded classes for beam j.
424 `decoded[j].dense_shape`: Shape vector, size `(2)`.
425 The shape values are: `[batch_size, max_decoded_length[j]]`.
427 log_probability: A `float` matrix `(batch_size x top_paths)` containing
428 sequence log-probabilities.
429 """
431 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
432 gen_ctc_ops.ctc_beam_search_decoder(
433 inputs,
434 sequence_length,
435 beam_width=beam_width,
436 top_paths=top_paths,
437 merge_repeated=merge_repeated))
439 return ([
440 sparse_tensor.SparseTensor(ix, val, shape)
441 for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes)
442 ], log_probabilities)
445@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
446@dispatch.add_dispatch_support
447def ctc_beam_search_decoder_v2(inputs,
448 sequence_length,
449 beam_width=100,
450 top_paths=1):
451 """Performs beam search decoding on the logits given in input.
453 **Note** Although in general greedy search is a special case of beam-search
454 with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs
455 from `ctc_greedy_decoder` in the treatment of blanks when computing the
456 probability of a sequence:
457 - `ctc_beam_search_decoder` treats blanks as sequence termination
458 - `ctc_greedy_decoder` treats blanks as regular elements
460 Args:
461 inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`.
462 The logits.
463 sequence_length: 1-D `int32` vector containing sequence lengths, having size
464 `[batch_size]`.
465 beam_width: An int scalar >= 0 (beam search beam width).
466 top_paths: An int scalar >= 0, <= beam_width (controls output size).
468 Returns:
469 A tuple `(decoded, log_probabilities)` where
471 decoded: A list of length top_paths, where `decoded[j]`
472 is a `SparseTensor` containing the decoded outputs:
474 `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`;
475 The rows store: `[batch, time]`.
477 `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`.
478 The vector stores the decoded classes for beam `j`.
480 `decoded[j].dense_shape`: Shape vector, size `(2)`.
481 The shape values are: `[batch_size, max_decoded_length[j]]`.
483 log_probability: A `float` matrix `[batch_size, top_paths]` containing
484 sequence log-probabilities.
485 """
487 # Note, merge_repeated is an invalid optimization that is removed from the
488 # public API: it returns low probability paths.
489 return ctc_beam_search_decoder(
490 inputs,
491 sequence_length=sequence_length,
492 beam_width=beam_width,
493 top_paths=top_paths,
494 merge_repeated=False)
497ops.NotDifferentiable("CTCGreedyDecoder")
498ops.NotDifferentiable("CTCBeamSearchDecoder")
501def _ctc_state_trans(label_seq):
502 """Computes CTC alignment model transition matrix.
504 Args:
505 label_seq: tensor of shape [batch_size, max_seq_length]
507 Returns:
508 tensor of shape [batch_size, states, states] with a state transition matrix
509 computed for each sequence of the batch.
510 """
512 with ops.name_scope("ctc_state_trans"):
513 label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
514 batch_size = _get_dim(label_seq, 0)
515 num_labels = _get_dim(label_seq, 1)
517 num_label_states = num_labels + 1
518 num_states = 2 * num_label_states
520 label_states = math_ops.range(num_label_states)
521 blank_states = label_states + num_label_states
523 # Start state to first label.
524 start_to_label = [[1, 0]]
526 # Blank to label transitions.
527 blank_to_label = array_ops_stack.stack(
528 [label_states[1:], blank_states[:-1]], 1)
530 # Label to blank transitions.
531 label_to_blank = array_ops_stack.stack([blank_states, label_states], 1)
533 # Scatter transitions that don't depend on sequence.
534 indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank],
535 0)
536 values = array_ops.ones([_get_dim(indices, 0)])
537 trans = array_ops.scatter_nd(
538 indices, values, shape=[num_states, num_states])
539 trans += linalg_ops.eye(num_states) # Self-loops.
541 # Label to label transitions. Disallow transitions between repeated labels
542 # with no blank state in between.
543 batch_idx = array_ops.zeros_like(label_states[2:])
544 indices = array_ops_stack.stack(
545 [batch_idx, label_states[2:], label_states[1:-1]], 1)
546 indices = array_ops.tile(
547 array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
548 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
549 indices += array_ops.expand_dims(batch_idx, 1)
550 repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
551 values = 1.0 - math_ops.cast(repeats, dtypes.float32)
552 batched_shape = [batch_size, num_states, num_states]
553 label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
555 return array_ops.expand_dims(trans, 0) + label_to_label
558def ctc_state_log_probs(seq_lengths, max_seq_length):
559 """Computes CTC alignment initial and final state log probabilities.
561 Create the initial/final state values directly as log values to avoid
562 having to take a float64 log on tpu (which does not exist).
564 Args:
565 seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
566 max_seq_length: int, max sequence length possible.
568 Returns:
569 initial_state_log_probs, final_state_log_probs
570 """
572 batch_size = _get_dim(seq_lengths, 0)
573 num_label_states = max_seq_length + 1
574 num_duration_states = 2
575 num_states = num_duration_states * num_label_states
576 log_0 = math_ops.cast(
577 math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32)
579 initial_state_log_probs = array_ops.one_hot(
580 indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
581 depth=num_states,
582 on_value=0.0,
583 off_value=log_0,
584 axis=1)
586 label_final_state_mask = array_ops.one_hot(
587 seq_lengths, depth=num_label_states, axis=0)
588 duration_final_state_mask = array_ops.ones(
589 [num_duration_states, 1, batch_size])
590 final_state_mask = duration_final_state_mask * label_final_state_mask
591 final_state_log_probs = (1.0 - final_state_mask) * log_0
592 final_state_log_probs = array_ops.reshape(final_state_log_probs,
593 [num_states, batch_size])
595 return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
598def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
599 """Project ilabel log probs to state log probs."""
601 num_label_states = _get_dim(labels, 1)
602 blank = ilabel_log_probs[:, :, :1]
603 blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
604 one_hot = array_ops.one_hot(labels, depth=num_labels)
605 one_hot = array_ops.expand_dims(one_hot, axis=0)
606 ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
607 state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
608 state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
609 return array_ops.pad(
610 state_log_probs, [[0, 0], [0, 0], [1, 0]],
611 constant_values=math_ops.log(0.0))
614def _state_to_olabel(labels, num_labels, states):
615 """Sum state log probs to ilabel log probs."""
617 num_label_states = _get_dim(labels, 1) + 1
618 label_states = states[:, :, 1:num_label_states]
619 blank_states = states[:, :, num_label_states:]
620 one_hot = array_ops.one_hot(
621 labels - 1,
622 depth=(num_labels - 1),
623 on_value=0.0,
624 off_value=math_ops.log(0.0))
625 one_hot = array_ops.expand_dims(one_hot, axis=0)
626 label_states = array_ops.expand_dims(label_states, axis=3)
627 label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
628 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
629 return array_ops.concat([blank_olabels, label_olabels], axis=-1)
632# pylint: disable=redefined-outer-name
633def _state_to_olabel_unique(labels, num_labels, states, unique):
634 """Sum state log probs to ilabel log probs using unique label indices."""
636 num_label_states = _get_dim(labels, 1) + 1
637 label_states = states[:, :, 1:num_label_states]
638 blank_states = states[:, :, num_label_states:]
640 unique_y, unique_idx = unique
641 mul_reduce = _sum_states(unique_idx, label_states)
643 num_frames = _get_dim(states, 0)
644 batch_size = _get_dim(states, 1)
645 num_states = num_label_states - 1
646 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
647 batch_state_major = array_ops.reshape(batch_state_major,
648 [batch_size * num_states, num_frames])
649 batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
650 indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
651 indices = array_ops.reshape(indices, [-1, 1])
652 scatter = array_ops.scatter_nd(
653 indices=indices,
654 updates=batch_state_major,
655 shape=[batch_size * num_labels, num_frames])
656 scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
658 mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
659 mask = array_ops.scatter_nd(
660 indices=indices,
661 updates=mask,
662 shape=[batch_size * num_labels, num_frames])
663 mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
665 scatter = array_ops.where(
666 mask, scatter,
667 array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
669 label_olabels = array_ops.transpose(scatter, [2, 0, 1])
670 label_olabels = label_olabels[:, :, 1:]
672 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
674 return array_ops.concat([blank_olabels, label_olabels], axis=-1)
677def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
678 """Computes the CTC loss and gradients.
680 Most users will want fwd_bwd.ctc_loss
682 This function returns the computed gradient, it does not have a gradient
683 of its own defined.
685 Args:
686 logits: tensor of shape [frames, batch_size, num_labels]
687 labels: tensor of shape [batch_size, max_label_seq_length]
688 label_length: tensor of shape [batch_size] Length of reference label
689 sequence in labels.
690 logit_length: tensor of shape [batch_size] Length of input sequence in
691 logits.
692 unique: (optional) unique label indices as computed by unique(labels) If
693 supplied, enables an implementation that is faster and more memory
694 efficient on TPU.
696 Returns:
697 loss: tensor of shape [batch_size]
698 gradient: tensor of shape [frames, batch_size, num_labels]
699 """
701 num_labels = _get_dim(logits, 2)
702 max_label_seq_length = _get_dim(labels, 1)
704 ilabel_log_probs = nn_ops.log_softmax(logits)
705 state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
706 state_trans_probs = _ctc_state_trans(labels)
707 initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
708 label_length, max_label_seq_length)
709 fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
710 state_trans_log_probs=math_ops.log(state_trans_probs),
711 initial_state_log_probs=initial_state_log_probs,
712 final_state_log_probs=final_state_log_probs,
713 observed_log_probs=state_log_probs,
714 sequence_length=logit_length)
716 if unique:
717 olabel_log_probs = _state_to_olabel_unique(labels, num_labels,
718 fwd_bwd_log_probs, unique)
719 else:
720 olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
722 grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
724 # Applies the sequence mask for the gradient. It is enough to appply the mask
725 # only for ilabel_log_probs because olabel_log_probs already consider the
726 # mask. However, it is just safe and clean to apply it for the gradient.
727 max_logit_length = _get_dim(logits, 0)
728 logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
729 dtypes.float32)
730 logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
731 logit_mask = array_ops.expand_dims(logit_mask, axis=2)
732 grad *= logit_mask
734 loss = -log_likelihood
735 return loss, grad
738def _ctc_loss_grad(op, grad_loss, _):
739 grad = op.outputs[1]
740 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
741 grad += [None] * (len(op.inputs) - len(grad))
742 return grad
745def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
746 blank_index):
747 part_before = logits[:, :, :blank_index]
748 part_after = logits[:, :, blank_index + 1:]
749 part_blank = logits[:, :, blank_index:blank_index + 1]
750 logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
751 labels = sparse_tensor.SparseTensor(
752 labels.indices,
753 array_ops.where(labels.values < blank_index, labels.values,
754 labels.values - 1), labels.dense_shape)
755 return _ctc_loss_impl(
756 labels=labels,
757 inputs=logits,
758 sequence_length=logit_length,
759 time_major=logits_time_major,
760 use_cudnn=False)
763def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
764 blank_index):
765 part_before = logits[:, :, :blank_index]
766 part_after = logits[:, :, blank_index + 1:]
767 part_blank = logits[:, :, blank_index:blank_index + 1]
768 logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
769 labels = sparse_tensor.SparseTensor(
770 labels.indices,
771 array_ops.where(labels.values < blank_index, labels.values + 1,
772 labels.values), labels.dense_shape)
773 return _ctc_loss_impl(
774 labels=labels,
775 inputs=logits,
776 sequence_length=logit_length,
777 time_major=logits_time_major,
778 use_cudnn=True)
781def _ctc_loss_shape(op):
782 return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
785# pylint: disable=protected-access, invalid-name
786@tf_export(v1=["nn.ctc_loss_v2"])
787@dispatch.add_dispatch_support
788def ctc_loss_v2(labels,
789 logits,
790 label_length,
791 logit_length,
792 logits_time_major=True,
793 unique=None,
794 blank_index=None,
795 name=None):
796 """Computes CTC (Connectionist Temporal Classification) loss.
798 This op implements the CTC loss as presented in (Graves et al., 2006).
800 Notes:
802 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
803 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
804 - Labels may be supplied as either a dense, zero-padded tensor with a
805 vector of label sequence lengths OR as a SparseTensor.
806 - On TPU and GPU: Only dense padded labels are supported.
807 - On CPU: Caller may use SparseTensor or dense padded labels but calling with
808 a SparseTensor will be significantly faster.
809 - Default blank label is 0 rather num_classes - 1, unless overridden by
810 blank_index.
812 Args:
813 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
814 logits: tensor of shape [frames, batch_size, num_labels], if
815 logits_time_major == False, shape is [batch_size, frames, num_labels].
816 label_length: tensor of shape [batch_size], None if labels is SparseTensor
817 Length of reference label sequence in labels.
818 logit_length: tensor of shape [batch_size] Length of input sequence in
819 logits.
820 logits_time_major: (optional) If True (default), logits is shaped [time,
821 batch, logits]. If False, shape is [batch, time, logits]
822 unique: (optional) Unique label indices as computed by
823 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient
824 implementation on TPU.
825 blank_index: (optional) Set the class index to use for the blank label.
826 Negative values will start from num_classes, ie, -1 will reproduce the
827 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
828 some memory/performance overhead to switching from the default of 0 as an
829 additional shifted copy of the logits may be created.
830 name: A name for this `Op`. Defaults to "ctc_loss_dense".
832 Returns:
833 loss: tensor of shape [batch_size], negative log probabilities.
835 References:
836 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
837 with Recurrent Neural Networks:
838 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
839 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
840 """
841 if isinstance(labels, sparse_tensor.SparseTensor):
842 if blank_index is None:
843 raise ValueError(
844 "Argument `blank_index` must be provided when labels is a "
845 "SparseTensor.")
847 if blank_index < 0:
848 blank_index += _get_dim(logits, 2)
850 if blank_index != _get_dim(logits, 2) - 1:
851 logits = array_ops.concat([
852 logits[:, :, :blank_index],
853 logits[:, :, blank_index + 1:],
854 logits[:, :, blank_index:blank_index + 1],
855 ],
856 axis=2)
857 labels = sparse_tensor.SparseTensor(
858 labels.indices,
859 array_ops.where(labels.values < blank_index, labels.values,
860 labels.values - 1), labels.dense_shape)
862 return ctc_loss(
863 labels=labels,
864 inputs=logits,
865 sequence_length=logit_length,
866 time_major=logits_time_major)
868 if blank_index is None:
869 blank_index = 0
871 return ctc_loss_dense(
872 labels=labels,
873 logits=logits,
874 label_length=label_length,
875 logit_length=logit_length,
876 logits_time_major=logits_time_major,
877 unique=unique,
878 blank_index=blank_index,
879 name=name)
882@tf_export("nn.ctc_loss", v1=[])
883@dispatch.add_dispatch_support
884def ctc_loss_v3(labels,
885 logits,
886 label_length,
887 logit_length,
888 logits_time_major=True,
889 unique=None,
890 blank_index=None,
891 name=None):
892 """Computes CTC (Connectionist Temporal Classification) loss.
894 This op implements the CTC loss as presented in
895 [Graves et al., 2006](https://www.cs.toronto.edu/~graves/icml_2006.pdf)
897 Connectionist temporal classification (CTC) is a type of neural network output
898 and associated scoring function, for training recurrent neural networks (RNNs)
899 such as LSTM networks to tackle sequence problems where the timing is
900 variable. It can be used for tasks like on-line handwriting recognition or
901 recognizing phones in speech audio. CTC refers to the outputs and scoring, and
902 is independent of the underlying neural network structure.
904 Notes:
906 - This class performs the softmax operation for you, so `logits` should be
907 e.g. linear projections of outputs by an LSTM.
908 - Outputs true repeated classes with blanks in between, and can also output
909 repeated classes with no blanks in between that need to be collapsed by the
910 decoder.
911 - `labels` may be supplied as either a dense, zero-padded `Tensor` with a
912 vector of label sequence lengths OR as a `SparseTensor`.
913 - On TPU: Only dense padded `labels` are supported.
914 - On CPU and GPU: Caller may use `SparseTensor` or dense padded `labels`
915 but calling with a `SparseTensor` will be significantly faster.
916 - Default blank label is `0` instead of `num_labels - 1` (where `num_labels`
917 is the innermost dimension size of `logits`), unless overridden by
918 `blank_index`.
920 >>> tf.random.set_seed(50)
921 >>> batch_size = 8
922 >>> num_labels = 6
923 >>> max_label_length = 5
924 >>> num_frames = 12
925 >>> labels = tf.random.uniform([batch_size, max_label_length],
926 ... minval=1, maxval=num_labels, dtype=tf.int64)
927 >>> logits = tf.random.uniform([num_frames, batch_size, num_labels])
928 >>> label_length = tf.random.uniform([batch_size], minval=2,
929 ... maxval=max_label_length, dtype=tf.int64)
930 >>> label_mask = tf.sequence_mask(label_length, maxlen=max_label_length,
931 ... dtype=label_length.dtype)
932 >>> labels *= label_mask
933 >>> logit_length = [num_frames] * batch_size
934 >>> with tf.GradientTape() as t:
935 ... t.watch(logits)
936 ... ref_loss = tf.nn.ctc_loss(
937 ... labels=labels,
938 ... logits=logits,
939 ... label_length=label_length,
940 ... logit_length=logit_length,
941 ... blank_index=0)
942 >>> ref_grad = t.gradient(ref_loss, logits)
944 Args:
945 labels: `Tensor` of shape `[batch_size, max_label_seq_length]` or
946 `SparseTensor`.
947 logits: `Tensor` of shape `[frames, batch_size, num_labels]`. If
948 `logits_time_major == False`, shape is `[batch_size, frames, num_labels]`.
949 label_length: `Tensor` of shape `[batch_size]`. None, if `labels` is a
950 `SparseTensor`. Length of reference label sequence in `labels`.
951 logit_length: `Tensor` of shape `[batch_size]`. Length of input sequence in
952 `logits`.
953 logits_time_major: (optional) If True (default), `logits` is shaped [frames,
954 batch_size, num_labels]. If False, shape is
955 `[batch_size, frames, num_labels]`.
956 unique: (optional) Unique label indices as computed by
957 `ctc_unique_labels(labels)`. If supplied, enable a faster, memory
958 efficient implementation on TPU.
959 blank_index: (optional) Set the class index to use for the blank label.
960 Negative values will start from `num_labels`, ie, `-1` will reproduce the
961 ctc_loss behavior of using `num_labels - 1` for the blank symbol. There is
962 some memory/performance overhead to switching from the default of 0 as an
963 additional shifted copy of `logits` may be created.
964 name: A name for this `Op`. Defaults to "ctc_loss_dense".
966 Returns:
967 loss: A 1-D `float Tensor` of shape `[batch_size]`, containing negative log
968 probabilities.
970 Raises:
971 ValueError: Argument `blank_index` must be provided when `labels` is a
972 `SparseTensor`.
974 References:
975 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
976 with Recurrent Neural Networks:
977 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
978 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
980 https://en.wikipedia.org/wiki/Connectionist_temporal_classification
981 """
982 if isinstance(labels, sparse_tensor.SparseTensor):
983 if blank_index is None:
984 raise ValueError(
985 "Argument `blank_index` must be provided when `labels` is a "
986 "`SparseTensor`.")
988 if blank_index < 0:
989 blank_index += _get_dim(logits, 2)
991 logits = ops.convert_to_tensor(logits, name="logits")
993 params = {
994 "labels": labels,
995 "logits": logits,
996 "logit_length": logit_length,
997 "logits_time_major": logits_time_major,
998 "blank_index": blank_index
999 }
1001 if context.executing_eagerly():
1002 device_type = _get_context_device_type()
1003 can_use_gpu = (
1004 # Either user specified GPU or unspecified but GPU is available.
1005 (device_type == _GPU_DEVICE_NAME or
1006 (device_type is None and context.num_gpus() > 0)))
1007 # Under eager context, check the device placement and prefer the
1008 if can_use_gpu:
1009 res = _ctc_loss_op_cudnn(**params)
1010 else:
1011 res = _ctc_loss_op_standard(**params)
1012 else:
1013 api_name = "ctc_loss_" + str(uuid.uuid4())
1014 ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
1015 _ctc_loss_op_standard)
1016 ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
1017 _ctc_loss_op_cudnn)
1018 res = ctc_loss_op_standard(**params)
1019 concrete_func = ctc_loss_op_cudnn.get_concrete_function(**params)
1020 concrete_func.add_to_graph()
1021 concrete_func.add_gradient_functions_to_graph()
1022 return res
1024 if blank_index is None:
1025 blank_index = 0
1027 return ctc_loss_dense(
1028 labels=labels,
1029 logits=logits,
1030 label_length=label_length,
1031 logit_length=logit_length,
1032 logits_time_major=logits_time_major,
1033 unique=unique,
1034 blank_index=blank_index,
1035 name=name)
1038def ctc_loss_dense(labels,
1039 logits,
1040 label_length,
1041 logit_length,
1042 logits_time_major=True,
1043 unique=None,
1044 blank_index=0,
1045 name=None):
1046 """Computes CTC (Connectionist Temporal Classification) loss.
1048 This op implements the CTC loss as presented in (Graves et al., 2006),
1049 using the batched forward backward algorithm described in (Sim et al., 2017).
1051 Notes:
1052 Significant differences from `tf.compat.v1.nn.ctc_loss`:
1053 Supports GPU and TPU (`tf.compat.v1.nn.ctc_loss` supports CPU only):
1054 For batched operations, GPU and TPU are significantly faster than using
1055 `ctc_loss` on CPU.
1056 This implementation runs on CPU, but significantly slower than ctc_loss.
1057 Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
1058 Logits and labels are dense arrays with padding rather than SparseTensor.
1059 The only mode supported is the same as:
1060 preprocess_collapse_repeated=False, ctc_merge_repeated=True
1061 To collapse labels, the caller can preprocess label sequence first.
1063 The dense implementation supports both CPU, GPU and TPU. A fast path is
1064 provided that significantly improves memory use for large vocabulary if the
1065 caller preprocesses label sequences to get unique label indices on the CPU
1066 (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in
1067 the optional "unique" kwarg. This is especially useful for TPU and GPU but
1068 also works with if used on CPU.
1070 Args:
1071 labels: tensor of shape [batch_size, max_label_seq_length]
1072 logits: tensor of shape [frames, batch_size, num_labels], if
1073 logits_time_major == False, shape is [batch_size, frames, num_labels].
1074 label_length: tensor of shape [batch_size] Length of reference label
1075 sequence in labels.
1076 logit_length: tensor of shape [batch_size] Length of input sequence in
1077 logits.
1078 logits_time_major: (optional) If True (default), logits is shaped [time,
1079 batch, logits]. If False, shape is [batch, time, logits]
1080 unique: (optional) Unique label indices as computed by unique(labels). If
1081 supplied, enable a faster, memory efficient implementation on TPU.
1082 blank_index: (optional) Set the class index to use for the blank label.
1083 Negative values will start from num_classes, ie, -1 will reproduce the
1084 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
1085 some memory/performance overhead to switching from the default of 0 as an
1086 additional shifted copy of the logits may be created.
1087 name: A name for this `Op`. Defaults to "ctc_loss_dense".
1089 Returns:
1090 loss: tensor of shape [batch_size], negative log probabilities.
1092 References:
1093 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
1094 with Recurrent Neural Networks:
1095 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
1096 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
1097 Improving the efficiency of forward-backward algorithm using batched
1098 computation in TensorFlow:
1099 [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944)
1100 ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf))
1101 """
1103 with ops.name_scope(name, "ctc_loss_dense",
1104 [logits, labels, label_length, logit_length]):
1105 logits = ops.convert_to_tensor(logits, name="logits")
1106 labels = ops.convert_to_tensor(labels, name="labels")
1107 label_length = ops.convert_to_tensor(label_length, name="label_length")
1108 logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
1110 orig_dtype = logits.dtype
1111 if orig_dtype in (dtypes.float16, dtypes.bfloat16):
1112 logits = math_ops.cast(logits, dtypes.float32)
1114 if not logits_time_major:
1115 logits = array_ops.transpose(logits, perm=[1, 0, 2])
1117 if blank_index != 0:
1118 if blank_index < 0:
1119 blank_index += _get_dim(logits, 2)
1120 logits = array_ops.concat([
1121 logits[:, :, blank_index:blank_index + 1],
1122 logits[:, :, :blank_index],
1123 logits[:, :, blank_index + 1:],
1124 ],
1125 axis=2)
1126 labels = array_ops.where(labels < blank_index, labels + 1, labels)
1128 args = [logits, labels, label_length, logit_length]
1130 if unique:
1131 unique_y, unique_idx = unique
1132 if blank_index != 0:
1133 unique_y = array_ops.where(unique_y < blank_index, unique_y + 1,
1134 unique_y)
1135 label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1
1136 max_label_length = _get_dim(unique_y, 1)
1137 label_mask = array_ops.sequence_mask(label_mask_len, max_label_length)
1138 unique_y = array_ops.where(label_mask, unique_y,
1139 array_ops.zeros_like(unique_y))
1140 args.extend([unique_y, unique_idx])
1142 @custom_gradient.custom_gradient
1143 def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t,
1144 *unique_t):
1145 """Compute CTC loss."""
1146 logits_t.set_shape(logits.shape)
1147 labels_t.set_shape(labels.shape)
1148 label_length_t.set_shape(label_length.shape)
1149 logit_length_t.set_shape(logit_length.shape)
1150 kwargs = dict(
1151 logits=logits_t,
1152 labels=labels_t,
1153 label_length=label_length_t,
1154 logit_length=logit_length_t)
1155 if unique_t:
1156 kwargs["unique"] = unique_t
1157 result = ctc_loss_and_grad(**kwargs)
1158 def grad(grad_loss):
1159 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]]
1160 grad += [None] * (len(args) - len(grad))
1161 return grad
1163 return result[0], grad
1165 loss = compute_ctc_loss(*args)
1166 if orig_dtype in (dtypes.float16, dtypes.bfloat16):
1167 loss = math_ops.cast(loss, orig_dtype)
1168 return loss
1171@tf_export("nn.collapse_repeated")
1172@dispatch.add_dispatch_support
1173def collapse_repeated(labels, seq_length, name=None):
1174 """Merge repeated labels into single labels.
1176 Args:
1177 labels: Tensor of shape [batch, max value in seq_length]
1178 seq_length: Tensor of shape [batch], sequence length of each batch element.
1179 name: A name for this `Op`. Defaults to "collapse_repeated_labels".
1181 Returns:
1182 A tuple `(collapsed_labels, new_seq_length)` where
1184 collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated
1185 labels collapsed and padded to max_seq_length, eg:
1186 `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]`
1188 new_seq_length: int tensor of shape [batch] with new sequence lengths.
1189 """
1191 with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]):
1192 labels = ops.convert_to_tensor(labels, name="labels")
1193 seq_length = ops.convert_to_tensor(seq_length, name="seq_length")
1195 # Mask labels that don't equal previous label.
1196 label_mask = array_ops.concat([
1197 array_ops.ones_like(labels[:, :1], dtypes.bool),
1198 math_ops.not_equal(labels[:, 1:], labels[:, :-1])
1199 ],
1200 axis=1)
1202 # Filter labels that aren't in the original sequence.
1203 maxlen = _get_dim(labels, 1)
1204 seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
1205 label_mask = math_ops.logical_and(label_mask, seq_mask)
1207 # Count masks for new sequence lengths.
1208 new_seq_len = math_ops.reduce_sum(
1209 math_ops.cast(label_mask, dtypes.int32), axis=1)
1211 # Mask indexes based on sequence length mask.
1212 new_maxlen = math_ops.reduce_max(new_seq_len)
1213 idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)
1215 # Flatten everything and mask out labels to keep and sparse indices.
1216 flat_labels = array_ops.reshape(labels, [-1])
1217 flat_label_mask = array_ops.reshape(label_mask, [-1])
1218 flat_idx_mask = array_ops.reshape(idx_mask, [-1])
1219 idx = math_ops.range(_get_dim(flat_idx_mask, 0))
1221 # Scatter to flat shape.
1222 flat = array_ops.scatter_nd(
1223 indices=array_ops.expand_dims(
1224 array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
1225 updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
1226 shape=array_ops.shape(flat_idx_mask))
1228 # Reshape back to square batch.
1229 batch_size = _get_dim(labels, 0)
1230 new_shape = [batch_size, new_maxlen]
1231 return (array_ops.reshape(flat, new_shape),
1232 math_ops.cast(new_seq_len, seq_length.dtype))
1235def dense_labels_to_sparse(dense, length):
1236 """Convert dense labels with sequence lengths to sparse tensor.
1238 Args:
1239 dense: tensor of shape [batch, max_length]
1240 length: int tensor of shape [batch] The length of each sequence in dense.
1242 Returns:
1243 tf.sparse.SparseTensor with values only for the valid elements of sequences.
1244 """
1246 flat_values = array_ops.reshape(dense, [-1])
1247 flat_indices = math_ops.range(
1248 array_ops.shape(flat_values, out_type=dtypes.int64)[0])
1249 mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1])
1250 flat_mask = array_ops.reshape(mask, [-1])
1251 indices = array_ops.expand_dims(
1252 array_ops.boolean_mask(flat_indices, flat_mask), 1)
1253 values = array_ops.boolean_mask(flat_values, flat_mask)
1254 sparse = sparse_tensor.SparseTensor(
1255 indices=indices,
1256 values=math_ops.cast(values, dtypes.int32),
1257 dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64))
1258 reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense))
1259 max_length = math_ops.reduce_max(length)
1260 return sparse_tensor.SparseTensor(
1261 indices=reshaped.indices,
1262 values=reshaped.values,
1263 dense_shape=[
1264 math_ops.cast(reshaped.dense_shape[0], dtypes.int64),
1265 math_ops.cast(max_length, dtypes.int64)
1266 ])
1269@tf_export("nn.ctc_unique_labels")
1270@dispatch.add_dispatch_support
1271def ctc_unique_labels(labels, name=None):
1272 """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`.
1274 For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be
1275 used to preprocess labels in input pipeline to for better speed/memory use
1276 computing the ctc loss on TPU.
1278 Example:
1279 ctc_unique_labels([[3, 4, 4, 3]]) ->
1280 unique labels padded with 0: [[3, 4, 0, 0]]
1281 indices of original labels in unique: [0, 1, 1, 0]
1283 Args:
1284 labels: tensor of shape [batch_size, max_label_length] padded with 0.
1285 name: A name for this `Op`. Defaults to "ctc_unique_labels".
1287 Returns:
1288 tuple of
1289 - unique labels, tensor of shape `[batch_size, max_label_length]`
1290 - indices into unique labels, shape `[batch_size, max_label_length]`
1291 """
1293 with ops.name_scope(name, "ctc_unique_labels", [labels]):
1294 labels = ops.convert_to_tensor(labels, name="labels")
1296 def _unique(x):
1297 u = array_ops.unique(x)
1298 y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
1299 y = math_ops.cast(y, dtypes.int64)
1300 return [y, u.idx]
1302 return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32])
1305def _sum_states(idx, states):
1306 """Take logsumexp for each unique state out of all label states.
1308 Args:
1309 idx: tensor of shape [batch, label_length] For each sequence, indices into a
1310 set of unique labels as computed by calling unique.
1311 states: tensor of shape [frames, batch, label_length] Log probabilities for
1312 each label state.
1314 Returns:
1315 tensor of shape [frames, batch_size, label_length], log probabilities summed
1316 for each unique label of the sequence.
1317 """
1319 with ops.name_scope("sum_states"):
1320 idx = ops.convert_to_tensor(idx, name="idx")
1321 num_states = _get_dim(states, 2)
1322 states = array_ops.expand_dims(states, axis=2)
1323 one_hot = array_ops.one_hot(
1324 idx,
1325 depth=num_states,
1326 on_value=0.0,
1327 off_value=math_ops.log(0.0),
1328 axis=1)
1329 return math_ops.reduce_logsumexp(states + one_hot, axis=-1)
1332def _forward_backward_log(state_trans_log_probs, initial_state_log_probs,
1333 final_state_log_probs, observed_log_probs,
1334 sequence_length):
1335 """Forward-backward algorithm computed in log domain.
1337 Args:
1338 state_trans_log_probs: tensor of shape [states, states] or if different
1339 transition matrix per batch [batch_size, states, states]
1340 initial_state_log_probs: tensor of shape [batch_size, states]
1341 final_state_log_probs: tensor of shape [batch_size, states]
1342 observed_log_probs: tensor of shape [frames, batch_size, states]
1343 sequence_length: tensor of shape [batch_size]
1345 Returns:
1346 forward backward log probabilities: tensor of shape [frames, batch, states]
1347 log_likelihood: tensor of shape [batch_size]
1349 Raises:
1350 ValueError: If state_trans_log_probs has unknown or incorrect rank.
1351 """
1353 if state_trans_log_probs.shape.ndims == 2:
1354 perm = [1, 0]
1355 elif state_trans_log_probs.shape.ndims == 3:
1356 perm = [0, 2, 1]
1357 else:
1358 raise ValueError(
1359 "Rank of argument `state_trans_log_probs` must be known and equal to "
1360 f"2 or 3. Received state_trans_log_probs={state_trans_log_probs} of "
1361 f"rank {state_trans_log_probs.shape.ndims}")
1363 bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm)
1364 batch_size = _get_dim(observed_log_probs, 1)
1366 def _forward(state_log_prob, obs_log_prob):
1367 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast.
1368 state_log_prob += state_trans_log_probs
1369 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1370 state_log_prob += obs_log_prob
1371 log_prob_sum = math_ops.reduce_logsumexp(
1372 state_log_prob, axis=-1, keepdims=True)
1373 state_log_prob -= log_prob_sum
1374 return state_log_prob
1376 fwd = _scan(
1377 _forward, observed_log_probs, initial_state_log_probs, inclusive=True)
1379 def _backward(accs, elems):
1380 """Calculate log probs and cumulative sum masked for sequence length."""
1381 state_log_prob, cum_log_sum = accs
1382 obs_log_prob, mask = elems
1383 state_log_prob += obs_log_prob
1384 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast.
1385 state_log_prob += bwd_state_trans_log_probs
1386 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1388 log_prob_sum = math_ops.reduce_logsumexp(
1389 state_log_prob, axis=-1, keepdims=True)
1390 state_log_prob -= log_prob_sum
1392 cum_log_sum += array_ops.squeeze(log_prob_sum, axis=[-1]) * mask
1393 batched_mask = array_ops.expand_dims(mask, axis=1)
1394 out = state_log_prob * batched_mask
1395 out += final_state_log_probs * (1.0 - batched_mask)
1396 return out, cum_log_sum
1398 zero_log_sum = array_ops.zeros([batch_size])
1399 maxlen = _get_dim(observed_log_probs, 0)
1400 mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32)
1401 mask = array_ops.transpose(mask, perm=[1, 0])
1403 bwd, cum_log_sum = _scan(
1404 _backward, (observed_log_probs, mask),
1405 (final_state_log_probs, zero_log_sum),
1406 reverse=True,
1407 inclusive=True)
1409 fwd_bwd_log_probs = fwd[1:] + bwd[1:]
1410 fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp(
1411 fwd_bwd_log_probs, axis=2, keepdims=True)
1412 fwd_bwd_log_probs -= fwd_bwd_log_probs_sum
1413 fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2))
1415 log_likelihood = bwd[0, :, 0] + cum_log_sum[0]
1417 return fwd_bwd_log_probs, log_likelihood
1420# TODO(tombagby): This is currently faster for the ctc implementation than using
1421# functional_ops.scan, but could be replaced by that or something similar if
1422# things change.
1423def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
1424 """Repeatedly applies callable `fn` to a sequence of elements.
1426 Implemented by functional_ops.While, tpu friendly, no gradient.
1428 This is similar to functional_ops.scan but significantly faster on tpu/gpu
1429 for the forward backward use case.
1431 Examples:
1432 scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]
1434 Multiple accumulators:
1435 scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))
1437 Multiple inputs:
1438 scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)
1440 Args:
1441 fn: callable, fn(accumulators, element) return new accumulator values. The
1442 (possibly nested) sequence of accumulators is the same as `initial` and
1443 the return value must have the same structure.
1444 elems: A (possibly nested) tensor which will be unpacked along the first
1445 dimension. The resulting slices will be the second argument to fn. The
1446 first dimension of all nested input tensors must be the same.
1447 initial: A tensor or (possibly nested) sequence of tensors with initial
1448 values for the accumulators.
1449 reverse: (optional) True enables scan and output elems in reverse order.
1450 inclusive: (optional) True includes the initial accumulator values in the
1451 output. Length of output will be len(elem sequence) + 1. Not meaningful if
1452 final_only is True.
1453 final_only: (optional) When True, return only the final accumulated values,
1454 not the concatenation of accumulated values for each input.
1456 Returns:
1457 A (possibly nested) sequence of tensors with the results of applying fn
1458 to tensors unpacked from elems and previous accumulator values.
1459 """
1461 flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
1462 num_elems = array_ops.shape(flat_elems[0])[0]
1463 pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x)
1464 flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
1465 pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
1466 accum_dtypes = [x.dtype for x in flat_initial]
1467 num_accums = len(flat_initial)
1469 # Types for counter, [outputs], [accumulators] loop arguments.
1470 if final_only:
1471 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
1472 else:
1473 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes
1475 # TODO(tombagby): Update to tfe.defun
1476 def cond(i, num_elems, *args):
1477 del args
1478 return i >= 0 if reverse else i < num_elems
1480 # The loop *args are [output tensors] + [accumulator tensors] which must
1481 # be paired. Each output corresponds to one accumulator.
1482 def body(i, num_elems, *args):
1483 """Loop body."""
1484 i.set_shape([])
1485 if final_only:
1486 accum = args
1487 else:
1488 out, accum = args[:num_accums], args[num_accums:]
1489 slices = [array_ops.gather(e, i) for e in flat_elems]
1490 accum = fn(pack(accum), pack_elems(slices))
1491 flat_accum = nest.flatten(accum)
1492 if final_only:
1493 new_out = []
1494 else:
1495 update_i = i + 1 if inclusive and not reverse else i
1496 new_out = [
1497 inplace_ops.alias_inplace_update(x, update_i, y)
1498 for x, y in zip(out, flat_accum)
1499 ]
1500 i = i - 1 if reverse else i + 1
1501 return [i, num_elems] + new_out + flat_accum
1503 init_i = (
1504 array_ops.shape(flat_elems[0])[0] -
1505 1 if reverse else constant_op.constant(0, dtype=dtypes.int32))
1506 outputs = []
1507 if not final_only:
1508 num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0)
1509 for initial_accum in flat_initial:
1510 out_shape = array_ops.concat(
1511 [[num_outputs], array_ops.shape(initial_accum)], 0)
1512 out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True)
1513 if inclusive:
1514 out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0),
1515 initial_accum)
1516 outputs.append(out)
1517 loop_in = [init_i, num_elems] + outputs + flat_initial
1518 hostmem = [
1519 i for i, x in enumerate(loop_in)
1520 if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
1521 ]
1523 if context.executing_eagerly():
1524 loop_results = loop_in
1525 while cond(*loop_results):
1526 loop_results = body(*loop_results)
1527 else:
1528 # TODO(tombagby): Update to while_v2.
1529 cond = function.Defun(*loop_dtypes)(cond)
1530 body = function.Defun(*loop_dtypes)(body)
1531 loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem)
1532 out = loop_results[2:num_accums + 2]
1533 return pack(out)
1536def _get_dim(tensor, i):
1537 """Get value of tensor shape[i] preferring static value if available."""
1538 return tensor_shape.dimension_value(
1539 tensor.shape[i]) or array_ops.shape(tensor)[i]