Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/_future.py: 20%

388 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-01 06:54 +0000

1"""Future-returning APIs for coroutines.""" 

2 

3# Copyright (c) PyZMQ Developers. 

4# Distributed under the terms of the Modified BSD License. 

5 

6import warnings 

7from asyncio import Future 

8from collections import deque 

9from itertools import chain 

10from typing import ( 

11 Any, 

12 Awaitable, 

13 Callable, 

14 Dict, 

15 List, 

16 NamedTuple, 

17 Optional, 

18 Tuple, 

19 Type, 

20 TypeVar, 

21 Union, 

22 cast, 

23 overload, 

24) 

25 

26import zmq as _zmq 

27from zmq import EVENTS, POLLIN, POLLOUT 

28from zmq._typing import Literal 

29 

30 

31class _FutureEvent(NamedTuple): 

32 future: Future 

33 kind: str 

34 kwargs: Dict 

35 msg: Any 

36 timer: Any 

37 

38 

39# These are incomplete classes and need a Mixin for compatibility with an eventloop 

40# defining the following attributes: 

41# 

42# _Future 

43# _READ 

44# _WRITE 

45# _default_loop() 

46 

47 

48class _Async: 

49 """Mixin for common async logic""" 

50 

51 _current_loop: Any = None 

52 _Future: Type[Future] 

53 

54 def _get_loop(self) -> Any: 

55 """Get event loop 

56 

57 Notice if event loop has changed, 

58 and register init_io_state on activation of a new event loop 

59 """ 

60 if self._current_loop is None: 

61 self._current_loop = self._default_loop() 

62 self._init_io_state(self._current_loop) 

63 return self._current_loop 

64 current_loop = self._default_loop() 

65 if current_loop is not self._current_loop: 

66 # warn? This means a socket is being used in multiple loops! 

67 self._current_loop = current_loop 

68 self._init_io_state(current_loop) 

69 return current_loop 

70 

71 def _default_loop(self) -> Any: 

72 raise NotImplementedError("Must be implemented in a subclass") 

73 

74 def _init_io_state(self, loop=None) -> None: 

75 pass 

76 

77 

78class _AsyncPoller(_Async, _zmq.Poller): 

79 """Poller that returns a Future on poll, instead of blocking.""" 

80 

81 _socket_class: Type["_AsyncSocket"] 

82 _READ: int 

83 _WRITE: int 

84 raw_sockets: List[Any] 

85 

86 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None: 

87 """Schedule callback for a raw socket""" 

88 raise NotImplementedError() 

89 

90 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None: 

91 """Unschedule callback for a raw socket""" 

92 raise NotImplementedError() 

93 

94 def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore 

95 """Return a Future for a poll event""" 

96 future = self._Future() 

97 if timeout == 0: 

98 try: 

99 result = super().poll(0) 

100 except Exception as e: 

101 future.set_exception(e) 

102 else: 

103 future.set_result(result) 

104 return future 

105 

106 loop = self._get_loop() 

107 

108 # register Future to be called as soon as any event is available on any socket 

109 watcher = self._Future() 

110 

111 # watch raw sockets: 

112 raw_sockets: List[Any] = [] 

113 

114 def wake_raw(*args): 

115 if not watcher.done(): 

116 watcher.set_result(None) 

117 

118 watcher.add_done_callback( 

119 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets) 

120 ) 

121 

122 for socket, mask in self.sockets: 

123 if isinstance(socket, _zmq.Socket): 

124 if not isinstance(socket, self._socket_class): 

125 # it's a blocking zmq.Socket, wrap it in async 

126 socket = self._socket_class.from_socket(socket) 

127 if mask & _zmq.POLLIN: 

128 socket._add_recv_event('poll', future=watcher) 

129 if mask & _zmq.POLLOUT: 

130 socket._add_send_event('poll', future=watcher) 

131 else: 

132 raw_sockets.append(socket) 

133 evt = 0 

134 if mask & _zmq.POLLIN: 

135 evt |= self._READ 

136 if mask & _zmq.POLLOUT: 

137 evt |= self._WRITE 

138 self._watch_raw_socket(loop, socket, evt, wake_raw) 

139 

140 def on_poll_ready(f): 

141 if future.done(): 

142 return 

143 if watcher.cancelled(): 

144 try: 

145 future.cancel() 

146 except RuntimeError: 

147 # RuntimeError may be called during teardown 

148 pass 

149 return 

150 if watcher.exception(): 

151 future.set_exception(watcher.exception()) 

152 else: 

