/proc/self/cwd/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc
Line | Count | Source (jump to first uncovered line) |
1 | | #include "source/extensions/filters/network/thrift_proxy/header_transport_impl.h" |
2 | | |
3 | | #include <limits> |
4 | | |
5 | | #include "envoy/common/exception.h" |
6 | | #include "envoy/http/header_formatter.h" |
7 | | |
8 | | #include "source/common/buffer/buffer_impl.h" |
9 | | #include "source/extensions/filters/network/thrift_proxy/buffer_helper.h" |
10 | | |
11 | | #include "absl/strings/str_replace.h" |
12 | | |
13 | | namespace Envoy { |
14 | | namespace Extensions { |
15 | | namespace NetworkFilters { |
16 | | namespace ThriftProxy { |
17 | | namespace { |
18 | | |
19 | | // c.f. |
20 | | // https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocolTypes.h#L27 |
21 | | enum class HeaderProtocolType { |
22 | | Binary = 0, |
23 | | JSON = 1, |
24 | | Compact = 2, |
25 | | |
26 | | FirstHeaderProtocolType = Binary, |
27 | | LastHeaderProtocolType = Compact, |
28 | | }; |
29 | | |
30 | | // Fixed portion of frame header: |
31 | | // Header magic: 2 bytes + |
32 | | // Flags: 2 bytes + |
33 | | // Sequence number: 4 bytes |
34 | | // Header data size: 2 bytes |
35 | | constexpr uint64_t MinFrameStartSizeNoHeaders = 10; |
36 | | |
37 | | // Minimum frame size: fixed portion of frame header + 4 bytes of header data (the minimum) |
38 | | constexpr int32_t MinFrameStartSize = MinFrameStartSizeNoHeaders + 4; |
39 | | |
40 | | // Minimum to start decoding: 4 bytes of frame size + the fixed portion of the frame header |
41 | | constexpr uint64_t MinDecodeBytes = MinFrameStartSizeNoHeaders + 4; |
42 | | |
43 | | // Maximum size for header data. |
44 | | constexpr int32_t MaxHeadersSize = 65536; |
45 | | |
46 | | } // namespace |
47 | | |
48 | 0 | bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) { |
49 | 0 | if (buffer.length() < MinDecodeBytes) { |
50 | 0 | return false; |
51 | 0 | } |
52 | | |
53 | | // Size of frame, not including the length bytes. |
54 | 0 | const int32_t frame_size = buffer.peekBEInt<int32_t>(); |
55 | | |
56 | | // Minimum header frame size is 18 bytes (4 bytes of frame size + 10 bytes of fixed header + |
57 | | // minimum 4 bytes of variable header data), so frame_size must be at least 14. |
58 | 0 | if (frame_size < MinFrameStartSize || frame_size > MaxFrameSize) { |
59 | 0 | throw EnvoyException(absl::StrCat("invalid thrift header transport frame size ", frame_size)); |
60 | 0 | } |
61 | | |
62 | 0 | int16_t magic = buffer.peekBEInt<uint16_t>(4); |
63 | 0 | if (!isMagic(magic)) { |
64 | 0 | throw EnvoyException(fmt::format("invalid thrift header transport magic {:04x}", magic)); |
65 | 0 | } |
66 | | |
67 | | // offset 6: 16 bit flags field |
68 | 0 | int16_t header_flags = buffer.peekBEInt<int16_t>(6); |
69 | | |
70 | | // offset 8: 32 bit sequence number field |
71 | 0 | int32_t seq_id = buffer.peekBEInt<int32_t>(8); |
72 | | |
73 | | // offset 12: 16 bit (remaining) header size / 4 (spec erroneously claims / 32). |
74 | 0 | int16_t raw_header_size = buffer.peekBEInt<int16_t>(12); |
75 | 0 | int32_t header_size = static_cast<int32_t>(raw_header_size) * 4; |
76 | 0 | if (header_size < 0 || header_size > MaxHeadersSize) { |
77 | 0 | throw EnvoyException(fmt::format("invalid thrift header transport header size {} ({:04x})", |
78 | 0 | header_size, static_cast<uint16_t>(raw_header_size))); |
79 | 0 | } |
80 | | |
81 | 0 | if (header_size == 0) { |
82 | 0 | throw EnvoyException("no header data"); |
83 | 0 | } |
84 | | |
85 | 0 | if (buffer.length() < static_cast<uint64_t>(header_size) + MinDecodeBytes) { |
86 | | // Need more header data. |
87 | 0 | return false; |
88 | 0 | } |
89 | | |
90 | | // Header data starts at offset 14 (4 bytes of frame size followed by 10 bytes of fixed header). |
91 | 0 | buffer.drain(MinDecodeBytes); |
92 | | |
93 | | // Remaining frame size is the original frame size (which does not count itself), less the 10 |
94 | | // fixed bytes of the header (magic, flags, etc), less the size of the variable header data |
95 | | // (header_size). |
96 | 0 | metadata.setFrameSize( |
97 | 0 | static_cast<uint32_t>(frame_size - header_size - MinFrameStartSizeNoHeaders)); |
98 | 0 | metadata.setHeaderFlags(header_flags); |
99 | 0 | metadata.setSequenceId(seq_id); |
100 | |
|
101 | 0 | ProtocolType proto = ProtocolType::Auto; |
102 | 0 | HeaderProtocolType header_proto = |
103 | 0 | static_cast<HeaderProtocolType>(drainVarIntI16(buffer, header_size, "protocol id")); |
104 | 0 | switch (header_proto) { |
105 | 0 | case HeaderProtocolType::Binary: |
106 | 0 | proto = ProtocolType::Binary; |
107 | 0 | break; |
108 | 0 | case HeaderProtocolType::Compact: |
109 | 0 | proto = ProtocolType::Compact; |
110 | 0 | break; |
111 | 0 | default: |
112 | 0 | throw EnvoyException(fmt::format("Unknown protocol {}", static_cast<int>(header_proto))); |
113 | 0 | } |
114 | 0 | metadata.setProtocol(proto); |
115 | |
|
116 | 0 | int16_t num_xforms = drainVarIntI16(buffer, header_size, "transform count"); |
117 | 0 | if (num_xforms < 0) { |
118 | 0 | throw EnvoyException(absl::StrCat("invalid header transport transform count ", num_xforms)); |
119 | 0 | } |
120 | | |
121 | 0 | while (num_xforms-- > 0) { |
122 | 0 | int32_t xform_id = drainVarIntI32(buffer, header_size, "transform id"); |
123 | | |
124 | | // To date, no transforms have a data field. In the future, some transform IDs may require |
125 | | // consuming another varint 32 at this point. The known transform IDs are: |
126 | | // 1: zlib compression |
127 | | // 2: hmac (appended to end of packet) |
128 | | // 3: snappy compression |
129 | 0 | buffer.drain(header_size); |
130 | 0 | metadata.setAppException(AppExceptionType::MissingResult, |
131 | 0 | absl::StrCat("Unknown transform ", xform_id)); |
132 | 0 | return true; |
133 | 0 | } |
134 | | |
135 | 0 | const bool is_request = metadata.isRequest(); |
136 | 0 | auto formatter = |
137 | 0 | is_request ? metadata.requestHeaders().formatter() : metadata.responseHeaders().formatter(); |
138 | |
|
139 | 0 | while (header_size > 0) { |
140 | | // Attempt to read info blocks |
141 | 0 | int32_t info_id = drainVarIntI32(buffer, header_size, "info id"); |
142 | 0 | if (info_id != 1) { |
143 | | // 0 indicates a padding byte, and the end of the info block. |
144 | | // 1 indicates an info id header/value pair. |
145 | | // Any other value is an unknown info id block, which we ignore. |
146 | 0 | break; |
147 | 0 | } |
148 | | |
149 | 0 | int32_t num_headers = drainVarIntI32(buffer, header_size, "header count"); |
150 | 0 | if (num_headers < 0) { |
151 | 0 | throw EnvoyException(absl::StrCat("invalid header transport header count ", num_headers)); |
152 | 0 | } |
153 | | |
154 | 0 | while (num_headers-- > 0) { |
155 | 0 | std::string key_string = drainVarString(buffer, header_size, "header key"); |
156 | 0 | if (formatter) { |
157 | 0 | formatter->processKey(key_string); |
158 | 0 | } |
159 | | // LowerCaseString doesn't allow '\0', '\n', and '\r'. |
160 | 0 | key_string = |
161 | 0 | absl::StrReplaceAll(key_string, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}}); |
162 | |
|
163 | 0 | const Http::LowerCaseString key = Http::LowerCaseString(key_string); |
164 | 0 | const std::string value = drainVarString(buffer, header_size, "header value"); |
165 | |
|
166 | 0 | if (is_request) { |
167 | 0 | metadata.requestHeaders().addCopy(key, value); |
168 | 0 | } else { |
169 | 0 | metadata.responseHeaders().addCopy(key, value); |
170 | 0 | } |
171 | 0 | } |
172 | 0 | } |
173 | | |
174 | | // Remaining bytes are padding or ignored info blocks. |
175 | 0 | if (header_size > 0) { |
176 | 0 | buffer.drain(header_size); |
177 | 0 | } |
178 | |
|
179 | 0 | return true; |
180 | 0 | } |
181 | | |
182 | 0 | bool HeaderTransportImpl::decodeFrameEnd(Buffer::Instance&) { |
183 | 0 | exception_.reset(); |
184 | 0 | exception_reason_.clear(); |
185 | |
|
186 | 0 | return true; |
187 | 0 | } |
188 | | |
189 | | void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, |
190 | 0 | Buffer::Instance& message) { |
191 | 0 | uint64_t msg_size = message.length(); |
192 | 0 | if (msg_size == 0) { |
193 | 0 | throw EnvoyException(absl::StrCat("invalid thrift header transport message size ", msg_size)); |
194 | 0 | } |
195 | | |
196 | 0 | const uint32_t headers_size = |
197 | 0 | metadata.isRequest() ? metadata.requestHeaders().size() : metadata.responseHeaders().size(); |
198 | 0 | if (headers_size > MaxHeadersSize / 2) { |
199 | | // Each header takes a minimum of 2 bytes, yielding this limit. |
200 | 0 | throw EnvoyException( |
201 | 0 | absl::StrCat("invalid thrift header transport too many headers ", headers_size)); |
202 | 0 | } |
203 | | |
204 | 0 | Buffer::OwnedImpl header_buffer; |
205 | |
|
206 | 0 | if (!metadata.hasProtocol()) { |
207 | 0 | throw EnvoyException("missing header transport protocol"); |
208 | 0 | } |
209 | | |
210 | 0 | switch (metadata.protocol()) { |
211 | 0 | case ProtocolType::Binary: |
212 | 0 | BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(HeaderProtocolType::Binary)); |
213 | 0 | break; |
214 | 0 | case ProtocolType::Compact: |
215 | 0 | BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(HeaderProtocolType::Compact)); |
216 | 0 | break; |
217 | 0 | default: |
218 | 0 | throw EnvoyException(fmt::format("invalid header transport protocol {}", |
219 | 0 | ProtocolNames::get().fromType(metadata.protocol()))); |
220 | 0 | } |
221 | | |
222 | 0 | BufferHelper::writeVarIntI32(header_buffer, 0); // num transforms |
223 | |
|
224 | 0 | if (headers_size > 0) { |
225 | | // Info ID 1 |
226 | 0 | header_buffer.writeByte(1); |
227 | | |
228 | | // Num headers |
229 | 0 | BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(headers_size)); |
230 | |
|
231 | 0 | auto formatter = metadata.isRequest() ? metadata.requestHeaders().formatter() |
232 | 0 | : metadata.responseHeaders().formatter(); |
233 | |
|
234 | 0 | auto header_writer = [&header_buffer, |
235 | 0 | formatter](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { |
236 | 0 | const auto header_key = header.key().getStringView(); |
237 | |
|
238 | 0 | writeVarString(header_buffer, formatter ? formatter->format(header_key) : header_key); |
239 | 0 | writeVarString(header_buffer, header.value().getStringView()); |
240 | 0 | return Http::HeaderMap::Iterate::Continue; |
241 | 0 | }; |
242 | |
|
243 | 0 | if (metadata.isRequest()) { |
244 | 0 | metadata.requestHeaders().iterate(header_writer); |
245 | 0 | } else { |
246 | 0 | metadata.responseHeaders().iterate(header_writer); |
247 | 0 | } |
248 | 0 | } |
249 | |
|
250 | 0 | uint64_t header_size = header_buffer.length(); |
251 | | |
252 | | // Always pad (as the Apache implementation does). |
253 | 0 | const int padding = 4 - (header_size % 4); |
254 | 0 | header_buffer.add("\0\0\0\0", padding); |
255 | 0 | header_size += padding; |
256 | |
|
257 | 0 | if (header_size > MaxHeadersSize) { |
258 | 0 | throw EnvoyException(absl::StrCat("invalid thrift header transport header size ", header_size)); |
259 | 0 | } |
260 | | |
261 | | // Frame size does not include the frame length itself. |
262 | 0 | uint64_t size = header_size + msg_size + MinFrameStartSizeNoHeaders; |
263 | 0 | if (size > MaxFrameSize) { |
264 | 0 | throw EnvoyException(absl::StrCat("invalid thrift header transport frame size ", size)); |
265 | 0 | } |
266 | | |
267 | 0 | int32_t seq_id = 0; |
268 | 0 | if (metadata.hasSequenceId()) { |
269 | 0 | seq_id = metadata.sequenceId(); |
270 | 0 | } |
271 | 0 | int16_t header_flags = 0; |
272 | 0 | if (metadata.hasHeaderFlags()) { |
273 | 0 | header_flags = metadata.headerFlags(); |
274 | 0 | } |
275 | |
|
276 | 0 | buffer.writeBEInt<uint32_t>(static_cast<uint32_t>(size)); |
277 | 0 | buffer.writeBEInt<uint16_t>(Magic); |
278 | 0 | buffer.writeBEInt<uint16_t>(header_flags); // flags |
279 | 0 | buffer.writeBEInt<int32_t>(seq_id); |
280 | 0 | buffer.writeBEInt<uint16_t>(static_cast<uint16_t>(header_size / 4)); |
281 | |
|
282 | 0 | buffer.move(header_buffer); |
283 | 0 | buffer.move(message); |
284 | 0 | } |
285 | | |
286 | | int16_t HeaderTransportImpl::drainVarIntI16(Buffer::Instance& buffer, int32_t& header_size, |
287 | 0 | const char* desc) { |
288 | 0 | int32_t value = drainVarIntI32(buffer, header_size, desc); |
289 | 0 | if (value > static_cast<int32_t>(std::numeric_limits<int16_t>::max())) { |
290 | 0 | throw EnvoyException(fmt::format("header transport {}: value {} exceeds max i16 ({})", desc, |
291 | 0 | value, std::numeric_limits<int16_t>::max())); |
292 | 0 | } |
293 | 0 | return static_cast<int16_t>(value); |
294 | 0 | } |
295 | | |
296 | | int32_t HeaderTransportImpl::drainVarIntI32(Buffer::Instance& buffer, int32_t& header_size, |
297 | 0 | const char* desc) { |
298 | 0 | if (header_size <= 0) { |
299 | 0 | throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc)); |
300 | 0 | } |
301 | | |
302 | 0 | int size; |
303 | 0 | int32_t value = BufferHelper::peekVarIntI32(buffer, 0, size); |
304 | 0 | if (size < 0 || (header_size - size) < 0) { |
305 | 0 | throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc)); |
306 | 0 | } |
307 | 0 | buffer.drain(size); |
308 | 0 | header_size -= size; |
309 | 0 | return value; |
310 | 0 | } |
311 | | |
312 | | std::string HeaderTransportImpl::drainVarString(Buffer::Instance& buffer, int32_t& header_size, |
313 | 0 | const char* desc) { |
314 | 0 | const int16_t str_len = drainVarIntI16(buffer, header_size, desc); |
315 | 0 | if (str_len == 0) { |
316 | 0 | return ""; |
317 | 0 | } |
318 | | |
319 | 0 | if (header_size < static_cast<int32_t>(str_len)) { |
320 | 0 | throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc)); |
321 | 0 | } |
322 | | |
323 | 0 | const std::string value(static_cast<char*>(buffer.linearize(str_len)), str_len); |
324 | 0 | buffer.drain(str_len); |
325 | 0 | header_size -= str_len; |
326 | 0 | return value; |
327 | 0 | } |
328 | | |
329 | 0 | void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const absl::string_view str) { |
330 | 0 | const std::string::size_type len = str.length(); |
331 | 0 | if (len > static_cast<uint32_t>(std::numeric_limits<int16_t>::max())) { |
332 | 0 | throw EnvoyException(absl::StrCat("header string too long: ", len)); |
333 | 0 | } |
334 | | |
335 | 0 | BufferHelper::writeVarIntI32(buffer, static_cast<int32_t>(len)); |
336 | 0 | if (len == 0) { |
337 | 0 | return; |
338 | 0 | } |
339 | 0 | buffer.add(str.data(), len); |
340 | 0 | } |
341 | | |
342 | | class HeaderTransportConfigFactory : public TransportFactoryBase<HeaderTransportImpl> { |
343 | | public: |
344 | 4 | HeaderTransportConfigFactory() : TransportFactoryBase(TransportNames::get().HEADER) {} |
345 | | }; |
346 | | |
347 | | /** |
348 | | * Static registration for the header transport. @see RegisterFactory. |
349 | | */ |
350 | | REGISTER_FACTORY(HeaderTransportConfigFactory, NamedTransportConfigFactory); |
351 | | |
352 | | } // namespace ThriftProxy |
353 | | } // namespace NetworkFilters |
354 | | } // namespace Extensions |
355 | | } // namespace Envoy |