Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tornado/websocket.py: 5%

735 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-22 06:15 +0000

1"""Implementation of the WebSocket protocol. 

2 

3`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional 

4communication between the browser and server. WebSockets are supported in the 

5current versions of all major browsers. 

6 

7This module implements the final version of the WebSocket protocol as 

8defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_. 

9 

10.. versionchanged:: 4.0 

11 Removed support for the draft 76 protocol version. 

12""" 

13 

14import abc 

15import asyncio 

16import base64 

17import hashlib 

18import os 

19import sys 

20import struct 

21import tornado 

22from urllib.parse import urlparse 

23import zlib 

24 

25from tornado.concurrent import Future, future_set_result_unless_cancelled 

26from tornado.escape import utf8, native_str, to_unicode 

27from tornado import gen, httpclient, httputil 

28from tornado.ioloop import IOLoop, PeriodicCallback 

29from tornado.iostream import StreamClosedError, IOStream 

30from tornado.log import gen_log, app_log 

31from tornado.netutil import Resolver 

32from tornado import simple_httpclient 

33from tornado.queues import Queue 

34from tornado.tcpclient import TCPClient 

35from tornado.util import _websocket_mask 

36 

37from typing import ( 

38 TYPE_CHECKING, 

39 cast, 

40 Any, 

41 Optional, 

42 Dict, 

43 Union, 

44 List, 

45 Awaitable, 

46 Callable, 

47 Tuple, 

48 Type, 

49) 

50from types import TracebackType 

51 

52if TYPE_CHECKING: 

53 from typing_extensions import Protocol 

54 

55 # The zlib compressor types aren't actually exposed anywhere 

56 # publicly, so declare protocols for the portions we use. 

57 class _Compressor(Protocol): 

58 def compress(self, data: bytes) -> bytes: 

59 pass 

60 

61 def flush(self, mode: int) -> bytes: 

62 pass 

63 

64 class _Decompressor(Protocol): 

65 unconsumed_tail = b"" # type: bytes 

66 

67 def decompress(self, data: bytes, max_length: int) -> bytes: 

68 pass 

69 

70 class _WebSocketDelegate(Protocol): 

71 # The common base interface implemented by WebSocketHandler on 

72 # the server side and WebSocketClientConnection on the client 

73 # side. 

74 def on_ws_connection_close( 

75 self, close_code: Optional[int] = None, close_reason: Optional[str] = None 

76 ) -> None: 

77 pass 

78 

79 def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: 

80 pass 

81 

82 def on_ping(self, data: bytes) -> None: 

83 pass 

84 

85 def on_pong(self, data: bytes) -> None: 

86 pass 

87 

88 def log_exception( 

89 self, 

90 typ: Optional[Type[BaseException]], 

91 value: Optional[BaseException], 

92 tb: Optional[TracebackType], 

93 ) -> None: 

94 pass 

95 

96 

97_default_max_message_size = 10 * 1024 * 1024 

98 

99 

100class WebSocketError(Exception): 

101 pass 

102 

103 

104class WebSocketClosedError(WebSocketError): 

105 """Raised by operations on a closed connection. 

106 

107 .. versionadded:: 3.2 

108 """ 

109 

110 pass 

111 

112 

113class _DecompressTooLargeError(Exception): 

114 pass 

115 

116 

117class _WebSocketParams(object): 

118 def __init__( 

119 self, 

120 ping_interval: Optional[float] = None, 

121 ping_timeout: Optional[float] = None, 

122 max_message_size: int = _default_max_message_size, 

123 compression_options: Optional[Dict[str, Any]] = None, 

124 ) -> None: 

125 self.ping_interval = ping_interval 

126 self.ping_timeout = ping_timeout 

127 self.max_message_size = max_message_size 

128 self.compression_options = compression_options 

129 

130 

131class WebSocketHandler(tornado.web.RequestHandler): 

132 """Subclass this class to create a basic WebSocket handler. 

133 

134 Override `on_message` to handle incoming messages, and use 

135 `write_message` to send messages to the client. You can also 

136 override `open` and `on_close` to handle opened and closed 

137 connections. 

138 

139 Custom upgrade response headers can be sent by overriding 

140 `~tornado.web.RequestHandler.set_default_headers` or 

141 `~tornado.web.RequestHandler.prepare`. 

142 

143 See http://dev.w3.org/html5/websockets/ for details on the 

144 JavaScript interface. The protocol is specified at 

145 http://tools.ietf.org/html/rfc6455. 

146 

147 Here is an example WebSocket handler that echos back all received messages 

148 back to the client: 

149 

150 .. testcode:: 

151 

152 class EchoWebSocket(tornado.websocket.WebSocketHandler): 

153 def open(self): 

154 print("WebSocket opened") 

155 

156 def on_message(self, message): 

157 self.write_message(u"You said: " + message) 

158 

159 def on_close(self): 

160 print("WebSocket closed") 

161 

162 .. testoutput:: 

163 :hide: 

164 

165 WebSockets are not standard HTTP connections. The "handshake" is 

166 HTTP, but after the handshake, the protocol is 

167 message-based. Consequently, most of the Tornado HTTP facilities 

168 are not available in handlers of this type. The only communication 

169 methods available to you are `write_message()`, `ping()`, and 

170 `close()`. Likewise, your request handler class should implement 

171 `open()` method rather than ``get()`` or ``post()``. 

172 

173 If you map the handler above to ``/websocket`` in your application, you can 

174 invoke it in JavaScript with:: 

175 

176 var ws = new WebSocket("ws://localhost:8888/websocket"); 

177 ws.onopen = function() { 

178 ws.send("Hello, world"); 

179 }; 

180 ws.onmessage = function (evt) { 

181 alert(evt.data); 

182 }; 

183 

184 This script pops up an alert box that says "You said: Hello, world". 

185 

186 Web browsers allow any site to open a websocket connection to any other, 

187 instead of using the same-origin policy that governs other network 

188 access from JavaScript. This can be surprising and is a potential 

189 security hole, so since Tornado 4.0 `WebSocketHandler` requires 

190 applications that wish to receive cross-origin websockets to opt in 

191 by overriding the `~WebSocketHandler.check_origin` method (see that 

192 method's docs for details). Failure to do so is the most likely 

193 cause of 403 errors when making a websocket connection. 

194 

195 When using a secure websocket connection (``wss://``) with a self-signed 

196 certificate, the connection from a browser may fail because it wants 

197 to show the "accept this certificate" dialog but has nowhere to show it. 

198 You must first visit a regular HTML page using the same certificate 

199 to accept it before the websocket connection will succeed. 

200 

201 If the application setting ``websocket_ping_interval`` has a non-zero 

202 value, a ping will be sent periodically, and the connection will be 

203 closed if a response is not received before the ``websocket_ping_timeout``. 

204 

205 Messages larger than the ``websocket_max_message_size`` application setting 

206 (default 10MiB) will not be accepted. 

207 

208 .. versionchanged:: 4.5 

209 Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and 

210 ``websocket_max_message_size``. 

211 """ 

212 

213 def __init__( 

214 self, 

215 application: tornado.web.Application, 

216 request: httputil.HTTPServerRequest, 

217 **kwargs: Any 

218 ) -> None: 

219 super().__init__(application, request, **kwargs) 

220 self.ws_connection = None # type: Optional[WebSocketProtocol] 

221 self.close_code = None # type: Optional[int] 

222 self.close_reason = None # type: Optional[str] 

223 self._on_close_called = False 

224 

225 async def get(self, *args: Any, **kwargs: Any) -> None: 

