Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/eventloop/zmqstream.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

277 statements  

1# 

2# Copyright 2009 Facebook 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); you may 

5# not use this file except in compliance with the License. You may obtain 

6# a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

13# License for the specific language governing permissions and limitations 

14# under the License. 

15"""A utility class for event-based messaging on a zmq socket using tornado. 

16 

17.. seealso:: 

18 

19 - :mod:`zmq.asyncio` 

20 - :mod:`zmq.eventloop.future` 

21""" 

22 

23from __future__ import annotations 

24 

25import asyncio 

26import pickle 

27import warnings 

28from queue import Queue 

29from typing import Any, Awaitable, Callable, Sequence, cast, overload 

30 

31from tornado.ioloop import IOLoop 

32from tornado.log import gen_log 

33 

34import zmq 

35import zmq._future 

36from zmq import POLLIN, POLLOUT 

37from zmq._typing import Literal 

38from zmq.utils import jsonapi 

39 

40 

41class ZMQStream: 

42 """A utility class to register callbacks when a zmq socket sends and receives 

43 

44 For use with tornado IOLoop. 

45 

46 There are three main methods 

47 

48 Methods: 

49 

50 * **on_recv(callback, copy=True):** 

51 register a callback to be run every time the socket has something to receive 

52 * **on_send(callback):** 

53 register a callback to be run every time you call send 

54 * **send_multipart(self, msg, flags=0, copy=False, callback=None):** 

55 perform a send that will trigger the callback 

56 if callback is passed, on_send is also called. 

57 

58 There are also send_multipart(), send_json(), send_pyobj() 

59 

60 Three other methods for deactivating the callbacks: 

61 

62 * **stop_on_recv():** 

63 turn off the recv callback 

64 * **stop_on_send():** 

65 turn off the send callback 

66 

67 which simply call ``on_<evt>(None)``. 

68 

69 The entire socket interface, excluding direct recv methods, is also 

70 provided, primarily through direct-linking the methods. 

71 e.g. 

72 

73 >>> stream.bind is stream.socket.bind 

74 True 

75 

76 

77 .. versionadded:: 25 

78 

79 send/recv callbacks can be coroutines. 

80 

81 .. versionchanged:: 25 

82 

83 ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced). 

84 If ZMQStreams are created with e.g. async Socket subclasses, 

85 a RuntimeWarning will be shown, 

86 and the socket cast back to the default zmq.Socket 

87 before connecting events. 

88 

89 Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the 

90 arguments passed to callback functions. 

91 Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods 

92 (the list of message frames). 

93 """ 

94 

95 socket: zmq.Socket 

96 io_loop: IOLoop 

97 poller: zmq.Poller 

98 _send_queue: Queue 

99 _recv_callback: Callable | None 

100 _send_callback: Callable | None 

101 _close_callback: Callable | None 

102 _state: int = 0 

103 _flushed: bool = False 

104 _recv_copy: bool = False 

105 _fd: int 

106 

107 def __init__(self, socket: zmq.Socket, io_loop: IOLoop | None = None): 

108 if isinstance(socket, zmq._future._AsyncSocket): 

109 warnings.warn( 

110 f"""ZMQStream only supports the base zmq.Socket class. 

111 

112 Use zmq.Socket(shadow=other_socket) 

113 or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)` 

114 to create a base zmq.Socket object, 

115 no matter what other kind of socket your Context creates. 

116 """, 

117 RuntimeWarning, 

118 stacklevel=2, 

119 ) 

120 # shadow back to base zmq.Socket, 

121 # otherwise callbacks like `on_recv` will get the wrong types. 

122 socket = zmq.Socket(shadow=socket) 

123 self.socket = socket 

124 

125 # IOLoop.current() is deprecated if called outside the event loop 

126 # that means 

127 self.io_loop = io_loop or IOLoop.current() 

128 self.poller = zmq.Poller() 

129 self._fd = cast(int, self.socket.FD) 

130 

131 self._send_queue = Queue() 

132 self._recv_callback = None 

133 self._send_callback = None 

134 self._close_callback = None 

135 self._recv_copy = False 

136 self._flushed = False 

137 

138 self._state = 0 

139 self._init_io_state() 

140 

141 # shortcircuit some socket methods 

142 self.bind = self.socket.bind 

143 self.bind_to_random_port = self.socket.bind_to_random_port 

144 self.connect = self.socket.connect 

145 self.setsockopt = self.socket.setsockopt 

