/src/serenity/AK/Base64.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * Copyright (c) 2020-2022, Andreas Kling <kling@serenityos.org> |
3 | | * |
4 | | * SPDX-License-Identifier: BSD-2-Clause |
5 | | */ |
6 | | |
7 | | #include <AK/Assertions.h> |
8 | | #include <AK/Base64.h> |
9 | | #include <AK/Error.h> |
10 | | #include <AK/StringBuilder.h> |
11 | | #include <AK/Types.h> |
12 | | #include <AK/Vector.h> |
13 | | |
14 | | namespace AK { |
15 | | |
16 | | size_t calculate_base64_decoded_length(StringView input) |
17 | 6.95k | { |
18 | 6.95k | auto length = input.length() * 3 / 4; |
19 | | |
20 | 6.95k | if (input.ends_with("="sv)) |
21 | 2.18k | --length; |
22 | 6.95k | if (input.ends_with("=="sv)) |
23 | 2.06k | --length; |
24 | | |
25 | 6.95k | return length; |
26 | 6.95k | } |
27 | | |
28 | | size_t calculate_base64_encoded_length(ReadonlyBytes input) |
29 | 251 | { |
30 | 251 | return ((4 * input.size() / 3) + 3) & ~3; |
31 | 251 | } |
32 | | |
33 | | static ErrorOr<ByteBuffer> decode_base64_impl(StringView input, ReadonlySpan<i16> alphabet_lookup_table) |
34 | 7.20k | { |
35 | 7.20k | input = input.trim_whitespace(); |
36 | | |
37 | 7.20k | if (input.length() % 4 != 0) |
38 | 247 | return Error::from_string_literal("Invalid length of Base64 encoded string"); |
39 | | |
40 | 54.4M | auto get = [&](size_t offset, bool* is_padding) -> ErrorOr<u8> { |
41 | 54.4M | if (offset >= input.length()) |
42 | 0 | return 0; |
43 | | |
44 | 54.4M | auto ch = static_cast<unsigned char>(input[offset]); |
45 | 54.4M | if (ch == '=') { |
46 | 5.18k | if (!is_padding) |
47 | 11 | return Error::from_string_literal("Invalid '=' character outside of padding in base64 data"); |
48 | 5.17k | *is_padding = true; |
49 | 5.17k | return 0; |
50 | 5.18k | } |
51 | | |
52 | 54.4M | i16 result = alphabet_lookup_table[ch]; |
53 | 54.4M | if (result < 0) |
54 | 163 | return Error::from_string_literal("Invalid character in base64 data"); |
55 | 54.4M | VERIFY(result < 256); |
56 | 54.4M | return { result }; |
57 | 54.4M | }; |
58 | | |
59 | 6.95k | ByteBuffer output; |
60 | 6.95k | TRY(output.try_resize(calculate_base64_decoded_length(input))); |
61 | | |
62 | 0 | size_t input_offset = 0; |
63 | 6.95k | size_t output_offset = 0; |
64 | | |
65 | 13.6M | while (input_offset < input.length()) { |
66 | 13.6M | bool in2_is_padding = false; |
67 | 13.6M | bool in3_is_padding = false; |
68 | | |
69 | 13.6M | u8 const in0 = TRY(get(input_offset++, nullptr)); |
70 | 13.6M | u8 const in1 = TRY(get(input_offset++, nullptr)); |
71 | 13.6M | u8 const in2 = TRY(get(input_offset++, &in2_is_padding)); |
72 | 13.6M | u8 const in3 = TRY(get(input_offset++, &in3_is_padding)); |
73 | | |
74 | 0 | output[output_offset++] = (in0 << 2) | ((in1 >> 4) & 3); |
75 | | |
76 | 13.6M | if (!in2_is_padding) |
77 | 13.6M | output[output_offset++] = ((in1 & 0xf) << 4) | ((in2 >> 2) & 0xf); |
78 | | |
79 | 13.6M | if (!in3_is_padding) |
80 | 13.6M | output[output_offset++] = ((in2 & 0x3) << 6) | in3; |
81 | 13.6M | } |
82 | | |
83 | 6.77k | return output; |
84 | 6.95k | } |
85 | | |
86 | | static ErrorOr<String> encode_base64_impl(ReadonlyBytes input, ReadonlySpan<char> alphabet) |
87 | 251 | { |
88 | 251 | Vector<u8> output; |
89 | 251 | TRY(output.try_ensure_capacity(calculate_base64_encoded_length(input))); |
90 | | |
91 | 30.4M | auto get = [&](size_t const offset, bool* need_padding = nullptr) -> u8 { |
92 | 30.4M | if (offset >= input.size()) { |
93 | 314 | if (need_padding) |
94 | 314 | *need_padding = true; |
95 | 314 | return 0; |
96 | 314 | } |
97 | 30.4M | return input[offset]; |
98 | 30.4M | }; |
99 | | |
100 | 10.1M | for (size_t i = 0; i < input.size(); i += 3) { |
101 | 10.1M | bool is_8bit = false; |
102 | 10.1M | bool is_16bit = false; |
103 | | |
104 | 10.1M | u8 const in0 = get(i); |
105 | 10.1M | u8 const in1 = get(i + 1, &is_16bit); |
106 | 10.1M | u8 const in2 = get(i + 2, &is_8bit); |
107 | | |
108 | 10.1M | u8 const index0 = (in0 >> 2) & 0x3f; |
109 | 10.1M | u8 const index1 = ((in0 << 4) | (in1 >> 4)) & 0x3f; |
110 | 10.1M | u8 const index2 = ((in1 << 2) | (in2 >> 6)) & 0x3f; |
111 | 10.1M | u8 const index3 = in2 & 0x3f; |
112 | | |
113 | 10.1M | output.unchecked_append(alphabet[index0]); |
114 | 10.1M | output.unchecked_append(alphabet[index1]); |
115 | 10.1M | output.unchecked_append(is_16bit ? '=' : alphabet[index2]); |
116 | 10.1M | output.unchecked_append(is_8bit ? '=' : alphabet[index3]); |
117 | 10.1M | } |
118 | | |
119 | 251 | return String::from_utf8_without_validation(output); |
120 | 251 | } |
121 | | |
122 | | ErrorOr<ByteBuffer> decode_base64(StringView input) |
123 | 7.20k | { |
124 | 7.20k | static constexpr auto lookup_table = base64_lookup_table(); |
125 | 7.20k | return decode_base64_impl(input, lookup_table); |
126 | 7.20k | } |
127 | | |
128 | | ErrorOr<ByteBuffer> decode_base64url(StringView input) |
129 | 0 | { |
130 | 0 | static constexpr auto lookup_table = base64url_lookup_table(); |
131 | 0 | return decode_base64_impl(input, lookup_table); |
132 | 0 | } |
133 | | |
134 | | ErrorOr<String> encode_base64(ReadonlyBytes input) |
135 | 251 | { |
136 | 251 | return encode_base64_impl(input, base64_alphabet); |
137 | 251 | } |
138 | | ErrorOr<String> encode_base64url(ReadonlyBytes input) |
139 | 0 | { |
140 | 0 | return encode_base64_impl(input, base64url_alphabet); |
141 | 0 | } |
142 | | |
143 | | } |