Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/lib/debug_data.py: 23%
526 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions to handle debug-dump data of TensorFlow Debugger."""
17import collections
18import glob
19import json
20import os
21import platform
22import re
24import numpy as np
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import types_pb2
28from tensorflow.core.util import event_pb2
29from tensorflow.python.debug.lib import debug_graphs
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.platform import gfile
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util import compat
36# TODO(cais): Tie these string constants in with C++?
37METADATA_FILE_PREFIX = "_tfdbg_"
38CORE_METADATA_TAG = "core_metadata_"
39GRAPH_FILE_TAG = "graph_"
40DEVICE_TAG = "device_"
41HASH_TAG = "hash"
43FETCHES_INFO_FILE_TAG = "fetches_info_"
44FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_"
47def _glob(glob_pattern):
48 if platform.system() == "Windows":
49 return glob.glob(glob_pattern)
50 else:
51 return gfile.Glob(glob_pattern)
54class InconvertibleTensorProto:
55 """Represents a TensorProto that cannot be converted to np.ndarray."""
57 def __init__(self, tensor_proto, initialized=True):
58 """Constructor.
60 Args:
61 tensor_proto: the `TensorProto` object that cannot be represented as a
62 `np.ndarray` object.
63 initialized: (`bool`) whether the Tensor is initialized.
64 """
65 self._tensor_proto = tensor_proto
66 self._initialized = initialized
68 def __str__(self):
69 output = "" if self._initialized else "Uninitialized tensor:\n"
70 output += str(self._tensor_proto)
71 return output
73 @property
74 def initialized(self):
75 return self._initialized
78def load_tensor_from_event_file(event_file_path):
79 """Load a tensor from an event file.
81 Assumes that the event file contains a `Event` protobuf and the `Event`
82 protobuf contains a `Tensor` value.
84 Args:
85 event_file_path: (`str`) path to the event file.
87 Returns:
88 The tensor value loaded from the event file, as a `numpy.ndarray`. For
89 uninitialized Tensors, returns `None`. For Tensors of data types that
90 cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return
91 `None`.
92 """
94 event = event_pb2.Event()
95 with gfile.Open(event_file_path, "rb") as f:
96 event.ParseFromString(f.read())
97 return load_tensor_from_event(event)
100def load_tensor_from_event(event):
101 """Load a tensor from an Event proto.
103 Args:
104 event: The Event proto, assumed to hold a tensor value in its
105 summary.value[0] field.
107 Returns:
108 The tensor value loaded from the event file, as a `numpy.ndarray`, if
109 representation of the tensor value by a `numpy.ndarray` is possible.
110 For uninitialized Tensors, returns `None`. For Tensors of data types that
111 cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return
112 the `TensorProto` protobuf object without converting it to a
113 `numpy.ndarray`.
114 """
116 tensor_proto = event.summary.value[0].tensor
117 shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape)
118 num_elements = 1
119 for shape_dim in shape:
120 num_elements *= shape_dim
122 if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements:
123 # Initialized tensor or empty tensor.
124 if tensor_proto.dtype == types_pb2.DT_RESOURCE:
125 tensor_value = InconvertibleTensorProto(tensor_proto)
126 else:
127 try:
128 tensor_value = tensor_util.MakeNdarray(tensor_proto)
129 except KeyError:
130 tensor_value = InconvertibleTensorProto(tensor_proto)
131 else:
132 # Uninitialized tensor or tensor of unconvertible data type.
133 tensor_value = InconvertibleTensorProto(tensor_proto, False)
135 return tensor_value
138def _load_graph_def_from_event_file(event_file_path):
139 event = event_pb2.Event()
140 with gfile.Open(event_file_path, "rb") as f:
141 event.ParseFromString(f.read())
143 return graph_pb2.GraphDef.FromString(event.graph_def)
146def _load_log_message_from_event_file(event_file_path):
147 event = event_pb2.Event()
148 with gfile.Open(event_file_path, "rb") as f:
149 event.ParseFromString(f.read())
151 return event.log_message.message
154def _is_graph_file(file_name):
155 return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG)
158def _is_run_fetches_info_file(file_name):
159 return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG
162def _is_run_feed_keys_info_file(file_name):
163 return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG
166def _get_tensor_name(node_name, output_slot):
167 """Get tensor name given node name and output slot index.
169 Args:
170 node_name: Name of the node that outputs the tensor, as a string.
171 output_slot: Output slot index of the tensor, as an integer.
173 Returns:
174 Name of the tensor, as a string.
175 """
177 return "%s:%d" % (node_name, output_slot)
180def _get_tensor_watch_key(node_name, output_slot, debug_op):
181 """Get the string representation of a debug watch on a tensor.
183 Args:
184 node_name: Name of the node by which the watched tensor is produced, as a
185 string.
186 output_slot: Output slot index of the tensor, as an integer.
187 debug_op: Name of the debug op that is used to watch the tensor, as a
188 string.
190 Returns:
191 A string representing the debug watch on the tensor (i.e., the "watch
192 key").
193 """
194 return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op)
197def has_inf_or_nan(datum, tensor):
198 """A predicate for whether a tensor consists of any bad numerical values.
200 This predicate is common enough to merit definition in this module.
201 Bad numerical values include `nan`s and `inf`s.
202 The signature of this function follows the requirement of the method
203 `DebugDumpDir.find()`.
205 Args:
206 datum: (`DebugTensorDatum`) Datum metadata.
207 tensor: (`numpy.ndarray` or None) Value of the tensor. None represents
208 an uninitialized tensor.
210 Returns:
211 (`bool`) True if and only if tensor consists of any nan or inf values.
212 """
214 _ = datum # Datum metadata is unused in this predicate.
216 if isinstance(tensor, InconvertibleTensorProto):
217 # Uninitialized tensor doesn't have bad numerical values.
218 # Also return False for data types that cannot be represented as numpy
219 # arrays.
220 return False
221 elif (np.issubdtype(tensor.dtype, np.floating) or
222 np.issubdtype(tensor.dtype, np.complexfloating) or
223 np.issubdtype(tensor.dtype, np.integer)):
224 return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
225 else:
226 return False
229_CoreMetadata = collections.namedtuple("CoreMetadata", [
230 "global_step", "session_run_index", "executor_step_index", "input_names",
231 "output_names", "target_nodes"
232])
235def extract_core_metadata_from_event_proto(event):
236 json_metadata = json.loads(event.log_message.message)
237 return _CoreMetadata(json_metadata["global_step"],
238 json_metadata["session_run_index"],
239 json_metadata["executor_step_index"],
240 json_metadata["input_names"],
241 json_metadata["output_names"],
242 json_metadata["target_nodes"])
245def device_name_to_device_path(device_name):
246 """Convert device name to device path."""
247 device_name_items = compat.as_text(device_name).split("/")
248 device_name_items = [item.replace(":", "_") for item in device_name_items]
249 return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items)
252def device_path_to_device_name(device_dir):
253 """Parse device name from device path.
255 Args:
256 device_dir: (str) a directory name for the device.
258 Returns:
259 (str) parsed device name.
260 """
261 path_items = os.path.basename(device_dir)[
262 len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",")
263 return "/".join([
264 path_item.replace("device_", "device:").replace("_", ":", 1)
265 for path_item in path_items])
268class DebugTensorDatum:
269 """A single tensor dumped by TensorFlow Debugger (tfdbg).
271 Contains metadata about the dumped tensor, including `timestamp`,
272 `node_name`, `output_slot`, `debug_op`, and path to the dump file
273 (`file_path`).
275 This type does not hold the generally space-expensive tensor value (numpy
276 array). Instead, it points to the file from which the tensor value can be
277 loaded (with the `get_tensor` method) if needed.
278 """
280 def __init__(self, dump_root, debug_dump_rel_path):
281 """`DebugTensorDatum` constructor.
283 Args:
284 dump_root: (`str`) Debug dump root directory. This path should not include
285 the path component that represents the device name (see also below).
286 debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the
287 `dump_root`. The first item of this relative path is assumed to be
288 a path representing the name of the device that the Tensor belongs to.
289 See `device_path_to_device_name` for more details on the device path.
290 For example, suppose the debug dump root
291 directory is `/tmp/tfdbg_1` and the dump file is at
292 `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`,
293 then the value of the debug_dump_rel_path should be
294 `<device_path>/ns_1/node_a_0_DebugIdentity_1234456789`.
296 Raises:
297 ValueError: If the base file name of the dump file does not conform to
298 the dump file naming pattern:
299 `node_name`_`output_slot`_`debug_op`_`timestamp`
300 """
302 path_components = os.path.normpath(debug_dump_rel_path).split(os.sep)
303 self._device_name = device_path_to_device_name(path_components[0])
304 base = path_components[-1]
305 if base.count("_") < 3:
306 raise ValueError(
307 "Dump file path does not conform to the naming pattern: %s" % base)
309 self._extended_timestamp = base.split("_")[-1]
310 # It may include an index suffix at the end if file path collision happened
311 # due to identical timestamps.
312 if "-" in self._extended_timestamp:
313 self._timestamp = int(
314 self._extended_timestamp[:self._extended_timestamp.find("-")])
315 else:
316 self._timestamp = int(self._extended_timestamp)
318 self._debug_op = base.split("_")[-2]
319 self._output_slot = int(base.split("_")[-3])
321 node_base_name = "_".join(base.split("_")[:-3])
322 self._node_name = "/".join(path_components[1:-1] + [node_base_name])
324 self._file_path = os.path.join(dump_root, debug_dump_rel_path)
325 self._dump_size_bytes = (gfile.Stat(self._file_path).length if
326 gfile.Exists(self._file_path) else None)
328 def __str__(self):
329 return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name,
330 self.node_name,
331 self.output_slot,
332 self.debug_op,
333 self.timestamp)
335 def __repr__(self):
336 return self.__str__()
338 def get_tensor(self):
339 """Get tensor from the dump (`Event`) file.
341 Returns:
342 The tensor loaded from the dump (`Event`) file.
343 """
345 return load_tensor_from_event_file(self.file_path)
347 # TODO(cais): Add time unit suffix to timestamp and t0 (us).
348 @property
349 def timestamp(self):
350 """Timestamp of when this tensor value was dumped.
352 Returns:
353 (`int`) The timestamp in microseconds.
354 """
356 return self._timestamp
358 @property
359 def extended_timestamp(self):
360 """Extended timestamp, possibly with an index suffix.
362 The index suffix, e.g., "-1", is for disambiguating multiple dumps of the
363 same tensor with the same timestamp, which can occur if the dumping events
364 are spaced by shorter than the temporal resolution of the timestamps.
366 Returns:
367 (`str`) The extended timestamp.
368 """
370 return self._extended_timestamp
372 @property
373 def debug_op(self):
374 """Name of the debug op.
376 Returns:
377 (`str`) debug op name (e.g., `DebugIdentity`).
378 """
380 return self._debug_op
382 @property
383 def device_name(self):
384 """Name of the device that the tensor belongs to.
386 Returns:
387 (`str`) device name.
388 """
390 return self._device_name
392 @property
393 def node_name(self):
394 """Name of the node from which the tensor value was dumped.
396 Returns:
397 (`str`) name of the node watched by the debug op.
398 """
400 return self._node_name
402 @property
403 def output_slot(self):
404 """Output slot index from which the tensor value was dumped.
406 Returns:
407 (`int`) output slot index watched by the debug op.
408 """
410 return self._output_slot
412 @property
413 def tensor_name(self):
414 """Name of the tensor watched by the debug op.
416 Returns:
417 (`str`) `Tensor` name, in the form of `node_name`:`output_slot`
418 """
420 return _get_tensor_name(self.node_name, self.output_slot)
422 @property
423 def watch_key(self):
424 """Watch key identities a debug watch on a tensor.
426 Returns:
427 (`str`) A watch key, in the form of `tensor_name`:`debug_op`.
428 """
430 return _get_tensor_watch_key(self.node_name, self.output_slot,
431 self.debug_op)
433 @property
434 def file_path(self):
435 """Path to the file which stores the value of the dumped tensor."""
437 return self._file_path
439 @property
440 def dump_size_bytes(self):
441 """Size of the dump file.
443 Unit: byte.
445 Returns:
446 If the dump file exists, size of the dump file, in bytes.
447 If the dump file does not exist, None.
448 """
450 return self._dump_size_bytes
453class WatchKeyDoesNotExistInDebugDumpDirError(ValueError):
454 pass
457class DebugDumpDir:
458 """Data set from a debug-dump directory on filesystem.
460 An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances
461 in a tfdbg dump root directory.
462 """
464 def __init__(self, dump_root, partition_graphs=None, validate=True):
465 """`DebugDumpDir` constructor.
467 Args:
468 dump_root: (`str`) path to the dump root directory.
469 partition_graphs: A repeated field of GraphDefs representing the
470 partition graphs executed by the TensorFlow runtime.
471 validate: (`bool`) whether the dump files are to be validated against the
472 partition graphs.
474 Raises:
475 IOError: If dump_root does not exist as a directory.
476 ValueError: If more than one core metadata file is found under the dump
477 root directory.
478 """
480 if not gfile.IsDirectory(dump_root):
481 raise IOError("Dump root directory %s does not exist" % dump_root)
483 self._core_metadata = []
485 # Find the list of devices.
486 self._dump_root = dump_root
488 self._load_core_metadata()
489 self._load_fetches_info()
490 self._load_feeds_info()
491 self._load_all_device_dumps(partition_graphs, validate)
493 self._python_graph = None
495 def _load_all_device_dumps(self, partition_graphs, validate):
496 """Load the dump data for all devices."""
497 device_dirs = _glob(os.path.join(
498 self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*"))
500 self._device_names = []
501 self._t0s = {}
502 self._dump_tensor_data = {}
503 self._dump_graph_file_paths = {}
504 self._debug_watches = {}
505 self._watch_key_to_devices = {}
506 self._watch_key_to_datum = {}
507 self._watch_key_to_rel_time = {}
508 self._watch_key_to_dump_size_bytes = {}
509 for device_dir in device_dirs:
510 device_name = device_path_to_device_name(device_dir)
511 self._device_names.append(device_name)
512 self._load_device_dumps(device_name, device_dir)
513 self._load_partition_graphs(partition_graphs, validate)
514 self._calculate_t0()
516 for device_name in self._device_names:
517 self._create_tensor_watch_maps(device_name)
519 def _load_device_dumps(self, device_name, device_root):
520 """Load `DebugTensorDatum` instances from the dump root of a given device.
522 Populates a map {device_name: a list of `DebugTensorDatum`}, where the list
523 is sorted by ascending timestamp.
525 This sorting order reflects the order in which the TensorFlow executor
526 processed the nodes of the graph. It is (one of many possible) topological
527 sort of the nodes. This is useful for displaying tensors in the debugger
528 frontend as well as for the use case in which the user wants to find a
529 "culprit tensor", i.e., the first tensor in the graph that exhibits certain
530 problematic properties, i.e., all zero values, or bad numerical values such
531 as nan and inf.
533 In addition, creates a map from node name to debug watches. In this Map,
534 the key is the watched node name; the value is a dictionary.
535 Of this dictionary, the key is the watched_output_slot.
537 This method attempts to load the debug watches from the tensor dump files
538 first, before loading the full set of debug watches from the partition
539 graphs as done later. This is necessary because sometimes the partition
540 graphs may not be available, e.g., when the run errors out.
542 Args:
543 device_name: (`str`) name of the device.
544 device_root: (`str`) dump root directory of the given device.
546 Raises:
547 ValueError: If GraphDef for the device is not available.
548 """
550 self._dump_tensor_data[device_name] = []
551 self._debug_watches[device_name] = collections.defaultdict(
552 lambda: collections.defaultdict(set))
554 for root, _, files in gfile.Walk(device_root):
555 for f in files:
556 if _is_graph_file(f):
557 self._dump_graph_file_paths[device_name] = os.path.join(root, f)
558 else:
559 datum = self._dump_file_name_to_datum(root, f)
560 self._dump_tensor_data[device_name].append(datum)
561 self._debug_watches[device_name][datum.node_name][
562 datum.output_slot].add(datum.debug_op)
564 self._dump_tensor_data[device_name] = sorted(
565 self._dump_tensor_data[device_name],
566 key=lambda x: x.extended_timestamp)
568 if self._dump_tensor_data[device_name]:
569 self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp
570 else:
571 self._t0s[device_name] = None
573 def _calculate_t0(self):
574 """Calculate the first timestamp across all devices."""
575 t0s = [t0 for t0 in self._t0s.values() if t0 is not None]
576 self._t0 = min(t0s) if t0s else None
578 def _load_core_metadata(self):
579 core_metadata_files = _glob(os.path.join(
580 self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*"))
581 for core_metadata_file in core_metadata_files:
582 with gfile.Open(core_metadata_file, "rb") as f:
583 event = event_pb2.Event()
584 event.ParseFromString(f.read())
585 self._core_metadata.append(
586 extract_core_metadata_from_event_proto(event))
588 def _load_fetches_info(self):
589 fetches_info_files = _glob(os.path.join(
590 self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*"))
591 self._run_fetches_info = []
592 for fetches_info_file in fetches_info_files:
593 self._run_fetches_info.append(
594 _load_log_message_from_event_file(fetches_info_file))
596 def _load_feeds_info(self):
597 feeds_info_files = _glob(os.path.join(
598 self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*"))
599 self._run_feed_keys_info = []
600 for feeds_info_file in feeds_info_files:
601 self._run_feed_keys_info.append(
602 _load_log_message_from_event_file(feeds_info_file))
604 def _dump_file_name_to_datum(self, dir_name, file_name):
605 """Obtain a DebugTensorDatum from the directory and file name.
607 Args:
608 dir_name: (`str`) Name of the directory in which the dump file resides.
609 file_name: (`str`) Base name of the dump file.
611 Returns:
612 (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file.
613 """
615 # Calculate the relative path of the dump file with respect to the root.
616 debug_dump_rel_path = os.path.join(
617 os.path.relpath(dir_name, self._dump_root), file_name)
618 return DebugTensorDatum(self._dump_root, debug_dump_rel_path)
620 def _create_tensor_watch_maps(self, device_name):
621 """Create maps from tensor watch keys to datum and to timestamps.
623 Create a map from watch key (tensor name + debug op) to `DebugTensorDatum`
624 item. Also make a map from watch key to relative timestamp.
625 "relative" means (absolute timestamp - t0).
627 Args:
628 device_name: (str) name of the device.
629 """
631 self._watch_key_to_datum[device_name] = {}
632 self._watch_key_to_rel_time[device_name] = {}
633 self._watch_key_to_dump_size_bytes[device_name] = {}
634 for datum in self._dump_tensor_data[device_name]:
635 if datum.watch_key not in self._watch_key_to_devices:
636 self._watch_key_to_devices[datum.watch_key] = {device_name}
637 else:
638 self._watch_key_to_devices[datum.watch_key].add(device_name)
640 if datum.watch_key not in self._watch_key_to_datum[device_name]:
641 self._watch_key_to_datum[device_name][datum.watch_key] = [datum]
642 self._watch_key_to_rel_time[device_name][datum.watch_key] = [
643 datum.timestamp - self._t0]
644 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [
645 datum.dump_size_bytes]
646 else:
647 self._watch_key_to_datum[device_name][datum.watch_key].append(datum)
648 self._watch_key_to_rel_time[device_name][datum.watch_key].append(
649 datum.timestamp - self._t0)
650 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append(
651 datum.dump_size_bytes)
653 def set_python_graph(self, python_graph):
654 """Provide Python `Graph` object to the wrapper.
656 Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph`
657 is a Python object and carries additional information such as the traceback
658 of the construction of the nodes in the graph.
660 Args:
661 python_graph: (ops.Graph) The Python Graph object.
662 """
664 self._python_graph = python_graph
665 self._node_traceback = {}
666 if self._python_graph:
667 for op in self._python_graph.get_operations():
668 self._node_traceback[op.name] = tuple(map(tuple, op.traceback))
670 @property
671 def python_graph(self):
672 """Get the Python graph.
674 Returns:
675 If the Python graph has been set, returns a `tf.Graph` object. Otherwise,
676 returns None.
677 """
679 return self._python_graph
681 @property
682 def core_metadata(self):
683 """Metadata about the `Session.run()` call from the core runtime.
685 Of the three counters available in the return value, `global_step` is
686 supplied by the caller of the debugged `Session.run()`, while
687 `session_run_index` and `executor_step_index` are determined by the state
688 of the core runtime, automatically. For the same fetch list, feed keys and
689 debug tensor watch options, the same executor will be used and
690 `executor_step_index` should increase by one at a time. However, runs with
691 different fetch lists, feed keys and debug_tensor watch options that all
692 share the same `Session` object can lead to gaps in `session_run_index`.
694 Returns:
695 If core metadata are loaded, a `namedtuple` with the fields:
696 `global_step`: A global step count supplied by the caller of
697 `Session.run()`. It is optional to the caller. If the caller did not
698 supply this parameter, its value will be -1.
699 `session_run_index`: A sorted index for Run() calls to the underlying
700 TensorFlow `Session` object.
701 `executor_step_index`: A counter for invocations of a given runtime
702 executor. The same executor is re-used for the same fetched tensors,
703 target nodes, input feed keys and debug tensor watch options.
704 `input_names`: Names of the input (feed) Tensors.
705 `output_names`: Names of the output (fetched) Tensors.
706 `target_nodes`: Names of the target nodes.
707 If the core metadata have not been loaded, `None`.
708 If more than one core metadata files exist, return a list of the
709 `nametuple` described above.
710 """
712 output = self._core_metadata
713 return output[0] if len(output) == 1 else output
715 @property
716 def dumped_tensor_data(self):
717 """Retrieve dumped tensor data."""
718 if len(self.devices()) == 1:
719 return self._dump_tensor_data[self.devices()[0]]
720 else:
721 all_devices_data = self._dump_tensor_data.values()
722 data = []
723 for device_data in all_devices_data:
724 data.extend(device_data)
725 return sorted(data, key=lambda x: x.extended_timestamp)
727 @property
728 def t0(self):
729 """Absolute timestamp of the first dumped tensor across all devices.
731 Returns:
732 (`int`) absolute timestamp of the first dumped tensor, in microseconds.
733 """
734 return self._t0
736 @property
737 def size(self):
738 """Total number of dumped tensors in the dump root directory.
740 Returns:
741 (`int`) The total number of dumped tensors in the dump root directory.
742 """
743 return sum(len(self._dump_tensor_data[device_name])
744 for device_name in self._dump_tensor_data)
746 def _load_partition_graphs(self, client_partition_graphs, validate):
747 """Load and process partition graphs.
749 Load the graphs; parse the input and control input structure; obtain the
750 device and op type of each node; remove the Copy and debug ops inserted
751 by the debugger. The gathered information can be used to validate the
752 tensor dumps.
754 Args:
755 client_partition_graphs: A repeated field of GraphDefs representing the
756 partition graphs executed by the TensorFlow runtime, from the Python
757 client. These partition graphs are used only if partition graphs
758 cannot be loaded from the dump directory on the file system.
759 validate: (`bool`) Whether the dump files are to be validated against the
760 partition graphs.
762 Raises:
763 ValueError: If the partition GraphDef of one or more devices fail to be
764 loaded.
765 """
766 self._debug_graphs = {}
767 self._node_devices = {}
769 partition_graphs_and_device_names = []
770 for device_name in self._device_names:
771 partition_graph = None
772 if device_name in self._dump_graph_file_paths:
773 partition_graph = _load_graph_def_from_event_file(
774 self._dump_graph_file_paths[device_name])
775 else:
776 logging.warn(
777 "Failed to load partition graphs for device %s from disk. "
778 "As a fallback, the client graphs will be used. This "
779 "may cause mismatches in device names." % device_name)
780 partition_graph = self._find_partition_graph(client_partition_graphs,
781 device_name)
783 if partition_graph:
784 partition_graphs_and_device_names.append((partition_graph,
785 device_name))
787 for partition_graph, maybe_device_name in partition_graphs_and_device_names:
788 debug_graph = debug_graphs.DebugGraph(partition_graph,
789 device_name=maybe_device_name)
790 self._debug_graphs[debug_graph.device_name] = debug_graph
791 self._collect_node_devices(debug_graph)
793 if validate and debug_graph.device_name in self._dump_tensor_data:
794 self._validate_dump_with_graphs(debug_graph.device_name)
796 def _find_partition_graph(self, partition_graphs, device_name):
797 if partition_graphs is None:
798 return None
799 else:
800 for graph_def in partition_graphs:
801 for node_def in graph_def.node:
802 if node_def.device == device_name:
803 return graph_def
804 return None
806 def _collect_node_devices(self, debug_graph):
807 for node_name in debug_graph.node_devices:
808 if node_name in self._node_devices:
809 self._node_devices[node_name] = self._node_devices[node_name].union(
810 debug_graph.node_devices[node_name])
811 else:
812 self._node_devices[node_name] = debug_graph.node_devices[node_name]
814 def _validate_dump_with_graphs(self, device_name):
815 """Validate the dumped tensor data against the partition graphs.
817 Only the watched nodes are validated by this method, because tfdbg allows
818 clients to watch only a subset of the nodes.
820 Args:
821 device_name: (`str`) device name.
823 Raises:
824 LookupError: If the partition graphs have not been loaded yet.
825 ValueError: If dumps contain node names not found in partition graph.
826 Or if the temporal order of the dump's timestamps violate the
827 input relations on the partition graphs.
828 """
829 if not self._debug_graphs:
830 raise LookupError(
831 "No partition graphs loaded for device %s" % device_name)
832 debug_graph = self._debug_graphs[device_name]
834 # Verify that the node names in the dump data are all present in the
835 # partition graphs.
836 for datum in self._dump_tensor_data[device_name]:
837 if datum.node_name not in debug_graph.node_inputs:
838 raise ValueError("Node name '%s' is not found in partition graphs of "
839 "device %s." % (datum.node_name, device_name))
841 pending_inputs = {}
842 for node in debug_graph.node_inputs:
843 pending_inputs[node] = []
844 inputs = debug_graph.node_inputs[node]
845 for inp in inputs:
846 inp_node = debug_graphs.get_node_name(inp)
847 inp_output_slot = debug_graphs.get_output_slot(inp)
848 # Inputs from Enter and NextIteration nodes are not validated because
849 # DebugNodeInserter::InsertNodes() in the debugger core skips creating
850 # control edges from debug ops watching these types of nodes.
851 if (inp_node in self._debug_watches[device_name] and
852 inp_output_slot in self._debug_watches[device_name][inp_node] and
853 debug_graph.node_op_types.get(inp) not in (
854 "Enter", "NextIteration") and
855 (inp_node, inp_output_slot) not in pending_inputs[node]):
856 pending_inputs[node].append((inp_node, inp_output_slot))
858 for i, datum in enumerate(self._dump_tensor_data[device_name]):
859 node = datum.node_name
860 slot = datum.output_slot
861 # In some cases (e.g., system clocks with insufficient precision),
862 # the upstream and downstream tensors may have identical timestamps, the
863 # following check examines this possibility and avoids raising an error if
864 # that is the case.
865 if not self._satisfied_at_timestamp(
866 device_name, pending_inputs[node], datum.timestamp, start_i=i + 1):
867 raise ValueError("Causality violated in timing relations of debug "
868 "dumps: %s (%d): "
869 "these input(s) are not satisfied: %s" %
870 (node, datum.timestamp, repr(pending_inputs[node])))
872 recipients = debug_graph.node_recipients[node]
873 for recipient in recipients:
874 recipient_pending_inputs = pending_inputs[recipient]
875 if (node, slot) in recipient_pending_inputs:
876 if self.node_op_type(recipient) == "Merge":
877 # If this is a Merge op, we automatically clear the list because
878 # a Merge node only requires one of its two inputs.
879 del recipient_pending_inputs[:]
880 else:
881 del recipient_pending_inputs[
882 recipient_pending_inputs.index((node, slot))]
884 def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0):
885 """Determine whether pending inputs are satisfied at given timestamp.
887 Note: This method mutates the input argument "pending".
889 Args:
890 device_name: (str) device name.
891 pending: A list of 2-tuple (node_name, output_slot): the dependencies to
892 check.
893 timestamp: (int) the timestamp in question.
894 start_i: (int) the index in self._dump_tensor_data to start searching for
895 the timestamp.
897 Returns:
898 (bool) Whether all the dependencies in pending are satisfied at the
899 timestamp. If pending is empty to begin with, return True.
900 """
901 if not pending:
902 return True
904 for datum in self._dump_tensor_data[device_name][start_i:]:
905 if datum.timestamp > timestamp:
906 break
907 if (datum.timestamp == timestamp and
908 (datum.node_name, datum.output_slot) in pending):
909 pending.remove((datum.node_name, datum.output_slot))
910 if not pending:
911 return True
913 return not pending
915 def loaded_partition_graphs(self):
916 """Test whether partition graphs have been loaded."""
917 return bool(self._debug_graphs)
919 def partition_graphs(self):
920 """Get the partition graphs.
922 Returns:
923 Partition graphs as a list of GraphDef.
925 Raises:
926 LookupError: If no partition graphs have been loaded.
927 """
928 if not self._debug_graphs:
929 raise LookupError("No partition graphs have been loaded.")
930 return [self._debug_graphs[key].debug_graph_def
931 for key in self._debug_graphs]
933 def reconstructed_non_debug_partition_graphs(self):
934 """Reconstruct partition graphs with the debugger-inserted ops stripped.
936 The reconstructed partition graphs are identical to the original (i.e.,
937 non-debugger-decorated) partition graphs except in the following respects:
938 1) The exact names of the runtime-inserted internal nodes may differ.
939 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
940 2) As a consequence of 1, the nodes that receive input directly from such
941 send- and recv-type ops will have different input names.
942 3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
944 Returns:
945 A dict mapping device names (`str`s) to reconstructed
946 `tf.compat.v1.GraphDef`s.
947 """
948 non_debug_graphs = {}
949 for key in self._debug_graphs:
950 non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
951 return non_debug_graphs
953 @property
954 def run_fetches_info(self):
955 """Get a str representation of the fetches used in the Session.run() call.
957 Returns:
958 If the information is available from one `Session.run` call, a `str`
959 obtained from `repr(fetches)`.
960 If the information is available from multiple `Session.run` calls, a
961 `list` of `str` from `repr(fetches)`.
962 If the information is not available, `None`.
963 """
965 output = self._run_fetches_info
966 return output[0] if len(output) == 1 else output
968 @property
969 def run_feed_keys_info(self):
970 """Get a str representation of the feed_dict used in the Session.run() call.
972 Returns:
973 If the information is available from one `Session.run` call, a `str`
974 obtained from `repr(feed_dict)`.
975 If the information is available from multiple `Session.run` calls, a
976 `list` of `str` obtained from `repr(feed_dict)`.
977 If the information is not available, `None`.
978 """
980 output = self._run_feed_keys_info
981 return output[0] if len(output) == 1 else output
983 def _infer_device_name(self, device_name, node_name):
984 """Infer the device name given node name.
986 If device_name is provided (i.e., not None), it'll be simply returned right
987 away.
989 Args:
990 device_name: (str or None) name of the device. If None, will try to infer
991 the device name by looking at the available nodes.
992 node_name: (str) name of the node.
994 Returns:
995 (str) Inferred name of the device, if available.
997 Raises:
998 ValueError: If the node name does not exist on any of the available
999 devices or if there are multiple devices that contain the node with
1000 the given name.
1001 """
1002 if device_name is None:
1003 if node_name in self._node_devices:
1004 if len(self._node_devices[node_name]) == 1:
1005 return list(self._node_devices[node_name])[0]
1006 else:
1007 raise ValueError(
1008 "There are multiple (%d) devices with nodes named '%s' but "
1009 "device_name is not specified." %
1010 (len(self._node_devices[node_name]), node_name))
1011 else:
1012 raise ValueError("None of the %d device(s) has a node named '%s'." %
1013 (len(self._device_names), node_name))
1014 else:
1015 return device_name
1017 def nodes(self, device_name=None):
1018 """Get a list of all nodes from the partition graphs.
1020 Args:
1021 device_name: (`str`) name of device. If None, all nodes from all available
1022 devices will be included.
1024 Returns:
1025 All nodes' names, as a list of str.
1027 Raises:
1028 LookupError: If no partition graphs have been loaded.
1029 ValueError: If specified node name does not exist.
1030 """
1031 if not self._debug_graphs:
1032 raise LookupError("No partition graphs have been loaded.")
1033 if device_name is None:
1034 nodes = []
1035 for device_name in self._debug_graphs:
1036 nodes.extend(self._debug_graphs[device_name].node_inputs.keys())
1037 return nodes
1038 else:
1039 if device_name not in self._debug_graphs:
1040 raise ValueError("Invalid device name: %s" % device_name)
1041 return self._debug_graphs[device_name].node_inputs.keys()
1043 def node_attributes(self, node_name, device_name=None):
1044 """Get the attributes of a node.
1046 Args:
1047 node_name: Name of the node in question.
1048 device_name: (`str`) name of the device. If there is only one device or if
1049 node_name exists on only one device, this argument is optional.
1051 Returns:
1052 Attributes of the node.
1054 Raises:
1055 LookupError: If no partition graphs have been loaded.
1056 """
1057 if not self._debug_graphs:
1058 raise LookupError("No partition graphs have been loaded.")
1060 device_name = self._infer_device_name(device_name, node_name)
1061 return self._debug_graphs[device_name].node_attributes[node_name]
1063 def node_inputs(self, node_name, is_control=False, device_name=None):
1064 """Get the inputs of given node according to partition graphs.
1066 Args:
1067 node_name: Name of the node.
1068 is_control: (`bool`) Whether control inputs, rather than non-control
1069 inputs, are to be returned.
1070 device_name: (`str`) name of the device. If there is only one device or if
1071 node_name exists on only one device, this argument is optional.
1073 Returns:
1074 (`list` of `str`) inputs to the node, as a list of node names.
1076 Raises:
1077 LookupError: If node inputs and control inputs have not been loaded
1078 from partition graphs yet.
1079 """
1080 if not self._debug_graphs:
1081 raise LookupError(
1082 "Node inputs are not loaded from partition graphs yet.")
1084 device_name = self._infer_device_name(device_name, node_name)
1085 if is_control:
1086 return self._debug_graphs[device_name].node_ctrl_inputs[node_name]
1087 else:
1088 return self._debug_graphs[device_name].node_inputs[node_name]
1090 def transitive_inputs(self,
1091 node_name,
1092 include_control=True,
1093 include_reversed_ref=False,
1094 device_name=None,):
1095 """Get the transitive inputs of given node according to partition graphs.
1097 Args:
1098 node_name: Name of the node.
1099 include_control: Include control inputs (True by default).
1100 include_reversed_ref: Whether a ref input, say from A to B, is to be also
1101 considered as an input from B to A. The rationale is that ref inputs
1102 generally let the recipient (e.g., B in this case) mutate the value of
1103 the source (e.g., A in this case). So the reverse direction of the ref
1104 edge reflects the direction of information flow.
1105 device_name: (`str`) name of the device. If there is only one device or if
1106 node_name exists on only one device, this argument is optional.
1108 Returns:
1109 (`list` of `str`) all transitive inputs to the node, as a list of node
1110 names.
1112 Raises:
1113 LookupError: If node inputs and control inputs have not been loaded
1114 from partition graphs yet.
1115 """
1116 if not self._debug_graphs:
1117 raise LookupError(
1118 "Node inputs are not loaded from partition graphs yet.")
1120 device_name = self._infer_device_name(device_name, node_name)
1122 input_lists = [self._debug_graphs[device_name].node_inputs]
1123 if include_control:
1124 input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs)
1125 if include_reversed_ref:
1126 input_lists.append(
1127 self._debug_graphs[device_name].node_reversed_ref_inputs)
1128 tracer = debug_graphs.DFSGraphTracer(
1129 input_lists,
1130 skip_node_names=self._get_merge_node_names(device_name))
1131 tracer.trace(node_name)
1132 return tracer.inputs()
1134 def _get_merge_node_names(self, device_name):
1135 """Lazily get a list of Merge nodes on a given device."""
1136 if device_name not in self._device_names:
1137 raise ValueError("Invalid device name: %s" % device_name)
1139 if not hasattr(self, "_merge_node_names"):
1140 self._merge_node_names = {}
1141 if device_name not in self._merge_node_names:
1142 debug_graph = self._debug_graphs[device_name]
1143 self._merge_node_names[device_name] = [
1144 node for node in debug_graph.node_op_types
1145 if debug_graph.node_op_types[node] == "Merge"]
1146 return self._merge_node_names[device_name]
1148 def find_some_path(self,
1149 src_node_name,
1150 dst_node_name,
1151 include_control=True,
1152 include_reversed_ref=False,
1153 device_name=None):
1154 """Find a path between a source node and a destination node.
1156 Limitation: the source and destination are required to be on the same
1157 device, i.e., this method does not yet take into account Send/Recv nodes
1158 across devices.
1160 TODO(cais): Make this method work across device edges by tracing Send/Recv
1161 nodes.
1163 Args:
1164 src_node_name: (`str`) name of the source node or name of an output tensor
1165 of the node.
1166 dst_node_name: (`str`) name of the destination node or name of an output
1167 tensor of the node.
1168 include_control: (`bool`) whrther control edges are considered in the
1169 graph tracing.
1170 include_reversed_ref: Whether a ref input, say from A to B, is to be also
1171 considered as an input from B to A. The rationale is that ref inputs
1172 generally let the recipient (e.g., B in this case) mutate the value of
1173 the source (e.g., A in this case). So the reverse direction of the ref
1174 edge reflects the direction of information flow.
1175 device_name: (`str`) name of the device. If there is only one device or if
1176 node_name exists on only one device, this argument is optional.
1178 Returns:
1179 A path from the src_node_name to dst_node_name, as a `list` of `str`, if
1180 it exists. The list includes src_node_name as the first item and
1181 dst_node_name as the last.
1182 If such a path does not exist, `None`.
1184 Raises:
1185 ValueError: If the source and destination nodes are not on the same
1186 device.
1187 """
1188 src_device_name = self._infer_device_name(device_name, src_node_name)
1189 dst_device_name = self._infer_device_name(device_name, dst_node_name)
1191 if src_device_name != dst_device_name:
1192 raise ValueError(
1193 "Source (%s) and destination (%s) are not on the same device: "
1194 "%s vs. %s" % (src_node_name, dst_node_name, src_device_name,
1195 dst_device_name))
1197 input_lists = [self._debug_graphs[dst_device_name].node_inputs]
1198 debug_graph = self._debug_graphs[dst_device_name]
1199 if include_control:
1200 input_lists.append(debug_graph.node_ctrl_inputs)
1201 if include_reversed_ref:
1202 input_lists.append(debug_graph.node_reversed_ref_inputs)
1203 tracer = debug_graphs.DFSGraphTracer(
1204 input_lists,
1205 skip_node_names=self._get_merge_node_names(dst_device_name),
1206 destination_node_name=src_node_name)
1207 # Here the value of destination_node_name is src_node_name, because we
1208 # are tracing the graph from output to its inputs (i.e., going backwards
1209 # on the graph).
1211 try:
1212 tracer.trace(dst_node_name)
1213 except debug_graphs.GraphTracingReachedDestination:
1214 # Prune nodes not on the path.
1215 inputs = [dst_node_name] + tracer.inputs()
1216 depth_list = [0] + tracer.depth_list()
1218 path = []
1219 curr_depth = depth_list[-1]
1220 for inp, depth in zip(reversed(inputs), reversed(depth_list)):
1221 if depth == curr_depth:
1222 path.append(inp)
1223 curr_depth -= 1
1224 return path
1226 def node_recipients(self, node_name, is_control=False, device_name=None):
1227 """Get recipient of the given node's output according to partition graphs.
1229 Args:
1230 node_name: (`str`) name of the node.
1231 is_control: (`bool`) whether control outputs, rather than non-control
1232 outputs, are to be returned.
1233 device_name: (`str`) name of the device. If there is only one device or if
1234 node_name exists on only one device, this argument is optional.
1236 Returns:
1237 (`list` of `str`) all inputs to the node, as a list of node names.
1239 Raises:
1240 LookupError: If node inputs and control inputs have not been loaded
1241 from partition graphs yet.
1242 """
1244 if not self._debug_graphs:
1245 raise LookupError(
1246 "Node recipients are not loaded from partition graphs yet.")
1248 device_name = self._infer_device_name(device_name, node_name)
1249 debug_graph = self._debug_graphs[device_name]
1250 if is_control:
1251 return debug_graph.node_ctrl_recipients[node_name]
1252 else:
1253 return debug_graph.node_recipients[node_name]
1255 def devices(self):
1256 """Get the list of device names.
1258 Returns:
1259 (`list` of `str`) names of the devices.
1260 """
1261 return self._device_names
1263 def node_exists(self, node_name, device_name=None):
1264 """Test if a node exists in the partition graphs.
1266 Args:
1267 node_name: (`str`) name of the node to be checked.
1268 device_name: optional device name. If None, will search for the node
1269 on all available devices. Otherwise, search for the node only on
1270 the given device.
1272 Returns:
1273 A boolean indicating whether the node exists.
1275 Raises:
1276 LookupError: If no partition graphs have been loaded yet.
1277 ValueError: If device_name is specified but cannot be found.
1278 """
1279 if not self._debug_graphs:
1280 raise LookupError(
1281 "Nodes have not been loaded from partition graphs yet.")
1283 if (device_name is not None) and device_name not in self._debug_graphs:
1284 raise ValueError(
1285 "The specified device_name '%s' cannot be found." % device_name)
1287 for _, debug_graph in self._debug_graphs.items():
1288 if node_name in debug_graph.node_inputs:
1289 return True
1290 return False
1292 def node_device(self, node_name):
1293 """Get the names of the devices that has nodes of the specified name.
1295 Args:
1296 node_name: (`str`) name of the node.
1298 Returns:
1299 (`str` or `list` of `str`) name of the device(s) on which the node of the
1300 given name is found. Returns a `str` if there is only one such device,
1301 otherwise return a `list` of `str`.
1303 Raises:
1304 LookupError: If node inputs and control inputs have not been loaded
1305 from partition graphs yet.
1306 ValueError: If the node does not exist in partition graphs.
1307 """
1308 if not self._debug_graphs:
1309 raise LookupError(
1310 "Node devices are not loaded from partition graphs yet.")
1312 if node_name not in self._node_devices:
1313 raise ValueError("Node '%s' does not exist in partition graphs." %
1314 node_name)
1316 output = list(self._node_devices[node_name])
1317 return output[0] if len(output) == 1 else output
1319 def node_op_type(self, node_name, device_name=None):
1320 """Get the op type of given node.
1322 Args:
1323 node_name: (`str`) name of the node.
1324 device_name: (`str`) name of the device. If there is only one device or if
1325 node_name exists on only one device, this argument is optional.
1327 Returns:
1328 (`str`) op type of the node.
1330 Raises:
1331 LookupError: If node op types have not been loaded
1332 from partition graphs yet.
1333 """
1334 if not self._debug_graphs:
1335 raise LookupError(
1336 "Node op types are not loaded from partition graphs yet.")
1338 device_name = self._infer_device_name(device_name, node_name)
1339 return self._debug_graphs[device_name].node_op_types[node_name]
1341 def debug_watch_keys(self, node_name, device_name=None):
1342 """Get all tensor watch keys of given node according to partition graphs.
1344 Args:
1345 node_name: (`str`) name of the node.
1346 device_name: (`str`) name of the device. If there is only one device or if
1347 node_name exists on only one device, this argument is optional.
1349 Returns:
1350 (`list` of `str`) all debug tensor watch keys. Returns an empty list if
1351 the node name does not correspond to any debug watch keys.
1353 Raises:
1354 `LookupError`: If debug watch information has not been loaded from
1355 partition graphs yet.
1356 """
1358 try:
1359 device_name = self._infer_device_name(device_name, node_name)
1360 except ValueError:
1361 return []
1363 if node_name not in self._debug_watches[device_name]:
1364 return []
1366 watch_keys = []
1367 for watched_slot in self._debug_watches[device_name][node_name]:
1368 debug_ops = self._debug_watches[device_name][node_name][watched_slot]
1369 for debug_op in debug_ops:
1370 watch_keys.append(
1371 _get_tensor_watch_key(node_name, watched_slot, debug_op))
1373 return watch_keys
1375 def watch_key_to_data(self, debug_watch_key, device_name=None):
1376 """Get all `DebugTensorDatum` instances corresponding to a debug watch key.
1378 Args:
1379 debug_watch_key: (`str`) debug watch key.
1380 device_name: (`str`) name of the device. If there is only one device or if
1381 the specified debug_watch_key exists on only one device, this argument
1382 is optional.
1384 Returns:
1385 A list of `DebugTensorDatum` instances that correspond to the debug watch
1386 key. If the watch key does not exist, returns an empty list.
1388 Raises:
1389 ValueError: If there are multiple devices that have the debug_watch_key,
1390 but device_name is not specified.
1391 """
1392 if device_name is None:
1393 matching_device_names = [
1394 name for name in self._watch_key_to_datum
1395 if debug_watch_key in self._watch_key_to_datum[name]]
1396 if not matching_device_names:
1397 return []
1398 elif len(matching_device_names) == 1:
1399 device_name = matching_device_names[0]
1400 else:
1401 raise ValueError(
1402 "The debug watch key '%s' exists on multiple (%d) devices, but "
1403 "device name is not specified." %
1404 (debug_watch_key, len(matching_device_names)))
1405 elif device_name not in self._debug_key_to_datum:
1406 raise ValueError(
1407 "There is no device named '%s' consisting of debug watch keys." %
1408 device_name)
1410 return self._watch_key_to_datum[device_name].get(debug_watch_key, [])
1412 def find(self,
1413 predicate,
1414 first_n=0,
1415 device_name=None,
1416 exclude_node_names=None):
1417 """Find dumped tensor data by a certain predicate.
1419 Args:
1420 predicate: A callable that takes two input arguments:
1422 ```python
1423 def predicate(debug_tensor_datum, tensor):
1424 # returns a bool
1425 ```
1427 where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which
1428 carries the metadata, such as the `Tensor`'s node name, output slot
1429 timestamp, debug op name, etc.; and `tensor` is the dumped tensor value
1430 as a `numpy.ndarray`.
1431 first_n: (`int`) return only the first n `DebugTensotDatum` instances (in
1432 time order) for which the predicate returns True. To return all the
1433 `DebugTensotDatum` instances, let first_n be <= 0.
1434 device_name: optional device name.
1435 exclude_node_names: Optional regular expression to exclude nodes with
1436 names matching the regular expression.
1438 Returns:
1439 A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
1440 for which predicate returns True, sorted in ascending order of the
1441 timestamp.
1442 """
1443 if exclude_node_names:
1444 exclude_node_names = re.compile(exclude_node_names)
1446 matched_data = []
1447 for device in (self._dump_tensor_data if device_name is None
1448 else (self._dump_tensor_data[device_name],)):
1449 for datum in self._dump_tensor_data[device]:
1450 if exclude_node_names and exclude_node_names.match(datum.node_name):
1451 continue
1453 if predicate(datum, datum.get_tensor()):
1454 matched_data.append(datum)
1456 if first_n > 0 and len(matched_data) >= first_n:
1457 return matched_data
1459 return matched_data
1461 def get_tensor_file_paths(self,
1462 node_name,
1463 output_slot,
1464 debug_op,
1465 device_name=None):
1466 """Get the file paths from a debug-dumped tensor.
1468 Args:
1469 node_name: (`str`) name of the node that the tensor is produced by.
1470 output_slot: (`int`) output slot index of tensor.
1471 debug_op: (`str`) name of the debug op.
1472 device_name: (`str`) name of the device. If there is only one device or if
1473 the specified debug_watch_key exists on only one device, this argument
1474 is optional.
1476 Returns:
1477 List of file path(s) loaded. This is a list because each debugged tensor
1478 may be dumped multiple times.
1480 Raises:
1481 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1482 the debug-dump data.
1483 """
1485 device_name = self._infer_device_name(device_name, node_name)
1486 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1487 if watch_key not in self._watch_key_to_datum[device_name]:
1488 raise WatchKeyDoesNotExistInDebugDumpDirError(
1489 "Watch key \"%s\" does not exist in the debug dump of device %s" %
1490 (watch_key, device_name))
1492 return [datum.file_path for datum in
1493 self._watch_key_to_datum[device_name][watch_key]]
1495 def get_tensors(self, node_name, output_slot, debug_op, device_name=None):
1496 """Get the tensor value from for a debug-dumped tensor.
1498 The tensor may be dumped multiple times in the dump root directory, so a
1499 list of tensors (`numpy.ndarray`) is returned.
1501 Args:
1502 node_name: (`str`) name of the node that the tensor is produced by.
1503 output_slot: (`int`) output slot index of tensor.
1504 debug_op: (`str`) name of the debug op.
1505 device_name: (`str`) name of the device. If there is only one device or if
1506 the specified debug_watch_key exists on only one device, this argument
1507 is optional.
1509 Returns:
1510 List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s).
1512 Raises:
1513 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1514 the debug-dump data.
1515 """
1517 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1518 try:
1519 device_name = self._infer_device_name(device_name, node_name)
1520 return [datum.get_tensor() for datum in
1521 self._watch_key_to_datum[device_name][watch_key]]
1522 except (ValueError, KeyError):
1523 raise WatchKeyDoesNotExistInDebugDumpDirError(
1524 "Watch key \"%s\" does not exist in the debug dump of device %s" %
1525 (watch_key, device_name))
1527 def get_rel_timestamps(self,
1528 node_name,
1529 output_slot,
1530 debug_op,
1531 device_name=None):
1532 """Get the relative timestamp from for a debug-dumped tensor.
1534 Relative timestamp means (absolute timestamp - `t0`), where `t0` is the
1535 absolute timestamp of the first dumped tensor in the dump root. The tensor
1536 may be dumped multiple times in the dump root directory, so a list of
1537 relative timestamps (`numpy.ndarray`) is returned.
1539 Args:
1540 node_name: (`str`) name of the node that the tensor is produced by.
1541 output_slot: (`int`) output slot index of tensor.
1542 debug_op: (`str`) name of the debug op.
1543 device_name: (`str`) name of the device. If there is only one device or if
1544 the specified debug_watch_key exists on only one device, this argument
1545 is optional.
1547 Returns:
1548 (`list` of `int`) list of relative timestamps.
1550 Raises:
1551 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1552 exist in the debug dump data.
1553 """
1555 device_name = self._infer_device_name(device_name, node_name)
1556 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1557 if watch_key not in self._watch_key_to_datum[device_name]:
1558 raise WatchKeyDoesNotExistInDebugDumpDirError(
1559 "Watch key \"%s\" does not exist in the debug dump" % watch_key)
1561 # TODO(cais): Figure out whether this should be relative to the global t0.
1562 return self._watch_key_to_rel_time[device_name][watch_key]
1564 def get_dump_sizes_bytes(self,
1565 node_name,
1566 output_slot,
1567 debug_op,
1568 device_name=None):
1569 """Get the sizes of the dump files for a debug-dumped tensor.
1571 Unit of the file size: byte.
1573 Args:
1574 node_name: (`str`) name of the node that the tensor is produced by.
1575 output_slot: (`int`) output slot index of tensor.
1576 debug_op: (`str`) name of the debug op.
1577 device_name: (`str`) name of the device. If there is only one device or if
1578 the specified debug_watch_key exists on only one device, this argument
1579 is optional.
1581 Returns:
1582 (`list` of `int`): list of dump file sizes in bytes.
1584 Raises:
1585 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1586 exist in the debug dump data.
1587 """
1589 device_name = self._infer_device_name(device_name, node_name)
1590 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1591 if watch_key not in self._watch_key_to_datum[device_name]:
1592 raise WatchKeyDoesNotExistInDebugDumpDirError(
1593 "Watch key \"%s\" does not exist in the debug dump of device %s" %
1594 (watch_key, device_name))
1596 return self._watch_key_to_dump_size_bytes[device_name][watch_key]
1598 def node_traceback(self, element_name):
1599 """Try to retrieve the Python traceback of node's construction.
1601 Args:
1602 element_name: (`str`) Name of a graph element (node or tensor).
1604 Returns:
1605 (list) The traceback list object as returned by the `extract_trace`
1606 method of Python's traceback module.
1608 Raises:
1609 LookupError: If Python graph is not available for traceback lookup.
1610 KeyError: If the node cannot be found in the Python graph loaded.
1611 """
1613 if self._python_graph is None:
1614 raise LookupError("Python graph is not available for traceback lookup")
1616 node_name = debug_graphs.get_node_name(element_name)
1617 if node_name not in self._node_traceback:
1618 raise KeyError("Cannot find node \"%s\" in Python graph" % node_name)
1620 return self._node_traceback[node_name]