1"""
2My own variation on function-specific inspect-like features.
3"""
4
5# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
6# Copyright (c) 2009 Gael Varoquaux
7# License: BSD Style, 3 clauses.
8
9import collections
10import inspect
11import os
12import re
13import warnings
14from itertools import islice
15from tokenize import open as open_py_source
16
17from .logger import pformat
18
19full_argspec_fields = (
20 "args varargs varkw defaults kwonlyargs kwonlydefaults annotations"
21)
22full_argspec_type = collections.namedtuple("FullArgSpec", full_argspec_fields)
23
24
25def get_func_code(func):
26 """Attempts to retrieve a reliable function code hash.
27
28 The reason we don't use inspect.getsource is that it caches the
29 source, whereas we want this to be modified on the fly when the
30 function is modified.
31
32 Returns
33 -------
34 func_code: string
35 The function code
36 source_file: string
37 The path to the file in which the function is defined.
38 first_line: int
39 The first line of the code in the source file.
40
41 Notes
42 ------
43 This function does a bit more magic than inspect, and is thus
44 more robust.
45 """
46 source_file = None
47 try:
48 code = func.__code__
49 source_file = code.co_filename
50 if not os.path.exists(source_file):
51 # Use inspect for lambda functions and functions defined in an
52 # interactive shell, or in doctests
53 source_code = "".join(inspect.getsourcelines(func)[0])
54 line_no = 1
55 if source_file.startswith("<doctest "):
56 source_file, line_no = re.match(
57 r"\<doctest (.*\.rst)\[(.*)\]\>", source_file
58 ).groups()
59 line_no = int(line_no)
60 source_file = "<doctest %s>" % source_file
61 return source_code, source_file, line_no
62 # Try to retrieve the source code.
63 with open_py_source(source_file) as source_file_obj:
64 first_line = code.co_firstlineno
65 # All the lines after the function definition:
66 source_lines = list(islice(source_file_obj, first_line - 1, None))
67 return "".join(inspect.getblock(source_lines)), source_file, first_line
68 except: # noqa: E722
69 # If the source code fails, we use the hash. This is fragile and
70 # might change from one session to another.
71 if hasattr(func, "__code__"):
72 # Python 3.X
73 return str(func.__code__.__hash__()), source_file, -1
74 else:
75 # Weird objects like numpy ufunc don't have __code__
76 # This is fragile, as quite often the id of the object is
77 # in the repr, so it might not persist across sessions,
78 # however it will work for ufuncs.
79 return repr(func), source_file, -1
80
81
82def _clean_win_chars(string):
83 """Windows cannot encode some characters in filename."""
84 import urllib
85
86 if hasattr(urllib, "quote"):
87 quote = urllib.quote
88 else:
89 # In Python 3, quote is elsewhere
90 import urllib.parse
91
92 quote = urllib.parse.quote
93 for char in ("<", ">", "!", ":", "\\"):
94 string = string.replace(char, quote(char))
95 return string
96
97
98def get_func_name(func, resolv_alias=True, win_characters=True):
99 """Return the function import path (as a list of module names), and
100 a name for the function.
101
102 Parameters
103 ----------
104 func: callable
105 The func to inspect
106 resolv_alias: boolean, optional
107 If true, possible local aliases are indicated.
108 win_characters: boolean, optional
109 If true, substitute special characters using urllib.quote
110 This is useful in Windows, as it cannot encode some filenames
111 """
112 if hasattr(func, "__module__"):
113 module = func.__module__
114 else:
115 try:
116 module = inspect.getmodule(func)
117 except TypeError:
118 if hasattr(func, "__class__"):
119 module = func.__class__.__module__
120 else:
121 module = "unknown"
122 if module is None:
123 # Happens in doctests, eg
124 module = ""
125 if module == "__main__":
126 try:
127 filename = os.path.abspath(inspect.getsourcefile(func))
128 except: # noqa: E722
129 filename = None
130 if filename is not None:
131 # mangling of full path to filename
132 parts = filename.split(os.sep)
133 if parts[-1].startswith("<ipython-input"):
134 # We're in a IPython (or notebook) session. parts[-1] comes
135 # from func.__code__.co_filename and is of the form
136 # <ipython-input-N-XYZ>, where:
137 # - N is the cell number where the function was defined
138 # - XYZ is a hash representing the function's code (and name).
139 # It will be consistent across sessions and kernel restarts,
140 # and will change if the function's code/name changes
141 # We remove N so that cache is properly hit if the cell where
142 # the func is defined is re-exectuted.
143 # The XYZ hash should avoid collisions between functions with
144 # the same name, both within the same notebook but also across
145 # notebooks
146 split = parts[-1].split("-")
147 parts[-1] = "-".join(split[:2] + split[3:])
148 elif len(parts) > 2 and parts[-2].startswith("ipykernel_"):
149 # In a notebook session (ipykernel). Filename seems to be 'xyz'
150 # of above. parts[-2] has the structure ipykernel_XXXXXX where
151 # XXXXXX is a six-digit number identifying the current run (?).
152 # If we split it off, the function again has the same
153 # identifier across runs.
154 parts[-2] = "ipykernel"
155 filename = "-".join(parts)
156 if filename.endswith(".py"):
157 filename = filename[:-3]
158 module = module + "-" + filename
159 module = module.split(".")
160 if hasattr(func, "func_name"):
161 name = func.func_name
162 elif hasattr(func, "__name__"):
163 name = func.__name__
164 else:
165 name = "unknown"
166 # Hack to detect functions not defined at the module-level
167 if resolv_alias:
168 # TODO: Maybe add a warning here?
169 if hasattr(func, "func_globals") and name in func.func_globals:
170 if func.func_globals[name] is not func:
171 name = "%s-alias" % name
172 if hasattr(func, "__qualname__") and func.__qualname__ != name:
173 # Extend the module name in case of nested functions to avoid
174 # (module, name) collisions
175 module.extend(func.__qualname__.split(".")[:-1])
176 if inspect.ismethod(func):
177 # We need to add the name of the class
178 if hasattr(func, "im_class"):
179 klass = func.im_class
180 module.append(klass.__name__)
181 if os.name == "nt" and win_characters:
182 # Windows can't encode certain characters in filenames
183 name = _clean_win_chars(name)
184 module = [_clean_win_chars(s) for s in module]
185 return module, name
186
187
188def _signature_str(function_name, arg_sig):
189 """Helper function to output a function signature"""
190 return "{}{}".format(function_name, arg_sig)
191
192
193def _function_called_str(function_name, args, kwargs):
194 """Helper function to output a function call"""
195 template_str = "{0}({1}, {2})"
196
197 args_str = repr(args)[1:-1]
198 kwargs_str = ", ".join("%s=%s" % (k, v) for k, v in kwargs.items())
199 return template_str.format(function_name, args_str, kwargs_str)
200
201
202def filter_args(func, ignore_lst, args=(), kwargs=dict()):
203 """Filters the given args and kwargs using a list of arguments to
204 ignore, and a function specification.
205
206 Parameters
207 ----------
208 func: callable
209 Function giving the argument specification
210 ignore_lst: list of strings
211 List of arguments to ignore (either a name of an argument
212 in the function spec, or '*', or '**')
213 *args: list
214 Positional arguments passed to the function.
215 **kwargs: dict
216 Keyword arguments passed to the function
217
218 Returns
219 -------
220 filtered_args: list
221 List of filtered positional and keyword arguments.
222 """
223 args = list(args)
224 if isinstance(ignore_lst, str):
225 # Catch a common mistake
226 raise ValueError(
227 "ignore_lst must be a list of parameters to ignore "
228 "%s (type %s) was given" % (ignore_lst, type(ignore_lst))
229 )
230 # Special case for functools.partial objects
231 if not inspect.ismethod(func) and not inspect.isfunction(func):
232 if ignore_lst:
233 warnings.warn(
234 "Cannot inspect object %s, ignore list will not work." % func,
235 stacklevel=2,
236 )
237 return {"*": args, "**": kwargs}
238 arg_sig = inspect.signature(func)
239 arg_names = []
240 arg_defaults = []
241 arg_kwonlyargs = []
242 arg_varargs = None
243 arg_varkw = None
244 for param in arg_sig.parameters.values():
245 if param.kind is param.POSITIONAL_OR_KEYWORD:
246 arg_names.append(param.name)
247 elif param.kind is param.KEYWORD_ONLY:
248 arg_names.append(param.name)
249 arg_kwonlyargs.append(param.name)
250 elif param.kind is param.VAR_POSITIONAL:
251 arg_varargs = param.name
252 elif param.kind is param.VAR_KEYWORD:
253 arg_varkw = param.name
254 if param.default is not param.empty:
255 arg_defaults.append(param.default)
256 if inspect.ismethod(func):
257 # First argument is 'self', it has been removed by Python
258 # we need to add it back:
259 args = [
260 func.__self__,
261 ] + args
262 # func is an instance method, inspect.signature(func) does not
263 # include self, we need to fetch it from the class method, i.e
264 # func.__func__
265 class_method_sig = inspect.signature(func.__func__)
266 self_name = next(iter(class_method_sig.parameters))
267 arg_names = [self_name] + arg_names
268 # XXX: Maybe I need an inspect.isbuiltin to detect C-level methods, such
269 # as on ndarrays.
270
271 _, name = get_func_name(func, resolv_alias=False)
272 arg_dict = dict()
273 arg_position = -1
274 for arg_position, arg_name in enumerate(arg_names):
275 if arg_position < len(args):
276 # Positional argument or keyword argument given as positional
277 if arg_name not in arg_kwonlyargs:
278 arg_dict[arg_name] = args[arg_position]
279 else:
280 raise ValueError(
281 "Keyword-only parameter '%s' was passed as "
282 "positional parameter for %s:\n"
283 " %s was called."
284 % (
285 arg_name,
286 _signature_str(name, arg_sig),
287 _function_called_str(name, args, kwargs),
288 )
289 )
290
291 else:
292 position = arg_position - len(arg_names)
293 if arg_name in kwargs:
294 arg_dict[arg_name] = kwargs[arg_name]
295 else:
296 try:
297 arg_dict[arg_name] = arg_defaults[position]
298 except (IndexError, KeyError) as e:
299 # Missing argument
300 raise ValueError(
301 "Wrong number of arguments for %s:\n"
302 " %s was called."
303 % (
304 _signature_str(name, arg_sig),
305 _function_called_str(name, args, kwargs),
306 )
307 ) from e
308
309 varkwargs = dict()
310 for arg_name, arg_value in sorted(kwargs.items()):
311 if arg_name in arg_dict:
312 arg_dict[arg_name] = arg_value
313 elif arg_varkw is not None:
314 varkwargs[arg_name] = arg_value
315 else:
316 raise TypeError(
317 "Ignore list for %s() contains an unexpected "
318 "keyword argument '%s'" % (name, arg_name)
319 )
320
321 if arg_varkw is not None:
322 arg_dict["**"] = varkwargs
323 if arg_varargs is not None:
324 varargs = args[arg_position + 1 :]
325 arg_dict["*"] = varargs
326
327 # Now remove the arguments to be ignored
328 for item in ignore_lst:
329 if item in arg_dict:
330 arg_dict.pop(item)
331 else:
332 raise ValueError(
333 "Ignore list: argument '%s' is not defined for "
334 "function %s" % (item, _signature_str(name, arg_sig))
335 )
336 # XXX: Return a sorted list of pairs?
337 return arg_dict
338
339
340def _format_arg(arg):
341 formatted_arg = pformat(arg, indent=2)
342 if len(formatted_arg) > 1500:
343 formatted_arg = "%s..." % formatted_arg[:700]
344 return formatted_arg
345
346
347def format_signature(func, *args, **kwargs):
348 # XXX: Should this use inspect.formatargvalues/formatargspec?
349 module, name = get_func_name(func)
350 module = [m for m in module if m]
351 if module:
352 module.append(name)
353 module_path = ".".join(module)
354 else:
355 module_path = name
356 arg_str = list()
357 previous_length = 0
358 for arg in args:
359 formatted_arg = _format_arg(arg)
360 if previous_length > 80:
361 formatted_arg = "\n%s" % formatted_arg
362 previous_length = len(formatted_arg)
363 arg_str.append(formatted_arg)
364 arg_str.extend(["%s=%s" % (v, _format_arg(i)) for v, i in kwargs.items()])
365 arg_str = ", ".join(arg_str)
366
367 signature = "%s(%s)" % (name, arg_str)
368 return module_path, signature
369
370
371def format_call(func, args, kwargs, object_name="Memory"):
372 """Returns a nicely formatted statement displaying the function
373 call with the given arguments.
374 """
375 path, signature = format_signature(func, *args, **kwargs)
376 msg = "%s\n[%s] Calling %s...\n%s" % (80 * "_", object_name, path, signature)
377 return msg
378 # XXX: Not using logging framework
379 # self.debug(msg)