Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/eventloop/zmqstream.py: 26%

284 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-01 06:54 +0000

1# 

2# Copyright 2009 Facebook 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); you may 

5# not use this file except in compliance with the License. You may obtain 

6# a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

13# License for the specific language governing permissions and limitations 

14# under the License. 

15 

16"""A utility class for event-based messaging on a zmq socket using tornado. 

17 

18.. seealso:: 

19 

20 - :mod:`zmq.asyncio` 

21 - :mod:`zmq.eventloop.future` 

22""" 

23 

24import asyncio 

25import pickle 

26import warnings 

27from queue import Queue 

28from typing import ( 

29 Any, 

30 Awaitable, 

31 Callable, 

32 List, 

33 Optional, 

34 Sequence, 

35 Union, 

36 cast, 

37 overload, 

38) 

39 

40from tornado.ioloop import IOLoop 

41from tornado.log import gen_log 

42 

43import zmq 

44import zmq._future 

45from zmq import POLLIN, POLLOUT 

46from zmq._typing import Literal 

47from zmq.utils import jsonapi 

48 

49 

50class ZMQStream: 

51 """A utility class to register callbacks when a zmq socket sends and receives 

52 

53 For use with tornado IOLoop. 

54 

55 There are three main methods 

56 

57 Methods: 

58 

59 * **on_recv(callback, copy=True):** 

60 register a callback to be run every time the socket has something to receive 

61 * **on_send(callback):** 

62 register a callback to be run every time you call send 

63 * **send_multipart(self, msg, flags=0, copy=False, callback=None):** 

64 perform a send that will trigger the callback 

65 if callback is passed, on_send is also called. 

66 

67 There are also send_multipart(), send_json(), send_pyobj() 

68 

69 Three other methods for deactivating the callbacks: 

70 

71 * **stop_on_recv():** 

72 turn off the recv callback 

73 * **stop_on_send():** 

74 turn off the send callback 

75 

76 which simply call ``on_<evt>(None)``. 

77 

78 The entire socket interface, excluding direct recv methods, is also 

79 provided, primarily through direct-linking the methods. 

80 e.g. 

81 

82 >>> stream.bind is stream.socket.bind 

83 True 

84 

85 

86 .. versionadded:: 25 

87 

88 send/recv callbacks can be coroutines. 

89 

90 .. versionchanged:: 25 

91 

92 ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced). 

93 If ZMQStreams are created with e.g. async Socket subclasses, 

94 a RuntimeWarning will be shown, 

95 and the socket cast back to the default zmq.Socket 

96 before connecting events. 

97 

98 Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the 

99 arguments passed to callback functions. 

100 Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods 

101 (the list of message frames). 

102 """ 

103 

104 socket: zmq.Socket 

105 io_loop: IOLoop 

106 poller: zmq.Poller 

107 _send_queue: Queue 

108 _recv_callback: Optional[Callable] 

109 _send_callback: Optional[Callable] 

110 _close_callback: Optional[Callable] 

111 _state: int = 0 

112 _flushed: bool = False 

113 _recv_copy: bool = False 

114 _fd: int 

115 

116 def __init__(self, socket: "zmq.Socket", io_loop: Optional[IOLoop] = None): 

117 if isinstance(socket, zmq._future._AsyncSocket): 

118 warnings.warn( 

119 f"""ZMQStream only supports the base zmq.Socket class. 

120 

121 Use zmq.Socket(shadow=other_socket) 

122 or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)` 

123 to create a base zmq.Socket object, 

124 no matter what other kind of socket your Context creates. 

125 """, 

126 RuntimeWarning, 

127 stacklevel=2, 

128 ) 

129 # shadow back to base zmq.Socket, 

130 # otherwise callbacks like `on_recv` will get the wrong types. 

131 socket = zmq.Socket(shadow=socket) 

132 self.socket = socket 

133 

134 # IOLoop.current() is deprecated if called outside the event loop 

135 # that means 

136 self.io_loop = io_loop or IOLoop.current() 

137 self.poller = zmq.Poller() 

138 self._fd = cast(int, self.socket.FD) 

139 

140 self._send_queue = Queue() 

141 self._recv_callback = None 

142 self._send_callback = None 

143 self._close_callback = None 

144 self._recv_copy = False 

145 self._flushed = False 

146 

147 self._state = 0 

148 self._init_io_state() 

149 

150 # shortcircuit some socket methods 

151 self.bind = self.socket.bind 

152 self.bind_to_random_port = self.socket.bind_to_random_port 

153 self.connect = self.socket.connect 

