Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/lib/debug_graphs.py: 22%

234 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Classes and methods for processing debugger-decorated graphs.""" 

16from tensorflow.core.framework import graph_pb2 

17from tensorflow.python.framework import op_def_registry 

18from tensorflow.python.platform import tf_logging as logging 

19 

20 

21def parse_node_or_tensor_name(name): 

22 """Get the node name from a string that can be node or tensor name. 

23 

24 Args: 

25 name: An input node name (e.g., "node_a") or tensor name (e.g., 

26 "node_a:0"), as a str. 

27 

28 Returns: 

29 1) The node name, as a str. If the input name is a tensor name, i.e., 

30 consists of a colon, the final colon and the following output slot 

31 will be stripped. 

32 2) If the input name is a tensor name, the output slot, as an int. If 

33 the input name is not a tensor name, None. 

34 """ 

35 

36 if ":" in name and not name.endswith(":"): 

37 node_name = name[:name.rfind(":")] 

38 output_slot = int(name[name.rfind(":") + 1:]) 

39 

40 return node_name, output_slot 

41 else: 

42 return name, None 

43 

44 

45def get_node_name(element_name): 

46 node_name, _ = parse_node_or_tensor_name(element_name) 

47 return node_name 

48 

49 

50def get_output_slot(element_name): 

51 """Get the output slot number from the name of a graph element. 

52 

53 If element_name is a node name without output slot at the end, 0 will be 

54 assumed. 

55 

56 Args: 

57 element_name: (`str`) name of the graph element in question. 

58 

59 Returns: 

60 (`int`) output slot number. 

61 """ 

62 _, output_slot = parse_node_or_tensor_name(element_name) 

63 return output_slot if output_slot is not None else 0 

64 

65 

66def is_copy_node(node_name): 

67 """Determine whether a node name is that of a debug Copy node. 

68 

69 Such nodes are inserted by TensorFlow core upon request in 

70 RunOptions.debug_options.debug_tensor_watch_opts. 

71 

72 Args: 

73 node_name: Name of the node. 

74 

75 Returns: 

76 A bool indicating whether the input argument is the name of a debug Copy 

77 node. 

78 """ 

79 return node_name.startswith("__copy_") 

80 

81 

82def is_debug_node(node_name): 

83 """Determine whether a node name is that of a debug node. 

84 

85 Such nodes are inserted by TensorFlow core upon request in 

86 RunOptions.debug_options.debug_tensor_watch_opts. 

87 

88 Args: 

89 node_name: Name of the node. 

90 

91 Returns: 

92 A bool indicating whether the input argument is the name of a debug node. 

93 """ 

94 return node_name.startswith("__dbg_") 

95 

96 

97def parse_debug_node_name(node_name): 

98 """Parse the name of a debug node. 

99 

100 Args: 

101 node_name: Name of the debug node. 

102 

103 Returns: 

104 1. Name of the watched node, as a str. 

105 2. Output slot index of the watched tensor, as an int. 

106 3. Index of the debug node, as an int. 

107 4. Name of the debug op, as a str, e.g, "DebugIdentity". 

108 

109 Raises: 

110 ValueError: If the input node name is not a valid debug node name. 

111 """ 

112 prefix = "__dbg_" 

113 

114 name = node_name 

115 if not name.startswith(prefix): 

116 raise ValueError("Invalid prefix in debug node name: '%s'" % node_name) 

117 

118 name = name[len(prefix):] 

119 

120 if name.count("_") < 2: 

121 raise ValueError("Invalid debug node name: '%s'" % node_name) 

122 

123 debug_op = name[name.rindex("_") + 1:] 

124 name = name[:name.rindex("_")] 

125 

126 debug_op_index = int(name[name.rindex("_") + 1:]) 

127 name = name[:name.rindex("_")] 

128 

129 if name.count(":") != 1: 

130 raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name) 

131 

132 watched_node_name = name[:name.index(":")] 

133 watched_output_slot = int(name[name.index(":") + 1:]) 

134 

135 return watched_node_name, watched_output_slot, debug_op_index, debug_op 

136 

137 

138class GraphTracingReachedDestination(Exception): 

139 pass 

140 

141 

142class DFSGraphTracer(object): 

143 """Graph input tracer using depth-first search.""" 

144 

145 def __init__(self, 

146 input_lists, 

147 skip_node_names=None, 

148 destination_node_name=None): 

149 """Constructor of _DFSGraphTracer. 

