Coverage Report

Created: 2026-06-01 06:32

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/swift-nio/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift
Line
Count
Source
1
//===----------------------------------------------------------------------===//
2
//
3
// This source file is part of the SwiftNIO open source project
4
//
5
// Copyright (c) 2017-2024 Apple Inc. and the SwiftNIO project authors
6
// Licensed under Apache License v2.0
7
//
8
// See LICENSE.txt for license information
9
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
10
//
11
// SPDX-License-Identifier: Apache-2.0
12
//
13
//===----------------------------------------------------------------------===//
14
import NIOCore
15
16
/// Errors that may be raised by the `HTTPServerProtocolUpgrader`.
17
public enum HTTPServerUpgradeErrors: Error {
18
    case invalidHTTPOrdering
19
}
20
21
/// User events that may be fired by the `HTTPServerProtocolUpgrader`.
22
public enum HTTPServerUpgradeEvents: Sendable {
23
    /// Fired when HTTP upgrade has completed and the
24
    /// `HTTPServerProtocolUpgrader` is about to remove itself from the
25
    /// `ChannelPipeline`.
26
    case upgradeComplete(toProtocol: String, upgradeRequest: HTTPRequestHead)
27
}
28
29
/// An object that implements `HTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to
30
/// a protocol on a server-side channel.
31
public protocol HTTPServerProtocolUpgrader {
32
    /// The protocol this upgrader knows how to support.
33
    var supportedProtocol: String { get }
34
35
    /// All the header fields the protocol needs in the request to successfully upgrade. These header fields
36
    /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated
37
    ///  against the inbound request's Connection header field.
38
    var requiredUpgradeHeaders: [String] { get }
39
40
    /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client
41
    /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should
42
    /// fail the future.
43
    func buildUpgradeResponse(
44
        channel: Channel,
45
        upgradeRequest: HTTPRequestHead,
46
        initialResponseHeaders: HTTPHeaders
47
    ) -> EventLoopFuture<HTTPHeaders>
48
49
    /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline
50
    /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received
51
    /// data will be buffered.
52
    func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture<Void>
53
}
54
55
/// A server-side channel handler that receives HTTP requests and optionally performs a HTTP-upgrade.
56
/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of
57
/// whether the upgrade succeeded or not.
58
///
59
/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade
60
/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's
61
/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined
62
/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine:
63
/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low.
64
public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler {
65
    public typealias InboundIn = HTTPServerRequestPart
66
    public typealias InboundOut = HTTPServerRequestPart
67
    public typealias OutboundOut = HTTPServerResponsePart
68
69
    private let upgraders: [String: HTTPServerProtocolUpgrader]
70
    private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void
71
72
    private let httpEncoder: HTTPResponseEncoder
73
    private let extraHTTPHandlers: [RemovableChannelHandler]
74
75
    /// Whether we've already seen the first request.
76
0
    private var seenFirstRequest = false
77
78
    /// The closure that should be invoked when the end of the upgrade request is received if any.
79
0
    private var upgradeState: UpgradeState = .idle
80
0
    private var receivedMessages: CircularBuffer<NIOAny> = CircularBuffer()
81
82
    /// Create a `HTTPServerUpgradeHandler`.
83
    ///
84
    /// - Parameter upgraders: All `HTTPServerProtocolUpgrader` objects that this pipeline will be able
85
    ///     to use to handle HTTP upgrade.
86
    /// - Parameter httpEncoder: The `HTTPResponseEncoder` encoding responses from this handler and which will
87
    ///     be removed from the pipeline once the upgrade response is sent. This is used to ensure
88
    ///     that the pipeline will be in a clean state after upgrade.
89
    /// - Parameter extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least
90
    ///     this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate
91
    ///     receiving non-HTTP data.
92
    /// - Parameter upgradeCompletionHandler: A block that will be fired when HTTP upgrade is complete.
93
    public init(
94
        upgraders: [HTTPServerProtocolUpgrader],
95
        httpEncoder: HTTPResponseEncoder,
96
        extraHTTPHandlers: [RemovableChannelHandler],
97
        upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void
98
0
    ) {
99
0
        var upgraderMap = [String: HTTPServerProtocolUpgrader]()
100
0
        for upgrader in upgraders {
101
0
            upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader
102
0
        }
103
0
        self.upgraders = upgraderMap
104
0
        self.upgradeCompletionHandler = upgradeCompletionHandler
105
0
        self.httpEncoder = httpEncoder
106
0
        self.extraHTTPHandlers = extraHTTPHandlers
107
0
    }
108
109
0
    public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
110
0
        guard !self.seenFirstRequest else {
111
0
            // We're waiting for upgrade to complete: buffer this data.
112
0
            self.receivedMessages.append(data)
113
0
            return
114
0
        }
115
0
116
0
        let requestPart = HTTPServerUpgradeHandler.unwrapInboundIn(data)
117
0
118
0
        switch self.upgradeState {
119
0
        case .idle:
120
0
            self.firstRequestHeadReceived(context: context, requestPart: requestPart)
121
0
        case .awaitingUpgrader:
122
0
            if case .end = requestPart {
123
0
                // This is the end of the first request. Swallow it, we're buffering the rest.
124
0
                self.seenFirstRequest = true
125
0
            }
126
0
        case .upgraderReady(let upgrade):
127
0
            if case .end = requestPart {
128
0
                // This is the end of the first request, and we can upgrade. Time to kick it off.
129
0
                self.seenFirstRequest = true
130
0
                upgrade()
131
0
            }
132
0
        case .upgradeFailed:
133
0
            // We were re-entrantly called while delivering the request head. We can just pass this through.
134
0
            context.fireChannelRead(data)
135
0
        case .upgradeComplete:
136
0
            preconditionFailure(
137
0
                "Upgrade has completed but we have not seen a whole request and still got re-entrantly called."
138
0
            )
139
0
        case .upgrading:
140
0
            preconditionFailure(
141
0
                "We think we saw .end before and began upgrading, but somehow we have not set seenFirstRequest"
142
0
            )
143
0
        }
144
0
    }
