Coverage for /pythoncovmergedfiles/medio/medio/src/jupyter_server/jupyter_server/base/websocket.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

92 statements  

1"""Base websocket classes.""" 

2 

3import re 

4import warnings 

5from typing import Optional, no_type_check 

6from urllib.parse import urlparse 

7 

8from tornado import ioloop, web 

9from tornado.iostream import IOStream 

10 

11from jupyter_server.base.handlers import JupyterHandler 

12from jupyter_server.utils import JupyterServerAuthWarning 

13 

14# ping interval for keeping websockets alive (30 seconds) 

15WS_PING_INTERVAL = 30000 

16 

17 

18class WebSocketMixin: 

19 """Mixin for common websocket options""" 

20 

21 ping_callback = None 

22 last_ping = 0.0 

23 last_pong = 0.0 

24 stream: Optional[IOStream] = None 

25 

26 @property 

27 def ping_interval(self): 

28 """The interval for websocket keep-alive pings. 

29 

30 Set ws_ping_interval = 0 to disable pings. 

31 """ 

32 return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] 

33 

34 @property 

35 def ping_timeout(self): 

36 """If no ping is received in this many milliseconds, 

37 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). 

38 Default is max of 3 pings or 30 seconds. 

39 """ 

40 return self.settings.get( # type:ignore[attr-defined] 

41 "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) 

42 ) 

43 

44 @no_type_check 

45 def check_origin(self, origin: Optional[str] = None) -> bool: 

46 """Check Origin == Host or Access-Control-Allow-Origin. 

47 

48 Tornado >= 4 calls this method automatically, raising 403 if it returns False. 

49 """ 

50 

51 if self.allow_origin == "*" or ( 

52 hasattr(self, "skip_check_origin") and self.skip_check_origin() 

53 ): 

54 return True 

55 

56 host = self.request.headers.get("Host") 

57 if origin is None: 

58 origin = self.get_origin() 

59 

60 # If no origin or host header is provided, assume from script 

61 if origin is None or host is None: 

62 return True 

63 

64 origin = origin.lower() 

65 origin_host = urlparse(origin).netloc 

66 

67 # OK if origin matches host 

68 if origin_host == host: 

69 return True 

70 

71 # Check CORS headers 

72 if self.allow_origin: 

73 allow = self.allow_origin == origin 

74 elif self.allow_origin_pat: 

75 allow = bool(re.match(self.allow_origin_pat, origin)) 

76 else: 

77 # No CORS headers deny the request 

78 allow = False 

79 if not allow: 

80 self.log.warning( 

81 "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", 

82 origin, 

83 host, 

84 ) 

85 return allow 

86 

87 def clear_cookie(self, *args, **kwargs): 

88 """meaningless for websockets""" 

89 

90 @no_type_check 

91 def _maybe_auth(self): 

92 """Verify authentication if required. 

93 

94 Only used when the websocket class does not inherit from JupyterHandler. 

95 """ 

96 if not self.settings.get("allow_unauthenticated_access", False): 

97 if not self.request.method: 

98 raise web.HTTPError(403) 

99 method = getattr(self, self.request.method.lower()) 

100 if not getattr(method, "__allow_unauthenticated", False): 

101 # rather than re-using `web.authenticated` which also redirects 

102 # to login page on GET, just raise 403 if user is not known 

103 user = self.current_user 

104 if user is None: 

105 self.log.warning("Couldn't authenticate WebSocket connection") 

106 raise web.HTTPError(403) 

107 

108 @no_type_check 

109 def prepare(self, *args, **kwargs): 

110 """Handle a get request.""" 

111 if not isinstance(self, JupyterHandler): 

112 should_authenticate = not self.settings.get("allow_unauthenticated_access", False) 

113 if "identity_provider" in self.settings and should_authenticate: 

114 warnings.warn( 

115 "WebSocketMixin sub-class does not inherit from JupyterHandler" 

116 " preventing proper authentication using custom identity provider.", 

117 JupyterServerAuthWarning, 

118 stacklevel=2, 

119 ) 

120 self._maybe_auth() 

121 return super().prepare(*args, **kwargs) 

122 return super().prepare(*args, **kwargs, _redirect_to_login=False) 

123 

124 @no_type_check 

125 def open(self, *args, **kwargs): 

126 """Open the websocket.""" 

127 self.log.debug("Opening websocket %s", self.request.path) 

128 

129 # start the pinging 

130 if self.ping_interval > 0: 

131 loop = ioloop.IOLoop.current() 

132 self.last_ping = loop.time() # Remember time of last ping 

133 self.last_pong = self.last_ping 

134 self.ping_callback = ioloop.PeriodicCallback( 

135 self.send_ping, 

136 self.ping_interval, 

137 ) 

138 self.ping_callback.start() 

139 return super().open(*args, **kwargs) 

140 

141 @no_type_check 

142 def send_ping(self): 

143 """send a ping to keep the websocket alive""" 

144 if self.ws_connection is None and self.ping_callback is not None: 

145 self.ping_callback.stop() 

146 return 

147 

148 if self.ws_connection.client_terminated: 

149 self.close() 

150 return 

151 

152 # check for timeout on pong. Make sure that we really have sent a recent ping in 

153 # case the machine with both server and client has been suspended since the last ping. 

154 now = ioloop.IOLoop.current().time() 

155 since_last_pong = 1e3 * (now - self.last_pong) 

156 since_last_ping = 1e3 * (now - self.last_ping) 

157 if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout: 

158 self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong) 

159 self.close() 

160 return 

161 

162 self.ping(b"") 

163 self.last_ping = now 

164 

165 def on_pong(self, data): 

166 """Handle a pong message.""" 

167 self.last_pong = ioloop.IOLoop.current().time()