Coverage Report

Created: 2025-06-24 06:59

/src/grpc-swift/Sources/GRPC/WebCORSHandler.swift
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2019, gRPC Authors All rights reserved.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16
import NIOCore
17
import NIOHTTP1
18
19
/// Handler that manages the CORS protocol for requests incoming from the browser.
20
internal final class WebCORSHandler {
21
  let configuration: Server.Configuration.CORS
22
23
19.9k
  private var state: State = .idle
24
  private enum State: Equatable {
25
    /// Starting state.
26
    case idle
27
    /// CORS preflight request is in progress.
28
    case processingPreflightRequest
29
    /// "Real" request is in progress.
30
    case processingRequest(origin: String?)
31
  }
32
33
12.0k
  init(configuration: Server.Configuration.CORS) {
34
12.0k
    self.configuration = configuration
35
12.0k
  }
36
}
37
38
extension WebCORSHandler: ChannelInboundHandler {
39
  typealias InboundIn = HTTPServerRequestPart
40
  typealias InboundOut = HTTPServerRequestPart
41
  typealias OutboundOut = HTTPServerResponsePart
42
43
549k
  func channelRead(context: ChannelHandlerContext, data: NIOAny) {
44
549k
    switch self.unwrapInboundIn(data) {
45
549k
    case let .head(head):
46
233k
      self.receivedRequestHead(context: context, head)
47
549k
48
549k
    case let .body(body):
49
84.3k
      self.receivedRequestBody(context: context, body)
50
549k
51
549k
    case let .end(trailers):
52
230k
      self.receivedRequestEnd(context: context, trailers)
53
549k
    }
54
549k
  }
55
56
356k
  private func receivedRequestHead(context: ChannelHandlerContext, _ head: HTTPRequestHead) {
57
356k
    if head.method == .OPTIONS,
58
356k
      head.headers.contains(.accessControlRequestMethod),
59
356k
      let origin = head.headers.first(name: "origin")
60
356k
    {
61
0
      // If the request is OPTIONS with a access-control-request-method header it's a CORS
62
0
      // preflight request and is not propagated further.
63
0
      self.state = .processingPreflightRequest
64
0
      self.handlePreflightRequest(context: context, head: head, origin: origin)
65
356k
    } else {
66
356k
      self.state = .processingRequest(origin: head.headers.first(name: "origin"))
67
356k
      context.fireChannelRead(self.wrapInboundOut(.head(head)))
68
356k
    }
69
356k
  }
70
71
163k
  private func receivedRequestBody(context: ChannelHandlerContext, _ body: ByteBuffer) {
72
163k
    // OPTIONS requests do not have a body, but still handle this case to be
73
163k
    // cautious.
74
163k
    if self.state == .processingPreflightRequest {
75
0
      return
76
163k
    }
77
163k
78
163k
    context.fireChannelRead(self.wrapInboundOut(.body(body)))
79
163k
  }
80
81
351k
  private func receivedRequestEnd(context: ChannelHandlerContext, _ trailers: HTTPHeaders?) {
82
351k
    if self.state == .processingPreflightRequest {
83
0
      // End of OPTIONS request; reset state and finish the response.
84
0
      self.state = .idle
85
0
      context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
86
351k
    } else {
87
351k
      context.fireChannelRead(self.wrapInboundOut(.end(trailers)))
88
351k
    }
89
351k
  }
90
91
  private func handlePreflightRequest(
92
    context: ChannelHandlerContext,
93
    head: HTTPRequestHead,
94
    origin: String
95
0
  ) {
96
0
    let responseHead: HTTPResponseHead
97
0
98
0
    if let allowedOrigin = self.configuration.allowedOrigins.header(origin) {
99
0
      var headers = HTTPHeaders()
100
0
      headers.reserveCapacity(4 + self.configuration.allowedHeaders.count)
101
0
      headers.add(name: .accessControlAllowOrigin, value: allowedOrigin)
102
0
      headers.add(name: .accessControlAllowMethods, value: "POST")
103
0
104
0
      for value in self.configuration.allowedHeaders {
105
0
        headers.add(name: .accessControlAllowHeaders, value: value)
106
0
      }
107
0
108
0
      if self.configuration.allowCredentialedRequests {
109
0
        headers.add(name: .accessControlAllowCredentials, value: "true")
110
0
      }
111
0
112
0
      if self.configuration.preflightCacheExpiration > 0 {
113
0
        headers.add(
114
0
          name: .accessControlMaxAge,
115
0
          value: "\(self.configuration.preflightCacheExpiration)"
116
0
        )
117
0
      }
118
0
      responseHead = HTTPResponseHead(version: head.version, status: .ok, headers: headers)
119
0
    } else {
120
0
      // Not allowed; respond with 403. This is okay in a pre-flight request.
121
0
      responseHead = HTTPResponseHead(version: head.version, status: .forbidden)
122
0
    }
123
0
124
0
    context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
125
0
  }
126
}
127
128
extension WebCORSHandler: ChannelOutboundHandler {
129
  typealias OutboundIn = HTTPServerResponsePart
130
131
629k
  func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
132
629k
    let responsePart = self.unwrapOutboundIn(data)
133
629k
    switch responsePart {
134
629k
    case var .head(responseHead):
135
233k
      switch self.state {
136
233k
      case let .processingRequest(origin):
137
233k
        self.prepareCORSResponseHead(&responseHead, origin: origin)
138
233k
        context.write(self.wrapOutboundOut(.head(responseHead)), promise: promise)
139
233k
140
233k
      case .idle, .processingPreflightRequest:
141
0
        assertionFailure("Writing response head when no request is in progress")
142
0
        context.close(promise: nil)
143
629k
      }
144
629k
145
629k
    case .body:
146
162k
      context.write(data, promise: promise)
147
629k
148
629k
    case .end:
149
233k
      self.state = .idle
150
233k
      context.write(data, promise: promise)
151
629k
    }
152
629k
  }
