Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/function_def_to_graph.py: 12%

175 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Utility to convert FunctionDef to GraphDef and Graph.""" 

16 

17import itertools 

18 

19 

20from tensorflow.core.framework import function_pb2 

21from tensorflow.core.framework import graph_pb2 

22from tensorflow.core.framework import tensor_shape_pb2 

23from tensorflow.core.framework import types_pb2 

24from tensorflow.core.framework import versions_pb2 

25from tensorflow.python.eager import context 

26from tensorflow.python.framework import cpp_shape_inference_pb2 

27from tensorflow.python.framework import importer 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import versions 

30from tensorflow.python.framework.func_graph import FuncGraph 

31from tensorflow.python.ops import resource_variable_ops 

32 

33 

34def function_def_to_graph( 

35 fdef, 

36 structured_input_signature=None, 

37 structured_outputs=None, 

38 input_shapes=None, 

39 propagate_device_spec=False, 

40 include_library_functions=False, 

41): 

42 """Converts a FunctionDef to a FuncGraph (sub-class Graph). 

43 

44 The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. 

45 The input tensors are represented as placeholders. 

46 

47 Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set 

48 by the caller. 

49 

50 Args: 

51 fdef: FunctionDef. 

52 structured_input_signature: Optional. The structured input signature to use 

53 for initializing the FuncGraph. See the docstring for FuncGraph for more 

54 information. 

55 structured_outputs: Optional. The structured outputs to use for initializing 

56 the FuncGraph. See the docstring for FuncGraph for more information. 

57 input_shapes: Optional. A list of TensorShape objects of the shapes of 

58 function inputs. Defaults to the function's "_input_shapes" attribute. If 

59 specified, its length must match length of `fdef.signature.input_arg`. If 

60 a shape is None, the corresponding input placeholder will have unknown 

61 shape. 

62 propagate_device_spec: Optional. Whether to propagate assigned device 

63 information when constructing a new Graph from a FunctionDef. 

64 include_library_functions: Optional. Whether to include library functions in 

65 the output FuncGraph. In graph mode, the library functions will be found 

66 from outer graph. In eager mode, the library functions will be found from 

67 eager context. 

68 

69 Returns: 

70 A FuncGraph. 