146 self.getsockopt = self.socket.getsockopt 

147 self.setsockopt_string = self.socket.setsockopt_string 

148 self.getsockopt_string = self.socket.getsockopt_string 

149 self.setsockopt_unicode = self.socket.setsockopt_unicode 

150 self.getsockopt_unicode = self.socket.getsockopt_unicode 

151 

152 def stop_on_recv(self): 

153 """Disable callback and automatic receiving.""" 

154 return self.on_recv(None) 

155 

156 def stop_on_send(self): 

157 """Disable callback on sending.""" 

158 return self.on_send(None) 

159 

160 def stop_on_err(self): 

161 """DEPRECATED, does nothing""" 

162 gen_log.warn("on_err does nothing, and will be removed") 

163 

164 def on_err(self, callback: Callable): 

165 """DEPRECATED, does nothing""" 

166 gen_log.warn("on_err does nothing, and will be removed") 

167 

168 @overload 

169 def on_recv( 

170 self, 

171 callback: Callable[[list[bytes]], Any], 

172 ) -> None: ... 

173 

174 @overload 

175 def on_recv( 

176 self, 

177 callback: Callable[[list[bytes]], Any], 

178 copy: Literal[True], 

179 ) -> None: ... 

180 

181 @overload 

182 def on_recv( 

183 self, 

184 callback: Callable[[list[zmq.Frame]], Any], 

185 copy: Literal[False], 

186 ) -> None: ... 

187 

188 @overload 

189 def on_recv( 

190 self, 

191 callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any], 

192 copy: bool = ..., 

193 ): ... 

194 

195 def on_recv( 

196 self, 

197 callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any], 

198 copy: bool = True, 

199 ) -> None: 

200 """Register a callback for when a message is ready to recv. 

201 

202 There can be only one callback registered at a time, so each 

203 call to `on_recv` replaces previously registered callbacks. 

204 

205 on_recv(None) disables recv event polling. 

206 

207 Use on_recv_stream(callback) instead, to register a callback that will receive 

208 both this ZMQStream and the message, instead of just the message. 

209 

210 Parameters 

211 ---------- 

212 

213 callback : callable 

214 callback must take exactly one argument, which will be a 

215 list, as returned by socket.recv_multipart() 

216 if callback is None, recv callbacks are disabled. 

217 copy : bool 

218 copy is passed directly to recv, so if copy is False, 

219 callback will receive Message objects. If copy is True, 

220 then callback will receive bytes/str objects. 

221 

222 Returns : None 

223 """ 

224 

225 self._check_closed() 

226 assert callback is None or callable(callback) 

227 self._recv_callback = callback 

228 self._recv_copy = copy 

229 if callback is None: 

230 self._drop_io_state(zmq.POLLIN) 

231 else: 

232 self._add_io_state(zmq.POLLIN) 

233 

234 @overload 

235 def on_recv_stream( 

236 self, 

237 callback: Callable[[ZMQStream, list[bytes]], Any], 

238 ) -> None: ... 

239 

240 @overload 

241 def on_recv_stream( 

242 self, 

243 callback: Callable[[ZMQStream, list[bytes]], Any], 

244 copy: Literal[True], 

245 ) -> None: ... 

246 

247 @overload 

248 def on_recv_stream( 

249 self, 

250 callback: Callable[[ZMQStream, list[zmq.Frame]], Any], 

251 copy: Literal[False], 

252 ) -> None: ... 

253 

254 @overload 

255 def on_recv_stream( 

256 self, 

257 callback: ( 

258 Callable[[ZMQStream, list[zmq.Frame]], Any] 

259 | Callable[[ZMQStream, list[bytes]], Any] 

260 ), 

261 copy: bool = ..., 

262 ): ... 

263 

264 def on_recv_stream( 

265 self, 

266 callback: ( 

267 Callable[[ZMQStream, list[zmq.Frame]], Any] 

268 | Callable[[ZMQStream, list[bytes]], Any] 

269 ), 

270 copy: bool = True, 

271 ): 

272 """Same as on_recv, but callback will get this stream as first argument 

273 

274 callback must take exactly two arguments, as it will be called as:: 

275 

276 callback(stream, msg) 

277 

278 Useful when a single callback should be used with multiple streams. 

279 """ 

280 if callback is None: 

281 self.stop_on_recv() 

282 else: 

283 

284 def stream_callback(msg): 

285 return callback(self, msg) 

286 

