/src/swift-nio/Sources/NIOHTTP1/NIOHTTPClientUpgradeHandler.swift
Line | Count | Source |
1 | | //===----------------------------------------------------------------------===// |
2 | | // |
3 | | // This source file is part of the SwiftNIO open source project |
4 | | // |
5 | | // Copyright (c) 2019-2024 Apple Inc. and the SwiftNIO project authors |
6 | | // Licensed under Apache License v2.0 |
7 | | // |
8 | | // See LICENSE.txt for license information |
9 | | // See CONTRIBUTORS.txt for the list of SwiftNIO project authors |
10 | | // |
11 | | // SPDX-License-Identifier: Apache-2.0 |
12 | | // |
13 | | //===----------------------------------------------------------------------===// |
14 | | import NIOCore |
15 | | |
16 | | /// Errors that may be raised by the `HTTPClientProtocolUpgrader`. |
17 | | public struct NIOHTTPClientUpgradeError: Hashable, Error { |
18 | | |
19 | | // Uses the open enum style to allow additional errors to be added in future. |
20 | | private enum Code: Hashable { |
21 | | case responseProtocolNotFound |
22 | | case invalidHTTPOrdering |
23 | | case upgraderDeniedUpgrade |
24 | | case writingToHandlerDuringUpgrade |
25 | | case writingToHandlerAfterUpgradeCompleted |
26 | | case writingToHandlerAfterUpgradeFailed |
27 | | case receivedResponseBeforeRequestSent |
28 | | case receivedResponseAfterUpgradeCompleted |
29 | | } |
30 | | |
31 | | private var code: Code |
32 | | |
33 | 0 | private init(_ code: Code) { |
34 | 0 | self.code = code |
35 | 0 | } |
36 | | |
37 | | public static let responseProtocolNotFound = NIOHTTPClientUpgradeError(.responseProtocolNotFound) |
38 | | public static let invalidHTTPOrdering = NIOHTTPClientUpgradeError(.invalidHTTPOrdering) |
39 | | public static let upgraderDeniedUpgrade = NIOHTTPClientUpgradeError(.upgraderDeniedUpgrade) |
40 | | public static let writingToHandlerDuringUpgrade = NIOHTTPClientUpgradeError(.writingToHandlerDuringUpgrade) |
41 | | public static let writingToHandlerAfterUpgradeCompleted = NIOHTTPClientUpgradeError( |
42 | | .writingToHandlerAfterUpgradeCompleted |
43 | | ) |
44 | | public static let writingToHandlerAfterUpgradeFailed = NIOHTTPClientUpgradeError( |
45 | | .writingToHandlerAfterUpgradeFailed |
46 | | ) |
47 | | public static let receivedResponseBeforeRequestSent = NIOHTTPClientUpgradeError(.receivedResponseBeforeRequestSent) |
48 | | public static let receivedResponseAfterUpgradeCompleted = NIOHTTPClientUpgradeError( |
49 | | .receivedResponseAfterUpgradeCompleted |
50 | | ) |
51 | | } |
52 | | |
53 | | extension NIOHTTPClientUpgradeError: CustomStringConvertible { |
54 | 0 | public var description: String { |
55 | 0 | String(describing: self.code) |
56 | 0 | } |
57 | | } |
58 | | |
59 | | /// An object that implements `NIOHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to |
60 | | /// a protocol on a client-side channel. |
61 | | /// It has the option of denying this upgrade based upon the server response. |
62 | | public protocol NIOHTTPClientProtocolUpgrader { |
63 | | |
64 | | /// The protocol this upgrader knows how to support. |
65 | | var supportedProtocol: String { get } |
66 | | |
67 | | /// All the header fields the protocol requires in the request to successfully upgrade. |
68 | | /// These header fields will be added to the outbound request's "Connection" header field. |
69 | | /// It is the responsibility of the custom headers call to actually add these required headers. |
70 | | var requiredUpgradeHeaders: [String] { get } |
71 | | |
72 | | /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. |
73 | | func addCustom(upgradeRequestHeaders: inout HTTPHeaders) |
74 | | |
75 | | /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. |
76 | | func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool |
77 | | |
78 | | /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel |
79 | | /// pipeline to add whatever channel handlers are required. |
80 | | /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. |
81 | | func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture<Void> |
82 | | } |
83 | | |
84 | | /// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. |
85 | | /// When the first HTTP request is sent, this handler will add all appropriate headers to perform an upgrade to |
86 | | /// the a protocol. It may add headers for a set of protocols in preference order. |
87 | | /// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply |
88 | | /// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. |
89 | | /// |
90 | | /// The request sends an order of preference to request which protocol it would like to use for the upgrade. |
91 | | /// It will only upgrade to the protocol that is returned first in the list and does not currently |
92 | | /// have the capability to upgrade to multiple simultaneous layered protocols. |
93 | | public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { |
94 | | |
95 | | public typealias OutboundIn = HTTPClientRequestPart |
96 | | public typealias OutboundOut = HTTPClientRequestPart |
97 | | |
98 | | public typealias InboundIn = HTTPClientResponsePart |
99 | | public typealias InboundOut = HTTPClientResponsePart |
100 | | |
101 | | private var upgraders: [NIOHTTPClientProtocolUpgrader] |
102 | | private let httpHandlers: [RemovableChannelHandler] |
103 | | private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void |
104 | | |
105 | | /// Whether we've already seen the first response from our initial upgrade request. |
106 | 0 | private var seenFirstResponse = false |
107 | | |
108 | 0 | private var upgradeState: UpgradeState = .requestRequired |
109 | | |
110 | 0 | private var receivedMessages: CircularBuffer<NIOAny> = CircularBuffer() |
111 | | |
112 | | /// Create a `HTTPClientUpgradeHandler`. |
113 | | /// |
114 | | /// - Parameter upgraders: All `HTTPClientProtocolUpgrader` objects that will add their upgrade request |
115 | | /// headers and handle the upgrade if there is a response for their protocol. They should be placed in |
116 | | /// order of the preference for the upgrade. |
117 | | /// - Parameter httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline |
118 | | /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state |
119 | | /// after the upgrade. It should include any handlers that are directly related to handling HTTP. |
120 | | /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include |
121 | | /// any other handler that cannot tolerate receiving non-HTTP data. |
122 | | /// - Parameter upgradeCompletionHandler: A closure that will be fired when HTTP upgrade is complete. |
123 | | public convenience init( |
124 | | upgraders: [NIOHTTPClientProtocolUpgrader], |
125 | | httpHandlers: [RemovableChannelHandler], |
126 | | upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void |
127 | 0 | ) { |
128 | 0 | self.init(_upgraders: upgraders, httpHandlers: httpHandlers, upgradeCompletionHandler: upgradeCompletionHandler) |
129 | 0 | } |
130 | | |
131 | | private init( |
132 | | _upgraders upgraders: [NIOHTTPClientProtocolUpgrader], |
133 | | httpHandlers: [RemovableChannelHandler], |
134 | | upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void |
135 | 0 | ) { |
136 | 0 | precondition(upgraders.count > 0, "A minimum of one protocol upgrader must be specified.") |
137 | 0 |
|
138 | 0 | self.upgraders = upgraders |
139 | 0 | self.httpHandlers = httpHandlers |
140 | 0 | self.upgradeCompletionHandler = upgradeCompletionHandler |
141 | 0 | } |
142 | | |
143 | 0 | public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) { |
144 | 0 |
|
145 | 0 | switch self.upgradeState { |
146 | 0 |
|
147 | 0 | case .requestRequired: |
148 | 0 | let updatedData = self.addHeadersToOutboundOut(data: data) |
149 | 0 | context.write(updatedData, promise: promise) |
150 | 0 |
|
151 | 0 | case .awaitingConfirmationResponse: |
152 | 0 | // Still have full http stack. |
153 | 0 | context.write(data, promise: promise) |
154 | 0 |
|
155 | 0 | case .upgraderReady, .upgrading: |
156 | 0 | promise?.fail(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) |
157 | 0 | context.fireErrorCaught(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) |
158 | 0 |
|
159 | 0 | case .upgradingAddingHandlers: |
160 | 0 | // These are most likely messages immediately fired by a new protocol handler. |
161 | 0 | // As that is added last we can just forward them on. |
162 | 0 | context.write(data, promise: promise) |
163 | 0 |
|
164 | 0 | case .upgradeComplete: |
165 | 0 | // Upgrade complete and this handler should have been removed from the pipeline. |
166 | 0 | promise?.fail(NIOHTTPClientUpgradeError.writingToHandlerAfterUpgradeCompleted) |
167 | 0 | context.fireErrorCaught(NIOHTTPClientUpgradeError.writingToHandlerAfterUpgradeCompleted) |
168 | 0 |
|
169 | 0 | case .upgradeFailed: |
170 | 0 | // Upgrade failed and this handler should have been removed from the pipeline. |
171 | 0 | promise?.fail(NIOHTTPClientUpgradeError.writingToHandlerAfterUpgradeCompleted) |
172 | 0 | context.fireErrorCaught(NIOHTTPClientUpgradeError.writingToHandlerAfterUpgradeCompleted) |
173 | 0 | } |
174 | 0 | } |
175 | | |
176 | 0 | private func addHeadersToOutboundOut(data: NIOAny) -> NIOAny { |
177 | 0 |
|
178 | 0 | let interceptedOutgoingRequest = NIOHTTPClientUpgradeHandler.unwrapOutboundIn(data) |
179 | 0 |
|
180 | 0 | if case .head(var requestHead) = interceptedOutgoingRequest { |
181 | 0 |
|
182 | 0 | self.upgradeState = .awaitingConfirmationResponse |
183 | 0 |
|
184 | 0 | self.addConnectionHeaders(to: &requestHead) |
185 | 0 | self.addUpgradeHeaders(to: &requestHead) |
186 | 0 | return NIOHTTPClientUpgradeHandler.wrapOutboundOut(.head(requestHead)) |
187 | 0 | } |
188 | 0 |
|
189 | 0 | return data |
190 | 0 | } |
191 | | |
192 | 0 | private func addConnectionHeaders(to requestHead: inout HTTPRequestHead) { |
193 | 0 |
|
194 | 0 | let requiredHeaders = ["upgrade"] + self.upgraders.flatMap { $0.requiredUpgradeHeaders } |
195 | 0 | requestHead.headers.add(name: "Connection", value: requiredHeaders.joined(separator: ",")) |
196 | 0 | } |
197 | | |
198 | 0 | private func addUpgradeHeaders(to requestHead: inout HTTPRequestHead) { |
199 | 0 |
|
200 | 0 | let allProtocols = self.upgraders.map { $0.supportedProtocol.lowercased() } |
201 | 0 | requestHead.headers.add(name: "Upgrade", value: allProtocols.joined(separator: ",")) |
202 | 0 |
|
203 | 0 | // Allow each upgrader the chance to add custom headers. |
204 | 0 | for upgrader in self.upgraders { |
205 | 0 | upgrader.addCustom(upgradeRequestHeaders: &requestHead.headers) |
206 | 0 | } |
207 | 0 | } |
208 | | |
209 | 0 | public func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
210 | 0 |
|
211 | 0 | guard !self.seenFirstResponse else { |
212 | 0 | // We're waiting for upgrade to complete: buffer this data. |
213 | 0 | self.receivedMessages.append(data) |
214 | 0 | return |
215 | 0 | } |
216 | 0 |
|
217 | 0 | let responsePart = NIOHTTPClientUpgradeHandler.unwrapInboundIn(data) |
218 | 0 |
|
219 | 0 | switch self.upgradeState { |
220 | 0 | case .awaitingConfirmationResponse: |
221 | 0 | self.firstResponseHeadReceived(context: context, responsePart: responsePart) |
222 | 0 | case .upgrading, .upgradingAddingHandlers: |
223 | 0 | if case .end = responsePart { |
224 | 0 | // This is the end of the first response. Swallow it, we're buffering the rest. |
225 | 0 | self.seenFirstResponse = true |
226 | 0 | } |
227 | 0 | case .upgraderReady(let upgrade): |
228 | 0 | if case .end = responsePart { |
229 | 0 | // This is the end of the first response, and we can upgrade. Time to kick it off. |
230 | 0 | self.seenFirstResponse = true |
231 | 0 | upgrade() |
232 | 0 | } |
233 | 0 | case .upgradeFailed: |
234 | 0 | // We were reentrantly called while delivering the response head. We can just pass this through. |
235 | 0 | context.fireChannelRead(data) |
236 | 0 | case .upgradeComplete: |
237 | 0 | //Upgrade has completed but we have not seen a whole response and still got reentrantly called. |
238 | 0 | context.fireErrorCaught(NIOHTTPClientUpgradeError.receivedResponseAfterUpgradeCompleted) |
239 | 0 | case .requestRequired: |
240 | 0 | //We are receiving an upgrade response and we have not requested the upgrade. |
241 | 0 | context.fireErrorCaught(NIOHTTPClientUpgradeError.receivedResponseBeforeRequestSent) |
242 | 0 | } |
243 | 0 | } |
244 | | |
245 | 0 | private func firstResponseHeadReceived(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { |
246 | 0 |
|
247 | 0 | // We should decide if we're can upgrade based on the first response header: if we aren't upgrading, |
248 | 0 | // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're |
249 | 0 | // upgrading, the only thing we should see is a response head. Anything else in an error. |
250 | 0 | guard case .head(let response) = responsePart else { |
251 | 0 | self.notUpgrading(context: context, data: responsePart, error: .invalidHTTPOrdering) |
252 | 0 | return |
253 | 0 | } |
254 | 0 |
|
255 | 0 | // Assess whether the upgrade response has accepted our upgrade request. |
256 | 0 | guard case .switchingProtocols = response.status else { |
257 | 0 | self.notUpgrading(context: context, data: responsePart, error: nil) |
258 | 0 | return |
259 | 0 | } |
260 | 0 |
|
261 | 0 | do { |
262 | 0 | let callback = try self.handleUpgrade(context: context, upgradeResponse: response) |
263 | 0 | self.gotUpgrader(upgrader: callback) |
264 | 0 | } catch { |
265 | 0 | let clientError = error as? NIOHTTPClientUpgradeError |
266 | 0 | self.notUpgrading(context: context, data: responsePart, error: clientError) |
267 | 0 | } |
268 | 0 | } |
269 | | |
270 | | private func handleUpgrade( |
271 | | context: ChannelHandlerContext, |
272 | | upgradeResponse response: HTTPResponseHead |
273 | 0 | ) throws -> (() -> Void) { |
274 | 0 |
|
275 | 0 | // Ok, we have a HTTP response. Check if it's an upgrade confirmation. |
276 | 0 | // If it's not, we want to pass it on and remove ourselves from the channel pipeline. |
277 | 0 | let acceptedProtocols = response.headers[canonicalForm: "upgrade"] |
278 | 0 |
|
279 | 0 | // At the moment we only upgrade to the first protocol returned from the server. |
280 | 0 | guard let protocolName = acceptedProtocols.first?.lowercased() else { |
281 | 0 | // There are no upgrade protocols returned. |
282 | 0 | throw NIOHTTPClientUpgradeError.responseProtocolNotFound |
283 | 0 | } |
284 | 0 |
|
285 | 0 | return try self.handleUpgradeForProtocol( |
286 | 0 | context: context, |
287 | 0 | protocolName: protocolName, |
288 | 0 | response: response |
289 | 0 | ) |
290 | 0 | } |
291 | | |
292 | | /// Attempt to upgrade a single protocol. |
293 | | private func handleUpgradeForProtocol( |
294 | | context: ChannelHandlerContext, |
295 | | protocolName: String, |
296 | | response: HTTPResponseHead |
297 | 0 | ) throws -> (() -> Void) { |
298 | 0 |
|
299 | 0 | let matchingUpgrader = self.upgraders |
300 | 0 | .first(where: { $0.supportedProtocol.lowercased() == protocolName }) |
301 | 0 |
|
302 | 0 | guard let upgrader = matchingUpgrader else { |
303 | 0 | // There is no upgrader for this protocol. |
304 | 0 | throw NIOHTTPClientUpgradeError.responseProtocolNotFound |
305 | 0 | } |
306 | 0 |
|
307 | 0 | guard upgrader.shouldAllowUpgrade(upgradeResponse: response) else { |
308 | 0 | // The upgrader says no. |
309 | 0 | throw NIOHTTPClientUpgradeError.upgraderDeniedUpgrade |
310 | 0 | } |
311 | 0 |
|
312 | 0 | return self.performUpgrade(context: context, upgrader: upgrader, response: response) |
313 | 0 | } |
314 | | |
315 | | private func performUpgrade( |
316 | | context: ChannelHandlerContext, |
317 | | upgrader: NIOHTTPClientProtocolUpgrader, |
318 | | response: HTTPResponseHead |
319 | 0 | ) -> () -> Void { |
320 | 0 |
|
321 | 0 | // Before we start the upgrade we have to remove the HTTPEncoder and HTTPDecoder handlers from the |
322 | 0 | // pipeline, to prevent them parsing any more data. We'll buffer the incoming data until that completes. |
323 | 0 | // While there are a lot of Futures involved here it's quite possible that all of this code will |
324 | 0 | // actually complete synchronously: we just want to program for the possibility that it won't. |
325 | 0 | // Once that's done, we call the internal handler, then call the upgrader code, and then finally when the |
326 | 0 | // upgrader code is done, we do our final cleanup steps, namely we replay the received data we |
327 | 0 | // buffered in the meantime and then remove ourselves from the pipeline. |
328 | 0 | let pipeline = context.pipeline |
329 | 0 | return { |
330 | 0 | self.upgradeState = .upgrading |
331 | 0 |
|
332 | 0 | self.removeHTTPHandlers(pipeline: pipeline) |
333 | 0 | .map { |
334 | 0 | // Let the other handlers be removed before continuing with upgrade. |
335 | 0 | self.upgradeCompletionHandler(context) |
336 | 0 | self.upgradeState = .upgradingAddingHandlers |
337 | 0 | } |
338 | 0 | .flatMap { |
339 | 0 | upgrader.upgrade(context: context, upgradeResponse: response) |
340 | 0 | } |
341 | 0 | .map { |
342 | 0 | // We unbuffer any buffered data here. |
343 | 0 |
|
344 | 0 | // If we received any, we fire readComplete. |
345 | 0 | let fireReadComplete = self.receivedMessages.count > 0 |
346 | 0 | while self.receivedMessages.count > 0 { |
347 | 0 | let bufferedPart = self.receivedMessages.removeFirst() |
348 | 0 | context.fireChannelRead(bufferedPart) |
349 | 0 | } |
350 | 0 | if fireReadComplete { |
351 | 0 | context.fireChannelReadComplete() |
352 | 0 | } |
353 | 0 |
|
354 | 0 | // We wait with the state change until _after_ the channel reads here. |
355 | 0 | // This is to prevent firing writes in response to these reads after we went to .upgradeComplete |
356 | 0 | // See: https://github.com/apple/swift-nio/issues/1279 |
357 | 0 | self.upgradeState = .upgradeComplete |
358 | 0 | } |
359 | 0 | .whenComplete { _ in |
360 | 0 | context.pipeline.syncOperations.removeHandler(context: context, promise: nil) |
361 | 0 | } |
362 | 0 | } |
363 | 0 | } |
364 | | |
365 | | /// Removes any extra HTTP-related handlers from the channel pipeline. |
366 | 0 | private func removeHTTPHandlers(pipeline: ChannelPipeline) -> EventLoopFuture<Void>.Isolated { |
367 | 0 | guard self.httpHandlers.count > 0 else { |
368 | 0 | return pipeline.eventLoop.makeSucceededIsolatedFuture(()) |
369 | 0 | } |
370 | 0 |
|
371 | 0 | let removeFutures = self.httpHandlers.map { pipeline.syncOperations.removeHandler($0) } |
372 | 0 | return EventLoopFuture.andAllSucceed(removeFutures, on: pipeline.eventLoop).assumeIsolated() |
373 | 0 | } |
374 | | |
375 | 0 | private func gotUpgrader(upgrader: @escaping (() -> Void)) { |
376 | 0 |
|
377 | 0 | self.upgradeState = .upgraderReady(upgrader) |
378 | 0 | if self.seenFirstResponse { |
379 | 0 | // Ok, we're good to go, we can upgrade. Otherwise we're waiting for .end, which |
380 | 0 | // will trigger the upgrade. |
381 | 0 | upgrader() |
382 | 0 | } |
383 | 0 | } |
384 | | |
385 | | private func notUpgrading( |
386 | | context: ChannelHandlerContext, |
387 | | data: HTTPClientResponsePart, |
388 | | error: NIOHTTPClientUpgradeError? |
389 | 0 | ) { |
390 | 0 |
|
391 | 0 | self.upgradeState = .upgradeFailed |
392 | 0 |
|
393 | 0 | if let error = error { |
394 | 0 | context.fireErrorCaught(error) |
395 | 0 | } |
396 | 0 |
|
397 | 0 | assert(self.receivedMessages.isEmpty) |
398 | 0 | context.fireChannelRead(NIOHTTPClientUpgradeHandler.wrapInboundOut(data)) |
399 | 0 |
|
400 | 0 | // We've delivered the data. We can now remove ourselves, which should happen synchronously. |
401 | 0 | context.pipeline.syncOperations.removeHandler(context: context, promise: nil) |
402 | 0 | } |
403 | | } |
404 | | |
405 | | extension NIOHTTPClientUpgradeHandler: @unchecked Sendable {} |
406 | | |
407 | | extension NIOHTTPClientUpgradeHandler { |
408 | | /// The state of the upgrade handler. |
409 | | fileprivate enum UpgradeState { |
410 | | /// Request not sent. This will need to be sent to initiate the upgrade. |
411 | | case requestRequired |
412 | | |
413 | | /// Awaiting confirmation response which will allow the upgrade to zero one or more protocols. |
414 | | case awaitingConfirmationResponse |
415 | | |
416 | | /// The response head has been received. We have an upgrader, which means we can begin upgrade. |
417 | | case upgraderReady(() -> Void) |
418 | | |
419 | | /// The response head has been received. The upgrade is in process. |
420 | | case upgrading |
421 | | |
422 | | /// The upgrade is in process and all of the http handlers have been removed. |
423 | | case upgradingAddingHandlers |
424 | | |
425 | | /// The upgrade has succeeded, and we are being removed from the pipeline. |
426 | | case upgradeComplete |
427 | | |
428 | | /// The upgrade has failed. |
429 | | case upgradeFailed |
430 | | } |
431 | | } |