153 try: 

154 result = super(_AsyncPoller, self).poll(0) 

155 except Exception as e: 

156 future.set_exception(e) 

157 else: 

158 future.set_result(result) 

159 

160 watcher.add_done_callback(on_poll_ready) 

161 

162 if timeout is not None and timeout > 0: 

163 # schedule cancel to fire on poll timeout, if any 

164 def trigger_timeout(): 

165 if not watcher.done(): 

166 watcher.set_result(None) 

167 

168 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout) 

169 

170 def cancel_timeout(f): 

171 if hasattr(timeout_handle, 'cancel'): 

172 timeout_handle.cancel() 

173 else: 

174 loop.remove_timeout(timeout_handle) 

175 

176 future.add_done_callback(cancel_timeout) 

177 

178 def cancel_watcher(f): 

179 if not watcher.done(): 

180 watcher.cancel() 

181 

182 future.add_done_callback(cancel_watcher) 

183 

184 return future 

185 

186 

187class _NoTimer: 

188 @staticmethod 

189 def cancel(): 

190 pass 

191 

192 

193T = TypeVar("T", bound="_AsyncSocket") 

194 

195 

196class _AsyncSocket(_Async, _zmq.Socket[Future]): 

197 # Warning : these class variables are only here to allow to call super().__setattr__. 

198 # They be overridden at instance initialization and not shared in the whole class 

199 _recv_futures = None 

200 _send_futures = None 

201 _state = 0 

202 _shadow_sock: "_zmq.Socket" 

203 _poller_class = _AsyncPoller 

204 _fd = None 

205 

206 def __init__( 

207 self, 

208 context=None, 

209 socket_type=-1, 

210 io_loop=None, 

211 _from_socket: Optional["_zmq.Socket"] = None, 

212 **kwargs, 

213 ) -> None: 

214 if isinstance(context, _zmq.Socket): 

215 context, _from_socket = (None, context) 

216 if _from_socket is not None: 

217 super().__init__(shadow=_from_socket.underlying) # type: ignore 

218 self._shadow_sock = _from_socket 

219 else: 

220 super().__init__(context, socket_type, **kwargs) # type: ignore 

221 self._shadow_sock = _zmq.Socket.shadow(self.underlying) 

222 

223 if io_loop is not None: 

224 warnings.warn( 

225 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2." 

226 " The currently active loop will always be used.", 

227 DeprecationWarning, 

228 stacklevel=3, 

229 ) 

230 self._recv_futures = deque() 

231 self._send_futures = deque() 

232 self._state = 0 

233 self._fd = self._shadow_sock.FD 

234 

235 @classmethod 

236 def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T: 

237 """Create an async socket from an existing Socket""" 

238 return cls(_from_socket=socket, io_loop=io_loop) 

239 

240 def close(self, linger: Optional[int] = None) -> None: 

241 if not self.closed and self._fd is not None: 

242 event_list: List[_FutureEvent] = list( 

243 chain(self._recv_futures or [], self._send_futures or []) 

244 ) 

245 for event in event_list: 

246 if not event.future.done(): 

247 try: 

248 event.future.cancel() 

249 except RuntimeError: 

250 # RuntimeError may be called during teardown 

251 pass 

252 self._clear_io_state() 

253 super().close(linger=linger) 

254 

255 close.__doc__ = _zmq.Socket.close.__doc__ 

256 

257 def get(self, key): 

258 result = super().get(key) 

259 if key == EVENTS: 

260 self._schedule_remaining_events(result) 

261 return result 

262 

263 get.__doc__ = _zmq.Socket.get.__doc__ 

264 

265 @overload # type: ignore 

266 def recv_multipart( 

267 self, flags: int = 0, *, track: bool = False 

268 ) -> Awaitable[List[bytes]]: 

269 ... 

270 

271 @overload 

272 def recv_multipart( 

273 self, flags: int = 0, *, copy: Literal[True], track: bool = False 

274 ) -> Awaitable[List[bytes]]: 

275 ... 

276 

277 @overload 

278 def recv_multipart( 

279 self, flags: int = 0, *, copy: Literal[False], track: bool = False 

280 ) -> Awaitable[List[_zmq.Frame]]: # type: ignore 

281 ... 

282 

283 @overload 

284 def recv_multipart( 

285 self, flags: int = 0, copy: bool = True, track: bool = False 

286 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]: 

287 ... 

288 

289 def recv_multipart( 

290 self, flags: int = 0, copy: bool = True, track: bool = False 

291 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]: 