71 """ 

72 func_graph = FuncGraph(fdef.signature.name, 

73 structured_input_signature=structured_input_signature, 

74 structured_outputs=structured_outputs) 

75 if input_shapes is None: 

76 input_shapes_attr = fdef.attr.get("_input_shapes", None) 

77 if input_shapes_attr is not None: 

78 raw_input_shapes = input_shapes_attr.list.shape 

79 

80 # Replace resource handle shapes in the inputs to disable shape inference. 

81 # Setting the shape to either the variable handle shape (which is always 

82 # `[]`) or the variable shape can cause shape inference issues. 

83 input_shapes = [] 

84 for input_shape, arg_def in zip(raw_input_shapes, 

85 fdef.signature.input_arg): 

86 if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data: 

87 input_shapes.append(None) 

88 else: 

89 input_shapes.append(input_shape) 

90 

91 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( 

92 fdef, input_shapes, include_library_functions=include_library_functions 

93 ) 

94 

95 with func_graph.as_default(): 

96 # Add all function nodes to the graph. 

97 importer.import_graph_def_for_function( 

98 graph_def, name="", propagate_device_spec=propagate_device_spec) 

99 

100 # Initialize fields specific to FuncGraph. 

101 

102 # inputs 

103 input_tensor_names = [ 

104 nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg 

105 ] 

106 func_graph.inputs = [ 

107 func_graph.get_tensor_by_name(name) for name in input_tensor_names 

108 ] 

109 

110 # outputs 

111 output_tensor_names = [ 

112 nested_to_flat_tensor_name[fdef.ret[arg.name]] 

113 for arg in fdef.signature.output_arg 

114 ] 

115 func_graph.outputs = [ 

116 func_graph.get_tensor_by_name(name) for name in output_tensor_names 

117 ] 

118 func_graph.control_outputs = [ 

119 func_graph.get_operation_by_name(fdef.control_ret[ret_name]) 

120 for ret_name in fdef.signature.control_output 

121 ] 

122 

123 _set_handle_data(func_graph, fdef) 

124 

125 for node in graph_def.node: 

126 output_shapes = node.attr.get("_output_shapes", None) 

127 if output_shapes is not None: 

128 op = func_graph.get_operation_by_name(node.name) 

129 # _output_shapes for functions can sometimes be too long because the 

130 # output-intermediates-for-gradients version of the function was 

131 # substituted before saving. We'll accept that here. (See b/133666530). 

132 for output_index, shape in enumerate( 

133 output_shapes.list.shape[:len(op.outputs)]): 

134 op.outputs[output_index].set_shape(shape) 

135 output_names = {} 

136 for ret_arg_def, tensor_name in zip( 

137 fdef.signature.output_arg, output_tensor_names): 

138 output_names[ops.tensor_id( 

139 func_graph.get_tensor_by_name(tensor_name))] = ( 

140 ret_arg_def.name) 

141 func_graph._output_names = output_names # pylint: disable=protected-access 

142 return func_graph 

143 

144 

145def is_function(fname, graph): 

146 """Checks for a function definition with `fname` in the current context.""" 

147 if context.executing_eagerly(): 

148 # Eager mode: use eager context as the single source of truth. 

149 return context.context().has_function(fname) 

150 else: 

151 # Graph mode: use outer graphs as the single source of truth. 

152 while graph is not None: 

153 if graph._is_function(fname): # pylint: disable=protected-access 

154 return True 

155 if hasattr(graph, "outer_graph"): 

156 graph = graph.outer_graph 

157 else: 

158 return False 

159 

160 

161def get_function_def(fname, graph): 

162 """Gets a function definition with `fname` in the current context.""" 

163 if context.executing_eagerly(): 

164 # Eager mode: use eager context as the single source of truth. 

165 if context.context().has_function(fname): 

166 return context.context().get_function_def(fname) 

167 else: 

168 # Graph mode: use outer graphs as the single source of truth. 

169 while graph is not None: 

170 if graph._is_function(fname): # pylint: disable=protected-access 

171 return graph._get_function(fname).cached_definition # pylint: disable=protected-access 

172 graph = getattr(graph, "outer_graph", None) 

173 

174 

175def copy_function_def_to_graph_def_recursively( 

176 func_name, graph_def, copied_functions, default_graph=None): 

177 """Recursively copies `FunctionDef`s to `GraphDef`. 

178 

179 It copies the outermost `FunctionDef` and all nested `FunctionDef`s to 

180 `graph_def`. The `copied_function` enforces that every `FunctionDef` will be 

181 copied at most once. The `FunctionDef`s will be found from `default_graph` if 

182 this function was called in graph mode or from eager context if this function 

183 was called in eager mode. 

184 

185 Args: 

186 func_name: The signature name of FunctionDef to be copied to `graph_def`. 

187 graph_def: The GraphDef that will contain all `FunctionDef`s in its library. 

188 copied_functions: A set contains all copied function names. 

189 default_graph: The `tf.Graph` where all `FunctionDef`s will be found 

190 in graph mode. Not used in eager mode. 

191 """ 

192 # Custom ops may contain a func attr with an empty fname. 

193 if func_name and not is_function(func_name, default_graph): 

194 raise ValueError(f"Function {func_name} was not found. Please make " 

195 "sure the FunctionDef `fdef` is correct.") 

196 

197 # If `copied_functions` contains `func_name`, the FunctionDef has already 

198 # been added to GraphDef so we simply return here. 

199 if func_name in copied_functions: 

200 return 

201 

202 copied_functions.add(func_name) 

203 func_def = get_function_def(func_name, default_graph) 

204 graph_def.library.function.add().CopyFrom(func_def) 

205 

206 for node_def in func_def.node_def: 

207 op_def = default_graph.op_def_for_type(node_def.op) 

208 for attr in op_def.attr: 

209 if attr.type == "func": 

210 func_name = node_def.attr[attr.name].func.name 

211 copy_function_def_to_graph_def_recursively( 

212 func_name, graph_def, copied_functions, default_graph) 

213 

214 elif attr.type == "list(func)": 

215 for fn in node_def.attr[attr.name].list.func: 

216 func_name = fn.name 

217 copy_function_def_to_graph_def_recursively( 

218 func_name, graph_def, copied_functions, default_graph) 

219 

220 

221def function_def_to_graph_def( 

222 fdef, input_shapes=None, include_library_functions=False 

223): 

224 """Convert a FunctionDef to a GraphDef. 

