Coverage for /pythoncovmergedfiles/medio/medio/src/jupyter_server/jupyter_server/base/websocket.py: 34%

68 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2023-12-15 06:13 +0000

1"""Base websocket classes.""" 

2import re 

3from typing import Optional, no_type_check 

4from urllib.parse import urlparse 

5 

6from tornado import ioloop 

7from tornado.iostream import IOStream 

8 

9# ping interval for keeping websockets alive (30 seconds) 

10WS_PING_INTERVAL = 30000 

11 

12 

13class WebSocketMixin: 

14 """Mixin for common websocket options""" 

15 

16 ping_callback = None 

17 last_ping = 0.0 

18 last_pong = 0.0 

19 stream: Optional[IOStream] = None 

20 

21 @property 

22 def ping_interval(self): 

23 """The interval for websocket keep-alive pings. 

24 

25 Set ws_ping_interval = 0 to disable pings. 

26 """ 

27 return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] 

28 

29 @property 

30 def ping_timeout(self): 

31 """If no ping is received in this many milliseconds, 

32 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). 

33 Default is max of 3 pings or 30 seconds. 

34 """ 

35 return self.settings.get( # type:ignore[attr-defined] 

36 "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) 

37 ) 

38 

39 @no_type_check 

40 def check_origin(self, origin: Optional[str] = None) -> bool: 

41 """Check Origin == Host or Access-Control-Allow-Origin. 

42 

43 Tornado >= 4 calls this method automatically, raising 403 if it returns False. 

44 """ 

45 

46 if self.allow_origin == "*" or ( 

47 hasattr(self, "skip_check_origin") and self.skip_check_origin() 

48 ): 

49 return True 

50 

51 host = self.request.headers.get("Host") 

52 if origin is None: 

53 origin = self.get_origin() 

54 

55 # If no origin or host header is provided, assume from script 

56 if origin is None or host is None: 

57 return True 

58 

59 origin = origin.lower() 

60 origin_host = urlparse(origin).netloc 

61 

62 # OK if origin matches host 

63 if origin_host == host: 

64 return True 

65 

66 # Check CORS headers 

67 if self.allow_origin: 

68 allow = self.allow_origin == origin 

69 elif self.allow_origin_pat: 

70 allow = bool(re.match(self.allow_origin_pat, origin)) 

71 else: 

72 # No CORS headers deny the request 

73 allow = False 

74 if not allow: 

75 self.log.warning( 

76 "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", 

77 origin, 

78 host, 

79 ) 

80 return allow 

81 

82 def clear_cookie(self, *args, **kwargs): 

83 """meaningless for websockets""" 

84 

85 @no_type_check 

86 def open(self, *args, **kwargs): 

87 """Open the websocket.""" 

88 self.log.debug("Opening websocket %s", self.request.path) 

89 

90 # start the pinging 

91 if self.ping_interval > 0: 

92 loop = ioloop.IOLoop.current() 

93 self.last_ping = loop.time() # Remember time of last ping 

94 self.last_pong = self.last_ping 

95 self.ping_callback = ioloop.PeriodicCallback( 

96 self.send_ping, 

97 self.ping_interval, 

98 ) 

99 self.ping_callback.start() 

100 return super().open(*args, **kwargs) 

101 

102 @no_type_check 

103 def send_ping(self): 

104 """send a ping to keep the websocket alive""" 

105 if self.ws_connection is None and self.ping_callback is not None: 

106 self.ping_callback.stop() 

107 return 

108 

109 if self.ws_connection.client_terminated: 

110 self.close() 

111 return 

112 

113 # check for timeout on pong. Make sure that we really have sent a recent ping in 

114 # case the machine with both server and client has been suspended since the last ping. 

115 now = ioloop.IOLoop.current().time() 

116 since_last_pong = 1e3 * (now - self.last_pong) 

117 since_last_ping = 1e3 * (now - self.last_ping) 

118 if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout: 

119 self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong) 

120 self.close() 

121 return 

122 

123 self.ping(b"") 

124 self.last_ping = now 

125 

126 def on_pong(self, data): 

127 """Handle a pong message.""" 

128 self.last_pong = ioloop.IOLoop.current().time()