Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/function_def_to_graph.py: 12%
175 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility to convert FunctionDef to GraphDef and Graph."""
17import itertools
20from tensorflow.core.framework import function_pb2
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import tensor_shape_pb2
23from tensorflow.core.framework import types_pb2
24from tensorflow.core.framework import versions_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.framework import cpp_shape_inference_pb2
27from tensorflow.python.framework import importer
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import versions
30from tensorflow.python.framework.func_graph import FuncGraph
31from tensorflow.python.ops import resource_variable_ops
34def function_def_to_graph(
35 fdef,
36 structured_input_signature=None,
37 structured_outputs=None,
38 input_shapes=None,
39 propagate_device_spec=False,
40 include_library_functions=False,
41):
42 """Converts a FunctionDef to a FuncGraph (sub-class Graph).
44 The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
45 The input tensors are represented as placeholders.
47 Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
48 by the caller.
50 Args:
51 fdef: FunctionDef.
52 structured_input_signature: Optional. The structured input signature to use
53 for initializing the FuncGraph. See the docstring for FuncGraph for more
54 information.
55 structured_outputs: Optional. The structured outputs to use for initializing
56 the FuncGraph. See the docstring for FuncGraph for more information.
57 input_shapes: Optional. A list of TensorShape objects of the shapes of
58 function inputs. Defaults to the function's "_input_shapes" attribute. If
59 specified, its length must match length of `fdef.signature.input_arg`. If
60 a shape is None, the corresponding input placeholder will have unknown
61 shape.
62 propagate_device_spec: Optional. Whether to propagate assigned device
63 information when constructing a new Graph from a FunctionDef.
64 include_library_functions: Optional. Whether to include library functions in
65 the output FuncGraph. In graph mode, the library functions will be found
66 from outer graph. In eager mode, the library functions will be found from
67 eager context.
69 Returns:
70 A FuncGraph.
71 """
72 func_graph = FuncGraph(fdef.signature.name,
73 structured_input_signature=structured_input_signature,
74 structured_outputs=structured_outputs)
75 if input_shapes is None:
76 input_shapes_attr = fdef.attr.get("_input_shapes", None)
77 if input_shapes_attr is not None:
78 raw_input_shapes = input_shapes_attr.list.shape
80 # Replace resource handle shapes in the inputs to disable shape inference.
81 # Setting the shape to either the variable handle shape (which is always
82 # `[]`) or the variable shape can cause shape inference issues.
83 input_shapes = []
84 for input_shape, arg_def in zip(raw_input_shapes,
85 fdef.signature.input_arg):
86 if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data:
87 input_shapes.append(None)
88 else:
89 input_shapes.append(input_shape)
91 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
92 fdef, input_shapes, include_library_functions=include_library_functions
93 )
95 with func_graph.as_default():
96 # Add all function nodes to the graph.
97 importer.import_graph_def_for_function(
98 graph_def, name="", propagate_device_spec=propagate_device_spec)
100 # Initialize fields specific to FuncGraph.
102 # inputs
103 input_tensor_names = [
104 nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
105 ]
106 func_graph.inputs = [
107 func_graph.get_tensor_by_name(name) for name in input_tensor_names
108 ]
110 # outputs
111 output_tensor_names = [
112 nested_to_flat_tensor_name[fdef.ret[arg.name]]
113 for arg in fdef.signature.output_arg
114 ]
115 func_graph.outputs = [
116 func_graph.get_tensor_by_name(name) for name in output_tensor_names
117 ]
118 func_graph.control_outputs = [
119 func_graph.get_operation_by_name(fdef.control_ret[ret_name])
120 for ret_name in fdef.signature.control_output
121 ]
123 _set_handle_data(func_graph, fdef)
125 for node in graph_def.node:
126 output_shapes = node.attr.get("_output_shapes", None)
127 if output_shapes is not None:
128 op = func_graph.get_operation_by_name(node.name)
129 # _output_shapes for functions can sometimes be too long because the
130 # output-intermediates-for-gradients version of the function was
131 # substituted before saving. We'll accept that here. (See b/133666530).
132 for output_index, shape in enumerate(
133 output_shapes.list.shape[:len(op.outputs)]):
134 op.outputs[output_index].set_shape(shape)
135 output_names = {}
136 for ret_arg_def, tensor_name in zip(
137 fdef.signature.output_arg, output_tensor_names):
138 output_names[ops.tensor_id(
139 func_graph.get_tensor_by_name(tensor_name))] = (
140 ret_arg_def.name)
141 func_graph._output_names = output_names # pylint: disable=protected-access
142 return func_graph
145def is_function(fname, graph):
146 """Checks for a function definition with `fname` in the current context."""
147 if context.executing_eagerly():
148 # Eager mode: use eager context as the single source of truth.
149 return context.context().has_function(fname)
150 else:
151 # Graph mode: use outer graphs as the single source of truth.
152 while graph is not None:
153 if graph._is_function(fname): # pylint: disable=protected-access
154 return True
155 if hasattr(graph, "outer_graph"):
156 graph = graph.outer_graph
157 else:
158 return False
161def get_function_def(fname, graph):
162 """Gets a function definition with `fname` in the current context."""
163 if context.executing_eagerly():
164 # Eager mode: use eager context as the single source of truth.
165 if context.context().has_function(fname):
166 return context.context().get_function_def(fname)
167 else:
168 # Graph mode: use outer graphs as the single source of truth.
169 while graph is not None:
170 if graph._is_function(fname): # pylint: disable=protected-access
171 return graph._get_function(fname).cached_definition # pylint: disable=protected-access
172 graph = getattr(graph, "outer_graph", None)
175def copy_function_def_to_graph_def_recursively(
176 func_name, graph_def, copied_functions, default_graph=None):
177 """Recursively copies `FunctionDef`s to `GraphDef`.
179 It copies the outermost `FunctionDef` and all nested `FunctionDef`s to
180 `graph_def`. The `copied_function` enforces that every `FunctionDef` will be
181 copied at most once. The `FunctionDef`s will be found from `default_graph` if
182 this function was called in graph mode or from eager context if this function
183 was called in eager mode.
185 Args:
186 func_name: The signature name of FunctionDef to be copied to `graph_def`.
187 graph_def: The GraphDef that will contain all `FunctionDef`s in its library.
188 copied_functions: A set contains all copied function names.
189 default_graph: The `tf.Graph` where all `FunctionDef`s will be found
190 in graph mode. Not used in eager mode.
191 """
192 # Custom ops may contain a func attr with an empty fname.
193 if func_name and not is_function(func_name, default_graph):
194 raise ValueError(f"Function {func_name} was not found. Please make "
195 "sure the FunctionDef `fdef` is correct.")
197 # If `copied_functions` contains `func_name`, the FunctionDef has already
198 # been added to GraphDef so we simply return here.
199 if func_name in copied_functions:
200 return
202 copied_functions.add(func_name)
203 func_def = get_function_def(func_name, default_graph)
204 graph_def.library.function.add().CopyFrom(func_def)
206 for node_def in func_def.node_def:
207 op_def = default_graph.op_def_for_type(node_def.op)
208 for attr in op_def.attr:
209 if attr.type == "func":
210 func_name = node_def.attr[attr.name].func.name
211 copy_function_def_to_graph_def_recursively(
212 func_name, graph_def, copied_functions, default_graph)
214 elif attr.type == "list(func)":
215 for fn in node_def.attr[attr.name].list.func:
216 func_name = fn.name
217 copy_function_def_to_graph_def_recursively(
218 func_name, graph_def, copied_functions, default_graph)
221def function_def_to_graph_def(
222 fdef, input_shapes=None, include_library_functions=False
223):
224 """Convert a FunctionDef to a GraphDef.
226 Steps:
227 1. Creates placeholder nodes corresponding to inputs in
228 `FunctionDef.signature.input_arg`.
229 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
230 3. Renames inputs of all nodes to use the convention of GraphDef instead of
231 FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
232 in FunctionDefs is different from GraphDefs.
234 Args:
235 fdef: FunctionDef.
236 input_shapes: Optional. A list of TensorShape objects of the shapes of
237 function inputs. If specified, its length must match length of
238 `fdef.signature.input_arg`. If a shape is None, the corresponding input
239 placeholder will have unknown shape.
240 include_library_functions: Optional. If enabled, copy `fdef` and its
241 nested `FunctionDef`s to the library functions of the returned `GraphDef`.
242 In graph mode, the functions will be found from outer graph. In eager
243 mode, the functions will be found from eager context.
245 Returns:
246 A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
247 from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
249 Raises:
250 ValueError: If the length of input_shapes does not match the number of
251 input_args or if the FunctionDef is invalid.
252 """
253 graph_def = graph_pb2.GraphDef()
254 graph_def.versions.CopyFrom(
255 versions_pb2.VersionDef(
256 producer=versions.GRAPH_DEF_VERSION,
257 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
259 default_graph = ops.get_default_graph()
261 copied_functions = set()
263 if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
264 raise ValueError("Length of `input_shapes` must match the number "
265 f"of `input_arg`s in `fdef`. Got "
266 f"{len(input_shapes)} `input_shapes` and "
267 f"{len(fdef.signature.input_arg)} `input_arg`s.")
269 # 1. Create placeholders for input nodes.
270 for i, arg_def in enumerate(fdef.signature.input_arg):
271 node_def = graph_def.node.add()
272 node_def.name = arg_def.name
273 node_def.op = "Placeholder"
274 node_def.attr["dtype"].type = arg_def.type
275 if input_shapes and input_shapes[i] is not None:
276 input_shape = input_shapes[i]
277 if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
278 input_shape = input_shape.as_proto()
279 node_def.attr["shape"].shape.CopyFrom(input_shape)
280 arg_attrs = fdef.arg_attr[i].attr
281 for k in arg_attrs:
282 # Only copy internal attributes. Normal attributes for nodes cannot be
283 # applied to these Placeholder nodes.
284 if k == "_output_shapes":
285 if arg_attrs[k].WhichOneof("value") == "list":
286 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0])
287 elif arg_attrs[k].WhichOneof("value") == "shape":
288 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].shape)
289 elif k.startswith("_"):
290 node_def.attr[k].CopyFrom(arg_attrs[k])
292 # 2. Copy all body NodeDefs to the GraphDef.
293 graph_def.node.extend(fdef.node_def)
295 # 3. Perform the renaming.
297 # Build the tensor name mapping then flatten the tensor names.
298 # See comment on `FunctionDef.node_def` on how the tensor naming in
299 # FunctionDefs is different from GraphDefs.
300 nested_to_flat_tensor_name = {}
302 for arg_def in fdef.signature.input_arg:
303 nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
304 control_name = "^" + arg_def.name
305 nested_to_flat_tensor_name[control_name] = control_name
307 for node_def in fdef.node_def:
308 graph = default_graph
309 while True:
310 f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access
311 if f is not None or not hasattr(graph, "outer_graph"):
312 break
313 graph = graph.outer_graph
315 if f is not None:
316 fdef = f.cached_definition
317 op_def = fdef.signature
318 if node_def.op not in copied_functions:
319 # Since this function is referenced as an op type, we have no choice but
320 # to copy it into the GraphDef if we want downstream tools to process
321 # it.
322 graph_def.library.function.add().CopyFrom(fdef)
323 copied_functions.add(node_def.op)
324 if getattr(f, "grad_func_name", None):
325 grad_def = function_pb2.GradientDef()
326 grad_def.function_name = f.name
327 grad_def.gradient_func = f.grad_func_name
328 graph_def.library.gradient.extend([grad_def])
329 else:
330 op_def = default_graph.op_def_for_type(node_def.op) # pylint: disable=protected-access
332 for attr in op_def.attr:
333 if attr.type == "func":
334 fname = node_def.attr[attr.name].func.name
335 # Custom ops may contain a func attr with an empty fname.
336 if fname and not is_function(fname, default_graph):
337 raise ValueError(f"Function {fname} was not found. Please make sure "
338 "the FunctionDef `fdef` is correct.")
339 if include_library_functions:
340 copy_function_def_to_graph_def_recursively(
341 fname, graph_def, copied_functions, default_graph)
343 elif attr.type == "list(func)":
344 for fn in node_def.attr[attr.name].list.func:
345 fname = fn.name
346 # Custom ops may contain a func attr with an empty fname.
347 if fname and not is_function(fname, default_graph):
348 raise ValueError(f"Function {fname} was not found. Please make "
349 "sure the FunctionDef `fdef` is correct.")
350 if include_library_functions:
351 copy_function_def_to_graph_def_recursively(
352 fname, graph_def, copied_functions, default_graph)
354 # Iterate over output_args in op_def to build the map.
355 # Index of the output tensor in the flattened list of *all* output
356 # tensors of the op.
357 flattened_index = 0
358 for arg_def in op_def.output_arg:
359 num_args = _get_num_args(arg_def, node_def)
360 for i in range(num_args):
361 # Map tensor names from "node_name:output_arg_name:index" to
362 # "node_name:flattened_index".
363 nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
364 flat_name = "{}:{}".format(node_def.name, flattened_index)
365 nested_to_flat_tensor_name[nested_name] = flat_name
366 flattened_index += 1
367 control_name = "^" + node_def.name
368 nested_to_flat_tensor_name[control_name] = control_name
370 # Update inputs of all nodes in graph.
371 for node_def in graph_def.node:
372 for i in range(len(node_def.input)):
373 node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
375 return graph_def, nested_to_flat_tensor_name
378# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
379def _get_num_args(arg_def, node_def):
380 if arg_def.number_attr:
381 return node_def.attr[arg_def.number_attr].i
382 elif arg_def.type_list_attr:
383 return len(node_def.attr[arg_def.type_list_attr].list.type)
384 elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
385 return 1
386 else:
387 raise ValueError(f"Invalid arg_def:\n\n{arg_def}. Please make sure the "
388 "FunctionDef `fdef` is correct.")
391def _set_handle_data(func_graph, fdef):
392 """Adds handle data for resource type inputs and outputs."""
393 # The shape of the handle itself is [], while the variable shape is
394 # saved in `handle_data`. Previously, the shape of the resource handle
395 # was set to `None`. Correct both shapes here.
396 for tensor, arg_def in itertools.chain(
397 zip(func_graph.inputs, fdef.signature.input_arg),
398 zip(func_graph.outputs, fdef.signature.output_arg)):
399 if arg_def.handle_data:
400 tensor.set_shape([])
402 shape_and_dtype = arg_def.handle_data[0]
403 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
404 handle_data.is_set = True
405 handle_data.shape_and_type.append(
406 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
407 shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
408 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access
409 tensor, handle_data, True)