Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/hooks.py: 27%

92 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""tfdbg CLI as SessionRunHook.""" 

16 

17from tensorflow.core.protobuf import config_pb2 

18from tensorflow.python.debug.lib import debug_utils 

19from tensorflow.python.debug.wrappers import dumping_wrapper 

20from tensorflow.python.debug.wrappers import framework 

21from tensorflow.python.debug.wrappers import grpc_wrapper 

22from tensorflow.python.debug.wrappers import local_cli_wrapper 

23from tensorflow.python.training import session_run_hook 

24 

25 

26class LocalCLIDebugHook(session_run_hook.SessionRunHook): 

27 """Command-line-interface debugger hook. 

28 

29 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and 

30 `tf.estimator.Estimator`s. Provides a substitute for 

31 `tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly 

32 available. 

33 """ 

34 

35 def __init__(self, 

36 ui_type="curses", 

37 dump_root=None, 

38 thread_name_filter=None, 

39 config_file_path=None): 

40 """Create a local debugger command-line interface (CLI) hook. 

41 

42 Args: 

43 ui_type: (`str`) requested user-interface type. Currently supported: 

44 (curses | readline). 

45 dump_root: (`str`) optional path to the dump root directory. Must be a 

46 directory that does not exist or an empty directory. If the directory 

47 does not exist, it will be created by the debugger core during debug 

48 `run()` calls and removed afterwards. 

49 thread_name_filter: Regular-expression white list for threads on which the 

50 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 

51 more details. 

52 config_file_path: Optional override to the default configuration file 

53 path, which is at `${HOME}/.tfdbg_config`. 

54 """ 

55 

56 self._ui_type = ui_type 

57 self._dump_root = dump_root 

58 self._thread_name_filter = thread_name_filter 

59 self._session_wrapper = None 

60 self._pending_tensor_filters = {} 

61 self._config_file_path = config_file_path 

62 

63 def add_tensor_filter(self, filter_name, tensor_filter): 

64 """Add a tensor filter. 

65 

66 See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. 

67 Override default behavior to accommodate the possibility of this method 

68 being 

69 called prior to the initialization of the underlying 

70 `LocalCLIDebugWrapperSession` object. 

71 

72 Args: 

73 filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` 

74 for details. 

75 tensor_filter: See doc of 

