Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/convert_to_constants.py: 27%

540 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helpers to convert variables to constants in TensorFlow 2.0.""" 

16 

17import collections 

18import numpy as np 

19 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.core.framework import graph_pb2 

22from tensorflow.core.framework import tensor_shape_pb2 

23from tensorflow.core.framework import variable_pb2 

24from tensorflow.core.protobuf import config_pb2 

25from tensorflow.core.protobuf import meta_graph_pb2 

26from tensorflow.core.protobuf import rewriter_config_pb2 

27from tensorflow.python.eager import context 

28from tensorflow.python.eager import wrap_function 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import errors 

31from tensorflow.python.framework import graph_util 

32from tensorflow.python.framework import ops 

33from tensorflow.python.framework import tensor_util 

34from tensorflow.python.grappler import tf_optimizer 

35from tensorflow.python.ops import array_ops 

36from tensorflow.python.ops import variables 

37from tensorflow.python.platform import tf_logging as logging 

38from tensorflow.python.training.saver import export_meta_graph 

39from tensorflow.python.util import deprecation 

40from tensorflow.python.util import object_identity 

41from tensorflow.python.util.tf_export import tf_export 

42 

43 

44# Used in _FunctionConverterDataInGraph(). 

45VAR_ASSIGN_COLLECTION = "extra_var_assign_ops" 

46_CONDITIONAL_OPS = set(["If", "StatelessIf"]) 

47_LOOP_OPS = set(["While", "StatelessWhile"]) 

48_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS) 

49 

50 

51class _TensorData( 

52 collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])): 

53 """Data about a tensor that was converted to a constant.""" 

54 __slots__ = () 

55 

56 @property 

57 def dtype_attr(self): 

58 return attr_value_pb2.AttrValue(type=self.dtype) 

59 

60 

61class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])): 

62 """An endpoint in a graph.""" 

63 __slots__ = () 

64 

65 def __str__(self): 

66 return "{}[{}]".format(self.convertible, self.index) 

67 

68 

69class _Edge(collections.namedtuple("_Edge", ["source", "destination"])): 

70 """A directed graph edge.""" 

71 __slots__ = () 

72 

73 def __str__(self): 

74 return "{} -> {}".format(self.source, self.destination) 

75 

76 

77class _Convertible(object): 

78 """An entity that can have variables converted to constants.""" 

79 

80 def __init__(self, enclosing_graph): 

81 self._enclosing_graph = enclosing_graph 

82 self._outgoing_edges = [] 

83 self._converted_self = None 

84 

85 def converted_self(self): 

86 """A copy of this Convertible to be modified during conversion. 

87 

88 Returns: 

89 Implementations should return the copied instance, which in turn should 

90 be contained in converted_enclosing_graph(). This instance is the one that 

91 will be modified during conversion. Its main use will be in the 

92 implementations of convert_variable_to_constant(). 

93 """ 

94 raise NotImplementedError 

95 

96 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

97 """Converts a variable in this Convertible and its dependencies. 

98 

99 This method should make sure that a converted copy of itself is present in 

100 the converted graph, and that all Convertibles depending on this one also go 

101 through the same process. 

102 

103 Args: 

104 incoming_edge: The graph edge into this Convertible that is being 

105 converted to a constant. 

106 tensor_data: The tensor representing the constant. 

107 """ 

108 raise NotImplementedError 

109 

110 def create_edges(self): 

111 """Calls add_outgoing_edge for all edges known to this Convertible. 

112 

113 This is used to build the graph dependencies, so that conversion of 

114 variables to constants can be properly propagated through the graph. Usually 

115 this method will call add_outgoing_edge() to all the Convertible inputs. 

116 """ 

117 raise NotImplementedError 

118 

119 def add_outgoing_edge(self, edge): 

120 """Adds an outgoing edge to the Convertible's list of edges. 

121 

122 Args: 

123 edge: The outgoing edge (its source should be 'self'). 

124 """ 

125 self._outgoing_edges.append(edge) 

126 

127 @property 

128 def converted_enclosing_graph(self): 

129 """The graph being converted.""" 

130 return self._enclosing_graph.converted_self() 

131 

132 @property 

133 def outgoing_edges(self): 

134 """The list of edges starting at this Convertible.""" 

135 return self._outgoing_edges 

136 

137 

138class _Function(_Convertible): 

139 """A library function Convertible. 

140 

141 Edges into functions are edges from node _inputs_ into function _inputs_: 

142 Functions get their input from their callers, not from node outputs, and the 

143 callers in turn get those values as inputs. 

144 """ 

145 

146 def __init__(self, function, enclosing_graph): 

147 super(_Function, self).__init__(enclosing_graph) 

148 self._function = function 

149 self._nodes = { 

150 n.name: 

151 _Node.new(node=n, function=self, enclosing_graph=enclosing_graph) 

152 for n in function.node_def 

153 } 

154 

155 def __str__(self): 

156 return self.function.signature.name 

157 

158 @property 

159 def function(self): 

160 return self._function 

161 

162 @property 

163 def nodes(self): 

164 return self._nodes 

165 

166 def converted_self(self): 

167 """The Function copy to be converted. 

168 

169 The copy will be renamed according to the graph's converted_function_name 

170 map, to ensure the name does not match anything currently in TensorFlow's 

171 function cache. 

172 

173 Returns: 

174 The function instance to be converted. 

175 """ 

176 if self._converted_self is None: 

177 old_name = self.function.signature.name 

178 new_name = self._enclosing_graph.converted_function_names[old_name] 

179 self.converted_enclosing_graph.rename_function(old_name, new_name) 

180 self._converted_self = self.converted_enclosing_graph.functions[new_name] 

181 return self._converted_self 

182 

183 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

184 """Converts one function argument into a constant. 

185 

186 Args: 

187 incoming_edge: The edge into the argument to be converted. 

188 tensor_data: The constant value. 