287 self.on_recv(stream_callback, copy=copy) 

288 

289 def on_send( 

290 self, callback: Callable[[Sequence[Any], zmq.MessageTracker | None], Any] 

291 ): 

292 """Register a callback to be called on each send 

293 

294 There will be two arguments:: 

295 

296 callback(msg, status) 

297 

298 * `msg` will be the list of sendable objects that was just sent 

299 * `status` will be the return result of socket.send_multipart(msg) - 

300 MessageTracker or None. 

301 

302 Non-copying sends return a MessageTracker object whose 

303 `done` attribute will be True when the send is complete. 

304 This allows users to track when an object is safe to write to 

305 again. 

306 

307 The second argument will always be None if copy=True 

308 on the send. 

309 

310 Use on_send_stream(callback) to register a callback that will be passed 

311 this ZMQStream as the first argument, in addition to the other two. 

312 

313 on_send(None) disables recv event polling. 

314 

315 Parameters 

316 ---------- 

317 

318 callback : callable 

319 callback must take exactly two arguments, which will be 

320 the message being sent (always a list), 

321 and the return result of socket.send_multipart(msg) - 

322 MessageTracker or None. 

323 

324 if callback is None, send callbacks are disabled. 

325 """ 

326 

327 self._check_closed() 

328 assert callback is None or callable(callback) 

329 self._send_callback = callback 

330 

331 def on_send_stream( 

332 self, 

333 callback: Callable[[ZMQStream, Sequence[Any], zmq.MessageTracker | None], Any], 

334 ): 

335 """Same as on_send, but callback will get this stream as first argument 

336 

337 Callback will be passed three arguments:: 

338 

339 callback(stream, msg, status) 

340 

341 Useful when a single callback should be used with multiple streams. 

342 """ 

343 if callback is None: 

344 self.stop_on_send() 

345 else: 

346 self.on_send(lambda msg, status: callback(self, msg, status)) 

347 

348 def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs): 

349 """Send a message, optionally also register a new callback for sends. 

350 See zmq.socket.send for details. 

351 """ 

352 return self.send_multipart( 

353 [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs 

354 ) 

355 

356 def send_multipart( 

357 self, 

358 msg: Sequence[Any], 

359 flags: int = 0, 

360 copy: bool = True, 

361 track: bool = False, 

362 callback: Callable | None = None, 

363 **kwargs: Any, 

364 ) -> None: 

365 """Send a multipart message, optionally also register a new callback for sends. 

366 See zmq.socket.send_multipart for details. 

367 """ 

368 kwargs.update(dict(flags=flags, copy=copy, track=track)) 

369 self._send_queue.put((msg, kwargs)) 

370 callback = callback or self._send_callback 

371 if callback is not None: 

372 self.on_send(callback) 

373 else: 

374 # noop callback 

375 self.on_send(lambda *args: None) 

376 self._add_io_state(zmq.POLLOUT) 

377 

378 def send_string( 

379 self, 

380 u: str, 

381 flags: int = 0, 

382 encoding: str = 'utf-8', 

383 callback: Callable | None = None, 

384 **kwargs: Any, 

385 ): 

386 """Send a unicode message with an encoding. 

387 See zmq.socket.send_unicode for details. 

388 """ 

389 if not isinstance(u, str): 

390 raise TypeError("unicode/str objects only") 

391 return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs) 

392 

393 send_unicode = send_string 

394 

395 def send_json( 

396 self, 

397 obj: Any, 

398 flags: int = 0, 

399 callback: Callable | None = None, 

400 **kwargs: Any, 

401 ): 

402 """Send json-serialized version of an object. 

403 See zmq.socket.send_json for details. 

404 """ 

405 msg = jsonapi.dumps(obj) 

406 return self.send(msg, flags=flags, callback=callback, **kwargs) 

407 

408 def send_pyobj( 

409 self, 

410 obj: Any, 

411 flags: int = 0, 

412 protocol: int = -1, 

413 callback: Callable | None = None, 

414 **kwargs: Any, 

415 ): 

416 """Send a Python object as a message using pickle to serialize. 

417 

418 See zmq.socket.send_json for details. 

419 """ 

420 msg = pickle.dumps(obj, protocol) 

421 return self.send(msg, flags, callback=callback, **kwargs) 

422 

423 def _finish_flush(self): 

424 """callback for unsetting _flushed flag.""" 

425 self._flushed = False 

426 

427 def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: int | None = None): 