226 self.open_args = args 

227 self.open_kwargs = kwargs 

228 

229 # Upgrade header should be present and should be equal to WebSocket 

230 if self.request.headers.get("Upgrade", "").lower() != "websocket": 

231 self.set_status(400) 

232 log_msg = 'Can "Upgrade" only to "WebSocket".' 

233 self.finish(log_msg) 

234 gen_log.debug(log_msg) 

235 return 

236 

237 # Connection header should be upgrade. 

238 # Some proxy servers/load balancers 

239 # might mess with it. 

240 headers = self.request.headers 

241 connection = map( 

242 lambda s: s.strip().lower(), headers.get("Connection", "").split(",") 

243 ) 

244 if "upgrade" not in connection: 

245 self.set_status(400) 

246 log_msg = '"Connection" must be "Upgrade".' 

247 self.finish(log_msg) 

248 gen_log.debug(log_msg) 

249 return 

250 

251 # Handle WebSocket Origin naming convention differences 

252 # The difference between version 8 and 13 is that in 8 the 

253 # client sends a "Sec-Websocket-Origin" header and in 13 it's 

254 # simply "Origin". 

255 if "Origin" in self.request.headers: 

256 origin = self.request.headers.get("Origin") 

257 else: 

258 origin = self.request.headers.get("Sec-Websocket-Origin", None) 

259 

260 # If there was an origin header, check to make sure it matches 

261 # according to check_origin. When the origin is None, we assume it 

262 # did not come from a browser and that it can be passed on. 

263 if origin is not None and not self.check_origin(origin): 

264 self.set_status(403) 

265 log_msg = "Cross origin websockets not allowed" 

266 self.finish(log_msg) 

267 gen_log.debug(log_msg) 

268 return 

269 

270 self.ws_connection = self.get_websocket_protocol() 

271 if self.ws_connection: 

272 await self.ws_connection.accept_connection(self) 

273 else: 

274 self.set_status(426, "Upgrade Required") 

275 self.set_header("Sec-WebSocket-Version", "7, 8, 13") 

276 

277 @property 

278 def ping_interval(self) -> Optional[float]: 

279 """The interval for websocket keep-alive pings. 

280 

281 Set websocket_ping_interval = 0 to disable pings. 

282 """ 

283 return self.settings.get("websocket_ping_interval", None) 

284 

285 @property 

286 def ping_timeout(self) -> Optional[float]: 

287 """If no ping is received in this many seconds, 

288 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). 

289 Default is max of 3 pings or 30 seconds. 

290 """ 

291 return self.settings.get("websocket_ping_timeout", None) 

292 

293 @property 

294 def max_message_size(self) -> int: 

295 """Maximum allowed message size. 

296 

297 If the remote peer sends a message larger than this, the connection 

298 will be closed. 

299 

300 Default is 10MiB. 

301 """ 

302 return self.settings.get( 

303 "websocket_max_message_size", _default_max_message_size 

304 ) 

305 

306 def write_message( 

307 self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False 

308 ) -> "Future[None]": 

309 """Sends the given message to the client of this Web Socket. 

310 

311 The message may be either a string or a dict (which will be 

312 encoded as json). If the ``binary`` argument is false, the 

313 message will be sent as utf8; in binary mode any byte string 

314 is allowed. 

315 

316 If the connection is already closed, raises `WebSocketClosedError`. 

317 Returns a `.Future` which can be used for flow control. 

318 

319 .. versionchanged:: 3.2 

320 `WebSocketClosedError` was added (previously a closed connection 

321 would raise an `AttributeError`) 

322 

323 .. versionchanged:: 4.3 

324 Returns a `.Future` which can be used for flow control. 

325 

326 .. versionchanged:: 5.0 

327 Consistently raises `WebSocketClosedError`. Previously could 

328 sometimes raise `.StreamClosedError`. 

329 """ 

330 if self.ws_connection is None or self.ws_connection.is_closing(): 

331 raise WebSocketClosedError() 

332 if isinstance(message, dict): 

333 message = tornado.escape.json_encode(message) 

334 return self.ws_connection.write_message(message, binary=binary) 

335 

336 def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: 

337 """Override to implement subprotocol negotiation. 

338 

339 ``subprotocols`` is a list of strings identifying the 

340 subprotocols proposed by the client. This method may be 

341 overridden to return one of those strings to select it, or 

342 ``None`` to not select a subprotocol. 

343 

344 Failure to select a subprotocol does not automatically abort 

345 the connection, although clients may close the connection if 

346 none of their proposed subprotocols was selected. 

347 

348 The list may be empty, in which case this method must return 

349 None. This method is always called exactly once even if no 

350 subprotocols were proposed so that the handler can be advised 

351 of this fact. 

352 

353 .. versionchanged:: 5.1 

354 

355 Previously, this method was called with a list containing 

356 an empty string instead of an empty list if no subprotocols 

357 were proposed by the client. 

358 """ 

359 return None 

360 

361 @property 

362 def selected_subprotocol(self) -> Optional[str]: 

363 """The subprotocol returned by `select_subprotocol`. 

364 

365 .. versionadded:: 5.1 

366 """ 

367 assert self.ws_connection is not None 

368 return self.ws_connection.selected_subprotocol 

369 

370 def get_compression_options(self) -> Optional[Dict[str, Any]]: 

371 """Override to return compression options for the connection. 

372 

373 If this method returns None (the default), compression will 

374 be disabled. If it returns a dict (even an empty one), it 

375 will be enabled. The contents of the dict may be used to 

376 control the following compression options: 

377 

378 ``compression_level`` specifies the compression level. 

379 

380 ``mem_level`` specifies the amount of memory used for the internal compression state. 

381 

382 These parameters are documented in details here: 

383 https://docs.python.org/3.6/library/zlib.html#zlib.compressobj 

384 

385 .. versionadded:: 4.1 

386 

387 .. versionchanged:: 4.5 

388 

389 Added ``compression_level`` and ``mem_level``. 

390 """ 

391 # TODO: Add wbits option. 

392 return None 

393 

394 def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: 

395 """Invoked when a new WebSocket is opened. 

396 

397 The arguments to `open` are extracted from the `tornado.web.URLSpec` 

398 regular expression, just like the arguments to 

399 `tornado.web.RequestHandler.get`. 

400 

401 `open` may be a coroutine. `on_message` will not be called until 

402 `open` has returned. 

403 

404 .. versionchanged:: 5.1 

405 

406 ``open`` may be a coroutine. 

407 """ 

408 pass 

409 

410 def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: 

411 """Handle incoming messages on the WebSocket 

412 

413 This method must be overridden. 

414 

415 .. versionchanged:: 4.5 

416 

417 ``on_message`` can be a coroutine. 

418 """ 

419 raise NotImplementedError 

420 

421 def ping(self, data: Union[str, bytes] = b"") -> None: 

422 """Send ping frame to the remote end. 

423 

424 The data argument allows a small amount of data (up to 125 

425 bytes) to be sent as a part of the ping message. Note that not 

426 all websocket implementations expose this data to 

427 applications. 

428 

429 Consider using the ``websocket_ping_interval`` application 

430 setting instead of sending pings manually. 

431 

432 .. versionchanged:: 5.1 

433 

434 The data argument is now optional. 

435 

436 """ 

437 data = utf8(data) 

438 if self.ws_connection is None or self.ws_connection.is_closing(): 

439 raise WebSocketClosedError() 

440 self.ws_connection.write_ping(data) 

441 

442 def on_pong(self, data: bytes) -> None: 

