/rust/registry/src/index.crates.io-1949cf8c6b5b557f/prost-0.13.5/src/encoding/varint.rs
Line | Count | Source |
1 | | use core::cmp::min; |
2 | | use core::num::NonZeroU64; |
3 | | |
4 | | use ::bytes::{Buf, BufMut}; |
5 | | |
6 | | use crate::DecodeError; |
7 | | |
8 | | /// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. |
9 | | /// The buffer must have enough remaining space (maximum 10 bytes). |
10 | | #[inline] |
11 | 0 | pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) { |
12 | | // Varints are never more than 10 bytes |
13 | 0 | for _ in 0..10 { |
14 | 0 | if value < 0x80 { |
15 | 0 | buf.put_u8(value as u8); |
16 | 0 | break; |
17 | 0 | } else { |
18 | 0 | buf.put_u8(((value & 0x7F) | 0x80) as u8); |
19 | 0 | value >>= 7; |
20 | 0 | } |
21 | | } |
22 | 0 | } |
23 | | |
24 | | /// Returns the encoded length of the value in LEB128 variable length format. |
25 | | /// The returned value will be between 1 and 10, inclusive. |
26 | | #[inline] |
27 | 0 | pub const fn encoded_len_varint(value: u64) -> usize { |
28 | | // Based on [VarintSize64][1]. |
29 | | // [1]: https://github.com/protocolbuffers/protobuf/blob/v28.3/src/google/protobuf/io/coded_stream.h#L1744-L1756 |
30 | | // Safety: value | 1 is non-zero. |
31 | 0 | let log2value = unsafe { NonZeroU64::new_unchecked(value | 1) }.ilog2(); |
32 | 0 | ((log2value * 9 + (64 + 9)) / 64) as usize |
33 | 0 | } |
34 | | |
35 | | /// Decodes a LEB128-encoded variable length integer from the buffer. |
36 | | #[inline] |
37 | 0 | pub fn decode_varint(buf: &mut impl Buf) -> Result<u64, DecodeError> { |
38 | 0 | let bytes = buf.chunk(); |
39 | 0 | let len = bytes.len(); |
40 | 0 | if len == 0 { |
41 | 0 | return Err(DecodeError::new("invalid varint")); |
42 | 0 | } |
43 | | |
44 | 0 | let byte = bytes[0]; |
45 | 0 | if byte < 0x80 { |
46 | 0 | buf.advance(1); |
47 | 0 | Ok(u64::from(byte)) |
48 | 0 | } else if len > 10 || bytes[len - 1] < 0x80 { |
49 | 0 | let (value, advance) = decode_varint_slice(bytes)?; |
50 | 0 | buf.advance(advance); |
51 | 0 | Ok(value) |
52 | | } else { |
53 | 0 | decode_varint_slow(buf) |
54 | | } |
55 | 0 | } |
56 | | |
57 | | /// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the |
58 | | /// number of bytes read. |
59 | | /// |
60 | | /// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from |
61 | | /// [`ConsumeVarint`][2]. |
62 | | /// |
63 | | /// ## Safety |
64 | | /// |
65 | | /// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last |
66 | | /// element in bytes is < `0x80`. |
67 | | /// |
68 | | /// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406 |
69 | | /// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 |
70 | | #[inline] |
71 | 0 | fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> { |
72 | | // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance. |
73 | | |
74 | | // Use assertions to ensure memory safety, but it should always be optimized after inline. |
75 | 0 | assert!(!bytes.is_empty()); |
76 | 0 | assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80); |
77 | | |
78 | 0 | let mut b: u8 = unsafe { *bytes.get_unchecked(0) }; |
79 | 0 | let mut part0: u32 = u32::from(b); |
80 | 0 | if b < 0x80 { |
81 | 0 | return Ok((u64::from(part0), 1)); |
82 | 0 | }; |
83 | 0 | part0 -= 0x80; |
84 | 0 | b = unsafe { *bytes.get_unchecked(1) }; |
85 | 0 | part0 += u32::from(b) << 7; |
86 | 0 | if b < 0x80 { |
87 | 0 | return Ok((u64::from(part0), 2)); |
88 | 0 | }; |
89 | 0 | part0 -= 0x80 << 7; |
90 | 0 | b = unsafe { *bytes.get_unchecked(2) }; |
91 | 0 | part0 += u32::from(b) << 14; |
92 | 0 | if b < 0x80 { |
93 | 0 | return Ok((u64::from(part0), 3)); |
94 | 0 | }; |
95 | 0 | part0 -= 0x80 << 14; |
96 | 0 | b = unsafe { *bytes.get_unchecked(3) }; |
97 | 0 | part0 += u32::from(b) << 21; |
98 | 0 | if b < 0x80 { |
99 | 0 | return Ok((u64::from(part0), 4)); |
100 | 0 | }; |
101 | 0 | part0 -= 0x80 << 21; |
102 | 0 | let value = u64::from(part0); |
103 | | |
104 | 0 | b = unsafe { *bytes.get_unchecked(4) }; |
105 | 0 | let mut part1: u32 = u32::from(b); |
106 | 0 | if b < 0x80 { |
107 | 0 | return Ok((value + (u64::from(part1) << 28), 5)); |
108 | 0 | }; |
109 | 0 | part1 -= 0x80; |
110 | 0 | b = unsafe { *bytes.get_unchecked(5) }; |
111 | 0 | part1 += u32::from(b) << 7; |
112 | 0 | if b < 0x80 { |
113 | 0 | return Ok((value + (u64::from(part1) << 28), 6)); |
114 | 0 | }; |
115 | 0 | part1 -= 0x80 << 7; |
116 | 0 | b = unsafe { *bytes.get_unchecked(6) }; |
117 | 0 | part1 += u32::from(b) << 14; |
118 | 0 | if b < 0x80 { |
119 | 0 | return Ok((value + (u64::from(part1) << 28), 7)); |
120 | 0 | }; |
121 | 0 | part1 -= 0x80 << 14; |
122 | 0 | b = unsafe { *bytes.get_unchecked(7) }; |
123 | 0 | part1 += u32::from(b) << 21; |
124 | 0 | if b < 0x80 { |
125 | 0 | return Ok((value + (u64::from(part1) << 28), 8)); |
126 | 0 | }; |
127 | 0 | part1 -= 0x80 << 21; |
128 | 0 | let value = value + ((u64::from(part1)) << 28); |
129 | | |
130 | 0 | b = unsafe { *bytes.get_unchecked(8) }; |
131 | 0 | let mut part2: u32 = u32::from(b); |
132 | 0 | if b < 0x80 { |
133 | 0 | return Ok((value + (u64::from(part2) << 56), 9)); |
134 | 0 | }; |
135 | 0 | part2 -= 0x80; |
136 | 0 | b = unsafe { *bytes.get_unchecked(9) }; |
137 | 0 | part2 += u32::from(b) << 7; |
138 | | // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details. |
139 | | // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 |
140 | 0 | if b < 0x02 { |
141 | 0 | return Ok((value + (u64::from(part2) << 56), 10)); |
142 | 0 | }; |
143 | | |
144 | | // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow. |
145 | | // Assume the data is corrupt. |
146 | 0 | Err(DecodeError::new("invalid varint")) |
147 | 0 | } |
148 | | |
149 | | /// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as |
150 | | /// necessary. |
151 | | /// |
152 | | /// Contains a varint overflow check from [`ConsumeVarint`][1]. |
153 | | /// |
154 | | /// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 |
155 | | #[inline(never)] |
156 | | #[cold] |
157 | 0 | fn decode_varint_slow(buf: &mut impl Buf) -> Result<u64, DecodeError> { |
158 | 0 | let mut value = 0; |
159 | 0 | for count in 0..min(10, buf.remaining()) { |
160 | 0 | let byte = buf.get_u8(); |
161 | 0 | value |= u64::from(byte & 0x7F) << (count * 7); |
162 | 0 | if byte <= 0x7F { |
163 | | // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details. |
164 | | // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 |
165 | 0 | if count == 9 && byte >= 0x02 { |
166 | 0 | return Err(DecodeError::new("invalid varint")); |
167 | | } else { |
168 | 0 | return Ok(value); |
169 | | } |
170 | 0 | } |
171 | | } |
172 | | |
173 | 0 | Err(DecodeError::new("invalid varint")) |
174 | 0 | } |
175 | | |
176 | | #[cfg(test)] |
177 | | mod test { |
178 | | use super::*; |
179 | | |
180 | | #[test] |
181 | | fn varint() { |
182 | | fn check(value: u64, encoded: &[u8]) { |
183 | | // Small buffer. |
184 | | let mut buf = Vec::with_capacity(1); |
185 | | encode_varint(value, &mut buf); |
186 | | assert_eq!(buf, encoded); |
187 | | |
188 | | // Large buffer. |
189 | | let mut buf = Vec::with_capacity(100); |
190 | | encode_varint(value, &mut buf); |
191 | | assert_eq!(buf, encoded); |
192 | | |
193 | | assert_eq!(encoded_len_varint(value), encoded.len()); |
194 | | |
195 | | // See: https://github.com/tokio-rs/prost/pull/1008 for copying reasoning. |
196 | | let mut encoded_copy = encoded; |
197 | | let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed"); |
198 | | assert_eq!(value, roundtrip_value); |
199 | | |
200 | | let mut encoded_copy = encoded; |
201 | | let roundtrip_value = |
202 | | decode_varint_slow(&mut encoded_copy).expect("slow decoding failed"); |
203 | | assert_eq!(value, roundtrip_value); |
204 | | } |
205 | | |
206 | | check(2u64.pow(0) - 1, &[0x00]); |
207 | | check(2u64.pow(0), &[0x01]); |
208 | | |
209 | | check(2u64.pow(7) - 1, &[0x7F]); |
210 | | check(2u64.pow(7), &[0x80, 0x01]); |
211 | | check(300, &[0xAC, 0x02]); |
212 | | |
213 | | check(2u64.pow(14) - 1, &[0xFF, 0x7F]); |
214 | | check(2u64.pow(14), &[0x80, 0x80, 0x01]); |
215 | | |
216 | | check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]); |
217 | | check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]); |
218 | | |
219 | | check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]); |
220 | | check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]); |
221 | | |
222 | | check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); |
223 | | check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); |
224 | | |
225 | | check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); |
226 | | check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); |
227 | | |
228 | | check( |
229 | | 2u64.pow(49) - 1, |
230 | | &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], |
231 | | ); |
232 | | check( |
233 | | 2u64.pow(49), |
234 | | &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], |
235 | | ); |
236 | | |
237 | | check( |
238 | | 2u64.pow(56) - 1, |
239 | | &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], |
240 | | ); |
241 | | check( |
242 | | 2u64.pow(56), |
243 | | &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], |
244 | | ); |
245 | | |
246 | | check( |
247 | | 2u64.pow(63) - 1, |
248 | | &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], |
249 | | ); |
250 | | check( |
251 | | 2u64.pow(63), |
252 | | &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], |
253 | | ); |
254 | | |
255 | | check( |
256 | | u64::MAX, |
257 | | &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], |
258 | | ); |
259 | | } |
260 | | |
261 | | const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02]; |
262 | | |
263 | | #[test] |
264 | | fn varint_overflow() { |
265 | | let mut copy = U64_MAX_PLUS_ONE; |
266 | | decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded"); |
267 | | } |
268 | | |
269 | | #[test] |
270 | | fn variant_slow_overflow() { |
271 | | let mut copy = U64_MAX_PLUS_ONE; |
272 | | decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded"); |
273 | | } |
274 | | } |