292 """Receive a complete multipart zmq message. 

293 

294 Returns a Future whose result will be a multipart message. 

295 """ 

296 return self._add_recv_event( 

297 'recv_multipart', dict(flags=flags, copy=copy, track=track) 

298 ) 

299 

300 def recv( # type: ignore 

301 self, flags: int = 0, copy: bool = True, track: bool = False 

302 ) -> Awaitable[Union[bytes, _zmq.Frame]]: 

303 """Receive a single zmq frame. 

304 

305 Returns a Future, whose result will be the received frame. 

306 

307 Recommend using recv_multipart instead. 

308 """ 

309 return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track)) 

310 

311 def send_multipart( # type: ignore 

312 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs 

313 ) -> Awaitable[Optional[_zmq.MessageTracker]]: 

314 """Send a complete multipart zmq message. 

315 

316 Returns a Future that resolves when sending is complete. 

317 """ 

318 kwargs['flags'] = flags 

319 kwargs['copy'] = copy 

320 kwargs['track'] = track 

321 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs) 

322 

323 def send( # type: ignore 

324 self, 

325 data: Any, 

326 flags: int = 0, 

327 copy: bool = True, 

328 track: bool = False, 

329 **kwargs: Any, 

330 ) -> Awaitable[Optional[_zmq.MessageTracker]]: 

331 """Send a single zmq frame. 

332 

333 Returns a Future that resolves when sending is complete. 

334 

335 Recommend using send_multipart instead. 

336 """ 

337 kwargs['flags'] = flags 

338 kwargs['copy'] = copy 

339 kwargs['track'] = track 

340 kwargs.update(dict(flags=flags, copy=copy, track=track)) 

341 return self._add_send_event('send', msg=data, kwargs=kwargs) 

342 

343 def _deserialize(self, recvd, load): 

344 """Deserialize with Futures""" 

345 f = self._Future() 

346 

347 def _chain(_): 

348 """Chain result through serialization to recvd""" 

349 if f.done(): 

350 return 

351 if recvd.exception(): 

352 f.set_exception(recvd.exception()) 

353 else: 

354 buf = recvd.result() 

355 try: 

356 loaded = load(buf) 

357 except Exception as e: 

358 f.set_exception(e) 

359 else: 

360 f.set_result(loaded) 

361 

362 recvd.add_done_callback(_chain) 

363 

364 def _chain_cancel(_): 

365 """Chain cancellation from f to recvd""" 

366 if recvd.done(): 

367 return 

368 if f.cancelled(): 

369 recvd.cancel() 

370 

371 f.add_done_callback(_chain_cancel) 

372 

373 return f 

374 

375 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore 

376 """poll the socket for events 

377 

378 returns a Future for the poll results. 

379 """ 

380 

381 if self.closed: 

382 raise _zmq.ZMQError(_zmq.ENOTSUP) 

383 

384 p = self._poller_class() 

385 p.register(self, flags) 

386 f = cast(Future, p.poll(timeout)) 

387 

388 future = self._Future() 

389 

390 def unwrap_result(f): 

391 if future.done(): 

392 return 

393 if f.cancelled(): 

394 try: 

395 future.cancel() 

396 except RuntimeError: 

397 # RuntimeError may be called during teardown 

398 pass 

399 return 

400 if f.exception(): 

401 future.set_exception(f.exception()) 

402 else: 

403 evts = dict(f.result()) 

404 future.set_result(evts.get(self, 0)) 

405 

406 if f.done(): 

407 # hook up result if 

408 unwrap_result(f) 

409 else: 

410 f.add_done_callback(unwrap_result) 

411 return future 

412 

413 # overrides only necessary for updated types 

414 def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore 

415 return super().recv_string(*args, **kwargs) # type: ignore 

416 

417 def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore 

418 return super().send_string(s, flags=flags, encoding=encoding) # type: ignore 

419 

420 def _add_timeout(self, future, timeout): 

421 """Add a timeout for a send or recv Future""" 

422 

423 def future_timeout(): 

424 if future.done(): 

425 # future already resolved, do nothing 

426 return 

427 

428 # raise EAGAIN 

429 future.set_exception(_zmq.Again()) 

430 

431 return self._call_later(timeout, future_timeout) 

432 

433 def _call_later(self, delay, callback): 

434 """Schedule a function to be called later 

435 

436 Override for different IOLoop implementations 

437 

438 Tornado and asyncio happen to both have ioloop.call_later 

439 with the same signature. 

440 """ 

441 return self._get_loop().call_later(delay, callback) 