443 """Invoked when the response to a ping frame is received.""" 

444 pass 

445 

446 def on_ping(self, data: bytes) -> None: 

447 """Invoked when the a ping frame is received.""" 

448 pass 

449 

450 def on_close(self) -> None: 

451 """Invoked when the WebSocket is closed. 

452 

453 If the connection was closed cleanly and a status code or reason 

454 phrase was supplied, these values will be available as the attributes 

455 ``self.close_code`` and ``self.close_reason``. 

456 

457 .. versionchanged:: 4.0 

458 

459 Added ``close_code`` and ``close_reason`` attributes. 

460 """ 

461 pass 

462 

463 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: 

464 """Closes this Web Socket. 

465 

466 Once the close handshake is successful the socket will be closed. 

467 

468 ``code`` may be a numeric status code, taken from the values 

469 defined in `RFC 6455 section 7.4.1 

470 <https://tools.ietf.org/html/rfc6455#section-7.4.1>`_. 

471 ``reason`` may be a textual message about why the connection is 

472 closing. These values are made available to the client, but are 

473 not otherwise interpreted by the websocket protocol. 

474 

475 .. versionchanged:: 4.0 

476 

477 Added the ``code`` and ``reason`` arguments. 

478 """ 

479 if self.ws_connection: 

480 self.ws_connection.close(code, reason) 

481 self.ws_connection = None 

482 

483 def check_origin(self, origin: str) -> bool: 

484 """Override to enable support for allowing alternate origins. 

485 

486 The ``origin`` argument is the value of the ``Origin`` HTTP 

487 header, the url responsible for initiating this request. This 

488 method is not called for clients that do not send this header; 

489 such requests are always allowed (because all browsers that 

490 implement WebSockets support this header, and non-browser 

491 clients do not have the same cross-site security concerns). 

492 

493 Should return ``True`` to accept the request or ``False`` to 

494 reject it. By default, rejects all requests with an origin on 

495 a host other than this one. 

496 

497 This is a security protection against cross site scripting attacks on 

498 browsers, since WebSockets are allowed to bypass the usual same-origin 

499 policies and don't use CORS headers. 

500 

501 .. warning:: 

502 

503 This is an important security measure; don't disable it 

504 without understanding the security implications. In 

505 particular, if your authentication is cookie-based, you 

506 must either restrict the origins allowed by 

507 ``check_origin()`` or implement your own XSRF-like 

508 protection for websocket connections. See `these 

509 <https://www.christian-schneider.net/CrossSiteWebSocketHijacking.html>`_ 

510 `articles 

511 <https://devcenter.heroku.com/articles/websocket-security>`_ 

512 for more. 

513 

514 To accept all cross-origin traffic (which was the default prior to 

515 Tornado 4.0), simply override this method to always return ``True``:: 

516 

517 def check_origin(self, origin): 

518 return True 

519 

520 To allow connections from any subdomain of your site, you might 

521 do something like:: 

522 

523 def check_origin(self, origin): 

524 parsed_origin = urllib.parse.urlparse(origin) 

525 return parsed_origin.netloc.endswith(".mydomain.com") 

526 

527 .. versionadded:: 4.0 

528 

529 """ 

530 parsed_origin = urlparse(origin) 

531 origin = parsed_origin.netloc 

532 origin = origin.lower() 

533 

534 host = self.request.headers.get("Host") 

535 

536 # Check to see that origin matches host directly, including ports 

537 return origin == host 

538 

539 def set_nodelay(self, value: bool) -> None: 

540 """Set the no-delay flag for this stream. 

541 

542 By default, small messages may be delayed and/or combined to minimize 

543 the number of packets sent. This can sometimes cause 200-500ms delays 

544 due to the interaction between Nagle's algorithm and TCP delayed 

545 ACKs. To reduce this delay (at the expense of possibly increasing 

546 bandwidth usage), call ``self.set_nodelay(True)`` once the websocket 

547 connection is established. 

548 

549 See `.BaseIOStream.set_nodelay` for additional details. 

550 

551 .. versionadded:: 3.1 

552 """ 

553 assert self.ws_connection is not None 

554 self.ws_connection.set_nodelay(value) 

555 

556 def on_connection_close(self) -> None: 

557 if self.ws_connection: 

558 self.ws_connection.on_connection_close() 

559 self.ws_connection = None 

560 if not self._on_close_called: 

561 self._on_close_called = True 

562 self.on_close() 

563 self._break_cycles() 

564 

565 def on_ws_connection_close( 

566 self, close_code: Optional[int] = None, close_reason: Optional[str] = None 

567 ) -> None: 

568 self.close_code = close_code 

569 self.close_reason = close_reason 

570 self.on_connection_close() 

571 

572 def _break_cycles(self) -> None: 

573 # WebSocketHandlers call finish() early, but we don't want to 

574 # break up reference cycles (which makes it impossible to call 

575 # self.render_string) until after we've really closed the 

576 # connection (if it was established in the first place, 

577 # indicated by status code 101). 

578 if self.get_status() != 101 or self._on_close_called: 

579 super()._break_cycles() 

580 

581 def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: 

582 websocket_version = self.request.headers.get("Sec-WebSocket-Version") 

583 if websocket_version in ("7", "8", "13"): 

584 params = _WebSocketParams( 

585 ping_interval=self.ping_interval, 

586 ping_timeout=self.ping_timeout, 

587 max_message_size=self.max_message_size, 

588 compression_options=self.get_compression_options(), 

589 ) 

590 return WebSocketProtocol13(self, False, params) 

591 return None 

592 

593 def _detach_stream(self) -> IOStream: 

594 # disable non-WS methods 

595 for method in [ 

596 "write", 

597 "redirect", 

598 "set_header", 

599 "set_cookie", 

600 "set_status", 

601 "flush", 

602 "finish", 

603 ]: 

604 setattr(self, method, _raise_not_supported_for_websockets) 

605 return self.detach() 

606 

607 

608def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: 

609 raise RuntimeError("Method not supported for Web Sockets") 

610 

611 

612class WebSocketProtocol(abc.ABC): 

613 """Base class for WebSocket protocol versions.""" 

614 

615 def __init__(self, handler: "_WebSocketDelegate") -> None: 

616 self.handler = handler 

617 self.stream = None # type: Optional[IOStream] 

618 self.client_terminated = False 

619 self.server_terminated = False 

620 

621 def _run_callback( 

622 self, callback: Callable, *args: Any, **kwargs: Any 

623 ) -> "Optional[Future[Any]]": 

624 """Runs the given callback with exception handling. 

625 

626 If the callback is a coroutine, returns its Future. On error, aborts the 

627 websocket connection and returns None. 

628 """ 

629 try: 

630 result = callback(*args, **kwargs) 

631 except Exception: 

632 self.handler.log_exception(*sys.exc_info()) 

633 self._abort() 

634 return None 

635 else: 

636 if result is not None: 

637 result = gen.convert_yielded(result) 

638 assert self.stream is not None 

639 self.stream.io_loop.add_future(result, lambda f: f.result()) 

640 return result 

641 

642 def on_connection_close(self) -> None: 

643 self._abort() 

644 

645 def _abort(self) -> None: 

646 """Instantly aborts the WebSocket connection by closing the socket""" 

647 self.client_terminated = True 

648 self.server_terminated = True 

649 if self.stream is not None: 

650 self.stream.close() # forcibly tear down the connection 

651 self.close() # let the subclass cleanup 

652 

653 @abc.abstractmethod 

654 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: 

655 raise NotImplementedError() 

656 

657 @abc.abstractmethod 