76 `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. 

77 """ 

78 

79 if self._session_wrapper: 

80 self._session_wrapper.add_tensor_filter(filter_name, tensor_filter) 

81 else: 

82 self._pending_tensor_filters[filter_name] = tensor_filter 

83 

84 def begin(self): 

85 pass 

86 

87 def before_run(self, run_context): 

88 if not self._session_wrapper: 

89 self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession( 

90 run_context.session, 

91 ui_type=self._ui_type, 

92 dump_root=self._dump_root, 

93 thread_name_filter=self._thread_name_filter, 

94 config_file_path=self._config_file_path) 

95 

96 # Actually register tensor filters registered prior to the construction 

97 # of the underlying LocalCLIDebugWrapperSession object. 

98 for filter_name in self._pending_tensor_filters: 

99 self._session_wrapper.add_tensor_filter( 

100 filter_name, self._pending_tensor_filters[filter_name]) 

101 

102 # Increment run call counter. 

103 self._session_wrapper.increment_run_call_count() 

104 

105 # Adapt run_context to an instance of OnRunStartRequest for invoking 

106 # superclass on_run_start(). 

107 on_run_start_request = framework.OnRunStartRequest( 

108 run_context.original_args.fetches, run_context.original_args.feed_dict, 

109 None, None, self._session_wrapper.run_call_count) 

110 

111 on_run_start_response = self._session_wrapper.on_run_start( 

112 on_run_start_request) 

113 self._performed_action = on_run_start_response.action 

114 

115 run_args = session_run_hook.SessionRunArgs( 

116 None, feed_dict=None, options=config_pb2.RunOptions()) 

117 if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: 

118 # pylint: disable=protected-access 

119 self._session_wrapper._decorate_run_options_for_debug( 

120 run_args.options, 

121 on_run_start_response.debug_urls, 

122 debug_ops=on_run_start_response.debug_ops, 

123 node_name_regex_allowlist=( 

124 on_run_start_response.node_name_regex_allowlist), 

125 op_type_regex_allowlist=( 

126 on_run_start_response.op_type_regex_allowlist), 

127 tensor_dtype_regex_allowlist=( 

128 on_run_start_response.tensor_dtype_regex_allowlist), 

129 tolerate_debug_op_creation_failures=( 

130 on_run_start_response.tolerate_debug_op_creation_failures)) 

131 # pylint: enable=protected-access 

132 elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN: 

133 # pylint: disable=protected-access 

134 self._session_wrapper._decorate_run_options_for_profile(run_args.options) 

135 # pylint: enable=protected-access 

136 

137 return run_args 

138 

139 def after_run(self, run_context, run_values): 

140 # Adapt run_context and run_values to OnRunEndRequest and invoke superclass 

141 # on_run_end() 

142 on_run_end_request = framework.OnRunEndRequest(self._performed_action, 

143 run_values.run_metadata) 

144 self._session_wrapper.on_run_end(on_run_end_request) 

145 

146 

147class DumpingDebugHook(session_run_hook.SessionRunHook): 

148 """A debugger hook that dumps debug data to filesystem. 

149 

150 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and 

151 `tf.estimator.Estimator`s. 

152 """ 

153 

154 def __init__(self, 

155 session_root, 

156 watch_fn=None, 

157 thread_name_filter=None, 

158 log_usage=True): 

159 """Create a local debugger command-line interface (CLI) hook. 

160 

161 Args: 

162 session_root: See doc of 

163 `dumping_wrapper.DumpingDebugWrapperSession.__init__`. 

164 watch_fn: See doc of 

165 `dumping_wrapper.DumpingDebugWrapperSession.__init__`. 

166 thread_name_filter: Regular-expression white list for threads on which the 

167 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 

168 more details. 

169 log_usage: (bool) Whether usage is to be logged. 

170 """ 

171 

172 self._session_root = session_root 

173 self._watch_fn = watch_fn 

174 self._thread_name_filter = thread_name_filter 

175 self._log_usage = log_usage 

176 self._session_wrapper = None 

177 

178 def begin(self): 

179 pass 

180 

181 def before_run(self, run_context): 

182 reset_disk_byte_usage = False 

183 if not self._session_wrapper: 

184 self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession( 

185 run_context.session, 

186 self._session_root, 

187 watch_fn=self._watch_fn, 

188 thread_name_filter=self._thread_name_filter, 

189 log_usage=self._log_usage) 

190 reset_disk_byte_usage = True 

191 

192 self._session_wrapper.increment_run_call_count() 

193 

194 # pylint: disable=protected-access 

195 debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config( 

196 run_context.original_args.fetches, run_context.original_args.feed_dict) 

197 # pylint: enable=protected-access 

198 run_options = config_pb2.RunOptions() 

199 debug_utils.watch_graph( 

200 run_options, 

201 run_context.session.graph, 

202 debug_urls=debug_urls, 

203 debug_ops=watch_options.debug_ops, 

204 node_name_regex_allowlist=watch_options.node_name_regex_allowlist, 

205 op_type_regex_allowlist=watch_options.op_type_regex_allowlist, 

206 tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist, 

207 tolerate_debug_op_creation_failures=( 

208 watch_options.tolerate_debug_op_creation_failures), 

209 reset_disk_byte_usage=reset_disk_byte_usage) 

210 

211 run_args = session_run_hook.SessionRunArgs( 

212 None, feed_dict=None, options=run_options) 

213 return run_args 

214 

215 def after_run(self, run_context, run_values): 

216 pass 

217 

218 

219class GrpcDebugHook(session_run_hook.SessionRunHook): 

220 """A hook that streams debugger-related events to any grpc_debug_server. 

221 

222 For example, the debugger data server is a grpc_debug_server. The debugger 

223 data server writes debugger-related events it receives via GRPC to logdir. 

224 This enables debugging features in Tensorboard such as health pills. 

225 

226 When the arguments of debug_utils.watch_graph changes, strongly consider 

227 changing arguments here too so that features are available to tflearn users. 

228 

229 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and 

230 `tf.estimator.Estimator`s. 

231 """ 

232 

233 def __init__(self, 

234 grpc_debug_server_addresses, 

235 watch_fn=None, 

236 thread_name_filter=None, 

237 log_usage=True): 

238 """Constructs a GrpcDebugHook. 

239 

240 Args: 

241 grpc_debug_server_addresses: (`list` of `str`) A list of the gRPC debug 

242 server addresses, in the format of <host:port>, with or without the 

243 "grpc://" prefix. For example: ["localhost:7000", "192.168.0.2:8000"] 

244 watch_fn: A function that allows for customizing which ops to watch at 

245 which specific steps. See doc of 

246 `dumping_wrapper.DumpingDebugWrapperSession.__init__` for details. 

247 thread_name_filter: Regular-expression white list for threads on which the 

248 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 

249 more details. 

250 log_usage: (bool) Whether usage is to be logged. 

251 """ 

252 self._grpc_debug_wrapper_session = None 

253 self._thread_name_filter = thread_name_filter 

254 self._grpc_debug_server_addresses = ( 

255 grpc_debug_server_addresses 

256 if isinstance(grpc_debug_server_addresses, list) else 

257 [grpc_debug_server_addresses]) 

258 

259 self._watch_fn = watch_fn 

260 self._log_usage = log_usage 

261 

262 def before_run(self, run_context): 

263 """Called right before a session is run. 

264 

265 Args: 

266 run_context: A session_run_hook.SessionRunContext. Encapsulates 

267 information on the run. 

268 

269 Returns: 

270 A session_run_hook.SessionRunArgs object. 

271 """ 

272 

273 if not self._grpc_debug_wrapper_session: 

274 self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession( 

275 run_context.session, 

276 self._grpc_debug_server_addresses, 

277 watch_fn=self._watch_fn, 

278 thread_name_filter=self._thread_name_filter, 

279 log_usage=self._log_usage) 

280 

281 fetches = run_context.original_args.fetches 

282 feed_dict = run_context.original_args.feed_dict 

283 watch_options = self._watch_fn(fetches, feed_dict) 

284 run_options = config_pb2.RunOptions() 

285 debug_utils.watch_graph( 

286 run_options, 

287 run_context.session.graph, 

288 debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls( 

289 fetches, feed_dict), 

290 debug_ops=watch_options.debug_ops, 

291 node_name_regex_allowlist=watch_options.node_name_regex_allowlist, 

292 op_type_regex_allowlist=watch_options.op_type_regex_allowlist, 

293 tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist, 

294 tolerate_debug_op_creation_failures=( 

295 watch_options.tolerate_debug_op_creation_failures)) 

296 

297 return session_run_hook.SessionRunArgs( 

298 None, feed_dict=None, options=run_options) 

299 

300 

301class TensorBoardDebugHook(GrpcDebugHook): 

302 """A tfdbg hook that can be used with TensorBoard Debugger Plugin. 

303 

304 This hook is the same as `GrpcDebugHook`, except that it uses a predefined 

305 `watch_fn` that 

306 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to 

307 `True`, to allow the interactive enabling and disabling of tensor 

308 breakpoints. 

309 2) watches all tensors in the graph. 

310 This saves the need for the user to define a `watch_fn`. 

311 """ 

312 

313 def __init__(self, 

314 grpc_debug_server_addresses, 

315 thread_name_filter=None, 

316 send_traceback_and_source_code=True, 

317 log_usage=True): 

318 """Constructor of TensorBoardDebugHook. 

319 

320 Args: 

321 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a 

322 `str` or a `list` of `str`s. E.g., "localhost:2333", 

323 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"]. 

324 thread_name_filter: Optional filter for thread names. 

325 send_traceback_and_source_code: Whether traceback of graph elements and 

326 the source code are to be sent to the debug server(s). 

327 log_usage: Whether the usage of this class is to be logged (if 

328 applicable). 

329 """ 

330 

331 def _gated_grpc_watch_fn(fetches, feeds): 

332 del fetches, feeds # Unused. 

333 return framework.WatchOptions( 

334 debug_ops=["DebugIdentity(gated_grpc=true)"]) 

335 

336 super(TensorBoardDebugHook, self).__init__( 

337 grpc_debug_server_addresses, 

338 watch_fn=_gated_grpc_watch_fn, 

339 thread_name_filter=thread_name_filter, 

340 log_usage=log_usage) 

341 

342 self._grpc_debug_server_addresses = grpc_debug_server_addresses 

343 self._send_traceback_and_source_code = send_traceback_and_source_code 

344 self._sent_graph_version = -1 

345 grpc_wrapper.register_signal_handler() 

346 

347 def before_run(self, run_context): 

348 if self._send_traceback_and_source_code: 

349 self._sent_graph_version = grpc_wrapper.publish_traceback( 

350 self._grpc_debug_server_addresses, run_context.session.graph, 

351 run_context.original_args.feed_dict, 

352 run_context.original_args.fetches, self._sent_graph_version) 

353 return super(TensorBoardDebugHook, self).before_run(run_context)