154 self.setsockopt = self.socket.setsockopt 

155 self.getsockopt = self.socket.getsockopt 

156 self.setsockopt_string = self.socket.setsockopt_string 

157 self.getsockopt_string = self.socket.getsockopt_string 

158 self.setsockopt_unicode = self.socket.setsockopt_unicode 

159 self.getsockopt_unicode = self.socket.getsockopt_unicode 

160 

161 def stop_on_recv(self): 

162 """Disable callback and automatic receiving.""" 

163 return self.on_recv(None) 

164 

165 def stop_on_send(self): 

166 """Disable callback on sending.""" 

167 return self.on_send(None) 

168 

169 def stop_on_err(self): 

170 """DEPRECATED, does nothing""" 

171 gen_log.warn("on_err does nothing, and will be removed") 

172 

173 def on_err(self, callback: Callable): 

174 """DEPRECATED, does nothing""" 

175 gen_log.warn("on_err does nothing, and will be removed") 

176 

177 @overload 

178 def on_recv( 

179 self, 

180 callback: Callable[[List[bytes]], Any], 

181 ) -> None: 

182 ... 

183 

184 @overload 

185 def on_recv( 

186 self, 

187 callback: Callable[[List[bytes]], Any], 

188 copy: Literal[True], 

189 ) -> None: 

190 ... 

191 

192 @overload 

193 def on_recv( 

194 self, 

195 callback: Callable[[List[zmq.Frame]], Any], 

196 copy: Literal[False], 

197 ) -> None: 

198 ... 

199 

200 @overload 

201 def on_recv( 

202 self, 

203 callback: Union[ 

204 Callable[[List[zmq.Frame]], Any], 

205 Callable[[List[bytes]], Any], 

206 ], 

207 copy: bool = ..., 

208 ): 

209 ... 

210 

211 def on_recv( 

212 self, 

213 callback: Union[ 

214 Callable[[List[zmq.Frame]], Any], 

215 Callable[[List[bytes]], Any], 

216 ], 

217 copy: bool = True, 

218 ) -> None: 

219 """Register a callback for when a message is ready to recv. 

220 

221 There can be only one callback registered at a time, so each 

222 call to `on_recv` replaces previously registered callbacks. 

223 

224 on_recv(None) disables recv event polling. 

225 

226 Use on_recv_stream(callback) instead, to register a callback that will receive 

227 both this ZMQStream and the message, instead of just the message. 

228 

229 Parameters 

230 ---------- 

231 

232 callback : callable 

233 callback must take exactly one argument, which will be a 

234 list, as returned by socket.recv_multipart() 

235 if callback is None, recv callbacks are disabled. 

236 copy : bool 

237 copy is passed directly to recv, so if copy is False, 

238 callback will receive Message objects. If copy is True, 

239 then callback will receive bytes/str objects. 

240 

241 Returns : None 

242 """ 

243 

244 self._check_closed() 

245 assert callback is None or callable(callback) 

246 self._recv_callback = callback 

247 self._recv_copy = copy 

248 if callback is None: 

249 self._drop_io_state(zmq.POLLIN) 

250 else: 

251 self._add_io_state(zmq.POLLIN) 

252 

253 @overload 

254 def on_recv_stream( 

255 self, 

256 callback: Callable[["ZMQStream", List[bytes]], Any], 

257 ) -> None: 

258 ... 

259 

260 @overload 

261 def on_recv_stream( 

262 self, 

263 callback: Callable[["ZMQStream", List[bytes]], Any], 

264 copy: Literal[True], 

265 ) -> None: 

266 ... 

267 

268 @overload 

269 def on_recv_stream( 

270 self, 

271 callback: Callable[["ZMQStream", List[zmq.Frame]], Any], 

272 copy: Literal[False], 

273 ) -> None: 

274 ... 

275 

276 @overload 

277 def on_recv_stream( 

278 self, 

279 callback: Union[ 

280 Callable[["ZMQStream", List[zmq.Frame]], Any], 

281 Callable[["ZMQStream", List[bytes]], Any], 

282 ], 

283 copy: bool = ..., 

284 ): 

285 ... 

286 

287 def on_recv_stream( 

288 self, 

289 callback: Union[ 

290 Callable[["ZMQStream", List[zmq.Frame]], Any], 

291 Callable[["ZMQStream", List[bytes]], Any], 

292 ], 

293 copy: bool = True, 

294 ): 

295 """Same as on_recv, but callback will get this stream as first argument 

296 

297 callback must take exactly two arguments, as it will be called as:: 

298 

299 callback(stream, msg) 

300 

301 Useful when a single callback should be used with multiple streams. 

302 """ 

