Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/convert_to_constants.py: 27%
540 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to convert variables to constants in TensorFlow 2.0."""
17import collections
18import numpy as np
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import tensor_shape_pb2
23from tensorflow.core.framework import variable_pb2
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.core.protobuf import meta_graph_pb2
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.python.eager import context
28from tensorflow.python.eager import wrap_function
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import graph_util
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.grappler import tf_optimizer
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.training.saver import export_meta_graph
39from tensorflow.python.util import deprecation
40from tensorflow.python.util import object_identity
41from tensorflow.python.util.tf_export import tf_export
44# Used in _FunctionConverterDataInGraph().
45VAR_ASSIGN_COLLECTION = "extra_var_assign_ops"
46_CONDITIONAL_OPS = set(["If", "StatelessIf"])
47_LOOP_OPS = set(["While", "StatelessWhile"])
48_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
51class _TensorData(
52 collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])):
53 """Data about a tensor that was converted to a constant."""
54 __slots__ = ()
56 @property
57 def dtype_attr(self):
58 return attr_value_pb2.AttrValue(type=self.dtype)
61class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])):
62 """An endpoint in a graph."""
63 __slots__ = ()
65 def __str__(self):
66 return "{}[{}]".format(self.convertible, self.index)
69class _Edge(collections.namedtuple("_Edge", ["source", "destination"])):
70 """A directed graph edge."""
71 __slots__ = ()
73 def __str__(self):
74 return "{} -> {}".format(self.source, self.destination)
77class _Convertible(object):
78 """An entity that can have variables converted to constants."""
80 def __init__(self, enclosing_graph):
81 self._enclosing_graph = enclosing_graph
82 self._outgoing_edges = []
83 self._converted_self = None
85 def converted_self(self):
86 """A copy of this Convertible to be modified during conversion.
88 Returns:
89 Implementations should return the copied instance, which in turn should
90 be contained in converted_enclosing_graph(). This instance is the one that
91 will be modified during conversion. Its main use will be in the
92 implementations of convert_variable_to_constant().
93 """
94 raise NotImplementedError
96 def convert_variable_to_constant(self, incoming_edge, tensor_data):
97 """Converts a variable in this Convertible and its dependencies.
99 This method should make sure that a converted copy of itself is present in
100 the converted graph, and that all Convertibles depending on this one also go
101 through the same process.
103 Args:
104 incoming_edge: The graph edge into this Convertible that is being
105 converted to a constant.
106 tensor_data: The tensor representing the constant.
107 """
108 raise NotImplementedError
110 def create_edges(self):
111 """Calls add_outgoing_edge for all edges known to this Convertible.
113 This is used to build the graph dependencies, so that conversion of
114 variables to constants can be properly propagated through the graph. Usually
115 this method will call add_outgoing_edge() to all the Convertible inputs.
116 """
117 raise NotImplementedError
119 def add_outgoing_edge(self, edge):
120 """Adds an outgoing edge to the Convertible's list of edges.
122 Args:
123 edge: The outgoing edge (its source should be 'self').
124 """
125 self._outgoing_edges.append(edge)
127 @property
128 def converted_enclosing_graph(self):
129 """The graph being converted."""
130 return self._enclosing_graph.converted_self()
132 @property
133 def outgoing_edges(self):
134 """The list of edges starting at this Convertible."""
135 return self._outgoing_edges
138class _Function(_Convertible):
139 """A library function Convertible.
141 Edges into functions are edges from node _inputs_ into function _inputs_:
142 Functions get their input from their callers, not from node outputs, and the
143 callers in turn get those values as inputs.
144 """
146 def __init__(self, function, enclosing_graph):
147 super(_Function, self).__init__(enclosing_graph)
148 self._function = function
149 self._nodes = {
150 n.name:
151 _Node.new(node=n, function=self, enclosing_graph=enclosing_graph)
152 for n in function.node_def
153 }
155 def __str__(self):
156 return self.function.signature.name
158 @property
159 def function(self):
160 return self._function
162 @property
163 def nodes(self):
164 return self._nodes
166 def converted_self(self):
167 """The Function copy to be converted.
169 The copy will be renamed according to the graph's converted_function_name
170 map, to ensure the name does not match anything currently in TensorFlow's
171 function cache.
173 Returns:
174 The function instance to be converted.
175 """
176 if self._converted_self is None:
177 old_name = self.function.signature.name
178 new_name = self._enclosing_graph.converted_function_names[old_name]
179 self.converted_enclosing_graph.rename_function(old_name, new_name)
180 self._converted_self = self.converted_enclosing_graph.functions[new_name]
181 return self._converted_self
183 def convert_variable_to_constant(self, incoming_edge, tensor_data):
184 """Converts one function argument into a constant.
186 Args:
187 incoming_edge: The edge into the argument to be converted.
188 tensor_data: The constant value.
189 """
190 index = incoming_edge.destination.index
191 for edge in self.outgoing_edges:
192 if edge.source.index == index:
193 edge.destination.convertible.convert_variable_to_constant(
194 edge, tensor_data)
196 function = self.converted_self().function
197 function.signature.input_arg[index].type = tensor_data.dtype
198 # TODO(b/176982859): Find a more satisfying way to update shape information
199 # than clearing it, or migrate users to a workflow that does not require
200 # freezing.
201 if "_input_shapes" in function.attr:
202 function.attr["_input_shapes"].list.shape[index].unknown_rank = True
203 del function.attr["_input_shapes"].list.shape[index].dim[:]
204 arg_attrs = function.arg_attr[index].attr
205 if "_output_shapes" in arg_attrs:
206 arg_attrs["_output_shapes"].list.shape[0].unknown_rank = True
207 del arg_attrs["_output_shapes"].list.shape[0].dim[:]
209 def create_edges(self):
210 for n in self._nodes.values():
211 n.create_edges()
214class _Node(_Convertible):
215 """A Convertible NodeDef."""
217 def __init__(self, node, function, enclosing_graph):
218 super(_Node, self).__init__(enclosing_graph)
219 self._node = node
220 self._function = function
222 def __str__(self):
223 return self._node.name
225 @staticmethod
226 def new(node, function, enclosing_graph):
227 """Creates a new _Node base on its operation type."""
228 if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]:
229 return _VarHandle(node, function, enclosing_graph)
230 elif node.op == "Case":
231 return _Case(node, function, enclosing_graph)
232 elif node.op == "Merge":
233 return _Merge(node, function, enclosing_graph)
234 elif node.op == "PartitionedCall":
235 return _PartitionedCall(node, function, enclosing_graph)
236 elif node.op == "StatefulPartitionedCall":
237 return _PartitionedCall(node, function, enclosing_graph)
238 elif node.op == "ReadVariableOp":
239 return _ReadVariable(node, function, enclosing_graph)
240 elif node.op == "ResourceGather":
241 return _ResourceGather(node, function, enclosing_graph)
242 elif node.op == "ResourceGatherNd":
243 return _ResourceGatherNd(node, function, enclosing_graph)
244 elif node.op in ["If", "StatelessIf"]:
245 return _If(node, function, enclosing_graph)
246 elif node.op in ["While", "StatelessWhile"]:
247 return _While(node, function, enclosing_graph)
248 elif node.op in [
249 "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
250 return _Intermediate(node, function, enclosing_graph)
251 else:
252 return _Node(node, function, enclosing_graph)
254 @property
255 def node(self):
256 return self._node
258 @property
259 def container(self):
260 """The node container (either a graph or a function)."""
261 if self._function is not None:
262 return self._function.function
263 return self._enclosing_graph.graph_def
265 def converted_self(self):
266 """The NodeDef to be converted.
268 Returns:
269 The NodeDef to be converted, which can come from either a graph for a
270 function. Derived classes should call this (via 'super') to make sure the
271 node is retrieved from the right place.
272 """
273 if self._converted_self is None:
274 source = self._function or self._enclosing_graph
275 self._converted_self = source.converted_self().nodes[self._node.name]
276 return self._converted_self
278 def convert_variable_to_constant(self, incoming_edge, tensor_data):
279 pass
281 def create_edges(self):
282 for index, name in enumerate(self._node.input):
283 # Discard edges from control inputs.
284 if name[0] == "^":
285 continue
286 source = self.resolve_input(name)
287 source.convertible.add_outgoing_edge(
288 _Edge(source, _EndPoint(self, index)))
290 def resolve_input(self, input_name):
291 """Resolves an input into its _EndPoint.
293 A NodeDef's input name can refer to either global NodeDefs (in the
294 GraphDef's node list), a NodeDef in a function's node list, or a Function
295 (in the GraphDef's function library). The name can also carry semantic
296 information, depending on whether it starts with "^". This method handles
297 all that logic in order to find the object to which the input name refers
298 to.
300 Args:
301 input_name: The input name to resolve.
303 Returns:
304 The object referred to by 'input_name'.
305 """
307 # The logic below oversimplifies the semantics, but is good enough for the
308 # purposes of converting to constants. The introduction of new types of
309 # operations may change this, forcing the code to be more generic.
310 #
311 # In particular, we are assuming that the lack of an index suffix means
312 # ":0", when it could mean "all the outputs of a node." This works now
313 # because converting to constants relies very little on output types, and
314 # when it does it specializes its treatment in dedicated classes.
315 name_elts = input_name.split(":")
316 source_name = name_elts[0]
317 if source_name[0] == "^":
318 source_name = source_name[1:]
319 source_index = 0
320 if len(name_elts) > 1 and name_elts[-1].isnumeric():
321 source_index = int(name_elts[-1])
323 if self._function is None:
324 return _EndPoint(self._enclosing_graph.nodes[source_name], source_index)
326 if source_index != 0 or source_name in self._function.nodes:
327 return _EndPoint(self._function.nodes[source_name], source_index)
329 inputs = [i.name for i in self._function.function.signature.input_arg]
330 return _EndPoint(self._function, inputs.index(source_name))
332 def update_dtype(self, attr_name, index, dtype):
333 """Changes the type of a given input.
335 Args:
336 attr_name: The NodeDef attribute containing the type to change.
337 index: The index of the input type to change.
338 dtype: The type to change to.
339 """
340 attr = self._node.attr[attr_name]
341 num_types = 0
342 # Check for various 'oneof' possibilities, and update the type if
343 # index in range.
344 if attr.HasField("list"):
345 types = attr.list.type
346 num_types = len(types)
347 if num_types > index:
348 types[index] = dtype
349 return
350 elif attr.HasField("type"):
351 num_types = 1
352 if index == 0:
353 attr.type = dtype
354 return
355 raise ValueError(f"`index` {index:d} is out of range for "
356 f"node({self._node.name}).attr({attr_name}), which has "
357 f"{num_types:d} elements.")
360class _Intermediate(_Node):
361 """Specialization of _Node to intermediate ops."""
363 def convert_variable_to_constant(self, incoming_edge, tensor_data):
364 node = self.converted_self()
365 node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype)
366 if "_output_shapes" in node.node.attr:
367 del node.node.attr["_output_shapes"]
368 for edge in self.outgoing_edges:
369 edge.destination.convertible.convert_variable_to_constant(
370 edge, tensor_data)
373class _Merge(_Node):
374 """Specialization of _Node to Merge ops."""
376 def convert_variable_to_constant(self, incoming_edge, tensor_data):
377 # The Merge operation has a single type for all its inputs, the number of
378 # which is reflected in the "N" attribute. For the time being, we assume
379 # that unilaterally changing all of them at once is ok.
380 super(_Merge, self).convert_variable_to_constant(
381 _Edge(incoming_edge.source,
382 _Edge(incoming_edge.destination.convertible, 0)), tensor_data)
385class _VarHandle(_Node):
386 """Specialization of _Node to VarHandleOp."""
388 def convert_variable_to_constant(self, incoming_edge, tensor_data):
389 tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy,
390 tensor_data.dtype,
391 tensor_data.numpy.shape)
393 node = self.converted_self().node
394 node.Clear()
395 node.name = self._node.name
396 node.op = "Const"
397 node.attr["dtype"].CopyFrom(tensor_data.dtype_attr)
398 node.attr["value"].tensor.CopyFrom(tensor_proto)
400 for edge in self.outgoing_edges:
401 edge.destination.convertible.convert_variable_to_constant(
402 edge, tensor_data)
405class _ResourceGather(_Node):
406 """Specialization of _Node to ResourceGather."""
408 def convert_variable_to_constant(self, incoming_edge, tensor_data):
409 # We currently skip the conversion if this is inside a function.
410 if self._function is not None:
411 return
412 if self._node.attr["batch_dims"].i != 0:
413 raise ValueError("batch_dims must be 0 for freeze_graph, but got "
414 f"node({self._node.name}).attr('batch_dims') = "
415 f"{self._node.attr['batch_dims'].i}.")
416 axis_node_name = self._node.name + "/axis"
417 axis_dtype = self._node.attr["Tindices"]
418 axis_data = np.array(self._node.attr["batch_dims"].i)
419 converted_graph = self._enclosing_graph.converted_self()
420 # Add Const axis node, or get it if it exists to avoid duplicates.
421 if axis_node_name not in converted_graph.nodes:
422 converted_graph.nodes[axis_node_name] = _Node.new(
423 node=converted_graph.graph_def.node.add(),
424 function=self._function,
425 enclosing_graph=converted_graph)
426 output_axis_node = converted_graph.nodes[axis_node_name].node
427 output_axis_node.name = axis_node_name
428 output_axis_node.op = "Const"
429 output_axis_node.attr["dtype"].CopyFrom(axis_dtype)
430 tensor = tensor_util.make_tensor_proto(
431 axis_data, dtype=axis_dtype.type, shape=axis_data.shape)
432 output_axis_node.attr["value"].tensor.CopyFrom(tensor)
434 output_node = self.converted_self().node
435 output_node.Clear()
436 output_node.name = self._node.name
437 output_node.op = "GatherV2"
438 output_node.input.extend(
439 [self._node.input[0], self._node.input[1], axis_node_name])
440 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
441 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
442 output_node.attr["Taxis"].CopyFrom(axis_dtype)
443 if "_class" in self._node.attr:
444 output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
447class _ResourceGatherNd(_Node):
448 """Specialization of _Node to ResourceGatherNd."""
450 def convert_variable_to_constant(self, incoming_edge, tensor_data):
451 output_node = self.converted_self().node
452 output_node.Clear()
453 output_node.name = self._node.name
454 output_node.op = "GatherNd"
455 output_node.input.extend([self._node.input[0], self._node.input[1]])
456 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
457 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
458 if "_class" in self._node.attr:
459 output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
462class _ReadVariable(_Node):
463 """Specialization of _Node to ReadVariableOp."""
465 def convert_variable_to_constant(self, incoming_edge, tensor_data):
466 node = self.converted_self().node
467 node.Clear()
468 node.name = self._node.name
469 node.op = "Identity"
471 node.input.append(self._node.input[0])
472 node.attr["T"].CopyFrom(self._node.attr["dtype"])
473 if "_class" in self._node.attr:
474 node.attr["_class"].CopyFrom(self._node.attr["_class"])
476 # If the ReadVariableOp is part of a function, then every node having the
477 # ReadVariableOp one as its input will refer to it using a ":value"
478 # syntax. We need to change that to ":output".
479 if self._function is not None:
480 for edge in self.outgoing_edges:
481 index = edge.destination.index
482 dest = edge.destination.convertible.converted_self()
483 if isinstance(dest, _Node):
484 input_name_parts = dest.node.input[index].split(":")
485 if len(input_name_parts) > 1 and input_name_parts[1] == "value":
486 input_name_parts[1] = "output"
487 dest.node.input[index] = ":".join(input_name_parts)
490class _FunctionCaller(_Node):
491 """A base class for Convertibles that reference functions."""
493 def __init__(self, node, function, enclosing_graph, first_function_input,
494 type_attribute, function_attributes):
495 """Initializes a _FunctionCaller.
497 Args:
498 node: As in _Node.
499 function: As in _Node.
500 enclosing_graph: As in _Node.
501 first_function_input: The index of the first NodeDef input that is tied to
502 the function inputs. It is assumed that the rest of the NodeDef inputs
503 map one to one to function inputs.
504 type_attribute: The name of the NodeDef attribute that defines the input
505 types. It is assumed that the types listed here map one-to-one with the
506 function inputs (that is, they do _not_ specify types for inputs that
507 are not passed to functions).
508 function_attributes: The names of the NodeDef attributes containing
509 references to functions.
510 """
511 super(_FunctionCaller, self).__init__(node, function, enclosing_graph)
512 self._first_function_input = first_function_input
513 self._type_attribute = type_attribute
514 self._function_attributes = function_attributes
516 def converted_self(self):
517 if self._converted_self is None:
518 node = super(_FunctionCaller, self).converted_self().node
519 converted_names = self._enclosing_graph.converted_function_names
520 for attr_name in self._function_attributes:
521 attr = node.attr[attr_name]
522 if attr.HasField(
523 "func") and self._enclosing_graph.is_converted_function(
524 attr.func.name):
525 attr.func.name = converted_names[attr.func.name]
526 elif attr.HasField("list"):
527 for func in attr.list.func:
528 if self._enclosing_graph.is_converted_function(func.name):
529 func.name = converted_names[func.name]
530 return self._converted_self
532 def convert_variable_to_constant(self, incoming_edge, tensor_data):
533 index = incoming_edge.destination.index
534 # The loop below is reasonable but not correct in general:
535 # The outgoing edges going into the functions are correct, because the
536 # inputs map to the function inputs. But the edges going into other nodes do
537 # not take into account the logic of the body function, which may do
538 # arbitrary things to the node's output:
539 #
540 # while x < 0:
541 # return y
542 #
543 # In this case, the node's ":0" output may map to its ":1 input". For the
544 # time being, then, we only process edges into functions.
545 for edge in self.outgoing_edges:
546 dest = edge.destination.convertible
547 if edge.source.index == index and isinstance(dest, _Function):
548 dest.convert_variable_to_constant(edge, tensor_data)
550 node = self.converted_self()
551 if index >= self._first_function_input:
552 node.update_dtype(self._type_attribute,
553 index - self._first_function_input, tensor_data.dtype)
555 def create_edges(self):
556 """Creates edges related to a function caller.
558 Edges from a function caller to its called functions are always edges from
559 _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on
560 its own inputs.
561 """
562 super(_FunctionCaller, self).create_edges()
563 for attr_name in self._function_attributes:
564 attr = self._node.attr[attr_name]
565 if attr.HasField("func"):
566 function = self._enclosing_graph.functions[attr.func.name]
567 for index in range(len(self._node.input) - self._first_function_input):
568 self.add_outgoing_edge(
569 _Edge(
570 _EndPoint(self, index + self._first_function_input),
571 _EndPoint(function, index)))
572 elif attr.HasField("list"):
573 for func in attr.list.func:
574 function = self._enclosing_graph.functions[func.name]
575 for index in range(
576 len(self._node.input) - self._first_function_input):
577 self.add_outgoing_edge(
578 _Edge(
579 _EndPoint(self, index + self._first_function_input),
580 _EndPoint(function, index)))
583class _If(_FunctionCaller):
584 """Specialization of _Node to If-like operations."""
586 def __init__(self, node, function, enclosing_graph):
587 super(_If, self).__init__(
588 node,
589 function,
590 enclosing_graph,
591 first_function_input=1,
592 type_attribute="Tin",
593 function_attributes=["then_branch", "else_branch"])
596class _Case(_FunctionCaller):
597 """Specialization of _Node to Case-like operations."""
599 def __init__(self, node, function, enclosing_graph):
600 super(_Case, self).__init__(
601 node,
602 function,
603 enclosing_graph,
604 first_function_input=1,
605 type_attribute="Tin",
606 function_attributes=["branches"])
609class _PartitionedCall(_FunctionCaller):
610 """Specialization of _Node to PartitionedCall-like operations."""
612 def __init__(self, node, function, enclosing_graph):
613 super(_PartitionedCall, self).__init__(
614 node,
615 function,
616 enclosing_graph,
617 first_function_input=0,
618 type_attribute="Tin",
619 function_attributes=["f"])
622class _While(_FunctionCaller):
623 """Specialization of _Node to While-like operations."""
625 def __init__(self, node, function, enclosing_graph):
626 super(_While, self).__init__(
627 node,
628 function,
629 enclosing_graph,
630 first_function_input=0,
631 type_attribute="T",
632 function_attributes=["body", "cond"])
634 def convert_variable_to_constant(self, incoming_edge, tensor_data):
635 super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data)
636 node = self.converted_self()
637 if node.node.attr["output_shapes"].list.shape:
638 node.node.attr["output_shapes"].list.shape[
639 incoming_edge.destination.index].CopyFrom(
640 tensor_shape_pb2.TensorShapeProto(dim=[
641 tensor_shape_pb2.TensorShapeProto.Dim(size=dim)
642 for dim in tensor_data.numpy.shape
643 ]))
645 # The while's body inputs and outputs have the same type, so here we can go
646 # ahead and change that function's output type.
647 body_name = self._node.attr["body"].func.name
648 body = self._enclosing_graph.functions[body_name].converted_self().function
649 body.signature.output_arg[
650 incoming_edge.destination.index].type = tensor_data.dtype
653class _GraphDef(_Convertible):
654 """A convertible GraphDef."""
656 def __init__(self, graph_def):
657 super(_GraphDef, self).__init__(enclosing_graph=None)
658 self._graph_def = graph_def
659 self._nodes = {
660 n.name: _Node.new(node=n, function=None, enclosing_graph=self)
661 for n in graph_def.node
662 }
663 self._functions = {
664 f.signature.name: _Function(f, enclosing_graph=self)
665 for f in graph_def.library.function
666 }
667 self.create_edges()
668 self._converted_function_names = None
670 @property
671 def graph_def(self):
672 return self._graph_def
674 @property
675 def nodes(self):
676 return self._nodes
678 @property
679 def functions(self):
680 return self._functions
682 @property
683 def converted_function_names(self):
684 """Map from original to new function names.
686 In order to avoid conflicts (two functions with the same name, one converted
687 and one not), we need to change the name of every converted function to
688 something that is hopefully unique.
690 Returns:
691 Map from original to new suggested function names.
692 """
693 if self._converted_function_names is None:
694 parsed_names = [] # List of (id, base_name, original_name)
695 for name in self.functions:
696 elements = name.rsplit("_", 1)
697 if len(elements) == 2 and elements[1].isnumeric():
698 parsed_names.append((int(elements[1]), elements[0], name))
699 else:
700 parsed_names.append((-1, name, name))
701 self._converted_function_names = {
702 name: "{}_frozen_{}".format(base_name, ops.uid())
703 for (_, base_name, name) in sorted(parsed_names)
704 }
706 return self._converted_function_names
708 def rename_function(self, old_name, new_name):
709 func = self.functions.pop(old_name)
710 func.function.signature.name = new_name
711 self.functions[new_name] = func
713 def is_converted_function(self, function_name):
714 # Only converted functions will be renamed.
715 return (function_name not in self.converted_self().functions) and (
716 function_name in self.converted_function_names)
718 def converted_self(self):
719 if self._converted_self is None:
720 copied_graph = graph_pb2.GraphDef()
721 copied_graph.CopyFrom(self._graph_def)
722 self._converted_self = _GraphDef(copied_graph)
723 return self._converted_self
725 def create_edges(self):
726 for n in self._nodes.values():
727 n.create_edges()
728 for f in self._functions.values():
729 f.create_edges()
732class _ConverterData(object):
733 """Container for constant conversion supporting data.
735 The data includes the graph being converted, and the pre-converted
736 tensors. This class will be specialized for ConcreteFunction and Session-based
737 conversions, as the means to obtain that data is different for each case.
738 """
740 def __init__(self,
741 graph_def,
742 variable_names_allowlist=None,
743 variable_names_denylist=None):
744 self._graph_def = graph_def
745 self._tensor_data = {}
746 self._build_node_defs_list()
747 self._variable_names_allowlist = variable_names_allowlist
748 self._variable_names_denylist = variable_names_denylist
750 @property
751 def graph_def(self):
752 """The graph to be converted."""
753 return self._graph_def
755 @property
756 def node_defs(self):
757 """All the node defs in the graph to be converted.
759 Returns:
760 A map from node name to the NodeDef for all NodeDefs in the graph, as well
761 as all control flow NodeDefs in the functions.
762 """
763 return self._node_defs
765 @property
766 def tensor_data(self):
767 """A map from tensor name to its converted _TensorData."""
768 return self._tensor_data
770 def _should_convert(self, name):
771 """Checks whether to convert the given variable name to a constant."""
772 return (self._variable_names_allowlist is None or
773 name in self._variable_names_allowlist) and (
774 self._variable_names_denylist is None or
775 name not in self._variable_names_denylist)
777 def _build_node_defs_list(self):
778 """Builds the list of NodeDefs in the GraphDef.
780 This list consists of all NodeDefs in the main graph as well as all control
781 flow NodeDefs in the functions.
783 The remaining NodeDefs in the functions are not included because the op
784 names
785 are not unique and the variables are handled differently than the main
786 graph.
787 The control flow ops need to be extracted because they are need their
788 attributes to be updated similar to the control flow ops in the main graph.
789 """
790 self._node_defs = {node.name: node for node in self._graph_def.node}
792 if self._graph_def.library:
793 for func in self._graph_def.library.function:
794 self._node_defs.update({
795 node.name: node
796 for node in func.node_def
797 if node.op in _CONTROL_FLOW_OPS
798 })
801class _FunctionConverterData(_ConverterData):
802 """Container for ConcreteFunction-based conversion data."""
804 def __init__(self,
805 func,
806 lower_control_flow,
807 aggressive_inlining,
808 variable_names_allowlist=None,
809 variable_names_denylist=None):
810 """Creates the conversion data for the given function.
812 Args:
813 func: ConcreteFunction.
814 lower_control_flow: Boolean indicating whether or not to lower control
815 flow ops such as If and While.
816 aggressive_inlining: Boolean indicating whether or not to do aggressive
817 function inlining (might be unsafe if function has stateful ops, not
818 properly connected to control outputs).
819 variable_names_allowlist: The set of variable names to convert (by
820 default, all variables are converted).
821 variable_names_denylist: The set of variable names to omit converting to
822 constants.
823 """
825 self._func = func
826 # Inline the graph in order to remove functions when possible.
827 graph_def = _run_inline_graph_optimization(func, lower_control_flow,
828 aggressive_inlining)
829 super(_FunctionConverterData, self).__init__(
830 graph_def,
831 variable_names_allowlist=variable_names_allowlist,
832 variable_names_denylist=variable_names_denylist)
834 self._build_tensor_data()
836 def _eval(self, tensor):
837 """Returns the value in the tensor. Must be implemented in sub-classes."""
838 raise errors.UnimplementedError(
839 "The evaluation method should be implemented in sub-classes.")
841 def _build_tensor_data(self):
842 """Caches the tensor data for all Placeholders in the given function."""
843 map_index_to_variable = {}
844 for var in self._func.graph.variables:
845 for idx, captured_input in enumerate(self._func.captured_inputs):
846 if var.handle is captured_input: # pylint: disable=protected-access
847 map_index_to_variable[idx] = var
848 break
850 # Iterates through all captures which are represented as Placeholders.
851 for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures):
852 tensor_name = name_tensor.name.split(":")[0]
853 if not self._should_convert(tensor_name):
854 continue
855 if idx in map_index_to_variable:
856 data = self._eval(map_index_to_variable[idx])
857 else:
858 if val_tensor.dtype == dtypes.resource:
859 logging.vlog(1, "Skip converting resource tensor %s" % tensor_name)
860 continue
861 data = np.array(self._eval(val_tensor))
863 self._tensor_data[tensor_name] = _TensorData(
864 numpy=data,
865 dtype=dtypes.as_dtype(data.dtype).as_datatype_enum,
866 index=idx)
868 # Get data for VariableV2 ops (reference variables) that cannot be lifted.
869 for node in self.node_defs.values():
870 if node.op == "VariableV2":
871 if not self._should_convert(node.name):
872 continue
873 if node.name not in self.tensor_data:
874 with self._func.graph.as_default():
875 identity_node = array_ops.identity(
876 self._func.graph.as_graph_element(node.name + ":0"))
877 pruned_graph = self._func.prune([], [identity_node.name])()[0]
878 self._tensor_data[node.name] = _TensorData(
879 numpy=pruned_graph.numpy(),
880 dtype=node.attr["dtype"].type,
881 index=None)
884class _FunctionConverterDataInEager(_FunctionConverterData):
885 """Container for ConcreteFunction-based conversion data in Eager mode."""
887 def _eval(self, tensor):
888 """Returns the value in the tensor. Must be implemented in sub-classes."""
889 return tensor.numpy()
892class _FunctionConverterDataInGraph(_FunctionConverterData):
893 """Container for ConcreteFunction-based conversion data in Graph mode."""
895 def __init__(self,
896 func,
897 lower_control_flow,
898 aggressive_inlining,
899 variable_names_allowlist=None,
900 variable_names_denylist=None,
901 session=None):
902 """Creates the conversion data for the given function.
904 Args:
905 func: ConcreteFunction.
906 lower_control_flow: Boolean indicating whether or not to lower control
907 flow ops such as If and While.
908 aggressive_inlining: Boolean indicating whether or not to do aggressive
909 function inlining (might be unsafe if function has stateful ops, not
910 properly connected to control outputs).
911 variable_names_allowlist: The set of variable names to convert (by
912 default, all variables are converted).
913 variable_names_denylist: The set of variable names to omit converting to
914 constants.
915 session: Session object.
916 """
917 self._session = session
919 session.run(variables.global_variables_initializer())
920 # Run extra assignment ops if needed.
921 # These assignments are run sequentially to ensure order.
922 for op in ops.get_default_graph().get_collection(VAR_ASSIGN_COLLECTION):
923 session.run(op)
925 super(_FunctionConverterDataInGraph, self).__init__(
926 func,
927 lower_control_flow,
928 aggressive_inlining,
929 variable_names_allowlist,
930 variable_names_denylist)
932 def _eval(self, tensor):
933 """Returns the value in the tensor. Must be implemented in sub-classes."""
934 return self._session.run(tensor)
937class _SessionConverterData(_ConverterData):
938 """Container for Session-based conversion data."""
940 def __init__(self,
941 session,
942 graph_def,
943 output_node_names,
944 variable_names_allowlist=None,
945 variable_names_denylist=None):
946 graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
947 super(_SessionConverterData, self).__init__(
948 graph_def,
949 variable_names_allowlist=variable_names_allowlist,
950 variable_names_denylist=variable_names_denylist)
952 nodes_to_convert = []
953 tensor_names_to_convert = []
954 for node in self.graph_def.node:
955 if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
956 tensor_name = node.name
957 if not self._should_convert(tensor_name):
958 continue
959 if node.op == "VarHandleOp":
960 tensor_name = tensor_name + "/Read/ReadVariableOp"
961 nodes_to_convert.append(node)
962 tensor_names_to_convert.append(tensor_name + ":0")
964 if tensor_names_to_convert:
965 converted_tensors = session.run(tensor_names_to_convert)
966 for node, tensor_value in zip(nodes_to_convert, converted_tensors):
967 self._tensor_data[node.name] = _TensorData(
968 numpy=tensor_value, dtype=node.attr["dtype"].type, index=None)
971def disable_lower_using_switch_merge(graph_def):
972 """Set '_lower_using_switch_merge' attributes to False.
974 Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs
975 in each function's graph.
977 Args:
978 graph_def: GraphDef proto.
980 Returns:
981 GraphDef
982 """
983 output_graph_def = graph_pb2.GraphDef()
984 output_graph_def.CopyFrom(graph_def)
986 def disable_control_flow_lowering(node):
987 if node.op in _CONTROL_FLOW_OPS:
988 node.attr["_lower_using_switch_merge"].b = False
990 for node in output_graph_def.node:
991 disable_control_flow_lowering(node)
993 if output_graph_def.library:
994 for func in output_graph_def.library.function:
995 for node in func.node_def:
996 disable_control_flow_lowering(node)
997 return output_graph_def
1000def _run_inline_graph_optimization(func, lower_control_flow,
1001 aggressive_inlining):
1002 """Apply function inline optimization to the graph.
1004 Returns the GraphDef after Grappler's function inlining optimization is
1005 applied. This optimization does not work on models with control flow.
1007 Args:
1008 func: ConcreteFunction.
1009 lower_control_flow: Boolean indicating whether or not to lower control flow
1010 ops such as If and While. (default True)
1011 aggressive_inlining: Boolean indicating whether or not to do aggressive
1012 function inlining (might be unsafe if function has stateful ops not
1013 properly connected to control outputs).
1015 Returns:
1016 GraphDef
1017 """
1018 graph_def = func.graph.as_graph_def()
1019 if not lower_control_flow:
1020 graph_def = disable_lower_using_switch_merge(graph_def)
1022 # In some cases, a secondary implementation of the function (e.g. for GPU) is
1023 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
1024 # TF2 produces a CuDNN-based RNN for GPU).
1025 # This function suppose to inline all functions calls, but "api_implements"
1026 # prevents this from happening. Removing the attribute solves the problem.
1027 # To learn more about "api_implements", see:
1028 # tensorflow/core/grappler/optimizers/implementation_selector.h
1029 for function in graph_def.library.function:
1030 if "api_implements" in function.attr:
1031 del function.attr["api_implements"]
1033 meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)
1035 # Clear the initializer_name for the variables collections, since they are not
1036 # needed after saved to saved_model.
1037 for name in [
1038 "variables", "model_variables", "trainable_variables", "local_variables"
1039 ]:
1040 raw_list = []
1041 for raw in meta_graph.collection_def["variables"].bytes_list.value:
1042 variable = variable_pb2.VariableDef()
1043 variable.ParseFromString(raw)
1044 variable.ClearField("initializer_name")
1045 raw_list.append(variable.SerializeToString())
1046 meta_graph.collection_def[name].bytes_list.value[:] = raw_list
1048 # Add a collection 'train_op' so that Grappler knows the outputs.
1049 fetch_collection = meta_graph_pb2.CollectionDef()
1050 for array in func.inputs + func.outputs:
1051 fetch_collection.node_list.value.append(array.name)
1052 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
1054 # Initialize RewriterConfig with everything disabled except function inlining.
1055 config = config_pb2.ConfigProto()
1056 rewrite_options = config.graph_options.rewrite_options
1057 rewrite_options.min_graph_nodes = -1 # do not skip small graphs
1058 rewrite_options.optimizers.append("function")
1059 if aggressive_inlining:
1060 rewrite_options.function_optimization =\
1061 rewriter_config_pb2.RewriterConfig.AGGRESSIVE
1062 return tf_optimizer.OptimizeGraph(config, meta_graph)
1065def _construct_concrete_function(func, output_graph_def,
1066 converted_input_indices):
1067 """Constructs a concrete function from the `output_graph_def`.
1069 Args:
1070 func: ConcreteFunction
1071 output_graph_def: GraphDef proto.
1072 converted_input_indices: Set of integers of input indices that were
1073 converted to constants.
1075 Returns:
1076 ConcreteFunction.
1077 """
1078 # Create a ConcreteFunction from the new GraphDef.
1079 input_tensors = func.graph.internal_captures
1080 converted_inputs = object_identity.ObjectIdentitySet(
1081 [input_tensors[index] for index in converted_input_indices])
1082 not_converted_inputs = [
1083 tensor for tensor in func.inputs if tensor not in converted_inputs
1084 ]
1085 not_converted_inputs_map = {
1086 tensor.name: tensor for tensor in not_converted_inputs
1087 }
1089 new_input_names = [tensor.name for tensor in not_converted_inputs]
1090 new_output_names = [tensor.name for tensor in func.outputs]
1092 # Remove old functions to use updated functions from graph def.
1093 for f in output_graph_def.library.function:
1094 if context.context().has_function(f.signature.name):
1095 context.context().remove_function(f.signature.name)
1097 new_func = wrap_function.function_from_graph_def(output_graph_def,
1098 new_input_names,
1099 new_output_names)
1101 # Manually propagate shape for input tensors where the shape is not correctly
1102 # propagated. Scalars shapes are lost when wrapping the function.
1103 for input_tensor in new_func.inputs:
1104 input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape)
1105 return new_func
1108def _replace_variables_by_constants(converter_data):
1109 """Replaces variables by constants on a given graph.
1111 Given a _ConverterData instance with converted variables in its tensor_data
1112 field, create a new graph where the respective variables are replaced with the
1113 converted constants.
1115 Args:
1116 converter_data: A pre-populated _ConverterData instance.
1118 Returns:
1119 The converted graph.
1120 """
1121 input_graph = _GraphDef(converter_data.graph_def)
1123 for tensor_name, tensor_data in converter_data.tensor_data.items():
1124 input_graph.nodes[tensor_name].convert_variable_to_constant(
1125 None, tensor_data)
1127 converted_graph = input_graph.converted_self().graph_def
1129 converted_input_indices = {
1130 t.index
1131 for t in converter_data.tensor_data.values()
1132 if t.index is not None
1133 }
1135 return converted_graph, converted_input_indices
1138def convert_variables_to_constants_v2(func,
1139 lower_control_flow=True,
1140 aggressive_inlining=False):
1141 """Replaces all the variables in a graph with constants of the same values.
1143 TensorFlow 2.0 function for converting all Variable ops into Const ops holding
1144 the same values. This makes it possible to describe the network fully with a
1145 single GraphDef file, and allows the removal of a lot of ops related to
1146 loading and saving the variables. This function runs Grappler's function
1147 inlining optimization in order to return a single subgraph.
1149 The current implementation only works for graphs that do not contain any
1150 control flow or embedding related ops.
1152 Args:
1153 func: ConcreteFunction.
1154 lower_control_flow: Boolean indicating whether or not to lower control flow
1155 ops such as If and While. (default True)
1156 aggressive_inlining: Boolean indicating whether or not to do aggressive
1157 function inlining (might be unsafe if function has stateful ops, not
1158 properly connected to control outputs). (default False)
1160 Returns:
1161 ConcreteFunction containing a simplified version of the original.
1162 """
1164 converter_data = _FunctionConverterDataInEager(
1165 func=func,
1166 lower_control_flow=lower_control_flow,
1167 aggressive_inlining=aggressive_inlining)
1169 output_graph_def, converted_input_indices = _replace_variables_by_constants(
1170 converter_data=converter_data)
1172 return _construct_concrete_function(func, output_graph_def,
1173 converted_input_indices)
1176def convert_var_to_const_function_in_v1(func,
1177 lower_control_flow=True,
1178 aggressive_inlining=False):
1179 """Replaces all the variables in a graph with constants of the same values.
1181 This function works as same as convert_variables_to_constants_v2, but it
1182 should be used in Graph mode. It is a temporary solution when users want to
1183 integrate their models written in TF2 with infra that requires TF1 mode.
1185 The current implementation only works for graphs that do not contain any
1186 control flow or embedding related ops.
1188 The function must be called in a Session context.
1190 Args:
1191 func: ConcreteFunction.
1192 lower_control_flow: Boolean indicating whether or not to lower control flow
1193 ops such as If and While. (default True)
1194 aggressive_inlining: Boolean indicating whether or not to do aggressive
1195 function inlining (might be unsafe if function has stateful ops, not
1196 properly connected to control outputs). (default False)
1198 Raises:
1199 RuntimeError: If no Session context is present.
1201 Returns:
1202 ConcreteFunction containing a simplified version of the original.
1203 """
1205 session = ops.get_default_session()
1206 if session is None:
1207 raise RuntimeError(
1208 "The conversion must be carried out in a Session context.")
1210 converter_data = _FunctionConverterDataInGraph(
1211 func=func,
1212 lower_control_flow=lower_control_flow,
1213 aggressive_inlining=aggressive_inlining,
1214 session=session)
1216 output_graph_def, converted_input_indices = _replace_variables_by_constants(
1217 converter_data=converter_data)
1219 return _construct_concrete_function(func, output_graph_def,
1220 converted_input_indices)
1223def convert_variables_to_constants_v2_as_graph(func,
1224 lower_control_flow=True,
1225 aggressive_inlining=False):
1226 """Replaces all the variables in a graph with constants of the same values.
1228 This function works as same as convert_variables_to_constants_v2, but it
1229 returns the intermediate `GraphDef` as well. This `GraphDef` contains all the
1230 debug information after all the transformations in the frozen phase.
1232 Args:
1233 func: ConcreteFunction.
1234 lower_control_flow: Boolean indicating whether or not to lower control flow
1235 ops such as If and While. (default True)
1236 aggressive_inlining: Boolean indicating whether or not to do aggressive
1237 function inlining (might be unsafe if function has stateful ops, not
1238 properly connected to control outputs).
1240 Returns:
1241 ConcreteFunction containing a simplified version of the original, and also
1242 the intermediate GraphDef containing the node debug information for the
1243 transformations in the frozen phase.
1244 """
1245 converter_data = _FunctionConverterDataInEager(
1246 func=func,
1247 lower_control_flow=lower_control_flow,
1248 aggressive_inlining=aggressive_inlining)
1250 output_graph_def, converted_input_indices = _replace_variables_by_constants(
1251 converter_data=converter_data)
1253 frozen_func = _construct_concrete_function(func, output_graph_def,
1254 converted_input_indices)
1255 return frozen_func, output_graph_def
1258def convert_variables_to_constants_from_session_graph(
1259 session,
1260 graph_def,
1261 output_node_names,
1262 variable_names_allowlist=None,
1263 variable_names_denylist=None):
1264 """Replaces all the variables in a graph with constants of the same values.
1266 This function works similarly to convert_variables_to_constants_v2, but it
1267 retrieves the constant values from a Session instead of from a
1268 ConcreteFunction. This is useful when converting graphs generated from
1269 TensorFlow V1, where ConcreteFunctions are not available. This also differs
1270 from graph_util.convert_variables_to_constants in that it supports resource
1271 variables when V2 control flow constructions are present.
1273 Args:
1274 session: Active TensorFlow session containing the variables.
1275 graph_def: A GraphDef to convert.
1276 output_node_names: List of name strings for the result nodes of the graph.
1277 variable_names_allowlist: The set of variable names to convert (by default,
1278 all variables are converted).
1279 variable_names_denylist: The set of variable names to omit converting to
1280 constants.
1282 Returns:
1283 An optimized GraphDef.
1284 """
1285 graph_def, _ = _replace_variables_by_constants(
1286 converter_data=_SessionConverterData(
1287 session=session,
1288 graph_def=graph_def,
1289 output_node_names=output_node_names,
1290 variable_names_allowlist=variable_names_allowlist,
1291 variable_names_denylist=variable_names_denylist))
1292 return graph_def
1295@deprecation.deprecated(
1296 date=None,
1297 instructions="This API was designed for TensorFlow v1. See "
1298 "https://www.tensorflow.org/guide/migrate for instructions on how to "
1299 "migrate your code to TensorFlow v2."
1300)
1301@tf_export(v1=["graph_util.convert_variables_to_constants"])
1302def convert_variables_to_constants(sess,
1303 input_graph_def,
1304 output_node_names,
1305 variable_names_whitelist=None,
1306 variable_names_blacklist=None):
1307 """Replaces all the variables in a graph with constants of the same values.
1309 If you have a trained graph containing Variable ops, it can be convenient to
1310 convert them all to Const ops holding the same values. This makes it possible
1311 to describe the network fully with a single GraphDef file, and allows the
1312 removal of a lot of ops related to loading and saving the variables.
1314 Args:
1315 sess: Active TensorFlow session containing the variables.
1316 input_graph_def: GraphDef object holding the network.
1317 output_node_names: List of name strings for the result nodes of the graph.
1318 variable_names_whitelist: The set of variable names to convert (by default,
1319 all variables are converted).
1320 variable_names_blacklist: The set of variable names to omit converting to
1321 constants.
1323 Returns:
1324 GraphDef containing a simplified version of the original.
1326 Raises:
1327 RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
1328 denylisted AND whitelisted for freezing.
1329 """
1330 ret = convert_variables_to_constants_from_session_graph(
1331 session=sess,
1332 graph_def=input_graph_def,
1333 output_node_names=output_node_names,
1334 variable_names_allowlist=variable_names_whitelist,
1335 variable_names_denylist=variable_names_blacklist)
1336 return ret