1"""
2My own variation on function-specific inspect-like features.
3"""
4
5# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
6# Copyright (c) 2009 Gael Varoquaux
7# License: BSD Style, 3 clauses.
8
9import inspect
10import warnings
11import re
12import os
13import collections
14
15from itertools import islice
16from tokenize import open as open_py_source
17
18from .logger import pformat
19
20full_argspec_fields = ('args varargs varkw defaults kwonlyargs '
21 'kwonlydefaults annotations')
22full_argspec_type = collections.namedtuple('FullArgSpec', full_argspec_fields)
23
24
25def get_func_code(func):
26 """ Attempts to retrieve a reliable function code hash.
27
28 The reason we don't use inspect.getsource is that it caches the
29 source, whereas we want this to be modified on the fly when the
30 function is modified.
31
32 Returns
33 -------
34 func_code: string
35 The function code
36 source_file: string
37 The path to the file in which the function is defined.
38 first_line: int
39 The first line of the code in the source file.
40
41 Notes
42 ------
43 This function does a bit more magic than inspect, and is thus
44 more robust.
45 """
46 source_file = None
47 try:
48 code = func.__code__
49 source_file = code.co_filename
50 if not os.path.exists(source_file):
51 # Use inspect for lambda functions and functions defined in an
52 # interactive shell, or in doctests
53 source_code = ''.join(inspect.getsourcelines(func)[0])
54 line_no = 1
55 if source_file.startswith('<doctest '):
56 source_file, line_no = re.match(
57 r'\<doctest (.*\.rst)\[(.*)\]\>', source_file).groups()
58 line_no = int(line_no)
59 source_file = '<doctest %s>' % source_file
60 return source_code, source_file, line_no
61 # Try to retrieve the source code.
62 with open_py_source(source_file) as source_file_obj:
63 first_line = code.co_firstlineno
64 # All the lines after the function definition:
65 source_lines = list(islice(source_file_obj, first_line - 1, None))
66 return ''.join(inspect.getblock(source_lines)), source_file, first_line
67 except: # noqa: E722
68 # If the source code fails, we use the hash. This is fragile and
69 # might change from one session to another.
70 if hasattr(func, '__code__'):
71 # Python 3.X
72 return str(func.__code__.__hash__()), source_file, -1
73 else:
74 # Weird objects like numpy ufunc don't have __code__
75 # This is fragile, as quite often the id of the object is
76 # in the repr, so it might not persist across sessions,
77 # however it will work for ufuncs.
78 return repr(func), source_file, -1
79
80
81def _clean_win_chars(string):
82 """Windows cannot encode some characters in filename."""
83 import urllib
84 if hasattr(urllib, 'quote'):
85 quote = urllib.quote
86 else:
87 # In Python 3, quote is elsewhere
88 import urllib.parse
89 quote = urllib.parse.quote
90 for char in ('<', '>', '!', ':', '\\'):
91 string = string.replace(char, quote(char))
92 return string
93
94
95def get_func_name(func, resolv_alias=True, win_characters=True):
96 """ Return the function import path (as a list of module names), and
97 a name for the function.
98
99 Parameters
100 ----------
101 func: callable
102 The func to inspect
103 resolv_alias: boolean, optional
104 If true, possible local aliases are indicated.
105 win_characters: boolean, optional
106 If true, substitute special characters using urllib.quote
107 This is useful in Windows, as it cannot encode some filenames
108 """
109 if hasattr(func, '__module__'):
110 module = func.__module__
111 else:
112 try:
113 module = inspect.getmodule(func)
114 except TypeError:
115 if hasattr(func, '__class__'):
116 module = func.__class__.__module__
117 else:
118 module = 'unknown'
119 if module is None:
120 # Happens in doctests, eg
121 module = ''
122 if module == '__main__':
123 try:
124 filename = os.path.abspath(inspect.getsourcefile(func))
125 except: # noqa: E722
126 filename = None
127 if filename is not None:
128 # mangling of full path to filename
129 parts = filename.split(os.sep)
130 if parts[-1].startswith('<ipython-input'):
131 # We're in a IPython (or notebook) session. parts[-1] comes
132 # from func.__code__.co_filename and is of the form
133 # <ipython-input-N-XYZ>, where:
134 # - N is the cell number where the function was defined
135 # - XYZ is a hash representing the function's code (and name).
136 # It will be consistent across sessions and kernel restarts,
137 # and will change if the function's code/name changes
138 # We remove N so that cache is properly hit if the cell where
139 # the func is defined is re-exectuted.
140 # The XYZ hash should avoid collisions between functions with
141 # the same name, both within the same notebook but also across
142 # notebooks
143 splitted = parts[-1].split('-')
144 parts[-1] = '-'.join(splitted[:2] + splitted[3:])
145 elif len(parts) > 2 and parts[-2].startswith('ipykernel_'):
146 # In a notebook session (ipykernel). Filename seems to be 'xyz'
147 # of above. parts[-2] has the structure ipykernel_XXXXXX where
148 # XXXXXX is a six-digit number identifying the current run (?).
149 # If we split it off, the function again has the same
150 # identifier across runs.
151 parts[-2] = 'ipykernel'
152 filename = '-'.join(parts)
153 if filename.endswith('.py'):
154 filename = filename[:-3]
155 module = module + '-' + filename
156 module = module.split('.')
157 if hasattr(func, 'func_name'):
158 name = func.func_name
159 elif hasattr(func, '__name__'):
160 name = func.__name__
161 else:
162 name = 'unknown'
163 # Hack to detect functions not defined at the module-level
164 if resolv_alias:
165 # TODO: Maybe add a warning here?
166 if hasattr(func, 'func_globals') and name in func.func_globals:
167 if not func.func_globals[name] is func:
168 name = '%s-alias' % name
169 if hasattr(func, '__qualname__') and func.__qualname__ != name:
170 # Extend the module name in case of nested functions to avoid
171 # (module, name) collisions
172 module.extend(func.__qualname__.split(".")[:-1])
173 if inspect.ismethod(func):
174 # We need to add the name of the class
175 if hasattr(func, 'im_class'):
176 klass = func.im_class
177 module.append(klass.__name__)
178 if os.name == 'nt' and win_characters:
179 # Windows can't encode certain characters in filenames
180 name = _clean_win_chars(name)
181 module = [_clean_win_chars(s) for s in module]
182 return module, name
183
184
185def _signature_str(function_name, arg_sig):
186 """Helper function to output a function signature"""
187 return '{}{}'.format(function_name, arg_sig)
188
189
190def _function_called_str(function_name, args, kwargs):
191 """Helper function to output a function call"""
192 template_str = '{0}({1}, {2})'
193
194 args_str = repr(args)[1:-1]
195 kwargs_str = ', '.join('%s=%s' % (k, v)
196 for k, v in kwargs.items())
197 return template_str.format(function_name, args_str,
198 kwargs_str)
199
200
201def filter_args(func, ignore_lst, args=(), kwargs=dict()):
202 """ Filters the given args and kwargs using a list of arguments to
203 ignore, and a function specification.
204
205 Parameters
206 ----------
207 func: callable
208 Function giving the argument specification
209 ignore_lst: list of strings
210 List of arguments to ignore (either a name of an argument
211 in the function spec, or '*', or '**')
212 *args: list
213 Positional arguments passed to the function.
214 **kwargs: dict
215 Keyword arguments passed to the function
216
217 Returns
218 -------
219 filtered_args: list
220 List of filtered positional and keyword arguments.
221 """
222 args = list(args)
223 if isinstance(ignore_lst, str):
224 # Catch a common mistake
225 raise ValueError(
226 'ignore_lst must be a list of parameters to ignore '
227 '%s (type %s) was given' % (ignore_lst, type(ignore_lst)))
228 # Special case for functools.partial objects
229 if (not inspect.ismethod(func) and not inspect.isfunction(func)):
230 if ignore_lst:
231 warnings.warn('Cannot inspect object %s, ignore list will '
232 'not work.' % func, stacklevel=2)
233 return {'*': args, '**': kwargs}
234 arg_sig = inspect.signature(func)
235 arg_names = []
236 arg_defaults = []
237 arg_kwonlyargs = []
238 arg_varargs = None
239 arg_varkw = None
240 for param in arg_sig.parameters.values():
241 if param.kind is param.POSITIONAL_OR_KEYWORD:
242 arg_names.append(param.name)
243 elif param.kind is param.KEYWORD_ONLY:
244 arg_names.append(param.name)
245 arg_kwonlyargs.append(param.name)
246 elif param.kind is param.VAR_POSITIONAL:
247 arg_varargs = param.name
248 elif param.kind is param.VAR_KEYWORD:
249 arg_varkw = param.name
250 if param.default is not param.empty:
251 arg_defaults.append(param.default)
252 if inspect.ismethod(func):
253 # First argument is 'self', it has been removed by Python
254 # we need to add it back:
255 args = [func.__self__, ] + args
256 # func is an instance method, inspect.signature(func) does not
257 # include self, we need to fetch it from the class method, i.e
258 # func.__func__
259 class_method_sig = inspect.signature(func.__func__)
260 self_name = next(iter(class_method_sig.parameters))
261 arg_names = [self_name] + arg_names
262 # XXX: Maybe I need an inspect.isbuiltin to detect C-level methods, such
263 # as on ndarrays.
264
265 _, name = get_func_name(func, resolv_alias=False)
266 arg_dict = dict()
267 arg_position = -1
268 for arg_position, arg_name in enumerate(arg_names):
269 if arg_position < len(args):
270 # Positional argument or keyword argument given as positional
271 if arg_name not in arg_kwonlyargs:
272 arg_dict[arg_name] = args[arg_position]
273 else:
274 raise ValueError(
275 "Keyword-only parameter '%s' was passed as "
276 'positional parameter for %s:\n'
277 ' %s was called.'
278 % (arg_name,
279 _signature_str(name, arg_sig),
280 _function_called_str(name, args, kwargs))
281 )
282
283 else:
284 position = arg_position - len(arg_names)
285 if arg_name in kwargs:
286 arg_dict[arg_name] = kwargs[arg_name]
287 else:
288 try:
289 arg_dict[arg_name] = arg_defaults[position]
290 except (IndexError, KeyError) as e:
291 # Missing argument
292 raise ValueError(
293 'Wrong number of arguments for %s:\n'
294 ' %s was called.'
295 % (_signature_str(name, arg_sig),
296 _function_called_str(name, args, kwargs))
297 ) from e
298
299 varkwargs = dict()
300 for arg_name, arg_value in sorted(kwargs.items()):
301 if arg_name in arg_dict:
302 arg_dict[arg_name] = arg_value
303 elif arg_varkw is not None:
304 varkwargs[arg_name] = arg_value
305 else:
306 raise TypeError("Ignore list for %s() contains an unexpected "
307 "keyword argument '%s'" % (name, arg_name))
308
309 if arg_varkw is not None:
310 arg_dict['**'] = varkwargs
311 if arg_varargs is not None:
312 varargs = args[arg_position + 1:]
313 arg_dict['*'] = varargs
314
315 # Now remove the arguments to be ignored
316 for item in ignore_lst:
317 if item in arg_dict:
318 arg_dict.pop(item)
319 else:
320 raise ValueError("Ignore list: argument '%s' is not defined for "
321 "function %s"
322 % (item,
323 _signature_str(name, arg_sig))
324 )
325 # XXX: Return a sorted list of pairs?
326 return arg_dict
327
328
329def _format_arg(arg):
330 formatted_arg = pformat(arg, indent=2)
331 if len(formatted_arg) > 1500:
332 formatted_arg = '%s...' % formatted_arg[:700]
333 return formatted_arg
334
335
336def format_signature(func, *args, **kwargs):
337 # XXX: Should this use inspect.formatargvalues/formatargspec?
338 module, name = get_func_name(func)
339 module = [m for m in module if m]
340 if module:
341 module.append(name)
342 module_path = '.'.join(module)
343 else:
344 module_path = name
345 arg_str = list()
346 previous_length = 0
347 for arg in args:
348 formatted_arg = _format_arg(arg)
349 if previous_length > 80:
350 formatted_arg = '\n%s' % formatted_arg
351 previous_length = len(formatted_arg)
352 arg_str.append(formatted_arg)
353 arg_str.extend(['%s=%s' % (v, _format_arg(i)) for v, i in kwargs.items()])
354 arg_str = ', '.join(arg_str)
355
356 signature = '%s(%s)' % (name, arg_str)
357 return module_path, signature
358
359
360def format_call(func, args, kwargs, object_name="Memory"):
361 """ Returns a nicely formatted statement displaying the function
362 call with the given arguments.
363 """
364 path, signature = format_signature(func, *args, **kwargs)
365 msg = '%s\n[%s] Calling %s...\n%s' % (80 * '_', object_name,
366 path, signature)
367 return msg
368 # XXX: Not using logging framework
369 # self.debug(msg)