225 

226 Steps: 

227 1. Creates placeholder nodes corresponding to inputs in 

228 `FunctionDef.signature.input_arg`. 

229 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 

230 3. Renames inputs of all nodes to use the convention of GraphDef instead of 

231 FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming 

232 in FunctionDefs is different from GraphDefs. 

233 

234 Args: 

235 fdef: FunctionDef. 

236 input_shapes: Optional. A list of TensorShape objects of the shapes of 

237 function inputs. If specified, its length must match length of 

238 `fdef.signature.input_arg`. If a shape is None, the corresponding input 

239 placeholder will have unknown shape. 

240 include_library_functions: Optional. If enabled, copy `fdef` and its 

241 nested `FunctionDef`s to the library functions of the returned `GraphDef`. 

242 In graph mode, the functions will be found from outer graph. In eager 

243 mode, the functions will be found from eager context. 

244 

245 Returns: 

246 A tuple of (GraphDef, dict<string, string>). The dict contains a mapping 

247 from nested tensor names (in FunctionDef) to flattened names (in GraphDef). 

248 

249 Raises: 

250 ValueError: If the length of input_shapes does not match the number of 

251 input_args or if the FunctionDef is invalid. 

252 """ 

253 graph_def = graph_pb2.GraphDef() 

254 graph_def.versions.CopyFrom( 

255 versions_pb2.VersionDef( 

256 producer=versions.GRAPH_DEF_VERSION, 

257 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) 

258 

259 default_graph = ops.get_default_graph() 

260 

261 copied_functions = set() 

262 

263 if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): 

264 raise ValueError("Length of `input_shapes` must match the number " 

265 f"of `input_arg`s in `fdef`. Got " 

266 f"{len(input_shapes)} `input_shapes` and " 

267 f"{len(fdef.signature.input_arg)} `input_arg`s.") 

268 

269 # 1. Create placeholders for input nodes. 

270 for i, arg_def in enumerate(fdef.signature.input_arg): 

271 node_def = graph_def.node.add() 

272 node_def.name = arg_def.name 

273 node_def.op = "Placeholder" 

274 node_def.attr["dtype"].type = arg_def.type 

275 if input_shapes and input_shapes[i] is not None: 

276 input_shape = input_shapes[i] 

277 if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto): 

278 input_shape = input_shape.as_proto() 

279 node_def.attr["shape"].shape.CopyFrom(input_shape) 

280 arg_attrs = fdef.arg_attr[i].attr 

281 for k in arg_attrs: 

282 # Only copy internal attributes. Normal attributes for nodes cannot be 

283 # applied to these Placeholder nodes. 

284 if k == "_output_shapes": 

285 if arg_attrs[k].WhichOneof("value") == "list": 

286 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0]) 

287 elif arg_attrs[k].WhichOneof("value") == "shape": 

288 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].shape) 

289 elif k.startswith("_"): 

290 node_def.attr[k].CopyFrom(arg_attrs[k]) 

291 

292 # 2. Copy all body NodeDefs to the GraphDef. 

293 graph_def.node.extend(fdef.node_def) 

294 

295 # 3. Perform the renaming. 

296 

297 # Build the tensor name mapping then flatten the tensor names. 

298 # See comment on `FunctionDef.node_def` on how the tensor naming in 

299 # FunctionDefs is different from GraphDefs. 

300 nested_to_flat_tensor_name = {} 

301 

302 for arg_def in fdef.signature.input_arg: 

303 nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) 

304 control_name = "^" + arg_def.name 

305 nested_to_flat_tensor_name[control_name] = control_name 

306 

307 for node_def in fdef.node_def: 

308 graph = default_graph 

309 while True: 

310 f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access 

311 if f is not None or not hasattr(graph, "outer_graph"): 

312 break 

313 graph = graph.outer_graph 

314 

315 if f is not None: 

316 fdef = f.cached_definition 

317 op_def = fdef.signature 

318 if node_def.op not in copied_functions: 

319 # Since this function is referenced as an op type, we have no choice but 

320 # to copy it into the GraphDef if we want downstream tools to process 

321 # it. 

322 graph_def.library.function.add().CopyFrom(fdef) 

323 copied_functions.add(node_def.op) 

324 if getattr(f, "grad_func_name", None): 

325 grad_def = function_pb2.GradientDef() 

326 grad_def.function_name = f.name 

327 grad_def.gradient_func = f.grad_func_name 

328 graph_def.library.gradient.extend([grad_def]) 

329 else: 

330 op_def = default_graph.op_def_for_type(node_def.op) # pylint: disable=protected-access 

331 

332 for attr in op_def.attr: 

333 if attr.type == "func": 

334 fname = node_def.attr[attr.name].func.name 

335 # Custom ops may contain a func attr with an empty fname. 

336 if fname and not is_function(fname, default_graph): 

337 raise ValueError(f"Function {fname} was not found. Please make sure " 

338 "the FunctionDef `fdef` is correct.") 

339 if include_library_functions: 

340 copy_function_def_to_graph_def_recursively( 

341 fname, graph_def, copied_functions, default_graph) 

342 

343 elif attr.type == "list(func)": 

344 for fn in node_def.attr[attr.name].list.func: 

345 fname = fn.name 

346 # Custom ops may contain a func attr with an empty fname. 

347 if fname and not is_function(fname, default_graph): 

348 raise ValueError(f"Function {fname} was not found. Please make " 

349 "sure the FunctionDef `fdef` is correct.") 

350 if include_library_functions: 

351 copy_function_def_to_graph_def_recursively( 

352 fname, graph_def, copied_functions, default_graph) 

353 

354 # Iterate over output_args in op_def to build the map. 

355 # Index of the output tensor in the flattened list of *all* output 

356 # tensors of the op. 

357 flattened_index = 0 

358 for arg_def in op_def.output_arg: 

359 num_args = _get_num_args(arg_def, node_def) 

360 for i in range(num_args): 

361 # Map tensor names from "node_name:output_arg_name:index" to 

362 # "node_name:flattened_index". 

363 nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) 

364 flat_name = "{}:{}".format(node_def.name, flattened_index) 

365 nested_to_flat_tensor_name[nested_name] = flat_name 

366 flattened_index += 1 

367 control_name = "^" + node_def.name 

368 nested_to_flat_tensor_name[control_name] = control_name 

369 

370 # Update inputs of all nodes in graph. 

371 for node_def in graph_def.node: 

372 for i in range(len(node_def.input)): 

373 node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] 

374 

375 return graph_def, nested_to_flat_tensor_name 

376 

377 

378# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. 

379def _get_num_args(arg_def, node_def): 

380 if arg_def.number_attr: 

381 return node_def.attr[arg_def.number_attr].i 

382 elif arg_def.type_list_attr: 

383 return len(node_def.attr[arg_def.type_list_attr].list.type) 

384 elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: 

385 return 1 

386 else: 

387 raise ValueError(f"Invalid arg_def:\n\n{arg_def}. Please make sure the " 

388 "FunctionDef `fdef` is correct.") 

389 

390 

391def _set_handle_data(func_graph, fdef): 

392 """Adds handle data for resource type inputs and outputs.""" 

393 # The shape of the handle itself is [], while the variable shape is 

394 # saved in `handle_data`. Previously, the shape of the resource handle 

395 # was set to `None`. Correct both shapes here. 

396 for tensor, arg_def in itertools.chain( 

397 zip(func_graph.inputs, fdef.signature.input_arg), 

398 zip(func_graph.outputs, fdef.signature.output_arg)): 

399 if arg_def.handle_data: 

400 tensor.set_shape([]) 

401 

402 shape_and_dtype = arg_def.handle_data[0] 

403 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 

404 handle_data.is_set = True 

405 handle_data.shape_and_type.append( 

406 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 

407 shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype)) 

408 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 

409 tensor, handle_data, True)