/src/grpc-swift/Sources/GRPC/WebCORSHandler.swift
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * Copyright 2019, gRPC Authors All rights reserved. |
3 | | * |
4 | | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | * you may not use this file except in compliance with the License. |
6 | | * You may obtain a copy of the License at |
7 | | * |
8 | | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | | * |
10 | | * Unless required by applicable law or agreed to in writing, software |
11 | | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | * See the License for the specific language governing permissions and |
14 | | * limitations under the License. |
15 | | */ |
16 | | import NIOCore |
17 | | import NIOHTTP1 |
18 | | |
19 | | /// Handler that manages the CORS protocol for requests incoming from the browser. |
20 | | internal final class WebCORSHandler { |
21 | | let configuration: Server.Configuration.CORS |
22 | | |
23 | 19.9k | private var state: State = .idle |
24 | | private enum State: Equatable { |
25 | | /// Starting state. |
26 | | case idle |
27 | | /// CORS preflight request is in progress. |
28 | | case processingPreflightRequest |
29 | | /// "Real" request is in progress. |
30 | | case processingRequest(origin: String?) |
31 | | } |
32 | | |
33 | 12.0k | init(configuration: Server.Configuration.CORS) { |
34 | 12.0k | self.configuration = configuration |
35 | 12.0k | } |
36 | | } |
37 | | |
38 | | extension WebCORSHandler: ChannelInboundHandler { |
39 | | typealias InboundIn = HTTPServerRequestPart |
40 | | typealias InboundOut = HTTPServerRequestPart |
41 | | typealias OutboundOut = HTTPServerResponsePart |
42 | | |
43 | 549k | func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
44 | 549k | switch self.unwrapInboundIn(data) { |
45 | 549k | case let .head(head): |
46 | 233k | self.receivedRequestHead(context: context, head) |
47 | 549k | |
48 | 549k | case let .body(body): |
49 | 84.3k | self.receivedRequestBody(context: context, body) |
50 | 549k | |
51 | 549k | case let .end(trailers): |
52 | 230k | self.receivedRequestEnd(context: context, trailers) |
53 | 549k | } |
54 | 549k | } |
55 | | |
56 | 356k | private func receivedRequestHead(context: ChannelHandlerContext, _ head: HTTPRequestHead) { |
57 | 356k | if head.method == .OPTIONS, |
58 | 356k | head.headers.contains(.accessControlRequestMethod), |
59 | 356k | let origin = head.headers.first(name: "origin") |
60 | 356k | { |
61 | 0 | // If the request is OPTIONS with a access-control-request-method header it's a CORS |
62 | 0 | // preflight request and is not propagated further. |
63 | 0 | self.state = .processingPreflightRequest |
64 | 0 | self.handlePreflightRequest(context: context, head: head, origin: origin) |
65 | 356k | } else { |
66 | 356k | self.state = .processingRequest(origin: head.headers.first(name: "origin")) |
67 | 356k | context.fireChannelRead(self.wrapInboundOut(.head(head))) |
68 | 356k | } |
69 | 356k | } |
70 | | |
71 | 163k | private func receivedRequestBody(context: ChannelHandlerContext, _ body: ByteBuffer) { |
72 | 163k | // OPTIONS requests do not have a body, but still handle this case to be |
73 | 163k | // cautious. |
74 | 163k | if self.state == .processingPreflightRequest { |
75 | 0 | return |
76 | 163k | } |
77 | 163k | |
78 | 163k | context.fireChannelRead(self.wrapInboundOut(.body(body))) |
79 | 163k | } |
80 | | |
81 | 351k | private func receivedRequestEnd(context: ChannelHandlerContext, _ trailers: HTTPHeaders?) { |
82 | 351k | if self.state == .processingPreflightRequest { |
83 | 0 | // End of OPTIONS request; reset state and finish the response. |
84 | 0 | self.state = .idle |
85 | 0 | context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) |
86 | 351k | } else { |
87 | 351k | context.fireChannelRead(self.wrapInboundOut(.end(trailers))) |
88 | 351k | } |
89 | 351k | } |
90 | | |
91 | | private func handlePreflightRequest( |
92 | | context: ChannelHandlerContext, |
93 | | head: HTTPRequestHead, |
94 | | origin: String |
95 | 0 | ) { |
96 | 0 | let responseHead: HTTPResponseHead |
97 | 0 |
|
98 | 0 | if let allowedOrigin = self.configuration.allowedOrigins.header(origin) { |
99 | 0 | var headers = HTTPHeaders() |
100 | 0 | headers.reserveCapacity(4 + self.configuration.allowedHeaders.count) |
101 | 0 | headers.add(name: .accessControlAllowOrigin, value: allowedOrigin) |
102 | 0 | headers.add(name: .accessControlAllowMethods, value: "POST") |
103 | 0 |
|
104 | 0 | for value in self.configuration.allowedHeaders { |
105 | 0 | headers.add(name: .accessControlAllowHeaders, value: value) |
106 | 0 | } |
107 | 0 |
|
108 | 0 | if self.configuration.allowCredentialedRequests { |
109 | 0 | headers.add(name: .accessControlAllowCredentials, value: "true") |
110 | 0 | } |
111 | 0 |
|
112 | 0 | if self.configuration.preflightCacheExpiration > 0 { |
113 | 0 | headers.add( |
114 | 0 | name: .accessControlMaxAge, |
115 | 0 | value: "\(self.configuration.preflightCacheExpiration)" |
116 | 0 | ) |
117 | 0 | } |
118 | 0 | responseHead = HTTPResponseHead(version: head.version, status: .ok, headers: headers) |
119 | 0 | } else { |
120 | 0 | // Not allowed; respond with 403. This is okay in a pre-flight request. |
121 | 0 | responseHead = HTTPResponseHead(version: head.version, status: .forbidden) |
122 | 0 | } |
123 | 0 |
|
124 | 0 | context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil) |
125 | 0 | } |
126 | | } |
127 | | |
128 | | extension WebCORSHandler: ChannelOutboundHandler { |
129 | | typealias OutboundIn = HTTPServerResponsePart |
130 | | |
131 | 629k | func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) { |
132 | 629k | let responsePart = self.unwrapOutboundIn(data) |
133 | 629k | switch responsePart { |
134 | 629k | case var .head(responseHead): |
135 | 233k | switch self.state { |
136 | 233k | case let .processingRequest(origin): |
137 | 233k | self.prepareCORSResponseHead(&responseHead, origin: origin) |
138 | 233k | context.write(self.wrapOutboundOut(.head(responseHead)), promise: promise) |
139 | 233k | |
140 | 233k | case .idle, .processingPreflightRequest: |
141 | 0 | assertionFailure("Writing response head when no request is in progress") |
142 | 0 | context.close(promise: nil) |
143 | 629k | } |
144 | 629k | |
145 | 629k | case .body: |
146 | 162k | context.write(data, promise: promise) |
147 | 629k | |
148 | 629k | case .end: |
149 | 233k | self.state = .idle |
150 | 233k | context.write(data, promise: promise) |
151 | 629k | } |
152 | 629k | } |
153 | | |
154 | 356k | private func prepareCORSResponseHead(_ head: inout HTTPResponseHead, origin: String?) { |
155 | 356k | guard let header = origin.flatMap({ self.configuration.allowedOrigins.header($0) }) else { |
156 | 353k | // No origin or the origin is not allowed; don't treat it as a CORS request. |
157 | 353k | return |
158 | 353k | } |
159 | 3.16k | |
160 | 3.16k | head.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: header) |
161 | 3.16k | |
162 | 3.16k | if self.configuration.allowCredentialedRequests { |
163 | 0 | head.headers.add(name: .accessControlAllowCredentials, value: "true") |
164 | 3.16k | } |
165 | 3.16k | |
166 | 3.16k | //! FIXME: Check whether we can let browsers keep connections alive. It's not possible |
167 | 3.16k | // now as the channel has a state that can't be reused since the pipeline is modified to |
168 | 3.16k | // inject the gRPC call handler. |
169 | 3.16k | head.headers.replaceOrAdd(name: "Connection", value: "close") |
170 | 3.16k | } |
171 | | } |
172 | | |
173 | | extension HTTPHeaders { |
174 | | fileprivate enum CORSHeader: String { |
175 | | case accessControlRequestMethod = "access-control-request-method" |
176 | | case accessControlRequestHeaders = "access-control-request-headers" |
177 | | case accessControlAllowOrigin = "access-control-allow-origin" |
178 | | case accessControlAllowMethods = "access-control-allow-methods" |
179 | | case accessControlAllowHeaders = "access-control-allow-headers" |
180 | | case accessControlAllowCredentials = "access-control-allow-credentials" |
181 | | case accessControlMaxAge = "access-control-max-age" |
182 | | } |
183 | | |
184 | 1.90k | fileprivate func contains(_ name: CORSHeader) -> Bool { |
185 | 1.90k | return self.contains(name: name.rawValue) |
186 | 1.90k | } |
187 | | |
188 | 0 | fileprivate mutating func add(name: CORSHeader, value: String) { |
189 | 0 | self.add(name: name.rawValue, value: value) |
190 | 0 | } |
191 | | |
192 | 3.16k | fileprivate mutating func replaceOrAdd(name: CORSHeader, value: String) { |
193 | 3.16k | self.replaceOrAdd(name: name.rawValue, value: value) |
194 | 3.16k | } |
195 | | } |
196 | | |
197 | | extension Server.Configuration.CORS.AllowedOrigins { |
198 | 3.16k | internal func header(_ origin: String) -> String? { |
199 | 3.16k | switch self.wrapped { |
200 | 3.16k | case .all: |
201 | 3.16k | return "*" |
202 | 3.16k | case .originBased: |
203 | 0 | return origin |
204 | 3.16k | case let .only(allowed): |
205 | 0 | return allowed.contains(origin) ? origin : nil |
206 | 3.16k | case let .custom(custom): |
207 | 0 | return custom.check(origin: origin) |
208 | 3.16k | } |
209 | 3.16k | } |
210 | | } |