/src/boringssl/fuzz/bn_mod_exp.cc
Line | Count | Source |
1 | | // Copyright 2017 The BoringSSL Authors |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include <openssl/bn.h> |
16 | | #include <openssl/bytestring.h> |
17 | | #include <openssl/mem.h> |
18 | | |
19 | | #define CHECK(expr) \ |
20 | 18.2k | do { \ |
21 | 18.2k | if (!(expr)) { \ |
22 | 0 | printf("%s failed\n", #expr); \ |
23 | 0 | abort(); \ |
24 | 0 | } \ |
25 | 18.2k | } while (false) |
26 | | |
27 | | // Basic implementation of mod_exp using square and multiple method. |
28 | | int mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, const BIGNUM *m, |
29 | 1.90k | BN_CTX *ctx) { |
30 | 1.90k | if (BN_is_one(m)) { |
31 | 25 | BN_zero(r); |
32 | 25 | return 1; |
33 | 25 | } |
34 | | |
35 | 1.88k | bssl::UniquePtr<BIGNUM> exp(BN_dup(p)); |
36 | 1.88k | bssl::UniquePtr<BIGNUM> base(BN_new()); |
37 | 1.88k | if (!exp || !base) { |
38 | 0 | return 0; |
39 | 0 | } |
40 | 1.88k | if (!BN_one(r) || !BN_nnmod(base.get(), a, m, ctx)) { |
41 | 0 | return 0; |
42 | 0 | } |
43 | | |
44 | 245k | while (!BN_is_zero(exp.get())) { |
45 | 243k | if (BN_is_odd(exp.get())) { |
46 | 90.1k | if (!BN_mul(r, r, base.get(), ctx) || !BN_nnmod(r, r, m, ctx)) { |
47 | 0 | return 0; |
48 | 0 | } |
49 | 90.1k | } |
50 | 243k | if (!BN_rshift1(exp.get(), exp.get()) || |
51 | 243k | !BN_mul(base.get(), base.get(), base.get(), ctx) || |
52 | 243k | !BN_nnmod(base.get(), base.get(), m, ctx)) { |
53 | 0 | return 0; |
54 | 0 | } |
55 | 243k | } |
56 | | |
57 | 1.88k | return 1; |
58 | 1.88k | } |
59 | | |
60 | 2.00k | extern "C" int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) { |
61 | 2.00k | CBS cbs, child0, child1, child2; |
62 | 2.00k | uint8_t sign; |
63 | 2.00k | CBS_init(&cbs, buf, len); |
64 | 2.00k | if (!CBS_get_u16_length_prefixed(&cbs, &child0) || |
65 | 1.98k | !CBS_get_u8(&child0, &sign) || |
66 | 1.98k | CBS_len(&child0) == 0 || |
67 | 1.97k | !CBS_get_u16_length_prefixed(&cbs, &child1) || |
68 | 1.96k | CBS_len(&child1) == 0 || |
69 | 1.95k | !CBS_get_u16_length_prefixed(&cbs, &child2) || |
70 | 1.94k | CBS_len(&child2) == 0) { |
71 | 64 | return 0; |
72 | 64 | } |
73 | | |
74 | | // Don't fuzz inputs larger than 512 bytes (4096 bits). This isn't ideal, but |
75 | | // the naive |mod_exp| above is somewhat slow, so this otherwise causes the |
76 | | // fuzzers to spend a lot of time exploring timeouts. |
77 | 1.93k | if (CBS_len(&child0) > 512 || |
78 | 1.93k | CBS_len(&child1) > 512 || |
79 | 1.92k | CBS_len(&child2) > 512) { |
80 | 29 | return 0; |
81 | 29 | } |
82 | | |
83 | 1.91k | bssl::UniquePtr<BIGNUM> base( |
84 | 1.91k | BN_bin2bn(CBS_data(&child0), CBS_len(&child0), nullptr)); |
85 | 1.91k | BN_set_negative(base.get(), sign % 2); |
86 | 1.91k | bssl::UniquePtr<BIGNUM> power( |
87 | 1.91k | BN_bin2bn(CBS_data(&child1), CBS_len(&child1), nullptr)); |
88 | 1.91k | bssl::UniquePtr<BIGNUM> modulus( |
89 | 1.91k | BN_bin2bn(CBS_data(&child2), CBS_len(&child2), nullptr)); |
90 | | |
91 | 1.91k | if (BN_is_zero(modulus.get())) { |
92 | 2 | return 0; |
93 | 2 | } |
94 | | |
95 | 1.90k | bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new()); |
96 | 1.90k | bssl::UniquePtr<BIGNUM> result(BN_new()); |
97 | 1.90k | bssl::UniquePtr<BIGNUM> expected(BN_new()); |
98 | 1.90k | CHECK(ctx); |
99 | 1.90k | CHECK(result); |
100 | 1.90k | CHECK(expected); |
101 | | |
102 | 1.90k | CHECK(mod_exp(expected.get(), base.get(), power.get(), modulus.get(), |
103 | 1.90k | ctx.get())); |
104 | 1.90k | CHECK(BN_mod_exp(result.get(), base.get(), power.get(), modulus.get(), |
105 | 1.90k | ctx.get())); |
106 | 1.90k | CHECK(BN_cmp(result.get(), expected.get()) == 0); |
107 | | |
108 | 1.90k | if (BN_is_odd(modulus.get())) { |
109 | 1.14k | bssl::UniquePtr<BN_MONT_CTX> mont( |
110 | 1.14k | BN_MONT_CTX_new_for_modulus(modulus.get(), ctx.get())); |
111 | 1.14k | CHECK(mont); |
112 | | // |BN_mod_exp_mont| and |BN_mod_exp_mont_consttime| require reduced inputs. |
113 | 1.14k | CHECK(BN_nnmod(base.get(), base.get(), modulus.get(), ctx.get())); |
114 | 1.14k | CHECK(BN_mod_exp_mont(result.get(), base.get(), power.get(), modulus.get(), |
115 | 1.14k | ctx.get(), mont.get())); |
116 | 1.14k | CHECK(BN_cmp(result.get(), expected.get()) == 0); |
117 | 1.14k | CHECK(BN_mod_exp_mont_consttime(result.get(), base.get(), power.get(), |
118 | 1.14k | modulus.get(), ctx.get(), mont.get())); |
119 | 1.14k | CHECK(BN_cmp(result.get(), expected.get()) == 0); |
120 | 1.14k | } |
121 | | |
122 | 1.90k | uint8_t *data = (uint8_t *)OPENSSL_malloc(BN_num_bytes(result.get())); |
123 | 1.90k | BN_bn2bin(result.get(), data); |
124 | 1.90k | OPENSSL_free(data); |
125 | | |
126 | 1.90k | return 0; |
127 | 1.90k | } |