145
146
0
    public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
147
0
        // We have been formally removed from the pipeline. We should send any buffered data we have.
148
0
        // Note that we loop twice. This is because we want to guard against being reentrantly called from fireChannelReadComplete.
149
0
        while self.receivedMessages.count > 0 {
150
0
            while self.receivedMessages.count > 0 {
151
0
                let bufferedPart = self.receivedMessages.removeFirst()
152
0
                context.fireChannelRead(bufferedPart)
153
0
            }
154
0
155
0
            context.fireChannelReadComplete()
156
0
        }
157
0
158
0
        context.leavePipeline(removalToken: removalToken)
159
0
    }
160
161
0
    private func firstRequestHeadReceived(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) {
162
0
        // We should decide if we're going to upgrade based on the first request header: if we aren't upgrading,
163
0
        // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're
164
0
        // upgrading, the only thing we should see is a request head. Anything else in an error.
165
0
        guard case .head(let request) = requestPart else {
166
0
            context.fireErrorCaught(HTTPServerUpgradeErrors.invalidHTTPOrdering)
167
0
            self.notUpgrading(context: context, data: requestPart)
168
0
            return
169
0
        }
170
0
171
0
        // Ok, we have a HTTP request. Check if it's an upgrade. If it's not, we want to pass it on and remove ourselves
172
0
        // from the channel pipeline.
173
0
        let requestedProtocols = request.headers[canonicalForm: "upgrade"].map(String.init)
174
0
        guard requestedProtocols.count > 0 else {
175
0
            self.notUpgrading(context: context, data: requestPart)
176
0
            return
177
0
        }
178
0
179
0
        // Cool, this is an upgrade request.
180
0
        // We'll attempt to upgrade. This may take a while, so while we're waiting more data can come in.
181
0
        self.upgradeState = .awaitingUpgrader
182
0
183
0
        self.handleUpgrade(context: context, request: request, requestedProtocols: requestedProtocols)
184
0
            .whenSuccess { callback in
185
0
                if let callback = callback {
186
0
                    self.gotUpgrader(upgrader: callback)
187
0
                } else {
188
0
                    self.notUpgrading(context: context, data: requestPart)
189
0
                }
190
0
            }
191
0
    }
192
193
    /// The core of the upgrade handling logic.
194
    ///
195
    /// - Returns: An isolated `EventLoopFuture` that will contain a callback to invoke if upgrade is requested,
196
    /// or nil if upgrade has failed. Never returns a failed future.
197
    private func handleUpgrade(
198
        context: ChannelHandlerContext,
199
        request: HTTPRequestHead,
200
        requestedProtocols: [String]
201
0
    ) -> EventLoopFuture<(() -> Void)?>.Isolated {
202
0
203
0
        let connectionHeader = Set(request.headers[canonicalForm: "connection"].map { $0.lowercased() })
204
0
        let allHeaderNames = Set(request.headers.map { $0.name.lowercased() })
205
0
206
0
        // We now set off a chain of Futures to try to find a protocol upgrade. While this is blocking, we need to buffer inbound data.
207
0
        let protocolIterator = requestedProtocols.makeIterator()
208
0
        return self.handleUpgradeForProtocol(
209
0
            context: context,
210
0
            protocolIterator: protocolIterator,
211
0
            request: request,
212
0
            allHeaderNames: allHeaderNames,
213
0
            connectionHeader: connectionHeader
214
0
        )
215
0
    }
