1"""Implementation of the WebSocket protocol.
2
3`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional
4communication between the browser and server. WebSockets are supported in the
5current versions of all major browsers.
6
7This module implements the final version of the WebSocket protocol as
8defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_.
9
10.. versionchanged:: 4.0
11 Removed support for the draft 76 protocol version.
12"""
13
14import abc
15import asyncio
16import base64
17import hashlib
18import os
19import sys
20import struct
21import tornado
22from urllib.parse import urlparse
23import warnings
24import zlib
25
26from tornado.concurrent import Future, future_set_result_unless_cancelled
27from tornado.escape import utf8, native_str, to_unicode
28from tornado import gen, httpclient, httputil
29from tornado.ioloop import IOLoop, PeriodicCallback
30from tornado.iostream import StreamClosedError, IOStream
31from tornado.log import gen_log, app_log
32from tornado.netutil import Resolver
33from tornado import simple_httpclient
34from tornado.queues import Queue
35from tornado.tcpclient import TCPClient
36from tornado.util import _websocket_mask
37
38from typing import (
39 TYPE_CHECKING,
40 cast,
41 Any,
42 Optional,
43 Dict,
44 Union,
45 List,
46 Awaitable,
47 Callable,
48 Tuple,
49 Type,
50)
51from types import TracebackType
52
53if TYPE_CHECKING:
54 from typing_extensions import Protocol
55
56 # The zlib compressor types aren't actually exposed anywhere
57 # publicly, so declare protocols for the portions we use.
58 class _Compressor(Protocol):
59 def compress(self, data: bytes) -> bytes:
60 pass
61
62 def flush(self, mode: int) -> bytes:
63 pass
64
65 class _Decompressor(Protocol):
66 unconsumed_tail = b"" # type: bytes
67
68 def decompress(self, data: bytes, max_length: int) -> bytes:
69 pass
70
71 class _WebSocketDelegate(Protocol):
72 # The common base interface implemented by WebSocketHandler on
73 # the server side and WebSocketClientConnection on the client
74 # side.
75 def on_ws_connection_close(
76 self, close_code: Optional[int] = None, close_reason: Optional[str] = None
77 ) -> None:
78 pass
79
80 def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]:
81 pass
82
83 def on_ping(self, data: bytes) -> None:
84 pass
85
86 def on_pong(self, data: bytes) -> None:
87 pass
88
89 def log_exception(
90 self,
91 typ: Optional[Type[BaseException]],
92 value: Optional[BaseException],
93 tb: Optional[TracebackType],
94 ) -> None:
95 pass
96
97
98_default_max_message_size = 10 * 1024 * 1024
99
100
101class WebSocketError(Exception):
102 pass
103
104
105class WebSocketClosedError(WebSocketError):
106 """Raised by operations on a closed connection.
107
108 .. versionadded:: 3.2
109 """
110
111 pass
112
113
114class _DecompressTooLargeError(Exception):
115 pass
116
117
118class _WebSocketParams(object):
119 def __init__(
120 self,
121 ping_interval: Optional[float] = None,
122 ping_timeout: Optional[float] = None,
123 max_message_size: int = _default_max_message_size,
124 compression_options: Optional[Dict[str, Any]] = None,
125 ) -> None:
126 self.ping_interval = ping_interval
127 self.ping_timeout = ping_timeout
128 self.max_message_size = max_message_size
129 self.compression_options = compression_options
130
131
132class WebSocketHandler(tornado.web.RequestHandler):
133 """Subclass this class to create a basic WebSocket handler.
134
135 Override `on_message` to handle incoming messages, and use
136 `write_message` to send messages to the client. You can also
137 override `open` and `on_close` to handle opened and closed
138 connections.
139
140 Custom upgrade response headers can be sent by overriding
141 `~tornado.web.RequestHandler.set_default_headers` or
142 `~tornado.web.RequestHandler.prepare`.
143
144 See http://dev.w3.org/html5/websockets/ for details on the
145 JavaScript interface. The protocol is specified at
146 http://tools.ietf.org/html/rfc6455.
147
148 Here is an example WebSocket handler that echos back all received messages
149 back to the client:
150
151 .. testcode::
152
153 class EchoWebSocket(tornado.websocket.WebSocketHandler):
154 def open(self):
155 print("WebSocket opened")
156
157 def on_message(self, message):
158 self.write_message(u"You said: " + message)
159
160 def on_close(self):
161 print("WebSocket closed")
162
163 .. testoutput::
164 :hide:
165
166 WebSockets are not standard HTTP connections. The "handshake" is
167 HTTP, but after the handshake, the protocol is
168 message-based. Consequently, most of the Tornado HTTP facilities
169 are not available in handlers of this type. The only communication
170 methods available to you are `write_message()`, `ping()`, and
171 `close()`. Likewise, your request handler class should implement
172 `open()` method rather than ``get()`` or ``post()``.
173
174 If you map the handler above to ``/websocket`` in your application, you can
175 invoke it in JavaScript with::
176
177 var ws = new WebSocket("ws://localhost:8888/websocket");
178 ws.onopen = function() {
179 ws.send("Hello, world");
180 };
181 ws.onmessage = function (evt) {
182 alert(evt.data);
183 };
184
185 This script pops up an alert box that says "You said: Hello, world".
186
187 Web browsers allow any site to open a websocket connection to any other,
188 instead of using the same-origin policy that governs other network
189 access from JavaScript. This can be surprising and is a potential
190 security hole, so since Tornado 4.0 `WebSocketHandler` requires
191 applications that wish to receive cross-origin websockets to opt in
192 by overriding the `~WebSocketHandler.check_origin` method (see that
193 method's docs for details). Failure to do so is the most likely
194 cause of 403 errors when making a websocket connection.
195
196 When using a secure websocket connection (``wss://``) with a self-signed
197 certificate, the connection from a browser may fail because it wants
198 to show the "accept this certificate" dialog but has nowhere to show it.
199 You must first visit a regular HTML page using the same certificate
200 to accept it before the websocket connection will succeed.
201
202 If the application setting ``websocket_ping_interval`` has a non-zero
203 value, a ping will be sent periodically, and the connection will be
204 closed if a response is not received before the ``websocket_ping_timeout``.
205
206 Messages larger than the ``websocket_max_message_size`` application setting
207 (default 10MiB) will not be accepted.
208
209 .. versionchanged:: 4.5
210 Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and
211 ``websocket_max_message_size``.
212 """
213
214 def __init__(
215 self,
216 application: tornado.web.Application,
217 request: httputil.HTTPServerRequest,
218 **kwargs: Any
219 ) -> None:
220 super().__init__(application, request, **kwargs)
221 self.ws_connection = None # type: Optional[WebSocketProtocol]
222 self.close_code = None # type: Optional[int]
223 self.close_reason = None # type: Optional[str]
224 self._on_close_called = False
225
226 async def get(self, *args: Any, **kwargs: Any) -> None:
227 self.open_args = args
228 self.open_kwargs = kwargs
229
230 # Upgrade header should be present and should be equal to WebSocket
231 if self.request.headers.get("Upgrade", "").lower() != "websocket":
232 self.set_status(400)
233 log_msg = 'Can "Upgrade" only to "WebSocket".'
234 self.finish(log_msg)
235 gen_log.debug(log_msg)
236 return
237
238 # Connection header should be upgrade.
239 # Some proxy servers/load balancers
240 # might mess with it.
241 headers = self.request.headers
242 connection = map(
243 lambda s: s.strip().lower(), headers.get("Connection", "").split(",")
244 )
245 if "upgrade" not in connection:
246 self.set_status(400)
247 log_msg = '"Connection" must be "Upgrade".'
248 self.finish(log_msg)
249 gen_log.debug(log_msg)
250 return
251
252 # Handle WebSocket Origin naming convention differences
253 # The difference between version 8 and 13 is that in 8 the
254 # client sends a "Sec-Websocket-Origin" header and in 13 it's
255 # simply "Origin".
256 if "Origin" in self.request.headers:
257 origin = self.request.headers.get("Origin")
258 else:
259 origin = self.request.headers.get("Sec-Websocket-Origin", None)
260
261 # If there was an origin header, check to make sure it matches
262 # according to check_origin. When the origin is None, we assume it
263 # did not come from a browser and that it can be passed on.
264 if origin is not None and not self.check_origin(origin):
265 self.set_status(403)
266 log_msg = "Cross origin websockets not allowed"
267 self.finish(log_msg)
268 gen_log.debug(log_msg)
269 return
270
271 self.ws_connection = self.get_websocket_protocol()
272 if self.ws_connection:
273 await self.ws_connection.accept_connection(self)
274 else:
275 self.set_status(426, "Upgrade Required")
276 self.set_header("Sec-WebSocket-Version", "7, 8, 13")
277
278 @property
279 def ping_interval(self) -> Optional[float]:
280 """The interval for websocket keep-alive pings.
281
282 Set websocket_ping_interval = 0 to disable pings.
283 """
284 return self.settings.get("websocket_ping_interval", None)
285
286 @property
287 def ping_timeout(self) -> Optional[float]:
288 """If no ping is received in this many seconds,
289 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
290 Default is max of 3 pings or 30 seconds.
291 """
292 return self.settings.get("websocket_ping_timeout", None)
293
294 @property
295 def max_message_size(self) -> int:
296 """Maximum allowed message size.
297
298 If the remote peer sends a message larger than this, the connection
299 will be closed.
300
301 Default is 10MiB.
302 """
303 return self.settings.get(
304 "websocket_max_message_size", _default_max_message_size
305 )
306
307 def write_message(
308 self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False
309 ) -> "Future[None]":
310 """Sends the given message to the client of this Web Socket.
311
312 The message may be either a string or a dict (which will be
313 encoded as json). If the ``binary`` argument is false, the
314 message will be sent as utf8; in binary mode any byte string
315 is allowed.
316
317 If the connection is already closed, raises `WebSocketClosedError`.
318 Returns a `.Future` which can be used for flow control.
319
320 .. versionchanged:: 3.2
321 `WebSocketClosedError` was added (previously a closed connection
322 would raise an `AttributeError`)
323
324 .. versionchanged:: 4.3
325 Returns a `.Future` which can be used for flow control.
326
327 .. versionchanged:: 5.0
328 Consistently raises `WebSocketClosedError`. Previously could
329 sometimes raise `.StreamClosedError`.
330 """
331 if self.ws_connection is None or self.ws_connection.is_closing():
332 raise WebSocketClosedError()
333 if isinstance(message, dict):
334 message = tornado.escape.json_encode(message)
335 return self.ws_connection.write_message(message, binary=binary)
336
337 def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]:
338 """Override to implement subprotocol negotiation.
339
340 ``subprotocols`` is a list of strings identifying the
341 subprotocols proposed by the client. This method may be
342 overridden to return one of those strings to select it, or
343 ``None`` to not select a subprotocol.
344
345 Failure to select a subprotocol does not automatically abort
346 the connection, although clients may close the connection if
347 none of their proposed subprotocols was selected.
348
349 The list may be empty, in which case this method must return
350 None. This method is always called exactly once even if no
351 subprotocols were proposed so that the handler can be advised
352 of this fact.
353
354 .. versionchanged:: 5.1
355
356 Previously, this method was called with a list containing
357 an empty string instead of an empty list if no subprotocols
358 were proposed by the client.
359 """
360 return None
361
362 @property
363 def selected_subprotocol(self) -> Optional[str]:
364 """The subprotocol returned by `select_subprotocol`.
365
366 .. versionadded:: 5.1
367 """
368 assert self.ws_connection is not None
369 return self.ws_connection.selected_subprotocol
370
371 def get_compression_options(self) -> Optional[Dict[str, Any]]:
372 """Override to return compression options for the connection.
373
374 If this method returns None (the default), compression will
375 be disabled. If it returns a dict (even an empty one), it
376 will be enabled. The contents of the dict may be used to
377 control the following compression options:
378
379 ``compression_level`` specifies the compression level.
380
381 ``mem_level`` specifies the amount of memory used for the internal compression state.
382
383 These parameters are documented in details here:
384 https://docs.python.org/3.6/library/zlib.html#zlib.compressobj
385
386 .. versionadded:: 4.1
387
388 .. versionchanged:: 4.5
389
390 Added ``compression_level`` and ``mem_level``.
391 """
392 # TODO: Add wbits option.
393 return None
394
395 def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
396 """Invoked when a new WebSocket is opened.
397
398 The arguments to `open` are extracted from the `tornado.web.URLSpec`
399 regular expression, just like the arguments to
400 `tornado.web.RequestHandler.get`.
401
402 `open` may be a coroutine. `on_message` will not be called until
403 `open` has returned.
404
405 .. versionchanged:: 5.1
406
407 ``open`` may be a coroutine.
408 """
409 pass
410
411 def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
412 """Handle incoming messages on the WebSocket
413
414 This method must be overridden.
415
416 .. versionchanged:: 4.5
417
418 ``on_message`` can be a coroutine.
419 """
420 raise NotImplementedError
421
422 def ping(self, data: Union[str, bytes] = b"") -> None:
423 """Send ping frame to the remote end.
424
425 The data argument allows a small amount of data (up to 125
426 bytes) to be sent as a part of the ping message. Note that not
427 all websocket implementations expose this data to
428 applications.
429
430 Consider using the ``websocket_ping_interval`` application
431 setting instead of sending pings manually.
432
433 .. versionchanged:: 5.1
434
435 The data argument is now optional.
436
437 """
438 data = utf8(data)
439 if self.ws_connection is None or self.ws_connection.is_closing():
440 raise WebSocketClosedError()
441 self.ws_connection.write_ping(data)
442
443 def on_pong(self, data: bytes) -> None:
444 """Invoked when the response to a ping frame is received."""
445 pass
446
447 def on_ping(self, data: bytes) -> None:
448 """Invoked when the a ping frame is received."""
449 pass
450
451 def on_close(self) -> None:
452 """Invoked when the WebSocket is closed.
453
454 If the connection was closed cleanly and a status code or reason
455 phrase was supplied, these values will be available as the attributes
456 ``self.close_code`` and ``self.close_reason``.
457
458 .. versionchanged:: 4.0
459
460 Added ``close_code`` and ``close_reason`` attributes.
461 """
462 pass
463
464 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
465 """Closes this Web Socket.
466
467 Once the close handshake is successful the socket will be closed.
468
469 ``code`` may be a numeric status code, taken from the values
470 defined in `RFC 6455 section 7.4.1
471 <https://tools.ietf.org/html/rfc6455#section-7.4.1>`_.
472 ``reason`` may be a textual message about why the connection is
473 closing. These values are made available to the client, but are
474 not otherwise interpreted by the websocket protocol.
475
476 .. versionchanged:: 4.0
477
478 Added the ``code`` and ``reason`` arguments.
479 """
480 if self.ws_connection:
481 self.ws_connection.close(code, reason)
482 self.ws_connection = None
483
484 def check_origin(self, origin: str) -> bool:
485 """Override to enable support for allowing alternate origins.
486
487 The ``origin`` argument is the value of the ``Origin`` HTTP
488 header, the url responsible for initiating this request. This
489 method is not called for clients that do not send this header;
490 such requests are always allowed (because all browsers that
491 implement WebSockets support this header, and non-browser
492 clients do not have the same cross-site security concerns).
493
494 Should return ``True`` to accept the request or ``False`` to
495 reject it. By default, rejects all requests with an origin on
496 a host other than this one.
497
498 This is a security protection against cross site scripting attacks on
499 browsers, since WebSockets are allowed to bypass the usual same-origin
500 policies and don't use CORS headers.
501
502 .. warning::
503
504 This is an important security measure; don't disable it
505 without understanding the security implications. In
506 particular, if your authentication is cookie-based, you
507 must either restrict the origins allowed by
508 ``check_origin()`` or implement your own XSRF-like
509 protection for websocket connections. See `these
510 <https://www.christian-schneider.net/CrossSiteWebSocketHijacking.html>`_
511 `articles
512 <https://devcenter.heroku.com/articles/websocket-security>`_
513 for more.
514
515 To accept all cross-origin traffic (which was the default prior to
516 Tornado 4.0), simply override this method to always return ``True``::
517
518 def check_origin(self, origin):
519 return True
520
521 To allow connections from any subdomain of your site, you might
522 do something like::
523
524 def check_origin(self, origin):
525 parsed_origin = urllib.parse.urlparse(origin)
526 return parsed_origin.netloc.endswith(".mydomain.com")
527
528 .. versionadded:: 4.0
529
530 """
531 parsed_origin = urlparse(origin)
532 origin = parsed_origin.netloc
533 origin = origin.lower()
534
535 host = self.request.headers.get("Host")
536
537 # Check to see that origin matches host directly, including ports
538 return origin == host
539
540 def set_nodelay(self, value: bool) -> None:
541 """Set the no-delay flag for this stream.
542
543 By default, small messages may be delayed and/or combined to minimize
544 the number of packets sent. This can sometimes cause 200-500ms delays
545 due to the interaction between Nagle's algorithm and TCP delayed
546 ACKs. To reduce this delay (at the expense of possibly increasing
547 bandwidth usage), call ``self.set_nodelay(True)`` once the websocket
548 connection is established.
549
550 See `.BaseIOStream.set_nodelay` for additional details.
551
552 .. versionadded:: 3.1
553 """
554 assert self.ws_connection is not None
555 self.ws_connection.set_nodelay(value)
556
557 def on_connection_close(self) -> None:
558 if self.ws_connection:
559 self.ws_connection.on_connection_close()
560 self.ws_connection = None
561 if not self._on_close_called:
562 self._on_close_called = True
563 self.on_close()
564 self._break_cycles()
565
566 def on_ws_connection_close(
567 self, close_code: Optional[int] = None, close_reason: Optional[str] = None
568 ) -> None:
569 self.close_code = close_code
570 self.close_reason = close_reason
571 self.on_connection_close()
572
573 def _break_cycles(self) -> None:
574 # WebSocketHandlers call finish() early, but we don't want to
575 # break up reference cycles (which makes it impossible to call
576 # self.render_string) until after we've really closed the
577 # connection (if it was established in the first place,
578 # indicated by status code 101).
579 if self.get_status() != 101 or self._on_close_called:
580 super()._break_cycles()
581
582 def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]:
583 websocket_version = self.request.headers.get("Sec-WebSocket-Version")
584 if websocket_version in ("7", "8", "13"):
585 params = _WebSocketParams(
586 ping_interval=self.ping_interval,
587 ping_timeout=self.ping_timeout,
588 max_message_size=self.max_message_size,
589 compression_options=self.get_compression_options(),
590 )
591 return WebSocketProtocol13(self, False, params)
592 return None
593
594 def _detach_stream(self) -> IOStream:
595 # disable non-WS methods
596 for method in [
597 "write",
598 "redirect",
599 "set_header",
600 "set_cookie",
601 "set_status",
602 "flush",
603 "finish",
604 ]:
605 setattr(self, method, _raise_not_supported_for_websockets)
606 return self.detach()
607
608
609def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None:
610 raise RuntimeError("Method not supported for Web Sockets")
611
612
613class WebSocketProtocol(abc.ABC):
614 """Base class for WebSocket protocol versions."""
615
616 def __init__(self, handler: "_WebSocketDelegate") -> None:
617 self.handler = handler
618 self.stream = None # type: Optional[IOStream]
619 self.client_terminated = False
620 self.server_terminated = False
621
622 def _run_callback(
623 self, callback: Callable, *args: Any, **kwargs: Any
624 ) -> "Optional[Future[Any]]":
625 """Runs the given callback with exception handling.
626
627 If the callback is a coroutine, returns its Future. On error, aborts the
628 websocket connection and returns None.
629 """
630 try:
631 result = callback(*args, **kwargs)
632 except Exception:
633 self.handler.log_exception(*sys.exc_info())
634 self._abort()
635 return None
636 else:
637 if result is not None:
638 result = gen.convert_yielded(result)
639 assert self.stream is not None
640 self.stream.io_loop.add_future(result, lambda f: f.result())
641 return result
642
643 def on_connection_close(self) -> None:
644 self._abort()
645
646 def _abort(self) -> None:
647 """Instantly aborts the WebSocket connection by closing the socket"""
648 self.client_terminated = True
649 self.server_terminated = True
650 if self.stream is not None:
651 self.stream.close() # forcibly tear down the connection
652 self.close() # let the subclass cleanup
653
654 @abc.abstractmethod
655 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
656 raise NotImplementedError()
657
658 @abc.abstractmethod
659 def is_closing(self) -> bool:
660 raise NotImplementedError()
661
662 @abc.abstractmethod
663 async def accept_connection(self, handler: WebSocketHandler) -> None:
664 raise NotImplementedError()
665
666 @abc.abstractmethod
667 def write_message(
668 self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False
669 ) -> "Future[None]":
670 raise NotImplementedError()
671
672 @property
673 @abc.abstractmethod
674 def selected_subprotocol(self) -> Optional[str]:
675 raise NotImplementedError()
676
677 @abc.abstractmethod
678 def write_ping(self, data: bytes) -> None:
679 raise NotImplementedError()
680
681 # The entry points below are used by WebSocketClientConnection,
682 # which was introduced after we only supported a single version of
683 # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13
684 # boundary is currently pretty ad-hoc.
685 @abc.abstractmethod
686 def _process_server_headers(
687 self, key: Union[str, bytes], headers: httputil.HTTPHeaders
688 ) -> None:
689 raise NotImplementedError()
690
691 @abc.abstractmethod
692 def start_pinging(self) -> None:
693 raise NotImplementedError()
694
695 @abc.abstractmethod
696 async def _receive_frame_loop(self) -> None:
697 raise NotImplementedError()
698
699 @abc.abstractmethod
700 def set_nodelay(self, x: bool) -> None:
701 raise NotImplementedError()
702
703
704class _PerMessageDeflateCompressor(object):
705 def __init__(
706 self,
707 persistent: bool,
708 max_wbits: Optional[int],
709 compression_options: Optional[Dict[str, Any]] = None,
710 ) -> None:
711 if max_wbits is None:
712 max_wbits = zlib.MAX_WBITS
713 # There is no symbolic constant for the minimum wbits value.
714 if not (8 <= max_wbits <= zlib.MAX_WBITS):
715 raise ValueError(
716 "Invalid max_wbits value %r; allowed range 8-%d",
717 max_wbits,
718 zlib.MAX_WBITS,
719 )
720 self._max_wbits = max_wbits
721
722 if (
723 compression_options is None
724 or "compression_level" not in compression_options
725 ):
726 self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL
727 else:
728 self._compression_level = compression_options["compression_level"]
729
730 if compression_options is None or "mem_level" not in compression_options:
731 self._mem_level = 8
732 else:
733 self._mem_level = compression_options["mem_level"]
734
735 if persistent:
736 self._compressor = self._create_compressor() # type: Optional[_Compressor]
737 else:
738 self._compressor = None
739
740 def _create_compressor(self) -> "_Compressor":
741 return zlib.compressobj(
742 self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level
743 )
744
745 def compress(self, data: bytes) -> bytes:
746 compressor = self._compressor or self._create_compressor()
747 data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH)
748 assert data.endswith(b"\x00\x00\xff\xff")
749 return data[:-4]
750
751
752class _PerMessageDeflateDecompressor(object):
753 def __init__(
754 self,
755 persistent: bool,
756 max_wbits: Optional[int],
757 max_message_size: int,
758 compression_options: Optional[Dict[str, Any]] = None,
759 ) -> None:
760 self._max_message_size = max_message_size
761 if max_wbits is None:
762 max_wbits = zlib.MAX_WBITS
763 if not (8 <= max_wbits <= zlib.MAX_WBITS):
764 raise ValueError(
765 "Invalid max_wbits value %r; allowed range 8-%d",
766 max_wbits,
767 zlib.MAX_WBITS,
768 )
769 self._max_wbits = max_wbits
770 if persistent:
771 self._decompressor = (
772 self._create_decompressor()
773 ) # type: Optional[_Decompressor]
774 else:
775 self._decompressor = None
776
777 def _create_decompressor(self) -> "_Decompressor":
778 return zlib.decompressobj(-self._max_wbits)
779
780 def decompress(self, data: bytes) -> bytes:
781 decompressor = self._decompressor or self._create_decompressor()
782 result = decompressor.decompress(
783 data + b"\x00\x00\xff\xff", self._max_message_size
784 )
785 if decompressor.unconsumed_tail:
786 raise _DecompressTooLargeError()
787 return result
788
789
790class WebSocketProtocol13(WebSocketProtocol):
791 """Implementation of the WebSocket protocol from RFC 6455.
792
793 This class supports versions 7 and 8 of the protocol in addition to the
794 final version 13.
795 """
796
797 # Bit masks for the first byte of a frame.
798 FIN = 0x80
799 RSV1 = 0x40
800 RSV2 = 0x20
801 RSV3 = 0x10
802 RSV_MASK = RSV1 | RSV2 | RSV3
803 OPCODE_MASK = 0x0F
804
805 stream = None # type: IOStream
806
807 def __init__(
808 self,
809 handler: "_WebSocketDelegate",
810 mask_outgoing: bool,
811 params: _WebSocketParams,
812 ) -> None:
813 WebSocketProtocol.__init__(self, handler)
814 self.mask_outgoing = mask_outgoing
815 self.params = params
816 self._final_frame = False
817 self._frame_opcode = None
818 self._masked_frame = None
819 self._frame_mask = None # type: Optional[bytes]
820 self._frame_length = None
821 self._fragmented_message_buffer = None # type: Optional[bytearray]
822 self._fragmented_message_opcode = None
823 self._waiting = None # type: object
824 self._compression_options = params.compression_options
825 self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor]
826 self._compressor = None # type: Optional[_PerMessageDeflateCompressor]
827 self._frame_compressed = None # type: Optional[bool]
828 # The total uncompressed size of all messages received or sent.
829 # Unicode messages are encoded to utf8.
830 # Only for testing; subject to change.
831 self._message_bytes_in = 0
832 self._message_bytes_out = 0
833 # The total size of all packets received or sent. Includes
834 # the effect of compression, frame overhead, and control frames.
835 self._wire_bytes_in = 0
836 self._wire_bytes_out = 0
837 self.ping_callback = None # type: Optional[PeriodicCallback]
838 self.last_ping = 0.0
839 self.last_pong = 0.0
840 self.close_code = None # type: Optional[int]
841 self.close_reason = None # type: Optional[str]
842
843 # Use a property for this to satisfy the abc.
844 @property
845 def selected_subprotocol(self) -> Optional[str]:
846 return self._selected_subprotocol
847
848 @selected_subprotocol.setter
849 def selected_subprotocol(self, value: Optional[str]) -> None:
850 self._selected_subprotocol = value
851
852 async def accept_connection(self, handler: WebSocketHandler) -> None:
853 try:
854 self._handle_websocket_headers(handler)
855 except ValueError:
856 handler.set_status(400)
857 log_msg = "Missing/Invalid WebSocket headers"
858 handler.finish(log_msg)
859 gen_log.debug(log_msg)
860 return
861
862 try:
863 await self._accept_connection(handler)
864 except asyncio.CancelledError:
865 self._abort()
866 return
867 except ValueError:
868 gen_log.debug("Malformed WebSocket request received", exc_info=True)
869 self._abort()
870 return
871
872 def _handle_websocket_headers(self, handler: WebSocketHandler) -> None:
873 """Verifies all invariant- and required headers
874
875 If a header is missing or have an incorrect value ValueError will be
876 raised
877 """
878 fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
879 if not all(map(lambda f: handler.request.headers.get(f), fields)):
880 raise ValueError("Missing/Invalid WebSocket headers")
881
882 @staticmethod
883 def compute_accept_value(key: Union[str, bytes]) -> str:
884 """Computes the value for the Sec-WebSocket-Accept header,
885 given the value for Sec-WebSocket-Key.
886 """
887 sha1 = hashlib.sha1()
888 sha1.update(utf8(key))
889 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
890 return native_str(base64.b64encode(sha1.digest()))
891
892 def _challenge_response(self, handler: WebSocketHandler) -> str:
893 return WebSocketProtocol13.compute_accept_value(
894 cast(str, handler.request.headers.get("Sec-Websocket-Key"))
895 )
896
897 async def _accept_connection(self, handler: WebSocketHandler) -> None:
898 subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
899 if subprotocol_header:
900 subprotocols = [s.strip() for s in subprotocol_header.split(",")]
901 else:
902 subprotocols = []
903 self.selected_subprotocol = handler.select_subprotocol(subprotocols)
904 if self.selected_subprotocol:
905 assert self.selected_subprotocol in subprotocols
906 handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
907
908 extensions = self._parse_extensions_header(handler.request.headers)
909 for ext in extensions:
910 if ext[0] == "permessage-deflate" and self._compression_options is not None:
911 # TODO: negotiate parameters if compression_options
912 # specifies limits.
913 self._create_compressors("server", ext[1], self._compression_options)
914 if (
915 "client_max_window_bits" in ext[1]
916 and ext[1]["client_max_window_bits"] is None
917 ):
918 # Don't echo an offered client_max_window_bits
919 # parameter with no value.
920 del ext[1]["client_max_window_bits"]
921 handler.set_header(
922 "Sec-WebSocket-Extensions",
923 httputil._encode_header("permessage-deflate", ext[1]),
924 )
925 break
926
927 handler.clear_header("Content-Type")
928 handler.set_status(101)
929 handler.set_header("Upgrade", "websocket")
930 handler.set_header("Connection", "Upgrade")
931 handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler))
932 handler.finish()
933
934 self.stream = handler._detach_stream()
935
936 self.start_pinging()
937 try:
938 open_result = handler.open(*handler.open_args, **handler.open_kwargs)
939 if open_result is not None:
940 await open_result
941 except Exception:
942 handler.log_exception(*sys.exc_info())
943 self._abort()
944 return
945
946 await self._receive_frame_loop()
947
948 def _parse_extensions_header(
949 self, headers: httputil.HTTPHeaders
950 ) -> List[Tuple[str, Dict[str, str]]]:
951 extensions = headers.get("Sec-WebSocket-Extensions", "")
952 if extensions:
953 return [httputil._parse_header(e.strip()) for e in extensions.split(",")]
954 return []
955
956 def _process_server_headers(
957 self, key: Union[str, bytes], headers: httputil.HTTPHeaders
958 ) -> None:
959 """Process the headers sent by the server to this client connection.
960
961 'key' is the websocket handshake challenge/response key.
962 """
963 assert headers["Upgrade"].lower() == "websocket"
964 assert headers["Connection"].lower() == "upgrade"
965 accept = self.compute_accept_value(key)
966 assert headers["Sec-Websocket-Accept"] == accept
967
968 extensions = self._parse_extensions_header(headers)
969 for ext in extensions:
970 if ext[0] == "permessage-deflate" and self._compression_options is not None:
971 self._create_compressors("client", ext[1])
972 else:
973 raise ValueError("unsupported extension %r", ext)
974
975 self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None)
976
977 def _get_compressor_options(
978 self,
979 side: str,
980 agreed_parameters: Dict[str, Any],
981 compression_options: Optional[Dict[str, Any]] = None,
982 ) -> Dict[str, Any]:
983 """Converts a websocket agreed_parameters set to keyword arguments
984 for our compressor objects.
985 """
986 options = dict(
987 persistent=(side + "_no_context_takeover") not in agreed_parameters
988 ) # type: Dict[str, Any]
989 wbits_header = agreed_parameters.get(side + "_max_window_bits", None)
990 if wbits_header is None:
991 options["max_wbits"] = zlib.MAX_WBITS
992 else:
993 options["max_wbits"] = int(wbits_header)
994 options["compression_options"] = compression_options
995 return options
996
997 def _create_compressors(
998 self,
999 side: str,
1000 agreed_parameters: Dict[str, Any],
1001 compression_options: Optional[Dict[str, Any]] = None,
1002 ) -> None:
1003 # TODO: handle invalid parameters gracefully
1004 allowed_keys = set(
1005 [
1006 "server_no_context_takeover",
1007 "client_no_context_takeover",
1008 "server_max_window_bits",
1009 "client_max_window_bits",
1010 ]
1011 )
1012 for key in agreed_parameters:
1013 if key not in allowed_keys:
1014 raise ValueError("unsupported compression parameter %r" % key)
1015 other_side = "client" if (side == "server") else "server"
1016 self._compressor = _PerMessageDeflateCompressor(
1017 **self._get_compressor_options(side, agreed_parameters, compression_options)
1018 )
1019 self._decompressor = _PerMessageDeflateDecompressor(
1020 max_message_size=self.params.max_message_size,
1021 **self._get_compressor_options(
1022 other_side, agreed_parameters, compression_options
1023 )
1024 )
1025
1026 def _write_frame(
1027 self, fin: bool, opcode: int, data: bytes, flags: int = 0
1028 ) -> "Future[None]":
1029 data_len = len(data)
1030 if opcode & 0x8:
1031 # All control frames MUST have a payload length of 125
1032 # bytes or less and MUST NOT be fragmented.
1033 if not fin:
1034 raise ValueError("control frames may not be fragmented")
1035 if data_len > 125:
1036 raise ValueError("control frame payloads may not exceed 125 bytes")
1037 if fin:
1038 finbit = self.FIN
1039 else:
1040 finbit = 0
1041 frame = struct.pack("B", finbit | opcode | flags)
1042 if self.mask_outgoing:
1043 mask_bit = 0x80
1044 else:
1045 mask_bit = 0
1046 if data_len < 126:
1047 frame += struct.pack("B", data_len | mask_bit)
1048 elif data_len <= 0xFFFF:
1049 frame += struct.pack("!BH", 126 | mask_bit, data_len)
1050 else:
1051 frame += struct.pack("!BQ", 127 | mask_bit, data_len)
1052 if self.mask_outgoing:
1053 mask = os.urandom(4)
1054 data = mask + _websocket_mask(mask, data)
1055 frame += data
1056 self._wire_bytes_out += len(frame)
1057 return self.stream.write(frame)
1058
1059 def write_message(
1060 self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False
1061 ) -> "Future[None]":
1062 """Sends the given message to the client of this Web Socket."""
1063 if binary:
1064 opcode = 0x2
1065 else:
1066 opcode = 0x1
1067 if isinstance(message, dict):
1068 message = tornado.escape.json_encode(message)
1069 message = tornado.escape.utf8(message)
1070 assert isinstance(message, bytes)
1071 self._message_bytes_out += len(message)
1072 flags = 0
1073 if self._compressor:
1074 message = self._compressor.compress(message)
1075 flags |= self.RSV1
1076 # For historical reasons, write methods in Tornado operate in a semi-synchronous
1077 # mode in which awaiting the Future they return is optional (But errors can
1078 # still be raised). This requires us to go through an awkward dance here
1079 # to transform the errors that may be returned while presenting the same
1080 # semi-synchronous interface.
1081 try:
1082 fut = self._write_frame(True, opcode, message, flags=flags)
1083 except StreamClosedError:
1084 raise WebSocketClosedError()
1085
1086 async def wrapper() -> None:
1087 try:
1088 await fut
1089 except StreamClosedError:
1090 raise WebSocketClosedError()
1091
1092 return asyncio.ensure_future(wrapper())
1093
1094 def write_ping(self, data: bytes) -> None:
1095 """Send ping frame."""
1096 assert isinstance(data, bytes)
1097 self._write_frame(True, 0x9, data)
1098
1099 async def _receive_frame_loop(self) -> None:
1100 try:
1101 while not self.client_terminated:
1102 await self._receive_frame()
1103 except StreamClosedError:
1104 self._abort()
1105 self.handler.on_ws_connection_close(self.close_code, self.close_reason)
1106
1107 async def _read_bytes(self, n: int) -> bytes:
1108 data = await self.stream.read_bytes(n)
1109 self._wire_bytes_in += n
1110 return data
1111
1112 async def _receive_frame(self) -> None:
1113 # Read the frame header.
1114 data = await self._read_bytes(2)
1115 header, mask_payloadlen = struct.unpack("BB", data)
1116 is_final_frame = header & self.FIN
1117 reserved_bits = header & self.RSV_MASK
1118 opcode = header & self.OPCODE_MASK
1119 opcode_is_control = opcode & 0x8
1120 if self._decompressor is not None and opcode != 0:
1121 # Compression flag is present in the first frame's header,
1122 # but we can't decompress until we have all the frames of
1123 # the message.
1124 self._frame_compressed = bool(reserved_bits & self.RSV1)
1125 reserved_bits &= ~self.RSV1
1126 if reserved_bits:
1127 # client is using as-yet-undefined extensions; abort
1128 self._abort()
1129 return
1130 is_masked = bool(mask_payloadlen & 0x80)
1131 payloadlen = mask_payloadlen & 0x7F
1132
1133 # Parse and validate the length.
1134 if opcode_is_control and payloadlen >= 126:
1135 # control frames must have payload < 126
1136 self._abort()
1137 return
1138 if payloadlen < 126:
1139 self._frame_length = payloadlen
1140 elif payloadlen == 126:
1141 data = await self._read_bytes(2)
1142 payloadlen = struct.unpack("!H", data)[0]
1143 elif payloadlen == 127:
1144 data = await self._read_bytes(8)
1145 payloadlen = struct.unpack("!Q", data)[0]
1146 new_len = payloadlen
1147 if self._fragmented_message_buffer is not None:
1148 new_len += len(self._fragmented_message_buffer)
1149 if new_len > self.params.max_message_size:
1150 self.close(1009, "message too big")
1151 self._abort()
1152 return
1153
1154 # Read the payload, unmasking if necessary.
1155 if is_masked:
1156 self._frame_mask = await self._read_bytes(4)
1157 data = await self._read_bytes(payloadlen)
1158 if is_masked:
1159 assert self._frame_mask is not None
1160 data = _websocket_mask(self._frame_mask, data)
1161
1162 # Decide what to do with this frame.
1163 if opcode_is_control:
1164 # control frames may be interleaved with a series of fragmented
1165 # data frames, so control frames must not interact with
1166 # self._fragmented_*
1167 if not is_final_frame:
1168 # control frames must not be fragmented
1169 self._abort()
1170 return
1171 elif opcode == 0: # continuation frame
1172 if self._fragmented_message_buffer is None:
1173 # nothing to continue
1174 self._abort()
1175 return
1176 self._fragmented_message_buffer.extend(data)
1177 if is_final_frame:
1178 opcode = self._fragmented_message_opcode
1179 data = bytes(self._fragmented_message_buffer)
1180 self._fragmented_message_buffer = None
1181 else: # start of new data message
1182 if self._fragmented_message_buffer is not None:
1183 # can't start new message until the old one is finished
1184 self._abort()
1185 return
1186 if not is_final_frame:
1187 self._fragmented_message_opcode = opcode
1188 self._fragmented_message_buffer = bytearray(data)
1189
1190 if is_final_frame:
1191 handled_future = self._handle_message(opcode, data)
1192 if handled_future is not None:
1193 await handled_future
1194
1195 def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]":
1196 """Execute on_message, returning its Future if it is a coroutine."""
1197 if self.client_terminated:
1198 return None
1199
1200 if self._frame_compressed:
1201 assert self._decompressor is not None
1202 try:
1203 data = self._decompressor.decompress(data)
1204 except _DecompressTooLargeError:
1205 self.close(1009, "message too big after decompression")
1206 self._abort()
1207 return None
1208
1209 if opcode == 0x1:
1210 # UTF-8 data
1211 self._message_bytes_in += len(data)
1212 try:
1213 decoded = data.decode("utf-8")
1214 except UnicodeDecodeError:
1215 self._abort()
1216 return None
1217 return self._run_callback(self.handler.on_message, decoded)
1218 elif opcode == 0x2:
1219 # Binary data
1220 self._message_bytes_in += len(data)
1221 return self._run_callback(self.handler.on_message, data)
1222 elif opcode == 0x8:
1223 # Close
1224 self.client_terminated = True
1225 if len(data) >= 2:
1226 self.close_code = struct.unpack(">H", data[:2])[0]
1227 if len(data) > 2:
1228 self.close_reason = to_unicode(data[2:])
1229 # Echo the received close code, if any (RFC 6455 section 5.5.1).
1230 self.close(self.close_code)
1231 elif opcode == 0x9:
1232 # Ping
1233 try:
1234 self._write_frame(True, 0xA, data)
1235 except StreamClosedError:
1236 self._abort()
1237 self._run_callback(self.handler.on_ping, data)
1238 elif opcode == 0xA:
1239 # Pong
1240 self.last_pong = IOLoop.current().time()
1241 return self._run_callback(self.handler.on_pong, data)
1242 else:
1243 self._abort()
1244 return None
1245
1246 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
1247 """Closes the WebSocket connection."""
1248 if not self.server_terminated:
1249 if not self.stream.closed():
1250 if code is None and reason is not None:
1251 code = 1000 # "normal closure" status code
1252 if code is None:
1253 close_data = b""
1254 else:
1255 close_data = struct.pack(">H", code)
1256 if reason is not None:
1257 close_data += utf8(reason)
1258 try:
1259 self._write_frame(True, 0x8, close_data)
1260 except StreamClosedError:
1261 self._abort()
1262 self.server_terminated = True
1263 if self.client_terminated:
1264 if self._waiting is not None:
1265 self.stream.io_loop.remove_timeout(self._waiting)
1266 self._waiting = None
1267 self.stream.close()
1268 elif self._waiting is None:
1269 # Give the client a few seconds to complete a clean shutdown,
1270 # otherwise just close the connection.
1271 self._waiting = self.stream.io_loop.add_timeout(
1272 self.stream.io_loop.time() + 5, self._abort
1273 )
1274 if self.ping_callback:
1275 self.ping_callback.stop()
1276 self.ping_callback = None
1277
1278 def is_closing(self) -> bool:
1279 """Return ``True`` if this connection is closing.
1280
1281 The connection is considered closing if either side has
1282 initiated its closing handshake or if the stream has been
1283 shut down uncleanly.
1284 """
1285 return self.stream.closed() or self.client_terminated or self.server_terminated
1286
1287 @property
1288 def ping_interval(self) -> Optional[float]:
1289 interval = self.params.ping_interval
1290 if interval is not None:
1291 return interval
1292 return 0
1293
1294 @property
1295 def ping_timeout(self) -> Optional[float]:
1296 timeout = self.params.ping_timeout
1297 if timeout is not None:
1298 return timeout
1299 assert self.ping_interval is not None
1300 return max(3 * self.ping_interval, 30)
1301
1302 def start_pinging(self) -> None:
1303 """Start sending periodic pings to keep the connection alive"""
1304 assert self.ping_interval is not None
1305 if self.ping_interval > 0:
1306 self.last_ping = self.last_pong = IOLoop.current().time()
1307 self.ping_callback = PeriodicCallback(
1308 self.periodic_ping, self.ping_interval * 1000
1309 )
1310 self.ping_callback.start()
1311
1312 def periodic_ping(self) -> None:
1313 """Send a ping to keep the websocket alive
1314
1315 Called periodically if the websocket_ping_interval is set and non-zero.
1316 """
1317 if self.is_closing() and self.ping_callback is not None:
1318 self.ping_callback.stop()
1319 return
1320
1321 # Check for timeout on pong. Make sure that we really have
1322 # sent a recent ping in case the machine with both server and
1323 # client has been suspended since the last ping.
1324 now = IOLoop.current().time()
1325 since_last_pong = now - self.last_pong
1326 since_last_ping = now - self.last_ping
1327 assert self.ping_interval is not None
1328 assert self.ping_timeout is not None
1329 if (
1330 since_last_ping < 2 * self.ping_interval
1331 and since_last_pong > self.ping_timeout
1332 ):
1333 self.close()
1334 return
1335
1336 self.write_ping(b"")
1337 self.last_ping = now
1338
1339 def set_nodelay(self, x: bool) -> None:
1340 self.stream.set_nodelay(x)
1341
1342
1343class WebSocketClientConnection(simple_httpclient._HTTPConnection):
1344 """WebSocket client connection.
1345
1346 This class should not be instantiated directly; use the
1347 `websocket_connect` function instead.
1348 """
1349
1350 protocol = None # type: WebSocketProtocol
1351
1352 def __init__(
1353 self,
1354 request: httpclient.HTTPRequest,
1355 on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None,
1356 compression_options: Optional[Dict[str, Any]] = None,
1357 ping_interval: Optional[float] = None,
1358 ping_timeout: Optional[float] = None,
1359 max_message_size: int = _default_max_message_size,
1360 subprotocols: Optional[List[str]] = None,
1361 resolver: Optional[Resolver] = None,
1362 ) -> None:
1363 self.connect_future = Future() # type: Future[WebSocketClientConnection]
1364 self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]]
1365 self.key = base64.b64encode(os.urandom(16))
1366 self._on_message_callback = on_message_callback
1367 self.close_code = None # type: Optional[int]
1368 self.close_reason = None # type: Optional[str]
1369 self.params = _WebSocketParams(
1370 ping_interval=ping_interval,
1371 ping_timeout=ping_timeout,
1372 max_message_size=max_message_size,
1373 compression_options=compression_options,
1374 )
1375
1376 scheme, sep, rest = request.url.partition(":")
1377 scheme = {"ws": "http", "wss": "https"}[scheme]
1378 request.url = scheme + sep + rest
1379 request.headers.update(
1380 {
1381 "Upgrade": "websocket",
1382 "Connection": "Upgrade",
1383 "Sec-WebSocket-Key": self.key,
1384 "Sec-WebSocket-Version": "13",
1385 }
1386 )
1387 if subprotocols is not None:
1388 request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols)
1389 if compression_options is not None:
1390 # Always offer to let the server set our max_wbits (and even though
1391 # we don't offer it, we will accept a client_no_context_takeover
1392 # from the server).
1393 # TODO: set server parameters for deflate extension
1394 # if requested in self.compression_options.
1395 request.headers["Sec-WebSocket-Extensions"] = (
1396 "permessage-deflate; client_max_window_bits"
1397 )
1398
1399 # Websocket connection is currently unable to follow redirects
1400 request.follow_redirects = False
1401
1402 self.tcp_client = TCPClient(resolver=resolver)
1403 super().__init__(
1404 None,
1405 request,
1406 lambda: None,
1407 self._on_http_response,
1408 104857600,
1409 self.tcp_client,
1410 65536,
1411 104857600,
1412 )
1413
1414 def __del__(self) -> None:
1415 if self.protocol is not None:
1416 # Unclosed client connections can sometimes log "task was destroyed but
1417 # was pending" warnings if shutdown strikes at the wrong time (such as
1418 # while a ping is being processed due to ping_interval). Log our own
1419 # warning to make it a little more deterministic (although it's still
1420 # dependent on GC timing).
1421 warnings.warn("Unclosed WebSocketClientConnection", ResourceWarning)
1422
1423 def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
1424 """Closes the websocket connection.
1425
1426 ``code`` and ``reason`` are documented under
1427 `WebSocketHandler.close`.
1428
1429 .. versionadded:: 3.2
1430
1431 .. versionchanged:: 4.0
1432
1433 Added the ``code`` and ``reason`` arguments.
1434 """
1435 if self.protocol is not None:
1436 self.protocol.close(code, reason)
1437 self.protocol = None # type: ignore
1438
1439 def on_connection_close(self) -> None:
1440 if not self.connect_future.done():
1441 self.connect_future.set_exception(StreamClosedError())
1442 self._on_message(None)
1443 self.tcp_client.close()
1444 super().on_connection_close()
1445
1446 def on_ws_connection_close(
1447 self, close_code: Optional[int] = None, close_reason: Optional[str] = None
1448 ) -> None:
1449 self.close_code = close_code
1450 self.close_reason = close_reason
1451 self.on_connection_close()
1452
1453 def _on_http_response(self, response: httpclient.HTTPResponse) -> None:
1454 if not self.connect_future.done():
1455 if response.error:
1456 self.connect_future.set_exception(response.error)
1457 else:
1458 self.connect_future.set_exception(
1459 WebSocketError("Non-websocket response")
1460 )
1461
1462 async def headers_received(
1463 self,
1464 start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
1465 headers: httputil.HTTPHeaders,
1466 ) -> None:
1467 assert isinstance(start_line, httputil.ResponseStartLine)
1468 if start_line.code != 101:
1469 await super().headers_received(start_line, headers)
1470 return
1471
1472 if self._timeout is not None:
1473 self.io_loop.remove_timeout(self._timeout)
1474 self._timeout = None
1475
1476 self.headers = headers
1477 self.protocol = self.get_websocket_protocol()
1478 self.protocol._process_server_headers(self.key, self.headers)
1479 self.protocol.stream = self.connection.detach()
1480
1481 IOLoop.current().add_callback(self.protocol._receive_frame_loop)
1482 self.protocol.start_pinging()
1483
1484 # Once we've taken over the connection, clear the final callback
1485 # we set on the http request. This deactivates the error handling
1486 # in simple_httpclient that would otherwise interfere with our
1487 # ability to see exceptions.
1488 self.final_callback = None # type: ignore
1489
1490 future_set_result_unless_cancelled(self.connect_future, self)
1491
1492 def write_message(
1493 self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False
1494 ) -> "Future[None]":
1495 """Sends a message to the WebSocket server.
1496
1497 If the stream is closed, raises `WebSocketClosedError`.
1498 Returns a `.Future` which can be used for flow control.
1499
1500 .. versionchanged:: 5.0
1501 Exception raised on a closed stream changed from `.StreamClosedError`
1502 to `WebSocketClosedError`.
1503 """
1504 if self.protocol is None:
1505 raise WebSocketClosedError("Client connection has been closed")
1506 return self.protocol.write_message(message, binary=binary)
1507
1508 def read_message(
1509 self,
1510 callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None,
1511 ) -> Awaitable[Union[None, str, bytes]]:
1512 """Reads a message from the WebSocket server.
1513
1514 If on_message_callback was specified at WebSocket
1515 initialization, this function will never return messages
1516
1517 Returns a future whose result is the message, or None
1518 if the connection is closed. If a callback argument
1519 is given it will be called with the future when it is
1520 ready.
1521 """
1522
1523 awaitable = self.read_queue.get()
1524 if callback is not None:
1525 self.io_loop.add_future(asyncio.ensure_future(awaitable), callback)
1526 return awaitable
1527
1528 def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
1529 return self._on_message(message)
1530
1531 def _on_message(
1532 self, message: Union[None, str, bytes]
1533 ) -> Optional[Awaitable[None]]:
1534 if self._on_message_callback:
1535 self._on_message_callback(message)
1536 return None
1537 else:
1538 return self.read_queue.put(message)
1539
1540 def ping(self, data: bytes = b"") -> None:
1541 """Send ping frame to the remote end.
1542
1543 The data argument allows a small amount of data (up to 125
1544 bytes) to be sent as a part of the ping message. Note that not
1545 all websocket implementations expose this data to
1546 applications.
1547
1548 Consider using the ``ping_interval`` argument to
1549 `websocket_connect` instead of sending pings manually.
1550
1551 .. versionadded:: 5.1
1552
1553 """
1554 data = utf8(data)
1555 if self.protocol is None:
1556 raise WebSocketClosedError()
1557 self.protocol.write_ping(data)
1558
1559 def on_pong(self, data: bytes) -> None:
1560 pass
1561
1562 def on_ping(self, data: bytes) -> None:
1563 pass
1564
1565 def get_websocket_protocol(self) -> WebSocketProtocol:
1566 return WebSocketProtocol13(self, mask_outgoing=True, params=self.params)
1567
1568 @property
1569 def selected_subprotocol(self) -> Optional[str]:
1570 """The subprotocol selected by the server.
1571
1572 .. versionadded:: 5.1
1573 """
1574 return self.protocol.selected_subprotocol
1575
1576 def log_exception(
1577 self,
1578 typ: "Optional[Type[BaseException]]",
1579 value: Optional[BaseException],
1580 tb: Optional[TracebackType],
1581 ) -> None:
1582 assert typ is not None
1583 assert value is not None
1584 app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb))
1585
1586
1587def websocket_connect(
1588 url: Union[str, httpclient.HTTPRequest],
1589 callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None,
1590 connect_timeout: Optional[float] = None,
1591 on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None,
1592 compression_options: Optional[Dict[str, Any]] = None,
1593 ping_interval: Optional[float] = None,
1594 ping_timeout: Optional[float] = None,
1595 max_message_size: int = _default_max_message_size,
1596 subprotocols: Optional[List[str]] = None,
1597 resolver: Optional[Resolver] = None,
1598) -> "Awaitable[WebSocketClientConnection]":
1599 """Client-side websocket support.
1600
1601 Takes a url and returns a Future whose result is a
1602 `WebSocketClientConnection`.
1603
1604 ``compression_options`` is interpreted in the same way as the
1605 return value of `.WebSocketHandler.get_compression_options`.
1606
1607 The connection supports two styles of operation. In the coroutine
1608 style, the application typically calls
1609 `~.WebSocketClientConnection.read_message` in a loop::
1610
1611 conn = yield websocket_connect(url)
1612 while True:
1613 msg = yield conn.read_message()
1614 if msg is None: break
1615 # Do something with msg
1616
1617 In the callback style, pass an ``on_message_callback`` to
1618 ``websocket_connect``. In both styles, a message of ``None``
1619 indicates that the connection has been closed.
1620
1621 ``subprotocols`` may be a list of strings specifying proposed
1622 subprotocols. The selected protocol may be found on the
1623 ``selected_subprotocol`` attribute of the connection object
1624 when the connection is complete.
1625
1626 .. versionchanged:: 3.2
1627 Also accepts ``HTTPRequest`` objects in place of urls.
1628
1629 .. versionchanged:: 4.1
1630 Added ``compression_options`` and ``on_message_callback``.
1631
1632 .. versionchanged:: 4.5
1633 Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size``
1634 arguments, which have the same meaning as in `WebSocketHandler`.
1635
1636 .. versionchanged:: 5.0
1637 The ``io_loop`` argument (deprecated since version 4.1) has been removed.
1638
1639 .. versionchanged:: 5.1
1640 Added the ``subprotocols`` argument.
1641
1642 .. versionchanged:: 6.3
1643 Added the ``resolver`` argument.
1644 """
1645 if isinstance(url, httpclient.HTTPRequest):
1646 assert connect_timeout is None
1647 request = url
1648 # Copy and convert the headers dict/object (see comments in
1649 # AsyncHTTPClient.fetch)
1650 request.headers = httputil.HTTPHeaders(request.headers)
1651 else:
1652 request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
1653 request = cast(
1654 httpclient.HTTPRequest,
1655 httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS),
1656 )
1657 conn = WebSocketClientConnection(
1658 request,
1659 on_message_callback=on_message_callback,
1660 compression_options=compression_options,
1661 ping_interval=ping_interval,
1662 ping_timeout=ping_timeout,
1663 max_message_size=max_message_size,
1664 subprotocols=subprotocols,
1665 resolver=resolver,
1666 )
1667 if callback is not None:
1668 IOLoop.current().add_future(conn.connect_future, callback)
1669 return conn.connect_future