189 """ 

190 index = incoming_edge.destination.index 

191 for edge in self.outgoing_edges: 

192 if edge.source.index == index: 

193 edge.destination.convertible.convert_variable_to_constant( 

194 edge, tensor_data) 

195 

196 function = self.converted_self().function 

197 function.signature.input_arg[index].type = tensor_data.dtype 

198 # TODO(b/176982859): Find a more satisfying way to update shape information 

199 # than clearing it, or migrate users to a workflow that does not require 

200 # freezing. 

201 if "_input_shapes" in function.attr: 

202 function.attr["_input_shapes"].list.shape[index].unknown_rank = True 

203 del function.attr["_input_shapes"].list.shape[index].dim[:] 

204 arg_attrs = function.arg_attr[index].attr 

205 if "_output_shapes" in arg_attrs: 

206 arg_attrs["_output_shapes"].list.shape[0].unknown_rank = True 

207 del arg_attrs["_output_shapes"].list.shape[0].dim[:] 

208 

209 def create_edges(self): 

210 for n in self._nodes.values(): 

211 n.create_edges() 

212 

213 

214class _Node(_Convertible): 

215 """A Convertible NodeDef.""" 

216 

217 def __init__(self, node, function, enclosing_graph): 

218 super(_Node, self).__init__(enclosing_graph) 

219 self._node = node 

220 self._function = function 

221 

222 def __str__(self): 

223 return self._node.name 

224 

225 @staticmethod 

226 def new(node, function, enclosing_graph): 

227 """Creates a new _Node base on its operation type.""" 

228 if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]: 

229 return _VarHandle(node, function, enclosing_graph) 

230 elif node.op == "Case": 

231 return _Case(node, function, enclosing_graph) 

232 elif node.op == "Merge": 

233 return _Merge(node, function, enclosing_graph) 

234 elif node.op == "PartitionedCall": 

235 return _PartitionedCall(node, function, enclosing_graph) 

236 elif node.op == "StatefulPartitionedCall": 

237 return _PartitionedCall(node, function, enclosing_graph) 

238 elif node.op == "ReadVariableOp": 

239 return _ReadVariable(node, function, enclosing_graph) 

240 elif node.op == "ResourceGather": 

241 return _ResourceGather(node, function, enclosing_graph) 

242 elif node.op == "ResourceGatherNd": 

243 return _ResourceGatherNd(node, function, enclosing_graph) 

244 elif node.op in ["If", "StatelessIf"]: 

245 return _If(node, function, enclosing_graph) 

246 elif node.op in ["While", "StatelessWhile"]: 

247 return _While(node, function, enclosing_graph) 

248 elif node.op in [ 

249 "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]: 

250 return _Intermediate(node, function, enclosing_graph) 

251 else: 

252 return _Node(node, function, enclosing_graph) 

253 

254 @property 

255 def node(self): 

256 return self._node 

257 

258 @property 

259 def container(self): 

260 """The node container (either a graph or a function).""" 

261 if self._function is not None: 

262 return self._function.function 

263 return self._enclosing_graph.graph_def 

264 

265 def converted_self(self): 

266 """The NodeDef to be converted. 

267 

268 Returns: 

269 The NodeDef to be converted, which can come from either a graph for a 

270 function. Derived classes should call this (via 'super') to make sure the 

271 node is retrieved from the right place. 

272 """ 

273 if self._converted_self is None: 

274 source = self._function or self._enclosing_graph 

275 self._converted_self = source.converted_self().nodes[self._node.name] 

276 return self._converted_self 

277 

278 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

279 pass 

280 

281 def create_edges(self): 

282 for index, name in enumerate(self._node.input): 

283 # Discard edges from control inputs. 

284 if name[0] == "^": 

285 continue 

286 source = self.resolve_input(name) 

287 source.convertible.add_outgoing_edge( 

288 _Edge(source, _EndPoint(self, index))) 

289 

290 def resolve_input(self, input_name): 

291 """Resolves an input into its _EndPoint. 

292 

293 A NodeDef's input name can refer to either global NodeDefs (in the 

294 GraphDef's node list), a NodeDef in a function's node list, or a Function 