428 """Flush pending messages. 

429 

430 This method safely handles all pending incoming and/or outgoing messages, 

431 bypassing the inner loop, passing them to the registered callbacks. 

432 

433 A limit can be specified, to prevent blocking under high load. 

434 

435 flush will return the first time ANY of these conditions are met: 

436 * No more events matching the flag are pending. 

437 * the total number of events handled reaches the limit. 

438 

439 Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback 

440 is registered, unlike normal IOLoop operation. This allows flush to be 

441 used to remove *and ignore* incoming messages. 

442 

443 Parameters 

444 ---------- 

445 flag : int 

446 default=POLLIN|POLLOUT 

447 0MQ poll flags. 

448 If flag|POLLIN, recv events will be flushed. 

449 If flag|POLLOUT, send events will be flushed. 

450 Both flags can be set at once, which is the default. 

451 limit : None or int, optional 

452 The maximum number of messages to send or receive. 

453 Both send and recv count against this limit. 

454 

455 Returns 

456 ------- 

457 int : 

458 count of events handled (both send and recv) 

459 """ 

460 self._check_closed() 

461 # unset self._flushed, so callbacks will execute, in case flush has 

462 # already been called this iteration 

463 already_flushed = self._flushed 

464 self._flushed = False 

465 # initialize counters 

466 count = 0 

467 

468 def update_flag(): 

469 """Update the poll flag, to prevent registering POLLOUT events 

470 if we don't have pending sends.""" 

471 return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) 

472 

473 flag = update_flag() 

474 if not flag: 

475 # nothing to do 

476 return 0 

477 self.poller.register(self.socket, flag) 

478 events = self.poller.poll(0) 

479 while events and (not limit or count < limit): 

480 s, event = events[0] 

481 if event & POLLIN: # receiving 

482 self._handle_recv() 

483 count += 1 

484 if self.socket is None: 

485 # break if socket was closed during callback 

486 break 

487 if event & POLLOUT and self.sending(): 

488 self._handle_send() 

489 count += 1 

490 if self.socket is None: 

491 # break if socket was closed during callback 

492 break 

493 

494 flag = update_flag() 

495 if flag: 

496 self.poller.register(self.socket, flag) 

497 events = self.poller.poll(0) 

498 else: 

499 events = [] 

500 if count: # only bypass loop if we actually flushed something 

501 # skip send/recv callbacks this iteration 

502 self._flushed = True 

503 # reregister them at the end of the loop 

504 if not already_flushed: # don't need to do it again 

505 self.io_loop.add_callback(self._finish_flush) 

506 elif already_flushed: 

507 self._flushed = True 

508 

509 # update ioloop poll state, which may have changed 

510 self._rebuild_io_state() 

511 return count 

512 

513 def set_close_callback(self, callback: Callable | None): 

514 """Call the given callback when the stream is closed.""" 

515 self._close_callback = callback 

516 

517 def close(self, linger: int | None = None) -> None: 

518 """Close this stream.""" 

519 if self.socket is not None: 

520 if self.socket.closed: 

521 # fallback on raw fd for closed sockets 

522 # hopefully this happened promptly after close, 

523 # otherwise somebody else may have the FD 

524 warnings.warn( 

525 "Unregistering FD %s after closing socket. " 

526 "This could result in unregistering handlers for the wrong socket. " 

527 "Please use stream.close() instead of closing the socket directly." 

528 % self._fd, 

529 stacklevel=2, 

530 ) 

531 self.io_loop.remove_handler(self._fd) 

532 else: 

533 self.io_loop.remove_handler(self.socket) 

534 self.socket.close(linger) 

535 self.socket = None # type: ignore 

536 if self._close_callback: 

537 self._run_callback(self._close_callback) 

538 

539 def receiving(self) -> bool: 

540 """Returns True if we are currently receiving from the stream.""" 

541 return self._recv_callback is not None 

542 

543 def sending(self) -> bool: 

544 """Returns True if we are currently sending to the stream.""" 

545 return not self._send_queue.empty() 

546 

547 def closed(self) -> bool: 

548 if self.socket is None: 

549 return True 

550 if self.socket.closed: 

551 # underlying socket has been closed, but not by us! 

552 # trigger our cleanup 

553 self.close() 

554 return True 

555 return False 

556 

557 def _run_callback(self, callback, *args, **kwargs): 

558 """Wrap running callbacks in try/except to allow us to 

559 close our socket.""" 

560 try: 