658 def is_closing(self) -> bool: 

659 raise NotImplementedError() 

660 

661 @abc.abstractmethod 

662 async def accept_connection(self, handler: WebSocketHandler) -> None: 

663 raise NotImplementedError() 

664 

665 @abc.abstractmethod 

666 def write_message( 

667 self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False 

668 ) -> "Future[None]": 

669 raise NotImplementedError() 

670 

671 @property 

672 @abc.abstractmethod 

673 def selected_subprotocol(self) -> Optional[str]: 

674 raise NotImplementedError() 

675 

676 @abc.abstractmethod 

677 def write_ping(self, data: bytes) -> None: 

678 raise NotImplementedError() 

679 

680 # The entry points below are used by WebSocketClientConnection, 

681 # which was introduced after we only supported a single version of 

682 # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13 

683 # boundary is currently pretty ad-hoc. 

684 @abc.abstractmethod 

685 def _process_server_headers( 

686 self, key: Union[str, bytes], headers: httputil.HTTPHeaders 

687 ) -> None: 

688 raise NotImplementedError() 

689 

690 @abc.abstractmethod 

691 def start_pinging(self) -> None: 

692 raise NotImplementedError() 

693 

694 @abc.abstractmethod 

695 async def _receive_frame_loop(self) -> None: 

696 raise NotImplementedError() 

697 

698 @abc.abstractmethod 

699 def set_nodelay(self, x: bool) -> None: 

700 raise NotImplementedError() 

701 

702 

703class _PerMessageDeflateCompressor(object): 

704 def __init__( 

705 self, 

706 persistent: bool, 

707 max_wbits: Optional[int], 

708 compression_options: Optional[Dict[str, Any]] = None, 

709 ) -> None: 

710 if max_wbits is None: 

711 max_wbits = zlib.MAX_WBITS 

712 # There is no symbolic constant for the minimum wbits value. 

713 if not (8 <= max_wbits <= zlib.MAX_WBITS): 

714 raise ValueError( 

715 "Invalid max_wbits value %r; allowed range 8-%d", 

716 max_wbits, 

717 zlib.MAX_WBITS, 

718 ) 

719 self._max_wbits = max_wbits 

720 

721 if ( 

722 compression_options is None 

723 or "compression_level" not in compression_options 

724 ): 

725 self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL 

726 else: 

727 self._compression_level = compression_options["compression_level"] 

728 

729 if compression_options is None or "mem_level" not in compression_options: 

730 self._mem_level = 8 

731 else: 

732 self._mem_level = compression_options["mem_level"] 

733 

734 if persistent: 

735 self._compressor = self._create_compressor() # type: Optional[_Compressor] 

736 else: 

737 self._compressor = None 

738 

739 def _create_compressor(self) -> "_Compressor": 

740 return zlib.compressobj( 

741 self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level 

742 ) 

743 

744 def compress(self, data: bytes) -> bytes: 

745 compressor = self._compressor or self._create_compressor() 

746 data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH) 

747 assert data.endswith(b"\x00\x00\xff\xff") 

748 return data[:-4] 

749 

750 

751class _PerMessageDeflateDecompressor(object): 

752 def __init__( 

753 self, 

754 persistent: bool, 

755 max_wbits: Optional[int], 

756 max_message_size: int, 

757 compression_options: Optional[Dict[str, Any]] = None, 

758 ) -> None: 

759 self._max_message_size = max_message_size 

760 if max_wbits is None: 

761 max_wbits = zlib.MAX_WBITS 

762 if not (8 <= max_wbits <= zlib.MAX_WBITS): 

763 raise ValueError( 

764 "Invalid max_wbits value %r; allowed range 8-%d", 

765 max_wbits, 

766 zlib.MAX_WBITS, 

767 ) 

768 self._max_wbits = max_wbits 

769 if persistent: 

770 self._decompressor = ( 

771 self._create_decompressor() 

772 ) # type: Optional[_Decompressor] 

773 else: 

774 self._decompressor = None 

775 

776 def _create_decompressor(self) -> "_Decompressor": 

777 return zlib.decompressobj(-self._max_wbits) 

778 

779 def decompress(self, data: bytes) -> bytes: 

780 decompressor = self._decompressor or self._create_decompressor() 

781 result = decompressor.decompress( 

782 data + b"\x00\x00\xff\xff", self._max_message_size 

783 ) 

784 if decompressor.unconsumed_tail: 

785 raise _DecompressTooLargeError() 

786 return result 

787 

788 

789class WebSocketProtocol13(WebSocketProtocol): 

790 """Implementation of the WebSocket protocol from RFC 6455. 

791 

792 This class supports versions 7 and 8 of the protocol in addition to the 

793 final version 13. 

794 """ 

795 

796 # Bit masks for the first byte of a frame. 

797 FIN = 0x80 

798 RSV1 = 0x40 

799 RSV2 = 0x20 

800 RSV3 = 0x10 

801 RSV_MASK = RSV1 | RSV2 | RSV3 

802 OPCODE_MASK = 0x0F 

803 

804 stream = None # type: IOStream 

805 

806 def __init__( 

807 self, 

808 handler: "_WebSocketDelegate", 

809 mask_outgoing: bool, 

810 params: _WebSocketParams, 

811 ) -> None: 

812 WebSocketProtocol.__init__(self, handler) 

813 self.mask_outgoing = mask_outgoing 

814 self.params = params 

815 self._final_frame = False 

816 self._frame_opcode = None 

817 self._masked_frame = None 

818 self._frame_mask = None # type: Optional[bytes] 

819 self._frame_length = None 

820 self._fragmented_message_buffer = None # type: Optional[bytearray] 

821 self._fragmented_message_opcode = None 

822 self._waiting = None # type: object 

823 self._compression_options = params.compression_options 

824 self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor] 

825 self._compressor = None # type: Optional[_PerMessageDeflateCompressor] 

826 self._frame_compressed = None # type: Optional[bool] 

827 # The total uncompressed size of all messages received or sent. 

828 # Unicode messages are encoded to utf8. 

829 # Only for testing; subject to change. 

830 self._message_bytes_in = 0 

831 self._message_bytes_out = 0 

832 # The total size of all packets received or sent. Includes 

833 # the effect of compression, frame overhead, and control frames. 

834 self._wire_bytes_in = 0 

835 self._wire_bytes_out = 0 

836 self.ping_callback = None # type: Optional[PeriodicCallback] 

837 self.last_ping = 0.0 

838 self.last_pong = 0.0 

839 self.close_code = None # type: Optional[int] 

840 self.close_reason = None # type: Optional[str] 

841 

842 # Use a property for this to satisfy the abc. 

843 @property 

844 def selected_subprotocol(self) -> Optional[str]: 

845 return self._selected_subprotocol 

846 

847 @selected_subprotocol.setter 

848 def selected_subprotocol(self, value: Optional[str]) -> None: 

849 self._selected_subprotocol = value 

850 

851 async def accept_connection(self, handler: WebSocketHandler) -> None: 

852 try: 

853 self._handle_websocket_headers(handler) 

854 except ValueError: 

855 handler.set_status(400) 

856 log_msg = "Missing/Invalid WebSocket headers" 

857 handler.finish(log_msg) 

858 gen_log.debug(log_msg) 

859 return 

860 

861 try: 

862 await self._accept_connection(handler) 

863 except asyncio.CancelledError: 

864 self._abort() 

865 return 

866 except ValueError: 

867 gen_log.debug("Malformed WebSocket request received", exc_info=True) 

868 self._abort() 

869 return 