153
154
356k
  private func prepareCORSResponseHead(_ head: inout HTTPResponseHead, origin: String?) {
155
356k
    guard let header = origin.flatMap({ self.configuration.allowedOrigins.header($0) }) else {
156
353k
      // No origin or the origin is not allowed; don't treat it as a CORS request.
157
353k
      return
158
353k
    }
159
3.16k
160
3.16k
    head.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: header)
161
3.16k
162
3.16k
    if self.configuration.allowCredentialedRequests {
163
0
      head.headers.add(name: .accessControlAllowCredentials, value: "true")
164
3.16k
    }
165
3.16k
166
3.16k
    //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
167
3.16k
    // now as the channel has a state that can't be reused since the pipeline is modified to
168
3.16k
    // inject the gRPC call handler.
169
3.16k
    head.headers.replaceOrAdd(name: "Connection", value: "close")
170
3.16k
  }
171
}
172
173
extension HTTPHeaders {
174
  fileprivate enum CORSHeader: String {
175
    case accessControlRequestMethod = "access-control-request-method"
176
    case accessControlRequestHeaders = "access-control-request-headers"
177
    case accessControlAllowOrigin = "access-control-allow-origin"
178
    case accessControlAllowMethods = "access-control-allow-methods"
179
    case accessControlAllowHeaders = "access-control-allow-headers"
180
    case accessControlAllowCredentials = "access-control-allow-credentials"
181
    case accessControlMaxAge = "access-control-max-age"
182
  }
183
184
1.90k
  fileprivate func contains(_ name: CORSHeader) -> Bool {
185
1.90k
    return self.contains(name: name.rawValue)
186
1.90k
  }
187
188
0
  fileprivate mutating func add(name: CORSHeader, value: String) {
189
0
    self.add(name: name.rawValue, value: value)
190
0
  }
191
192
3.16k
  fileprivate mutating func replaceOrAdd(name: CORSHeader, value: String) {
193
3.16k
    self.replaceOrAdd(name: name.rawValue, value: value)
194
3.16k
  }
195
}
196
197
extension Server.Configuration.CORS.AllowedOrigins {
198
3.16k
  internal func header(_ origin: String) -> String? {
199
3.16k
    switch self.wrapped {
200
3.16k
    case .all:
201
3.16k
      return "*"
202
3.16k
    case .originBased:
203
0
      return origin
204
3.16k
    case let .only(allowed):
205
0
      return allowed.contains(origin) ? origin : nil
206
3.16k
    case let .custom(custom):
207
0
      return custom.check(origin: origin)
208
3.16k
    }
209
3.16k
  }
210
}