Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/_lib/_uarray/_backend.py: 37%

163 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1import typing 

2import types 

3import inspect 

4import functools 

5from . import _uarray 

6import copyreg 

7import pickle 

8import contextlib 

9 

10ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]] 

11ArgumentReplacerType = typing.Callable[ 

12 [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict] 

13] 

14 

15from ._uarray import ( # type: ignore 

16 BackendNotImplementedError, 

17 _Function, 

18 _SkipBackendContext, 

19 _SetBackendContext, 

20 _BackendState, 

21) 

22 

23__all__ = [ 

24 "set_backend", 

25 "set_global_backend", 

26 "skip_backend", 

27 "register_backend", 

28 "determine_backend", 

29 "determine_backend_multi", 

30 "clear_backends", 

31 "create_multimethod", 

32 "generate_multimethod", 

33 "_Function", 

34 "BackendNotImplementedError", 

35 "Dispatchable", 

36 "wrap_single_convertor", 

37 "wrap_single_convertor_instance", 

38 "all_of_type", 

39 "mark_as", 

40 "set_state", 

41 "get_state", 

42 "reset_state", 

43 "_BackendState", 

44 "_SkipBackendContext", 

45 "_SetBackendContext", 

46] 

47 

48 

49def unpickle_function(mod_name, qname, self_): 

50 import importlib 

51 

52 try: 

53 module = importlib.import_module(mod_name) 

54 qname = qname.split(".") 

55 func = module 

56 for q in qname: 

57 func = getattr(func, q) 

58 

59 if self_ is not None: 

60 func = types.MethodType(func, self_) 

61 

62 return func 

63 except (ImportError, AttributeError) as e: 

64 from pickle import UnpicklingError 

65 

66 raise UnpicklingError from e 

67 

68 

69def pickle_function(func): 

70 mod_name = getattr(func, "__module__", None) 

71 qname = getattr(func, "__qualname__", None) 

72 self_ = getattr(func, "__self__", None) 

73 

74 try: 

75 test = unpickle_function(mod_name, qname, self_) 

76 except pickle.UnpicklingError: 

77 test = None 

78 

79 if test is not func: 

80 raise pickle.PicklingError( 

81 "Can't pickle {}: it's not the same object as {}".format(func, test) 

82 ) 

83 

84 return unpickle_function, (mod_name, qname, self_) 

85 

86 

87def pickle_state(state): 

88 return _uarray._BackendState._unpickle, state._pickle() 

89 

90 

91def pickle_set_backend_context(ctx): 

92 return _SetBackendContext, ctx._pickle() 

93 

94 

95def pickle_skip_backend_context(ctx): 

96 return _SkipBackendContext, ctx._pickle() 

97 

98 

99copyreg.pickle(_Function, pickle_function) 

100copyreg.pickle(_uarray._BackendState, pickle_state) 

101copyreg.pickle(_SetBackendContext, pickle_set_backend_context) 

102copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context) 

103 

104 

105def get_state(): 

106 """ 

107 Returns an opaque object containing the current state of all the backends. 

108 

109 Can be used for synchronization between threads/processes. 

110 

111 See Also 

112 -------- 

113 set_state 

114 Sets the state returned by this function. 

115 """ 

116 return _uarray.get_state() 

117 

118 

119@contextlib.contextmanager 

120def reset_state(): 

121 """ 

122 Returns a context manager that resets all state once exited. 

123 

124 See Also 

125 -------- 

126 set_state 

127 Context manager that sets the backend state. 

128 get_state 

129 Gets a state to be set by this context manager. 

130 """ 

131 with set_state(get_state()): 

132 yield 

133 

134 

135@contextlib.contextmanager 

136def set_state(state): 

137 """ 

138 A context manager that sets the state of the backends to one returned by :obj:`get_state`. 

139 

140 See Also 

141 -------- 

142 get_state 

143 Gets a state to be set by this context manager. 

144 """ 

145 old_state = get_state() 

146 _uarray.set_state(state) 

147 try: 

148 yield 

149 finally: 

150 _uarray.set_state(old_state, True) 

151 

152 

153def create_multimethod(*args, **kwargs): 

154 """ 

155 Creates a decorator for generating multimethods. 

156 

157 This function creates a decorator that can be used with an argument 

158 extractor in order to generate a multimethod. Other than for the 

159 argument extractor, all arguments are passed on to 

160 :obj:`generate_multimethod`. 

161 

162 See Also 

163 -------- 

164 generate_multimethod 

165 Generates a multimethod. 

166 """ 