150 

151 Args: 

152 input_lists: A list of dicts. Each dict is an adjacency (input) map from 

153 the recipient node name as the key and the list of input node names 

154 as the value. 

155 skip_node_names: Optional: a list of node names to skip tracing. 

156 destination_node_name: Optional: destination node name. If not `None`, it 

157 should be the name of a destination not as a str and the graph tracing 

158 will raise GraphTracingReachedDestination as soon as the node has been 

159 reached. 

160 

161 Raises: 

162 GraphTracingReachedDestination: if stop_at_node_name is not None and 

163 the specified node is reached. 

164 """ 

165 

166 self._input_lists = input_lists 

167 self._skip_node_names = skip_node_names 

168 

169 self._inputs = [] 

170 self._visited_nodes = [] 

171 self._depth_count = 0 

172 self._depth_list = [] 

173 

174 self._destination_node_name = destination_node_name 

175 

176 def trace(self, graph_element_name): 

177 """Trace inputs. 

178 

179 Args: 

180 graph_element_name: Name of the node or an output tensor of the node, as a 

181 str. 

182 

183 Raises: 

184 GraphTracingReachedDestination: if destination_node_name of this tracer 

185 object is not None and the specified node is reached. 

186 """ 

187 self._depth_count += 1 

188 

189 node_name = get_node_name(graph_element_name) 

190 if node_name == self._destination_node_name: 

191 raise GraphTracingReachedDestination() 

192 

193 if node_name in self._skip_node_names: 

194 return 

195 if node_name in self._visited_nodes: 

196 return 

197 

198 self._visited_nodes.append(node_name) 

199 

200 for input_list in self._input_lists: 

201 if node_name not in input_list: 

202 continue 

203 for inp in input_list[node_name]: 

204 if get_node_name(inp) in self._visited_nodes: 

205 continue 

206 self._inputs.append(inp) 

207 self._depth_list.append(self._depth_count) 

208 self.trace(inp) 

209 

210 self._depth_count -= 1 

211 

212 def inputs(self): 

213 return self._inputs 

214 

215 def depth_list(self): 

216 return self._depth_list 

217 

218 

219def _infer_device_name(graph_def): 

220 """Infer device name from a partition GraphDef.""" 

221 device_name = None 

222 for node in graph_def.node: 

223 if node.device: 

224 device_name = node.device 

225 break 

226 if device_name is None: 

227 logging.warn( 

228 "Failed to infer device name from partition GraphDef: none of the " 

229 "nodes of the GraphDef has a non-empty device name.") 

230 return device_name 

231 

232 

233class DebugGraph(object): 

234 """Represents a debugger-decorated graph.""" 

235 

236 def __init__(self, debug_graph_def, device_name=None): 

237 self._debug_graph_def = debug_graph_def 

238 self._non_debug_graph_def = None 

239 

240 self._node_attributes = {} 

241 self._node_inputs = {} 

242 self._node_reversed_ref_inputs = {} 

243 self._node_ctrl_inputs = {} 

244 self._node_recipients = {} 

245 self._node_ctrl_recipients = {} 

246 self._node_devices = {} 

247 self._node_op_types = {} 

248 self._copy_send_nodes = [] 

249 self._ref_args = {} 

250 

251 self._device_name = device_name 

252 if not self._device_name: 

253 self._device_name = _infer_device_name(debug_graph_def) 

254 

255 for node in debug_graph_def.node: 

256 self._process_debug_graph_node(node) 

257 

258 self._prune_non_control_edges_of_debug_ops() 

259 self._prune_control_edges_of_debug_ops() 

260 self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes()) 

261 

262 self._populate_recipient_maps() 

263 

264 def _process_debug_graph_node(self, node): 

265 """Process a node from the debug GraphDef. 

266 

267 Args: 

268 node: (NodeDef) A partition-graph node to be processed. 

269 

270 Raises: 

271 ValueError: If duplicate node names are encountered. 

