Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/graph_util_impl.py: 20%

175 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helpers to manipulate a tensor graph in python. 

16""" 

17 

18import copy 

19import re 

20 

21from tensorflow.core.framework import graph_pb2 

22from tensorflow.core.framework import node_def_pb2 

23from tensorflow.python.framework import _proto_comparators 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.util import deprecation 

27from tensorflow.python.util.tf_export import tf_export 

28 

29tf_export(v1=["GraphDef"])(graph_pb2.GraphDef) 

30 

31_VARIABLE_OPS = { 

32 "Assign", 

33 "AssignAdd", 

34 "AssignSub", 

35 "Queue", 

36 "ScatterAdd", 

37 "ScatterSub", 

38 "ScatterUpdate", 

39 "TruncatedNormal", 

40 "Variable", 

41 "VariableV2", 

42} 

43 

44_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [ 

45 "Switch", 

46 "Enter", 

47 "Exit", 

48 "Identity", 

49 "Merge", 

50 "NextIteration", 

51] 

52 

53_DEPRECATION_MSG = ( 

54 "This API was designed for TensorFlow v1. See " 

55 "https://www.tensorflow.org/guide/migrate for instructions on how to " 

56 "migrate your code to TensorFlow v2.") 

57 

58 

59def _is_variable_op(op): 

60 """Returns true if 'op' refers to a Variable node.""" 

61 return op in _VARIABLE_OPS 

62 

63# GraphDef protobuf docstring. 

64graph_pb2.GraphDef.__doc__ = """\ 

65A protobuf containing the graph of operations. 

66 

67@compatibility(TF2) 

68This API is not available in TensorFlow 2.x. 

69 

70You should not need to use `GraphDef`s directly in TF2. To load `GraphDef`s in 

71TF2, use SavedModel. The SavedModel contains the `GraphDef`. 

72 

73Before: 

74 

75```python 

76with tf.io.gfile.GFile('/tmp/graph.pb', 'rb') as f: 

77 graph_def = tf.compat.v1.GraphDef() 

78 graph_def.ParseFromString(f.read()) 

79``` 

80 

81After: 

82 

83```python 

84tf.saved_model.load('/tmp/saved_model') 

85``` 

86 

87If you would like to create a `GraphDef` in TF2, use `tf.function` and 

88`get_concrete_function`. 

89 

90>>> @tf.function 

91>>> def f(x): 

92>>> return x 

93>>> 

94>>> graph_def = f.get_concrete_function(1.).graph.as_graph_def() 

95>>> print(graph_def) 

96 

97@end_compatibility 

98 

99""" 

100 

101 

102@deprecation.deprecated( 

103 date=None, 

104 instructions=_DEPRECATION_MSG) 

105@tf_export(v1=["graph_util.must_run_on_cpu"]) 

106def must_run_on_cpu(node, pin_variables_on_cpu=False): 

107 """Returns True if the given node_def must run on CPU, otherwise False. 

108 

109 Args: 

110 node: The node to be assigned to a device. Could be either an ops.Operation 

111 or NodeDef. 

112 pin_variables_on_cpu: If True, this function will return False if node_def 

113 represents a variable-related op. 

114 

115 Returns: 

116 True if the given node must run on CPU, otherwise False. 