870 

871 def _handle_websocket_headers(self, handler: WebSocketHandler) -> None: 

872 """Verifies all invariant- and required headers 

873 

874 If a header is missing or have an incorrect value ValueError will be 

875 raised 

876 """ 

877 fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") 

878 if not all(map(lambda f: handler.request.headers.get(f), fields)): 

879 raise ValueError("Missing/Invalid WebSocket headers") 

880 

881 @staticmethod 

882 def compute_accept_value(key: Union[str, bytes]) -> str: 

883 """Computes the value for the Sec-WebSocket-Accept header, 

884 given the value for Sec-WebSocket-Key. 

885 """ 

886 sha1 = hashlib.sha1() 

887 sha1.update(utf8(key)) 

888 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value 

889 return native_str(base64.b64encode(sha1.digest())) 

890 

891 def _challenge_response(self, handler: WebSocketHandler) -> str: 

892 return WebSocketProtocol13.compute_accept_value( 

893 cast(str, handler.request.headers.get("Sec-Websocket-Key")) 

894 ) 

895 

896 async def _accept_connection(self, handler: WebSocketHandler) -> None: 

897 subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") 

898 if subprotocol_header: 

899 subprotocols = [s.strip() for s in subprotocol_header.split(",")] 

900 else: 

901 subprotocols = [] 

902 self.selected_subprotocol = handler.select_subprotocol(subprotocols) 

903 if self.selected_subprotocol: 

904 assert self.selected_subprotocol in subprotocols 

905 handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) 

906 

907 extensions = self._parse_extensions_header(handler.request.headers) 

908 for ext in extensions: 

909 if ext[0] == "permessage-deflate" and self._compression_options is not None: 

910 # TODO: negotiate parameters if compression_options 

911 # specifies limits. 

912 self._create_compressors("server", ext[1], self._compression_options) 

913 if ( 

914 "client_max_window_bits" in ext[1] 

915 and ext[1]["client_max_window_bits"] is None 

916 ): 

917 # Don't echo an offered client_max_window_bits 

918 # parameter with no value. 

919 del ext[1]["client_max_window_bits"] 

920 handler.set_header( 

921 "Sec-WebSocket-Extensions", 

922 httputil._encode_header("permessage-deflate", ext[1]), 

923 ) 

924 break 

925 

926 handler.clear_header("Content-Type") 

927 handler.set_status(101) 

928 handler.set_header("Upgrade", "websocket") 

929 handler.set_header("Connection", "Upgrade") 

930 handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) 

931 handler.finish() 

932 

933 self.stream = handler._detach_stream() 

934 

935 self.start_pinging() 

936 try: 

937 open_result = handler.open(*handler.open_args, **handler.open_kwargs) 

938 if open_result is not None: 

939 await open_result 

940 except Exception: 

941 handler.log_exception(*sys.exc_info()) 

942 self._abort() 

943 return 

944 

945 await self._receive_frame_loop() 

946 

947 def _parse_extensions_header( 

948 self, headers: httputil.HTTPHeaders 

949 ) -> List[Tuple[str, Dict[str, str]]]: 

950 extensions = headers.get("Sec-WebSocket-Extensions", "") 

951 if extensions: 

952 return [httputil._parse_header(e.strip()) for e in extensions.split(",")] 

953 return [] 

954 

955 def _process_server_headers( 

956 self, key: Union[str, bytes], headers: httputil.HTTPHeaders 

957 ) -> None: 

958 """Process the headers sent by the server to this client connection. 

959 

960 'key' is the websocket handshake challenge/response key. 

961 """ 

962 assert headers["Upgrade"].lower() == "websocket" 

963 assert headers["Connection"].lower() == "upgrade" 

964 accept = self.compute_accept_value(key) 

965 assert headers["Sec-Websocket-Accept"] == accept 

966 

967 extensions = self._parse_extensions_header(headers) 

968 for ext in extensions: 

969 if ext[0] == "permessage-deflate" and self._compression_options is not None: 

970 self._create_compressors("client", ext[1]) 

971 else: 

972 raise ValueError("unsupported extension %r", ext) 

973 

974 self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None) 

975 

976 def _get_compressor_options( 

977 self, 

978 side: str, 

979 agreed_parameters: Dict[str, Any], 

980 compression_options: Optional[Dict[str, Any]] = None, 

981 ) -> Dict[str, Any]: 

982 """Converts a websocket agreed_parameters set to keyword arguments 

983 for our compressor objects. 

984 """ 

985 options = dict( 

986 persistent=(side + "_no_context_takeover") not in agreed_parameters 

987 ) # type: Dict[str, Any] 

988 wbits_header = agreed_parameters.get(side + "_max_window_bits", None) 

989 if wbits_header is None: 

990 options["max_wbits"] = zlib.MAX_WBITS 

991 else: 

992 options["max_wbits"] = int(wbits_header) 

993 options["compression_options"] = compression_options 

994 return options 

995 

996 def _create_compressors( 

997 self, 

998 side: str, 

999 agreed_parameters: Dict[str, Any], 

1000 compression_options: Optional[Dict[str, Any]] = None, 

1001 ) -> None: 

1002 # TODO: handle invalid parameters gracefully 

1003 allowed_keys = set( 

1004 [ 

1005 "server_no_context_takeover", 

1006 "client_no_context_takeover", 

1007 "server_max_window_bits", 

1008 "client_max_window_bits", 

1009 ] 

1010 ) 

1011 for key in agreed_parameters: 

1012 if key not in allowed_keys: 

1013 raise ValueError("unsupported compression parameter %r" % key) 

1014 other_side = "client" if (side == "server") else "server" 

1015 self._compressor = _PerMessageDeflateCompressor( 

1016 **self._get_compressor_options(side, agreed_parameters, compression_options) 

1017 ) 

1018 self._decompressor = _PerMessageDeflateDecompressor( 

1019 max_message_size=self.params.max_message_size, 

1020 **self._get_compressor_options( 

1021 other_side, agreed_parameters, compression_options 

1022 ) 

1023 ) 

1024 

1025 def _write_frame( 

1026 self, fin: bool, opcode: int, data: bytes, flags: int = 0 

1027 ) -> "Future[None]": 

1028 data_len = len(data) 

1029 if opcode & 0x8: 

1030 # All control frames MUST have a payload length of 125 

1031 # bytes or less and MUST NOT be fragmented. 

1032 if not fin: 

1033 raise ValueError("control frames may not be fragmented") 

1034 if data_len > 125: 

1035 raise ValueError("control frame payloads may not exceed 125 bytes") 

1036 if fin: 

1037 finbit = self.FIN 

1038 else: 

1039 finbit = 0 

1040 frame = struct.pack("B", finbit | opcode | flags) 

1041 if self.mask_outgoing: 

1042 mask_bit = 0x80 

1043 else: 

1044 mask_bit = 0 

1045 if data_len < 126: 

1046 frame += struct.pack("B", data_len | mask_bit) 

1047 elif data_len <= 0xFFFF: 

1048 frame += struct.pack("!BH", 126 | mask_bit, data_len) 

1049 else: 

1050 frame += struct.pack("!BQ", 127 | mask_bit, data_len) 

1051 if self.mask_outgoing: 

1052 mask = os.urandom(4) 

1053 data = mask + _websocket_mask(mask, data) 

1054 frame += data 

1055 self._wire_bytes_out += len(frame) 

1056 return self.stream.write(frame) 

1057 

1058 def write_message( 

1059 self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False 

1060 ) -> "Future[None]": 

