Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/grpc_wrapper.py: 30%
54 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Debugger wrapper session that sends debug data to file:// URLs."""
16import signal
17import sys
18import traceback
20# Google-internal import(s).
21from tensorflow.python.debug.lib import common
22from tensorflow.python.debug.wrappers import framework
25def publish_traceback(debug_server_urls,
26 graph,
27 feed_dict,
28 fetches,
29 old_graph_version):
30 """Publish traceback and source code if graph version is new.
32 `graph.version` is compared with `old_graph_version`. If the former is higher
33 (i.e., newer), the graph traceback and the associated source code is sent to
34 the debug server at the specified gRPC URLs.
36 Args:
37 debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
38 debug server URLs.
39 graph: A Python `tf.Graph` object.
40 feed_dict: Feed dictionary given to the `Session.run()` call.
41 fetches: Fetches from the `Session.run()` call.
42 old_graph_version: Old graph version to compare to.
44 Returns:
45 If `graph.version > old_graph_version`, the new graph version as an `int`.
46 Else, the `old_graph_version` is returned.
47 """
48 # TODO(cais): Consider moving this back to the top, after grpc becomes a
49 # pip dependency of tensorflow or tf_debug.
50 # pylint:disable=g-import-not-at-top
51 from tensorflow.python.debug.lib import source_remote
52 # pylint:enable=g-import-not-at-top
53 if graph.version > old_graph_version:
54 run_key = common.get_run_key(feed_dict, fetches)
55 source_remote.send_graph_tracebacks(
56 debug_server_urls, run_key, traceback.extract_stack(), graph,
57 send_source=True)
58 return graph.version
59 else:
60 return old_graph_version
63class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
64 """Debug Session wrapper that send debug data to gRPC stream(s)."""
66 def __init__(self,
67 sess,
68 grpc_debug_server_addresses,
69 watch_fn=None,
70 thread_name_filter=None,
71 log_usage=True):
72 """Constructor of DumpingDebugWrapperSession.
74 Args:
75 sess: The TensorFlow `Session` object being wrapped.
76 grpc_debug_server_addresses: (`str` or `list` of `str`) Single or a list
77 of the gRPC debug server addresses, in the format of
78 <host:port>, with or without the "grpc://" prefix. For example:
79 "localhost:7000",
80 ["localhost:7000", "192.168.0.2:8000"]
81 watch_fn: (`Callable`) A Callable that can be used to define per-run
82 debug ops and watched tensors. See the doc of
83 `NonInteractiveDebugWrapperSession.__init__()` for details.
84 thread_name_filter: Regular-expression white list for threads on which the
85 wrapper session will be active. See doc of `BaseDebugWrapperSession` for
86 more details.
87 log_usage: (`bool`) whether the usage of this class is to be logged.
89 Raises:
90 TypeError: If `grpc_debug_server_addresses` is not a `str` or a `list`
91 of `str`.
92 """
94 if log_usage:
95 pass # No logging for open-source.
97 framework.NonInteractiveDebugWrapperSession.__init__(
98 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter)
100 if isinstance(grpc_debug_server_addresses, str):
101 self._grpc_debug_server_urls = [
102 self._normalize_grpc_url(grpc_debug_server_addresses)]
103 elif isinstance(grpc_debug_server_addresses, list):
104 self._grpc_debug_server_urls = []
105 for address in grpc_debug_server_addresses:
106 if not isinstance(address, str):
107 raise TypeError(
108 "Expected type str in list grpc_debug_server_addresses, "
109 "received type %s" % type(address))
110 self._grpc_debug_server_urls.append(self._normalize_grpc_url(address))
111 else:
112 raise TypeError(
113 "Expected type str or list in grpc_debug_server_addresses, "
114 "received type %s" % type(grpc_debug_server_addresses))
116 def prepare_run_debug_urls(self, fetches, feed_dict):
117 """Implementation of abstract method in superclass.
119 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()`
120 for details.
122 Args:
123 fetches: Same as the `fetches` argument to `Session.run()`
124 feed_dict: Same as the `feed_dict` argument to `Session.run()`
126 Returns:
127 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in
128 this `Session.run()` call.
129 """
131 return self._grpc_debug_server_urls
133 def _normalize_grpc_url(self, address):
134 return (common.GRPC_URL_PREFIX + address
135 if not address.startswith(common.GRPC_URL_PREFIX) else address)
138def _signal_handler(unused_signal, unused_frame):
139 while True:
140 response = input("\nSIGINT received. Quit program? (Y/n): ").strip()
141 if response in ("", "Y", "y"):
142 sys.exit(0)
143 elif response in ("N", "n"):
144 break
147def register_signal_handler():
148 try:
149 signal.signal(signal.SIGINT, _signal_handler)
150 except ValueError:
151 # This can happen if we are not in the MainThread.
152 pass
155class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
156 """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin.
158 This wrapper is the same as `GrpcDebugWrapperSession`, except that it uses a
159 predefined `watch_fn` that
160 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to
161 `True` to allow the interactive enabling and disabling of tensor
162 breakpoints.
163 2) watches all tensors in the graph.
164 This saves the need for the user to define a `watch_fn`.
165 """
167 def __init__(self,
168 sess,
169 grpc_debug_server_addresses,
170 thread_name_filter=None,
171 send_traceback_and_source_code=True,
172 log_usage=True):
173 """Constructor of TensorBoardDebugWrapperSession.
175 Args:
176 sess: The `tf.compat.v1.Session` instance to be wrapped.
177 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a
178 `str` or a `list` of `str`s. E.g., "localhost:2333",
179 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"].
180 thread_name_filter: Optional filter for thread names.
181 send_traceback_and_source_code: Whether traceback of graph elements and
182 the source code are to be sent to the debug server(s).
183 log_usage: Whether the usage of this class is to be logged (if
184 applicable).
185 """
186 def _gated_grpc_watch_fn(fetches, feeds):
187 del fetches, feeds # Unused.
188 return framework.WatchOptions(
189 debug_ops=["DebugIdentity(gated_grpc=true)"])
191 super().__init__(
192 sess,
193 grpc_debug_server_addresses,
194 watch_fn=_gated_grpc_watch_fn,
195 thread_name_filter=thread_name_filter,
196 log_usage=log_usage)
198 self._send_traceback_and_source_code = send_traceback_and_source_code
199 # Keeps track of the latest version of Python graph object that has been
200 # sent to the debug servers.
201 self._sent_graph_version = -1
203 register_signal_handler()
205 def run(self,
206 fetches,
207 feed_dict=None,
208 options=None,
209 run_metadata=None,
210 callable_runner=None,
211 callable_runner_args=None,
212 callable_options=None):
213 if self._send_traceback_and_source_code:
214 self._sent_graph_version = publish_traceback(
215 self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
216 self._sent_graph_version)
217 return super().run(
218 fetches,
219 feed_dict=feed_dict,
220 options=options,
221 run_metadata=run_metadata,
222 callable_runner=callable_runner,
223 callable_runner_args=callable_runner_args,
224 callable_options=callable_options)