216
217
    /// Attempt to upgrade a single protocol.
218
    ///
219
    /// Will recurse through `protocolIterator` if upgrade fails.
220
    ///
221
    /// - Returns: An isolated `EventLoopFuture` that will contain a callback to invoke if upgrade is requested,
222
    /// or nil if upgrade has failed. Never returns a failed future.
223
    private func handleUpgradeForProtocol(
224
        context: ChannelHandlerContext,
225
        protocolIterator: Array<String>.Iterator,
226
        request: HTTPRequestHead,
227
        allHeaderNames: Set<String>,
228
        connectionHeader: Set<String>
229
0
    ) -> EventLoopFuture<(() -> Void)?>.Isolated {
230
0
        // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function.
231
0
        var protocolIterator = protocolIterator
232
0
        guard let proto = protocolIterator.next() else {
233
0
            // We're done! No suitable protocol for upgrade.
234
0
            return context.eventLoop.makeSucceededIsolatedFuture(nil)
235
0
        }
236
0
237
0
        guard let upgrader = self.upgraders[proto.lowercased()] else {
238
0
            return self.handleUpgradeForProtocol(
239
0
                context: context,
240
0
                protocolIterator: protocolIterator,
241
0
                request: request,
242
0
                allHeaderNames: allHeaderNames,
243
0
                connectionHeader: connectionHeader
244
0
            )
245
0
        }
246
0
247
0
        let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() })
248
0
        guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else {
249
0
            return self.handleUpgradeForProtocol(
250
0
                context: context,
251
0
                protocolIterator: protocolIterator,
252
0
                request: request,
253
0
                allHeaderNames: allHeaderNames,
254
0
                connectionHeader: connectionHeader
255
0
            )
256
0
        }
257
0
258
0
        let responseHeaders = self.buildUpgradeHeaders(protocol: proto)
259
0
        let pipeline = context.pipeline
260
0
261
0
        return upgrader.buildUpgradeResponse(
262
0
            channel: context.channel,
263
0
            upgradeRequest: request,
264
0
            initialResponseHeaders: responseHeaders
265
0
        ).hop(to: context.eventLoop)
266
0
            .assumeIsolated()
267
0
            .map { finalResponseHeaders in
268
0
                {
269
0
                    // Ok, we're upgrading.
270
0
                    self.upgradeState = .upgrading
271
0
272
0
                    // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP
273
0
                    // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until
274
0
                    // that completes.
275
0
                    // While there are a lot of Futures involved here it's quite possible that all of this code will
276
0
                    // actually complete synchronously: we just want to program for the possibility that it won't.
277
0
                    // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the
278
0
                    // internal handler, then call the user code, and then finally when the user code is done we do
279
0
                    // our final cleanup steps, namely we replay the received data we buffered in the meantime and
280
0
                    // then remove ourselves from the pipeline.
281
0
                    self.removeExtraHandlers(pipeline: pipeline)
282
0
                        .assumeIsolated()
283
0
                        .flatMap {
284
0
                            self.sendUpgradeResponse(
285
0
                                context: context,
286
0
                                upgradeRequest: request,
287
0
                                responseHeaders: finalResponseHeaders
288
0
                            )
289
0
                        }.flatMap {
290
0
                            pipeline.syncOperations.removeHandler(self.httpEncoder)
291
0
                        }.flatMap { () -> EventLoopFuture<Void> in
292
0
                            self.upgradeCompletionHandler(context)
293
0
                            return upgrader.upgrade(context: context, upgradeRequest: request)
294
0
                        }.whenComplete { result in
295
0
                            switch result {
296
0
                            case .success:
297
0
                                context.fireUserInboundEventTriggered(
298
0
                                    HTTPServerUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request)
299
0
                                )
300
0
                                self.upgradeState = .upgradeComplete
301
0
                                // When we remove ourselves we'll be delivering any buffered data.
302
0
                                context.pipeline.syncOperations.removeHandler(context: context, promise: nil)
303
0
304
0
                            case .failure(let error):
305
0
                                // Remain in the '.upgrading' state.
306
0
                                context.fireErrorCaught(error)
307
0
                            }
308
0
                        }
309
0
                }