442 

443 @staticmethod 

444 def _remove_finished_future(future, event_list): 

445 """Make sure that futures are removed from the event list when they resolve 

446 

447 Avoids delaying cleanup until the next send/recv event, 

448 which may never come. 

449 """ 

450 for f_idx, event in enumerate(event_list): 

451 if event.future is future: 

452 break 

453 else: 

454 return 

455 

456 # "future" instance is shared between sockets, but each socket has its own event list. 

457 event_list.remove(event_list[f_idx]) 

458 

459 def _add_recv_event(self, kind, kwargs=None, future=None): 

460 """Add a recv event, returning the corresponding Future""" 

461 f = future or self._Future() 

462 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT: 

463 # short-circuit non-blocking calls 

464 recv = getattr(self._shadow_sock, kind) 

465 try: 

466 r = recv(**kwargs) 

467 except Exception as e: 

468 f.set_exception(e) 

469 else: 

470 f.set_result(r) 

471 return f 

472 

473 timer = _NoTimer 

474 if hasattr(_zmq, 'RCVTIMEO'): 

475 timeout_ms = self._shadow_sock.rcvtimeo 

476 if timeout_ms >= 0: 

477 timer = self._add_timeout(f, timeout_ms * 1e-3) 

478 

479 # we add it to the list of futures before we add the timeout as the 

480 # timeout will remove the future from recv_futures to avoid leaks 

481 self._recv_futures.append(_FutureEvent(f, kind, kwargs, msg=None, timer=timer)) 

482 

483 # Don't let the Future sit in _recv_events after it's done 

484 f.add_done_callback( 

485 lambda f: self._remove_finished_future(f, self._recv_futures) 

486 ) 

487 

488 if self._shadow_sock.get(EVENTS) & POLLIN: 

489 # recv immediately, if we can 

490 self._handle_recv() 

491 if self._recv_futures: 

492 self._add_io_state(POLLIN) 

493 return f 

494 

495 def _add_send_event(self, kind, msg=None, kwargs=None, future=None): 

496 """Add a send event, returning the corresponding Future""" 

497 f = future or self._Future() 

498 # attempt send with DONTWAIT if no futures are waiting 

499 # short-circuit for sends that will resolve immediately 

500 # only call if no send Futures are waiting 

501 if kind in ('send', 'send_multipart') and not self._send_futures: 

502 flags = kwargs.get('flags', 0) 

503 nowait_kwargs = kwargs.copy() 

504 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT 

505 

506 # short-circuit non-blocking calls 

507 send = getattr(self._shadow_sock, kind) 

508 # track if the send resolved or not 

509 # (EAGAIN if DONTWAIT is not set should proceed with) 

510 finish_early = True 

511 try: 

512 r = send(msg, **nowait_kwargs) 

513 except _zmq.Again as e: 

514 if flags & _zmq.DONTWAIT: 

515 f.set_exception(e) 

516 else: 

517 # EAGAIN raised and DONTWAIT not requested, 

518 # proceed with async send 

519 finish_early = False 

520 except Exception as e: 

521 f.set_exception(e) 

522 else: 

523 f.set_result(r) 

524 

525 if finish_early: 

526 # short-circuit resolved, return finished Future 

527 # schedule wake for recv if there are any receivers waiting 

528 if self._recv_futures: 

529 self._schedule_remaining_events() 

530 return f 

531 

532 timer = _NoTimer 

533 if hasattr(_zmq, 'SNDTIMEO'): 

534 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO) 

535 if timeout_ms >= 0: 

536 timer = self._add_timeout(f, timeout_ms * 1e-3) 

537 

538 # we add it to the list of futures before we add the timeout as the 

539 # timeout will remove the future from recv_futures to avoid leaks 

540 self._send_futures.append( 

541 _FutureEvent(f, kind, kwargs=kwargs, msg=msg, timer=timer) 

542 ) 

543 # Don't let the Future sit in _send_futures after it's done 

544 f.add_done_callback( 

545 lambda f: self._remove_finished_future(f, self._send_futures) 

546 ) 

547 

548 self._add_io_state(POLLOUT) 

549 return f 

550 

551 def _handle_recv(self): 

552 """Handle recv events""" 

553 if not self._shadow_sock.get(EVENTS) & POLLIN: 

554 # event triggered, but state may have been changed between trigger and callback 

555 return 

556 f = None 

557 while self._recv_futures: 

558 f, kind, kwargs, _, timer = self._recv_futures.popleft() 

559 # skip any cancelled futures 

