Coverage for /pythoncovmergedfiles/medio/medio/src/jupyter_server/jupyter_server/base/websocket.py: 34%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Base websocket classes."""
3import re
4import warnings
5from typing import Optional, no_type_check
6from urllib.parse import urlparse
8from tornado import ioloop, web
9from tornado.iostream import IOStream
11from jupyter_server.base.handlers import JupyterHandler
12from jupyter_server.utils import JupyterServerAuthWarning
14# ping interval for keeping websockets alive (30 seconds)
15WS_PING_INTERVAL = 30000
18class WebSocketMixin:
19 """Mixin for common websocket options"""
21 ping_callback = None
22 last_ping = 0.0
23 last_pong = 0.0
24 stream: Optional[IOStream] = None
26 @property
27 def ping_interval(self):
28 """The interval for websocket keep-alive pings.
30 Set ws_ping_interval = 0 to disable pings.
31 """
32 return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
34 @property
35 def ping_timeout(self):
36 """If no ping is received in this many milliseconds,
37 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
38 Default is max of 3 pings or 30 seconds.
39 """
40 return self.settings.get( # type:ignore[attr-defined]
41 "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
42 )
44 @no_type_check
45 def check_origin(self, origin: Optional[str] = None) -> bool:
46 """Check Origin == Host or Access-Control-Allow-Origin.
48 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
49 """
51 if self.allow_origin == "*" or (
52 hasattr(self, "skip_check_origin") and self.skip_check_origin()
53 ):
54 return True
56 host = self.request.headers.get("Host")
57 if origin is None:
58 origin = self.get_origin()
60 # If no origin or host header is provided, assume from script
61 if origin is None or host is None:
62 return True
64 origin = origin.lower()
65 origin_host = urlparse(origin).netloc
67 # OK if origin matches host
68 if origin_host == host:
69 return True
71 # Check CORS headers
72 if self.allow_origin:
73 allow = self.allow_origin == origin
74 elif self.allow_origin_pat:
75 allow = bool(re.match(self.allow_origin_pat, origin))
76 else:
77 # No CORS headers deny the request
78 allow = False
79 if not allow:
80 self.log.warning(
81 "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
82 origin,
83 host,
84 )
85 return allow
87 def clear_cookie(self, *args, **kwargs):
88 """meaningless for websockets"""
90 @no_type_check
91 def _maybe_auth(self):
92 """Verify authentication if required.
94 Only used when the websocket class does not inherit from JupyterHandler.
95 """
96 if not self.settings.get("allow_unauthenticated_access", False):
97 if not self.request.method:
98 raise web.HTTPError(403)
99 method = getattr(self, self.request.method.lower())
100 if not getattr(method, "__allow_unauthenticated", False):
101 # rather than re-using `web.authenticated` which also redirects
102 # to login page on GET, just raise 403 if user is not known
103 user = self.current_user
104 if user is None:
105 self.log.warning("Couldn't authenticate WebSocket connection")
106 raise web.HTTPError(403)
108 @no_type_check
109 def prepare(self, *args, **kwargs):
110 """Handle a get request."""
111 if not isinstance(self, JupyterHandler):
112 should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
113 if "identity_provider" in self.settings and should_authenticate:
114 warnings.warn(
115 "WebSocketMixin sub-class does not inherit from JupyterHandler"
116 " preventing proper authentication using custom identity provider.",
117 JupyterServerAuthWarning,
118 stacklevel=2,
119 )
120 self._maybe_auth()
121 return super().prepare(*args, **kwargs)
122 return super().prepare(*args, **kwargs, _redirect_to_login=False)
124 @no_type_check
125 def open(self, *args, **kwargs):
126 """Open the websocket."""
127 self.log.debug("Opening websocket %s", self.request.path)
129 # start the pinging
130 if self.ping_interval > 0:
131 loop = ioloop.IOLoop.current()
132 self.last_ping = loop.time() # Remember time of last ping
133 self.last_pong = self.last_ping
134 self.ping_callback = ioloop.PeriodicCallback(
135 self.send_ping,
136 self.ping_interval,
137 )
138 self.ping_callback.start()
139 return super().open(*args, **kwargs)
141 @no_type_check
142 def send_ping(self):
143 """send a ping to keep the websocket alive"""
144 if self.ws_connection is None and self.ping_callback is not None:
145 self.ping_callback.stop()
146 return
148 if self.ws_connection.client_terminated:
149 self.close()
150 return
152 # check for timeout on pong. Make sure that we really have sent a recent ping in
153 # case the machine with both server and client has been suspended since the last ping.
154 now = ioloop.IOLoop.current().time()
155 since_last_pong = 1e3 * (now - self.last_pong)
156 since_last_ping = 1e3 * (now - self.last_ping)
157 if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
158 self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
159 self.close()
160 return
162 self.ping(b"")
163 self.last_ping = now
165 def on_pong(self, data):
166 """Handle a pong message."""
167 self.last_pong = ioloop.IOLoop.current().time()