Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/_lib/_uarray/_backend.py: 37%
163 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
1import typing
2import types
3import inspect
4import functools
5from . import _uarray
6import copyreg
7import pickle
8import contextlib
10ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]]
11ArgumentReplacerType = typing.Callable[
12 [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict]
13]
15from ._uarray import ( # type: ignore
16 BackendNotImplementedError,
17 _Function,
18 _SkipBackendContext,
19 _SetBackendContext,
20 _BackendState,
21)
23__all__ = [
24 "set_backend",
25 "set_global_backend",
26 "skip_backend",
27 "register_backend",
28 "determine_backend",
29 "determine_backend_multi",
30 "clear_backends",
31 "create_multimethod",
32 "generate_multimethod",
33 "_Function",
34 "BackendNotImplementedError",
35 "Dispatchable",
36 "wrap_single_convertor",
37 "wrap_single_convertor_instance",
38 "all_of_type",
39 "mark_as",
40 "set_state",
41 "get_state",
42 "reset_state",
43 "_BackendState",
44 "_SkipBackendContext",
45 "_SetBackendContext",
46]
49def unpickle_function(mod_name, qname, self_):
50 import importlib
52 try:
53 module = importlib.import_module(mod_name)
54 qname = qname.split(".")
55 func = module
56 for q in qname:
57 func = getattr(func, q)
59 if self_ is not None:
60 func = types.MethodType(func, self_)
62 return func
63 except (ImportError, AttributeError) as e:
64 from pickle import UnpicklingError
66 raise UnpicklingError from e
69def pickle_function(func):
70 mod_name = getattr(func, "__module__", None)
71 qname = getattr(func, "__qualname__", None)
72 self_ = getattr(func, "__self__", None)
74 try:
75 test = unpickle_function(mod_name, qname, self_)
76 except pickle.UnpicklingError:
77 test = None
79 if test is not func:
80 raise pickle.PicklingError(
81 "Can't pickle {}: it's not the same object as {}".format(func, test)
82 )
84 return unpickle_function, (mod_name, qname, self_)
87def pickle_state(state):
88 return _uarray._BackendState._unpickle, state._pickle()
91def pickle_set_backend_context(ctx):
92 return _SetBackendContext, ctx._pickle()
95def pickle_skip_backend_context(ctx):
96 return _SkipBackendContext, ctx._pickle()
99copyreg.pickle(_Function, pickle_function)
100copyreg.pickle(_uarray._BackendState, pickle_state)
101copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
102copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
105def get_state():
106 """
107 Returns an opaque object containing the current state of all the backends.
109 Can be used for synchronization between threads/processes.
111 See Also
112 --------
113 set_state
114 Sets the state returned by this function.
115 """
116 return _uarray.get_state()
119@contextlib.contextmanager
120def reset_state():
121 """
122 Returns a context manager that resets all state once exited.
124 See Also
125 --------
126 set_state
127 Context manager that sets the backend state.
128 get_state
129 Gets a state to be set by this context manager.
130 """
131 with set_state(get_state()):
132 yield
135@contextlib.contextmanager
136def set_state(state):
137 """
138 A context manager that sets the state of the backends to one returned by :obj:`get_state`.
140 See Also
141 --------
142 get_state
143 Gets a state to be set by this context manager.
144 """
145 old_state = get_state()
146 _uarray.set_state(state)
147 try:
148 yield
149 finally:
150 _uarray.set_state(old_state, True)
153def create_multimethod(*args, **kwargs):
154 """
155 Creates a decorator for generating multimethods.
157 This function creates a decorator that can be used with an argument
158 extractor in order to generate a multimethod. Other than for the
159 argument extractor, all arguments are passed on to
160 :obj:`generate_multimethod`.
162 See Also
163 --------
164 generate_multimethod
165 Generates a multimethod.
166 """
168 def wrapper(a):
169 return generate_multimethod(a, *args, **kwargs)
171 return wrapper
174def generate_multimethod(
175 argument_extractor: ArgumentExtractorType,
176 argument_replacer: ArgumentReplacerType,
177 domain: str,
178 default: typing.Optional[typing.Callable] = None,
179):
180 """
181 Generates a multimethod.
183 Parameters
184 ----------
185 argument_extractor : ArgumentExtractorType
186 A callable which extracts the dispatchable arguments. Extracted arguments
187 should be marked by the :obj:`Dispatchable` class. It has the same signature
188 as the desired multimethod.
189 argument_replacer : ArgumentReplacerType
190 A callable with the signature (args, kwargs, dispatchables), which should also
191 return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs.
192 domain : str
193 A string value indicating the domain of this multimethod.
194 default: Optional[Callable], optional
195 The default implementation of this multimethod, where ``None`` (the default) specifies
196 there is no default implementation.
198 Examples
199 --------
200 In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``.
201 The trailing comma is needed because the args have to be returned as an iterable.
203 >>> def override_me(a, b):
204 ... return Dispatchable(a, int),
206 Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the
207 supplied ones.
209 >>> def override_replacer(args, kwargs, dispatchables):
210 ... return (dispatchables[0], args[1]), {}
212 Next, we define the multimethod.
214 >>> overridden_me = generate_multimethod(
215 ... override_me, override_replacer, "ua_examples"
216 ... )
218 Notice that there's no default implementation, unless you supply one.
220 >>> overridden_me(1, "a")
221 Traceback (most recent call last):
222 ...
223 uarray.BackendNotImplementedError: ...
225 >>> overridden_me2 = generate_multimethod(
226 ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
227 ... )
228 >>> overridden_me2(1, "a")
229 (1, 'a')
231 See Also
232 --------
233 uarray
234 See the module documentation for how to override the method by creating backends.
235 """
236 kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
237 ua_func = _Function(
238 argument_extractor,
239 argument_replacer,
240 domain,
241 arg_defaults,
242 kw_defaults,
243 default,
244 )
246 return functools.update_wrapper(ua_func, argument_extractor)
249def set_backend(backend, coerce=False, only=False):
250 """
251 A context manager that sets the preferred backend.
253 Parameters
254 ----------
255 backend
256 The backend to set.
257 coerce
258 Whether or not to coerce to a specific backend's types. Implies ``only``.
259 only
260 Whether or not this should be the last backend to try.
262 See Also
263 --------
264 skip_backend: A context manager that allows skipping of backends.
265 set_global_backend: Set a single, global backend for a domain.
266 """
267 try:
268 return backend.__ua_cache__["set", coerce, only]
269 except AttributeError:
270 backend.__ua_cache__ = {}
271 except KeyError:
272 pass
274 ctx = _SetBackendContext(backend, coerce, only)
275 backend.__ua_cache__["set", coerce, only] = ctx
276 return ctx
279def skip_backend(backend):
280 """
281 A context manager that allows one to skip a given backend from processing
282 entirely. This allows one to use another backend's code in a library that
283 is also a consumer of the same backend.
285 Parameters
286 ----------
287 backend
288 The backend to skip.
290 See Also
291 --------
292 set_backend: A context manager that allows setting of backends.
293 set_global_backend: Set a single, global backend for a domain.
294 """
295 try:
296 return backend.__ua_cache__["skip"]
297 except AttributeError:
298 backend.__ua_cache__ = {}
299 except KeyError:
300 pass
302 ctx = _SkipBackendContext(backend)
303 backend.__ua_cache__["skip"] = ctx
304 return ctx
307def get_defaults(f):
308 sig = inspect.signature(f)
309 kw_defaults = {}
310 arg_defaults = []
311 opts = set()
312 for k, v in sig.parameters.items():
313 if v.default is not inspect.Parameter.empty:
314 kw_defaults[k] = v.default
315 if v.kind in (
316 inspect.Parameter.POSITIONAL_ONLY,
317 inspect.Parameter.POSITIONAL_OR_KEYWORD,
318 ):
319 arg_defaults.append(v.default)
320 opts.add(k)
322 return kw_defaults, tuple(arg_defaults), opts
325def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
326 """
327 This utility method replaces the default backend for permanent use. It
328 will be tried in the list of backends automatically, unless the
329 ``only`` flag is set on a backend. This will be the first tried
330 backend outside the :obj:`set_backend` context manager.
332 Note that this method is not thread-safe.
334 .. warning::
335 We caution library authors against using this function in
336 their code. We do *not* support this use-case. This function
337 is meant to be used only by users themselves, or by a reference
338 implementation, if one exists.
340 Parameters
341 ----------
342 backend
343 The backend to register.
344 coerce : bool
345 Whether to coerce input types when trying this backend.
346 only : bool
347 If ``True``, no more backends will be tried if this fails.
348 Implied by ``coerce=True``.
349 try_last : bool
350 If ``True``, the global backend is tried after registered backends.
352 See Also
353 --------
354 set_backend: A context manager that allows setting of backends.
355 skip_backend: A context manager that allows skipping of backends.
356 """
357 _uarray.set_global_backend(backend, coerce, only, try_last)
360def register_backend(backend):
361 """
362 This utility method sets registers backend for permanent use. It
363 will be tried in the list of backends automatically, unless the
364 ``only`` flag is set on a backend.
366 Note that this method is not thread-safe.
368 Parameters
369 ----------
370 backend
371 The backend to register.
372 """
373 _uarray.register_backend(backend)
376def clear_backends(domain, registered=True, globals=False):
377 """
378 This utility method clears registered backends.
380 .. warning::
381 We caution library authors against using this function in
382 their code. We do *not* support this use-case. This function
383 is meant to be used only by users themselves.
385 .. warning::
386 Do NOT use this method inside a multimethod call, or the
387 program is likely to crash.
389 Parameters
390 ----------
391 domain : Optional[str]
392 The domain for which to de-register backends. ``None`` means
393 de-register for all domains.
394 registered : bool
395 Whether or not to clear registered backends. See :obj:`register_backend`.
396 globals : bool
397 Whether or not to clear global backends. See :obj:`set_global_backend`.
399 See Also
400 --------
401 register_backend : Register a backend globally.
402 set_global_backend : Set a global backend.
403 """
404 _uarray.clear_backends(domain, registered, globals)
407class Dispatchable:
408 """
409 A utility class which marks an argument with a specific dispatch type.
412 Attributes
413 ----------
414 value
415 The value of the Dispatchable.
417 type
418 The type of the Dispatchable.
420 Examples
421 --------
422 >>> x = Dispatchable(1, str)
423 >>> x
424 <Dispatchable: type=<class 'str'>, value=1>
426 See Also
427 --------
428 all_of_type
429 Marks all unmarked parameters of a function.
431 mark_as
432 Allows one to create a utility function to mark as a given type.
433 """
435 def __init__(self, value, dispatch_type, coercible=True):
436 self.value = value
437 self.type = dispatch_type
438 self.coercible = coercible
440 def __getitem__(self, index):
441 return (self.type, self.value)[index]
443 def __str__(self):
444 return "<{0}: type={1!r}, value={2!r}>".format(
445 type(self).__name__, self.type, self.value
446 )
448 __repr__ = __str__
451def mark_as(dispatch_type):
452 """
453 Creates a utility function to mark something as a specific type.
455 Examples
456 --------
457 >>> mark_int = mark_as(int)
458 >>> mark_int(1)
459 <Dispatchable: type=<class 'int'>, value=1>
460 """
461 return functools.partial(Dispatchable, dispatch_type=dispatch_type)
464def all_of_type(arg_type):
465 """
466 Marks all unmarked arguments as a given type.
468 Examples
469 --------
470 >>> @all_of_type(str)
471 ... def f(a, b):
472 ... return a, Dispatchable(b, int)
473 >>> f('a', 1)
474 (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>)
475 """
477 def outer(func):
478 @functools.wraps(func)
479 def inner(*args, **kwargs):
480 extracted_args = func(*args, **kwargs)
481 return tuple(
482 Dispatchable(arg, arg_type)
483 if not isinstance(arg, Dispatchable)
484 else arg
485 for arg in extracted_args
486 )
488 return inner
490 return outer
493def wrap_single_convertor(convert_single):
494 """
495 Wraps a ``__ua_convert__`` defined for a single element to all elements.
496 If any of them return ``NotImplemented``, the operation is assumed to be
497 undefined.
499 Accepts a signature of (value, type, coerce).
500 """
502 @functools.wraps(convert_single)
503 def __ua_convert__(dispatchables, coerce):
504 converted = []
505 for d in dispatchables:
506 c = convert_single(d.value, d.type, coerce and d.coercible)
508 if c is NotImplemented:
509 return NotImplemented
511 converted.append(c)
513 return converted
515 return __ua_convert__
518def wrap_single_convertor_instance(convert_single):
519 """
520 Wraps a ``__ua_convert__`` defined for a single element to all elements.
521 If any of them return ``NotImplemented``, the operation is assumed to be
522 undefined.
524 Accepts a signature of (value, type, coerce).
525 """
527 @functools.wraps(convert_single)
528 def __ua_convert__(self, dispatchables, coerce):
529 converted = []
530 for d in dispatchables:
531 c = convert_single(self, d.value, d.type, coerce and d.coercible)
533 if c is NotImplemented:
534 return NotImplemented
536 converted.append(c)
538 return converted
540 return __ua_convert__
543def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
544 """Set the backend to the first active backend that supports ``value``
546 This is useful for functions that call multimethods without any dispatchable
547 arguments. You can use :func:`determine_backend` to ensure the same backend
548 is used everywhere in a block of multimethod calls.
550 Parameters
551 ----------
552 value
553 The value being tested
554 dispatch_type
555 The dispatch type associated with ``value``, aka
556 ":ref:`marking <MarkingGlossary>`".
557 domain: string
558 The domain to query for backends and set.
559 coerce: bool
560 Whether or not to allow coercion to the backend's types. Implies ``only``.
561 only: bool
562 Whether or not this should be the last backend to try.
564 See Also
565 --------
566 set_backend: For when you know which backend to set
568 Notes
569 -----
571 Support is determined by the ``__ua_convert__`` protocol. Backends not
572 supporting the type must return ``NotImplemented`` from their
573 ``__ua_convert__`` if they don't support input of that type.
575 Examples
576 --------
578 Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
579 different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
581 >>> with ua.set_backend(ex.BackendA):
582 ... ex.call_multimethod(ex.TypeB(), ex.TypeB())
583 Traceback (most recent call last):
584 ...
585 uarray.BackendNotImplementedError: ...
587 Now consider a multimethod that creates a new object of ``TypeA``, or
588 ``TypeB`` depending on the active backend.
590 >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
591 ... res = ex.creation_multimethod()
592 ... ex.call_multimethod(res, ex.TypeA())
593 Traceback (most recent call last):
594 ...
595 uarray.BackendNotImplementedError: ...
597 ``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
598 innermost with statement. So, ``call_multimethod`` fails since the types
599 don't match.
601 Instead, we need to first find a backend suitable for all of our objects.
603 >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
604 ... x = ex.TypeA()
605 ... with ua.determine_backend(x, "mark", domain="ua_examples"):
606 ... res = ex.creation_multimethod()
607 ... ex.call_multimethod(res, x)
608 TypeA
610 """
611 dispatchables = (Dispatchable(value, dispatch_type, coerce),)
612 backend = _uarray.determine_backend(domain, dispatchables, coerce)
614 return set_backend(backend, coerce=coerce, only=only)
617def determine_backend_multi(
618 dispatchables, *, domain, only=True, coerce=False, **kwargs
619):
620 """Set a backend supporting all ``dispatchables``
622 This is useful for functions that call multimethods without any dispatchable
623 arguments. You can use :func:`determine_backend_multi` to ensure the same
624 backend is used everywhere in a block of multimethod calls involving
625 multiple arrays.
627 Parameters
628 ----------
629 dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
630 The dispatchables that must be supported
631 domain: string
632 The domain to query for backends and set.
633 coerce: bool
634 Whether or not to allow coercion to the backend's types. Implies ``only``.
635 only: bool
636 Whether or not this should be the last backend to try.
637 dispatch_type: Optional[Any]
638 The default dispatch type associated with ``dispatchables``, aka
639 ":ref:`marking <MarkingGlossary>`".
641 See Also
642 --------
643 determine_backend: For a single dispatch value
644 set_backend: For when you know which backend to set
646 Notes
647 -----
649 Support is determined by the ``__ua_convert__`` protocol. Backends not
650 supporting the type must return ``NotImplemented`` from their
651 ``__ua_convert__`` if they don't support input of that type.
653 Examples
654 --------
656 :func:`determine_backend` allows the backend to be set from a single
657 object. :func:`determine_backend_multi` allows multiple objects to be
658 checked simultaneously for support in the backend. Suppose we have a
659 ``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
660 and a ``BackendBC`` that doesn't support ``TypeA``.
662 >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
663 ... a, b = ex.TypeA(), ex.TypeB()
664 ... with ua.determine_backend_multi(
665 ... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
666 ... domain="ua_examples"
667 ... ):
668 ... res = ex.creation_multimethod()
669 ... ex.call_multimethod(res, a, b)
670 TypeA
672 This won't call ``BackendBC`` because it doesn't support ``TypeA``.
674 We can also use leave out the ``ua.Dispatchable`` if we specify the
675 default ``dispatch_type`` for the ``dispatchables`` argument.
677 >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
678 ... a, b = ex.TypeA(), ex.TypeB()
679 ... with ua.determine_backend_multi(
680 ... [a, b], dispatch_type="mark", domain="ua_examples"
681 ... ):
682 ... res = ex.creation_multimethod()
683 ... ex.call_multimethod(res, a, b)
684 TypeA
686 """
687 if "dispatch_type" in kwargs:
688 disp_type = kwargs.pop("dispatch_type")
689 dispatchables = tuple(
690 d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
691 for d in dispatchables
692 )
693 else:
694 dispatchables = tuple(dispatchables)
695 if not all(isinstance(d, Dispatchable) for d in dispatchables):
696 raise TypeError("dispatchables must be instances of uarray.Dispatchable")
698 if len(kwargs) != 0:
699 raise TypeError("Received unexpected keyword arguments: {}".format(kwargs))
701 backend = _uarray.determine_backend(domain, dispatchables, coerce)
703 return set_backend(backend, coerce=coerce, only=only)