Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py: 20%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities related to Keras exception stack trace prettifying.""" 

16 

17import inspect 

18import os 

19import sys 

20import traceback 

21import types 

22 

23import tensorflow.compat.v2 as tf 

24 

25_EXCLUDED_PATHS = ( 

26 os.path.abspath(os.path.join(__file__, "..", "..")), 

27 os.path.join("tensorflow", "python"), 

28) 

29 

30 

31def include_frame(fname): 

32 for exclusion in _EXCLUDED_PATHS: 

33 if exclusion in fname: 

34 return False 

35 return True 

36 

37 

38def _process_traceback_frames(tb): 

39 """Iterate through traceback frames and return a new, filtered traceback.""" 

40 last_tb = None 

41 tb_list = list(traceback.walk_tb(tb)) 

42 for f, line_no in reversed(tb_list): 

43 if include_frame(f.f_code.co_filename): 

44 last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no) 

45 if last_tb is None and tb_list: 

46 # If no frames were kept during filtering, create a new traceback 

47 # from the outermost function. 

48 f, line_no = tb_list[-1] 

49 last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no) 

50 return last_tb 

51 

52 

53def filter_traceback(fn): 

54 """Filter out Keras-internal stack trace frames in exceptions raised by 

55 fn.""" 

56 if sys.version_info.major != 3 or sys.version_info.minor < 7: 

57 return fn 

58 

59 def error_handler(*args, **kwargs): 

60 if not tf.debugging.is_traceback_filtering_enabled(): 

61 return fn(*args, **kwargs) 

62 

63 filtered_tb = None 

64 try: 

65 return fn(*args, **kwargs) 

66 except Exception as e: 

67 filtered_tb = _process_traceback_frames(e.__traceback__) 

68 # To get the full stack trace, call: 

69 # `tf.debugging.disable_traceback_filtering()` 

70 raise e.with_traceback(filtered_tb) from None 

71 finally: 

72 del filtered_tb 

73 

74 return tf.__internal__.decorator.make_decorator(fn, error_handler) 

75 

76 

77def inject_argument_info_in_traceback(fn, object_name=None): 

78 """Add information about call argument values to an error message. 

79 

80 Arguments: 

81 fn: Function to wrap. Exceptions raised by the this function will be 

82 re-raised with additional information added to the error message, 

83 displaying the values of the different arguments that the function 

84 was called with. 

85 object_name: String, display name of the class/function being called, 

86 e.g. `'layer "layer_name" (LayerClass)'`. 

87 

88 Returns: 

89 A wrapped version of `fn`. 

90 """ 

91 

92 def error_handler(*args, **kwargs): 

93 signature = None 

94 bound_signature = None 

95 try: 

96 return fn(*args, **kwargs) 

97 except Exception as e: 

98 if hasattr(e, "_keras_call_info_injected"): 

99 # Only inject info for the innermost failing call 

100 raise e 

101 signature = inspect.signature(fn) 

102 try: 

103 # The first argument is `self`, so filter it out 

104 bound_signature = signature.bind(*args, **kwargs) 

105 except TypeError: 

106 # Likely unbindable arguments 

107 raise e 

108 

109 # Add argument context 

110 arguments_context = [] 

111 for arg in list(signature.parameters.values()): 

112 if arg.name in bound_signature.arguments: 

113 value = tf.nest.map_structure( 

114 format_argument_value, 

115 bound_signature.arguments[arg.name], 

116 ) 

117 else: 

118 value = arg.default 

119 arguments_context.append(f" • {arg.name}={value}") 

120 

121 if arguments_context: 

122 arguments_context = "\n".join(arguments_context) 

123 # Get original error message and append information to it. 

124 if isinstance(e, tf.errors.OpError): 

125 message = e.message 

126 elif e.args: 

127 # Canonically, the 1st argument in an exception is the error 

128 # message. This works for all built-in Python exceptions. 

129 message = e.args[0] 

130 else: 

131 message = "" 

132 display_name = f"{object_name if object_name else fn.__name__}" 

133 message = ( 

134 f"Exception encountered when calling {display_name}.\n\n" 

135 f"{message}\n\n" 

136 f"Call arguments received by {display_name}:\n" 

137 f"{arguments_context}" 

138 ) 

139 

140 # Reraise exception, with added context 

141 if isinstance(e, tf.errors.OpError): 

142 new_e = e.__class__(e.node_def, e.op, message, e.error_code) 

143 else: 

144 try: 

145 # For standard exceptions such as ValueError, TypeError, 

146 # etc. 

147 new_e = e.__class__(message) 

148 except TypeError: 

149 # For any custom error that doesn't have a standard 

150 # signature. 

151 new_e = RuntimeError(message) 

152 new_e._keras_call_info_injected = True 

153 else: 

154 new_e = e 

155 raise new_e.with_traceback(e.__traceback__) from None 

156 finally: 

157 del signature 

158 del bound_signature 

159 

160 return tf.__internal__.decorator.make_decorator(fn, error_handler) 

161 

162 

163def format_argument_value(value): 

164 if isinstance(value, tf.Tensor): 

165 # Simplified representation for eager / graph tensors 

166 # to keep messages readable 

167 return f"tf.Tensor(shape={value.shape}, dtype={value.dtype.name})" 

168 return repr(value) 

169