Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/map_op.py: 37%

62 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The implementation of `tf.data.Dataset.map`.""" 

16 

17import warnings 

18 

19from tensorflow.python.data.ops import dataset_ops 

20from tensorflow.python.data.ops import debug_mode 

21from tensorflow.python.data.ops import structured_function 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.ops import gen_dataset_ops 

25 

26 

27def _map_v2(input_dataset, # pylint: disable=unused-private-name 

28 map_func, 

29 num_parallel_calls=None, 

30 deterministic=None, 

31 name=None): 

32 """See `Dataset.map()` for details.""" 

33 if num_parallel_calls is None or debug_mode.DEBUG_MODE: 

34 if deterministic is not None and not debug_mode.DEBUG_MODE: 

35 warnings.warn("The `deterministic` argument has no effect unless the " 

36 "`num_parallel_calls` argument is specified.") 

37 return _MapDataset( 

38 input_dataset, map_func, preserve_cardinality=True, name=name) 

39 else: 

40 return _ParallelMapDataset( 

41 input_dataset, 

42 map_func, 

43 num_parallel_calls=num_parallel_calls, 

44 deterministic=deterministic, 

45 preserve_cardinality=True, 

46 name=name) 

47 

48 

49def _map_v1(input_dataset, # pylint: disable=unused-private-name 

50 map_func, 

51 num_parallel_calls=None, 

52 deterministic=None): 

53 """See `Dataset.map()` for details.""" 

54 if num_parallel_calls is None or debug_mode.DEBUG_MODE: 

55 return dataset_ops.DatasetV1Adapter( 

56 _MapDataset(input_dataset, map_func, preserve_cardinality=False)) 

57 else: 

58 return dataset_ops.DatasetV1Adapter( 

59 _ParallelMapDataset( 

60 input_dataset, 

61 map_func, 

62 num_parallel_calls, 

63 deterministic, 

64 preserve_cardinality=False)) 

65 

66 

67def _map_v1_with_legacy_function( # pylint: disable=unused-private-name 

68 input_dataset, 

69 map_func, 

70 num_parallel_calls=None, 

71 deterministic=None): 

72 """See `Dataset.map()` for details.""" 

73 if num_parallel_calls is None: 

74 if deterministic is not None: 

75 warnings.warn("The `deterministic` argument has no effect unless the " 

76 "`num_parallel_calls` argument is specified.") 

77 return dataset_ops.DatasetV1Adapter( 

78 _MapDataset( 

79 input_dataset, 

80 map_func, 

81 preserve_cardinality=False, 

82 use_legacy_function=True)) 

83 else: 

84 return dataset_ops.DatasetV1Adapter( 

85 _ParallelMapDataset( 

86 input_dataset, 

87 map_func, 

88 num_parallel_calls, 

89 deterministic, 

90 preserve_cardinality=False, 

91 use_legacy_function=True)) 

92 

93 

94class _MapDataset(dataset_ops.UnaryDataset): 

95 """A `Dataset` that maps a function over elements in its input.""" 

96 

97 def __init__(self, 

98 input_dataset, 

99 map_func, 

100 use_inter_op_parallelism=True, 

101 preserve_cardinality=True, 

102 use_legacy_function=False, 

103 name=None): 

104 self._input_dataset = input_dataset 

105 self._use_inter_op_parallelism = use_inter_op_parallelism 

106 self._preserve_cardinality = preserve_cardinality 

107 self._map_func = structured_function.StructuredFunctionWrapper( 

108 map_func, 

109 self._transformation_name(), 

110 dataset=input_dataset, 

111 use_legacy_function=use_legacy_function) 

112 self._name = name 

113 variant_tensor = gen_dataset_ops.map_dataset( 

114 input_dataset._variant_tensor, # pylint: disable=protected-access 

115 self._map_func.function.captured_inputs, 

116 f=self._map_func.function, 

117 use_inter_op_parallelism=self._use_inter_op_parallelism, 

118 preserve_cardinality=self._preserve_cardinality, 

119 **self._common_args) 

120 super().__init__(input_dataset, variant_tensor) 

121 

122 def _functions(self): 

123 return [self._map_func] 

124 

125 @property 

126 def element_spec(self): 

127 return self._map_func.output_structure 

128 

129 def _transformation_name(self): 

130 return "Dataset.map()" 

131 

132 

133class _ParallelMapDataset(dataset_ops.UnaryDataset): 

134 """A `Dataset` that maps a function over elements in its input in parallel.""" 

135 

136 def __init__(self, 

137 input_dataset, 

138 map_func, 

139 num_parallel_calls, 

140 deterministic, 

141 use_inter_op_parallelism=True, 

142 preserve_cardinality=False, 

143 use_legacy_function=False, 

144 name=None): 

145 """See `Dataset.map()` for details.""" 

146 self._input_dataset = input_dataset 

147 self._use_inter_op_parallelism = use_inter_op_parallelism 

148 self._map_func = structured_function.StructuredFunctionWrapper( 

149 map_func, 

150 self._transformation_name(), 

151 dataset=input_dataset, 

152 use_legacy_function=use_legacy_function) 

153 if deterministic is None: 

154 self._deterministic = "default" 

155 elif deterministic: 

156 self._deterministic = "true" 

157 else: 

158 self._deterministic = "false" 

159 self._preserve_cardinality = preserve_cardinality 

160 self._num_parallel_calls = ops.convert_to_tensor( 

161 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 

162 self._name = name 

163 variant_tensor = gen_dataset_ops.parallel_map_dataset_v2( 

164 input_dataset._variant_tensor, # pylint: disable=protected-access 

165 self._map_func.function.captured_inputs, 

166 f=self._map_func.function, 

167 num_parallel_calls=self._num_parallel_calls, 

168 deterministic=self._deterministic, 

169 use_inter_op_parallelism=self._use_inter_op_parallelism, 

170 preserve_cardinality=self._preserve_cardinality, 

171 **self._common_args) 

172 super().__init__(input_dataset, variant_tensor) 

173 

174 def _functions(self): 

175 return [self._map_func] 

176 

177 @property 

178 def element_spec(self): 

179 return self._map_func.output_structure 

180 

181 def _transformation_name(self): 

182 return "Dataset.map()"