Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/map_op.py: 37%
62 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The implementation of `tf.data.Dataset.map`."""
17import warnings
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.data.ops import debug_mode
21from tensorflow.python.data.ops import structured_function
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import gen_dataset_ops
27def _map_v2(input_dataset, # pylint: disable=unused-private-name
28 map_func,
29 num_parallel_calls=None,
30 deterministic=None,
31 name=None):
32 """See `Dataset.map()` for details."""
33 if num_parallel_calls is None or debug_mode.DEBUG_MODE:
34 if deterministic is not None and not debug_mode.DEBUG_MODE:
35 warnings.warn("The `deterministic` argument has no effect unless the "
36 "`num_parallel_calls` argument is specified.")
37 return _MapDataset(
38 input_dataset, map_func, preserve_cardinality=True, name=name)
39 else:
40 return _ParallelMapDataset(
41 input_dataset,
42 map_func,
43 num_parallel_calls=num_parallel_calls,
44 deterministic=deterministic,
45 preserve_cardinality=True,
46 name=name)
49def _map_v1(input_dataset, # pylint: disable=unused-private-name
50 map_func,
51 num_parallel_calls=None,
52 deterministic=None):
53 """See `Dataset.map()` for details."""
54 if num_parallel_calls is None or debug_mode.DEBUG_MODE:
55 return dataset_ops.DatasetV1Adapter(
56 _MapDataset(input_dataset, map_func, preserve_cardinality=False))
57 else:
58 return dataset_ops.DatasetV1Adapter(
59 _ParallelMapDataset(
60 input_dataset,
61 map_func,
62 num_parallel_calls,
63 deterministic,
64 preserve_cardinality=False))
67def _map_v1_with_legacy_function( # pylint: disable=unused-private-name
68 input_dataset,
69 map_func,
70 num_parallel_calls=None,
71 deterministic=None):
72 """See `Dataset.map()` for details."""
73 if num_parallel_calls is None:
74 if deterministic is not None:
75 warnings.warn("The `deterministic` argument has no effect unless the "
76 "`num_parallel_calls` argument is specified.")
77 return dataset_ops.DatasetV1Adapter(
78 _MapDataset(
79 input_dataset,
80 map_func,
81 preserve_cardinality=False,
82 use_legacy_function=True))
83 else:
84 return dataset_ops.DatasetV1Adapter(
85 _ParallelMapDataset(
86 input_dataset,
87 map_func,
88 num_parallel_calls,
89 deterministic,
90 preserve_cardinality=False,
91 use_legacy_function=True))
94class _MapDataset(dataset_ops.UnaryDataset):
95 """A `Dataset` that maps a function over elements in its input."""
97 def __init__(self,
98 input_dataset,
99 map_func,
100 use_inter_op_parallelism=True,
101 preserve_cardinality=True,
102 use_legacy_function=False,
103 name=None):
104 self._input_dataset = input_dataset
105 self._use_inter_op_parallelism = use_inter_op_parallelism
106 self._preserve_cardinality = preserve_cardinality
107 self._map_func = structured_function.StructuredFunctionWrapper(
108 map_func,
109 self._transformation_name(),
110 dataset=input_dataset,
111 use_legacy_function=use_legacy_function)
112 self._name = name
113 variant_tensor = gen_dataset_ops.map_dataset(
114 input_dataset._variant_tensor, # pylint: disable=protected-access
115 self._map_func.function.captured_inputs,
116 f=self._map_func.function,
117 use_inter_op_parallelism=self._use_inter_op_parallelism,
118 preserve_cardinality=self._preserve_cardinality,
119 **self._common_args)
120 super().__init__(input_dataset, variant_tensor)
122 def _functions(self):
123 return [self._map_func]
125 @property
126 def element_spec(self):
127 return self._map_func.output_structure
129 def _transformation_name(self):
130 return "Dataset.map()"
133class _ParallelMapDataset(dataset_ops.UnaryDataset):
134 """A `Dataset` that maps a function over elements in its input in parallel."""
136 def __init__(self,
137 input_dataset,
138 map_func,
139 num_parallel_calls,
140 deterministic,
141 use_inter_op_parallelism=True,
142 preserve_cardinality=False,
143 use_legacy_function=False,
144 name=None):
145 """See `Dataset.map()` for details."""
146 self._input_dataset = input_dataset
147 self._use_inter_op_parallelism = use_inter_op_parallelism
148 self._map_func = structured_function.StructuredFunctionWrapper(
149 map_func,
150 self._transformation_name(),
151 dataset=input_dataset,
152 use_legacy_function=use_legacy_function)
153 if deterministic is None:
154 self._deterministic = "default"
155 elif deterministic:
156 self._deterministic = "true"
157 else:
158 self._deterministic = "false"
159 self._preserve_cardinality = preserve_cardinality
160 self._num_parallel_calls = ops.convert_to_tensor(
161 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
162 self._name = name
163 variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
164 input_dataset._variant_tensor, # pylint: disable=protected-access
165 self._map_func.function.captured_inputs,
166 f=self._map_func.function,
167 num_parallel_calls=self._num_parallel_calls,
168 deterministic=self._deterministic,
169 use_inter_op_parallelism=self._use_inter_op_parallelism,
170 preserve_cardinality=self._preserve_cardinality,
171 **self._common_args)
172 super().__init__(input_dataset, variant_tensor)
174 def _functions(self):
175 return [self._map_func]
177 @property
178 def element_spec(self):
179 return self._map_func.output_structure
181 def _transformation_name(self):
182 return "Dataset.map()"