/src/grpc-swift/Sources/GRPC/ReadWriteStates.swift
Line | Count | Source |
1 | | /* |
2 | | * Copyright 2019, gRPC Authors All rights reserved. |
3 | | * |
4 | | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | * you may not use this file except in compliance with the License. |
6 | | * You may obtain a copy of the License at |
7 | | * |
8 | | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | | * |
10 | | * Unless required by applicable law or agreed to in writing, software |
11 | | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | * See the License for the specific language governing permissions and |
14 | | * limitations under the License. |
15 | | */ |
16 | | import NIOCore |
17 | | import SwiftProtobuf |
18 | | |
19 | | /// Number of messages expected on a stream. |
20 | | enum MessageArity { |
21 | | case one |
22 | | case many |
23 | | } |
24 | | |
25 | | /// Encapsulates the state required to create a new write state. |
26 | | struct PendingWriteState { |
27 | | /// The number of messages we expect to write to the stream. |
28 | | var arity: MessageArity |
29 | | |
30 | | /// The 'content-type' being written. |
31 | | var contentType: ContentType |
32 | | |
33 | | func makeWriteState( |
34 | | messageEncoding: ClientMessageEncoding, |
35 | | allocator: ByteBufferAllocator |
36 | 0 | ) -> WriteState { |
37 | 0 | let compression: CompressionAlgorithm? |
38 | 0 | switch messageEncoding { |
39 | 0 | case let .enabled(configuration): |
40 | 0 | compression = configuration.outbound |
41 | 0 | case .disabled: |
42 | 0 | compression = nil |
43 | 0 | } |
44 | 0 |
|
45 | 0 | let writer = CoalescingLengthPrefixedMessageWriter( |
46 | 0 | compression: compression, |
47 | 0 | allocator: allocator |
48 | 0 | ) |
49 | 0 | return .init(arity: self.arity, contentType: self.contentType, writer: writer) |
50 | 0 | } |
51 | | } |
52 | | |
53 | | /// The write state of a stream. |
54 | | struct WriteState { |
55 | | private var arity: MessageArity |
56 | | private var contentType: ContentType |
57 | | private var writer: CoalescingLengthPrefixedMessageWriter |
58 | | private var canWrite: Bool |
59 | | |
60 | | init( |
61 | | arity: MessageArity, |
62 | | contentType: ContentType, |
63 | | writer: CoalescingLengthPrefixedMessageWriter |
64 | 0 | ) { |
65 | 0 | self.arity = arity |
66 | 0 | self.contentType = contentType |
67 | 0 | self.writer = writer |
68 | 0 | self.canWrite = true |
69 | 0 | } |
70 | | |
71 | | /// Writes a message into a buffer using the `writer`. |
72 | | /// |
73 | | /// - Parameter message: The `Message` to write. |
74 | | mutating func write( |
75 | | _ message: ByteBuffer, |
76 | | compressed: Bool, |
77 | | promise: EventLoopPromise<Void>? |
78 | 0 | ) -> Result<Void, MessageWriteError> { |
79 | 0 | guard self.canWrite else { |
80 | 0 | return .failure(.cardinalityViolation) |
81 | 0 | } |
82 | 0 |
|
83 | 0 | self.writer.append(buffer: message, compress: compressed, promise: promise) |
84 | 0 |
|
85 | 0 | switch self.arity { |
86 | 0 | case .one: |
87 | 0 | self.canWrite = false |
88 | 0 | case .many: |
89 | 0 | () |
90 | 0 | } |
91 | 0 |
|
92 | 0 | return .success(()) |
93 | 0 | } |
94 | | |
95 | 0 | mutating func next() -> (Result<ByteBuffer, MessageWriteError>, EventLoopPromise<Void>?)? { |
96 | 0 | if let next = self.writer.next() { |
97 | 0 | return (next.0.mapError { _ in .serializationFailed }, next.1) |
98 | 0 | } else { |
99 | 0 | return nil |
100 | 0 | } |
101 | 0 | } |
102 | | } |
103 | | |
104 | | enum MessageWriteError: Error { |
105 | | /// Too many messages were written. |
106 | | case cardinalityViolation |
107 | | |
108 | | /// Message serialization failed. |
109 | | case serializationFailed |
110 | | |
111 | | /// An invalid state was encountered. This is a serious implementation error. |
112 | | case invalidState |
113 | | } |
114 | | |
115 | | /// Encapsulates the state required to create a new read state. |
116 | | struct PendingReadState { |
117 | | /// The number of messages we expect to read from the stream. |
118 | | var arity: MessageArity |
119 | | |
120 | | /// The message encoding configuration, and whether it's enabled or not. |
121 | | var messageEncoding: ClientMessageEncoding |
122 | | |
123 | 0 | func makeReadState(compression: CompressionAlgorithm? = nil) -> ReadState { |
124 | 0 | let reader: LengthPrefixedMessageReader |
125 | 0 | switch (self.messageEncoding, compression) { |
126 | 0 | case let (.enabled(configuration), .some(compression)): |
127 | 0 | reader = LengthPrefixedMessageReader( |
128 | 0 | compression: compression, |
129 | 0 | decompressionLimit: configuration.decompressionLimit |
130 | 0 | ) |
131 | 0 |
|
132 | 0 | case (.enabled, .none), |
133 | 0 | (.disabled, _): |
134 | 0 | reader = LengthPrefixedMessageReader() |
135 | 0 | } |
136 | 0 | return .reading(self.arity, reader) |
137 | 0 | } |
138 | | } |
139 | | |
140 | | /// The read state of a stream. |
141 | | enum ReadState { |
142 | | /// Reading may be attempted using the given reader. |
143 | | case reading(MessageArity, LengthPrefixedMessageReader) |
144 | | |
145 | | /// Reading may not be attempted: either a read previously failed or it is not valid for any |
146 | | /// more messages to be read. |
147 | | case notReading |
148 | | |
149 | | /// Consume the given `buffer` then attempt to read length-prefixed serialized messages. |
150 | | /// |
151 | | /// For an expected message count of `.one`, this function will produce **at most** 1 message. If |
152 | | /// a message has been produced then subsequent calls will result in an error. |
153 | | /// |
154 | | /// - Parameter buffer: The buffer to read from. |
155 | | mutating func readMessages( |
156 | | _ buffer: inout ByteBuffer, |
157 | | maxLength: Int |
158 | 0 | ) -> Result<[ByteBuffer], MessageReadError> { |
159 | 0 | switch self { |
160 | 0 | case .notReading: |
161 | 0 | return .failure(.cardinalityViolation) |
162 | 0 |
|
163 | 0 | case .reading(let readArity, var reader): |
164 | 0 | self = .notReading // Avoid CoWs |
165 | 0 | reader.append(buffer: &buffer) |
166 | 0 | var messages: [ByteBuffer] = [] |
167 | 0 |
|
168 | 0 | do { |
169 | 0 | while let serializedBytes = try reader.nextMessage(maxLength: maxLength) { |
170 | 0 | messages.append(serializedBytes) |
171 | 0 | } |
172 | 0 | } catch { |
173 | 0 | self = .notReading |
174 | 0 | if let grpcError = error as? GRPCError.WithContext { |
175 | 0 | if let compressionLimit = grpcError.error as? GRPCError.DecompressionLimitExceeded { |
176 | 0 | return .failure(.decompressionLimitExceeded(compressionLimit.compressedSize)) |
177 | 0 | } else if let lengthLimit = grpcError.error as? GRPCError.PayloadLengthLimitExceeded { |
178 | 0 | return .failure(.lengthExceedsLimit(lengthLimit)) |
179 | 0 | } |
180 | 0 | } |
181 | 0 |
|
182 | 0 | return .failure(.deserializationFailed) |
183 | 0 | } |
184 | 0 |
|
185 | 0 | // We need to validate the number of messages we decoded. Zero is fine because the payload may |
186 | 0 | // be split across frames. |
187 | 0 | switch (readArity, messages.count) { |
188 | 0 | // Always allowed: |
189 | 0 | case (.one, 0), |
190 | 0 | (.many, 0...): |
191 | 0 | self = .reading(readArity, reader) |
192 | 0 | return .success(messages) |
193 | 0 |
|
194 | 0 | // Also allowed, assuming we have no leftover bytes: |
195 | 0 | case (.one, 1): |
196 | 0 | // We can't read more than one message on a unary stream. |
197 | 0 | self = .notReading |
198 | 0 | // We shouldn't have any bytes leftover after reading a single message and we also should not |
199 | 0 | // have partially read a message. |
200 | 0 | if reader.unprocessedBytes != 0 || reader.isReading { |
201 | 0 | return .failure(.leftOverBytes) |
202 | 0 | } else { |
203 | 0 | return .success(messages) |
204 | 0 | } |
205 | 0 |
|
206 | 0 | // Anything else must be invalid. |
207 | 0 | default: |
208 | 0 | self = .notReading |
209 | 0 | return .failure(.cardinalityViolation) |
210 | 0 | } |
211 | 0 | } |
212 | 0 | } |
213 | | } |
214 | | |
215 | | enum MessageReadError: Error, Equatable { |
216 | | /// Too many messages were read. |
217 | | case cardinalityViolation |
218 | | |
219 | | /// Enough messages were read but bytes there are left-over bytes. |
220 | | case leftOverBytes |
221 | | |
222 | | /// Message deserialization failed. |
223 | | case deserializationFailed |
224 | | |
225 | | /// The limit for decompression was exceeded. |
226 | | case decompressionLimitExceeded(Int) |
227 | | |
228 | | /// The length of the message exceeded the permitted maximum length. |
229 | | case lengthExceedsLimit(GRPCError.PayloadLengthLimitExceeded) |
230 | | |
231 | | /// An invalid state was encountered. This is a serious implementation error. |
232 | | case invalidState |
233 | | } |