Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/partial_batch_padding_handler.py: 18%

56 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility object to handler partial batches for TPUStrategy.""" 

16 

17import numpy as np 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21 

22 

23class PartialBatchPaddingHandler: 

24 """A container that holds info about partial batches for `predict()`.""" 

25 

26 def __init__(self, output_shape): 

27 self.padded_batch_size = 0 

28 self.padding_mask = tf.zeros(0) 

29 self.output_shape = output_shape 

30 

31 def get_real_batch_size(self, dataset_batch): 

32 """Returns the number of elements in a potentially partial batch.""" 

33 if isinstance(dataset_batch, (tuple, list)): 

34 dataset_batch = dataset_batch[0] 

35 

36 assert tf.nest.flatten(dataset_batch) 

37 

38 def _find_any_tensor(batch_features): 

39 tensors = [ 

40 x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x) 

41 ] 

42 if not tensors: 

43 raise ValueError("Cannot find any Tensor in features dict.") 

44 return tensors[0] 

45 

46 return backend.cast( 

47 backend.shape(_find_any_tensor(dataset_batch))[0], dtype="int64" 

48 ) 

49 

50 def update_mask(self, padding_mask, dataset_batch): 

51 """Calculate and cache the amount of padding required for a batch.""" 

52 original_batch_size = self.get_real_batch_size(dataset_batch) 

53 missing_count = self.padded_batch_size - original_batch_size 

54 mask = backend.concatenate( 

55 [tf.ones(original_batch_size), tf.zeros(missing_count)], axis=0 

56 ) 

57 return backend.concatenate([padding_mask, mask], axis=0) 

58 

59 def pad_batch(self, *dataset_batch_elements): 

60 """Pads the batch dimension of a tensor to the complete batch size.""" 

61 

62 def _pad(batch): 

63 """Helper function to pad nested data within each batch elements.""" 

64 padded_dict_batch = {} 

65 if isinstance(batch, dict): 

66 for key, value in batch.items(): 

67 padded_dict_batch[key] = _pad(value) 

68 return padded_dict_batch 

69 

70 rank = len(batch.shape) 

71 assert rank > 0 

72 missing_count = self.padded_batch_size - self.get_real_batch_size( 

73 batch 

74 ) 

75 padding = backend.stack( 

76 [[0, missing_count]] + [[0, 0]] * (rank - 1) 

77 ) 

78 return tf.pad(batch, padding, "constant") 

79 

80 if len(dataset_batch_elements) == 1: 

81 return _pad(dataset_batch_elements[0]) 

82 

83 batch_elements = [] 

84 for batch_element in dataset_batch_elements: 

85 batch_elements.append(_pad(batch_element)) 

86 return tuple(batch_elements) 

87 

88 def apply_mask(self, prediction_result): 

89 """Removes prediction output that corresponds to padded input.""" 

90 padding_mask = backend.get_value(self.padding_mask) 

91 assert len(padding_mask.shape) == 1 

92 

93 if len(self.output_shape) == 1: 

94 prediction = np.take( 

95 prediction_result, 

96 np.nonzero(padding_mask[: len(prediction_result)]), 

97 axis=0, 

98 ) 

99 if prediction.shape[0] == 1: 

100 prediction = np.squeeze(prediction, axis=0) 

101 return prediction 

102 

103 else: 

104 predictions = [] 

105 for i in range(len(self.output_shape)): 

106 prediction = prediction_result[i] 

107 prediction = np.take( 

108 prediction, 

109 np.nonzero(padding_mask[: len(prediction)]), 

110 axis=0, 

111 ) 

112 predictions.append(np.squeeze(prediction)) 

113 

114 return predictions 

115