1061 """Sends the given message to the client of this Web Socket.""" 

1062 if binary: 

1063 opcode = 0x2 

1064 else: 

1065 opcode = 0x1 

1066 if isinstance(message, dict): 

1067 message = tornado.escape.json_encode(message) 

1068 message = tornado.escape.utf8(message) 

1069 assert isinstance(message, bytes) 

1070 self._message_bytes_out += len(message) 

1071 flags = 0 

1072 if self._compressor: 

1073 message = self._compressor.compress(message) 

1074 flags |= self.RSV1 

1075 # For historical reasons, write methods in Tornado operate in a semi-synchronous 

1076 # mode in which awaiting the Future they return is optional (But errors can 

1077 # still be raised). This requires us to go through an awkward dance here 

1078 # to transform the errors that may be returned while presenting the same 

1079 # semi-synchronous interface. 

1080 try: 

1081 fut = self._write_frame(True, opcode, message, flags=flags) 

1082 except StreamClosedError: 

1083 raise WebSocketClosedError() 

1084 

1085 async def wrapper() -> None: 

1086 try: 

1087 await fut 

1088 except StreamClosedError: 

1089 raise WebSocketClosedError() 

1090 

1091 return asyncio.ensure_future(wrapper()) 

1092 

1093 def write_ping(self, data: bytes) -> None: 

1094 """Send ping frame.""" 

1095 assert isinstance(data, bytes) 

1096 self._write_frame(True, 0x9, data) 

1097 

1098 async def _receive_frame_loop(self) -> None: 

1099 try: 

1100 while not self.client_terminated: 

1101 await self._receive_frame() 

1102 except StreamClosedError: 

1103 self._abort() 

1104 self.handler.on_ws_connection_close(self.close_code, self.close_reason) 

1105 

1106 async def _read_bytes(self, n: int) -> bytes: 

1107 data = await self.stream.read_bytes(n) 

1108 self._wire_bytes_in += n 

1109 return data 

1110 

1111 async def _receive_frame(self) -> None: 

1112 # Read the frame header. 

1113 data = await self._read_bytes(2) 

1114 header, mask_payloadlen = struct.unpack("BB", data) 

1115 is_final_frame = header & self.FIN 

1116 reserved_bits = header & self.RSV_MASK 

1117 opcode = header & self.OPCODE_MASK 

1118 opcode_is_control = opcode & 0x8 

1119 if self._decompressor is not None and opcode != 0: 

1120 # Compression flag is present in the first frame's header, 

1121 # but we can't decompress until we have all the frames of 

1122 # the message. 

1123 self._frame_compressed = bool(reserved_bits & self.RSV1) 

1124 reserved_bits &= ~self.RSV1 

1125 if reserved_bits: 

1126 # client is using as-yet-undefined extensions; abort 

1127 self._abort() 

1128 return 

1129 is_masked = bool(mask_payloadlen & 0x80) 

1130 payloadlen = mask_payloadlen & 0x7F 

1131 

1132 # Parse and validate the length. 

1133 if opcode_is_control and payloadlen >= 126: 

1134 # control frames must have payload < 126 

1135 self._abort() 

1136 return 

1137 if payloadlen < 126: 

1138 self._frame_length = payloadlen 

1139 elif payloadlen == 126: 

1140 data = await self._read_bytes(2) 

1141 payloadlen = struct.unpack("!H", data)[0] 

1142 elif payloadlen == 127: 

1143 data = await self._read_bytes(8) 

1144 payloadlen = struct.unpack("!Q", data)[0] 

1145 new_len = payloadlen 

1146 if self._fragmented_message_buffer is not None: 

1147 new_len += len(self._fragmented_message_buffer) 

1148 if new_len > self.params.max_message_size: 

1149 self.close(1009, "message too big") 

1150 self._abort() 

1151 return 

1152 

1153 # Read the payload, unmasking if necessary. 

1154 if is_masked: 

1155 self._frame_mask = await self._read_bytes(4) 

1156 data = await self._read_bytes(payloadlen) 

1157 if is_masked: 

1158 assert self._frame_mask is not None 

1159 data = _websocket_mask(self._frame_mask, data) 

1160 

1161 # Decide what to do with this frame. 

1162 if opcode_is_control: 

1163 # control frames may be interleaved with a series of fragmented 

1164 # data frames, so control frames must not interact with 

1165 # self._fragmented_* 

1166 if not is_final_frame: 

1167 # control frames must not be fragmented 

1168 self._abort() 

1169 return 

1170 elif opcode == 0: # continuation frame 

1171 if self._fragmented_message_buffer is None: 

1172 # nothing to continue 

1173 self._abort() 

1174 return 

1175 self._fragmented_message_buffer.extend(data) 

1176 if is_final_frame: 

1177 opcode = self._fragmented_message_opcode 

1178 data = bytes(self._fragmented_message_buffer) 

1179 self._fragmented_message_buffer = None 

1180 else: # start of new data message 

1181 if self._fragmented_message_buffer is not None: 

1182 # can't start new message until the old one is finished 

1183 self._abort() 

1184 return 

1185 if not is_final_frame: 

1186 self._fragmented_message_opcode = opcode 

1187 self._fragmented_message_buffer = bytearray(data) 

1188 

1189 if is_final_frame: 

1190 handled_future = self._handle_message(opcode, data) 

1191 if handled_future is not None: 

1192 await handled_future 

1193 

1194 def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]": 

1195 """Execute on_message, returning its Future if it is a coroutine.""" 

1196 if self.client_terminated: 

1197 return None 

1198 

1199 if self._frame_compressed: 

1200 assert self._decompressor is not None 

1201 try: 

1202 data = self._decompressor.decompress(data) 

1203 except _DecompressTooLargeError: 

1204 self.close(1009, "message too big after decompression") 

1205 self._abort() 

1206 return None 

1207 

1208 if opcode == 0x1: 

1209 # UTF-8 data 

1210 self._message_bytes_in += len(data) 

1211 try: 

1212 decoded = data.decode("utf-8") 

1213 except UnicodeDecodeError: 

1214 self._abort() 

1215 return None 

1216 return self._run_callback(self.handler.on_message, decoded) 

1217 elif opcode == 0x2: 

1218 # Binary data 

1219 self._message_bytes_in += len(data) 

1220 return self._run_callback(self.handler.on_message, data) 

1221 elif opcode == 0x8: 

1222 # Close 

1223 self.client_terminated = True 

1224 if len(data) >= 2: 

1225 self.close_code = struct.unpack(">H", data[:2])[0] 

1226 if len(data) > 2: 

1227 self.close_reason = to_unicode(data[2:]) 

1228 # Echo the received close code, if any (RFC 6455 section 5.5.1). 

1229 self.close(self.close_code) 

1230 elif opcode == 0x9: 

1231 # Ping 

1232 try: 

1233 self._write_frame(True, 0xA, data) 

1234 except StreamClosedError: 

1235 self._abort() 

1236 self._run_callback(self.handler.on_ping, data) 

1237 elif opcode == 0xA: 

1238 # Pong 

1239 self.last_pong = IOLoop.current().time() 

1240 return self._run_callback(self.handler.on_pong, data) 

1241 else: 

1242 self._abort() 

1243 return None 

1244 

1245 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: 

1246 """Closes the WebSocket connection.""" 

1247 if not self.server_terminated: 

1248 if not self.stream.closed(): 

1249 if code is None and reason is not None: 

1250 code = 1000 # "normal closure" status code 

1251 if code is None: 

1252 close_data = b"" 

1253 else: 

1254 close_data = struct.pack(">H", code) 

