Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py: 20%
82 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to Keras exception stack trace prettifying."""
17import inspect
18import os
19import sys
20import traceback
21import types
23import tensorflow.compat.v2 as tf
25_EXCLUDED_PATHS = (
26 os.path.abspath(os.path.join(__file__, "..", "..")),
27 os.path.join("tensorflow", "python"),
28)
31def include_frame(fname):
32 for exclusion in _EXCLUDED_PATHS:
33 if exclusion in fname:
34 return False
35 return True
38def _process_traceback_frames(tb):
39 """Iterate through traceback frames and return a new, filtered traceback."""
40 last_tb = None
41 tb_list = list(traceback.walk_tb(tb))
42 for f, line_no in reversed(tb_list):
43 if include_frame(f.f_code.co_filename):
44 last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no)
45 if last_tb is None and tb_list:
46 # If no frames were kept during filtering, create a new traceback
47 # from the outermost function.
48 f, line_no = tb_list[-1]
49 last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no)
50 return last_tb
53def filter_traceback(fn):
54 """Filter out Keras-internal stack trace frames in exceptions raised by
55 fn."""
56 if sys.version_info.major != 3 or sys.version_info.minor < 7:
57 return fn
59 def error_handler(*args, **kwargs):
60 if not tf.debugging.is_traceback_filtering_enabled():
61 return fn(*args, **kwargs)
63 filtered_tb = None
64 try:
65 return fn(*args, **kwargs)
66 except Exception as e:
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
74 return tf.__internal__.decorator.make_decorator(fn, error_handler)
77def inject_argument_info_in_traceback(fn, object_name=None):
78 """Add information about call argument values to an error message.
80 Arguments:
81 fn: Function to wrap. Exceptions raised by the this function will be
82 re-raised with additional information added to the error message,
83 displaying the values of the different arguments that the function
84 was called with.
85 object_name: String, display name of the class/function being called,
86 e.g. `'layer "layer_name" (LayerClass)'`.
88 Returns:
89 A wrapped version of `fn`.
90 """
92 def error_handler(*args, **kwargs):
93 signature = None
94 bound_signature = None
95 try:
96 return fn(*args, **kwargs)
97 except Exception as e:
98 if hasattr(e, "_keras_call_info_injected"):
99 # Only inject info for the innermost failing call
100 raise e
101 signature = inspect.signature(fn)
102 try:
103 # The first argument is `self`, so filter it out
104 bound_signature = signature.bind(*args, **kwargs)
105 except TypeError:
106 # Likely unbindable arguments
107 raise e
109 # Add argument context
110 arguments_context = []
111 for arg in list(signature.parameters.values()):
112 if arg.name in bound_signature.arguments:
113 value = tf.nest.map_structure(
114 format_argument_value,
115 bound_signature.arguments[arg.name],
116 )
117 else:
118 value = arg.default
119 arguments_context.append(f" • {arg.name}={value}")
121 if arguments_context:
122 arguments_context = "\n".join(arguments_context)
123 # Get original error message and append information to it.
124 if isinstance(e, tf.errors.OpError):
125 message = e.message
126 elif e.args:
127 # Canonically, the 1st argument in an exception is the error
128 # message. This works for all built-in Python exceptions.
129 message = e.args[0]
130 else:
131 message = ""
132 display_name = f"{object_name if object_name else fn.__name__}"
133 message = (
134 f"Exception encountered when calling {display_name}.\n\n"
135 f"{message}\n\n"
136 f"Call arguments received by {display_name}:\n"
137 f"{arguments_context}"
138 )
140 # Reraise exception, with added context
141 if isinstance(e, tf.errors.OpError):
142 new_e = e.__class__(e.node_def, e.op, message, e.error_code)
143 else:
144 try:
145 # For standard exceptions such as ValueError, TypeError,
146 # etc.
147 new_e = e.__class__(message)
148 except TypeError:
149 # For any custom error that doesn't have a standard
150 # signature.
151 new_e = RuntimeError(message)
152 new_e._keras_call_info_injected = True
153 else:
154 new_e = e
155 raise new_e.with_traceback(e.__traceback__) from None
156 finally:
157 del signature
158 del bound_signature
160 return tf.__internal__.decorator.make_decorator(fn, error_handler)
163def format_argument_value(value):
164 if isinstance(value, tf.Tensor):
165 # Simplified representation for eager / graph tensors
166 # to keep messages readable
167 return f"tf.Tensor(shape={value.shape}, dtype={value.dtype.name})"
168 return repr(value)