1"""WebSocket protocol versions 13 and 8."""
2
3import asyncio
4import random
5from functools import partial
6from typing import Any, Final, Optional, Union
7
8from ..base_protocol import BaseProtocol
9from ..client_exceptions import ClientConnectionResetError
10from ..compression_utils import ZLibBackend, ZLibCompressor
11from .helpers import (
12 MASK_LEN,
13 MSG_SIZE,
14 PACK_CLOSE_CODE,
15 PACK_LEN1,
16 PACK_LEN2,
17 PACK_LEN3,
18 PACK_RANDBITS,
19 websocket_mask,
20)
21from .models import WS_DEFLATE_TRAILING, WSMsgType
22
23DEFAULT_LIMIT: Final[int] = 2**16
24
25# For websockets, keeping latency low is extremely important as implementations
26# generally expect to be able to send and receive messages quickly. We use a
27# larger chunk size than the default to reduce the number of executor calls
28# since the executor is a significant source of latency and overhead when
29# the chunks are small. A size of 5KiB was chosen because it is also the
30# same value python-zlib-ng choose to use as the threshold to release the GIL.
31
32WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
33
34
35class WebSocketWriter:
36 """WebSocket writer.
37
38 The writer is responsible for sending messages to the client. It is
39 created by the protocol when a connection is established. The writer
40 should avoid implementing any application logic and should only be
41 concerned with the low-level details of the WebSocket protocol.
42 """
43
44 def __init__(
45 self,
46 protocol: BaseProtocol,
47 transport: asyncio.Transport,
48 *,
49 use_mask: bool = False,
50 limit: int = DEFAULT_LIMIT,
51 random: random.Random = random.Random(),
52 compress: int = 0,
53 notakeover: bool = False,
54 ) -> None:
55 """Initialize a WebSocket writer."""
56 self.protocol = protocol
57 self.transport = transport
58 self.use_mask = use_mask
59 self.get_random_bits = partial(random.getrandbits, 32)
60 self.compress = compress
61 self.notakeover = notakeover
62 self._closing = False
63 self._limit = limit
64 self._output_size = 0
65 self._compressobj: Any = None # actually compressobj
66
67 async def send_frame(
68 self, message: bytes, opcode: int, compress: Optional[int] = None
69 ) -> None:
70 """Send a frame over the websocket with message as its payload."""
71 if self._closing and not (opcode & WSMsgType.CLOSE):
72 raise ClientConnectionResetError("Cannot write to closing transport")
73
74 # RSV are the reserved bits in the frame header. They are used to
75 # indicate that the frame is using an extension.
76 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
77 rsv = 0
78 # Only compress larger packets (disabled)
79 # Does small packet needs to be compressed?
80 # if self.compress and opcode < 8 and len(message) > 124:
81 if (compress or self.compress) and opcode < 8:
82 # RSV1 (rsv = 0x40) is set for compressed frames
83 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
84 rsv = 0x40
85
86 if compress:
87 # Do not set self._compress if compressing is for this frame
88 compressobj = self._make_compress_obj(compress)
89 else: # self.compress
90 if not self._compressobj:
91 self._compressobj = self._make_compress_obj(self.compress)
92 compressobj = self._compressobj
93
94 message = (
95 await compressobj.compress(message)
96 + compressobj.flush(
97 ZLibBackend.Z_FULL_FLUSH
98 if self.notakeover
99 else ZLibBackend.Z_SYNC_FLUSH
100 )
101 ).removesuffix(WS_DEFLATE_TRAILING)
102 # Its critical that we do not return control to the event
103 # loop until we have finished sending all the compressed
104 # data. Otherwise we could end up mixing compressed frames
105 # if there are multiple coroutines compressing data.
106
107 msg_length = len(message)
108
109 use_mask = self.use_mask
110 mask_bit = 0x80 if use_mask else 0
111
112 # Depending on the message length, the header is assembled differently.
113 # The first byte is reserved for the opcode and the RSV bits.
114 first_byte = 0x80 | rsv | opcode
115 if msg_length < 126:
116 header = PACK_LEN1(first_byte, msg_length | mask_bit)
117 header_len = 2
118 elif msg_length < 65536:
119 header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
120 header_len = 4
121 else:
122 header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
123 header_len = 10
124
125 if self.transport.is_closing():
126 raise ClientConnectionResetError("Cannot write to closing transport")
127
128 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
129 # If we are using a mask, we need to generate it randomly
130 # and apply it to the message before sending it. A mask is
131 # a 32-bit value that is applied to the message using a
132 # bitwise XOR operation. It is used to prevent certain types
133 # of attacks on the websocket protocol. The mask is only used
134 # when aiohttp is acting as a client. Servers do not use a mask.
135 if use_mask:
136 mask = PACK_RANDBITS(self.get_random_bits())
137 message = bytearray(message)
138 websocket_mask(mask, message)
139 self.transport.write(header + mask + message)
140 self._output_size += MASK_LEN
141 elif msg_length > MSG_SIZE:
142 self.transport.write(header)
143 self.transport.write(message)
144 else:
145 self.transport.write(header + message)
146
147 self._output_size += header_len + msg_length
148
149 # It is safe to return control to the event loop when using compression
150 # after this point as we have already sent or buffered all the data.
151
152 # Once we have written output_size up to the limit, we call the
153 # drain helper which waits for the transport to be ready to accept
154 # more data. This is a flow control mechanism to prevent the buffer
155 # from growing too large. The drain helper will return right away
156 # if the writer is not paused.
157 if self._output_size > self._limit:
158 self._output_size = 0
159 if self.protocol._paused:
160 await self.protocol._drain_helper()
161
162 def _make_compress_obj(self, compress: int) -> ZLibCompressor:
163 return ZLibCompressor(
164 level=ZLibBackend.Z_BEST_SPEED,
165 wbits=-compress,
166 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
167 )
168
169 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
170 """Close the websocket, sending the specified code and message."""
171 if isinstance(message, str):
172 message = message.encode("utf-8")
173 try:
174 await self.send_frame(
175 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
176 )
177 finally:
178 self._closing = True