167 

168 def wrapper(a): 

169 return generate_multimethod(a, *args, **kwargs) 

170 

171 return wrapper 

172 

173 

174def generate_multimethod( 

175 argument_extractor: ArgumentExtractorType, 

176 argument_replacer: ArgumentReplacerType, 

177 domain: str, 

178 default: typing.Optional[typing.Callable] = None, 

179): 

180 """ 

181 Generates a multimethod. 

182 

183 Parameters 

184 ---------- 

185 argument_extractor : ArgumentExtractorType 

186 A callable which extracts the dispatchable arguments. Extracted arguments 

187 should be marked by the :obj:`Dispatchable` class. It has the same signature 

188 as the desired multimethod. 

189 argument_replacer : ArgumentReplacerType 

190 A callable with the signature (args, kwargs, dispatchables), which should also 

191 return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs. 

192 domain : str 

193 A string value indicating the domain of this multimethod. 

194 default: Optional[Callable], optional 

195 The default implementation of this multimethod, where ``None`` (the default) specifies 

196 there is no default implementation. 

197 

198 Examples 

199 -------- 

200 In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``. 

201 The trailing comma is needed because the args have to be returned as an iterable. 

202 

203 >>> def override_me(a, b): 

204 ... return Dispatchable(a, int), 

205 

206 Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the 

207 supplied ones. 

208 

209 >>> def override_replacer(args, kwargs, dispatchables): 

210 ... return (dispatchables[0], args[1]), {} 

211 

212 Next, we define the multimethod. 

213 

214 >>> overridden_me = generate_multimethod( 

215 ... override_me, override_replacer, "ua_examples" 

216 ... ) 

217 

218 Notice that there's no default implementation, unless you supply one. 

219 

220 >>> overridden_me(1, "a") 

221 Traceback (most recent call last): 

222 ... 

223 uarray.BackendNotImplementedError: ... 

224 

225 >>> overridden_me2 = generate_multimethod( 

226 ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y) 

227 ... ) 

228 >>> overridden_me2(1, "a") 

229 (1, 'a') 

230 

231 See Also 

232 -------- 

233 uarray 

234 See the module documentation for how to override the method by creating backends. 

235 """ 

236 kw_defaults, arg_defaults, opts = get_defaults(argument_extractor) 

237 ua_func = _Function( 

238 argument_extractor, 

239 argument_replacer, 

240 domain, 

241 arg_defaults, 

242 kw_defaults, 

243 default, 

244 ) 

245 

246 return functools.update_wrapper(ua_func, argument_extractor) 

247 

248 

249def set_backend(backend, coerce=False, only=False): 

250 """ 

251 A context manager that sets the preferred backend. 

252 

253 Parameters 

254 ---------- 

255 backend 

256 The backend to set. 

257 coerce 

258 Whether or not to coerce to a specific backend's types. Implies ``only``. 

259 only 

260 Whether or not this should be the last backend to try. 

261 

262 See Also 

263 -------- 

264 skip_backend: A context manager that allows skipping of backends. 

265 set_global_backend: Set a single, global backend for a domain. 

266 """ 

267 try: 

268 return backend.__ua_cache__["set", coerce, only] 

269 except AttributeError: 

270 backend.__ua_cache__ = {} 

271 except KeyError: 

272 pass 

273 

274 ctx = _SetBackendContext(backend, coerce, only) 

275 backend.__ua_cache__["set", coerce, only] = ctx 

276 return ctx 

277 

278 

279def skip_backend(backend): 

280 """ 

281 A context manager that allows one to skip a given backend from processing 

282 entirely. This allows one to use another backend's code in a library that 

283 is also a consumer of the same backend. 

284 

285 Parameters 

286 ---------- 

287 backend 

288 The backend to skip. 

289 

290 See Also 

291 -------- 

292 set_backend: A context manager that allows setting of backends. 

293 set_global_backend: Set a single, global backend for a domain. 

294 """ 

295 try: 

296 return backend.__ua_cache__["skip"] 

297 except AttributeError: 

298 backend.__ua_cache__ = {} 

299 except KeyError: 

300 pass 

301 

302 ctx = _SkipBackendContext(backend) 