117 """ 

118 

119 if isinstance(node, ops.Operation): 

120 node_def = node.node_def 

121 else: 

122 assert isinstance(node, node_def_pb2.NodeDef) 

123 node_def = node 

124 

125 # If the op is a variable-related op, should we pin it on CPU? 

126 if pin_variables_on_cpu and _is_variable_op(node_def.op): 

127 return True 

128 

129 # Constant operations producing a string or int32 must run on CPU. 

130 if node_def.op == "Const": 

131 # Get the value of the 'dtype' attr 

132 dtype = node_def.attr["dtype"].type 

133 if dtype == dtypes.string or dtype == dtypes.int32: 

134 return True 

135 

136 if node_def.op in ["DynamicStitch", "ParallelDynamicStitch"]: 

137 dtype = node_def.attr["T"].type 

138 if dtype == dtypes.int32: 

139 # DynamicStitch on GPU only works for int32 values. 

140 return True 

141 

142 if node_def.op in ["Cast"]: 

143 dtype = node_def.attr["SrcT"].type 

144 if dtype == dtypes.int32: 

145 # Cast on GPU does not works for int32 values. 

146 return True 

147 return False 

148 

149 

150################################################################################ 

151# 

152# device functions for use in with g.device(...) 

153# 

154################################################################################ 

155 

156 

157def _node_name(n): 

158 if n.startswith("^"): 

159 return n[1:] 

160 else: 

161 return n.split(":")[0] 

162 

163 

164def _get_colocated_node_name(colocated_node_name): 

165 """Decodes colocated node name and returns it without loc:@ prepended.""" 

166 colocated_node_decoded = colocated_node_name.decode("utf-8") 

167 if colocated_node_decoded.startswith("loc:@"): 

168 return colocated_node_decoded[5:] 

169 return colocated_node_decoded 

170 

171 

172def _extract_graph_summary(graph_def): 

173 """Extracts useful information from the graph and returns them.""" 

174 name_to_input_name = {} # Keyed by the dest node name. 

175 name_to_node = {} # Keyed by node name. 

176 

177 # Keeps track of node sequences. It is important to still output the 

178 # operations in the original order. 

179 name_to_seq_num = {} # Keyed by node name. 

180 seq = 0 

181 for node in graph_def.node: 

182 n = _node_name(node.name) 

183 name_to_node[n] = node 

184 name_to_input_name[n] = [_node_name(x) for x in node.input] 

185 # Prevent colocated nodes from being lost. 

186 if "_class" in node.attr: 

187 for colocated_node_name in node.attr["_class"].list.s: 

188 name_to_input_name[n].append( 

189 _get_colocated_node_name(colocated_node_name)) 

190 name_to_seq_num[n] = seq 

191 seq += 1 

192 return name_to_input_name, name_to_node, name_to_seq_num 

193 

194 

195def _assert_nodes_are_present(name_to_node, nodes): 

196 """Assert that nodes are present in the graph.""" 

197 for d in nodes: 

198 assert d in name_to_node, "%s is not in graph" % d 

199 

200 

201def _bfs_for_reachable_nodes(target_nodes, name_to_input_name): 

202 """Breadth first search for reachable nodes from target nodes.""" 

203 nodes_to_keep = set() 

204 # Breadth first search to find all the nodes that we should keep. 

205 next_to_visit = list(target_nodes) 

206 while next_to_visit: 

207 node = next_to_visit[0] 

208 del next_to_visit[0] 

209 if node in nodes_to_keep: 

210 # Already visited this node. 

211 continue 

212 nodes_to_keep.add(node) 

213 if node in name_to_input_name: 

214 next_to_visit += name_to_input_name[node] 

215 return nodes_to_keep 

216 

217 

218@deprecation.deprecated( 

219 date=None, 

220 instructions=_DEPRECATION_MSG) 

221@tf_export(v1=["graph_util.extract_sub_graph"]) 

222def extract_sub_graph(graph_def, dest_nodes): 

223 """Extract the subgraph that can reach any of the nodes in 'dest_nodes'. 

224 

225 Args: 

226 graph_def: A graph_pb2.GraphDef proto. 

227 dest_nodes: An iterable of strings specifying the destination node names. 

228 Returns: 

229 The GraphDef of the sub-graph. 

230 

231 Raises: 

232 TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. 

