Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/hooks.py: 27%
92 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tfdbg CLI as SessionRunHook."""
17from tensorflow.core.protobuf import config_pb2
18from tensorflow.python.debug.lib import debug_utils
19from tensorflow.python.debug.wrappers import dumping_wrapper
20from tensorflow.python.debug.wrappers import framework
21from tensorflow.python.debug.wrappers import grpc_wrapper
22from tensorflow.python.debug.wrappers import local_cli_wrapper
23from tensorflow.python.training import session_run_hook
26class LocalCLIDebugHook(session_run_hook.SessionRunHook):
27 """Command-line-interface debugger hook.
29 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
30 `tf.estimator.Estimator`s. Provides a substitute for
31 `tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly
32 available.
33 """
35 def __init__(self,
36 ui_type="curses",
37 dump_root=None,
38 thread_name_filter=None,
39 config_file_path=None):
40 """Create a local debugger command-line interface (CLI) hook.
42 Args:
43 ui_type: (`str`) requested user-interface type. Currently supported:
44 (curses | readline).
45 dump_root: (`str`) optional path to the dump root directory. Must be a
46 directory that does not exist or an empty directory. If the directory
47 does not exist, it will be created by the debugger core during debug
48 `run()` calls and removed afterwards.
49 thread_name_filter: Regular-expression white list for threads on which the
50 wrapper session will be active. See doc of `BaseDebugWrapperSession` for
51 more details.
52 config_file_path: Optional override to the default configuration file
53 path, which is at `${HOME}/.tfdbg_config`.
54 """
56 self._ui_type = ui_type
57 self._dump_root = dump_root
58 self._thread_name_filter = thread_name_filter
59 self._session_wrapper = None
60 self._pending_tensor_filters = {}
61 self._config_file_path = config_file_path
63 def add_tensor_filter(self, filter_name, tensor_filter):
64 """Add a tensor filter.
66 See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
67 Override default behavior to accommodate the possibility of this method
68 being
69 called prior to the initialization of the underlying
70 `LocalCLIDebugWrapperSession` object.
72 Args:
73 filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()`
74 for details.
75 tensor_filter: See doc of
76 `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
77 """
79 if self._session_wrapper:
80 self._session_wrapper.add_tensor_filter(filter_name, tensor_filter)
81 else:
82 self._pending_tensor_filters[filter_name] = tensor_filter
84 def begin(self):
85 pass
87 def before_run(self, run_context):
88 if not self._session_wrapper:
89 self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
90 run_context.session,
91 ui_type=self._ui_type,
92 dump_root=self._dump_root,
93 thread_name_filter=self._thread_name_filter,
94 config_file_path=self._config_file_path)
96 # Actually register tensor filters registered prior to the construction
97 # of the underlying LocalCLIDebugWrapperSession object.
98 for filter_name in self._pending_tensor_filters:
99 self._session_wrapper.add_tensor_filter(
100 filter_name, self._pending_tensor_filters[filter_name])
102 # Increment run call counter.
103 self._session_wrapper.increment_run_call_count()
105 # Adapt run_context to an instance of OnRunStartRequest for invoking
106 # superclass on_run_start().
107 on_run_start_request = framework.OnRunStartRequest(
108 run_context.original_args.fetches, run_context.original_args.feed_dict,
109 None, None, self._session_wrapper.run_call_count)
111 on_run_start_response = self._session_wrapper.on_run_start(
112 on_run_start_request)
113 self._performed_action = on_run_start_response.action
115 run_args = session_run_hook.SessionRunArgs(
116 None, feed_dict=None, options=config_pb2.RunOptions())
117 if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
118 # pylint: disable=protected-access
119 self._session_wrapper._decorate_run_options_for_debug(
120 run_args.options,
121 on_run_start_response.debug_urls,
122 debug_ops=on_run_start_response.debug_ops,
123 node_name_regex_allowlist=(
124 on_run_start_response.node_name_regex_allowlist),
125 op_type_regex_allowlist=(
126 on_run_start_response.op_type_regex_allowlist),
127 tensor_dtype_regex_allowlist=(
128 on_run_start_response.tensor_dtype_regex_allowlist),
129 tolerate_debug_op_creation_failures=(
130 on_run_start_response.tolerate_debug_op_creation_failures))
131 # pylint: enable=protected-access
132 elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
133 # pylint: disable=protected-access
134 self._session_wrapper._decorate_run_options_for_profile(run_args.options)
135 # pylint: enable=protected-access
137 return run_args
139 def after_run(self, run_context, run_values):
140 # Adapt run_context and run_values to OnRunEndRequest and invoke superclass
141 # on_run_end()
142 on_run_end_request = framework.OnRunEndRequest(self._performed_action,
143 run_values.run_metadata)
144 self._session_wrapper.on_run_end(on_run_end_request)
147class DumpingDebugHook(session_run_hook.SessionRunHook):
148 """A debugger hook that dumps debug data to filesystem.
150 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
151 `tf.estimator.Estimator`s.
152 """
154 def __init__(self,
155 session_root,
156 watch_fn=None,
157 thread_name_filter=None,
158 log_usage=True):
159 """Create a local debugger command-line interface (CLI) hook.
161 Args:
162 session_root: See doc of
163 `dumping_wrapper.DumpingDebugWrapperSession.__init__`.
164 watch_fn: See doc of
165 `dumping_wrapper.DumpingDebugWrapperSession.__init__`.
166 thread_name_filter: Regular-expression white list for threads on which the
167 wrapper session will be active. See doc of `BaseDebugWrapperSession` for
168 more details.
169 log_usage: (bool) Whether usage is to be logged.
170 """
172 self._session_root = session_root
173 self._watch_fn = watch_fn
174 self._thread_name_filter = thread_name_filter
175 self._log_usage = log_usage
176 self._session_wrapper = None
178 def begin(self):
179 pass
181 def before_run(self, run_context):
182 reset_disk_byte_usage = False
183 if not self._session_wrapper:
184 self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
185 run_context.session,
186 self._session_root,
187 watch_fn=self._watch_fn,
188 thread_name_filter=self._thread_name_filter,
189 log_usage=self._log_usage)
190 reset_disk_byte_usage = True
192 self._session_wrapper.increment_run_call_count()
194 # pylint: disable=protected-access
195 debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
196 run_context.original_args.fetches, run_context.original_args.feed_dict)
197 # pylint: enable=protected-access
198 run_options = config_pb2.RunOptions()
199 debug_utils.watch_graph(
200 run_options,
201 run_context.session.graph,
202 debug_urls=debug_urls,
203 debug_ops=watch_options.debug_ops,
204 node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
205 op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
206 tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
207 tolerate_debug_op_creation_failures=(
208 watch_options.tolerate_debug_op_creation_failures),
209 reset_disk_byte_usage=reset_disk_byte_usage)
211 run_args = session_run_hook.SessionRunArgs(
212 None, feed_dict=None, options=run_options)
213 return run_args
215 def after_run(self, run_context, run_values):
216 pass
219class GrpcDebugHook(session_run_hook.SessionRunHook):
220 """A hook that streams debugger-related events to any grpc_debug_server.
222 For example, the debugger data server is a grpc_debug_server. The debugger
223 data server writes debugger-related events it receives via GRPC to logdir.
224 This enables debugging features in Tensorboard such as health pills.
226 When the arguments of debug_utils.watch_graph changes, strongly consider
227 changing arguments here too so that features are available to tflearn users.
229 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
230 `tf.estimator.Estimator`s.
231 """
233 def __init__(self,
234 grpc_debug_server_addresses,
235 watch_fn=None,
236 thread_name_filter=None,
237 log_usage=True):
238 """Constructs a GrpcDebugHook.
240 Args:
241 grpc_debug_server_addresses: (`list` of `str`) A list of the gRPC debug
242 server addresses, in the format of <host:port>, with or without the
243 "grpc://" prefix. For example: ["localhost:7000", "192.168.0.2:8000"]
244 watch_fn: A function that allows for customizing which ops to watch at
245 which specific steps. See doc of
246 `dumping_wrapper.DumpingDebugWrapperSession.__init__` for details.
247 thread_name_filter: Regular-expression white list for threads on which the
248 wrapper session will be active. See doc of `BaseDebugWrapperSession` for
249 more details.
250 log_usage: (bool) Whether usage is to be logged.
251 """
252 self._grpc_debug_wrapper_session = None
253 self._thread_name_filter = thread_name_filter
254 self._grpc_debug_server_addresses = (
255 grpc_debug_server_addresses
256 if isinstance(grpc_debug_server_addresses, list) else
257 [grpc_debug_server_addresses])
259 self._watch_fn = watch_fn
260 self._log_usage = log_usage
262 def before_run(self, run_context):
263 """Called right before a session is run.
265 Args:
266 run_context: A session_run_hook.SessionRunContext. Encapsulates
267 information on the run.
269 Returns:
270 A session_run_hook.SessionRunArgs object.
271 """
273 if not self._grpc_debug_wrapper_session:
274 self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession(
275 run_context.session,
276 self._grpc_debug_server_addresses,
277 watch_fn=self._watch_fn,
278 thread_name_filter=self._thread_name_filter,
279 log_usage=self._log_usage)
281 fetches = run_context.original_args.fetches
282 feed_dict = run_context.original_args.feed_dict
283 watch_options = self._watch_fn(fetches, feed_dict)
284 run_options = config_pb2.RunOptions()
285 debug_utils.watch_graph(
286 run_options,
287 run_context.session.graph,
288 debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
289 fetches, feed_dict),
290 debug_ops=watch_options.debug_ops,
291 node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
292 op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
293 tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
294 tolerate_debug_op_creation_failures=(
295 watch_options.tolerate_debug_op_creation_failures))
297 return session_run_hook.SessionRunArgs(
298 None, feed_dict=None, options=run_options)
301class TensorBoardDebugHook(GrpcDebugHook):
302 """A tfdbg hook that can be used with TensorBoard Debugger Plugin.
304 This hook is the same as `GrpcDebugHook`, except that it uses a predefined
305 `watch_fn` that
306 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to
307 `True`, to allow the interactive enabling and disabling of tensor
308 breakpoints.
309 2) watches all tensors in the graph.
310 This saves the need for the user to define a `watch_fn`.
311 """
313 def __init__(self,
314 grpc_debug_server_addresses,
315 thread_name_filter=None,
316 send_traceback_and_source_code=True,
317 log_usage=True):
318 """Constructor of TensorBoardDebugHook.
320 Args:
321 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a
322 `str` or a `list` of `str`s. E.g., "localhost:2333",
323 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"].
324 thread_name_filter: Optional filter for thread names.
325 send_traceback_and_source_code: Whether traceback of graph elements and
326 the source code are to be sent to the debug server(s).
327 log_usage: Whether the usage of this class is to be logged (if
328 applicable).
329 """
331 def _gated_grpc_watch_fn(fetches, feeds):
332 del fetches, feeds # Unused.
333 return framework.WatchOptions(
334 debug_ops=["DebugIdentity(gated_grpc=true)"])
336 super(TensorBoardDebugHook, self).__init__(
337 grpc_debug_server_addresses,
338 watch_fn=_gated_grpc_watch_fn,
339 thread_name_filter=thread_name_filter,
340 log_usage=log_usage)
342 self._grpc_debug_server_addresses = grpc_debug_server_addresses
343 self._send_traceback_and_source_code = send_traceback_and_source_code
344 self._sent_graph_version = -1
345 grpc_wrapper.register_signal_handler()
347 def before_run(self, run_context):
348 if self._send_traceback_and_source_code:
349 self._sent_graph_version = grpc_wrapper.publish_traceback(
350 self._grpc_debug_server_addresses, run_context.session.graph,
351 run_context.original_args.feed_dict,
352 run_context.original_args.fetches, self._sent_graph_version)
353 return super(TensorBoardDebugHook, self).before_run(run_context)