303 backend.__ua_cache__["skip"] = ctx 

304 return ctx 

305 

306 

307def get_defaults(f): 

308 sig = inspect.signature(f) 

309 kw_defaults = {} 

310 arg_defaults = [] 

311 opts = set() 

312 for k, v in sig.parameters.items(): 

313 if v.default is not inspect.Parameter.empty: 

314 kw_defaults[k] = v.default 

315 if v.kind in ( 

316 inspect.Parameter.POSITIONAL_ONLY, 

317 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

318 ): 

319 arg_defaults.append(v.default) 

320 opts.add(k) 

321 

322 return kw_defaults, tuple(arg_defaults), opts 

323 

324 

325def set_global_backend(backend, coerce=False, only=False, *, try_last=False): 

326 """ 

327 This utility method replaces the default backend for permanent use. It 

328 will be tried in the list of backends automatically, unless the 

329 ``only`` flag is set on a backend. This will be the first tried 

330 backend outside the :obj:`set_backend` context manager. 

331 

332 Note that this method is not thread-safe. 

333 

334 .. warning:: 

335 We caution library authors against using this function in 

336 their code. We do *not* support this use-case. This function 

337 is meant to be used only by users themselves, or by a reference 

338 implementation, if one exists. 

339 

340 Parameters 

341 ---------- 

342 backend 

343 The backend to register. 

344 coerce : bool 

345 Whether to coerce input types when trying this backend. 

346 only : bool 

347 If ``True``, no more backends will be tried if this fails. 

348 Implied by ``coerce=True``. 

349 try_last : bool 

350 If ``True``, the global backend is tried after registered backends. 

351 

352 See Also 

353 -------- 

354 set_backend: A context manager that allows setting of backends. 

355 skip_backend: A context manager that allows skipping of backends. 

356 """ 

357 _uarray.set_global_backend(backend, coerce, only, try_last) 

358 

359 

360def register_backend(backend): 

361 """ 

362 This utility method sets registers backend for permanent use. It 

363 will be tried in the list of backends automatically, unless the 

364 ``only`` flag is set on a backend. 

365 

366 Note that this method is not thread-safe. 

367 

368 Parameters 

369 ---------- 

370 backend 

371 The backend to register. 

372 """ 

373 _uarray.register_backend(backend) 

374 

375 

376def clear_backends(domain, registered=True, globals=False): 

377 """ 

378 This utility method clears registered backends. 

379 

380 .. warning:: 

381 We caution library authors against using this function in 

382 their code. We do *not* support this use-case. This function 

383 is meant to be used only by users themselves. 

384 

385 .. warning:: 

386 Do NOT use this method inside a multimethod call, or the 

387 program is likely to crash. 

388 

389 Parameters 

390 ---------- 

391 domain : Optional[str] 

392 The domain for which to de-register backends. ``None`` means 

393 de-register for all domains. 

394 registered : bool 

395 Whether or not to clear registered backends. See :obj:`register_backend`. 

396 globals : bool 

397 Whether or not to clear global backends. See :obj:`set_global_backend`. 

398 

399 See Also 

400 -------- 

401 register_backend : Register a backend globally. 

402 set_global_backend : Set a global backend. 

403 """ 

404 _uarray.clear_backends(domain, registered, globals) 

405 

406 

407class Dispatchable: 

408 """ 

409 A utility class which marks an argument with a specific dispatch type. 

410 

411 

412 Attributes 

413 ---------- 

414 value 

415 The value of the Dispatchable. 

416 

417 type 

418 The type of the Dispatchable. 

419 

420 Examples 

421 -------- 

422 >>> x = Dispatchable(1, str) 

423 >>> x 

424 <Dispatchable: type=<class 'str'>, value=1> 

425 

426 See Also 

427 -------- 

428 all_of_type 

429 Marks all unmarked parameters of a function. 

430 

431 mark_as 

432 Allows one to create a utility function to mark as a given type. 

433 """ 

434 

435 def __init__(self, value, dispatch_type, coercible=True): 

436 self.value = value 

437 self.type = dispatch_type 

438 self.coercible = coercible 

439 

440 def __getitem__(self, index): 

441 return (self.type, self.value)[index] 

442 

443 def __str__(self): 

444 return "<{0}: type={1!r}, value={2!r}>".format( 

445 type(self).__name__, self.type, self.value 

446 ) 

447 