561 f = callback(*args, **kwargs) 

562 if isinstance(f, Awaitable): 

563 f = asyncio.ensure_future(f) 

564 else: 

565 f = None 

566 except Exception: 

567 gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True) 

568 # Re-raise the exception so that IOLoop.handle_callback_exception 

569 # can see it and log the error 

570 raise 

571 

572 if f is not None: 

573 # handle async callbacks 

574 def _log_error(f): 

575 try: 

576 f.result() 

577 except Exception: 

578 gen_log.error( 

579 "Uncaught exception in ZMQStream callback", exc_info=True 

580 ) 

581 

582 f.add_done_callback(_log_error) 

583 

584 def _handle_events(self, fd, events): 

585 """This method is the actual handler for IOLoop, that gets called whenever 

586 an event on my socket is posted. It dispatches to _handle_recv, etc.""" 

587 if not self.socket: 

588 gen_log.warning("Got events for closed stream %s", self) 

589 return 

590 try: 

591 zmq_events = self.socket.EVENTS 

592 except zmq.ContextTerminated: 

593 gen_log.warning("Got events for stream %s after terminating context", self) 

594 # trigger close check, this will unregister callbacks 

595 self.closed() 

596 return 

597 except zmq.ZMQError as e: 

598 # run close check 

599 # shadow sockets may have been closed elsewhere, 

600 # which should show up as ENOTSOCK here 

601 if self.closed(): 

602 gen_log.warning( 

603 "Got events for stream %s attached to closed socket: %s", self, e 

604 ) 

605 else: 

606 gen_log.error("Error getting events for %s: %s", self, e) 

607 return 

608 try: 

609 # dispatch events: 

610 if zmq_events & zmq.POLLIN and self.receiving(): 

611 self._handle_recv() 

612 if not self.socket: 

613 return 

614 if zmq_events & zmq.POLLOUT and self.sending(): 

615 self._handle_send() 

616 if not self.socket: 

617 return 

618 

619 # rebuild the poll state 

620 self._rebuild_io_state() 

621 except Exception: 

622 gen_log.error("Uncaught exception in zmqstream callback", exc_info=True) 

623 raise 

624 

625 def _handle_recv(self): 

626 """Handle a recv event.""" 

627 if self._flushed: 

628 return 

629 try: 

630 msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) 

631 except zmq.ZMQError as e: 

632 if e.errno == zmq.EAGAIN: 

633 # state changed since poll event 

634 pass 

635 else: 

636 raise 

637 else: 

638 if self._recv_callback: 

639 callback = self._recv_callback 

640 self._run_callback(callback, msg) 

641 

642 def _handle_send(self): 

643 """Handle a send event.""" 

644 if self._flushed: 

645 return 

646 if not self.sending(): 

647 gen_log.error("Shouldn't have handled a send event") 

648 return 

649 

650 msg, kwargs = self._send_queue.get() 

651 try: 

652 status = self.socket.send_multipart(msg, **kwargs) 

653 except zmq.ZMQError as e: 

654 gen_log.error("SEND Error: %s", e) 

655 status = e 

656 if self._send_callback: 

657 callback = self._send_callback 

658 self._run_callback(callback, msg, status) 

659 

660 def _check_closed(self): 

661 if not self.socket: 

662 raise OSError("Stream is closed") 

663 

664 def _rebuild_io_state(self): 

665 """rebuild io state based on self.sending() and receiving()""" 

666 if self.socket is None: 

667 return 

668 state = 0 

669 if self.receiving(): 

670 state |= zmq.POLLIN 

671 if self.sending(): 

672 state |= zmq.POLLOUT 

673 

674 self._state = state 

675 self._update_handler(state) 

676 

677 def _add_io_state(self, state): 

678 """Add io_state to poller.""" 

679 self._state = self._state | state 

680 self._update_handler(self._state) 

681 

682 def _drop_io_state(self, state): 

683 """Stop poller from watching an io_state.""" 

684 self._state = self._state & (~state) 

685 self._update_handler(self._state) 

686 

687 def _update_handler(self, state): 

688 """Update IOLoop handler with state.""" 

689 if self.socket is None: 

690 return 

691 

692 if state & self.socket.events: 

693 # events still exist that haven't been processed 

694 # explicitly schedule handling to avoid missing events due to edge-triggered FDs 

695 self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0)) 

696 

697 def _init_io_state(self): 

698 """initialize the ioloop event handler""" 

699 self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)