Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/batch_op.py: 34%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The implementation of `tf.data.Dataset.batch`.""" 

16 

17import warnings 

18 

19from tensorflow.python.data.ops import dataset_ops 

20from tensorflow.python.data.ops import debug_mode 

21from tensorflow.python.data.util import nest 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.framework import tensor_util 

25from tensorflow.python.ops import gen_dataset_ops 

26 

27 

28def _batch(input_dataset, 

29 batch_size, 

30 drop_remainder=False, 

31 num_parallel_calls=None, 

32 deterministic=None, 

33 name=None): 

34 """See `Dataset.batch` for details.""" 

35 if num_parallel_calls is None or debug_mode.DEBUG_MODE: 

36 if deterministic is not None and not debug_mode.DEBUG_MODE: 

37 warnings.warn("The `deterministic` argument has no effect unless the " 

38 "`num_parallel_calls` argument is specified.") 

39 return _BatchDataset(input_dataset, batch_size, drop_remainder, name=name) 

40 else: 

41 return _ParallelBatchDataset( 

42 input_dataset, 

43 batch_size, 

44 drop_remainder, 

45 num_parallel_calls, 

46 deterministic, 

47 name=name) 

48 

49 

50class _BatchDataset(dataset_ops.UnaryDataset): 

51 """A `Dataset` that batches contiguous elements from its input.""" 

52 

53 def __init__(self, input_dataset, batch_size, drop_remainder, name=None): 

54 """See `Dataset.batch()` for details.""" 

55 self._input_dataset = input_dataset 

56 self._batch_size = ops.convert_to_tensor( 

57 batch_size, dtype=dtypes.int64, name="batch_size") 

58 self._drop_remainder = ops.convert_to_tensor( 

59 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 

60 

61 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 

62 # pylint: disable=protected-access 

63 if constant_drop_remainder: 

64 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 

65 # or `False` (explicitly retaining the remainder). 

66 # pylint: disable=g-long-lambda 

67 constant_batch_size = tensor_util.constant_value(self._batch_size) 

68 self._structure = nest.map_structure( 

69 lambda component_spec: component_spec._batch(constant_batch_size), 

70 input_dataset.element_spec) 

71 else: 

72 self._structure = nest.map_structure( 

73 lambda component_spec: component_spec._batch(None), 

74 input_dataset.element_spec) 

75 

76 self._name = name 

77 variant_tensor = gen_dataset_ops.batch_dataset_v2( 

78 input_dataset._variant_tensor, 

79 batch_size=self._batch_size, 

80 drop_remainder=self._drop_remainder, 

81 **self._common_args) 

82 super().__init__(input_dataset, variant_tensor) 

83 

84 @property 

85 def element_spec(self): 

86 return self._structure 

87 

88 

89class _ParallelBatchDataset(dataset_ops.UnaryDataset): 

90 """A `Dataset` that batches contiguous elements from its input in parallel.""" 

91 

92 def __init__(self, 

93 input_dataset, 

94 batch_size, 

95 drop_remainder, 

96 num_parallel_calls, 

97 deterministic, 

98 name=None): 

99 """See `Dataset.batch()` for details.""" 

100 self._input_dataset = input_dataset 

101 self._batch_size = ops.convert_to_tensor( 

102 batch_size, dtype=dtypes.int64, name="batch_size") 

103 self._drop_remainder = ops.convert_to_tensor( 

104 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 

105 self._num_parallel_calls = ops.convert_to_tensor( 

106 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 

107 if deterministic is None: 

108 self._deterministic = "default" 

109 elif deterministic: 

110 self._deterministic = "true" 

111 else: 

112 self._deterministic = "false" 

113 

114 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) 

115 # pylint: disable=protected-access 

116 if constant_drop_remainder: 

117 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 

118 # or `False` (explicitly retaining the remainder). 

119 # pylint: disable=g-long-lambda 

120 constant_batch_size = tensor_util.constant_value(self._batch_size) 

121 self._structure = nest.map_structure( 

122 lambda component_spec: component_spec._batch(constant_batch_size), 

123 input_dataset.element_spec) 

124 else: 

125 self._structure = nest.map_structure( 

126 lambda component_spec: component_spec._batch(None), 

127 input_dataset.element_spec) 

128 

129 self._name = name 

130 variant_tensor = gen_dataset_ops.parallel_batch_dataset( 

131 input_dataset._variant_tensor, 

132 batch_size=self._batch_size, 

133 num_parallel_calls=self._num_parallel_calls, 

134 drop_remainder=self._drop_remainder, 

135 deterministic=self._deterministic, 

136 **self._common_args) 

137 

138 super().__init__(input_dataset, variant_tensor) 

139 

140 @property 

141 def element_spec(self): 

142 return self._structure