303 if callback is None: 

304 self.stop_on_recv() 

305 else: 

306 

307 def stream_callback(msg): 

308 return callback(self, msg) 

309 

310 self.on_recv(stream_callback, copy=copy) 

311 

312 def on_send( 

313 self, callback: Callable[[Sequence[Any], Optional[zmq.MessageTracker]], Any] 

314 ): 

315 """Register a callback to be called on each send 

316 

317 There will be two arguments:: 

318 

319 callback(msg, status) 

320 

321 * `msg` will be the list of sendable objects that was just sent 

322 * `status` will be the return result of socket.send_multipart(msg) - 

323 MessageTracker or None. 

324 

325 Non-copying sends return a MessageTracker object whose 

326 `done` attribute will be True when the send is complete. 

327 This allows users to track when an object is safe to write to 

328 again. 

329 

330 The second argument will always be None if copy=True 

331 on the send. 

332 

333 Use on_send_stream(callback) to register a callback that will be passed 

334 this ZMQStream as the first argument, in addition to the other two. 

335 

336 on_send(None) disables recv event polling. 

337 

338 Parameters 

339 ---------- 

340 

341 callback : callable 

342 callback must take exactly two arguments, which will be 

343 the message being sent (always a list), 

344 and the return result of socket.send_multipart(msg) - 

345 MessageTracker or None. 

346 

347 if callback is None, send callbacks are disabled. 

348 """ 

349 

350 self._check_closed() 

351 assert callback is None or callable(callback) 

352 self._send_callback = callback 

353 

354 def on_send_stream( 

355 self, 

356 callback: Callable[ 

357 ["ZMQStream", Sequence[Any], Optional[zmq.MessageTracker]], Any 

358 ], 

359 ): 

360 """Same as on_send, but callback will get this stream as first argument 

361 

362 Callback will be passed three arguments:: 

363 

364 callback(stream, msg, status) 

365 

366 Useful when a single callback should be used with multiple streams. 

367 """ 

368 if callback is None: 

369 self.stop_on_send() 

370 else: 

371 self.on_send(lambda msg, status: callback(self, msg, status)) 

372 

373 def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs): 

374 """Send a message, optionally also register a new callback for sends. 

375 See zmq.socket.send for details. 

376 """ 

377 return self.send_multipart( 

378 [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs 

379 ) 

380 

381 def send_multipart( 

382 self, 

383 msg: Sequence[Any], 

384 flags: int = 0, 

385 copy: bool = True, 

386 track: bool = False, 

387 callback: Optional[Callable] = None, 

388 **kwargs: Any, 

389 ) -> None: 

390 """Send a multipart message, optionally also register a new callback for sends. 

391 See zmq.socket.send_multipart for details. 

392 """ 

393 kwargs.update(dict(flags=flags, copy=copy, track=track)) 

394 self._send_queue.put((msg, kwargs)) 

395 callback = callback or self._send_callback 

396 if callback is not None: 

397 self.on_send(callback) 

398 else: 

399 # noop callback 

400 self.on_send(lambda *args: None) 

401 self._add_io_state(zmq.POLLOUT) 

402 

403 def send_string( 

404 self, 

405 u: str, 

406 flags: int = 0, 

407 encoding: str = 'utf-8', 

408 callback: Optional[Callable] = None, 

409 **kwargs: Any, 

410 ): 

411 """Send a unicode message with an encoding. 

412 See zmq.socket.send_unicode for details. 

413 """ 

414 if not isinstance(u, str): 

415 raise TypeError("unicode/str objects only") 

416 return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs) 

417 

418 send_unicode = send_string 

419 

420 def send_json( 

421 self, 

422 obj: Any, 

423 flags: int = 0, 

424 callback: Optional[Callable] = None, 

425 **kwargs: Any, 

426 ): 

427 """Send json-serialized version of an object. 

428 See zmq.socket.send_json for details. 

429 """ 

430 msg = jsonapi.dumps(obj) 

431 return self.send(msg, flags=flags, callback=callback, **kwargs) 

432 

433 def send_pyobj( 

434 self, 

435 obj: Any, 

436 flags: int = 0, 

437 protocol: int = -1, 

438 callback: Optional[Callable] = None, 

439 **kwargs: Any, 

440 ): 

441 """Send a Python object as a message using pickle to serialize. 

442 

443 See zmq.socket.send_json for details. 

444 """ 

445 msg = pickle.dumps(obj, protocol) 