1255 if reason is not None: 

1256 close_data += utf8(reason) 

1257 try: 

1258 self._write_frame(True, 0x8, close_data) 

1259 except StreamClosedError: 

1260 self._abort() 

1261 self.server_terminated = True 

1262 if self.client_terminated: 

1263 if self._waiting is not None: 

1264 self.stream.io_loop.remove_timeout(self._waiting) 

1265 self._waiting = None 

1266 self.stream.close() 

1267 elif self._waiting is None: 

1268 # Give the client a few seconds to complete a clean shutdown, 

1269 # otherwise just close the connection. 

1270 self._waiting = self.stream.io_loop.add_timeout( 

1271 self.stream.io_loop.time() + 5, self._abort 

1272 ) 

1273 if self.ping_callback: 

1274 self.ping_callback.stop() 

1275 self.ping_callback = None 

1276 

1277 def is_closing(self) -> bool: 

1278 """Return ``True`` if this connection is closing. 

1279 

1280 The connection is considered closing if either side has 

1281 initiated its closing handshake or if the stream has been 

1282 shut down uncleanly. 

1283 """ 

1284 return self.stream.closed() or self.client_terminated or self.server_terminated 

1285 

1286 @property 

1287 def ping_interval(self) -> Optional[float]: 

1288 interval = self.params.ping_interval 

1289 if interval is not None: 

1290 return interval 

1291 return 0 

1292 

1293 @property 

1294 def ping_timeout(self) -> Optional[float]: 

1295 timeout = self.params.ping_timeout 

1296 if timeout is not None: 

1297 return timeout 

1298 assert self.ping_interval is not None 

1299 return max(3 * self.ping_interval, 30) 

1300 

1301 def start_pinging(self) -> None: 

1302 """Start sending periodic pings to keep the connection alive""" 

1303 assert self.ping_interval is not None 

1304 if self.ping_interval > 0: 

1305 self.last_ping = self.last_pong = IOLoop.current().time() 

1306 self.ping_callback = PeriodicCallback( 

1307 self.periodic_ping, self.ping_interval * 1000 

1308 ) 

1309 self.ping_callback.start() 

1310 

1311 def periodic_ping(self) -> None: 

1312 """Send a ping to keep the websocket alive 

1313 

1314 Called periodically if the websocket_ping_interval is set and non-zero. 

1315 """ 

1316 if self.is_closing() and self.ping_callback is not None: 

1317 self.ping_callback.stop() 

1318 return 

1319 

1320 # Check for timeout on pong. Make sure that we really have 

1321 # sent a recent ping in case the machine with both server and 

1322 # client has been suspended since the last ping. 

1323 now = IOLoop.current().time() 

1324 since_last_pong = now - self.last_pong 

1325 since_last_ping = now - self.last_ping 

1326 assert self.ping_interval is not None 

1327 assert self.ping_timeout is not None 

1328 if ( 

1329 since_last_ping < 2 * self.ping_interval 

1330 and since_last_pong > self.ping_timeout 

1331 ): 

1332 self.close() 

1333 return 

1334 

1335 self.write_ping(b"") 

1336 self.last_ping = now 

1337 

1338 def set_nodelay(self, x: bool) -> None: 

1339 self.stream.set_nodelay(x) 

1340 

1341 

1342class WebSocketClientConnection(simple_httpclient._HTTPConnection): 

1343 """WebSocket client connection. 

1344 

1345 This class should not be instantiated directly; use the 

1346 `websocket_connect` function instead. 

1347 """ 

1348 

1349 protocol = None # type: WebSocketProtocol 

1350 

1351 def __init__( 

1352 self, 

1353 request: httpclient.HTTPRequest, 

1354 on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, 

1355 compression_options: Optional[Dict[str, Any]] = None, 

1356 ping_interval: Optional[float] = None, 

1357 ping_timeout: Optional[float] = None, 

1358 max_message_size: int = _default_max_message_size, 

1359 subprotocols: Optional[List[str]] = [], 

1360 resolver: Optional[Resolver] = None, 

1361 ) -> None: 

1362 self.connect_future = Future() # type: Future[WebSocketClientConnection] 

1363 self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] 

1364 self.key = base64.b64encode(os.urandom(16)) 

1365 self._on_message_callback = on_message_callback 

1366 self.close_code = None # type: Optional[int] 

1367 self.close_reason = None # type: Optional[str] 

1368 self.params = _WebSocketParams( 

1369 ping_interval=ping_interval, 

1370 ping_timeout=ping_timeout, 

1371 max_message_size=max_message_size, 

1372 compression_options=compression_options, 

1373 ) 

1374 

1375 scheme, sep, rest = request.url.partition(":") 

1376 scheme = {"ws": "http", "wss": "https"}[scheme] 

1377 request.url = scheme + sep + rest 

1378 request.headers.update( 

1379 { 

1380 "Upgrade": "websocket", 

1381 "Connection": "Upgrade", 

1382 "Sec-WebSocket-Key": self.key, 

1383 "Sec-WebSocket-Version": "13", 

1384 } 

1385 ) 

1386 if subprotocols is not None: 

1387 request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) 

1388 if compression_options is not None: 

1389 # Always offer to let the server set our max_wbits (and even though 

1390 # we don't offer it, we will accept a client_no_context_takeover 

1391 # from the server). 

1392 # TODO: set server parameters for deflate extension 

1393 # if requested in self.compression_options. 

1394 request.headers[ 

1395 "Sec-WebSocket-Extensions" 

1396 ] = "permessage-deflate; client_max_window_bits" 

1397 

1398 # Websocket connection is currently unable to follow redirects 

1399 request.follow_redirects = False 

1400 

1401 self.tcp_client = TCPClient(resolver=resolver) 

1402 super().__init__( 

1403 None, 

1404 request, 

1405 lambda: None, 

1406 self._on_http_response, 

1407 104857600, 

1408 self.tcp_client, 

1409 65536, 

1410 104857600, 

1411 ) 

1412 

1413 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: 

1414 """Closes the websocket connection. 

1415 

1416 ``code`` and ``reason`` are documented under 

1417 `WebSocketHandler.close`. 

1418 

1419 .. versionadded:: 3.2 

1420 

1421 .. versionchanged:: 4.0 

1422 

1423 Added the ``code`` and ``reason`` arguments. 

1424 """ 

1425 if self.protocol is not None: 

1426 self.protocol.close(code, reason) 

1427 self.protocol = None # type: ignore 

1428 

1429 def on_connection_close(self) -> None: 

1430 if not self.connect_future.done(): 

1431 self.connect_future.set_exception(StreamClosedError()) 

1432 self._on_message(None) 

1433 self.tcp_client.close() 

1434 super().on_connection_close() 

1435 

1436 def on_ws_connection_close( 

1437 self, close_code: Optional[int] = None, close_reason: Optional[str] = None 

1438 ) -> None: 

1439 self.close_code = close_code 

1440 self.close_reason = close_reason 

1441 self.on_connection_close() 

1442 

1443 def _on_http_response(self, response: httpclient.HTTPResponse) -> None: 

1444 if not self.connect_future.done(): 

1445 if response.error: 

1446 self.connect_future.set_exception(response.error) 

1447 else: 

1448 self.connect_future.set_exception( 

1449 WebSocketError("Non-websocket response") 

1450 ) 

1451 

1452 async def headers_received( 

1453 self, 

1454 start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], 

1455 headers: httputil.HTTPHeaders, 

1456 ) -> None: 

1457 assert isinstance(start_line, httputil.ResponseStartLine) 

