Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/dumping_wrapper.py: 28%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Debugger wrapper session that dumps debug data to file:// URLs."""
16import os
17import threading
18import time
20# Google-internal import(s).
21from tensorflow.core.util import event_pb2
22from tensorflow.python.debug.lib import debug_data
23from tensorflow.python.debug.wrappers import framework
24from tensorflow.python.platform import gfile
27class DumpingDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
28 """Debug Session wrapper that dumps debug data to filesystem."""
30 def __init__(self,
31 sess,
32 session_root,
33 watch_fn=None,
34 thread_name_filter=None,
35 pass_through_operrors=None,
36 log_usage=True):
37 """Constructor of DumpingDebugWrapperSession.
39 Args:
40 sess: The TensorFlow `Session` object being wrapped.
41 session_root: (`str`) Path to the session root directory. Must be a
42 directory that does not exist or an empty directory. If the directory
43 does not exist, it will be created by the debugger core during debug
44 `tf.Session.run`
45 calls.
46 As the `run()` calls occur, subdirectories will be added to
47 `session_root`. The subdirectories' names has the following pattern:
48 run_<epoch_time_stamp>_<zero_based_run_counter>
49 E.g., run_1480734393835964_ad4c953a85444900ae79fc1b652fb324
50 watch_fn: (`Callable`) A Callable that can be used to define per-run
51 debug ops and watched tensors. See the doc of
52 `NonInteractiveDebugWrapperSession.__init__()` for details.
53 thread_name_filter: Regular-expression white list for threads on which the
54 wrapper session will be active. See doc of `BaseDebugWrapperSession` for
55 more details.
56 pass_through_operrors: If true, all captured OpErrors will be
57 propagated. By default this captures all OpErrors.
58 log_usage: (`bool`) whether the usage of this class is to be logged.
60 Raises:
61 ValueError: If `session_root` is an existing and non-empty directory or
62 if `session_root` is a file.
63 """
65 if log_usage:
66 pass # No logging for open-source.
68 framework.NonInteractiveDebugWrapperSession.__init__(
69 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter,
70 pass_through_operrors=pass_through_operrors)
72 session_root = os.path.expanduser(session_root)
73 if gfile.Exists(session_root):
74 if not gfile.IsDirectory(session_root):
75 raise ValueError(
76 "session_root path points to a file: %s" % session_root)
77 elif gfile.ListDirectory(session_root):
78 raise ValueError(
79 "session_root path points to a non-empty directory: %s" %
80 session_root)
81 else:
82 gfile.MakeDirs(session_root)
83 self._session_root = session_root
85 self._run_counter = 0
86 self._run_counter_lock = threading.Lock()
88 def prepare_run_debug_urls(self, fetches, feed_dict):
89 """Implementation of abstract method in superclass.
91 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()`
92 for details. This implementation creates a run-specific subdirectory under
93 self._session_root and stores information regarding run `fetches` and
94 `feed_dict.keys()` in the subdirectory.
96 Args:
97 fetches: Same as the `fetches` argument to `Session.run()`
98 feed_dict: Same as the `feed_dict` argument to `Session.run()`
100 Returns:
101 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in
102 this `Session.run()` call.
103 """
105 # Add a UUID to accommodate the possibility of concurrent run() calls.
106 self._run_counter_lock.acquire()
107 run_dir = os.path.join(self._session_root, "run_%d_%d" %
108 (int(time.time() * 1e6), self._run_counter))
109 self._run_counter += 1
110 self._run_counter_lock.release()
111 gfile.MkDir(run_dir)
113 fetches_event = event_pb2.Event()
114 fetches_event.log_message.message = repr(fetches)
115 fetches_path = os.path.join(
116 run_dir,
117 debug_data.METADATA_FILE_PREFIX + debug_data.FETCHES_INFO_FILE_TAG)
118 with gfile.Open(os.path.join(fetches_path), "wb") as f:
119 f.write(fetches_event.SerializeToString())
121 feed_keys_event = event_pb2.Event()
122 feed_keys_event.log_message.message = (repr(feed_dict.keys()) if feed_dict
123 else repr(feed_dict))
125 feed_keys_path = os.path.join(
126 run_dir,
127 debug_data.METADATA_FILE_PREFIX + debug_data.FEED_KEYS_INFO_FILE_TAG)
128 with gfile.Open(os.path.join(feed_keys_path), "wb") as f:
129 f.write(feed_keys_event.SerializeToString())
131 return ["file://" + run_dir]