446 return self.send(msg, flags, callback=callback, **kwargs) 

447 

448 def _finish_flush(self): 

449 """callback for unsetting _flushed flag.""" 

450 self._flushed = False 

451 

452 def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: Optional[int] = None): 

453 """Flush pending messages. 

454 

455 This method safely handles all pending incoming and/or outgoing messages, 

456 bypassing the inner loop, passing them to the registered callbacks. 

457 

458 A limit can be specified, to prevent blocking under high load. 

459 

460 flush will return the first time ANY of these conditions are met: 

461 * No more events matching the flag are pending. 

462 * the total number of events handled reaches the limit. 

463 

464 Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback 

465 is registered, unlike normal IOLoop operation. This allows flush to be 

466 used to remove *and ignore* incoming messages. 

467 

468 Parameters 

469 ---------- 

470 flag : int, default=POLLIN|POLLOUT 

471 0MQ poll flags. 

472 If flag|POLLIN, recv events will be flushed. 

473 If flag|POLLOUT, send events will be flushed. 

474 Both flags can be set at once, which is the default. 

475 limit : None or int, optional 

476 The maximum number of messages to send or receive. 

477 Both send and recv count against this limit. 

478 

479 Returns 

480 ------- 

481 int : count of events handled (both send and recv) 

482 """ 

483 self._check_closed() 

484 # unset self._flushed, so callbacks will execute, in case flush has 

485 # already been called this iteration 

486 already_flushed = self._flushed 

487 self._flushed = False 

488 # initialize counters 

489 count = 0 

490 

491 def update_flag(): 

492 """Update the poll flag, to prevent registering POLLOUT events 

493 if we don't have pending sends.""" 

494 return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) 

495 

496 flag = update_flag() 

497 if not flag: 

498 # nothing to do 

499 return 0 

500 self.poller.register(self.socket, flag) 

501 events = self.poller.poll(0) 

502 while events and (not limit or count < limit): 

503 s, event = events[0] 

504 if event & POLLIN: # receiving 

505 self._handle_recv() 

506 count += 1 

507 if self.socket is None: 

508 # break if socket was closed during callback 

509 break 

510 if event & POLLOUT and self.sending(): 

511 self._handle_send() 

512 count += 1 

513 if self.socket is None: 

514 # break if socket was closed during callback 

515 break 

516 

517 flag = update_flag() 

518 if flag: 

519 self.poller.register(self.socket, flag) 

520 events = self.poller.poll(0) 

521 else: 

522 events = [] 

523 if count: # only bypass loop if we actually flushed something 

524 # skip send/recv callbacks this iteration 

525 self._flushed = True 

526 # reregister them at the end of the loop 

527 if not already_flushed: # don't need to do it again 

528 self.io_loop.add_callback(self._finish_flush) 

529 elif already_flushed: 

530 self._flushed = True 

531 

532 # update ioloop poll state, which may have changed 

533 self._rebuild_io_state() 

534 return count 

535 

536 def set_close_callback(self, callback: Optional[Callable]): 

537 """Call the given callback when the stream is closed.""" 

538 self._close_callback = callback 

539 

540 def close(self, linger: Optional[int] = None) -> None: 

541 """Close this stream.""" 

542 if self.socket is not None: 

543 if self.socket.closed: 

544 # fallback on raw fd for closed sockets 

545 # hopefully this happened promptly after close, 

546 # otherwise somebody else may have the FD 

547 warnings.warn( 

548 "Unregistering FD %s after closing socket. " 

549 "This could result in unregistering handlers for the wrong socket. " 

550 "Please use stream.close() instead of closing the socket directly." 

551 % self._fd, 

552 stacklevel=2, 

553 ) 

554 self.io_loop.remove_handler(self._fd) 

555 else: 

556 self.io_loop.remove_handler(self.socket) 

557 self.socket.close(linger) 

558 self.socket = None # type: ignore 

559 if self._close_callback: 

560 self._run_callback(self._close_callback) 

561 

562 def receiving(self) -> bool: 

563 """Returns True if we are currently receiving from the stream.""" 

564 return self._recv_callback is not None 

565 

566 def sending(self) -> bool: 

567 """Returns True if we are currently sending to the stream.""" 

568 return not self._send_queue.empty() 

569 

570 def closed(self) -> bool: 

571 if self.socket is None: 

572 return True 

573 if self.socket.closed: 

574 # underlying socket has been closed, but not by us! 

575 # trigger our cleanup 

576 self.close() 

577 return True 

578 return False 

579 

580 def _run_callback(self, callback, *args, **kwargs): 

