/src/swift-nio/Sources/NIOPosix/BaseStreamSocketChannel.swift
Line | Count | Source |
1 | | //===----------------------------------------------------------------------===// |
2 | | // |
3 | | // This source file is part of the SwiftNIO open source project |
4 | | // |
5 | | // Copyright (c) 2019-2021 Apple Inc. and the SwiftNIO project authors |
6 | | // Licensed under Apache License v2.0 |
7 | | // |
8 | | // See LICENSE.txt for license information |
9 | | // See CONTRIBUTORS.txt for the list of SwiftNIO project authors |
10 | | // |
11 | | // SPDX-License-Identifier: Apache-2.0 |
12 | | // |
13 | | //===----------------------------------------------------------------------===// |
14 | | |
15 | | #if !os(WASI) |
16 | | |
17 | | import NIOCore |
18 | | |
19 | | class BaseStreamSocketChannel<Socket: SocketProtocol>: BaseSocketChannel<Socket>, @unchecked Sendable { |
20 | | internal var connectTimeoutScheduled: Optional<Scheduled<Void>> |
21 | 0 | private var allowRemoteHalfClosure: Bool = false |
22 | 0 | private var inputShutdown: Bool = false |
23 | 0 | private var outputShutdown: Bool = false |
24 | | private let pendingWrites: PendingStreamWritesManager |
25 | | |
26 | | init( |
27 | | socket: Socket, |
28 | | parent: Channel?, |
29 | | eventLoop: SelectableEventLoop, |
30 | | recvAllocator: RecvByteBufferAllocator |
31 | 0 | ) throws { |
32 | 0 | self.pendingWrites = PendingStreamWritesManager(bufferPool: eventLoop.bufferPool) |
33 | 0 | self.connectTimeoutScheduled = nil |
34 | 0 | try super.init( |
35 | 0 | socket: socket, |
36 | 0 | parent: parent, |
37 | 0 | eventLoop: eventLoop, |
38 | 0 | recvAllocator: recvAllocator, |
39 | 0 | supportReconnect: false |
40 | 0 | ) |
41 | 0 | } |
42 | | |
43 | 0 | deinit { |
44 | 0 | // We should never have any pending writes left as otherwise we may leak callbacks |
45 | 0 | assert(self.pendingWrites.isEmpty) |
46 | 0 | } |
47 | | |
48 | | // MARK: BaseSocketChannel's must override API that might be further refined by subclasses |
49 | 0 | override func setOption0<Option: ChannelOption>(_ option: Option, value: Option.Value) throws { |
50 | 0 | self.eventLoop.assertInEventLoop() |
51 | 0 |
|
52 | 0 | guard self.isOpen else { |
53 | 0 | throw ChannelError._ioOnClosedChannel |
54 | 0 | } |
55 | 0 |
|
56 | 0 | switch option { |
57 | 0 | case _ as ChannelOptions.Types.AllowRemoteHalfClosureOption: |
58 | 0 | self.allowRemoteHalfClosure = value as! Bool |
59 | 0 | case _ as ChannelOptions.Types.WriteSpinOption: |
60 | 0 | self.pendingWrites.writeSpinCount = value as! UInt |
61 | 0 | case _ as ChannelOptions.Types.WriteBufferWaterMarkOption: |
62 | 0 | self.pendingWrites.waterMark = value as! ChannelOptions.Types.WriteBufferWaterMark |
63 | 0 | default: |
64 | 0 | try super.setOption0(option, value: value) |
65 | 0 | } |
66 | 0 | } |
67 | | |
68 | 0 | override func getOption0<Option: ChannelOption>(_ option: Option) throws -> Option.Value { |
69 | 0 | self.eventLoop.assertInEventLoop() |
70 | 0 |
|
71 | 0 | guard self.isOpen else { |
72 | 0 | throw ChannelError._ioOnClosedChannel |
73 | 0 | } |
74 | 0 |
|
75 | 0 | switch option { |
76 | 0 | case _ as ChannelOptions.Types.AllowRemoteHalfClosureOption: |
77 | 0 | return self.allowRemoteHalfClosure as! Option.Value |
78 | 0 | case _ as ChannelOptions.Types.WriteSpinOption: |
79 | 0 | return self.pendingWrites.writeSpinCount as! Option.Value |
80 | 0 | case _ as ChannelOptions.Types.WriteBufferWaterMarkOption: |
81 | 0 | return self.pendingWrites.waterMark as! Option.Value |
82 | 0 | case _ as ChannelOptions.Types.BufferedWritableBytesOption: |
83 | 0 | return Int(self.pendingWrites.bufferedBytes) as! Option.Value |
84 | 0 | default: |
85 | 0 | return try super.getOption0(option) |
86 | 0 | } |
87 | 0 | } |
88 | | |
89 | | // Hook for customizable socket shutdown processing for subclasses, e.g. PipeChannel |
90 | 0 | func shutdownSocket(mode: CloseMode) throws { |
91 | 0 | switch mode { |
92 | 0 | case .output: |
93 | 0 | try self.socket.shutdown(how: .WR) |
94 | 0 | self.outputShutdown = true |
95 | 0 | case .input: |
96 | 0 | try socket.shutdown(how: .RD) |
97 | 0 | self.inputShutdown = true |
98 | 0 | case .all: |
99 | 0 | break |
100 | 0 | } |
101 | 0 | } |
102 | | |
103 | | // MARK: BaseSocketChannel's must override API that cannot be further refined by subclasses |
104 | | // This is `Channel` API so must be thread-safe. |
105 | 0 | final override public var isWritable: Bool { |
106 | 0 | self.pendingWrites.isWritable |
107 | 0 | } |
108 | | |
109 | 0 | final override var isOpen: Bool { |
110 | 0 | self.eventLoop.assertInEventLoop() |
111 | 0 | assert(super.isOpen == self.pendingWrites.isOpen) |
112 | 0 | return super.isOpen |
113 | 0 | } |
114 | | |
115 | 0 | final override func readFromSocket() throws -> ReadResult { |
116 | 0 | self.eventLoop.assertInEventLoop() |
117 | 0 | var result = ReadResult.none |
118 | 0 | for _ in 1...self.maxMessagesPerRead { |
119 | 0 | guard self.isOpen && !self.inputShutdown else { |
120 | 0 | throw ChannelError._eof |
121 | 0 | } |
122 | 0 |
|
123 | 0 | let (buffer, readResult) = try self.recvBufferPool.buffer(allocator: self.allocator) { buffer in |
124 | 0 | try buffer.withMutableWritePointer { pointer in |
125 | 0 | try self.socket.read(pointer: pointer) |
126 | 0 | } |
127 | 0 | } |
128 | 0 |
|
129 | 0 | // Reset reader and writerIndex and so allow to have the buffer filled again. This is better here than at |
130 | 0 | // the end of the loop to not do an allocation when the loop exits. |
131 | 0 | switch readResult { |
132 | 0 | case .processed(let bytesRead): |
133 | 0 | if bytesRead > 0 { |
134 | 0 | self.recvBufferPool.record(actualReadBytes: bytesRead) |
135 | 0 | self.readPending = false |
136 | 0 |
|
137 | 0 | assert(self.isActive) |
138 | 0 | self.pipeline.syncOperations.fireChannelRead(NIOAny(buffer)) |
139 | 0 | result = .some |
140 | 0 |
|
141 | 0 | if buffer.writableBytes > 0 { |
142 | 0 | // If we did not fill the whole buffer with read(...) we should stop reading and wait until we get notified again. |
143 | 0 | // Otherwise chances are good that the next read(...) call will either read nothing or only a very small amount of data. |
144 | 0 | // Also this will allow us to call fireChannelReadComplete() which may give the user the chance to flush out all pending |
145 | 0 | // writes. |
146 | 0 | return result |
147 | 0 | } |
148 | 0 | } else { |
149 | 0 | if self.inputShutdown { |
150 | 0 | // We received a EOF because we called shutdown on the fd by ourself, unregister from the Selector and return |
151 | 0 | self.readPending = false |
152 | 0 | self.unregisterForReadable() |
153 | 0 | return result |
154 | 0 | } |
155 | 0 | // end-of-file |
156 | 0 | throw ChannelError._eof |
157 | 0 | } |
158 | 0 | case .wouldBlock(let bytesRead): |
159 | 0 | assert(bytesRead == 0) |
160 | 0 | return result |
161 | 0 | } |
162 | 0 | } |
163 | 0 | return result |
164 | 0 | } |
165 | | |
166 | 0 | final override func writeToSocket() throws -> OverallWriteResult { |
167 | 0 | let result = try self.pendingWrites.triggerAppropriateWriteOperations( |
168 | 0 | scalarBufferWriteOperation: { ptr in |
169 | 0 | guard ptr.count > 0 else { |
170 | 0 | // No need to call write if the buffer is empty. |
171 | 0 | return .processed(0) |
172 | 0 | } |
173 | 0 | // normal write |
174 | 0 | return try self.socket.write(pointer: ptr) |
175 | 0 | }, |
176 | 0 | vectorBufferWriteOperation: { ptrs in |
177 | 0 | // Gathering write |
178 | 0 | try self.socket.writev(iovecs: ptrs) |
179 | 0 | }, |
180 | 0 | scalarFileWriteOperation: { descriptor, index, endIndex in |
181 | 0 | try self.socket.sendFile(fd: descriptor, offset: index, count: endIndex - index) |
182 | 0 | } |
183 | 0 | ) |
184 | 0 | return result |
185 | 0 | } |
186 | | |
187 | 0 | final override func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) { |
188 | 0 | do { |
189 | 0 | switch mode { |
190 | 0 | case .output: |
191 | 0 | if self.outputShutdown { |
192 | 0 | promise?.fail(ChannelError._outputClosed) |
193 | 0 | return |
194 | 0 | } |
195 | 0 | if self.inputShutdown { |
196 | 0 | // Escalate to full closure |
197 | 0 | self.close0(error: error, mode: .all, promise: promise) |
198 | 0 | return |
199 | 0 | } |
200 | 0 |
|
201 | 0 | let result = self.pendingWrites.closeOutbound(promise) |
202 | 0 | switch result { |
203 | 0 | case .pending: |
204 | 0 | () // promise is stored in `pendingWrites` state for completing later |
205 | 0 |
|
206 | 0 | case .readyForClose(let closePromise), .closed(let closePromise): |
207 | 0 | // Shutdown the socket only when the pending writes are dealt with ... |
208 | 0 | // ... or if we think we are already closed - just to make sure it *is* closed / to match the old behavior |
209 | 0 | do { |
210 | 0 | try self.shutdownSocket(mode: mode) |
211 | 0 | closePromise?.succeed(()) |
212 | 0 | } catch let err { |
213 | 0 | closePromise?.fail(err) |
214 | 0 | } |
215 | 0 | self.unregisterForWritable() |
216 | 0 | self.pipeline.fireUserInboundEventTriggered(ChannelEvent.outputClosed) |
217 | 0 |
|
218 | 0 | case .errored(let err, let closePromise): |
219 | 0 | assertionFailure("Close errored: \(err)") |
220 | 0 | closePromise?.fail(err) |
221 | 0 |
|
222 | 0 | // Escalate to full closure |
223 | 0 | // promise is nil here because we have used the supplied promise to convey failure of the half-close |
224 | 0 | self.close0(error: err, mode: .all, promise: nil) |
225 | 0 | } |
226 | 0 |
|
227 | 0 | case .input: |
228 | 0 | if self.inputShutdown { |
229 | 0 | promise?.fail(ChannelError._inputClosed) |
230 | 0 | return |
231 | 0 | } |
232 | 0 | if self.outputShutdown { |
233 | 0 | // Escalate to full closure |
234 | 0 | self.close0(error: error, mode: .all, promise: promise) |
235 | 0 | return |
236 | 0 | } |
237 | 0 | switch error { |
238 | 0 | case ChannelError.eof: |
239 | 0 | // No need to explicit call socket.shutdown(...) as we received an EOF and the call would only cause |
240 | 0 | // ENOTCON |
241 | 0 | self.inputShutdown = true |
242 | 0 | break |
243 | 0 | default: |
244 | 0 | try self.shutdownSocket(mode: mode) |
245 | 0 | } |
246 | 0 | self.unregisterForReadable() |
247 | 0 | promise?.succeed(()) |
248 | 0 |
|
249 | 0 | self.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) |
250 | 0 |
|
251 | 0 | case .all: |
252 | 0 | if let timeout = self.connectTimeoutScheduled { |
253 | 0 | self.connectTimeoutScheduled = nil |
254 | 0 | timeout.cancel() |
255 | 0 | } |
256 | 0 | super.close0(error: error, mode: mode, promise: promise) |
257 | 0 | } |
258 | 0 | } catch let err { |
259 | 0 | promise?.fail(err) |
260 | 0 | } |
261 | 0 | } |
262 | | |
263 | 0 | final override func hasFlushedPendingWrites() -> Bool { |
264 | 0 | self.pendingWrites.isFlushPending |
265 | 0 | } |
266 | | |
267 | 0 | final override func markFlushPoint() { |
268 | 0 | // Even if writable() will be called later by the EventLoop we still need to mark the flush checkpoint so we are sure all the flushed messages |
269 | 0 | // are actually written once writable() is called. |
270 | 0 | self.pendingWrites.markFlushCheckpoint() |
271 | 0 | } |
272 | | |
273 | 0 | final override func cancelWritesOnClose(error: Error) { |
274 | 0 | if let eventLoopPromise = self.pendingWrites.failAll(error: error) { |
275 | 0 | eventLoopPromise.fail(error) |
276 | 0 | } |
277 | 0 | } |
278 | | |
279 | | @discardableResult |
280 | 0 | final override func readIfNeeded0() -> Bool { |
281 | 0 | if self.inputShutdown { |
282 | 0 | return false |
283 | 0 | } |
284 | 0 | return super.readIfNeeded0() |
285 | 0 | } |
286 | | |
287 | 0 | final override public func read0() { |
288 | 0 | if self.inputShutdown { |
289 | 0 | return |
290 | 0 | } |
291 | 0 | super.read0() |
292 | 0 | } |
293 | | |
294 | 0 | final override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise<Void>?) { |
295 | 0 | if self.outputShutdown { |
296 | 0 | promise?.fail(ChannelError._outputClosed) |
297 | 0 | return |
298 | 0 | } |
299 | 0 |
|
300 | 0 | let data = self.unwrapData(data, as: IOData.self) |
301 | 0 |
|
302 | 0 | if !self.pendingWrites.add(data: data, promise: promise) { |
303 | 0 | self.pipeline.syncOperations.fireChannelWritabilityChanged() |
304 | 0 | } |
305 | 0 | } |
306 | | } |
307 | | #endif // !os(WASI) |