Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/graph_util_impl.py: 20%
175 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to manipulate a tensor graph in python.
16"""
18import copy
19import re
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import node_def_pb2
23from tensorflow.python.framework import _proto_comparators
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.util import deprecation
27from tensorflow.python.util.tf_export import tf_export
29tf_export(v1=["GraphDef"])(graph_pb2.GraphDef)
31_VARIABLE_OPS = {
32 "Assign",
33 "AssignAdd",
34 "AssignSub",
35 "Queue",
36 "ScatterAdd",
37 "ScatterSub",
38 "ScatterUpdate",
39 "TruncatedNormal",
40 "Variable",
41 "VariableV2",
42}
44_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [
45 "Switch",
46 "Enter",
47 "Exit",
48 "Identity",
49 "Merge",
50 "NextIteration",
51]
53_DEPRECATION_MSG = (
54 "This API was designed for TensorFlow v1. See "
55 "https://www.tensorflow.org/guide/migrate for instructions on how to "
56 "migrate your code to TensorFlow v2.")
59def _is_variable_op(op):
60 """Returns true if 'op' refers to a Variable node."""
61 return op in _VARIABLE_OPS
63# GraphDef protobuf docstring.
64graph_pb2.GraphDef.__doc__ = """\
65A protobuf containing the graph of operations.
67@compatibility(TF2)
68This API is not available in TensorFlow 2.x.
70You should not need to use `GraphDef`s directly in TF2. To load `GraphDef`s in
71TF2, use SavedModel. The SavedModel contains the `GraphDef`.
73Before:
75```python
76with tf.io.gfile.GFile('/tmp/graph.pb', 'rb') as f:
77 graph_def = tf.compat.v1.GraphDef()
78 graph_def.ParseFromString(f.read())
79```
81After:
83```python
84tf.saved_model.load('/tmp/saved_model')
85```
87If you would like to create a `GraphDef` in TF2, use `tf.function` and
88`get_concrete_function`.
90>>> @tf.function
91>>> def f(x):
92>>> return x
93>>>
94>>> graph_def = f.get_concrete_function(1.).graph.as_graph_def()
95>>> print(graph_def)
97@end_compatibility
99"""
102@deprecation.deprecated(
103 date=None,
104 instructions=_DEPRECATION_MSG)
105@tf_export(v1=["graph_util.must_run_on_cpu"])
106def must_run_on_cpu(node, pin_variables_on_cpu=False):
107 """Returns True if the given node_def must run on CPU, otherwise False.
109 Args:
110 node: The node to be assigned to a device. Could be either an ops.Operation
111 or NodeDef.
112 pin_variables_on_cpu: If True, this function will return False if node_def
113 represents a variable-related op.
115 Returns:
116 True if the given node must run on CPU, otherwise False.
117 """
119 if isinstance(node, ops.Operation):
120 node_def = node.node_def
121 else:
122 assert isinstance(node, node_def_pb2.NodeDef)
123 node_def = node
125 # If the op is a variable-related op, should we pin it on CPU?
126 if pin_variables_on_cpu and _is_variable_op(node_def.op):
127 return True
129 # Constant operations producing a string or int32 must run on CPU.
130 if node_def.op == "Const":
131 # Get the value of the 'dtype' attr
132 dtype = node_def.attr["dtype"].type
133 if dtype == dtypes.string or dtype == dtypes.int32:
134 return True
136 if node_def.op in ["DynamicStitch", "ParallelDynamicStitch"]:
137 dtype = node_def.attr["T"].type
138 if dtype == dtypes.int32:
139 # DynamicStitch on GPU only works for int32 values.
140 return True
142 if node_def.op in ["Cast"]:
143 dtype = node_def.attr["SrcT"].type
144 if dtype == dtypes.int32:
145 # Cast on GPU does not works for int32 values.
146 return True
147 return False
150################################################################################
151#
152# device functions for use in with g.device(...)
153#
154################################################################################
157def _node_name(n):
158 if n.startswith("^"):
159 return n[1:]
160 else:
161 return n.split(":")[0]
164def _get_colocated_node_name(colocated_node_name):
165 """Decodes colocated node name and returns it without loc:@ prepended."""
166 colocated_node_decoded = colocated_node_name.decode("utf-8")
167 if colocated_node_decoded.startswith("loc:@"):
168 return colocated_node_decoded[5:]
169 return colocated_node_decoded
172def _extract_graph_summary(graph_def):
173 """Extracts useful information from the graph and returns them."""
174 name_to_input_name = {} # Keyed by the dest node name.
175 name_to_node = {} # Keyed by node name.
177 # Keeps track of node sequences. It is important to still output the
178 # operations in the original order.
179 name_to_seq_num = {} # Keyed by node name.
180 seq = 0
181 for node in graph_def.node:
182 n = _node_name(node.name)
183 name_to_node[n] = node
184 name_to_input_name[n] = [_node_name(x) for x in node.input]
185 # Prevent colocated nodes from being lost.
186 if "_class" in node.attr:
187 for colocated_node_name in node.attr["_class"].list.s:
188 name_to_input_name[n].append(
189 _get_colocated_node_name(colocated_node_name))
190 name_to_seq_num[n] = seq
191 seq += 1
192 return name_to_input_name, name_to_node, name_to_seq_num
195def _assert_nodes_are_present(name_to_node, nodes):
196 """Assert that nodes are present in the graph."""
197 for d in nodes:
198 assert d in name_to_node, "%s is not in graph" % d
201def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
202 """Breadth first search for reachable nodes from target nodes."""
203 nodes_to_keep = set()
204 # Breadth first search to find all the nodes that we should keep.
205 next_to_visit = list(target_nodes)
206 while next_to_visit:
207 node = next_to_visit[0]
208 del next_to_visit[0]
209 if node in nodes_to_keep:
210 # Already visited this node.
211 continue
212 nodes_to_keep.add(node)
213 if node in name_to_input_name:
214 next_to_visit += name_to_input_name[node]
215 return nodes_to_keep
218@deprecation.deprecated(
219 date=None,
220 instructions=_DEPRECATION_MSG)
221@tf_export(v1=["graph_util.extract_sub_graph"])
222def extract_sub_graph(graph_def, dest_nodes):
223 """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
225 Args:
226 graph_def: A graph_pb2.GraphDef proto.
227 dest_nodes: An iterable of strings specifying the destination node names.
228 Returns:
229 The GraphDef of the sub-graph.
231 Raises:
232 TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
233 """
235 if not isinstance(graph_def, graph_pb2.GraphDef):
236 raise TypeError("graph_def must be a graph_pb2.GraphDef proto, but got "
237 f"type {type(graph_def)}.")
239 if isinstance(dest_nodes, str):
240 raise TypeError("dest_nodes must be an iterable of strings, but got "
241 f"type {type(dest_nodes)}.")
243 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
244 graph_def)
245 _assert_nodes_are_present(name_to_node, dest_nodes)
247 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
249 nodes_to_keep_list = sorted(
250 list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
251 # Now construct the output GraphDef
252 out = graph_pb2.GraphDef()
253 for n in nodes_to_keep_list:
254 out.node.extend([copy.deepcopy(name_to_node[n])])
255 out.library.CopyFrom(graph_def.library)
256 out.versions.CopyFrom(graph_def.versions)
258 return out
261@deprecation.deprecated(
262 date=None,
263 instructions=_DEPRECATION_MSG)
264@tf_export(v1=["graph_util.tensor_shape_from_node_def_name"])
265def tensor_shape_from_node_def_name(graph, input_name):
266 """Convenience function to get a shape from a NodeDef's input string."""
267 # To get a tensor, the name must be in the form <input>:<port>, for example
268 # 'Mul:0'. The GraphDef input strings don't always have the port specified
269 # though, so if there isn't a colon we need to add a default ':0' to the end.
270 if ":" not in input_name:
271 canonical_name = input_name + ":0"
272 else:
273 canonical_name = input_name
274 tensor = graph.get_tensor_by_name(canonical_name)
275 shape = tensor.get_shape()
276 return shape
279@deprecation.deprecated(
280 date=None,
281 instructions=_DEPRECATION_MSG)
282@tf_export(v1=["graph_util.remove_training_nodes"])
283def remove_training_nodes(input_graph, protected_nodes=None):
284 """Prunes out nodes that aren't needed for inference.
286 There are nodes like Identity and CheckNumerics that are only useful
287 during training, and can be removed in graphs that will be used for
288 nothing but inference. Here we identify and remove them, returning an
289 equivalent graph. To be specific, CheckNumerics nodes are always removed, and
290 Identity nodes that aren't involved in control edges are spliced out so that
291 their input and outputs are directly connected.
293 Args:
294 input_graph: Model to analyze and prune.
295 protected_nodes: An optional list of names of nodes to be kept
296 unconditionally. This is for example useful to preserve Identity output
297 nodes.
299 Returns:
300 A list of nodes with the unnecessary ones removed.
301 """
302 if not protected_nodes:
303 protected_nodes = []
305 types_to_remove = {"CheckNumerics": True}
307 input_nodes = input_graph.node
308 names_to_remove = {}
309 for node in input_nodes:
310 if node.op in types_to_remove and node.name not in protected_nodes:
311 names_to_remove[node.name] = True
313 nodes_after_removal = []
314 for node in input_nodes:
315 if node.name in names_to_remove:
316 continue
317 new_node = node_def_pb2.NodeDef()
318 new_node.CopyFrom(node)
319 input_before_removal = node.input
320 del new_node.input[:]
321 for full_input_name in input_before_removal:
322 input_name = re.sub(r"^\^", "", full_input_name)
323 if input_name in names_to_remove:
324 continue
325 new_node.input.append(full_input_name)
326 nodes_after_removal.append(new_node)
328 types_to_splice = {"Identity": True}
329 control_input_names = set()
330 node_names_with_control_input = set()
331 node_in_colocated = set()
333 for node in nodes_after_removal:
334 for node_input in node.input:
335 if "^" in node_input:
336 control_input_names.add(node_input.replace("^", ""))
337 node_names_with_control_input.add(node.name)
338 # Prevent colocated nodes from being lost.
339 if "_class" in node.attr:
340 for colocated_node_name in node.attr["_class"].list.s:
341 node_in_colocated.add(_get_colocated_node_name(colocated_node_name))
343 names_to_splice = {}
344 for node in nodes_after_removal:
345 if node.op in types_to_splice and node.name not in protected_nodes:
346 if node.name in node_in_colocated:
347 continue
348 # We don't want to remove nodes that have control edge inputs, because
349 # they might be involved in subtle dependency issues that removing them
350 # will jeopardize.
351 if node.name not in node_names_with_control_input:
352 names_to_splice[node.name] = node.input[0]
354 # We also don't want to remove nodes which are used as control edge inputs.
355 names_to_splice = {name: value for name, value in names_to_splice.items()
356 if name not in control_input_names}
358 nodes_after_splicing = []
359 for node in nodes_after_removal:
360 if node.name in names_to_splice:
361 continue
362 new_node = node_def_pb2.NodeDef()
363 new_node.CopyFrom(node)
364 input_before_removal = node.input
365 del new_node.input[:]
366 for full_input_name in input_before_removal:
367 input_name = re.sub(r"^\^", "", full_input_name)
368 while input_name in names_to_splice:
369 full_input_name = names_to_splice[input_name]
370 input_name = re.sub(r"^\^", "", full_input_name)
371 new_node.input.append(full_input_name)
372 nodes_after_splicing.append(new_node)
374 output_graph = graph_pb2.GraphDef()
375 output_graph.node.extend(nodes_after_splicing)
376 return output_graph
379@tf_export("__internal__.graph_util.graph_defs_equal", v1=[])
380def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
381 graph_def_2: graph_pb2.GraphDef,
382 treat_nan_as_equal: bool = False) -> bool:
383 """Returns True iff the graph def arguments are structurally equivalent.
385 The notion of equivalence encoded here checks that the set of NodeDefs in
386 the GraphDef's function library and main graph body are identical.
387 Additionally, it checks that the functions in the function library are equal
388 as sets.
390 Example usage:
392 ```
393 with tf.Graph().as_default() as g1:
394 tf.constant(1)
396 with tf.Graph().as_default() as g2:
397 tf.constant(2)
399 with tf.Graph().as_default() as g3:
400 tf.constant(1)
402 assert tf.__internal__.graph_util.graph_defs_equal(g1.as_graph_def(),
403 g3.as_graph_def())
405 assert not tf.__internal__.graph_util.graph_defs_equal(g1.as_graph_def(),
406 g2.as_graph_def())
407 ```
409 Args:
410 graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
411 graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
412 treat_nan_as_equal: Boolean indicating whether or not to treat nan
413 floating-point values as equal. This is crucial for any equivalence
414 relation defined over GraphDefs, to ensure symmetry.
416 Returns:
417 Boolean indicating structural equivalence as described above.
419 Raises:
420 TypeError: If either of the GraphDefs are not instances of
421 `graph_pb2.GraphDef`.
422 """
423 if not isinstance(graph_def_1, graph_pb2.GraphDef):
424 raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto, but got "
425 f"type {type(graph_def_1)}.")
426 if not isinstance(graph_def_2, graph_pb2.GraphDef):
427 raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto, but got "
428 f"type {type(graph_def_2)}.")
429 options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
430 return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
431 graph_def_2.SerializeToString(),
432 options)