/src/swift-nio/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift
Line | Count | Source |
1 | | //===----------------------------------------------------------------------===// |
2 | | // |
3 | | // This source file is part of the SwiftNIO open source project |
4 | | // |
5 | | // Copyright (c) 2017-2024 Apple Inc. and the SwiftNIO project authors |
6 | | // Licensed under Apache License v2.0 |
7 | | // |
8 | | // See LICENSE.txt for license information |
9 | | // See CONTRIBUTORS.txt for the list of SwiftNIO project authors |
10 | | // |
11 | | // SPDX-License-Identifier: Apache-2.0 |
12 | | // |
13 | | //===----------------------------------------------------------------------===// |
14 | | import NIOCore |
15 | | |
16 | | /// Errors that may be raised by the `HTTPServerProtocolUpgrader`. |
17 | | public enum HTTPServerUpgradeErrors: Error { |
18 | | case invalidHTTPOrdering |
19 | | } |
20 | | |
21 | | /// User events that may be fired by the `HTTPServerProtocolUpgrader`. |
22 | | public enum HTTPServerUpgradeEvents: Sendable { |
23 | | /// Fired when HTTP upgrade has completed and the |
24 | | /// `HTTPServerProtocolUpgrader` is about to remove itself from the |
25 | | /// `ChannelPipeline`. |
26 | | case upgradeComplete(toProtocol: String, upgradeRequest: HTTPRequestHead) |
27 | | } |
28 | | |
29 | | /// An object that implements `HTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to |
30 | | /// a protocol on a server-side channel. |
31 | | public protocol HTTPServerProtocolUpgrader { |
32 | | /// The protocol this upgrader knows how to support. |
33 | | var supportedProtocol: String { get } |
34 | | |
35 | | /// All the header fields the protocol needs in the request to successfully upgrade. These header fields |
36 | | /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated |
37 | | /// against the inbound request's Connection header field. |
38 | | var requiredUpgradeHeaders: [String] { get } |
39 | | |
40 | | /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client |
41 | | /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should |
42 | | /// fail the future. |
43 | | func buildUpgradeResponse( |
44 | | channel: Channel, |
45 | | upgradeRequest: HTTPRequestHead, |
46 | | initialResponseHeaders: HTTPHeaders |
47 | | ) -> EventLoopFuture<HTTPHeaders> |
48 | | |
49 | | /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline |
50 | | /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received |
51 | | /// data will be buffered. |
52 | | func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture<Void> |
53 | | } |
54 | | |
55 | | /// A server-side channel handler that receives HTTP requests and optionally performs a HTTP-upgrade. |
56 | | /// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of |
57 | | /// whether the upgrade succeeded or not. |
58 | | /// |
59 | | /// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade |
60 | | /// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's |
61 | | /// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined |
62 | | /// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: |
63 | | /// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. |
64 | | public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { |
65 | | public typealias InboundIn = HTTPServerRequestPart |
66 | | public typealias InboundOut = HTTPServerRequestPart |
67 | | public typealias OutboundOut = HTTPServerResponsePart |
68 | | |
69 | | private let upgraders: [String: HTTPServerProtocolUpgrader] |
70 | | private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void |
71 | | |
72 | | private let httpEncoder: HTTPResponseEncoder |
73 | | private let extraHTTPHandlers: [RemovableChannelHandler] |
74 | | |
75 | | /// Whether we've already seen the first request. |
76 | 0 | private var seenFirstRequest = false |
77 | | |
78 | | /// The closure that should be invoked when the end of the upgrade request is received if any. |
79 | 0 | private var upgradeState: UpgradeState = .idle |
80 | 0 | private var receivedMessages: CircularBuffer<NIOAny> = CircularBuffer() |
81 | | |
82 | | /// Create a `HTTPServerUpgradeHandler`. |
83 | | /// |
84 | | /// - Parameter upgraders: All `HTTPServerProtocolUpgrader` objects that this pipeline will be able |
85 | | /// to use to handle HTTP upgrade. |
86 | | /// - Parameter httpEncoder: The `HTTPResponseEncoder` encoding responses from this handler and which will |
87 | | /// be removed from the pipeline once the upgrade response is sent. This is used to ensure |
88 | | /// that the pipeline will be in a clean state after upgrade. |
89 | | /// - Parameter extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least |
90 | | /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate |
91 | | /// receiving non-HTTP data. |
92 | | /// - Parameter upgradeCompletionHandler: A block that will be fired when HTTP upgrade is complete. |
93 | | public init( |
94 | | upgraders: [HTTPServerProtocolUpgrader], |
95 | | httpEncoder: HTTPResponseEncoder, |
96 | | extraHTTPHandlers: [RemovableChannelHandler], |
97 | | upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void |
98 | 0 | ) { |
99 | 0 | var upgraderMap = [String: HTTPServerProtocolUpgrader]() |
100 | 0 | for upgrader in upgraders { |
101 | 0 | upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader |
102 | 0 | } |
103 | 0 | self.upgraders = upgraderMap |
104 | 0 | self.upgradeCompletionHandler = upgradeCompletionHandler |
105 | 0 | self.httpEncoder = httpEncoder |
106 | 0 | self.extraHTTPHandlers = extraHTTPHandlers |
107 | 0 | } |
108 | | |
109 | 0 | public func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
110 | 0 | guard !self.seenFirstRequest else { |
111 | 0 | // We're waiting for upgrade to complete: buffer this data. |
112 | 0 | self.receivedMessages.append(data) |
113 | 0 | return |
114 | 0 | } |
115 | 0 |
|
116 | 0 | let requestPart = HTTPServerUpgradeHandler.unwrapInboundIn(data) |
117 | 0 |
|
118 | 0 | switch self.upgradeState { |
119 | 0 | case .idle: |
120 | 0 | self.firstRequestHeadReceived(context: context, requestPart: requestPart) |
121 | 0 | case .awaitingUpgrader: |
122 | 0 | if case .end = requestPart { |
123 | 0 | // This is the end of the first request. Swallow it, we're buffering the rest. |
124 | 0 | self.seenFirstRequest = true |
125 | 0 | } |
126 | 0 | case .upgraderReady(let upgrade): |
127 | 0 | if case .end = requestPart { |
128 | 0 | // This is the end of the first request, and we can upgrade. Time to kick it off. |
129 | 0 | self.seenFirstRequest = true |
130 | 0 | upgrade() |
131 | 0 | } |
132 | 0 | case .upgradeFailed: |
133 | 0 | // We were re-entrantly called while delivering the request head. We can just pass this through. |
134 | 0 | context.fireChannelRead(data) |
135 | 0 | case .upgradeComplete: |
136 | 0 | preconditionFailure( |
137 | 0 | "Upgrade has completed but we have not seen a whole request and still got re-entrantly called." |
138 | 0 | ) |
139 | 0 | case .upgrading: |
140 | 0 | preconditionFailure( |
141 | 0 | "We think we saw .end before and began upgrading, but somehow we have not set seenFirstRequest" |
142 | 0 | ) |
143 | 0 | } |
144 | 0 | } |
145 | | |
146 | 0 | public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) { |
147 | 0 | // We have been formally removed from the pipeline. We should send any buffered data we have. |
148 | 0 | // Note that we loop twice. This is because we want to guard against being reentrantly called from fireChannelReadComplete. |
149 | 0 | while self.receivedMessages.count > 0 { |
150 | 0 | while self.receivedMessages.count > 0 { |
151 | 0 | let bufferedPart = self.receivedMessages.removeFirst() |
152 | 0 | context.fireChannelRead(bufferedPart) |
153 | 0 | } |
154 | 0 |
|
155 | 0 | context.fireChannelReadComplete() |
156 | 0 | } |
157 | 0 |
|
158 | 0 | context.leavePipeline(removalToken: removalToken) |
159 | 0 | } |
160 | | |
161 | 0 | private func firstRequestHeadReceived(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) { |
162 | 0 | // We should decide if we're going to upgrade based on the first request header: if we aren't upgrading, |
163 | 0 | // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're |
164 | 0 | // upgrading, the only thing we should see is a request head. Anything else in an error. |
165 | 0 | guard case .head(let request) = requestPart else { |
166 | 0 | context.fireErrorCaught(HTTPServerUpgradeErrors.invalidHTTPOrdering) |
167 | 0 | self.notUpgrading(context: context, data: requestPart) |
168 | 0 | return |
169 | 0 | } |
170 | 0 |
|
171 | 0 | // Ok, we have a HTTP request. Check if it's an upgrade. If it's not, we want to pass it on and remove ourselves |
172 | 0 | // from the channel pipeline. |
173 | 0 | let requestedProtocols = request.headers[canonicalForm: "upgrade"].map(String.init) |
174 | 0 | guard requestedProtocols.count > 0 else { |
175 | 0 | self.notUpgrading(context: context, data: requestPart) |
176 | 0 | return |
177 | 0 | } |
178 | 0 |
|
179 | 0 | // Cool, this is an upgrade request. |
180 | 0 | // We'll attempt to upgrade. This may take a while, so while we're waiting more data can come in. |
181 | 0 | self.upgradeState = .awaitingUpgrader |
182 | 0 |
|
183 | 0 | self.handleUpgrade(context: context, request: request, requestedProtocols: requestedProtocols) |
184 | 0 | .whenSuccess { callback in |
185 | 0 | if let callback = callback { |
186 | 0 | self.gotUpgrader(upgrader: callback) |
187 | 0 | } else { |
188 | 0 | self.notUpgrading(context: context, data: requestPart) |
189 | 0 | } |
190 | 0 | } |
191 | 0 | } |
192 | | |
193 | | /// The core of the upgrade handling logic. |
194 | | /// |
195 | | /// - Returns: An isolated `EventLoopFuture` that will contain a callback to invoke if upgrade is requested, |
196 | | /// or nil if upgrade has failed. Never returns a failed future. |
197 | | private func handleUpgrade( |
198 | | context: ChannelHandlerContext, |
199 | | request: HTTPRequestHead, |
200 | | requestedProtocols: [String] |
201 | 0 | ) -> EventLoopFuture<(() -> Void)?>.Isolated { |
202 | 0 |
|
203 | 0 | let connectionHeader = Set(request.headers[canonicalForm: "connection"].map { $0.lowercased() }) |
204 | 0 | let allHeaderNames = Set(request.headers.map { $0.name.lowercased() }) |
205 | 0 |
|
206 | 0 | // We now set off a chain of Futures to try to find a protocol upgrade. While this is blocking, we need to buffer inbound data. |
207 | 0 | let protocolIterator = requestedProtocols.makeIterator() |
208 | 0 | return self.handleUpgradeForProtocol( |
209 | 0 | context: context, |
210 | 0 | protocolIterator: protocolIterator, |
211 | 0 | request: request, |
212 | 0 | allHeaderNames: allHeaderNames, |
213 | 0 | connectionHeader: connectionHeader |
214 | 0 | ) |
215 | 0 | } |
216 | | |
217 | | /// Attempt to upgrade a single protocol. |
218 | | /// |
219 | | /// Will recurse through `protocolIterator` if upgrade fails. |
220 | | /// |
221 | | /// - Returns: An isolated `EventLoopFuture` that will contain a callback to invoke if upgrade is requested, |
222 | | /// or nil if upgrade has failed. Never returns a failed future. |
223 | | private func handleUpgradeForProtocol( |
224 | | context: ChannelHandlerContext, |
225 | | protocolIterator: Array<String>.Iterator, |
226 | | request: HTTPRequestHead, |
227 | | allHeaderNames: Set<String>, |
228 | | connectionHeader: Set<String> |
229 | 0 | ) -> EventLoopFuture<(() -> Void)?>.Isolated { |
230 | 0 | // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. |
231 | 0 | var protocolIterator = protocolIterator |
232 | 0 | guard let proto = protocolIterator.next() else { |
233 | 0 | // We're done! No suitable protocol for upgrade. |
234 | 0 | return context.eventLoop.makeSucceededIsolatedFuture(nil) |
235 | 0 | } |
236 | 0 |
|
237 | 0 | guard let upgrader = self.upgraders[proto.lowercased()] else { |
238 | 0 | return self.handleUpgradeForProtocol( |
239 | 0 | context: context, |
240 | 0 | protocolIterator: protocolIterator, |
241 | 0 | request: request, |
242 | 0 | allHeaderNames: allHeaderNames, |
243 | 0 | connectionHeader: connectionHeader |
244 | 0 | ) |
245 | 0 | } |
246 | 0 |
|
247 | 0 | let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) |
248 | 0 | guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { |
249 | 0 | return self.handleUpgradeForProtocol( |
250 | 0 | context: context, |
251 | 0 | protocolIterator: protocolIterator, |
252 | 0 | request: request, |
253 | 0 | allHeaderNames: allHeaderNames, |
254 | 0 | connectionHeader: connectionHeader |
255 | 0 | ) |
256 | 0 | } |
257 | 0 |
|
258 | 0 | let responseHeaders = self.buildUpgradeHeaders(protocol: proto) |
259 | 0 | let pipeline = context.pipeline |
260 | 0 |
|
261 | 0 | return upgrader.buildUpgradeResponse( |
262 | 0 | channel: context.channel, |
263 | 0 | upgradeRequest: request, |
264 | 0 | initialResponseHeaders: responseHeaders |
265 | 0 | ).hop(to: context.eventLoop) |
266 | 0 | .assumeIsolated() |
267 | 0 | .map { finalResponseHeaders in |
268 | 0 | { |
269 | 0 | // Ok, we're upgrading. |
270 | 0 | self.upgradeState = .upgrading |
271 | 0 |
|
272 | 0 | // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP |
273 | 0 | // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until |
274 | 0 | // that completes. |
275 | 0 | // While there are a lot of Futures involved here it's quite possible that all of this code will |
276 | 0 | // actually complete synchronously: we just want to program for the possibility that it won't. |
277 | 0 | // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the |
278 | 0 | // internal handler, then call the user code, and then finally when the user code is done we do |
279 | 0 | // our final cleanup steps, namely we replay the received data we buffered in the meantime and |
280 | 0 | // then remove ourselves from the pipeline. |
281 | 0 | self.removeExtraHandlers(pipeline: pipeline) |
282 | 0 | .assumeIsolated() |
283 | 0 | .flatMap { |
284 | 0 | self.sendUpgradeResponse( |
285 | 0 | context: context, |
286 | 0 | upgradeRequest: request, |
287 | 0 | responseHeaders: finalResponseHeaders |
288 | 0 | ) |
289 | 0 | }.flatMap { |
290 | 0 | pipeline.syncOperations.removeHandler(self.httpEncoder) |
291 | 0 | }.flatMap { () -> EventLoopFuture<Void> in |
292 | 0 | self.upgradeCompletionHandler(context) |
293 | 0 | return upgrader.upgrade(context: context, upgradeRequest: request) |
294 | 0 | }.whenComplete { result in |
295 | 0 | switch result { |
296 | 0 | case .success: |
297 | 0 | context.fireUserInboundEventTriggered( |
298 | 0 | HTTPServerUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request) |
299 | 0 | ) |
300 | 0 | self.upgradeState = .upgradeComplete |
301 | 0 | // When we remove ourselves we'll be delivering any buffered data. |
302 | 0 | context.pipeline.syncOperations.removeHandler(context: context, promise: nil) |
303 | 0 |
|
304 | 0 | case .failure(let error): |
305 | 0 | // Remain in the '.upgrading' state. |
306 | 0 | context.fireErrorCaught(error) |
307 | 0 | } |
308 | 0 | } |
309 | 0 | } |
310 | 0 | }.flatMapError { error in |
311 | 0 | // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. |
312 | 0 | context.fireErrorCaught(error) |
313 | 0 | return self.handleUpgradeForProtocol( |
314 | 0 | context: context, |
315 | 0 | protocolIterator: protocolIterator, |
316 | 0 | request: request, |
317 | 0 | allHeaderNames: allHeaderNames, |
318 | 0 | connectionHeader: connectionHeader |
319 | 0 | ) |
320 | 0 | } |
321 | 0 | } |
322 | | |
323 | 0 | private func gotUpgrader(upgrader: @escaping (() -> Void)) { |
324 | 0 | switch self.upgradeState { |
325 | 0 | case .awaitingUpgrader: |
326 | 0 | self.upgradeState = .upgraderReady(upgrader) |
327 | 0 | if self.seenFirstRequest { |
328 | 0 | // Ok, we're good to go, we can upgrade. Otherwise we're waiting for .end, which |
329 | 0 | // will trigger the upgrade. |
330 | 0 | upgrader() |
331 | 0 | } |
332 | 0 | case .idle, .upgradeComplete, .upgraderReady, .upgradeFailed, .upgrading: |
333 | 0 | preconditionFailure("Unexpected upgrader state: \(self.upgradeState)") |
334 | 0 | } |
335 | 0 | } |
336 | | |
337 | | /// Sends the 101 Switching Protocols response for the pipeline. |
338 | | private func sendUpgradeResponse( |
339 | | context: ChannelHandlerContext, |
340 | | upgradeRequest: HTTPRequestHead, |
341 | | responseHeaders: HTTPHeaders |
342 | 0 | ) -> EventLoopFuture<Void> { |
343 | 0 | var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) |
344 | 0 | response.headers = responseHeaders |
345 | 0 | return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) |
346 | 0 | } |
347 | | |
348 | | /// Called when we know we're not upgrading. Passes the data on and then removes this object from the pipeline. |
349 | 0 | private func notUpgrading(context: ChannelHandlerContext, data: HTTPServerRequestPart) { |
350 | 0 | self.upgradeState = .upgradeFailed |
351 | 0 |
|
352 | 0 | if !self.seenFirstRequest { |
353 | 0 | // We haven't seen the first request .end. That means we're not buffering anything, and we can |
354 | 0 | // just deliver data. |
355 | 0 | assert(self.receivedMessages.isEmpty) |
356 | 0 | context.fireChannelRead(HTTPServerUpgradeHandler.wrapInboundOut(data)) |
357 | 0 | } else { |
358 | 0 | // This is trickier. We've seen the first request .end, so we now need to deliver the .head we |
359 | 0 | // got passed, as well as the .end we swallowed, and any buffered parts. While we're doing this |
360 | 0 | // we may be re-entrantly called, which will cause us to buffer new parts. To make that safe, we |
361 | 0 | // must ensure we aren't holding the buffer mutably, so no for loop for us. |
362 | 0 | context.fireChannelRead(HTTPServerUpgradeHandler.wrapInboundOut(data)) |
363 | 0 | context.fireChannelRead(HTTPServerUpgradeHandler.wrapInboundOut(.end(nil))) |
364 | 0 | } |
365 | 0 |
|
366 | 0 | context.fireChannelReadComplete() |
367 | 0 |
|
368 | 0 | // Ok, we've delivered all the parts. We can now remove ourselves, which should happen synchronously. |
369 | 0 | context.pipeline.syncOperations.removeHandler(context: context, promise: nil) |
370 | 0 | } |
371 | | |
372 | | /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. |
373 | 0 | private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { |
374 | 0 | HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) |
375 | 0 | } |
376 | | |
377 | | /// Removes any extra HTTP-related handlers from the channel pipeline. |
378 | 0 | private func removeExtraHandlers(pipeline: ChannelPipeline) -> EventLoopFuture<Void> { |
379 | 0 | guard self.extraHTTPHandlers.count > 0 else { |
380 | 0 | return pipeline.eventLoop.makeSucceededFuture(()) |
381 | 0 | } |
382 | 0 |
|
383 | 0 | return .andAllSucceed( |
384 | 0 | self.extraHTTPHandlers.map { pipeline.syncOperations.removeHandler($0) }, |
385 | 0 | on: pipeline.eventLoop |
386 | 0 | ) |
387 | 0 | } |
388 | | } |
389 | | |
390 | | extension HTTPServerUpgradeHandler: @unchecked Sendable {} |
391 | | |
392 | | extension HTTPServerUpgradeHandler { |
393 | | /// The state of the upgrade handler. |
394 | | private enum UpgradeState { |
395 | | /// Awaiting some activity. |
396 | | case idle |
397 | | |
398 | | /// The request head has been received. We're currently running the future chain awaiting an upgrader. |
399 | | case awaitingUpgrader |
400 | | |
401 | | /// We have an upgrader, which means we can begin upgrade. |
402 | | case upgraderReady(() -> Void) |
403 | | |
404 | | /// The upgrade is in process. |
405 | | case upgrading |
406 | | |
407 | | /// The upgrade has failed, and we are being removed from the pipeline. |
408 | | case upgradeFailed |
409 | | |
410 | | /// The upgrade has succeeded, and we are being removed from the pipeline. |
411 | | case upgradeComplete |
412 | | } |
413 | | } |