Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/numpy/core/overrides.py: 69%
61 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:27 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:27 +0000
1"""Implementation of __array_function__ overrides from NEP-18."""
2import collections
3import functools
4import os
6from .._utils import set_module
7from .._utils._inspect import getargspec
8from numpy.core._multiarray_umath import (
9 add_docstring, implement_array_function, _get_implementing_args)
12ARRAY_FUNCTIONS = set()
14ARRAY_FUNCTION_ENABLED = bool(
15 int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
17array_function_like_doc = (
18 """like : array_like, optional
19 Reference object to allow the creation of arrays which are not
20 NumPy arrays. If an array-like passed in as ``like`` supports
21 the ``__array_function__`` protocol, the result will be defined
22 by it. In this case, it ensures the creation of an array object
23 compatible with that passed in via this argument."""
24)
26def set_array_function_like_doc(public_api):
27 if public_api.__doc__ is not None:
28 public_api.__doc__ = public_api.__doc__.replace(
29 "${ARRAY_FUNCTION_LIKE}",
30 array_function_like_doc,
31 )
32 return public_api
35add_docstring(
36 implement_array_function,
37 """
38 Implement a function with checks for __array_function__ overrides.
40 All arguments are required, and can only be passed by position.
42 Parameters
43 ----------
44 implementation : function
45 Function that implements the operation on NumPy array without
46 overrides when called like ``implementation(*args, **kwargs)``.
47 public_api : function
48 Function exposed by NumPy's public API originally called like
49 ``public_api(*args, **kwargs)`` on which arguments are now being
50 checked.
51 relevant_args : iterable
52 Iterable of arguments to check for __array_function__ methods.
53 args : tuple
54 Arbitrary positional arguments originally passed into ``public_api``.
55 kwargs : dict
56 Arbitrary keyword arguments originally passed into ``public_api``.
58 Returns
59 -------
60 Result from calling ``implementation()`` or an ``__array_function__``
61 method, as appropriate.
63 Raises
64 ------
65 TypeError : if no implementation is found.
66 """)
69# exposed for testing purposes; used internally by implement_array_function
70add_docstring(
71 _get_implementing_args,
72 """
73 Collect arguments on which to call __array_function__.
75 Parameters
76 ----------
77 relevant_args : iterable of array-like
78 Iterable of possibly array-like arguments to check for
79 __array_function__ methods.
81 Returns
82 -------
83 Sequence of arguments with __array_function__ methods, in the order in
84 which they should be called.
85 """)
88ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
91def verify_matching_signatures(implementation, dispatcher):
92 """Verify that a dispatcher function has the right signature."""
93 implementation_spec = ArgSpec(*getargspec(implementation))
94 dispatcher_spec = ArgSpec(*getargspec(dispatcher))
96 if (implementation_spec.args != dispatcher_spec.args or
97 implementation_spec.varargs != dispatcher_spec.varargs or
98 implementation_spec.keywords != dispatcher_spec.keywords or
99 (bool(implementation_spec.defaults) !=
100 bool(dispatcher_spec.defaults)) or
101 (implementation_spec.defaults is not None and
102 len(implementation_spec.defaults) !=
103 len(dispatcher_spec.defaults))):
104 raise RuntimeError('implementation and dispatcher for %s have '
105 'different function signatures' % implementation)
107 if implementation_spec.defaults is not None:
108 if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
109 raise RuntimeError('dispatcher functions can only use None for '
110 'default argument values')
113def array_function_dispatch(dispatcher, module=None, verify=True,
114 docs_from_dispatcher=False):
115 """Decorator for adding dispatch with the __array_function__ protocol.
117 See NEP-18 for example usage.
119 Parameters
120 ----------
121 dispatcher : callable
122 Function that when called like ``dispatcher(*args, **kwargs)`` with
123 arguments from the NumPy function call returns an iterable of
124 array-like arguments to check for ``__array_function__``.
125 module : str, optional
126 __module__ attribute to set on new function, e.g., ``module='numpy'``.
127 By default, module is copied from the decorated function.
128 verify : bool, optional
129 If True, verify the that the signature of the dispatcher and decorated
130 function signatures match exactly: all required and optional arguments
131 should appear in order with the same names, but the default values for
132 all optional arguments should be ``None``. Only disable verification
133 if the dispatcher's signature needs to deviate for some particular
134 reason, e.g., because the function has a signature like
135 ``func(*args, **kwargs)``.
136 docs_from_dispatcher : bool, optional
137 If True, copy docs from the dispatcher function onto the dispatched
138 function, rather than from the implementation. This is useful for
139 functions defined in C, which otherwise don't have docstrings.
141 Returns
142 -------
143 Function suitable for decorating the implementation of a NumPy function.
144 """
146 if not ARRAY_FUNCTION_ENABLED:
147 def decorator(implementation):
148 if docs_from_dispatcher:
149 add_docstring(implementation, dispatcher.__doc__)
150 if module is not None:
151 implementation.__module__ = module
152 return implementation
153 return decorator
155 def decorator(implementation):
156 if verify:
157 verify_matching_signatures(implementation, dispatcher)
159 if docs_from_dispatcher:
160 add_docstring(implementation, dispatcher.__doc__)
162 @functools.wraps(implementation)
163 def public_api(*args, **kwargs):
164 try:
165 relevant_args = dispatcher(*args, **kwargs)
166 except TypeError as exc:
167 # Try to clean up a signature related TypeError. Such an
168 # error will be something like:
169 # dispatcher.__name__() got an unexpected keyword argument
170 #
171 # So replace the dispatcher name in this case. In principle
172 # TypeErrors may be raised from _within_ the dispatcher, so
173 # we check that the traceback contains a string that starts
174 # with the name. (In principle we could also check the
175 # traceback length, as it would be deeper.)
176 msg = exc.args[0]
177 disp_name = dispatcher.__name__
178 if not isinstance(msg, str) or not msg.startswith(disp_name):
179 raise
181 # Replace with the correct name and re-raise:
182 new_msg = msg.replace(disp_name, public_api.__name__)
183 raise TypeError(new_msg) from None
185 return implement_array_function(
186 implementation, public_api, relevant_args, args, kwargs)
188 public_api.__code__ = public_api.__code__.replace(
189 co_name=implementation.__name__,
190 co_filename='<__array_function__ internals>')
191 if module is not None:
192 public_api.__module__ = module
194 public_api._implementation = implementation
196 ARRAY_FUNCTIONS.add(public_api)
198 return public_api
200 return decorator
203def array_function_from_dispatcher(
204 implementation, module=None, verify=True, docs_from_dispatcher=True):
205 """Like array_function_dispatcher, but with function arguments flipped."""
207 def decorator(dispatcher):
208 return array_function_dispatch(
209 dispatcher, module, verify=verify,
210 docs_from_dispatcher=docs_from_dispatcher)(implementation)
211 return decorator