Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/grpc_wrapper.py: 30%

54 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Debugger wrapper session that sends debug data to file:// URLs.""" 

16import signal 

17import sys 

18import traceback 

19 

20# Google-internal import(s). 

21from tensorflow.python.debug.lib import common 

22from tensorflow.python.debug.wrappers import framework 

23 

24 

25def publish_traceback(debug_server_urls, 

26 graph, 

27 feed_dict, 

28 fetches, 

29 old_graph_version): 

30 """Publish traceback and source code if graph version is new. 

31 

32 `graph.version` is compared with `old_graph_version`. If the former is higher 

33 (i.e., newer), the graph traceback and the associated source code is sent to 

34 the debug server at the specified gRPC URLs. 

35 

36 Args: 

37 debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of 

38 debug server URLs. 

39 graph: A Python `tf.Graph` object. 

40 feed_dict: Feed dictionary given to the `Session.run()` call. 

41 fetches: Fetches from the `Session.run()` call. 

42 old_graph_version: Old graph version to compare to. 

43 

44 Returns: 

45 If `graph.version > old_graph_version`, the new graph version as an `int`. 

46 Else, the `old_graph_version` is returned. 

47 """ 

48 # TODO(cais): Consider moving this back to the top, after grpc becomes a 

49 # pip dependency of tensorflow or tf_debug. 

50 # pylint:disable=g-import-not-at-top 

51 from tensorflow.python.debug.lib import source_remote 

52 # pylint:enable=g-import-not-at-top 

53 if graph.version > old_graph_version: 

54 run_key = common.get_run_key(feed_dict, fetches) 

55 source_remote.send_graph_tracebacks( 

56 debug_server_urls, run_key, traceback.extract_stack(), graph, 

57 send_source=True) 

58 return graph.version 

59 else: 

60 return old_graph_version 

61 

62 

63class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): 

64 """Debug Session wrapper that send debug data to gRPC stream(s).""" 

65 

66 def __init__(self, 

67 sess, 

68 grpc_debug_server_addresses, 

69 watch_fn=None, 

70 thread_name_filter=None, 

71 log_usage=True): 

72 """Constructor of DumpingDebugWrapperSession. 

73 

74 Args: 

75 sess: The TensorFlow `Session` object being wrapped. 

76 grpc_debug_server_addresses: (`str` or `list` of `str`) Single or a list 

77 of the gRPC debug server addresses, in the format of 

78 <host:port>, with or without the "grpc://" prefix. For example: 

79 "localhost:7000", 

80 ["localhost:7000", "192.168.0.2:8000"] 

81 watch_fn: (`Callable`) A Callable that can be used to define per-run 

82 debug ops and watched tensors. See the doc of 

83 `NonInteractiveDebugWrapperSession.__init__()` for details. 

84 thread_name_filter: Regular-expression white list for threads on which the 

85 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 

86 more details. 

87 log_usage: (`bool`) whether the usage of this class is to be logged. 

88 

89 Raises: 

90 TypeError: If `grpc_debug_server_addresses` is not a `str` or a `list` 

91 of `str`. 

92 """ 

93 

94 if log_usage: 

95 pass # No logging for open-source. 

96 

97 framework.NonInteractiveDebugWrapperSession.__init__( 

98 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter) 

99 

100 if isinstance(grpc_debug_server_addresses, str): 

101 self._grpc_debug_server_urls = [ 

102 self._normalize_grpc_url(grpc_debug_server_addresses)] 

103 elif isinstance(grpc_debug_server_addresses, list): 

104 self._grpc_debug_server_urls = [] 

105 for address in grpc_debug_server_addresses: 

106 if not isinstance(address, str): 

107 raise TypeError( 

108 "Expected type str in list grpc_debug_server_addresses, " 

109 "received type %s" % type(address)) 

110 self._grpc_debug_server_urls.append(self._normalize_grpc_url(address)) 

111 else: 

