1"""Implementation of __array_function__ overrides from NEP-18."""
2import collections
3import functools
4import os
5
6from numpy.core._multiarray_umath import (
7 add_docstring, implement_array_function, _get_implementing_args)
8from numpy.compat._inspect import getargspec
9
10
11ARRAY_FUNCTION_ENABLED = bool(
12 int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
13
14array_function_like_doc = (
15 """like : array_like, optional
16 Reference object to allow the creation of arrays which are not
17 NumPy arrays. If an array-like passed in as ``like`` supports
18 the ``__array_function__`` protocol, the result will be defined
19 by it. In this case, it ensures the creation of an array object
20 compatible with that passed in via this argument."""
21)
22
23def set_array_function_like_doc(public_api):
24 if public_api.__doc__ is not None:
25 public_api.__doc__ = public_api.__doc__.replace(
26 "${ARRAY_FUNCTION_LIKE}",
27 array_function_like_doc,
28 )
29 return public_api
30
31
32add_docstring(
33 implement_array_function,
34 """
35 Implement a function with checks for __array_function__ overrides.
36
37 All arguments are required, and can only be passed by position.
38
39 Parameters
40 ----------
41 implementation : function
42 Function that implements the operation on NumPy array without
43 overrides when called like ``implementation(*args, **kwargs)``.
44 public_api : function
45 Function exposed by NumPy's public API originally called like
46 ``public_api(*args, **kwargs)`` on which arguments are now being
47 checked.
48 relevant_args : iterable
49 Iterable of arguments to check for __array_function__ methods.
50 args : tuple
51 Arbitrary positional arguments originally passed into ``public_api``.
52 kwargs : dict
53 Arbitrary keyword arguments originally passed into ``public_api``.
54
55 Returns
56 -------
57 Result from calling ``implementation()`` or an ``__array_function__``
58 method, as appropriate.
59
60 Raises
61 ------
62 TypeError : if no implementation is found.
63 """)
64
65
66# exposed for testing purposes; used internally by implement_array_function
67add_docstring(
68 _get_implementing_args,
69 """
70 Collect arguments on which to call __array_function__.
71
72 Parameters
73 ----------
74 relevant_args : iterable of array-like
75 Iterable of possibly array-like arguments to check for
76 __array_function__ methods.
77
78 Returns
79 -------
80 Sequence of arguments with __array_function__ methods, in the order in
81 which they should be called.
82 """)
83
84
85ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
86
87
88def verify_matching_signatures(implementation, dispatcher):
89 """Verify that a dispatcher function has the right signature."""
90 implementation_spec = ArgSpec(*getargspec(implementation))
91 dispatcher_spec = ArgSpec(*getargspec(dispatcher))
92
93 if (implementation_spec.args != dispatcher_spec.args or
94 implementation_spec.varargs != dispatcher_spec.varargs or
95 implementation_spec.keywords != dispatcher_spec.keywords or
96 (bool(implementation_spec.defaults) !=
97 bool(dispatcher_spec.defaults)) or
98 (implementation_spec.defaults is not None and
99 len(implementation_spec.defaults) !=
100 len(dispatcher_spec.defaults))):
101 raise RuntimeError('implementation and dispatcher for %s have '
102 'different function signatures' % implementation)
103
104 if implementation_spec.defaults is not None:
105 if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
106 raise RuntimeError('dispatcher functions can only use None for '
107 'default argument values')
108
109
110def set_module(module):
111 """Decorator for overriding __module__ on a function or class.
112
113 Example usage::
114
115 @set_module('numpy')
116 def example():
117 pass
118
119 assert example.__module__ == 'numpy'
120 """
121 def decorator(func):
122 if module is not None:
123 func.__module__ = module
124 return func
125 return decorator
126
127
128def array_function_dispatch(dispatcher, module=None, verify=True,
129 docs_from_dispatcher=False, use_like=False):
130 """Decorator for adding dispatch with the __array_function__ protocol.
131
132 See NEP-18 for example usage.
133
134 Parameters
135 ----------
136 dispatcher : callable
137 Function that when called like ``dispatcher(*args, **kwargs)`` with
138 arguments from the NumPy function call returns an iterable of
139 array-like arguments to check for ``__array_function__``.
140 module : str, optional
141 __module__ attribute to set on new function, e.g., ``module='numpy'``.
142 By default, module is copied from the decorated function.
143 verify : bool, optional
144 If True, verify the that the signature of the dispatcher and decorated
145 function signatures match exactly: all required and optional arguments
146 should appear in order with the same names, but the default values for
147 all optional arguments should be ``None``. Only disable verification
148 if the dispatcher's signature needs to deviate for some particular
149 reason, e.g., because the function has a signature like
150 ``func(*args, **kwargs)``.
151 docs_from_dispatcher : bool, optional
152 If True, copy docs from the dispatcher function onto the dispatched
153 function, rather than from the implementation. This is useful for
154 functions defined in C, which otherwise don't have docstrings.
155
156 Returns
157 -------
158 Function suitable for decorating the implementation of a NumPy function.
159 """
160
161 if not ARRAY_FUNCTION_ENABLED:
162 def decorator(implementation):
163 if docs_from_dispatcher:
164 add_docstring(implementation, dispatcher.__doc__)
165 if module is not None:
166 implementation.__module__ = module
167 return implementation
168 return decorator
169
170 def decorator(implementation):
171 if verify:
172 verify_matching_signatures(implementation, dispatcher)
173
174 if docs_from_dispatcher:
175 add_docstring(implementation, dispatcher.__doc__)
176
177 @functools.wraps(implementation)
178 def public_api(*args, **kwargs):
179 try:
180 relevant_args = dispatcher(*args, **kwargs)
181 except TypeError as exc:
182 # Try to clean up a signature related TypeError. Such an
183 # error will be something like:
184 # dispatcher.__name__() got an unexpected keyword argument
185 #
186 # So replace the dispatcher name in this case. In principle
187 # TypeErrors may be raised from _within_ the dispatcher, so
188 # we check that the traceback contains a string that starts
189 # with the name. (In principle we could also check the
190 # traceback length, as it would be deeper.)
191 msg = exc.args[0]
192 disp_name = dispatcher.__name__
193 if not isinstance(msg, str) or not msg.startswith(disp_name):
194 raise
195
196 # Replace with the correct name and re-raise:
197 new_msg = msg.replace(disp_name, public_api.__name__)
198 raise TypeError(new_msg) from None
199
200 return implement_array_function(
201 implementation, public_api, relevant_args, args, kwargs,
202 use_like)
203
204 public_api.__code__ = public_api.__code__.replace(
205 co_name=implementation.__name__,
206 co_filename='<__array_function__ internals>')
207 if module is not None:
208 public_api.__module__ = module
209
210 public_api._implementation = implementation
211
212 return public_api
213
214 return decorator
215
216
217def array_function_from_dispatcher(
218 implementation, module=None, verify=True, docs_from_dispatcher=True):
219 """Like array_function_dispatcher, but with function arguments flipped."""
220
221 def decorator(dispatcher):
222 return array_function_dispatch(
223 dispatcher, module, verify=verify,
224 docs_from_dispatcher=docs_from_dispatcher)(implementation)
225 return decorator