448 __repr__ = __str__ 

449 

450 

451def mark_as(dispatch_type): 

452 """ 

453 Creates a utility function to mark something as a specific type. 

454 

455 Examples 

456 -------- 

457 >>> mark_int = mark_as(int) 

458 >>> mark_int(1) 

459 <Dispatchable: type=<class 'int'>, value=1> 

460 """ 

461 return functools.partial(Dispatchable, dispatch_type=dispatch_type) 

462 

463 

464def all_of_type(arg_type): 

465 """ 

466 Marks all unmarked arguments as a given type. 

467 

468 Examples 

469 -------- 

470 >>> @all_of_type(str) 

471 ... def f(a, b): 

472 ... return a, Dispatchable(b, int) 

473 >>> f('a', 1) 

474 (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>) 

475 """ 

476 

477 def outer(func): 

478 @functools.wraps(func) 

479 def inner(*args, **kwargs): 

480 extracted_args = func(*args, **kwargs) 

481 return tuple( 

482 Dispatchable(arg, arg_type) 

483 if not isinstance(arg, Dispatchable) 

484 else arg 

485 for arg in extracted_args 

486 ) 

487 

488 return inner 

489 

490 return outer 

491 

492 

493def wrap_single_convertor(convert_single): 

494 """ 

495 Wraps a ``__ua_convert__`` defined for a single element to all elements. 

496 If any of them return ``NotImplemented``, the operation is assumed to be 

497 undefined. 

498 

499 Accepts a signature of (value, type, coerce). 

500 """ 

501 

502 @functools.wraps(convert_single) 

503 def __ua_convert__(dispatchables, coerce): 

504 converted = [] 

505 for d in dispatchables: 

506 c = convert_single(d.value, d.type, coerce and d.coercible) 

507 

508 if c is NotImplemented: 

509 return NotImplemented 

510 

511 converted.append(c) 

512 

513 return converted 

514 

515 return __ua_convert__ 

516 

517 

518def wrap_single_convertor_instance(convert_single): 

519 """ 

520 Wraps a ``__ua_convert__`` defined for a single element to all elements. 

521 If any of them return ``NotImplemented``, the operation is assumed to be 

522 undefined. 

523 

524 Accepts a signature of (value, type, coerce). 

525 """ 

526 

527 @functools.wraps(convert_single) 

528 def __ua_convert__(self, dispatchables, coerce): 

529 converted = [] 

530 for d in dispatchables: 

531 c = convert_single(self, d.value, d.type, coerce and d.coercible) 

532 

533 if c is NotImplemented: 

534 return NotImplemented 

535 

536 converted.append(c) 

537 

538 return converted 

539 

540 return __ua_convert__ 

541 

542 

543def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False): 

544 """Set the backend to the first active backend that supports ``value`` 

545 

546 This is useful for functions that call multimethods without any dispatchable 

547 arguments. You can use :func:`determine_backend` to ensure the same backend 

548 is used everywhere in a block of multimethod calls. 

549 

550 Parameters 

551 ---------- 

552 value 

553 The value being tested 

554 dispatch_type 

555 The dispatch type associated with ``value``, aka 

556 ":ref:`marking <MarkingGlossary>`". 

557 domain: string 

558 The domain to query for backends and set. 

559 coerce: bool 

560 Whether or not to allow coercion to the backend's types. Implies ``only``. 

561 only: bool 

562 Whether or not this should be the last backend to try. 

563 

564 See Also 

565 -------- 

566 set_backend: For when you know which backend to set 

567 

568 Notes 

569 ----- 

570 

571 Support is determined by the ``__ua_convert__`` protocol. Backends not 

572 supporting the type must return ``NotImplemented`` from their 

573 ``__ua_convert__`` if they don't support input of that type. 

574 

575 Examples 

576 -------- 

577 

578 Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting 

579 different types, ``TypeA`` and ``TypeB``. Neither supporting the other type: 

580 

581 >>> with ua.set_backend(ex.BackendA): 

582 ... ex.call_multimethod(ex.TypeB(), ex.TypeB()) 

583 Traceback (most recent call last): 

584 ... 

585 uarray.BackendNotImplementedError: ... 

586 

587 Now consider a multimethod that creates a new object of ``TypeA``, or 

588 ``TypeB`` depending on the active backend. 

589 

590 >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB): 

591 ... res = ex.creation_multimethod() 

592 ... ex.call_multimethod(res, ex.TypeA()) 

593 Traceback (most recent call last): 

594 ... 

595 uarray.BackendNotImplementedError: ... 

596 

597 ``res`` is an object of ``TypeB`` because ``BackendB`` is set in the 

598 innermost with statement. So, ``call_multimethod`` fails since the types 

599 don't match. 

600 

601 Instead, we need to first find a backend suitable for all of our objects. 

602 

603 >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB): 

604 ... x = ex.TypeA() 

605 ... with ua.determine_backend(x, "mark", domain="ua_examples"): 

606 ... res = ex.creation_multimethod() 

607 ... ex.call_multimethod(res, x) 

608 TypeA 

609 

610 """ 