233 """ 

234 

235 if not isinstance(graph_def, graph_pb2.GraphDef): 

236 raise TypeError("graph_def must be a graph_pb2.GraphDef proto, but got " 

237 f"type {type(graph_def)}.") 

238 

239 if isinstance(dest_nodes, str): 

240 raise TypeError("dest_nodes must be an iterable of strings, but got " 

241 f"type {type(dest_nodes)}.") 

242 

243 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 

244 graph_def) 

245 _assert_nodes_are_present(name_to_node, dest_nodes) 

246 

247 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name) 

248 

249 nodes_to_keep_list = sorted( 

250 list(nodes_to_keep), key=lambda n: name_to_seq_num[n]) 

251 # Now construct the output GraphDef 

252 out = graph_pb2.GraphDef() 

253 for n in nodes_to_keep_list: 

254 out.node.extend([copy.deepcopy(name_to_node[n])]) 

255 out.library.CopyFrom(graph_def.library) 

256 out.versions.CopyFrom(graph_def.versions) 

257 

258 return out 

259 

260 

261@deprecation.deprecated( 

262 date=None, 

263 instructions=_DEPRECATION_MSG) 

264@tf_export(v1=["graph_util.tensor_shape_from_node_def_name"]) 

265def tensor_shape_from_node_def_name(graph, input_name): 

266 """Convenience function to get a shape from a NodeDef's input string.""" 

267 # To get a tensor, the name must be in the form <input>:<port>, for example 

268 # 'Mul:0'. The GraphDef input strings don't always have the port specified 

269 # though, so if there isn't a colon we need to add a default ':0' to the end. 

270 if ":" not in input_name: 

271 canonical_name = input_name + ":0" 

272 else: 

273 canonical_name = input_name 

274 tensor = graph.get_tensor_by_name(canonical_name) 

275 shape = tensor.get_shape() 

276 return shape 

277 

278 

279@deprecation.deprecated( 

280 date=None, 

281 instructions=_DEPRECATION_MSG) 

282@tf_export(v1=["graph_util.remove_training_nodes"]) 

283def remove_training_nodes(input_graph, protected_nodes=None): 

284 """Prunes out nodes that aren't needed for inference. 

285 

286 There are nodes like Identity and CheckNumerics that are only useful 

287 during training, and can be removed in graphs that will be used for 

288 nothing but inference. Here we identify and remove them, returning an 

289 equivalent graph. To be specific, CheckNumerics nodes are always removed, and 

290 Identity nodes that aren't involved in control edges are spliced out so that 

291 their input and outputs are directly connected. 

292 

293 Args: 

294 input_graph: Model to analyze and prune. 

295 protected_nodes: An optional list of names of nodes to be kept 

296 unconditionally. This is for example useful to preserve Identity output 

297 nodes. 

298 

299 Returns: 

300 A list of nodes with the unnecessary ones removed. 

301 """ 

302 if not protected_nodes: 

303 protected_nodes = [] 

304 

305 types_to_remove = {"CheckNumerics": True} 

306 

307 input_nodes = input_graph.node 

308 names_to_remove = {} 

309 for node in input_nodes: 

310 if node.op in types_to_remove and node.name not in protected_nodes: 

311 names_to_remove[node.name] = True 

312 

313 nodes_after_removal = [] 

314 for node in input_nodes: 

315 if node.name in names_to_remove: 

316 continue 

317 new_node = node_def_pb2.NodeDef() 

318 new_node.CopyFrom(node) 

319 input_before_removal = node.input 

320 del new_node.input[:] 

321 for full_input_name in input_before_removal: 

322 input_name = re.sub(r"^\^", "", full_input_name) 

323 if input_name in names_to_remove: 

324 continue 

325 new_node.input.append(full_input_name) 

326 nodes_after_removal.append(new_node) 

327 

328 types_to_splice = {"Identity": True} 

329 control_input_names = set() 

330 node_names_with_control_input = set() 

331 node_in_colocated = set() 

332 

333 for node in nodes_after_removal: 

334 for node_input in node.input: 

335 if "^" in node_input: 

336 control_input_names.add(node_input.replace("^", "")) 

337 node_names_with_control_input.add(node.name) 

338 # Prevent colocated nodes from being lost. 