272 """ 

273 if is_debug_node(node.name): 

274 # This is a debug node. Parse the node name and retrieve the 

275 # information about debug watches on tensors. But do not include 

276 # the node in the graph. 

277 return 

278 

279 if node.name in self._node_inputs: 

280 raise ValueError("Duplicate node name on device %s: '%s'" % 

281 (self._device_name, node.name)) 

282 

283 self._node_attributes[node.name] = node.attr 

284 

285 self._node_inputs[node.name] = [] 

286 self._node_ctrl_inputs[node.name] = [] 

287 self._node_recipients[node.name] = [] 

288 self._node_ctrl_recipients[node.name] = [] 

289 

290 if node.name not in self._node_devices: 

291 self._node_devices[node.name] = set() 

292 self._node_devices[node.name].add( 

293 node.device if node.device else self._device_name) 

294 self._node_op_types[node.name] = node.op 

295 self._ref_args[node.name] = self._get_ref_args(node) 

296 

297 for inp in node.input: 

298 if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): 

299 self._copy_send_nodes.append(node.name) 

300 

301 if inp.startswith("^"): 

302 cinp = inp[1:] 

303 self._node_ctrl_inputs[node.name].append(cinp) 

304 else: 

305 self._node_inputs[node.name].append(inp) 

306 

307 def _get_ref_args(self, node): 

308 """Determine whether an input of an op is ref-type. 

309 

310 Args: 

311 node: A `NodeDef`. 

312 

313 Returns: 

314 A list of the arg names (as strs) that are ref-type. 

315 """ 

316 op_def = op_def_registry.get(node.op) 

317 if op_def is None: 

318 return [] 

319 

320 ref_args = [] 

321 for i, output_arg in enumerate(op_def.output_arg): 

322 if output_arg.is_ref: 

323 arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) 

324 ref_args.append(arg_name) 

325 return ref_args 

326 

327 def _get_copy_nodes(self): 

328 """Find all Copy nodes in the loaded graph.""" 

329 copy_nodes = [] 

330 for node in self._node_inputs: 

331 if is_copy_node(node): 

332 copy_nodes.append(node) 

333 return copy_nodes 

334 

335 def _prune_non_control_edges_of_debug_ops(self): 

336 """Prune (non-control) edges related to debug ops. 

337 

338 Prune the Copy ops and associated _Send ops inserted by the debugger out 

339 from the non-control inputs and output recipients map. Replace the inputs 

340 and recipients with original ones. 

341 """ 

342 for node in self._node_inputs: 

343 inputs = self._node_inputs[node] 

344 

345 for i, inp in enumerate(inputs): 

346 if is_copy_node(inp): 

347 # Find the input to the Copy node, which should be the original 

348 # input to the node. 

349 orig_inp = self._node_inputs[inp][0] 

350 inputs[i] = orig_inp 

351 

352 def _prune_control_edges_of_debug_ops(self): 

353 """Prune control edges related to the debug ops.""" 

354 for node in self._node_ctrl_inputs: 

355 ctrl_inputs = self._node_ctrl_inputs[node] 

356 debug_op_inputs = [] 

357 for ctrl_inp in ctrl_inputs: 

358 if is_debug_node(ctrl_inp): 

359 debug_op_inputs.append(ctrl_inp) 

360 for debug_op_inp in debug_op_inputs: 

361 ctrl_inputs.remove(debug_op_inp) 

362 

363 def _populate_recipient_maps(self): 

364 """Populate the map from node name to recipient(s) of its output(s). 

365 

366 This method also populates the input map based on reversed ref edges. 

367 """ 

368 for node in self._node_inputs: 

369 inputs = self._node_inputs[node] 

370 for inp in inputs: 

371 inp = get_node_name(inp) 

372 if inp not in self._node_recipients: 

373 self._node_recipients[inp] = [] 

374 self._node_recipients[inp].append(node) 

375 

376 if inp in self._ref_args: 

377 if inp not in self._node_reversed_ref_inputs: 

378 self._node_reversed_ref_inputs[inp] = [] 

379 self._node_reversed_ref_inputs[inp].append(node) 

380 

381 for node in self._node_ctrl_inputs: 

382 ctrl_inputs = self._node_ctrl_inputs[node] 

383 for ctrl_inp in ctrl_inputs: 

384 if ctrl_inp in self._copy_send_nodes: 

385 continue 

386 

387 if ctrl_inp not in self._node_ctrl_recipients: 

388 self._node_ctrl_recipients[ctrl_inp] = [] 

389 self._node_ctrl_recipients[ctrl_inp].append(node) 

390 

391 def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune): 

392 """Prune nodes out of input and recipient maps. 

