Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/partial_batch_padding_handler.py: 21%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility object to handler partial batches for TPUStrategy.""" 

16# pylint: disable=protected-access 

17 

18import numpy as np 

19 

20from tensorflow.python.framework import tensor_util 

21from tensorflow.python.keras import backend 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.util import nest 

24 

25 

26class PartialBatchPaddingHandler(object): 

27 """A container that holds info about partial batches for `predict()`.""" 

28 

29 def __init__(self, output_shape): 

30 self.padded_batch_size = 0 

31 self.padding_mask = array_ops.zeros(0) 

32 self.output_shape = output_shape 

33 

34 def get_real_batch_size(self, dataset_batch): 

35 """Returns the number of elements in a potentially partial batch.""" 

36 if isinstance(dataset_batch, (tuple, list)): 

37 dataset_batch = dataset_batch[0] 

38 

39 assert nest.flatten(dataset_batch) 

40 

41 def _find_any_tensor(batch_features): 

42 tensors = [ 

43 x for x in nest.flatten(batch_features) if tensor_util.is_tf_type(x) 

44 ] 

45 if not tensors: 

46 raise ValueError('Cannot find any Tensor in features dict.') 

47 return tensors[0] 

48 

49 return backend.cast(backend.shape(_find_any_tensor(dataset_batch))[0], 

50 dtype='int64') 

51 

52 def update_mask(self, padding_mask, dataset_batch): 

53 """Calculate and cache the amount of padding required for a batch.""" 

54 original_batch_size = self.get_real_batch_size(dataset_batch) 

55 missing_count = self.padded_batch_size - original_batch_size 

56 mask = backend.concatenate([array_ops.ones(original_batch_size), 

57 array_ops.zeros(missing_count)], axis=0) 

58 return backend.concatenate([padding_mask, mask], axis=0) 

59 

60 def pad_batch(self, *dataset_batch_elements): 

61 """Pads out the batch dimension of a tensor to the complete batch size.""" 

62 def _pad(batch): 

63 """Helper function to pad nested data within each batch elements.""" 

64 padded_dict_batch = {} 

65 if isinstance(batch, dict): 

66 for key, value in batch.items(): 

67 padded_dict_batch[key] = _pad(value) 

68 return padded_dict_batch 

69 

70 rank = len(batch.shape) 

71 assert rank > 0 

72 missing_count = (self.padded_batch_size - 

73 self.get_real_batch_size(batch)) 

74 padding = backend.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) 

75 return array_ops.pad(batch, padding, 'constant') 

76 

77 if len(dataset_batch_elements) == 1: 

78 return _pad(dataset_batch_elements[0]) 

79 

80 batch_elements = [] 

81 for batch_element in dataset_batch_elements: 

82 batch_elements.append(_pad(batch_element)) 

83 return tuple(batch_elements) 

84 

85 def apply_mask(self, prediction_result): 

86 """Removes prediction output that corresponds to padded input.""" 

87 padding_mask = backend.get_value(self.padding_mask) 

88 assert len(padding_mask.shape) == 1 

89 

90 if len(self.output_shape) == 1: 

91 prediction = np.take(prediction_result, 

92 np.nonzero( 

93 padding_mask[:len(prediction_result)]), 

94 axis=0) 

95 if prediction.shape[0] == 1: 

96 prediction = np.squeeze(prediction, axis=0) 

97 return prediction 

98 

99 else: 

100 predictions = [] 

101 for i in range(len(self.output_shape)): 

102 prediction = prediction_result[i] 

103 prediction = np.take(prediction, np.nonzero( 

104 padding_mask[:len(prediction)]), axis=0) 

105 predictions.append(np.squeeze(prediction)) 

106 

107 return predictions