Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/partial_batch_padding_handler.py: 18%
56 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility object to handler partial batches for TPUStrategy."""
17import numpy as np
18import tensorflow.compat.v2 as tf
20from keras.src import backend
23class PartialBatchPaddingHandler:
24 """A container that holds info about partial batches for `predict()`."""
26 def __init__(self, output_shape):
27 self.padded_batch_size = 0
28 self.padding_mask = tf.zeros(0)
29 self.output_shape = output_shape
31 def get_real_batch_size(self, dataset_batch):
32 """Returns the number of elements in a potentially partial batch."""
33 if isinstance(dataset_batch, (tuple, list)):
34 dataset_batch = dataset_batch[0]
36 assert tf.nest.flatten(dataset_batch)
38 def _find_any_tensor(batch_features):
39 tensors = [
40 x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x)
41 ]
42 if not tensors:
43 raise ValueError("Cannot find any Tensor in features dict.")
44 return tensors[0]
46 return backend.cast(
47 backend.shape(_find_any_tensor(dataset_batch))[0], dtype="int64"
48 )
50 def update_mask(self, padding_mask, dataset_batch):
51 """Calculate and cache the amount of padding required for a batch."""
52 original_batch_size = self.get_real_batch_size(dataset_batch)
53 missing_count = self.padded_batch_size - original_batch_size
54 mask = backend.concatenate(
55 [tf.ones(original_batch_size), tf.zeros(missing_count)], axis=0
56 )
57 return backend.concatenate([padding_mask, mask], axis=0)
59 def pad_batch(self, *dataset_batch_elements):
60 """Pads the batch dimension of a tensor to the complete batch size."""
62 def _pad(batch):
63 """Helper function to pad nested data within each batch elements."""
64 padded_dict_batch = {}
65 if isinstance(batch, dict):
66 for key, value in batch.items():
67 padded_dict_batch[key] = _pad(value)
68 return padded_dict_batch
70 rank = len(batch.shape)
71 assert rank > 0
72 missing_count = self.padded_batch_size - self.get_real_batch_size(
73 batch
74 )
75 padding = backend.stack(
76 [[0, missing_count]] + [[0, 0]] * (rank - 1)
77 )
78 return tf.pad(batch, padding, "constant")
80 if len(dataset_batch_elements) == 1:
81 return _pad(dataset_batch_elements[0])
83 batch_elements = []
84 for batch_element in dataset_batch_elements:
85 batch_elements.append(_pad(batch_element))
86 return tuple(batch_elements)
88 def apply_mask(self, prediction_result):
89 """Removes prediction output that corresponds to padded input."""
90 padding_mask = backend.get_value(self.padding_mask)
91 assert len(padding_mask.shape) == 1
93 if len(self.output_shape) == 1:
94 prediction = np.take(
95 prediction_result,
96 np.nonzero(padding_mask[: len(prediction_result)]),
97 axis=0,
98 )
99 if prediction.shape[0] == 1:
100 prediction = np.squeeze(prediction, axis=0)
101 return prediction
103 else:
104 predictions = []
105 for i in range(len(self.output_shape)):
106 prediction = prediction_result[i]
107 prediction = np.take(
108 prediction,
109 np.nonzero(padding_mask[: len(prediction)]),
110 axis=0,
111 )
112 predictions.append(np.squeeze(prediction))
114 return predictions