611 dispatchables = (Dispatchable(value, dispatch_type, coerce),) 

612 backend = _uarray.determine_backend(domain, dispatchables, coerce) 

613 

614 return set_backend(backend, coerce=coerce, only=only) 

615 

616 

617def determine_backend_multi( 

618 dispatchables, *, domain, only=True, coerce=False, **kwargs 

619): 

620 """Set a backend supporting all ``dispatchables`` 

621 

622 This is useful for functions that call multimethods without any dispatchable 

623 arguments. You can use :func:`determine_backend_multi` to ensure the same 

624 backend is used everywhere in a block of multimethod calls involving 

625 multiple arrays. 

626 

627 Parameters 

628 ---------- 

629 dispatchables: Sequence[Union[uarray.Dispatchable, Any]] 

630 The dispatchables that must be supported 

631 domain: string 

632 The domain to query for backends and set. 

633 coerce: bool 

634 Whether or not to allow coercion to the backend's types. Implies ``only``. 

635 only: bool 

636 Whether or not this should be the last backend to try. 

637 dispatch_type: Optional[Any] 

638 The default dispatch type associated with ``dispatchables``, aka 

639 ":ref:`marking <MarkingGlossary>`". 

640 

641 See Also 

642 -------- 

643 determine_backend: For a single dispatch value 

644 set_backend: For when you know which backend to set 

645 

646 Notes 

647 ----- 

648 

649 Support is determined by the ``__ua_convert__`` protocol. Backends not 

650 supporting the type must return ``NotImplemented`` from their 

651 ``__ua_convert__`` if they don't support input of that type. 

652 

653 Examples 

654 -------- 

655 

656 :func:`determine_backend` allows the backend to be set from a single 

657 object. :func:`determine_backend_multi` allows multiple objects to be 

658 checked simultaneously for support in the backend. Suppose we have a 

659 ``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call, 

660 and a ``BackendBC`` that doesn't support ``TypeA``. 

661 

662 >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC): 

663 ... a, b = ex.TypeA(), ex.TypeB() 

664 ... with ua.determine_backend_multi( 

665 ... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")], 

666 ... domain="ua_examples" 

667 ... ): 

668 ... res = ex.creation_multimethod() 

669 ... ex.call_multimethod(res, a, b) 

670 TypeA 

671 

672 This won't call ``BackendBC`` because it doesn't support ``TypeA``. 

673 

674 We can also use leave out the ``ua.Dispatchable`` if we specify the 

675 default ``dispatch_type`` for the ``dispatchables`` argument. 

676 

677 >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC): 

678 ... a, b = ex.TypeA(), ex.TypeB() 

679 ... with ua.determine_backend_multi( 

680 ... [a, b], dispatch_type="mark", domain="ua_examples" 

681 ... ): 

682 ... res = ex.creation_multimethod() 

683 ... ex.call_multimethod(res, a, b) 

684 TypeA 

685 

686 """ 

687 if "dispatch_type" in kwargs: 

688 disp_type = kwargs.pop("dispatch_type") 

689 dispatchables = tuple( 

690 d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type) 

691 for d in dispatchables 

692 ) 

693 else: 

694 dispatchables = tuple(dispatchables) 

695 if not all(isinstance(d, Dispatchable) for d in dispatchables): 

696 raise TypeError("dispatchables must be instances of uarray.Dispatchable") 

697 

698 if len(kwargs) != 0: 

699 raise TypeError("Received unexpected keyword arguments: {}".format(kwargs)) 

700 

701 backend = _uarray.determine_backend(domain, dispatchables, coerce) 

702 

703 return set_backend(backend, coerce=coerce, only=only)