1458 if start_line.code != 101: 

1459 await super().headers_received(start_line, headers) 

1460 return 

1461 

1462 if self._timeout is not None: 

1463 self.io_loop.remove_timeout(self._timeout) 

1464 self._timeout = None 

1465 

1466 self.headers = headers 

1467 self.protocol = self.get_websocket_protocol() 

1468 self.protocol._process_server_headers(self.key, self.headers) 

1469 self.protocol.stream = self.connection.detach() 

1470 

1471 IOLoop.current().add_callback(self.protocol._receive_frame_loop) 

1472 self.protocol.start_pinging() 

1473 

1474 # Once we've taken over the connection, clear the final callback 

1475 # we set on the http request. This deactivates the error handling 

1476 # in simple_httpclient that would otherwise interfere with our 

1477 # ability to see exceptions. 

1478 self.final_callback = None # type: ignore 

1479 

1480 future_set_result_unless_cancelled(self.connect_future, self) 

1481 

1482 def write_message( 

1483 self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False 

1484 ) -> "Future[None]": 

1485 """Sends a message to the WebSocket server. 

1486 

1487 If the stream is closed, raises `WebSocketClosedError`. 

1488 Returns a `.Future` which can be used for flow control. 

1489 

1490 .. versionchanged:: 5.0 

1491 Exception raised on a closed stream changed from `.StreamClosedError` 

1492 to `WebSocketClosedError`. 

1493 """ 

1494 if self.protocol is None: 

1495 raise WebSocketClosedError("Client connection has been closed") 

1496 return self.protocol.write_message(message, binary=binary) 

1497 

1498 def read_message( 

1499 self, 

1500 callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None, 

1501 ) -> Awaitable[Union[None, str, bytes]]: 

1502 """Reads a message from the WebSocket server. 

1503 

1504 If on_message_callback was specified at WebSocket 

1505 initialization, this function will never return messages 

1506 

1507 Returns a future whose result is the message, or None 

1508 if the connection is closed. If a callback argument 

1509 is given it will be called with the future when it is 

1510 ready. 

1511 """ 

1512 

1513 awaitable = self.read_queue.get() 

1514 if callback is not None: 

1515 self.io_loop.add_future(asyncio.ensure_future(awaitable), callback) 

1516 return awaitable 

1517 

1518 def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: 

1519 return self._on_message(message) 

1520 

1521 def _on_message( 

1522 self, message: Union[None, str, bytes] 

1523 ) -> Optional[Awaitable[None]]: 

1524 if self._on_message_callback: 

1525 self._on_message_callback(message) 

1526 return None 

1527 else: 

1528 return self.read_queue.put(message) 

1529 

1530 def ping(self, data: bytes = b"") -> None: 

1531 """Send ping frame to the remote end. 

1532 

1533 The data argument allows a small amount of data (up to 125 

1534 bytes) to be sent as a part of the ping message. Note that not 

1535 all websocket implementations expose this data to 

1536 applications. 

1537 

1538 Consider using the ``ping_interval`` argument to 

1539 `websocket_connect` instead of sending pings manually. 

1540 

1541 .. versionadded:: 5.1 

1542 

1543 """ 

1544 data = utf8(data) 

1545 if self.protocol is None: 

1546 raise WebSocketClosedError() 

1547 self.protocol.write_ping(data) 

1548 

1549 def on_pong(self, data: bytes) -> None: 

1550 pass 

1551 

1552 def on_ping(self, data: bytes) -> None: 

1553 pass 

1554 

1555 def get_websocket_protocol(self) -> WebSocketProtocol: 

1556 return WebSocketProtocol13(self, mask_outgoing=True, params=self.params) 

1557 

1558 @property 

1559 def selected_subprotocol(self) -> Optional[str]: 

1560 """The subprotocol selected by the server. 

1561 

1562 .. versionadded:: 5.1 

1563 """ 

1564 return self.protocol.selected_subprotocol 

1565 

1566 def log_exception( 

1567 self, 

1568 typ: "Optional[Type[BaseException]]", 

1569 value: Optional[BaseException], 

1570 tb: Optional[TracebackType], 

1571 ) -> None: 

1572 assert typ is not None 

1573 assert value is not None 

1574 app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) 

1575 

1576 

1577def websocket_connect( 

1578 url: Union[str, httpclient.HTTPRequest], 

1579 callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None, 

1580 connect_timeout: Optional[float] = None, 

1581 on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, 

1582 compression_options: Optional[Dict[str, Any]] = None, 

1583 ping_interval: Optional[float] = None, 

1584 ping_timeout: Optional[float] = None, 

1585 max_message_size: int = _default_max_message_size, 

1586 subprotocols: Optional[List[str]] = None, 

1587 resolver: Optional[Resolver] = None, 

1588) -> "Awaitable[WebSocketClientConnection]": 

1589 """Client-side websocket support. 

1590 

1591 Takes a url and returns a Future whose result is a 

1592 `WebSocketClientConnection`. 

1593 

1594 ``compression_options`` is interpreted in the same way as the 

1595 return value of `.WebSocketHandler.get_compression_options`. 

1596 

1597 The connection supports two styles of operation. In the coroutine 

1598 style, the application typically calls 

1599 `~.WebSocketClientConnection.read_message` in a loop:: 

1600 

1601 conn = yield websocket_connect(url) 

1602 while True: 

1603 msg = yield conn.read_message() 

1604 if msg is None: break 

1605 # Do something with msg 

1606 

1607 In the callback style, pass an ``on_message_callback`` to 

1608 ``websocket_connect``. In both styles, a message of ``None`` 

1609 indicates that the connection has been closed. 

1610 

1611 ``subprotocols`` may be a list of strings specifying proposed 

1612 subprotocols. The selected protocol may be found on the 

1613 ``selected_subprotocol`` attribute of the connection object 

1614 when the connection is complete. 

1615 

1616 .. versionchanged:: 3.2 

1617 Also accepts ``HTTPRequest`` objects in place of urls. 

1618 

1619 .. versionchanged:: 4.1 

1620 Added ``compression_options`` and ``on_message_callback``. 

1621 

1622 .. versionchanged:: 4.5 

1623 Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size`` 

1624 arguments, which have the same meaning as in `WebSocketHandler`. 

1625 

1626 .. versionchanged:: 5.0 

1627 The ``io_loop`` argument (deprecated since version 4.1) has been removed. 

1628 

1629 .. versionchanged:: 5.1 

1630 Added the ``subprotocols`` argument. 

1631 

1632 .. versionchanged:: 6.3 

1633 Added the ``resolver`` argument. 

1634 """ 

1635 if isinstance(url, httpclient.HTTPRequest): 

1636 assert connect_timeout is None 

1637 request = url 

1638 # Copy and convert the headers dict/object (see comments in 

1639 # AsyncHTTPClient.fetch) 

1640 request.headers = httputil.HTTPHeaders(request.headers) 

1641 else: 

1642 request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) 

1643 request = cast( 

1644 httpclient.HTTPRequest, 

1645 httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), 

1646 ) 

1647 conn = WebSocketClientConnection( 

1648 request, 

1649 on_message_callback=on_message_callback, 

1650 compression_options=compression_options, 

1651 ping_interval=ping_interval, 

1652 ping_timeout=ping_timeout, 

1653 max_message_size=max_message_size, 

1654 subprotocols=subprotocols, 

1655 resolver=resolver, 

1656 ) 

1657 if callback is not None: 

1658 IOLoop.current().add_future(conn.connect_future, callback) 

1659 return conn.connect_future