581 """Wrap running callbacks in try/except to allow us to 

582 close our socket.""" 

583 try: 

584 f = callback(*args, **kwargs) 

585 if isinstance(f, Awaitable): 

586 f = asyncio.ensure_future(f) 

587 else: 

588 f = None 

589 except Exception: 

590 gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True) 

591 # Re-raise the exception so that IOLoop.handle_callback_exception 

592 # can see it and log the error 

593 raise 

594 

595 if f is not None: 

596 # handle async callbacks 

597 def _log_error(f): 

598 try: 

599 f.result() 

600 except Exception: 

601 gen_log.error( 

602 "Uncaught exception in ZMQStream callback", exc_info=True 

603 ) 

604 

605 f.add_done_callback(_log_error) 

606 

607 def _handle_events(self, fd, events): 

608 """This method is the actual handler for IOLoop, that gets called whenever 

609 an event on my socket is posted. It dispatches to _handle_recv, etc.""" 

610 if not self.socket: 

611 gen_log.warning("Got events for closed stream %s", self) 

612 return 

613 try: 

614 zmq_events = self.socket.EVENTS 

615 except zmq.ContextTerminated: 

616 gen_log.warning("Got events for stream %s after terminating context", self) 

617 # trigger close check, this will unregister callbacks 

618 self.closed() 

619 return 

620 except zmq.ZMQError as e: 

621 # run close check 

622 # shadow sockets may have been closed elsewhere, 

623 # which should show up as ENOTSOCK here 

624 if self.closed(): 

625 gen_log.warning( 

626 "Got events for stream %s attached to closed socket: %s", self, e 

627 ) 

628 else: 

629 gen_log.error("Error getting events for %s: %s", self, e) 

630 return 

631 try: 

632 # dispatch events: 

633 if zmq_events & zmq.POLLIN and self.receiving(): 

634 self._handle_recv() 

635 if not self.socket: 

636 return 

637 if zmq_events & zmq.POLLOUT and self.sending(): 

638 self._handle_send() 

639 if not self.socket: 

640 return 

641 

642 # rebuild the poll state 

643 self._rebuild_io_state() 

644 except Exception: 

645 gen_log.error("Uncaught exception in zmqstream callback", exc_info=True) 

646 raise 

647 

648 def _handle_recv(self): 

649 """Handle a recv event.""" 

650 if self._flushed: 

651 return 

652 try: 

653 msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) 

654 except zmq.ZMQError as e: 

655 if e.errno == zmq.EAGAIN: 

656 # state changed since poll event 

657 pass 

658 else: 

659 raise 

660 else: 

661 if self._recv_callback: 

662 callback = self._recv_callback 

663 self._run_callback(callback, msg) 

664 

665 def _handle_send(self): 

666 """Handle a send event.""" 

667 if self._flushed: 

668 return 

669 if not self.sending(): 

670 gen_log.error("Shouldn't have handled a send event") 

671 return 

672 

673 msg, kwargs = self._send_queue.get() 

674 try: 

675 status = self.socket.send_multipart(msg, **kwargs) 

676 except zmq.ZMQError as e: 

677 gen_log.error("SEND Error: %s", e) 

678 status = e 

679 if self._send_callback: 

680 callback = self._send_callback 

681 self._run_callback(callback, msg, status) 

682 

683 def _check_closed(self): 

684 if not self.socket: 

685 raise OSError("Stream is closed") 

686 

687 def _rebuild_io_state(self): 

688 """rebuild io state based on self.sending() and receiving()""" 

689 if self.socket is None: 

690 return 

691 state = 0 

692 if self.receiving(): 

693 state |= zmq.POLLIN 

694 if self.sending(): 

695 state |= zmq.POLLOUT 

696 

697 self._state = state 

698 self._update_handler(state) 

699 

700 def _add_io_state(self, state): 

701 """Add io_state to poller.""" 

702 self._state = self._state | state 

703 self._update_handler(self._state) 

704 

705 def _drop_io_state(self, state): 

706 """Stop poller from watching an io_state.""" 

707 self._state = self._state & (~state) 

708 self._update_handler(self._state) 

709 

710 def _update_handler(self, state): 

711 """Update IOLoop handler with state.""" 

712 if self.socket is None: 

713 return 

714 

715 if state & self.socket.events: 

716 # events still exist that haven't been processed 

717 # explicitly schedule handling to avoid missing events due to edge-triggered FDs 

718 self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0)) 

719 

720 def _init_io_state(self): 

721 """initialize the ioloop event handler""" 

722 self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)