339 if "_class" in node.attr: 

340 for colocated_node_name in node.attr["_class"].list.s: 

341 node_in_colocated.add(_get_colocated_node_name(colocated_node_name)) 

342 

343 names_to_splice = {} 

344 for node in nodes_after_removal: 

345 if node.op in types_to_splice and node.name not in protected_nodes: 

346 if node.name in node_in_colocated: 

347 continue 

348 # We don't want to remove nodes that have control edge inputs, because 

349 # they might be involved in subtle dependency issues that removing them 

350 # will jeopardize. 

351 if node.name not in node_names_with_control_input: 

352 names_to_splice[node.name] = node.input[0] 

353 

354 # We also don't want to remove nodes which are used as control edge inputs. 

355 names_to_splice = {name: value for name, value in names_to_splice.items() 

356 if name not in control_input_names} 

357 

358 nodes_after_splicing = [] 

359 for node in nodes_after_removal: 

360 if node.name in names_to_splice: 

361 continue 

362 new_node = node_def_pb2.NodeDef() 

363 new_node.CopyFrom(node) 

364 input_before_removal = node.input 

365 del new_node.input[:] 

366 for full_input_name in input_before_removal: 

367 input_name = re.sub(r"^\^", "", full_input_name) 

368 while input_name in names_to_splice: 

369 full_input_name = names_to_splice[input_name] 

370 input_name = re.sub(r"^\^", "", full_input_name) 

371 new_node.input.append(full_input_name) 

372 nodes_after_splicing.append(new_node) 

373 

374 output_graph = graph_pb2.GraphDef() 

375 output_graph.node.extend(nodes_after_splicing) 

376 return output_graph 

377 

378 

379@tf_export("__internal__.graph_util.graph_defs_equal", v1=[]) 

380def graph_defs_equal(graph_def_1: graph_pb2.GraphDef, 

381 graph_def_2: graph_pb2.GraphDef, 

382 treat_nan_as_equal: bool = False) -> bool: 

383 """Returns True iff the graph def arguments are structurally equivalent. 

384 

385 The notion of equivalence encoded here checks that the set of NodeDefs in 

386 the GraphDef's function library and main graph body are identical. 

387 Additionally, it checks that the functions in the function library are equal 

388 as sets. 

389 

390 Example usage: 

391 

392 ``` 

393 with tf.Graph().as_default() as g1: 

394 tf.constant(1) 

395 

396 with tf.Graph().as_default() as g2: 

397 tf.constant(2) 

398 

399 with tf.Graph().as_default() as g3: 

400 tf.constant(1) 

401 

402 assert tf.__internal__.graph_util.graph_defs_equal(g1.as_graph_def(), 

403 g3.as_graph_def()) 

404 

405 assert not tf.__internal__.graph_util.graph_defs_equal(g1.as_graph_def(), 

406 g2.as_graph_def()) 

407 ``` 

408 

409 Args: 

410 graph_def_1: Instance of `graph_pb2.GraphDef` to compare. 

411 graph_def_2: Instance of `graph_pb2.GraphDef` to compare. 

412 treat_nan_as_equal: Boolean indicating whether or not to treat nan 

413 floating-point values as equal. This is crucial for any equivalence 

414 relation defined over GraphDefs, to ensure symmetry. 

415 

416 Returns: 

417 Boolean indicating structural equivalence as described above. 

418 

419 Raises: 

420 TypeError: If either of the GraphDefs are not instances of 

421 `graph_pb2.GraphDef`. 

422 """ 

423 if not isinstance(graph_def_1, graph_pb2.GraphDef): 

424 raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto, but got " 

425 f"type {type(graph_def_1)}.") 

426 if not isinstance(graph_def_2, graph_pb2.GraphDef): 

427 raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto, but got " 

428 f"type {type(graph_def_2)}.") 

429 options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal) 

430 return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(), 

431 graph_def_2.SerializeToString(), 

432 options)