Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py: 24%
123 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for managing tf.data user-defined functions."""
17import warnings
19from tensorflow.python.data.ops import debug_mode
20from tensorflow.python.data.util import nest
21from tensorflow.python.data.util import structure
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import script_ops
28from tensorflow.python.util import function_utils
29from tensorflow.python.util import lazy_loader
30from tensorflow.python.util import variable_utils
32autograph = lazy_loader.LazyLoader(
33 "autograph", globals(),
34 "tensorflow.python.autograph.impl.api")
35# TODO(mdan): Create a public API for this.
36autograph_ctx = lazy_loader.LazyLoader(
37 "autograph_ctx", globals(),
38 "tensorflow.python.autograph.core.ag_ctx")
41def _should_pack(arg):
42 """Determines whether the caller needs to pack the argument in a tuple.
44 If user-defined function returns a list of tensors, `nest.flatten()` and
45 `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors
46 into a single tensor because the tf.data version of `nest.flatten()` does
47 not recurse into lists. Since it is more likely that the list arose from
48 returning the result of an operation (such as `tf.numpy_function()`) that
49 returns a list of not-necessarily-stackable tensors, we treat the returned
50 value as a `tuple` instead. A user wishing to pack the return value into a
51 single tensor can use an explicit `tf.stack()` before returning.
53 Args:
54 arg: argument to check
56 Returns:
57 Indication of whether the caller needs to pack the argument in a tuple.
58 """
59 return isinstance(arg, list)
62def _should_unpack(arg):
63 """Determines whether the caller needs to unpack the argument from a tuple.
65 Args:
66 arg: argument to check
68 Returns:
69 Indication of whether the caller needs to unpack the argument from a tuple.
70 """
71 return type(arg) is tuple # pylint: disable=unidiomatic-typecheck
74class StructuredFunctionWrapper():
75 """A function wrapper that supports structured arguments and return values."""
77 def __init__(self,
78 func,
79 transformation_name,
80 dataset=None,
81 input_classes=None,
82 input_shapes=None,
83 input_types=None,
84 input_structure=None,
85 add_to_graph=True,
86 use_legacy_function=False,
87 defun_kwargs=None):
88 """Creates a new `StructuredFunctionWrapper` for the given function.
90 Args:
91 func: A function from a (nested) structure to another (nested) structure.
92 transformation_name: Human-readable name of the transformation in which
93 this function is being instantiated, for error messages.
94 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
95 dataset will be assumed as the structure for `func` arguments; otherwise
96 `input_classes`, `input_shapes`, and `input_types` must be defined.
97 input_classes: (Optional.) A (nested) structure of `type`. If given, this
98 argument defines the Python types for `func` arguments.
99 input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If
100 given, this argument defines the shapes and structure for `func`
101 arguments.
102 input_types: (Optional.) A (nested) structure of `tf.DType`. If given,
103 this argument defines the element types and structure for `func`
104 arguments.
105 input_structure: (Optional.) A `Structure` object. If given, this argument
106 defines the element types and structure for `func` arguments.
107 add_to_graph: (Optional.) If `True`, the function will be added to the
108 default graph, if it exists.
109 use_legacy_function: (Optional.) A boolean that determines whether the
110 function be created using `tensorflow.python.eager.function.defun`
111 (default behavior) or `tensorflow.python.framework.function.Defun`
112 (legacy behavior).
113 defun_kwargs: (Optional.) A dictionary mapping string argument names to
114 values. If supplied, will be passed to `function` as keyword arguments.
116 Raises:
117 ValueError: If an invalid combination of `dataset`, `input_classes`,
118 `input_shapes`, and `input_types` is passed.
119 """
120 # pylint: disable=protected-access
121 if input_structure is None:
122 if dataset is None:
123 if input_classes is None or input_shapes is None or input_types is None:
124 raise ValueError("Either `dataset`, `input_structure` or all of "
125 "`input_classes`, `input_shapes`, and `input_types` "
126 "must be specified.")
127 self._input_structure = structure.convert_legacy_structure(
128 input_types, input_shapes, input_classes)
129 else:
130 if not (input_classes is None and input_shapes is None and
131 input_types is None):
132 raise ValueError("Either `dataset`, `input_structure` or all of "
133 "`input_classes`, `input_shapes`, and `input_types` "
134 "must be specified.")
135 self._input_structure = dataset.element_spec
136 else:
137 if not (dataset is None and input_classes is None and
138 input_shapes is None and input_types is None):
139 raise ValueError("Either `dataset`, `input_structure`, or all of "
140 "`input_classes`, `input_shapes`, and `input_types` "
141 "must be specified.")
142 self._input_structure = input_structure
144 self._func = func
146 if defun_kwargs is None:
147 defun_kwargs = {}
149 readable_transformation_name = transformation_name.replace(
150 ".", "_")[:-2] if len(transformation_name) > 2 else ""
152 func_name = "_".join(
153 [readable_transformation_name,
154 function_utils.get_func_name(func)])
155 # Sanitize function name to remove symbols that interfere with graph
156 # construction.
157 for symbol in ["<", ">", "\\", "'", " "]:
158 func_name = func_name.replace(symbol, "")
160 ag_ctx = autograph_ctx.control_status_ctx()
162 def wrapper_helper(*args):
163 """Wrapper for passing nested structures to and from tf.data functions."""
164 nested_args = structure.from_compatible_tensor_list(
165 self._input_structure, args)
166 if not _should_unpack(nested_args):
167 nested_args = (nested_args,)
168 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
169 ret = variable_utils.convert_variables_to_tensors(ret)
170 if _should_pack(ret):
171 ret = tuple(ret)
173 try:
174 self._output_structure = structure.type_spec_from_value(ret)
175 except (ValueError, TypeError) as e:
176 raise TypeError(f"Unsupported return value from function passed to "
177 f"{transformation_name}: {ret}.") from e
178 return ret
180 def trace_legacy_function(defun_kwargs):
182 @function.Defun(*structure.get_flat_tensor_types(self._input_structure),
183 **defun_kwargs)
184 def wrapped_fn(*args):
185 ret = wrapper_helper(*args)
186 return structure.to_tensor_list(self._output_structure, ret)
188 return lambda: wrapped_fn
190 def trace_py_function(defun_kwargs):
191 # First we trace the function to infer the output structure.
192 def unused(*args): # pylint: disable=missing-docstring,unused-variable
193 ret = wrapper_helper(*args)
194 ret = structure.to_tensor_list(self._output_structure, ret)
195 return [ops.convert_to_tensor(t) for t in ret]
197 func_name = defun_kwargs.pop("func_name", "unused")
198 tf_function = def_function.Function(
199 python_function=unused,
200 name=func_name,
201 input_signature=structure.get_flat_tensor_specs(
202 self._input_structure
203 ),
204 autograph=False,
205 experimental_attributes=defun_kwargs,
206 )
208 _ = tf_function.get_concrete_function()
210 def py_function_wrapper(*args):
211 nested_args = structure.from_compatible_tensor_list(
212 self._input_structure, args)
213 if not _should_unpack(nested_args):
214 nested_args = (nested_args,)
215 ret = self._func(*nested_args)
216 if _should_pack(ret):
217 ret = tuple(ret)
218 ret = structure.to_tensor_list(self._output_structure, ret)
219 return [ops.convert_to_tensor(t) for t in ret]
221 # Next we trace the function wrapped in `eager_py_func` to force eager
222 # execution.
223 @def_function.function(
224 input_signature=structure.get_flat_tensor_specs(
225 self._input_structure),
226 autograph=False,
227 experimental_attributes=defun_kwargs)
228 def wrapped_fn(*args): # pylint: disable=missing-docstring
229 return script_ops.eager_py_func(
230 py_function_wrapper, args,
231 structure.get_flat_tensor_types(self._output_structure))
233 return wrapped_fn.get_concrete_function
235 def trace_tf_function(defun_kwargs):
236 # Note: wrapper_helper will apply autograph based on context.
237 def wrapped_fn(*args): # pylint: disable=missing-docstring
238 ret = wrapper_helper(*args)
239 ret = structure.to_tensor_list(self._output_structure, ret)
240 return [ops.convert_to_tensor(t) for t in ret]
242 func_name = defun_kwargs.pop("func_name", "wrapped_fn")
243 tf_function = def_function.Function(
244 python_function=wrapped_fn,
245 name=func_name,
246 input_signature=structure.get_flat_tensor_specs(
247 self._input_structure
248 ),
249 autograph=False,
250 experimental_attributes=defun_kwargs,
251 )
253 return tf_function.get_concrete_function
255 if use_legacy_function:
256 defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())})
257 fn_factory = trace_legacy_function(defun_kwargs)
258 else:
259 defun_kwargs.update({"func_name": func_name})
260 defun_kwargs.update({"_tf_data_function": True})
261 if debug_mode.DEBUG_MODE:
262 fn_factory = trace_py_function(defun_kwargs)
263 else:
264 if def_function.functions_run_eagerly():
265 warnings.warn(
266 "Even though the `tf.config.experimental_run_functions_eagerly` "
267 "option is set, this option does not apply to tf.data functions. "
268 "To force eager execution of tf.data functions, please use "
269 "`tf.data.experimental.enable_debug_mode()`.")
270 fn_factory = trace_tf_function(defun_kwargs)
272 self._function = fn_factory()
273 # There is no graph to add in eager mode.
274 add_to_graph &= not context.executing_eagerly()
275 # There are some lifetime issues when a legacy function is not added to a
276 # out-living graph. It's already deprecated so de-prioritizing the fix.
277 add_to_graph |= use_legacy_function
278 if add_to_graph:
279 self._function.add_to_graph(ops.get_default_graph())
281 if not use_legacy_function:
282 outer_graph_seed = ops.get_default_graph().seed
283 if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
284 if self._function.graph._seed_used:
285 warnings.warn(
286 "Seed %s from outer graph might be getting used by function %s, "
287 "if the random op has not been provided any seed. Explicitly set "
288 "the seed in the function if this is not the intended behavior." %
289 (outer_graph_seed, func_name),
290 stacklevel=4)
292 @property
293 def output_structure(self):
294 return self._output_structure
296 @property
297 def output_classes(self):
298 return nest.map_structure(
299 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
300 self._output_structure)
302 @property
303 def output_shapes(self):
304 return nest.map_structure(
305 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
306 self._output_structure)
308 @property
309 def output_types(self):
310 return nest.map_structure(
311 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
312 self._output_structure)
314 @property
315 def function(self):
316 return self._function