560 if f.done(): 

561 f = None 

562 else: 

563 break 

564 

565 if not self._recv_futures: 

566 self._drop_io_state(POLLIN) 

567 

568 if f is None: 

569 return 

570 

571 timer.cancel() 

572 

573 if kind == 'poll': 

574 # on poll event, just signal ready, nothing else. 

575 f.set_result(None) 

576 return 

577 elif kind == 'recv_multipart': 

578 recv = self._shadow_sock.recv_multipart 

579 elif kind == 'recv': 

580 recv = self._shadow_sock.recv 

581 else: 

582 raise ValueError("Unhandled recv event type: %r" % kind) 

583 

584 kwargs['flags'] |= _zmq.DONTWAIT 

585 try: 

586 result = recv(**kwargs) 

587 except Exception as e: 

588 f.set_exception(e) 

589 else: 

590 f.set_result(result) 

591 

592 def _handle_send(self): 

593 if not self._shadow_sock.get(EVENTS) & POLLOUT: 

594 # event triggered, but state may have been changed between trigger and callback 

595 return 

596 f = None 

597 while self._send_futures: 

598 f, kind, kwargs, msg, timer = self._send_futures.popleft() 

599 # skip any cancelled futures 

600 if f.done(): 

601 f = None 

602 else: 

603 break 

604 

605 if not self._send_futures: 

606 self._drop_io_state(POLLOUT) 

607 

608 if f is None: 

609 return 

610 

611 timer.cancel() 

612 

613 if kind == 'poll': 

614 # on poll event, just signal ready, nothing else. 

615 f.set_result(None) 

616 return 

617 elif kind == 'send_multipart': 

618 send = self._shadow_sock.send_multipart 

619 elif kind == 'send': 

620 send = self._shadow_sock.send 

621 else: 

622 raise ValueError("Unhandled send event type: %r" % kind) 

623 

624 kwargs['flags'] |= _zmq.DONTWAIT 

625 try: 

626 result = send(msg, **kwargs) 

627 except Exception as e: 

628 f.set_exception(e) 

629 else: 

630 f.set_result(result) 

631 

632 # event masking from ZMQStream 

633 def _handle_events(self, fd=0, events=0): 

634 """Dispatch IO events to _handle_recv, etc.""" 

635 if self._shadow_sock.closed: 

636 return 

637 

638 zmq_events = self._shadow_sock.get(EVENTS) 

639 if zmq_events & _zmq.POLLIN: 

640 self._handle_recv() 

641 if zmq_events & _zmq.POLLOUT: 

642 self._handle_send() 

643 self._schedule_remaining_events() 

644 

645 def _schedule_remaining_events(self, events=None): 

646 """Schedule a call to handle_events next loop iteration 

647 

648 If there are still events to handle. 

649 """ 

650 # edge-triggered handling 

651 # allow passing events in, in case this is triggered by retrieving events, 

652 # so we don't have to retrieve it twice. 

653 if self._state == 0: 

654 # not watching for anything, nothing to schedule 

655 return 

656 if events is None: 

657 events = self._shadow_sock.get(EVENTS) 

658 if events & self._state: 

659 self._call_later(0, self._handle_events) 

660 

661 def _add_io_state(self, state): 

662 """Add io_state to poller.""" 

663 if self._state != state: 

664 state = self._state = self._state | state 

665 self._update_handler(self._state) 

666 

667 def _drop_io_state(self, state): 

668 """Stop poller from watching an io_state.""" 

669 if self._state & state: 

670 self._state = self._state & (~state) 

671 self._update_handler(self._state) 

672 

673 def _update_handler(self, state): 

674 """Update IOLoop handler with state. 

675 

676 zmq FD is always read-only. 

677 """ 

678 # ensure loop is registered and init_io has been called 

679 # if there are any events to watch for 

680 if state: 

681 self._get_loop() 

682 self._schedule_remaining_events() 

683 

684 def _init_io_state(self, loop=None): 

685 """initialize the ioloop event handler""" 

686 if loop is None: 

687 loop = self._get_loop() 

688 loop.add_handler(self._shadow_sock, self._handle_events, self._READ) 

689 self._call_later(0, self._handle_events) 

690 

691 def _clear_io_state(self): 

692 """unregister the ioloop event handler 

693 

694 called once during close 

695 """ 

696 fd = self._shadow_sock 

697 if self._shadow_sock.closed: 

698 fd = self._fd 

699 if self._current_loop is not None: 

700 self._current_loop.remove_handler(fd)