Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/dumping_wrapper.py: 28%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Debugger wrapper session that dumps debug data to file:// URLs.""" 

16import os 

17import threading 

18import time 

19 

20# Google-internal import(s). 

21from tensorflow.core.util import event_pb2 

22from tensorflow.python.debug.lib import debug_data 

23from tensorflow.python.debug.wrappers import framework 

24from tensorflow.python.platform import gfile 

25 

26 

27class DumpingDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): 

28 """Debug Session wrapper that dumps debug data to filesystem.""" 

29 

30 def __init__(self, 

31 sess, 

32 session_root, 

33 watch_fn=None, 

34 thread_name_filter=None, 

35 pass_through_operrors=None, 

36 log_usage=True): 

37 """Constructor of DumpingDebugWrapperSession. 

38 

39 Args: 

40 sess: The TensorFlow `Session` object being wrapped. 

41 session_root: (`str`) Path to the session root directory. Must be a 

42 directory that does not exist or an empty directory. If the directory 

43 does not exist, it will be created by the debugger core during debug 

44 `tf.Session.run` 

45 calls. 

46 As the `run()` calls occur, subdirectories will be added to 

47 `session_root`. The subdirectories' names has the following pattern: 

48 run_<epoch_time_stamp>_<zero_based_run_counter> 

49 E.g., run_1480734393835964_ad4c953a85444900ae79fc1b652fb324 

50 watch_fn: (`Callable`) A Callable that can be used to define per-run 

51 debug ops and watched tensors. See the doc of 

52 `NonInteractiveDebugWrapperSession.__init__()` for details. 

53 thread_name_filter: Regular-expression white list for threads on which the 

54 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 

55 more details. 

56 pass_through_operrors: If true, all captured OpErrors will be 

57 propagated. By default this captures all OpErrors. 

58 log_usage: (`bool`) whether the usage of this class is to be logged. 

59 

60 Raises: 

61 ValueError: If `session_root` is an existing and non-empty directory or 

62 if `session_root` is a file. 

63 """ 

64 

65 if log_usage: 

66 pass # No logging for open-source. 

67 

68 framework.NonInteractiveDebugWrapperSession.__init__( 

69 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter, 

70 pass_through_operrors=pass_through_operrors) 

71 

72 session_root = os.path.expanduser(session_root) 

73 if gfile.Exists(session_root): 

74 if not gfile.IsDirectory(session_root): 

75 raise ValueError( 

76 "session_root path points to a file: %s" % session_root) 

77 elif gfile.ListDirectory(session_root): 

78 raise ValueError( 

79 "session_root path points to a non-empty directory: %s" % 

80 session_root) 

81 else: 

82 gfile.MakeDirs(session_root) 

83 self._session_root = session_root 

84 

85 self._run_counter = 0 

86 self._run_counter_lock = threading.Lock() 

87 

88 def prepare_run_debug_urls(self, fetches, feed_dict): 

89 """Implementation of abstract method in superclass. 

90 

91 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()` 

92 for details. This implementation creates a run-specific subdirectory under 

93 self._session_root and stores information regarding run `fetches` and 

94 `feed_dict.keys()` in the subdirectory. 

95 

96 Args: 

97 fetches: Same as the `fetches` argument to `Session.run()` 

98 feed_dict: Same as the `feed_dict` argument to `Session.run()` 

99 

100 Returns: 

101 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in 

102 this `Session.run()` call. 

103 """ 

104 

105 # Add a UUID to accommodate the possibility of concurrent run() calls. 

106 self._run_counter_lock.acquire() 

107 run_dir = os.path.join(self._session_root, "run_%d_%d" % 

108 (int(time.time() * 1e6), self._run_counter)) 

109 self._run_counter += 1 

110 self._run_counter_lock.release() 

111 gfile.MkDir(run_dir) 

112 

113 fetches_event = event_pb2.Event() 

114 fetches_event.log_message.message = repr(fetches) 

115 fetches_path = os.path.join( 

116 run_dir, 

117 debug_data.METADATA_FILE_PREFIX + debug_data.FETCHES_INFO_FILE_TAG) 

118 with gfile.Open(os.path.join(fetches_path), "wb") as f: 

119 f.write(fetches_event.SerializeToString()) 

120 

121 feed_keys_event = event_pb2.Event() 

122 feed_keys_event.log_message.message = (repr(feed_dict.keys()) if feed_dict 

123 else repr(feed_dict)) 

124 

125 feed_keys_path = os.path.join( 

126 run_dir, 

127 debug_data.METADATA_FILE_PREFIX + debug_data.FEED_KEYS_INFO_FILE_TAG) 

128 with gfile.Open(os.path.join(feed_keys_path), "wb") as f: 

129 f.write(feed_keys_event.SerializeToString()) 

130 

131 return ["file://" + run_dir]