/src/serenity/AK/Base64.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) 2020-2022, Andreas Kling <kling@serenityos.org> |
3 | | * |
4 | | * SPDX-License-Identifier: BSD-2-Clause |
5 | | */ |
6 | | |
7 | | #include <AK/Assertions.h> |
8 | | #include <AK/Base64.h> |
9 | | #include <AK/Error.h> |
10 | | #include <AK/StringBuilder.h> |
11 | | #include <AK/Types.h> |
12 | | #include <AK/Vector.h> |
13 | | |
14 | | namespace AK { |
15 | | |
16 | | size_t calculate_base64_decoded_length(StringView input) |
17 | 5.67k | { |
18 | 5.67k | auto length = input.length() * 3 / 4; |
19 | | |
20 | 5.67k | if (input.ends_with("="sv)) |
21 | 2.01k | --length; |
22 | 5.67k | if (input.ends_with("=="sv)) |
23 | 1.85k | --length; |
24 | | |
25 | 5.67k | return length; |
26 | 5.67k | } |
27 | | |
28 | | size_t calculate_base64_encoded_length(ReadonlyBytes input) |
29 | 280 | { |
30 | 280 | return ((4 * input.size() / 3) + 3) & ~3; |
31 | 280 | } |
32 | | |
33 | | static ErrorOr<ByteBuffer> decode_base64_impl(StringView input, ReadonlySpan<i16> alphabet_lookup_table) |
34 | 5.93k | { |
35 | 5.93k | input = input.trim_whitespace(); |
36 | | |
37 | 5.93k | if (input.length() % 4 != 0) |
38 | 262 | return Error::from_string_literal("Invalid length of Base64 encoded string"); |
39 | | |
40 | 64.6M | auto get = [&](size_t offset, bool* is_padding) -> ErrorOr<u8> { |
41 | 64.6M | if (offset >= input.length()) |
42 | 0 | return 0; |
43 | | |
44 | 64.6M | auto ch = static_cast<unsigned char>(input[offset]); |
45 | 64.6M | if (ch == '=') { |
46 | 4.98k | if (!is_padding) |
47 | 6 | return Error::from_string_literal("Invalid '=' character outside of padding in base64 data"); |
48 | 4.97k | *is_padding = true; |
49 | 4.97k | return 0; |
50 | 4.98k | } |
51 | | |
52 | 64.6M | i16 result = alphabet_lookup_table[ch]; |
53 | 64.6M | if (result < 0) |
54 | 205 | return Error::from_string_literal("Invalid character in base64 data"); |
55 | 64.6M | VERIFY(result < 256); |
56 | 64.6M | return { result }; |
57 | 64.6M | }; |
58 | | |
59 | 5.67k | ByteBuffer output; |
60 | 5.67k | TRY(output.try_resize(calculate_base64_decoded_length(input))); |
61 | | |
62 | 5.67k | size_t input_offset = 0; |
63 | 5.67k | size_t output_offset = 0; |
64 | | |
65 | 16.1M | while (input_offset < input.length()) { |
66 | 16.1M | bool in2_is_padding = false; |
67 | 16.1M | bool in3_is_padding = false; |
68 | | |
69 | 16.1M | u8 const in0 = TRY(get(input_offset++, nullptr)); |
70 | 16.1M | u8 const in1 = TRY(get(input_offset++, nullptr)); |
71 | 16.1M | u8 const in2 = TRY(get(input_offset++, &in2_is_padding)); |
72 | 16.1M | u8 const in3 = TRY(get(input_offset++, &in3_is_padding)); |
73 | | |
74 | 16.1M | output[output_offset++] = (in0 << 2) | ((in1 >> 4) & 3); |
75 | | |
76 | 16.1M | if (!in2_is_padding) |
77 | 16.1M | output[output_offset++] = ((in1 & 0xf) << 4) | ((in2 >> 2) & 0xf); |
78 | | |
79 | 16.1M | if (!in3_is_padding) |
80 | 16.1M | output[output_offset++] = ((in2 & 0x3) << 6) | in3; |
81 | 16.1M | } |
82 | | |
83 | 5.67k | return output; |
84 | 5.67k | } |
85 | | |
86 | | static ErrorOr<String> encode_base64_impl(ReadonlyBytes input, ReadonlySpan<char> alphabet) |
87 | 280 | { |
88 | 280 | Vector<u8> output; |
89 | 280 | TRY(output.try_ensure_capacity(calculate_base64_encoded_length(input))); |
90 | | |
91 | 38.3M | auto get = [&](size_t const offset, bool* need_padding = nullptr) -> u8 { |
92 | 38.3M | if (offset >= input.size()) { |
93 | 351 | if (need_padding) |
94 | 351 | *need_padding = true; |
95 | 351 | return 0; |
96 | 351 | } |
97 | 38.3M | return input[offset]; |
98 | 38.3M | }; |
99 | | |
100 | 12.7M | for (size_t i = 0; i < input.size(); i += 3) { |
101 | 12.7M | bool is_8bit = false; |
102 | 12.7M | bool is_16bit = false; |
103 | | |
104 | 12.7M | u8 const in0 = get(i); |
105 | 12.7M | u8 const in1 = get(i + 1, &is_16bit); |
106 | 12.7M | u8 const in2 = get(i + 2, &is_8bit); |
107 | | |
108 | 12.7M | u8 const index0 = (in0 >> 2) & 0x3f; |
109 | 12.7M | u8 const index1 = ((in0 << 4) | (in1 >> 4)) & 0x3f; |
110 | 12.7M | u8 const index2 = ((in1 << 2) | (in2 >> 6)) & 0x3f; |
111 | 12.7M | u8 const index3 = in2 & 0x3f; |
112 | | |
113 | 12.7M | output.unchecked_append(alphabet[index0]); |
114 | 12.7M | output.unchecked_append(alphabet[index1]); |
115 | 12.7M | output.unchecked_append(is_16bit ? '=' : alphabet[index2]); |
116 | 12.7M | output.unchecked_append(is_8bit ? '=' : alphabet[index3]); |
117 | 12.7M | } |
118 | | |
119 | 280 | return String::from_utf8_without_validation(output); |
120 | 280 | } |
121 | | |
122 | | ErrorOr<ByteBuffer> decode_base64(StringView input) |
123 | 5.93k | { |
124 | 5.93k | static constexpr auto lookup_table = base64_lookup_table(); |
125 | 5.93k | return decode_base64_impl(input, lookup_table); |
126 | 5.93k | } |
127 | | |
128 | | ErrorOr<ByteBuffer> decode_base64url(StringView input) |
129 | 0 | { |
130 | 0 | static constexpr auto lookup_table = base64url_lookup_table(); |
131 | 0 | return decode_base64_impl(input, lookup_table); |
132 | 0 | } |
133 | | |
134 | | ErrorOr<String> encode_base64(ReadonlyBytes input) |
135 | 280 | { |
136 | 280 | return encode_base64_impl(input, base64_alphabet); |
137 | 280 | } |
138 | | ErrorOr<String> encode_base64url(ReadonlyBytes input) |
139 | 0 | { |
140 | 0 | return encode_base64_impl(input, base64url_alphabet); |
141 | 0 | } |
142 | | |
143 | | } |