112 raise TypeError( 

113 "Expected type str or list in grpc_debug_server_addresses, " 

114 "received type %s" % type(grpc_debug_server_addresses)) 

115 

116 def prepare_run_debug_urls(self, fetches, feed_dict): 

117 """Implementation of abstract method in superclass. 

118 

119 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()` 

120 for details. 

121 

122 Args: 

123 fetches: Same as the `fetches` argument to `Session.run()` 

124 feed_dict: Same as the `feed_dict` argument to `Session.run()` 

125 

126 Returns: 

127 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in 

128 this `Session.run()` call. 

129 """ 

130 

131 return self._grpc_debug_server_urls 

132 

133 def _normalize_grpc_url(self, address): 

134 return (common.GRPC_URL_PREFIX + address 

135 if not address.startswith(common.GRPC_URL_PREFIX) else address) 

136 

137 

138def _signal_handler(unused_signal, unused_frame): 

139 while True: 

140 response = input("\nSIGINT received. Quit program? (Y/n): ").strip() 

141 if response in ("", "Y", "y"): 

142 sys.exit(0) 

143 elif response in ("N", "n"): 

144 break 

145 

146 

147def register_signal_handler(): 

148 try: 

149 signal.signal(signal.SIGINT, _signal_handler) 

150 except ValueError: 

151 # This can happen if we are not in the MainThread. 

152 pass 

153 

154 

155class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): 

156 """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin. 

157 

158 This wrapper is the same as `GrpcDebugWrapperSession`, except that it uses a 

159 predefined `watch_fn` that 

160 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to 

161 `True` to allow the interactive enabling and disabling of tensor 

162 breakpoints. 

163 2) watches all tensors in the graph. 

164 This saves the need for the user to define a `watch_fn`. 

165 """ 

166 

167 def __init__(self, 

168 sess, 

169 grpc_debug_server_addresses, 

170 thread_name_filter=None, 

171 send_traceback_and_source_code=True, 

172 log_usage=True): 

173 """Constructor of TensorBoardDebugWrapperSession. 

174 

175 Args: 

176 sess: The `tf.compat.v1.Session` instance to be wrapped. 

177 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a 

178 `str` or a `list` of `str`s. E.g., "localhost:2333", 

179 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"]. 

180 thread_name_filter: Optional filter for thread names. 

181 send_traceback_and_source_code: Whether traceback of graph elements and 

182 the source code are to be sent to the debug server(s). 

183 log_usage: Whether the usage of this class is to be logged (if 

184 applicable). 

185 """ 

186 def _gated_grpc_watch_fn(fetches, feeds): 

187 del fetches, feeds # Unused. 

188 return framework.WatchOptions( 

189 debug_ops=["DebugIdentity(gated_grpc=true)"]) 

190 

191 super().__init__( 

192 sess, 

193 grpc_debug_server_addresses, 

194 watch_fn=_gated_grpc_watch_fn, 

195 thread_name_filter=thread_name_filter, 

196 log_usage=log_usage) 

197 

198 self._send_traceback_and_source_code = send_traceback_and_source_code 

199 # Keeps track of the latest version of Python graph object that has been 

200 # sent to the debug servers. 

201 self._sent_graph_version = -1 

202 

203 register_signal_handler() 

204 

205 def run(self, 

206 fetches, 

207 feed_dict=None, 

208 options=None, 

209 run_metadata=None, 

210 callable_runner=None, 

211 callable_runner_args=None, 

212 callable_options=None): 

213 if self._send_traceback_and_source_code: 

214 self._sent_graph_version = publish_traceback( 

215 self._grpc_debug_server_urls, self.graph, feed_dict, fetches, 

216 self._sent_graph_version) 

217 return super().run( 

218 fetches, 

219 feed_dict=feed_dict, 

220 options=options, 

221 run_metadata=run_metadata, 

222 callable_runner=callable_runner, 

223 callable_runner_args=callable_runner_args, 

224 callable_options=callable_options)