295 (in the GraphDef's function library). The name can also carry semantic 

296 information, depending on whether it starts with "^". This method handles 

297 all that logic in order to find the object to which the input name refers 

298 to. 

299 

300 Args: 

301 input_name: The input name to resolve. 

302 

303 Returns: 

304 The object referred to by 'input_name'. 

305 """ 

306 

307 # The logic below oversimplifies the semantics, but is good enough for the 

308 # purposes of converting to constants. The introduction of new types of 

309 # operations may change this, forcing the code to be more generic. 

310 # 

311 # In particular, we are assuming that the lack of an index suffix means 

312 # ":0", when it could mean "all the outputs of a node." This works now 

313 # because converting to constants relies very little on output types, and 

314 # when it does it specializes its treatment in dedicated classes. 

315 name_elts = input_name.split(":") 

316 source_name = name_elts[0] 

317 if source_name[0] == "^": 

318 source_name = source_name[1:] 

319 source_index = 0 

320 if len(name_elts) > 1 and name_elts[-1].isnumeric(): 

321 source_index = int(name_elts[-1]) 

322 

323 if self._function is None: 

324 return _EndPoint(self._enclosing_graph.nodes[source_name], source_index) 

325 

326 if source_index != 0 or source_name in self._function.nodes: 

327 return _EndPoint(self._function.nodes[source_name], source_index) 

328 

329 inputs = [i.name for i in self._function.function.signature.input_arg] 

330 return _EndPoint(self._function, inputs.index(source_name)) 

331 

332 def update_dtype(self, attr_name, index, dtype): 

333 """Changes the type of a given input. 

334 

335 Args: 

336 attr_name: The NodeDef attribute containing the type to change. 

337 index: The index of the input type to change. 

338 dtype: The type to change to. 

339 """ 

340 attr = self._node.attr[attr_name] 

341 num_types = 0 

342 # Check for various 'oneof' possibilities, and update the type if 

343 # index in range. 

344 if attr.HasField("list"): 

345 types = attr.list.type 

346 num_types = len(types) 

347 if num_types > index: 

348 types[index] = dtype 

349 return 

350 elif attr.HasField("type"): 

351 num_types = 1 

352 if index == 0: 

353 attr.type = dtype 

354 return 

355 raise ValueError(f"`index` {index:d} is out of range for " 

356 f"node({self._node.name}).attr({attr_name}), which has " 

357 f"{num_types:d} elements.") 

358 

359 

360class _Intermediate(_Node): 

361 """Specialization of _Node to intermediate ops.""" 

362 

363 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

364 node = self.converted_self() 

365 node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype) 

366 if "_output_shapes" in node.node.attr: 

367 del node.node.attr["_output_shapes"] 

368 for edge in self.outgoing_edges: 

369 edge.destination.convertible.convert_variable_to_constant( 

370 edge, tensor_data) 

371 

372 

373class _Merge(_Node): 

374 """Specialization of _Node to Merge ops.""" 

375 

376 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

377 # The Merge operation has a single type for all its inputs, the number of 

378 # which is reflected in the "N" attribute. For the time being, we assume 

379 # that unilaterally changing all of them at once is ok. 

380 super(_Merge, self).convert_variable_to_constant( 

381 _Edge(incoming_edge.source, 

382 _Edge(incoming_edge.destination.convertible, 0)), tensor_data) 

383 

384 

385class _VarHandle(_Node): 

386 """Specialization of _Node to VarHandleOp.""" 

387 

388 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

389 tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy, 

390 tensor_data.dtype, 

391 tensor_data.numpy.shape) 

392 

393 node = self.converted_self().node 

394 node.Clear() 

395 node.name = self._node.name 

396 node.op = "Const" 

397 node.attr["dtype"].CopyFrom(tensor_data.dtype_attr) 

398 node.attr["value"].tensor.CopyFrom(tensor_proto) 

399 

400 for edge in self.outgoing_edges: 

401 edge.destination.convertible.convert_variable_to_constant( 

402 edge, tensor_data) 

403 

404 

405class _ResourceGather(_Node): 

406 """Specialization of _Node to ResourceGather.""" 

407 

408 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

409 # We currently skip the conversion if this is inside a function. 

410 if self._function is not None: 

411 return 

412 if self._node.attr["batch_dims"].i != 0: 

413 raise ValueError("batch_dims must be 0 for freeze_graph, but got " 

414 f"node({self._node.name}).attr('batch_dims') = " 

415 f"{self._node.attr['batch_dims'].i}.") 

416 axis_node_name = self._node.name + "/axis" 

417 axis_dtype = self._node.attr["Tindices"] 

418 axis_data = np.array(self._node.attr["batch_dims"].i) 

419 converted_graph = self._enclosing_graph.converted_self() 

420 # Add Const axis node, or get it if it exists to avoid duplicates. 

421 if axis_node_name not in converted_graph.nodes: 

422 converted_graph.nodes[axis_node_name] = _Node.new( 

423 node=converted_graph.graph_def.node.add(), 

424 function=self._function, 

425 enclosing_graph=converted_graph) 

426 output_axis_node = converted_graph.nodes[axis_node_name].node 

427 output_axis_node.name = axis_node_name 

428 output_axis_node.op = "Const" 

429 output_axis_node.attr["dtype"].CopyFrom(axis_dtype) 

430 tensor = tensor_util.make_tensor_proto( 

431 axis_data, dtype=axis_dtype.type, shape=axis_data.shape) 

432 output_axis_node.attr["value"].tensor.CopyFrom(tensor) 

433 

434 output_node = self.converted_self().node 

435 output_node.Clear() 

436 output_node.name = self._node.name 

437 output_node.op = "GatherV2" 

438 output_node.input.extend( 

439 [self._node.input[0], self._node.input[1], axis_node_name]) 

440 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 

441 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 

442 output_node.attr["Taxis"].CopyFrom(axis_dtype) 

443 if "_class" in self._node.attr: 

444 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 

445 

446 

447class _ResourceGatherNd(_Node): 

448 """Specialization of _Node to ResourceGatherNd.""" 

449 

450 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

451 output_node = self.converted_self().node 

452 output_node.Clear() 

453 output_node.name = self._node.name 

454 output_node.op = "GatherNd" 

455 output_node.input.extend([self._node.input[0], self._node.input[1]]) 

456 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 

457 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 

458 if "_class" in self._node.attr: 

459 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 

460 

461 

462class _ReadVariable(_Node): 

463 """Specialization of _Node to ReadVariableOp.""" 

464 

465 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

466 node = self.converted_self().node 

467 node.Clear() 

468 node.name = self._node.name 

469 node.op = "Identity" 

470 

471 node.input.append(self._node.input[0]) 

472 node.attr["T"].CopyFrom(self._node.attr["dtype"]) 

473 if "_class" in self._node.attr: 

474 node.attr["_class"].CopyFrom(self._node.attr["_class"]) 

475 

476 # If the ReadVariableOp is part of a function, then every node having the 

477 # ReadVariableOp one as its input will refer to it using a ":value" 

478 # syntax. We need to change that to ":output". 

479 if self._function is not None: 

480 for edge in self.outgoing_edges: 

481 index = edge.destination.index 

482 dest = edge.destination.convertible.converted_self() 

483 if isinstance(dest, _Node): 

484 input_name_parts = dest.node.input[index].split(":") 

485 if len(input_name_parts) > 1 and input_name_parts[1] == "value": 

486 input_name_parts[1] = "output" 

487 dest.node.input[index] = ":".join(input_name_parts) 

488 

489 

490class _FunctionCaller(_Node): 

491 """A base class for Convertibles that reference functions.""" 

492 

493 def __init__(self, node, function, enclosing_graph, first_function_input, 

494 type_attribute, function_attributes): 

495 """Initializes a _FunctionCaller. 

496 

497 Args: 

498 node: As in _Node. 

499 function: As in _Node. 

500 enclosing_graph: As in _Node. 

501 first_function_input: The index of the first NodeDef input that is tied to 

502 the function inputs. It is assumed that the rest of the NodeDef inputs 

503 map one to one to function inputs. 

504 type_attribute: The name of the NodeDef attribute that defines the input 

505 types. It is assumed that the types listed here map one-to-one with the 

506 function inputs (that is, they do _not_ specify types for inputs that 

507 are not passed to functions). 

508 function_attributes: The names of the NodeDef attributes containing 

509 references to functions. 

510 """ 

511 super(_FunctionCaller, self).__init__(node, function, enclosing_graph) 

512 self._first_function_input = first_function_input 

513 self._type_attribute = type_attribute 

514 self._function_attributes = function_attributes 

515 

516 def converted_self(self): 

517 if self._converted_self is None: 

518 node = super(_FunctionCaller, self).converted_self().node 

519 converted_names = self._enclosing_graph.converted_function_names 

520 for attr_name in self._function_attributes: 

521 attr = node.attr[attr_name] 

522 if attr.HasField( 

523 "func") and self._enclosing_graph.is_converted_function( 

524 attr.func.name): 

525 attr.func.name = converted_names[attr.func.name] 

526 elif attr.HasField("list"): 

527 for func in attr.list.func: 

528 if self._enclosing_graph.is_converted_function(func.name): 

529 func.name = converted_names[func.name] 

530 return self._converted_self 

531 

532 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

533 index = incoming_edge.destination.index 

534 # The loop below is reasonable but not correct in general: 

535 # The outgoing edges going into the functions are correct, because the 

536 # inputs map to the function inputs. But the edges going into other nodes do 

537 # not take into account the logic of the body function, which may do 

538 # arbitrary things to the node's output: 

539 # 

540 # while x < 0: 

541 # return y 

542 # 

543 # In this case, the node's ":0" output may map to its ":1 input". For the 

544 # time being, then, we only process edges into functions. 

545 for edge in self.outgoing_edges: 

546 dest = edge.destination.convertible 

547 if edge.source.index == index and isinstance(dest, _Function): 

548 dest.convert_variable_to_constant(edge, tensor_data) 

549 

550 node = self.converted_self() 

551 if index >= self._first_function_input: 

552 node.update_dtype(self._type_attribute, 

553 index - self._first_function_input, tensor_data.dtype) 

554 

555 def create_edges(self): 

556 """Creates edges related to a function caller. 

557 

558 Edges from a function caller to its called functions are always edges from 

559 _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on 

560 its own inputs. 

561 """ 

562 super(_FunctionCaller, self).create_edges() 

563 for attr_name in self._function_attributes: 

564 attr = self._node.attr[attr_name] 

565 if attr.HasField("func"): 

566 function = self._enclosing_graph.functions[attr.func.name] 

567 for index in range(len(self._node.input) - self._first_function_input): 

568 self.add_outgoing_edge( 

569 _Edge( 

570 _EndPoint(self, index + self._first_function_input), 

571 _EndPoint(function, index))) 

572 elif attr.HasField("list"): 

573 for func in attr.list.func: 

574 function = self._enclosing_graph.functions[func.name] 

575 for index in range( 

576 len(self._node.input) - self._first_function_input): 

577 self.add_outgoing_edge( 

578 _Edge( 

579 _EndPoint(self, index + self._first_function_input), 

580 _EndPoint(function, index))) 

581 

582 

583class _If(_FunctionCaller): 

584 """Specialization of _Node to If-like operations.""" 

585 

586 def __init__(self, node, function, enclosing_graph): 

587 super(_If, self).__init__( 

588 node, 

589 function, 

590 enclosing_graph, 

591 first_function_input=1, 

592 type_attribute="Tin", 

593 function_attributes=["then_branch", "else_branch"]) 

594 

595 

596class _Case(_FunctionCaller): 

597 """Specialization of _Node to Case-like operations.""" 

598 

599 def __init__(self, node, function, enclosing_graph): 

600 super(_Case, self).__init__( 

601 node, 

602 function, 

603 enclosing_graph, 

604 first_function_input=1, 

605 type_attribute="Tin", 

606 function_attributes=["branches"]) 

607 

608 

609class _PartitionedCall(_FunctionCaller): 

610 """Specialization of _Node to PartitionedCall-like operations.""" 

611 

612 def __init__(self, node, function, enclosing_graph): 

613 super(_PartitionedCall, self).__init__( 

614 node, 

615 function, 

616 enclosing_graph, 

617 first_function_input=0, 

618 type_attribute="Tin", 

619 function_attributes=["f"]) 

620 

621 

622class _While(_FunctionCaller): 

623 """Specialization of _Node to While-like operations.""" 

624 

625 def __init__(self, node, function, enclosing_graph): 

626 super(_While, self).__init__( 

627 node, 

628 function, 

629 enclosing_graph, 

630 first_function_input=0, 

631 type_attribute="T", 

632 function_attributes=["body", "cond"]) 

633 

634 def convert_variable_to_constant(self, incoming_edge, tensor_data): 

635 super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data) 

636 node = self.converted_self() 

637 if node.node.attr["output_shapes"].list.shape: 

638 node.node.attr["output_shapes"].list.shape[ 

639 incoming_edge.destination.index].CopyFrom( 

640 tensor_shape_pb2.TensorShapeProto(dim=[ 

641 tensor_shape_pb2.TensorShapeProto.Dim(size=dim) 

642 for dim in tensor_data.numpy.shape 

643 ])) 

644 

645 # The while's body inputs and outputs have the same type, so here we can go 

646 # ahead and change that function's output type. 

647 body_name = self._node.attr["body"].func.name 

648 body = self._enclosing_graph.functions[body_name].converted_self().function 

649 body.signature.output_arg[ 

650 incoming_edge.destination.index].type = tensor_data.dtype 

651 

652 

653class _GraphDef(_Convertible): 

654 """A convertible GraphDef.""" 

655 

656 def __init__(self, graph_def): 

657 super(_GraphDef, self).__init__(enclosing_graph=None) 

658 self._graph_def = graph_def 

659 self._nodes = { 

660 n.name: _Node.new(node=n, function=None, enclosing_graph=self) 

661 for n in graph_def.node 

662 } 

663 self._functions = { 

664 f.signature.name: _Function(f, enclosing_graph=self) 

665 for f in graph_def.library.function 

666 } 

667 self.create_edges() 

668 self._converted_function_names = None 

669 

670 @property 

671 def graph_def(self): 

672 return self._graph_def 

673 

674 @property 

675 def nodes(self): 

676 return self._nodes 

677 

678 @property 

679 def functions(self): 

680 return self._functions 

681 

682 @property 

683 def converted_function_names(self): 

684 """Map from original to new function names. 

685 

686 In order to avoid conflicts (two functions with the same name, one converted 

687 and one not), we need to change the name of every converted function to 

688 something that is hopefully unique. 

689 

690 Returns: 

691 Map from original to new suggested function names. 

692 """ 

693 if self._converted_function_names is None: 

694 parsed_names = [] # List of (id, base_name, original_name) 

695 for name in self.functions: 

696 elements = name.rsplit("_", 1) 

697 if len(elements) == 2 and elements[1].isnumeric(): 

698 parsed_names.append((int(elements[1]), elements[0], name)) 

699 else: 

700 parsed_names.append((-1, name, name)) 

701 self._converted_function_names = { 

702 name: "{}_frozen_{}".format(base_name, ops.uid()) 

703 for (_, base_name, name) in sorted(parsed_names) 

704 } 

705 

706 return self._converted_function_names 

707 

708 def rename_function(self, old_name, new_name): 

709 func = self.functions.pop(old_name) 

710 func.function.signature.name = new_name 

711 self.functions[new_name] = func 

712 

713 def is_converted_function(self, function_name): 

714 # Only converted functions will be renamed. 

715 return (function_name not in self.converted_self().functions) and ( 

716 function_name in self.converted_function_names) 

717 

718 def converted_self(self): 

719 if self._converted_self is None: 

720 copied_graph = graph_pb2.GraphDef() 

721 copied_graph.CopyFrom(self._graph_def) 

722 self._converted_self = _GraphDef(copied_graph) 

723 return self._converted_self 

724 

725 def create_edges(self): 

726 for n in self._nodes.values(): 

727 n.create_edges() 

728 for f in self._functions.values(): 

729 f.create_edges() 

730 

731 

732class _ConverterData(object): 

733 """Container for constant conversion supporting data. 

734 

735 The data includes the graph being converted, and the pre-converted 

736 tensors. This class will be specialized for ConcreteFunction and Session-based 

737 conversions, as the means to obtain that data is different for each case. 

738 """ 

739 

740 def __init__(self, 

741 graph_def, 

742 variable_names_allowlist=None, 

743 variable_names_denylist=None): 

744 self._graph_def = graph_def 

745 self._tensor_data = {} 

746 self._build_node_defs_list() 

747 self._variable_names_allowlist = variable_names_allowlist 

748 self._variable_names_denylist = variable_names_denylist 

749 

750 @property 

751 def graph_def(self): 

752 """The graph to be converted.""" 

753 return self._graph_def 

754 

755 @property 

756 def node_defs(self): 

757 """All the node defs in the graph to be converted. 

758 

759 Returns: 

760 A map from node name to the NodeDef for all NodeDefs in the graph, as well 

761 as all control flow NodeDefs in the functions. 

762 """ 

763 return self._node_defs 

764 

765 @property 

766 def tensor_data(self): 

767 """A map from tensor name to its converted _TensorData.""" 

768 return self._tensor_data 

769 

770 def _should_convert(self, name): 

771 """Checks whether to convert the given variable name to a constant.""" 

772 return (self._variable_names_allowlist is None or 

773 name in self._variable_names_allowlist) and ( 

774 self._variable_names_denylist is None or 

775 name not in self._variable_names_denylist) 

776 

777 def _build_node_defs_list(self): 

778 """Builds the list of NodeDefs in the GraphDef. 

779 

780 This list consists of all NodeDefs in the main graph as well as all control 

781 flow NodeDefs in the functions. 

782 

783 The remaining NodeDefs in the functions are not included because the op 

784 names 

785 are not unique and the variables are handled differently than the main 

786 graph. 

787 The control flow ops need to be extracted because they are need their 

788 attributes to be updated similar to the control flow ops in the main graph. 

789 """ 

790 self._node_defs = {node.name: node for node in self._graph_def.node} 

791 

792 if self._graph_def.library: 

793 for func in self._graph_def.library.function: 

794 self._node_defs.update({ 

795 node.name: node 

796 for node in func.node_def 

797 if node.op in _CONTROL_FLOW_OPS 

798 }) 

799 

800 

801class _FunctionConverterData(_ConverterData): 

802 """Container for ConcreteFunction-based conversion data.""" 

803 

804 def __init__(self, 

805 func, 

806 lower_control_flow, 

807 aggressive_inlining, 

808 variable_names_allowlist=None, 

809 variable_names_denylist=None): 

810 """Creates the conversion data for the given function. 

811 

812 Args: 

813 func: ConcreteFunction. 

814 lower_control_flow: Boolean indicating whether or not to lower control 

815 flow ops such as If and While. 

816 aggressive_inlining: Boolean indicating whether or not to do aggressive 

817 function inlining (might be unsafe if function has stateful ops, not 

818 properly connected to control outputs). 

819 variable_names_allowlist: The set of variable names to convert (by 

820 default, all variables are converted). 

821 variable_names_denylist: The set of variable names to omit converting to 

822 constants. 

823 """ 

824 

825 self._func = func 

826 # Inline the graph in order to remove functions when possible. 

827 graph_def = _run_inline_graph_optimization(func, lower_control_flow, 

828 aggressive_inlining) 

829 super(_FunctionConverterData, self).__init__( 

830 graph_def, 

831 variable_names_allowlist=variable_names_allowlist, 

832 variable_names_denylist=variable_names_denylist) 

833 

834 self._build_tensor_data() 

835 

836 def _eval(self, tensor): 

837 """Returns the value in the tensor. Must be implemented in sub-classes.""" 

838 raise errors.UnimplementedError( 

839 "The evaluation method should be implemented in sub-classes.") 

840 

841 def _build_tensor_data(self): 

842 """Caches the tensor data for all Placeholders in the given function.""" 

843 map_index_to_variable = {} 

844 for var in self._func.graph.variables: 

845 for idx, captured_input in enumerate(self._func.captured_inputs): 

846 if var.handle is captured_input: # pylint: disable=protected-access 

847 map_index_to_variable[idx] = var 

848 break 

849 

850 # Iterates through all captures which are represented as Placeholders. 

851 for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures): 

852 tensor_name = name_tensor.name.split(":")[0] 

853 if not self._should_convert(tensor_name): 

854 continue 

855 if idx in map_index_to_variable: 

856 data = self._eval(map_index_to_variable[idx]) 

857 else: 

858 if val_tensor.dtype == dtypes.resource: 

859 logging.vlog(1, "Skip converting resource tensor %s" % tensor_name) 

860 continue 

861 data = np.array(self._eval(val_tensor)) 

862 

863 self._tensor_data[tensor_name] = _TensorData( 

864 numpy=data, 

865 dtype=dtypes.as_dtype(data.dtype).as_datatype_enum, 

866 index=idx) 

867 

868 # Get data for VariableV2 ops (reference variables) that cannot be lifted. 

869 for node in self.node_defs.values(): 

870 if node.op == "VariableV2": 

871 if not self._should_convert(node.name): 

872 continue 

873 if node.name not in self.tensor_data: 

874 with self._func.graph.as_default(): 

875 identity_node = array_ops.identity( 

876 self._func.graph.as_graph_element(node.name + ":0")) 

877 pruned_graph = self._func.prune([], [identity_node.name])()[0] 

878 self._tensor_data[node.name] = _TensorData( 

879 numpy=pruned_graph.numpy(), 

880 dtype=node.attr["dtype"].type, 

881 index=None) 

882 

883 

884class _FunctionConverterDataInEager(_FunctionConverterData): 

885 """Container for ConcreteFunction-based conversion data in Eager mode.""" 

886 

887 def _eval(self, tensor): 

888 """Returns the value in the tensor. Must be implemented in sub-classes.""" 

889 return tensor.numpy() 

890 

891 

892class _FunctionConverterDataInGraph(_FunctionConverterData): 

893 """Container for ConcreteFunction-based conversion data in Graph mode.""" 

894 

895 def __init__(self, 

896 func, 

897 lower_control_flow, 

898 aggressive_inlining, 

899 variable_names_allowlist=None, 

900 variable_names_denylist=None, 

901 session=None): 

902 """Creates the conversion data for the given function. 

903 

904 Args: 

905 func: ConcreteFunction. 

906 lower_control_flow: Boolean indicating whether or not to lower control 

907 flow ops such as If and While. 

908 aggressive_inlining: Boolean indicating whether or not to do aggressive 

909 function inlining (might be unsafe if function has stateful ops, not 

910 properly connected to control outputs). 

911 variable_names_allowlist: The set of variable names to convert (by 

912 default, all variables are converted). 

913 variable_names_denylist: The set of variable names to omit converting to 

914 constants. 

915 session: Session object. 

916 """ 

917 self._session = session 

918 

919 session.run(variables.global_variables_initializer()) 

920 # Run extra assignment ops if needed. 

921 # These assignments are run sequentially to ensure order. 

922 for op in ops.get_default_graph().get_collection(VAR_ASSIGN_COLLECTION): 

923 session.run(op) 

924 

925 super(_FunctionConverterDataInGraph, self).__init__( 

926 func, 

927 lower_control_flow, 

928 aggressive_inlining, 

929 variable_names_allowlist, 

930 variable_names_denylist) 

931 

932 def _eval(self, tensor): 

933 """Returns the value in the tensor. Must be implemented in sub-classes.""" 

934 return self._session.run(tensor) 

935 

936 

937class _SessionConverterData(_ConverterData): 

938 """Container for Session-based conversion data.""" 

939 

940 def __init__(self, 

941 session, 

942 graph_def, 

943 output_node_names, 

944 variable_names_allowlist=None, 

945 variable_names_denylist=None): 

946 graph_def = graph_util.extract_sub_graph(graph_def, output_node_names) 

947 super(_SessionConverterData, self).__init__( 

948 graph_def, 

949 variable_names_allowlist=variable_names_allowlist, 

950 variable_names_denylist=variable_names_denylist) 

951 

952 nodes_to_convert = [] 

953 tensor_names_to_convert = [] 

954 for node in self.graph_def.node: 

955 if node.op in ["Variable", "VariableV2", "VarHandleOp"]: 

956 tensor_name = node.name 

957 if not self._should_convert(tensor_name): 

958 continue 

959 if node.op == "VarHandleOp": 

960 tensor_name = tensor_name + "/Read/ReadVariableOp" 

961 nodes_to_convert.append(node) 

962 tensor_names_to_convert.append(tensor_name + ":0") 

963 

964 if tensor_names_to_convert: 

965 converted_tensors = session.run(tensor_names_to_convert) 

966 for node, tensor_value in zip(nodes_to_convert, converted_tensors): 

967 self._tensor_data[node.name] = _TensorData( 

968 numpy=tensor_value, dtype=node.attr["dtype"].type, index=None) 

969 

970 

971def disable_lower_using_switch_merge(graph_def): 

972 """Set '_lower_using_switch_merge' attributes to False. 

973 

974 Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs 

975 in each function's graph. 

976 

977 Args: 

978 graph_def: GraphDef proto. 

979 

980 Returns: 

981 GraphDef 

982 """ 

983 output_graph_def = graph_pb2.GraphDef() 

984 output_graph_def.CopyFrom(graph_def) 

985 

986 def disable_control_flow_lowering(node): 

987 if node.op in _CONTROL_FLOW_OPS: 

988 node.attr["_lower_using_switch_merge"].b = False 

989 

990 for node in output_graph_def.node: 

991 disable_control_flow_lowering(node) 

992 

993 if output_graph_def.library: 

994 for func in output_graph_def.library.function: 

995 for node in func.node_def: 

996 disable_control_flow_lowering(node) 

997 return output_graph_def 

998 

999 

1000def _run_inline_graph_optimization(func, lower_control_flow, 

1001 aggressive_inlining): 

1002 """Apply function inline optimization to the graph. 

1003 

1004 Returns the GraphDef after Grappler's function inlining optimization is 

1005 applied. This optimization does not work on models with control flow. 

1006 

1007 Args: 

1008 func: ConcreteFunction. 

1009 lower_control_flow: Boolean indicating whether or not to lower control flow 

1010 ops such as If and While. (default True) 

1011 aggressive_inlining: Boolean indicating whether or not to do aggressive 

1012 function inlining (might be unsafe if function has stateful ops not 

1013 properly connected to control outputs). 

1014 

1015 Returns: 

1016 GraphDef 

1017 """ 

1018 graph_def = func.graph.as_graph_def() 

1019 if not lower_control_flow: 

1020 graph_def = disable_lower_using_switch_merge(graph_def) 

1021 

1022 # In some cases, a secondary implementation of the function (e.g. for GPU) is 

1023 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in 

1024 # TF2 produces a CuDNN-based RNN for GPU). 

1025 # This function suppose to inline all functions calls, but "api_implements" 

1026 # prevents this from happening. Removing the attribute solves the problem. 

1027 # To learn more about "api_implements", see: 

1028 # tensorflow/core/grappler/optimizers/implementation_selector.h 

1029 for function in graph_def.library.function: 

1030 if "api_implements" in function.attr: 

1031 del function.attr["api_implements"] 

1032 

1033 meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) 

1034 

1035 # Clear the initializer_name for the variables collections, since they are not 

1036 # needed after saved to saved_model. 

1037 for name in [ 

1038 "variables", "model_variables", "trainable_variables", "local_variables" 

1039 ]: 

1040 raw_list = [] 

1041 for raw in meta_graph.collection_def["variables"].bytes_list.value: 

1042 variable = variable_pb2.VariableDef() 

1043 variable.ParseFromString(raw) 

1044 variable.ClearField("initializer_name") 

1045 raw_list.append(variable.SerializeToString()) 

1046 meta_graph.collection_def[name].bytes_list.value[:] = raw_list 

1047 

1048 # Add a collection 'train_op' so that Grappler knows the outputs. 

1049 fetch_collection = meta_graph_pb2.CollectionDef() 

1050 for array in func.inputs + func.outputs: 

1051 fetch_collection.node_list.value.append(array.name) 

1052 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 

1053 

1054 # Initialize RewriterConfig with everything disabled except function inlining. 

1055 config = config_pb2.ConfigProto() 

1056 rewrite_options = config.graph_options.rewrite_options 

1057 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 

1058 rewrite_options.optimizers.append("function") 

1059 if aggressive_inlining: 

1060 rewrite_options.function_optimization =\ 

1061 rewriter_config_pb2.RewriterConfig.AGGRESSIVE 

1062 return tf_optimizer.OptimizeGraph(config, meta_graph) 

1063 

1064 

1065def _construct_concrete_function(func, output_graph_def, 

1066 converted_input_indices): 

1067 """Constructs a concrete function from the `output_graph_def`. 

1068 

1069 Args: 

1070 func: ConcreteFunction 

1071 output_graph_def: GraphDef proto. 

1072 converted_input_indices: Set of integers of input indices that were 

1073 converted to constants. 

1074 

1075 Returns: 

1076 ConcreteFunction. 

1077 """ 

1078 # Create a ConcreteFunction from the new GraphDef. 

1079 input_tensors = func.graph.internal_captures 

1080 converted_inputs = object_identity.ObjectIdentitySet( 

1081 [input_tensors[index] for index in converted_input_indices]) 

1082 not_converted_inputs = [ 

1083 tensor for tensor in func.inputs if tensor not in converted_inputs 

1084 ] 

1085 not_converted_inputs_map = { 

1086 tensor.name: tensor for tensor in not_converted_inputs 

1087 } 

1088 

1089 new_input_names = [tensor.name for tensor in not_converted_inputs] 

1090 new_output_names = [tensor.name for tensor in func.outputs] 

1091 

1092 # Remove old functions to use updated functions from graph def. 

1093 for f in output_graph_def.library.function: 

1094 if context.context().has_function(f.signature.name): 

1095 context.context().remove_function(f.signature.name) 

1096 

1097 new_func = wrap_function.function_from_graph_def(output_graph_def, 

1098 new_input_names, 

1099 new_output_names) 

1100 

1101 # Manually propagate shape for input tensors where the shape is not correctly 

1102 # propagated. Scalars shapes are lost when wrapping the function. 

1103 for input_tensor in new_func.inputs: 

1104 input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) 

1105 return new_func 

1106 

1107 

1108def _replace_variables_by_constants(converter_data): 

1109 """Replaces variables by constants on a given graph. 

1110 

1111 Given a _ConverterData instance with converted variables in its tensor_data 

1112 field, create a new graph where the respective variables are replaced with the 

1113 converted constants. 

1114 

1115 Args: 

1116 converter_data: A pre-populated _ConverterData instance. 

1117 

1118 Returns: 

1119 The converted graph. 

1120 """ 

1121 input_graph = _GraphDef(converter_data.graph_def) 

1122 

1123 for tensor_name, tensor_data in converter_data.tensor_data.items(): 

1124 input_graph.nodes[tensor_name].convert_variable_to_constant( 

1125 None, tensor_data) 

1126 

1127 converted_graph = input_graph.converted_self().graph_def 

1128 

1129 converted_input_indices = { 

1130 t.index 

1131 for t in converter_data.tensor_data.values() 

1132 if t.index is not None 

1133 } 

1134 

1135 return converted_graph, converted_input_indices 

1136 

1137 

1138def convert_variables_to_constants_v2(func, 

1139 lower_control_flow=True, 

1140 aggressive_inlining=False): 

1141 """Replaces all the variables in a graph with constants of the same values. 

1142 

1143 TensorFlow 2.0 function for converting all Variable ops into Const ops holding 

1144 the same values. This makes it possible to describe the network fully with a 

1145 single GraphDef file, and allows the removal of a lot of ops related to 

1146 loading and saving the variables. This function runs Grappler's function 

1147 inlining optimization in order to return a single subgraph. 

1148 

1149 The current implementation only works for graphs that do not contain any 

1150 control flow or embedding related ops. 

1151 

1152 Args: 

1153 func: ConcreteFunction. 

1154 lower_control_flow: Boolean indicating whether or not to lower control flow 

1155 ops such as If and While. (default True) 

1156 aggressive_inlining: Boolean indicating whether or not to do aggressive 

1157 function inlining (might be unsafe if function has stateful ops, not 

1158 properly connected to control outputs). (default False) 

1159 

1160 Returns: 

1161 ConcreteFunction containing a simplified version of the original. 

1162 """ 

1163 

1164 converter_data = _FunctionConverterDataInEager( 

1165 func=func, 

1166 lower_control_flow=lower_control_flow, 

1167 aggressive_inlining=aggressive_inlining) 

1168 

1169 output_graph_def, converted_input_indices = _replace_variables_by_constants( 

1170 converter_data=converter_data) 

1171 

1172 return _construct_concrete_function(func, output_graph_def, 

1173 converted_input_indices) 

1174 

1175 

1176def convert_var_to_const_function_in_v1(func, 

1177 lower_control_flow=True, 

1178 aggressive_inlining=False): 

1179 """Replaces all the variables in a graph with constants of the same values. 

1180 

1181 This function works as same as convert_variables_to_constants_v2, but it 

1182 should be used in Graph mode. It is a temporary solution when users want to 

1183 integrate their models written in TF2 with infra that requires TF1 mode. 

1184 

1185 The current implementation only works for graphs that do not contain any 

1186 control flow or embedding related ops. 

1187 

1188 The function must be called in a Session context. 

1189 

1190 Args: 

1191 func: ConcreteFunction. 

1192 lower_control_flow: Boolean indicating whether or not to lower control flow 

1193 ops such as If and While. (default True) 

1194 aggressive_inlining: Boolean indicating whether or not to do aggressive 

1195 function inlining (might be unsafe if function has stateful ops, not 

1196 properly connected to control outputs). (default False) 

1197 

1198 Raises: 

1199 RuntimeError: If no Session context is present. 

1200 

1201 Returns: 

1202 ConcreteFunction containing a simplified version of the original. 

1203 """ 

1204 

1205 session = ops.get_default_session() 

1206 if session is None: 

1207 raise RuntimeError( 

1208 "The conversion must be carried out in a Session context.") 

1209 

1210 converter_data = _FunctionConverterDataInGraph( 

1211 func=func, 

1212 lower_control_flow=lower_control_flow, 

1213 aggressive_inlining=aggressive_inlining, 

1214 session=session) 

1215 

1216 output_graph_def, converted_input_indices = _replace_variables_by_constants( 

1217 converter_data=converter_data) 

1218 

1219 return _construct_concrete_function(func, output_graph_def, 

1220 converted_input_indices) 

1221 

1222 

1223def convert_variables_to_constants_v2_as_graph(func, 

1224 lower_control_flow=True, 

1225 aggressive_inlining=False): 

1226 """Replaces all the variables in a graph with constants of the same values. 

1227 

1228 This function works as same as convert_variables_to_constants_v2, but it 

1229 returns the intermediate `GraphDef` as well. This `GraphDef` contains all the 

1230 debug information after all the transformations in the frozen phase. 

1231 

1232 Args: 

1233 func: ConcreteFunction. 

1234 lower_control_flow: Boolean indicating whether or not to lower control flow 

1235 ops such as If and While. (default True) 

1236 aggressive_inlining: Boolean indicating whether or not to do aggressive 

1237 function inlining (might be unsafe if function has stateful ops, not 

1238 properly connected to control outputs). 

1239 

1240 Returns: 

1241 ConcreteFunction containing a simplified version of the original, and also 

1242 the intermediate GraphDef containing the node debug information for the 

1243 transformations in the frozen phase. 

1244 """ 

1245 converter_data = _FunctionConverterDataInEager( 

1246 func=func, 

1247 lower_control_flow=lower_control_flow, 

1248 aggressive_inlining=aggressive_inlining) 

1249 

1250 output_graph_def, converted_input_indices = _replace_variables_by_constants( 

1251 converter_data=converter_data) 

1252 

1253 frozen_func = _construct_concrete_function(func, output_graph_def, 

1254 converted_input_indices) 

1255 return frozen_func, output_graph_def 

1256 

1257 

1258def convert_variables_to_constants_from_session_graph( 

1259 session, 

1260 graph_def, 

1261 output_node_names, 

1262 variable_names_allowlist=None, 

1263 variable_names_denylist=None): 

1264 """Replaces all the variables in a graph with constants of the same values. 

1265 

1266 This function works similarly to convert_variables_to_constants_v2, but it 

1267 retrieves the constant values from a Session instead of from a 

1268 ConcreteFunction. This is useful when converting graphs generated from 

1269 TensorFlow V1, where ConcreteFunctions are not available. This also differs 

1270 from graph_util.convert_variables_to_constants in that it supports resource 

1271 variables when V2 control flow constructions are present. 

1272 

1273 Args: 

1274 session: Active TensorFlow session containing the variables. 

1275 graph_def: A GraphDef to convert. 

1276 output_node_names: List of name strings for the result nodes of the graph. 

1277 variable_names_allowlist: The set of variable names to convert (by default, 

1278 all variables are converted). 

1279 variable_names_denylist: The set of variable names to omit converting to 

1280 constants. 

1281 

1282 Returns: 

1283 An optimized GraphDef. 

1284 """ 

1285 graph_def, _ = _replace_variables_by_constants( 

1286 converter_data=_SessionConverterData( 

1287 session=session, 

1288 graph_def=graph_def, 

1289 output_node_names=output_node_names, 

1290 variable_names_allowlist=variable_names_allowlist, 

1291 variable_names_denylist=variable_names_denylist)) 

1292 return graph_def 

1293 

1294 

1295@deprecation.deprecated( 

1296 date=None, 

1297 instructions="This API was designed for TensorFlow v1. See " 

1298 "https://www.tensorflow.org/guide/migrate for instructions on how to " 

1299 "migrate your code to TensorFlow v2." 

1300) 

1301@tf_export(v1=["graph_util.convert_variables_to_constants"]) 

1302def convert_variables_to_constants(sess, 

1303 input_graph_def, 

1304 output_node_names, 

1305 variable_names_whitelist=None, 

1306 variable_names_blacklist=None): 

1307 """Replaces all the variables in a graph with constants of the same values. 

1308 

1309 If you have a trained graph containing Variable ops, it can be convenient to 

1310 convert them all to Const ops holding the same values. This makes it possible 

1311 to describe the network fully with a single GraphDef file, and allows the 

1312 removal of a lot of ops related to loading and saving the variables. 

1313 

1314 Args: 

1315 sess: Active TensorFlow session containing the variables. 

1316 input_graph_def: GraphDef object holding the network. 

1317 output_node_names: List of name strings for the result nodes of the graph. 

1318 variable_names_whitelist: The set of variable names to convert (by default, 

1319 all variables are converted). 

1320 variable_names_blacklist: The set of variable names to omit converting to 

1321 constants. 

1322 

1323 Returns: 

1324 GraphDef containing a simplified version of the original. 

1325 

1326 Raises: 

1327 RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both 

1328 denylisted AND whitelisted for freezing. 

1329 """ 

1330 ret = convert_variables_to_constants_from_session_graph( 

1331 session=sess, 

1332 graph_def=input_graph_def, 

1333 output_node_names=output_node_names, 

1334 variable_names_allowlist=variable_names_whitelist, 

1335 variable_names_denylist=variable_names_blacklist) 

1336 return ret