Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py: 24%

123 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for managing tf.data user-defined functions.""" 

16 

17import warnings 

18 

19from tensorflow.python.data.ops import debug_mode 

20from tensorflow.python.data.util import nest 

21from tensorflow.python.data.util import structure 

22from tensorflow.python.eager import context 

23from tensorflow.python.eager import def_function 

24 

25from tensorflow.python.framework import function 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import script_ops 

28from tensorflow.python.util import function_utils 

29from tensorflow.python.util import lazy_loader 

30from tensorflow.python.util import variable_utils 

31 

32autograph = lazy_loader.LazyLoader( 

33 "autograph", globals(), 

34 "tensorflow.python.autograph.impl.api") 

35# TODO(mdan): Create a public API for this. 

36autograph_ctx = lazy_loader.LazyLoader( 

37 "autograph_ctx", globals(), 

38 "tensorflow.python.autograph.core.ag_ctx") 

39 

40 

41def _should_pack(arg): 

42 """Determines whether the caller needs to pack the argument in a tuple. 

43 

44 If user-defined function returns a list of tensors, `nest.flatten()` and 

45 `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors 

46 into a single tensor because the tf.data version of `nest.flatten()` does 

47 not recurse into lists. Since it is more likely that the list arose from 

48 returning the result of an operation (such as `tf.numpy_function()`) that 

49 returns a list of not-necessarily-stackable tensors, we treat the returned 

50 value as a `tuple` instead. A user wishing to pack the return value into a 

51 single tensor can use an explicit `tf.stack()` before returning. 

52 

53 Args: 

54 arg: argument to check 

55 

56 Returns: 

57 Indication of whether the caller needs to pack the argument in a tuple. 

58 """ 

59 return isinstance(arg, list) 

60 

61 

62def _should_unpack(arg): 

63 """Determines whether the caller needs to unpack the argument from a tuple. 

64 

65 Args: 

66 arg: argument to check 

67 

68 Returns: 

69 Indication of whether the caller needs to unpack the argument from a tuple. 

70 """ 

71 return type(arg) is tuple # pylint: disable=unidiomatic-typecheck 

72 

73 

74class StructuredFunctionWrapper(): 

75 """A function wrapper that supports structured arguments and return values.""" 

76 

77 def __init__(self, 

78 func, 

79 transformation_name, 

80 dataset=None, 

81 input_classes=None, 

82 input_shapes=None, 

83 input_types=None, 

84 input_structure=None, 

85 add_to_graph=True, 

86 use_legacy_function=False, 

87 defun_kwargs=None): 

88 """Creates a new `StructuredFunctionWrapper` for the given function. 

89 

90 Args: 

91 func: A function from a (nested) structure to another (nested) structure. 

92 transformation_name: Human-readable name of the transformation in which 

93 this function is being instantiated, for error messages. 

94 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this 

95 dataset will be assumed as the structure for `func` arguments; otherwise 

96 `input_classes`, `input_shapes`, and `input_types` must be defined. 

97 input_classes: (Optional.) A (nested) structure of `type`. If given, this 

98 argument defines the Python types for `func` arguments. 

99 input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If 

100 given, this argument defines the shapes and structure for `func` 

101 arguments. 

102 input_types: (Optional.) A (nested) structure of `tf.DType`. If given, 

103 this argument defines the element types and structure for `func` 

104 arguments. 

105 input_structure: (Optional.) A `Structure` object. If given, this argument 

106 defines the element types and structure for `func` arguments. 

107 add_to_graph: (Optional.) If `True`, the function will be added to the 

108 default graph, if it exists. 

109 use_legacy_function: (Optional.) A boolean that determines whether the 

110 function be created using `tensorflow.python.eager.function.defun` 

111 (default behavior) or `tensorflow.python.framework.function.Defun` 

112 (legacy behavior). 

113 defun_kwargs: (Optional.) A dictionary mapping string argument names to 

114 values. If supplied, will be passed to `function` as keyword arguments. 

115 

116 Raises: 

117 ValueError: If an invalid combination of `dataset`, `input_classes`, 

118 `input_shapes`, and `input_types` is passed. 

119 """ 

120 # pylint: disable=protected-access 

121 if input_structure is None: 

122 if dataset is None: 

123 if input_classes is None or input_shapes is None or input_types is None: 

124 raise ValueError("Either `dataset`, `input_structure` or all of " 

125 "`input_classes`, `input_shapes`, and `input_types` " 

126 "must be specified.") 

127 self._input_structure = structure.convert_legacy_structure( 

128 input_types, input_shapes, input_classes) 

129 else: 

130 if not (input_classes is None and input_shapes is None and 

131 input_types is None): 

132 raise ValueError("Either `dataset`, `input_structure` or all of " 

133 "`input_classes`, `input_shapes`, and `input_types` " 

134 "must be specified.") 

135 self._input_structure = dataset.element_spec 

136 else: 

137 if not (dataset is None and input_classes is None and 

138 input_shapes is None and input_types is None): 

139 raise ValueError("Either `dataset`, `input_structure`, or all of " 

140 "`input_classes`, `input_shapes`, and `input_types` " 

141 "must be specified.") 

142 self._input_structure = input_structure 

143 

144 self._func = func 

145 

146 if defun_kwargs is None: 

147 defun_kwargs = {} 

148 

149 readable_transformation_name = transformation_name.replace( 

150 ".", "_")[:-2] if len(transformation_name) > 2 else "" 

151 

152 func_name = "_".join( 

153 [readable_transformation_name, 

154 function_utils.get_func_name(func)]) 

155 # Sanitize function name to remove symbols that interfere with graph 

156 # construction. 

157 for symbol in ["<", ">", "\\", "'", " "]: 

158 func_name = func_name.replace(symbol, "") 

159 

160 ag_ctx = autograph_ctx.control_status_ctx() 

161 

162 def wrapper_helper(*args): 

163 """Wrapper for passing nested structures to and from tf.data functions.""" 

164 nested_args = structure.from_compatible_tensor_list( 

165 self._input_structure, args) 

166 if not _should_unpack(nested_args): 

167 nested_args = (nested_args,) 

168 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) 

169 ret = variable_utils.convert_variables_to_tensors(ret) 

170 if _should_pack(ret): 

171 ret = tuple(ret) 

172 

173 try: 

174 self._output_structure = structure.type_spec_from_value(ret) 

175 except (ValueError, TypeError) as e: 

176 raise TypeError(f"Unsupported return value from function passed to " 

177 f"{transformation_name}: {ret}.") from e 

178 return ret 

179 

180 def trace_legacy_function(defun_kwargs): 

181 

182 @function.Defun(*structure.get_flat_tensor_types(self._input_structure), 

183 **defun_kwargs) 

184 def wrapped_fn(*args): 

185 ret = wrapper_helper(*args) 

186 return structure.to_tensor_list(self._output_structure, ret) 

187 

188 return lambda: wrapped_fn 

189 

190 def trace_py_function(defun_kwargs): 

191 # First we trace the function to infer the output structure. 

192 def unused(*args): # pylint: disable=missing-docstring,unused-variable 

193 ret = wrapper_helper(*args) 

194 ret = structure.to_tensor_list(self._output_structure, ret) 

195 return [ops.convert_to_tensor(t) for t in ret] 

196 

197 func_name = defun_kwargs.pop("func_name", "unused") 

198 tf_function = def_function.Function( 

199 python_function=unused, 

200 name=func_name, 

201 input_signature=structure.get_flat_tensor_specs( 

202 self._input_structure 

203 ), 

204 autograph=False, 

205 experimental_attributes=defun_kwargs, 

206 ) 

207 

208 _ = tf_function.get_concrete_function() 

209 

210 def py_function_wrapper(*args): 

211 nested_args = structure.from_compatible_tensor_list( 

212 self._input_structure, args) 

213 if not _should_unpack(nested_args): 

214 nested_args = (nested_args,) 

215 ret = self._func(*nested_args) 

216 if _should_pack(ret): 

217 ret = tuple(ret) 

218 ret = structure.to_tensor_list(self._output_structure, ret) 

219 return [ops.convert_to_tensor(t) for t in ret] 

220 

221 # Next we trace the function wrapped in `eager_py_func` to force eager 

222 # execution. 

223 @def_function.function( 

224 input_signature=structure.get_flat_tensor_specs( 

225 self._input_structure), 

226 autograph=False, 

227 experimental_attributes=defun_kwargs) 

228 def wrapped_fn(*args): # pylint: disable=missing-docstring 

229 return script_ops.eager_py_func( 

230 py_function_wrapper, args, 

231 structure.get_flat_tensor_types(self._output_structure)) 

232 

233 return wrapped_fn.get_concrete_function 

234 

235 def trace_tf_function(defun_kwargs): 

236 # Note: wrapper_helper will apply autograph based on context. 

237 def wrapped_fn(*args): # pylint: disable=missing-docstring 

238 ret = wrapper_helper(*args) 

239 ret = structure.to_tensor_list(self._output_structure, ret) 

240 return [ops.convert_to_tensor(t) for t in ret] 

241 

242 func_name = defun_kwargs.pop("func_name", "wrapped_fn") 

243 tf_function = def_function.Function( 

244 python_function=wrapped_fn, 

245 name=func_name, 

246 input_signature=structure.get_flat_tensor_specs( 

247 self._input_structure 

248 ), 

249 autograph=False, 

250 experimental_attributes=defun_kwargs, 

251 ) 

252 

253 return tf_function.get_concrete_function 

254 

255 if use_legacy_function: 

256 defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())}) 

257 fn_factory = trace_legacy_function(defun_kwargs) 

258 else: 

259 defun_kwargs.update({"func_name": func_name}) 

260 defun_kwargs.update({"_tf_data_function": True}) 

261 if debug_mode.DEBUG_MODE: 

262 fn_factory = trace_py_function(defun_kwargs) 

263 else: 

264 if def_function.functions_run_eagerly(): 

265 warnings.warn( 

266 "Even though the `tf.config.experimental_run_functions_eagerly` " 

267 "option is set, this option does not apply to tf.data functions. " 

268 "To force eager execution of tf.data functions, please use " 

269 "`tf.data.experimental.enable_debug_mode()`.") 

270 fn_factory = trace_tf_function(defun_kwargs) 

271 

272 self._function = fn_factory() 

273 # There is no graph to add in eager mode. 

274 add_to_graph &= not context.executing_eagerly() 

275 # There are some lifetime issues when a legacy function is not added to a 

276 # out-living graph. It's already deprecated so de-prioritizing the fix. 

277 add_to_graph |= use_legacy_function 

278 if add_to_graph: 

279 self._function.add_to_graph(ops.get_default_graph()) 

280 

281 if not use_legacy_function: 

282 outer_graph_seed = ops.get_default_graph().seed 

283 if outer_graph_seed and self._function.graph.seed == outer_graph_seed: 

284 if self._function.graph._seed_used: 

285 warnings.warn( 

286 "Seed %s from outer graph might be getting used by function %s, " 

287 "if the random op has not been provided any seed. Explicitly set " 

288 "the seed in the function if this is not the intended behavior." % 

289 (outer_graph_seed, func_name), 

290 stacklevel=4) 

291 

292 @property 

293 def output_structure(self): 

294 return self._output_structure 

295 

296 @property 

297 def output_classes(self): 

298 return nest.map_structure( 

299 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

300 self._output_structure) 

301 

302 @property 

303 def output_shapes(self): 

304 return nest.map_structure( 

305 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

306 self._output_structure) 

307 

308 @property 

309 def output_types(self): 

310 return nest.map_structure( 

311 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

312 self._output_structure) 

313 

314 @property 

315 def function(self): 

316 return self._function