/src/swift-protobuf/Sources/SwiftProtobuf/AsyncMessageSequence.swift
Line | Count | Source |
1 | | // |
2 | | // Sources/SwiftProtobuf/AsyncMessageSequence.swift - Async sequence over binary delimited protobuf |
3 | | // |
4 | | // Copyright (c) 2023 Apple Inc. and the project authors |
5 | | // Licensed under Apache License v2.0 with Runtime Library Exception |
6 | | // |
7 | | // See LICENSE.txt for license information: |
8 | | // https://github.com/apple/swift-protobuf/blob/main/LICENSE.txt |
9 | | // |
10 | | // ----------------------------------------------------------------------------- |
11 | | /// |
12 | | /// An async sequence of messages decoded from a binary delimited protobuf stream. |
13 | | /// |
14 | | // ----------------------------------------------------------------------------- |
15 | | |
16 | | @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) |
17 | | extension AsyncSequence where Element == UInt8 { |
18 | | /// Creates an asynchronous sequence of size-delimited messages from this sequence of bytes. |
19 | | /// Delimited format allows a single file or stream to contain multiple messages. A delimited message |
20 | | /// is a varint encoding the message size followed by a message of exactly that size. |
21 | | /// |
22 | | /// - Parameters: |
23 | | /// - messageType: The type of message to read. |
24 | | /// - extensions: An ``ExtensionMap`` used to look up and decode any extensions in |
25 | | /// messages encoded by this sequence, or in messages nested within these messages. |
26 | | /// - partial: If `false` (the default), after decoding a message, ``Message/isInitialized-6abgi` |
27 | | /// will be checked to ensure all fields are present. |
28 | | /// - options: The ``BinaryDecodingOptions`` to use. |
29 | | /// - Returns: An asynchronous sequence of messages read from the `AsyncSequence` of bytes. |
30 | | @inlinable |
31 | | public func binaryProtobufDelimitedMessages<M: Message>( |
32 | | of messageType: M.Type = M.self, |
33 | | extensions: (any ExtensionMap)? = nil, |
34 | | partial: Bool = false, |
35 | | options: BinaryDecodingOptions = BinaryDecodingOptions() |
36 | 52.6k | ) -> AsyncMessageSequence<Self, M> { |
37 | 52.6k | AsyncMessageSequence<Self, M>( |
38 | 52.6k | base: self, |
39 | 52.6k | extensions: extensions, |
40 | 52.6k | partial: partial, |
41 | 52.6k | options: options |
42 | 52.6k | ) |
43 | 52.6k | } |
44 | | } |
45 | | |
46 | | /// An asynchronous sequence of messages decoded from an asynchronous sequence of bytes. |
47 | | @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) |
48 | | public struct AsyncMessageSequence< |
49 | | Base: AsyncSequence, |
50 | | M: Message |
51 | | >: AsyncSequence where Base.Element == UInt8 { |
52 | | |
53 | | /// The message type in this asynchronous sequence. |
54 | | public typealias Element = M |
55 | | |
56 | | private let base: Base |
57 | | private let extensions: (any ExtensionMap)? |
58 | | private let partial: Bool |
59 | | private let options: BinaryDecodingOptions |
60 | | |
61 | | /// Reads size-delimited messages from the given sequence of bytes. Delimited |
62 | | /// format allows a single file or stream to contain multiple messages. A delimited message |
63 | | /// is a varint encoding the message size followed by a message of exactly that size. |
64 | | /// |
65 | | /// - Parameters: |
66 | | /// - base: The `AsyncSequence` to read messages from. |
67 | | /// - extensions: An ``ExtensionMap`` used to look up and decode any extensions in |
68 | | /// messages encoded by this sequence, or in messages nested within these messages. |
69 | | /// - partial: If `false` (the default), after decoding a message, ``Message/isInitialized-6abgi`` |
70 | | /// will be checked to ensure all fields are present. |
71 | | /// - options: The ``BinaryDecodingOptions`` to use. |
72 | | /// - Returns: An asynchronous sequence of messages read from the `AsyncSequence` of bytes. |
73 | | public init( |
74 | | base: Base, |
75 | | extensions: (any ExtensionMap)? = nil, |
76 | | partial: Bool = false, |
77 | | options: BinaryDecodingOptions = BinaryDecodingOptions() |
78 | 52.6k | ) { |
79 | 52.6k | self.base = base |
80 | 52.6k | self.extensions = extensions |
81 | 52.6k | self.partial = partial |
82 | 52.6k | self.options = options |
83 | 52.6k | } |
84 | | |
85 | | /// An asynchronous iterator that produces the messages of this asynchronous sequence. |
86 | | public struct AsyncIterator: AsyncIteratorProtocol { |
87 | | @usableFromInline |
88 | | var iterator: Base.AsyncIterator? |
89 | | @usableFromInline |
90 | | let extensions: (any ExtensionMap)? |
91 | | @usableFromInline |
92 | | let partial: Bool |
93 | | @usableFromInline |
94 | | let options: BinaryDecodingOptions |
95 | | |
96 | | init( |
97 | | iterator: Base.AsyncIterator, |
98 | | extensions: (any ExtensionMap)?, |
99 | | partial: Bool, |
100 | | options: BinaryDecodingOptions |
101 | 13.1k | ) { |
102 | 13.1k | self.iterator = iterator |
103 | 13.1k | self.extensions = extensions |
104 | 13.1k | self.partial = partial |
105 | 13.1k | self.options = options |
106 | 13.1k | } |
107 | | |
108 | | /// Asynchronously reads the next varint. |
109 | | @inlinable |
110 | 170k | mutating func nextVarInt() async throws -> UInt64? { |
111 | 170k | var messageSize: UInt64 = 0 |
112 | 170k | var shift: UInt64 = 0 |
113 | 170k | |
114 | 170k | while let byte = try await iterator?.next() { |
115 | 167k | messageSize |= UInt64(byte & 0x7f) << shift |
116 | 167k | shift += UInt64(7) |
117 | 167k | if shift > 35 { |
118 | 12 | iterator = nil |
119 | 12 | throw SwiftProtobufError.BinaryStreamDecoding.malformedLength() |
120 | 167k | } |
121 | 167k | if byte & 0x80 == 0 { |
122 | 158k | return messageSize |
123 | 158k | } |
124 | 11.7k | } |
125 | 11.7k | if shift > 0 { |
126 | 484 | // The stream has ended inside a varint. |
127 | 484 | iterator = nil |
128 | 484 | throw BinaryDelimited.Error.truncated |
129 | 11.2k | } |
130 | 11.2k | return nil // End of stream reached. |
131 | 170k | } |
132 | | |
133 | | /// Helper to read the given number of bytes. |
134 | | @usableFromInline |
135 | 34.7k | mutating func readBytes(_ size: Int) async throws -> [UInt8] { |
136 | 34.7k | // Even though the bytes are read in chunks, things can still hard fail if |
137 | 34.7k | // there isn't enough memory to append to have all the bytes at once for |
138 | 34.7k | // parsing; but this at least catches some possible OOM attacks. |
139 | 34.7k | var bytesNeeded = size |
140 | 34.7k | var buffer = [UInt8]() |
141 | 34.7k | let kChunkSize = 16 * 1024 * 1024 |
142 | 34.7k | var chunk = [UInt8](repeating: 0, count: Swift.min(bytesNeeded, kChunkSize)) |
143 | 68.9k | while bytesNeeded > 0 { |
144 | 34.7k | var consumedBytes = 0 |
145 | 34.7k | let maxLength = Swift.min(bytesNeeded, chunk.count) |
146 | 13.5M | while consumedBytes < maxLength { |
147 | 13.4M | guard let byte = try await iterator?.next() else { |
148 | 616 | // The iterator hit the end, but the chunk wasn't filled, so the full |
149 | 616 | // payload wasn't read. |
150 | 616 | throw BinaryDelimited.Error.truncated |
151 | 13.4M | } |
152 | 13.4M | chunk[consumedBytes] = byte |
153 | 13.4M | consumedBytes += 1 |
154 | 13.4M | } |
155 | 34.1k | if consumedBytes < chunk.count { |
156 | 0 | buffer += chunk[0..<consumedBytes] |
157 | 34.1k | } else { |
158 | 34.1k | buffer += chunk |
159 | 34.1k | } |
160 | 34.1k | bytesNeeded -= maxLength |
161 | 34.1k | } |
162 | 34.1k | return buffer |
163 | 34.7k | } |
164 | | |
165 | | /// Asynchronously advances to the next message and returns it, or ends the |
166 | | /// sequence if there is no next message. |
167 | | /// |
168 | | /// - Returns: The next message, if it exists, or `nil` to signal the end of |
169 | | /// the sequence. |
170 | | @inlinable |
171 | 170k | public mutating func next() async throws -> M? { |
172 | 170k | guard let messageSize = try await nextVarInt() else { |
173 | 11.2k | iterator = nil |
174 | 11.2k | return nil |
175 | 158k | } |
176 | 158k | guard messageSize <= UInt64(0x7fff_ffff) else { |
177 | 28 | iterator = nil |
178 | 28 | throw SwiftProtobufError.BinaryDecoding.tooLarge() |
179 | 158k | } |
180 | 158k | if messageSize == 0 { |
181 | 19.6k | return try M( |
182 | 19.6k | serializedBytes: [], |
183 | 19.6k | extensions: extensions, |
184 | 19.6k | partial: partial, |
185 | 19.6k | options: options |
186 | 19.6k | ) |
187 | 138k | } |
188 | 138k | let buffer = try await readBytes(Int(messageSize)) |
189 | 136k | return try M( |
190 | 136k | serializedBytes: buffer, |
191 | 136k | extensions: extensions, |
192 | 136k | partial: partial, |
193 | 136k | options: options |
194 | 136k | ) |
195 | 170k | } |
196 | | } |
197 | | |
198 | | /// Creates the asynchronous iterator that produces elements of this |
199 | | /// asynchronous sequence. |
200 | | /// |
201 | | /// - Returns: An instance of the `AsyncIterator` type used to produce |
202 | | /// messages in the asynchronous sequence. |
203 | 13.2k | public func makeAsyncIterator() -> AsyncMessageSequence.AsyncIterator { |
204 | 13.2k | AsyncIterator( |
205 | 13.2k | iterator: base.makeAsyncIterator(), |
206 | 13.2k | extensions: extensions, |
207 | 13.2k | partial: partial, |
208 | 13.2k | options: options |
209 | 13.2k | ) |
210 | 13.2k | } |
211 | | } |
212 | | |
213 | | @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) |
214 | | extension AsyncMessageSequence: Sendable where Base: Sendable {} |
215 | | |
216 | | @available(*, unavailable) |
217 | | extension AsyncMessageSequence.AsyncIterator: Sendable {} |