Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/meta_graph.py: 11%
413 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""MetaGraph and related functions."""
17import copy
18from packaging import version as packaging_version # pylint: disable=g-bad-import-order
19import os.path
20import re
21import sys
23from google.protobuf.any_pb2 import Any
24from google.protobuf import text_format
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.framework import graph_pb2
28from tensorflow.core.framework import op_def_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import saver_pb2
31from tensorflow.python.client import pywrap_tf_session as c_api
32from tensorflow.python.eager import context
33from tensorflow.python.framework import byte_swap_tensor as bst
34from tensorflow.python.framework import error_interpolation
35from tensorflow.python.framework import graph_io
36from tensorflow.python.framework import importer
37from tensorflow.python.framework import op_def_registry
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor
40from tensorflow.python.framework import versions
41from tensorflow.python.lib.io import file_io
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.util import compat
46# Prefix to be added to unbound input names so they are easily identifiable.
47_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
49# List of collections that didn't register proto functions, as a result in
50# a previously exported meta_graph the items are of a different data type.
51_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
52 ops.GraphKeys.MODEL_VARIABLES,
53 ops.GraphKeys.METRIC_VARIABLES]
56def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
57 """Create a `NodeDef` proto with export_scope stripped.
59 Args:
60 from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
61 export_scope: A `string` representing the name scope to remove.
62 unbound_inputs: An array of unbound input names if they exist.
63 clear_devices: Boolean which controls whether to clear device information
64 from node_def. Default false.
66 Returns:
67 A `node_def_pb2.NodeDef` protocol buffer.
68 """
69 node_def = copy.deepcopy(from_node_def)
70 for i, v in enumerate(node_def.input):
71 if (export_scope and
72 not node_def.input[i].lstrip("^").startswith(export_scope)):
73 # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
74 # identifiable.
75 node_def.input[i] = re.sub(r"([\^]|^)(.*)",
76 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
77 compat.as_str(v))
78 unbound_inputs.append(node_def.input[i])
79 else:
80 node_def.input[i] = ops.strip_name_scope(v, export_scope)
81 node_def.name = compat.as_bytes(
82 ops.strip_name_scope(from_node_def.name, export_scope))
83 for k, v in from_node_def.attr.items():
84 if k == "_class":
85 new_s = [compat.as_bytes(
86 ops.strip_name_scope(s, export_scope)) for s in v.list.s
87 if not export_scope or
88 compat.as_str(s).split("@")[1].startswith(export_scope)]
89 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
90 list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
91 elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
92 if not export_scope or compat.as_str(v.s).startswith(export_scope):
93 new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
94 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
95 else:
96 node_def.attr[k].CopyFrom(v)
98 if clear_devices:
99 node_def.device = ""
101 return node_def
104def _read_file(filename):
105 """Reads a file containing `GraphDef` and returns the protocol buffer.
107 Args:
108 filename: `graph_def` filename including the path.
110 Returns:
111 A `GraphDef` protocol buffer.
113 Raises:
114 IOError: If the file doesn't exist, or cannot be successfully parsed.
115 """
116 graph_def = graph_pb2.GraphDef()
117 if not file_io.file_exists(filename):
118 raise IOError(f"File {filename} does not exist.")
119 # First try to read it as a binary file.
120 with file_io.FileIO(filename, "rb") as f:
121 file_content = f.read()
122 try:
123 graph_def.ParseFromString(file_content)
124 return graph_def
125 except Exception: # pylint: disable=broad-except
126 pass
128 # Next try to read it as a text file.
129 try:
130 text_format.Merge(file_content, graph_def)
131 except text_format.ParseError as e:
132 raise IOError(f"Cannot parse file {filename}: {str(e)}.")
134 return graph_def
137def ops_used_by_graph_def(graph_def):
138 """Collect the list of ops used by a graph.
140 Does not validate that the ops are all registered.
142 Args:
143 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
145 Returns:
146 A list of strings, each naming an op used by the graph.
147 """
148 # Map function names to definitions
149 name_to_function = {}
150 for fun in graph_def.library.function:
151 name_to_function[fun.signature.name] = fun
153 # Collect the list of op names. Since functions can reference functions, we
154 # need a recursive traversal.
155 used_ops = set() # Includes both primitive ops and functions
156 functions_to_process = [] # A subset of used_ops
158 def mark_op_as_used(op):
159 if op not in used_ops and op in name_to_function:
160 functions_to_process.append(name_to_function[op])
161 used_ops.add(op)
163 def process_node(node):
164 mark_op_as_used(node.op)
165 if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
166 mark_op_as_used(node.attr["f"].func.name)
168 for node in graph_def.node:
169 process_node(node)
170 while functions_to_process:
171 fun = functions_to_process.pop()
172 for node in fun.node_def:
173 process_node(node)
175 return [op for op in used_ops if op not in name_to_function]
178def stripped_op_list_for_graph(graph_def):
179 """Collect the stripped OpDefs for ops used by a graph.
181 This function computes the `stripped_op_list` field of `MetaGraphDef` and
182 similar protos. The result can be communicated from the producer to the
183 consumer, which can then use the C++ function
184 `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
186 Args:
187 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
189 Returns:
190 An `OpList` of ops used by the graph.
191 """
192 # This is similar to StrippedOpListForGraph in C++, but unlike its
193 # C++ counterpart, this version does not require all ops to be registered.
194 # This is done to support Prelu fusion in tfjs.
195 used_ops = ops_used_by_graph_def(graph_def)
196 op_defs = []
197 for op in sorted(used_ops):
198 op_def = op_def_registry.get(op)
199 if op_def is not None:
200 op_defs.append(op_def)
201 return op_def_pb2.OpList(op=op_defs)
204def _get_kind_name(item):
205 """Returns the kind name in CollectionDef.
207 Args:
208 item: A data item.
210 Returns:
211 The string representation of the kind in CollectionDef.
212 """
213 if isinstance(item, (str, bytes)):
214 kind = "bytes_list"
215 elif isinstance(item, int):
216 kind = "int64_list"
217 elif isinstance(item, float):
218 kind = "float_list"
219 elif isinstance(item, Any):
220 kind = "any_list"
221 else:
222 kind = "node_list"
223 return kind
226SAVE_AND_RESTORE_OPS = ["SaveV2",
227 "Save", "SaveSlice",
228 "LegacySave", "LegacySaveSlice",
229 "RestoreV2",
230 "Restore", "RestoreSlice",
231 "LegacyRestore", "LegacyRestoreSlice"]
234def _get_scope(node_name):
235 """Extract the scope name from a node name.
237 The scope name is everything before the final slash,
238 not including any ^ prefix denoting a control dependency.
240 Args:
241 node_name: the full name of an Op or a Tensor in the graph.
242 Returns:
243 The deepest named scope containing the node.
244 Raises:
245 ValueError: if tensor_name is None or empty
246 """
247 if not node_name:
248 raise ValueError(
249 f"Node name cannot be empty or None. Received: {node_name}.")
251 # Control dependency inputs start with ^.
252 if node_name.startswith("^"):
253 node_name = node_name[1:]
254 if "/" in node_name:
255 scope, _ = node_name.rsplit("/", 1)
256 return scope
258 return ""
261def _find_extraneous_saver_nodes(graph_def, saver_def):
262 """Identifies any nodes in the graph_def related to unused Savers.
264 This approach assumes that each Saver is cleanly isolated in its own name
265 scope, so we need only identify the scopes associated with extraneous Savers
266 and return all the nodes in those scopes.
268 Args:
269 graph_def: a GraphDef proto to evaluate.
270 saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
271 Returns:
272 An iterable of node names that may be safely omitted.
273 """
274 # TODO(soergel): confirm that the assumption of scope isolation is valid.
275 # If not, we need to walk up the graph from any restore_all nodes, and walk
276 # down the graph from any Save/Restore nodes. I drafted that approach too,
277 # but it seems unnecessarily complex given the name scope solution.
279 # load the graph DAG in minimal form, without initializing a full Graph object
280 nodes = {
281 node_def.name: (
282 set(tensor.get_op_name(x) for x in node_def.input), node_def.op)
283 for node_def in graph_def.node
284 }
286 retain_scope_save = None
287 retain_scope_restore = None
288 # It's possible to have no saver if the graph has no Variables
289 if saver_def is not None:
290 save_op_name = tensor.get_op_name(saver_def.save_tensor_name)
291 restore_op_name = tensor.get_op_name(saver_def.restore_op_name)
293 # The save and restore scopes should always be the same, but if they differ
294 # for some reason, we retain them both to be safe.
295 retain_scope_restore = _get_scope(restore_op_name) + "/"
296 retain_scope_save = _get_scope(save_op_name) + "/"
298 all_saver_node_names = set(
299 name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
301 all_saver_scopes = (
302 set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
303 all_saver_scopes = set(x + "/" for x in all_saver_scopes)
305 extraneous_scopes = all_saver_scopes - set([retain_scope_save,
306 retain_scope_restore])
308 extraneous_node_names = set()
309 for name, _ in nodes.items():
310 for extraneous_scope in extraneous_scopes:
311 if name.startswith(extraneous_scope):
312 extraneous_node_names.add(name)
313 break
315 return extraneous_node_names
318def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
319 """Returns `True` if a node should be included.
321 Args:
322 node_or_node_name: A node or `string` node name.
323 export_scope: `string`. Name scope under which to extract the subgraph. The
324 scope name will be stripped from the node definitions for easy import
325 later into new name scopes.
326 exclude_nodes: An iterable of nodes or `string` node names to omit from the
327 export, or None. Note no sanity-checking is done, so this list must be
328 carefully constructed to avoid producing an invalid graph.
330 Returns:
331 `True` if the node should be included.
332 """
333 if not isinstance(node_or_node_name, str):
334 try:
335 node_name = node_or_node_name.name
336 except AttributeError:
337 # Keep the object that we don't know how to process.
338 return True
339 else:
340 node_name = node_or_node_name
342 if exclude_nodes and (node_or_node_name in exclude_nodes
343 or node_name in exclude_nodes):
344 return False
346 return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
347 (not export_scope or node_name.startswith(export_scope)))
350def add_collection_def(meta_graph_def, key, graph=None,
351 export_scope=None, exclude_nodes=None,
352 override_contents=None):
353 """Adds a collection to MetaGraphDef protocol buffer.
355 Args:
356 meta_graph_def: MetaGraphDef protocol buffer.
357 key: One of the GraphKeys or user-defined string.
358 graph: The `Graph` from which to get collections.
359 export_scope: Optional `string`. Name scope to remove.
360 exclude_nodes: An iterable of nodes or `string` node names to omit from the
361 collection, or None.
362 override_contents: An iterable of values to place in the collection,
363 ignoring the current values (if set).
364 """
365 if graph and not isinstance(graph, ops.Graph):
366 raise TypeError(
367 f"graph must be of type Graph. Received type: {type(graph)}.")
369 if not isinstance(key, str) and not isinstance(key, bytes):
370 logging.warning("Only collections with string type keys will be "
371 "serialized. This key has %s", type(key))
372 return
374 # Sets graph to default graph if it's not passed in.
375 graph = graph or ops.get_default_graph()
377 if override_contents:
378 collection_list = override_contents
379 else:
380 collection_list = graph.get_collection(key)
382 # Remove nodes that should not be exported from the collection list.
383 collection_list = [x for x in collection_list if
384 _should_include_node(x, export_scope, exclude_nodes)]
385 if not collection_list:
386 return
388 try:
389 col_def = meta_graph_def.collection_def[key]
390 to_proto = ops.get_to_proto_function(key)
391 proto_type = ops.get_collection_proto_type(key)
392 if to_proto:
393 kind = "bytes_list"
394 for x in collection_list:
395 # Additional type check to make sure the returned proto is indeed
396 # what we expect.
397 proto = to_proto(x, export_scope=export_scope)
398 if proto:
399 assert isinstance(proto, proto_type)
400 getattr(col_def, kind).value.append(proto.SerializeToString())
401 else:
402 kind = _get_kind_name(collection_list[0])
403 if kind == "node_list":
404 for x in collection_list:
405 if not export_scope or x.name.startswith(export_scope):
406 getattr(col_def, kind).value.append(
407 ops.strip_name_scope(x.name, export_scope))
408 elif kind == "bytes_list":
409 # NOTE(opensource): This force conversion is to work around the fact
410 # that Python3 distinguishes between bytes and strings.
411 getattr(col_def, kind).value.extend(
412 [compat.as_bytes(x) for x in collection_list])
413 else:
414 getattr(col_def, kind).value.extend([x for x in collection_list])
415 except Exception as e: # pylint: disable=broad-except
416 logging.warning("Issue encountered when serializing %s.\n"
417 "Type is unsupported, or the types of the items don't "
418 "match field type in CollectionDef. Note this is a warning "
419 "and probably safe to ignore.\n%s", key, str(e))
420 if key in meta_graph_def.collection_def:
421 del meta_graph_def.collection_def[key]
422 return
425def _is_default_attr_value(op_def, attr_name, attr_value):
426 """Checks if given attribute matches the default value in the op def."""
427 for attr_def in op_def.attr:
428 if attr_def.name == attr_name:
429 if not attr_def.HasField("default_value"):
430 return False
431 # c_api.EqualAttrValueWrapper returns an empty string
432 # if both arguments represent an equivalent AttrValue instance.
433 return not c_api.EqualAttrValueWrapper(
434 attr_value.SerializeToString(),
435 attr_def.default_value.SerializeToString())
436 return False
439def strip_graph_default_valued_attrs(meta_graph_def):
440 """Strips default valued attributes for node defs in given MetaGraphDef.
442 This method also sets `meta_info_def.stripped_default_attrs` in the given
443 `MetaGraphDef` proto to True.
445 Args:
446 meta_graph_def: `MetaGraphDef` protocol buffer
448 Returns:
449 None.
450 """
451 # Map function op names to their function definitions.
452 op_name_to_function = {}
453 for function_def in meta_graph_def.graph_def.library.function:
454 op_name_to_function[function_def.signature.name] = function_def
456 def _strip_node_default_valued_attrs(node_def):
457 """Removes default valued attributes from a single node def."""
458 if node_def.op in op_name_to_function:
459 return
461 op_def = op_def_registry.get(node_def.op)
462 if op_def is None:
463 return
465 attrs_to_strip = set()
466 for attr_name, attr_value in node_def.attr.items():
467 if _is_default_attr_value(op_def, attr_name, attr_value):
468 attrs_to_strip.add(attr_name)
470 for attr in attrs_to_strip:
471 del node_def.attr[attr]
473 # Process all NodeDef instances in graph_def.
474 for node_def in meta_graph_def.graph_def.node:
475 _strip_node_default_valued_attrs(node_def)
477 # Process all NodeDef instances in graph_def.library.function.
478 for function_def in meta_graph_def.graph_def.library.function:
479 for function_node_def in function_def.node_def:
480 _strip_node_default_valued_attrs(function_node_def)
482 # Tell consumers of this graph that default valued attrs have been stripped.
483 meta_graph_def.meta_info_def.stripped_default_attrs = True
486def create_meta_graph_def(meta_info_def=None,
487 graph_def=None,
488 saver_def=None,
489 collection_list=None,
490 graph=None,
491 export_scope=None,
492 exclude_nodes=None,
493 clear_extraneous_savers=False,
494 strip_default_attrs=False):
495 # pylint: disable=line-too-long
496 """Construct and returns a `MetaGraphDef` protocol buffer.
498 Args:
499 meta_info_def: `MetaInfoDef` protocol buffer.
500 graph_def: `GraphDef` protocol buffer.
501 saver_def: `SaverDef` protocol buffer.
502 collection_list: List of string keys to collect.
503 graph: The `Graph` to create `MetaGraphDef` out of.
504 export_scope: Optional `string`. Name scope to remove.
505 exclude_nodes: An iterable of nodes or `string` node names to omit from all
506 collection, or None.
507 clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
508 collection. Note this method does not alter the graph, so any
509 extraneous Save/Restore ops should have been removed already, as needed.
510 strip_default_attrs: Boolean. If `True`, default-valued attributes will be
511 removed from the NodeDefs. For a detailed guide, see
512 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
514 Returns:
515 MetaGraphDef protocol buffer.
517 Raises:
518 TypeError: If the arguments are not of the correct proto buffer type.
519 """
520 # pylint: enable=line-too-long
521 # Type check.
522 if graph and not isinstance(graph, ops.Graph):
523 raise TypeError(
524 f"graph must be of type Graph. Received type: {type(graph)}.")
525 if meta_info_def and not isinstance(meta_info_def,
526 meta_graph_pb2.MetaGraphDef.MetaInfoDef):
527 raise TypeError(
528 "meta_info_def must be of type MetaInfoDef. "
529 f"Received type: {type(meta_info_def)}.")
530 if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
531 raise TypeError(
532 "graph_def must be of type GraphDef. "
533 f"Received type: {type(graph_def)}.")
534 if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
535 raise TypeError(
536 f"saver_def must be of type SaverDef. "
537 f"Received type: {type(saver_def)}.")
539 # Sets graph to default graph if it's not passed in.
540 graph = graph or ops.get_default_graph()
542 # Creates a MetaGraphDef proto.
543 meta_graph_def = meta_graph_pb2.MetaGraphDef()
544 # Adds meta_info_def.
545 if not meta_info_def:
546 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
548 # Set the tf version strings to the current tf build.
549 meta_info_def.tensorflow_version = versions.__version__
550 meta_info_def.tensorflow_git_version = versions.__git_version__
551 meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
553 # Adds graph_def or the default.
554 if not graph_def:
555 meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
556 else:
557 meta_graph_def.graph_def.MergeFrom(graph_def)
559 # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
560 # pylint: disable=g-explicit-length-test
561 if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
562 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
563 stripped_op_list_for_graph(meta_graph_def.graph_def))
564 # pylint: enable=g-explicit-length-test
566 # Strip default valued attributes in graph_def.
567 if strip_default_attrs:
568 strip_graph_default_valued_attrs(meta_graph_def)
570 # Adds saver_def.
571 if saver_def:
572 meta_graph_def.saver_def.MergeFrom(saver_def)
574 # Adds collection_list.
575 if collection_list is not None:
576 clist = collection_list
577 else:
578 clist = graph.get_all_collection_keys()
580 for ctype in clist:
581 if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
582 # Avoid importing Saver here
583 from_proto = ops.get_from_proto_function(ctype)
584 add_collection_def(meta_graph_def, ctype,
585 graph=graph,
586 export_scope=export_scope,
587 exclude_nodes=exclude_nodes,
588 override_contents=[from_proto(saver_def)])
589 else:
590 add_collection_def(meta_graph_def, ctype,
591 graph=graph,
592 export_scope=export_scope,
593 exclude_nodes=exclude_nodes)
594 return meta_graph_def
597def read_meta_graph_file(filename):
598 """Reads a file containing `MetaGraphDef` and returns the protocol buffer.
600 Args:
601 filename: `meta_graph_def` filename including the path.
603 Returns:
604 A `MetaGraphDef` protocol buffer.
606 Raises:
607 IOError: If the file doesn't exist, or cannot be successfully parsed.
608 """
609 meta_graph_def = meta_graph_pb2.MetaGraphDef()
610 if not file_io.file_exists(filename):
611 raise IOError(f"File does not exist. Received: {filename}.")
612 # First try to read it as a binary file.
613 with file_io.FileIO(filename, "rb") as f:
614 file_content = f.read()
615 try:
616 meta_graph_def.ParseFromString(file_content)
617 if sys.byteorder == "big":
618 bst.swap_tensor_content_in_graph_function(meta_graph_def, "little", "big")
619 return meta_graph_def
620 except Exception: # pylint: disable=broad-except
621 pass
623 # Next try to read it as a text file.
624 try:
625 text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
626 if sys.byteorder == "big":
627 bst.swap_tensor_content_in_graph_function(meta_graph_def, "little", "big")
628 except text_format.ParseError as e:
629 raise IOError(f"Cannot parse file {filename}: {str(e)}.")
631 return meta_graph_def
634def import_scoped_meta_graph(meta_graph_or_file,
635 clear_devices=False,
636 graph=None,
637 import_scope=None,
638 input_map=None,
639 unbound_inputs_col_name="unbound_inputs",
640 restore_collections_predicate=(lambda key: True)):
641 """Recreates a `Graph` saved in a `MetaGraphDef` proto.
643 This function takes a `MetaGraphDef` protocol buffer as input. If
644 the argument is a file containing a `MetaGraphDef` protocol buffer ,
645 it constructs a protocol buffer from the file content. The function
646 then adds all the nodes from the `graph_def` field to the
647 current graph, recreates the desired collections, and returns a dictionary of
648 all the Variables imported into the name scope.
650 In combination with `export_scoped_meta_graph()`, this function can be used to
652 * Serialize a graph along with other Python objects such as `QueueRunner`,
653 `Variable` into a `MetaGraphDef`.
655 * Restart training from a saved graph and checkpoints.
657 * Run inference from a saved graph and checkpoints.
659 Args:
660 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
661 the path) containing a `MetaGraphDef`.
662 clear_devices: Boolean which controls whether to clear device information
663 from graph_def. Default false.
664 graph: The `Graph` to import into. If `None`, use the default graph.
665 import_scope: Optional `string`. Name scope into which to import the
666 subgraph. If `None`, the graph is imported to the root name scope.
667 input_map: A dictionary mapping input names (as strings) in `graph_def` to
668 `Tensor` objects. The values of the named input tensors in the imported
669 graph will be re-mapped to the respective `Tensor` values.
670 unbound_inputs_col_name: Collection name for looking up unbound inputs.
671 restore_collections_predicate: a predicate on collection names. A collection
672 named c (i.e whose key is c) will be restored iff
673 1) `restore_collections_predicate(c)` is True, and
674 2) `c != unbound_inputs_col_name`.
676 Returns:
677 A dictionary of all the `Variables` imported into the name scope.
679 Raises:
680 ValueError: If the graph_def contains unbound inputs.
681 """
682 return import_scoped_meta_graph_with_return_elements(
683 meta_graph_or_file, clear_devices, graph, import_scope, input_map,
684 unbound_inputs_col_name, restore_collections_predicate)[0]
687def import_scoped_meta_graph_with_return_elements(
688 meta_graph_or_file,
689 clear_devices=False,
690 graph=None,
691 import_scope=None,
692 input_map=None,
693 unbound_inputs_col_name="unbound_inputs",
694 restore_collections_predicate=(lambda key: True),
695 return_elements=None):
696 """Imports graph from `MetaGraphDef` and returns vars and return elements.
698 This function takes a `MetaGraphDef` protocol buffer as input. If
699 the argument is a file containing a `MetaGraphDef` protocol buffer ,
700 it constructs a protocol buffer from the file content. The function
701 then adds all the nodes from the `graph_def` field to the
702 current graph, recreates the desired collections, and returns a dictionary of
703 all the Variables imported into the name scope.
705 In combination with `export_scoped_meta_graph()`, this function can be used to
707 * Serialize a graph along with other Python objects such as `QueueRunner`,
708 `Variable` into a `MetaGraphDef`.
710 * Restart training from a saved graph and checkpoints.
712 * Run inference from a saved graph and checkpoints.
714 Args:
715 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
716 the path) containing a `MetaGraphDef`.
717 clear_devices: Boolean which controls whether to clear device information
718 from graph_def. Default false.
719 graph: The `Graph` to import into. If `None`, use the default graph.
720 import_scope: Optional `string`. Name scope into which to import the
721 subgraph. If `None`, the graph is imported to the root name scope.
722 input_map: A dictionary mapping input names (as strings) in `graph_def` to
723 `Tensor` objects. The values of the named input tensors in the imported
724 graph will be re-mapped to the respective `Tensor` values.
725 unbound_inputs_col_name: Collection name for looking up unbound inputs.
726 restore_collections_predicate: a predicate on collection names. A collection
727 named c (i.e whose key is c) will be restored iff
728 1) `restore_collections_predicate(c)` is True, and
729 2) `c != unbound_inputs_col_name`.
730 return_elements: A list of strings containing operation names in the
731 `MetaGraphDef` that will be returned as `Operation` objects; and/or
732 tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
734 Returns:
735 A tuple of (
736 dictionary of all the `Variables` imported into the name scope,
737 list of `Operation` or `Tensor` objects from the `return_elements` list).
739 Raises:
740 ValueError: If the graph_def contains unbound inputs.
742 """
743 if context.executing_eagerly():
744 raise ValueError("Exporting/importing meta graphs is not supported when "
745 "eager execution is enabled.")
746 if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
747 meta_graph_def = meta_graph_or_file
748 else:
749 meta_graph_def = read_meta_graph_file(meta_graph_or_file)
751 if unbound_inputs_col_name:
752 for key, col_def in meta_graph_def.collection_def.items():
753 if key == unbound_inputs_col_name:
754 kind = col_def.WhichOneof("kind")
755 field = getattr(col_def, kind)
756 if field.value and (
757 not input_map or
758 sorted([compat.as_str(v) for v in field.value]) !=
759 sorted(input_map)):
760 raise ValueError("Graph contains unbound inputs: %s. Must "
761 "provide these inputs through input_map." % ",".join(
762 compat.as_str(v)
763 for v in field.value
764 if not input_map or v not in input_map))
765 break
767 # Sets graph to default graph if it's not passed in.
768 graph = graph or ops.get_default_graph()
770 # Gathers the list of nodes we are interested in.
771 with graph.as_default():
772 producer_op_list = None
773 if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
774 producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
775 input_graph_def = meta_graph_def.graph_def
776 # Remove all the explicit device specifications for this node. This helps to
777 # make the graph more portable.
778 if clear_devices:
779 for node in input_graph_def.node:
780 node.device = ""
782 scope_to_prepend_to_names = graph.unique_name(
783 import_scope or "", mark_as_used=False)
785 imported_return_elements = importer.import_graph_def(
786 input_graph_def,
787 name=(import_scope or scope_to_prepend_to_names),
788 input_map=input_map,
789 producer_op_list=producer_op_list,
790 return_elements=return_elements)
792 # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
793 # without a VariableDef.trainable field set.
794 tf_version = meta_graph_def.meta_info_def.tensorflow_version
795 if not tf_version:
796 variables_have_trainable = True
797 else:
798 variables_have_trainable = (
799 packaging_version.parse(tf_version) >= packaging_version.parse("1.9"))
801 # Sort collections so we see TRAINABLE_VARIABLES first and can default these
802 # variables to trainable if the value is not set in their VariableDef.
803 sorted_collections = []
804 if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
805 sorted_collections.append(
806 (ops.GraphKeys.TRAINABLE_VARIABLES,
807 meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
808 for key, value in sorted(meta_graph_def.collection_def.items()):
809 if key != ops.GraphKeys.TRAINABLE_VARIABLES:
810 sorted_collections.append((key, value))
812 # Restores all the other collections.
813 variable_objects = {}
814 for key, col_def in sorted_collections:
815 # Don't add unbound_inputs to the new graph.
816 if key == unbound_inputs_col_name:
817 continue
818 if not restore_collections_predicate(key):
819 continue
821 kind = col_def.WhichOneof("kind")
822 if kind is None:
823 logging.error("Cannot identify data type for collection %s. Skipping.",
824 key)
825 continue
826 from_proto = ops.get_from_proto_function(key)
828 # Temporary change to allow the TFMA evaluator to read metric variables
829 # saved as a bytes list.
830 # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
831 if key == ops.GraphKeys.METRIC_VARIABLES:
832 # Metric variables will use the same proto functions as GLOBAL_VARIABLES
833 from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
834 if from_proto and kind == "bytes_list":
835 proto_type = ops.get_collection_proto_type(key)
836 if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
837 for value in col_def.bytes_list.value:
838 variable = variable_objects.get(value, None)
839 if variable is None:
840 proto = proto_type()
841 proto.ParseFromString(value)
842 if not variables_have_trainable:
843 # If the VariableDef proto does not contain a "trainable"
844 # property because it was exported before that property was
845 # added, we default it to whether the variable is in the
846 # TRAINABLE_VARIABLES collection. We've sorted
847 # TRAINABLE_VARIABLES to be first, so trainable variables will
848 # be created from that collection.
849 proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
850 variable = from_proto(
851 proto, import_scope=scope_to_prepend_to_names)
852 variable_objects[value] = variable
853 graph.add_to_collection(key, variable)
854 else:
855 for value in col_def.bytes_list.value:
856 proto = proto_type()
857 proto.ParseFromString(value)
858 graph.add_to_collection(
859 key, from_proto(
860 proto, import_scope=scope_to_prepend_to_names))
861 else:
862 field = getattr(col_def, kind)
863 if key in _COMPAT_COLLECTION_LIST:
864 logging.warning(
865 "The saved meta_graph is possibly from an older release:\n"
866 "'%s' collection should be of type 'byte_list', but instead "
867 "is of type '%s'.", key, kind)
868 if kind == "node_list":
869 for value in field.value:
870 col_op = graph.as_graph_element(
871 ops.prepend_name_scope(value, scope_to_prepend_to_names))
872 graph.add_to_collection(key, col_op)
873 elif kind == "int64_list":
874 # NOTE(opensource): This force conversion is to work around the fact
875 # that Python2 distinguishes between int and long, while Python3 has
876 # only int.
877 for value in field.value:
878 graph.add_to_collection(key, int(value))
879 else:
880 for value in field.value:
881 graph.add_to_collection(
882 key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
884 var_list = {}
885 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
886 scope=scope_to_prepend_to_names)
887 for v in variables:
888 var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
890 return var_list, imported_return_elements
893def export_scoped_meta_graph(filename=None,
894 graph_def=None,
895 graph=None,
896 export_scope=None,
897 as_text=False,
898 unbound_inputs_col_name="unbound_inputs",
899 clear_devices=False,
900 saver_def=None,
901 clear_extraneous_savers=False,
902 strip_default_attrs=False,
903 save_debug_info=False,
904 **kwargs):
905 """Returns `MetaGraphDef` proto. Optionally writes it to filename.
907 This function exports the graph, saver, and collection objects into
908 `MetaGraphDef` protocol buffer with the intention of it being imported
909 at a later time or location to restart training, run inference, or be
910 a subgraph.
912 Args:
913 filename: Optional filename including the path for writing the
914 generated `MetaGraphDef` protocol buffer.
915 graph_def: `GraphDef` protocol buffer.
916 graph: The `Graph` to export. If `None`, use the default graph.
917 export_scope: Optional `string`. Name scope under which to extract
918 the subgraph. The scope name will be stripped from the node definitions
919 for easy import later into new name scopes. If `None`, the whole graph
920 is exported.
921 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
922 unbound_inputs_col_name: Optional `string`. If provided, a string collection
923 with the given name will be added to the returned `MetaGraphDef`,
924 containing the names of tensors that must be remapped when importing the
925 `MetaGraphDef`.
926 clear_devices: Boolean which controls whether to clear device information
927 before exporting the graph.
928 saver_def: `SaverDef` protocol buffer.
929 clear_extraneous_savers: Remove any Saver-related information from the
930 graph (both Save/Restore ops and SaverDefs) that are not associated
931 with the provided SaverDef.
932 strip_default_attrs: Set to true if default valued attributes must be
933 removed while exporting the GraphDef.
934 save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
935 which in the same directory of filename and with `_debug` added before the
936 file extension.
937 **kwargs: Optional keyed arguments, including meta_info_def and
938 collection_list.
940 Returns:
941 A `MetaGraphDef` proto and dictionary of `Variables` in the exported
942 name scope.
944 Raises:
945 ValueError: When the `GraphDef` is larger than 2GB.
946 ValueError: When executing in Eager mode and either `graph_def` or `graph`
947 is undefined.
948 """
949 if context.executing_eagerly() and not (graph_def is not None and
950 graph is not None):
951 raise ValueError("Exporting/importing meta graphs is not supported when "
952 "Eager Execution is enabled.")
953 graph = graph or ops.get_default_graph()
955 exclude_nodes = None
956 unbound_inputs = []
957 if export_scope or clear_extraneous_savers or clear_devices:
958 if graph_def:
959 new_graph_def = graph_pb2.GraphDef()
960 new_graph_def.versions.CopyFrom(graph_def.versions)
961 new_graph_def.library.CopyFrom(graph_def.library)
963 if clear_extraneous_savers:
964 exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
966 for node_def in graph_def.node:
967 if _should_include_node(node_def.name, export_scope, exclude_nodes):
968 new_node_def = _node_def(node_def, export_scope, unbound_inputs,
969 clear_devices=clear_devices)
970 new_graph_def.node.extend([new_node_def])
971 graph_def = new_graph_def
972 else:
973 # Only do this complicated work if we want to remove a name scope.
974 graph_def = graph_pb2.GraphDef()
975 # pylint: disable=protected-access
976 graph_def.versions.CopyFrom(graph.graph_def_versions)
977 bytesize = 0
979 if clear_extraneous_savers:
980 exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
981 saver_def)
983 for key in sorted(graph._nodes_by_id):
984 if _should_include_node(graph._nodes_by_id[key].name,
985 export_scope,
986 exclude_nodes):
987 value = graph._nodes_by_id[key]
988 # pylint: enable=protected-access
989 node_def = _node_def(value.node_def, export_scope, unbound_inputs,
990 clear_devices=clear_devices)
991 graph_def.node.extend([node_def])
992 if value.outputs:
993 assert "_output_shapes" not in graph_def.node[-1].attr
994 graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
995 output.get_shape().as_proto() for output in value.outputs])
996 bytesize += value.node_def.ByteSize()
997 if bytesize >= (1 << 31) or bytesize < 0:
998 raise ValueError(
999 "GraphDef cannot be larger than 2GB. "
1000 f"Received size: {bytesize}.")
1002 graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access
1004 # It's possible that not all the inputs are in the export_scope.
1005 # If we would like such information included in the exported meta_graph,
1006 # add them to a special unbound_inputs collection.
1007 if unbound_inputs_col_name:
1008 # Clears the unbound_inputs collections.
1009 graph.clear_collection(unbound_inputs_col_name)
1010 for k in unbound_inputs:
1011 graph.add_to_collection(unbound_inputs_col_name, k)
1013 var_list = {}
1014 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1015 scope=export_scope)
1016 for v in variables:
1017 if _should_include_node(v, export_scope, exclude_nodes):
1018 var_list[ops.strip_name_scope(v.name, export_scope)] = v
1020 scoped_meta_graph_def = create_meta_graph_def(
1021 graph_def=graph_def,
1022 graph=graph,
1023 export_scope=export_scope,
1024 exclude_nodes=exclude_nodes,
1025 clear_extraneous_savers=clear_extraneous_savers,
1026 saver_def=saver_def,
1027 strip_default_attrs=strip_default_attrs,
1028 **kwargs)
1030 if filename:
1031 graph_io.write_graph(
1032 scoped_meta_graph_def,
1033 os.path.dirname(filename),
1034 os.path.basename(filename),
1035 as_text=as_text)
1036 if save_debug_info:
1037 name, _ = os.path.splitext(filename)
1038 debug_filename = "{name}{ext}".format(name=name, ext=".debug")
1040 # Gets the operation from the graph by the name. Excludes variable nodes,
1041 # so only the nodes in the frozen models are included.
1042 # TODO(liufengdb): fix this for functions.
1043 ops_to_export = []
1044 for node in scoped_meta_graph_def.graph_def.node:
1045 scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
1046 ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
1048 graph_debug_info = error_interpolation.create_graph_debug_info_def(
1049 ops_to_export)
1051 graph_io.write_graph(
1052 graph_debug_info,
1053 os.path.dirname(debug_filename),
1054 os.path.basename(debug_filename),
1055 as_text=as_text)
1057 return scoped_meta_graph_def, var_list
1060def copy_scoped_meta_graph(from_scope, to_scope,
1061 from_graph=None, to_graph=None):
1062 """Copies a sub-meta_graph from one scope to another.
1064 Args:
1065 from_scope: `String` name scope containing the subgraph to be copied.
1066 to_scope: `String` name scope under which the copied subgraph will reside.
1067 from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
1068 default graph is use.
1069 to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
1070 default graph is used.
1072 Returns:
1073 A dictionary of `Variables` that has been copied into `to_scope`.
1075 Raises:
1076 ValueError: If `from_scope` and `to_scope` are the same while
1077 `from_graph` and `to_graph` are also the same.
1078 """
1079 from_graph = from_graph or ops.get_default_graph()
1080 to_graph = to_graph or ops.get_default_graph()
1082 if from_graph == to_graph and from_scope == to_scope:
1083 raise ValueError("'from_scope' and 'to_scope' need to be different "
1084 "when performing copy in the same graph. "
1085 f"Received: 'from_graph': {from_graph}, "
1086 f"'to_graph': {to_graph}, "
1087 f"'from_scope': {from_scope}, 'to_scope': {to_scope}.")
1089 orig_meta_graph, var_list = export_scoped_meta_graph(
1090 export_scope=from_scope, graph=from_graph)
1091 var_list = import_scoped_meta_graph(orig_meta_graph,
1092 graph=to_graph,
1093 import_scope=to_scope)
1094 return var_list