310
0
            }.flatMapError { error in
311
0
                // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration.
312
0
                context.fireErrorCaught(error)
313
0
                return self.handleUpgradeForProtocol(
314
0
                    context: context,
315
0
                    protocolIterator: protocolIterator,
316
0
                    request: request,
317
0
                    allHeaderNames: allHeaderNames,
318
0
                    connectionHeader: connectionHeader
319
0
                )
320
0
            }
321
0
    }
322
323
0
    private func gotUpgrader(upgrader: @escaping (() -> Void)) {
324
0
        switch self.upgradeState {
325
0
        case .awaitingUpgrader:
326
0
            self.upgradeState = .upgraderReady(upgrader)
327
0
            if self.seenFirstRequest {
328
0
                // Ok, we're good to go, we can upgrade. Otherwise we're waiting for .end, which
329
0
                // will trigger the upgrade.
330
0
                upgrader()
331
0
            }
332
0
        case .idle, .upgradeComplete, .upgraderReady, .upgradeFailed, .upgrading:
333
0
            preconditionFailure("Unexpected upgrader state: \(self.upgradeState)")
334
0
        }
335
0
    }
336
337
    /// Sends the 101 Switching Protocols response for the pipeline.
338
    private func sendUpgradeResponse(
339
        context: ChannelHandlerContext,
340
        upgradeRequest: HTTPRequestHead,
341
        responseHeaders: HTTPHeaders
342
0
    ) -> EventLoopFuture<Void> {
343
0
        var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols)
344
0
        response.headers = responseHeaders
345
0
        return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response)))
346
0
    }
347
348
    /// Called when we know we're not upgrading. Passes the data on and then removes this object from the pipeline.
349
0
    private func notUpgrading(context: ChannelHandlerContext, data: HTTPServerRequestPart) {
350
0
        self.upgradeState = .upgradeFailed
351
0
352
0
        if !self.seenFirstRequest {
353
0
            // We haven't seen the first request .end. That means we're not buffering anything, and we can
354
0
            // just deliver data.
355
0
            assert(self.receivedMessages.isEmpty)
356
0
            context.fireChannelRead(HTTPServerUpgradeHandler.wrapInboundOut(data))
357
0
        } else {
358
0
            // This is trickier. We've seen the first request .end, so we now need to deliver the .head we
359
0
            // got passed, as well as the .end we swallowed, and any buffered parts. While we're doing this
360
0
            // we may be re-entrantly called, which will cause us to buffer new parts. To make that safe, we
361
0
            // must ensure we aren't holding the buffer mutably, so no for loop for us.
362
0
            context.fireChannelRead(HTTPServerUpgradeHandler.wrapInboundOut(data))
363
0
            context.fireChannelRead(HTTPServerUpgradeHandler.wrapInboundOut(.end(nil)))
364
0
        }
365
0
366
0
        context.fireChannelReadComplete()
367
0
368
0
        // Ok, we've delivered all the parts. We can now remove ourselves, which should happen synchronously.
369
0
        context.pipeline.syncOperations.removeHandler(context: context, promise: nil)
370
0
    }
371
372
    /// Builds the initial mandatory HTTP headers for HTTP upgrade responses.
373
0
    private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders {
374
0
        HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)])
375
0
    }
376
377
    /// Removes any extra HTTP-related handlers from the channel pipeline.
378
0
    private func removeExtraHandlers(pipeline: ChannelPipeline) -> EventLoopFuture<Void> {
379
0
        guard self.extraHTTPHandlers.count > 0 else {
380
0
            return pipeline.eventLoop.makeSucceededFuture(())
381
0
        }
382
0
383
0
        return .andAllSucceed(
384
0
            self.extraHTTPHandlers.map { pipeline.syncOperations.removeHandler($0) },
385
0
            on: pipeline.eventLoop
386
0
        )
387
0
    }
388
}
389
390
extension HTTPServerUpgradeHandler: @unchecked Sendable {}
391
392
extension HTTPServerUpgradeHandler {
393
    /// The state of the upgrade handler.
394
    private enum UpgradeState {
395
        /// Awaiting some activity.
396
        case idle
397
398
        /// The request head has been received. We're currently running the future chain awaiting an upgrader.
399
        case awaitingUpgrader
400
401
        /// We have an upgrader, which means we can begin upgrade.
402
        case upgraderReady(() -> Void)
403
404
        /// The upgrade is in process.
405
        case upgrading
406
407
        /// The upgrade has failed, and we are being removed from the pipeline.
408
        case upgradeFailed
409
410
        /// The upgrade has succeeded, and we are being removed from the pipeline.
411
        case upgradeComplete
412
    }
413
}