393 

394 Args: 

395 nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned. 

396 """ 

397 for node in nodes_to_prune: 

398 del self._node_inputs[node] 

399 del self._node_ctrl_inputs[node] 

400 del self._node_recipients[node] 

401 del self._node_ctrl_recipients[node] 

402 

403 def _reconstruct_non_debug_graph_def(self): 

404 """Reconstruct non-debug GraphDef. 

405 

406 Non-debug GraphDef means the original GraphDef without the Copy* and Debug 

407 nodes inserted by the debugger. 

408 """ 

409 if self._non_debug_graph_def: 

410 return 

411 

412 self._non_debug_graph_def = graph_pb2.GraphDef() 

413 for node in self._debug_graph_def.node: 

414 if is_copy_node(node.name) or is_debug_node(node.name): 

415 continue 

416 

417 new_node = self._non_debug_graph_def.node.add() 

418 new_node.CopyFrom(node) 

419 

420 # Redo the list of inputs, because in _debug_graph_def, the list can 

421 # consist of Copy* and Debug* nodes inserted by the debugger. Those will 

422 # be replaced with the original inputs here. 

423 del new_node.input[:] 

424 for inp in self._node_inputs[node.name]: 

425 new_node.input.append(inp) 

426 for ctrl_inp in self._node_ctrl_inputs[node.name]: 

427 new_node.input.append("^" + ctrl_inp) 

428 

429 @property 

430 def device_name(self): 

431 return self._device_name 

432 

433 @property 

434 def debug_graph_def(self): 

435 """The debugger-decorated GraphDef.""" 

436 return self._debug_graph_def 

437 

438 @property 

439 def non_debug_graph_def(self): 

440 """The GraphDef without the Copy* and Debug* nodes added by the debugger.""" 

441 self._reconstruct_non_debug_graph_def() 

442 return self._non_debug_graph_def 

443 

444 @property 

445 def node_devices(self): 

446 return self._node_devices 

447 

448 @property 

449 def node_op_types(self): 

450 return self._node_op_types 

451 

452 @property 

453 def node_attributes(self): 

454 return self._node_attributes 

455 

456 @property 

457 def node_inputs(self): 

458 return self._node_inputs 

459 

460 @property 

461 def node_ctrl_inputs(self): 

462 return self._node_ctrl_inputs 

463 

464 @property 

465 def node_reversed_ref_inputs(self): 

466 return self._node_reversed_ref_inputs 

467 

468 @property 

469 def node_recipients(self): 

470 return self._node_recipients 

471 

472 @property 

473 def node_ctrl_recipients(self): 

474 return self._node_ctrl_recipients 

475 

476 

477def reconstruct_non_debug_graph_def(debug_graph_def): 

478 """Reconstruct original (non-debugger-decorated) partition GraphDef. 

479 

480 This method strips the input `tf.compat.v1.GraphDef` of the Copy* and 

481 Debug*-type nodes inserted by the debugger. 

482 

483 The reconstructed partition graph is identical to the original (i.e., 

484 non-debugger-decorated) partition graph except in the following respects: 

485 1) The exact names of the runtime-inserted internal nodes may differ. 

486 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops. 

487 2) As a consequence of 1, the nodes that receive input directly from such 

488 send- and recv-type ops will have different input names. 

489 3) The parallel_iteration attribute of while-loop Enter ops are set to 1. 

490 

491 Args: 

492 debug_graph_def: The debugger-decorated `tf.compat.v1.GraphDef`, with the 

493 debugger-inserted Copy* and Debug* nodes. 

494 

495 Returns: 

496 The reconstructed `tf.compat.v1.GraphDef` stripped of the debugger-inserted 

497 nodes. 

498 """ 

499 return DebugGraph(debug_graph_def).non_debug_graph_def