Coverage Report

Created: 2025-12-08 07:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/kea-fuzzer/mysqlmock.cc
Line
Count
Source
1
// Copyright (C) 2025 Ada Logcis Ltd.
2
//
3
// This Source Code Form is subject to the terms of the Mozilla Public
4
// License, v. 2.0. If a copy of the MPL was not distributed with this
5
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
////////////////////////////////////////////////////////////////////////////////
7
#include <fuzzer/FuzzedDataProvider.h>
8
9
#include <mariadb/mysql.h>
10
11
#include <stdint.h>
12
#include <string.h>
13
#include <stdlib.h>
14
#include <vector>
15
#include <string>
16
#include <algorithm>
17
#include <unordered_map>
18
19
struct StmtState;
20
21
static thread_local FuzzedDataProvider* g_fdp = nullptr;
22
static thread_local std::string g_mysql_error;
23
static thread_local std::string g_stmt_error;
24
static thread_local std::string g_tls_cipher;
25
26
struct MockResRow {
27
    enum ColKind {
28
        CK_UINT64, CK_UINT32, CK_UINT8, CK_STRING, CK_BLOB
29
    } kind;
30
    std::string s;
31
    std::vector<uint8_t> blob;
32
    uint64_t u64 = 0;
33
    uint32_t u32 = 0;
34
    uint8_t u8 = 0;
35
};
36
37
struct StmtState {
38
    MYSQL* mysql;
39
    std::string sql;
40
    std::vector<std::vector<MockResRow>> rows;
41
    size_t fetch_index = 0;
42
    MYSQL_BIND* res_binds = nullptr;
43
    size_t res_binds_count = 0;
44
    bool has_rows = false;
45
    bool is_version_stmt = false;
46
    unsigned int field_count = 0;
47
};
48
49
static thread_local std::unordered_map<MYSQL_STMT*, StmtState*> g_stmt_state;
50
static thread_local std::vector<MYSQL_STMT*> g_all_stmts;
51
static thread_local std::vector<StmtState*> g_live_stmts;
52
53
6.99k
extern "C" void mysqlmock_load_bytes(const uint8_t* data, size_t size) {
54
6.99k
    for (auto* st : g_all_stmts) {
55
0
       free(st);
56
0
    }
57
6.99k
    g_all_stmts.clear();
58
6.99k
    for (auto& kv : g_stmt_state) {
59
0
        delete kv.second;
60
0
    }
61
6.99k
    g_stmt_state.clear();
62
6.99k
    g_stmt_state.rehash(0);
63
6.99k
    delete g_fdp;
64
6.99k
    g_fdp = new FuzzedDataProvider(data, size);
65
6.99k
}
66
67
2.25M
static bool is_like(const std::string& hay, const char* needle) {
68
2.25M
    std::string h = hay;
69
2.25M
    std::string n = needle ? needle : "";
70
2.25M
    std::transform(h.begin(), h.end(), h.begin(), ::tolower);
71
2.25M
    std::transform(n.begin(), n.end(), n.begin(), ::tolower);
72
2.25M
    return h.find(n) != std::string::npos;
73
2.25M
}
74
75
0
static MYSQL_STMT* make_stmt() {
76
0
    auto s = new StmtState();
77
0
    s->mysql = reinterpret_cast<MYSQL*>(0x1);
78
0
    g_live_stmts.push_back(s);
79
0
    return reinterpret_cast<MYSQL_STMT*>(s);
80
0
}
81
82
1.61M
static StmtState* SS(MYSQL_STMT* st) {
83
1.61M
    if (!st) {
84
0
        return nullptr;
85
0
    }
86
1.61M
    auto it = g_stmt_state.find(st);
87
1.61M
    return (it == g_stmt_state.end()) ? nullptr : it->second;
88
1.61M
}
89
90
1.54M
static StmtState* ensure_state(MYSQL_STMT* st) {
91
1.54M
    auto s = SS(st);
92
1.54M
    if (s) return s;
93
761k
    s = new StmtState();
94
761k
    s->mysql = reinterpret_cast<MYSQL*>(0x1);
95
761k
    g_stmt_state[st] = s;
96
761k
    g_live_stmts.push_back(s);
97
761k
    return s;
98
1.54M
}
99
100
747k
static unsigned int infer_field_count_from_sql(const std::string& sql) {
101
747k
    std::string s = sql;
102
747k
    std::transform(s.begin(), s.end(), s.begin(), ::tolower);
103
747k
    auto psel = s.find("select");
104
747k
    if (psel == std::string::npos) {
105
361k
        return 0;
106
361k
    }
107
385k
    auto pfrom = s.find(" from ", psel);
108
385k
    if (pfrom == std::string::npos) {
109
0
        return 0;
110
0
    }
111
385k
    std::string proj = s.substr(psel + 6, pfrom - (psel + 6));
112
113
385k
    proj.erase(0, proj.find_first_not_of(" \t\r\n"));
114
385k
    proj.erase(proj.find_last_not_of(" \t\r\n")+1);
115
385k
    if (proj.rfind("distinct", 0) == 0) {
116
0
        proj.erase(0, 8);
117
0
        proj.erase(0, proj.find_first_not_of(" \t\r\n"));
118
0
    }
119
120
385k
    unsigned int cols = 0;
121
385k
    int paren = 0;
122
385k
    bool in_s = false, in_d = false;
123
215M
    for (size_t i = 0; i < proj.size(); ++i) {
124
214M
        char c = proj[i];
125
214M
        if (!in_d && c=='\'' && (i==0 || proj[i-1] != '\\')) {
126
0
            in_s = !in_s;
127
214M
        } else if (!in_s && c=='"' && (i==0 || proj[i-1] != '\\')) {
128
0
            in_d = !in_d;
129
214M
        } else if (!in_s && !in_d) {
130
214M
            if (c=='(') {
131
0
                paren++;
132
214M
            } else if (c==')' && paren>0) {
133
0
                paren--;
134
214M
            } else if (c==',' && paren==0) {
135
11.9M
                cols++;
136
11.9M
            }
137
214M
        }
138
214M
    }
139
385k
    if (!proj.empty()) {
140
385k
        cols++;
141
385k
    }
142
385k
    return cols;
143
385k
}
144
145
13.3k
static void fill_version_stmt(StmtState* s) {
146
13.3k
    s->is_version_stmt = true;
147
13.3k
    s->has_rows = true;
148
13.3k
    s->rows.clear();
149
13.3k
    s->field_count = 2;
150
13.3k
    std::vector<MockResRow> r;
151
13.3k
    MockResRow c1;
152
13.3k
    c1.kind = MockResRow::CK_UINT32;
153
13.3k
    c1.u32 = 32;
154
13.3k
    MockResRow c2;
155
13.3k
    c2.kind = MockResRow::CK_UINT32;
156
13.3k
    c2.u32 = 0;
157
13.3k
    r.push_back(c1);
158
13.3k
    r.push_back(c2);
159
13.3k
    s->rows.push_back(r);
160
13.3k
}
161
162
709k
static void fill_no_rows(StmtState* s, unsigned int cols = 0) {
163
709k
    s->is_version_stmt = false;
164
709k
    s->has_rows = false;
165
709k
    s->rows.clear();
166
709k
    s->fetch_index = 0;
167
709k
    s->field_count = cols;
168
709k
}
169
170
38.1k
static void fill_fuzz_rows(StmtState* s, unsigned int ncols) {
171
38.1k
    s->is_version_stmt = false;
172
38.1k
    s->has_rows = true;
173
38.1k
    s->rows.clear();
174
38.1k
    s->fetch_index = 0;
175
38.1k
    s->field_count = ncols ? ncols : 1;;
176
38.1k
    unsigned int nrows = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 3);
177
115k
    for (unsigned int r = 0; r < nrows; ++r) {
178
77.4k
        std::vector<MockResRow> row;
179
77.4k
        row.reserve(ncols);
180
1.66M
        for (unsigned int c = 0; c < ncols; ++c) {
181
1.58M
            int k = g_fdp->ConsumeIntegralInRange<int>(0, 4);
182
1.58M
            MockResRow cell;
183
1.58M
            switch (k) {
184
537k
                case 0:
185
537k
                    cell.kind = MockResRow::CK_UINT64;
186
537k
                    cell.u64 = g_fdp->ConsumeIntegral<uint64_t>();
187
537k
                    break;
188
90.0k
                case 1:
189
90.0k
                    cell.kind = MockResRow::CK_UINT32;
190
90.0k
                    cell.u32 = g_fdp->ConsumeIntegral<uint32_t>();
191
90.0k
                    break;
192
560k
                case 2:
193
560k
                    cell.kind = MockResRow::CK_UINT8;
194
560k
                    cell.u8 = g_fdp->ConsumeIntegral<uint8_t>();
195
560k
                    break;
196
261k
                case 3:
197
261k
                    cell.kind = MockResRow::CK_STRING;
198
261k
                    cell.s = g_fdp->ConsumeRandomLengthString(32);
199
261k
                    break;
200
139k
                case 4: {
201
139k
                    cell.kind = MockResRow::CK_BLOB;
202
139k
                    size_t n = g_fdp->ConsumeIntegralInRange<size_t>(0, 32);
203
139k
                    cell.blob = g_fdp->ConsumeBytes<uint8_t>(n);
204
139k
                    break;
205
0
                }
206
1.58M
            }
207
1.58M
            row.push_back(cell);
208
1.58M
        }
209
77.4k
        s->rows.push_back(std::move(row));
210
77.4k
    }
211
38.1k
}
212
213
extern "C" {
214
4
    int mysql_server_init(int argc, char **argv, char **groups) {
215
4
        return 0;
216
4
    }
217
218
27.6k
    MYSQL* mysql_init(MYSQL* in) {
219
27.6k
        return in ? in : reinterpret_cast<MYSQL*>(0x1);
220
27.6k
    }
221
222
27.6k
    void mysql_close(MYSQL*) {}
223
224
0
    void mysql_free_result(MYSQL_RES*) {}
225
226
0
    void mysql_server_end(void) {}
227
228
    MYSQL* mysql_real_connect(MYSQL* mysql,
229
                              const char*,
230
                              const char*,
231
                              const char*,
232
                              const char*,
233
                              unsigned int,
234
                              const char*,
235
26.7k
                              unsigned long) {
236
26.7k
        return mysql ? mysql : reinterpret_cast<MYSQL*>(0x1);
237
26.7k
    }
238
239
0
    unsigned int mysql_errno(MYSQL*) {
240
0
        return 0;
241
0
    }
242
243
0
    const char* mysql_error(MYSQL*) {
244
0
        g_mysql_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string();
245
0
        return g_mysql_error.c_str();
246
0
    }
247
248
761k
    MYSQL_STMT* mysql_stmt_init(MYSQL*) {
249
761k
        auto* stmt = static_cast<MYSQL_STMT*>(std::calloc(1, sizeof(MYSQL_STMT)));
250
761k
        if (!stmt) {
251
0
            return nullptr;
252
0
        }
253
761k
        stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
254
761k
        g_all_stmts.push_back(stmt);
255
761k
        auto* s = ensure_state(stmt);
256
761k
        s->mysql = stmt->mysql;
257
761k
        return stmt;
258
761k
    }
259
260
761k
    int mysql_stmt_prepare(MYSQL_STMT* stmt, const char* q, unsigned long len) {
261
761k
        if (!stmt) {
262
0
            return 1;
263
0
        }
264
761k
        stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
265
761k
        auto* s = ensure_state(stmt);
266
267
761k
        const uintptr_t pq = reinterpret_cast<uintptr_t>(q);
268
761k
        const bool ptr_ok  = (q != nullptr) && (pq > 0x10000);
269
761k
        const bool len_ok  = (len > 0);
270
761k
        if (ptr_ok && len_ok) {
271
761k
            size_t copy_len = std::min<size_t>(len, 4096);
272
761k
            s->sql.assign(q, copy_len);
273
761k
        } else {
274
0
            s->sql = "SELECT version, minor FROM schema_version";
275
0
        }
276
277
761k
        if (is_like(s->sql, "schema_version") || is_like(s->sql, "select version") || is_like(s->sql, "get_version")) {
278
13.3k
            fill_version_stmt(s);
279
13.3k
            return 0;
280
13.3k
        }
281
282
747k
        unsigned int cols = infer_field_count_from_sql(s->sql);
283
747k
        if (cols == 0) {
284
361k
            cols = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 8);
285
361k
        }
286
747k
        s->field_count = cols;
287
288
747k
        if (g_fdp->ConsumeBool()) {
289
38.1k
            fill_fuzz_rows(s, cols);
290
709k
        } else {
291
709k
            fill_no_rows(s);
292
709k
        }
293
747k
        return 0;
294
761k
    }
295
296
761k
    my_bool mysql_stmt_close(MYSQL_STMT* stmt) {
297
761k
        if (stmt) {
298
761k
            auto it = g_stmt_state.find(stmt);
299
761k
            if (it != g_stmt_state.end()) {
300
761k
                delete it->second;
301
761k
                g_stmt_state.erase(it);
302
761k
                if (g_stmt_state.empty()) {
303
20.0k
                    g_stmt_state.rehash(0);
304
20.0k
                }
305
761k
            }
306
761k
            auto it2 = std::find(g_all_stmts.begin(), g_all_stmts.end(), stmt);
307
761k
            if (it2 != g_all_stmts.end()) {
308
761k
                g_all_stmts.erase(it2);
309
761k
            }
310
761k
            free(stmt);
311
761k
        }
312
761k
        return 0;
313
761k
    }
314
315
25.8k
    my_bool mysql_stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) {
316
25.8k
        auto* s = ensure_state(stmt);
317
25.8k
        s->res_binds = bnd;
318
25.8k
        s->res_binds_count = bnd ? static_cast<size_t>(mysql_stmt_field_count(stmt)) : 0;
319
25.8k
        return 0;
320
25.8k
    }
321
322
79.5k
    int mysql_stmt_execute(MYSQL_STMT*) {
323
79.5k
        return 0;
324
79.5k
    }
325
326
12.4k
    int mysql_stmt_store_result(MYSQL_STMT*) {
327
12.4k
        return 0;
328
12.4k
    }
329
330
12.4k
    my_bool mysql_stmt_free_result(MYSQL_STMT* stmt) {
331
12.4k
        auto s = SS(stmt);
332
12.4k
        if (s) {
333
12.4k
            s->fetch_index = 0;
334
12.4k
        }
335
12.4k
        if (stmt) {
336
12.4k
            stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
337
12.4k
        }
338
12.4k
        return 0;
339
12.4k
    }
340
341
13.2k
    my_ulonglong mysql_stmt_affected_rows(MYSQL_STMT*) {
342
13.2k
        return 0ULL;
343
13.2k
    }
344
345
0
    my_bool mysql_stmt_reset(MYSQL_STMT* stmt) {
346
0
        auto s = SS(stmt);
347
0
        if (s) {
348
0
            s->fetch_index = 0;
349
0
        }
350
0
        if (stmt) {
351
0
            stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
352
0
        }
353
0
        return 0;
354
0
    }
355
356
0
    unsigned int mysql_stmt_errno(MYSQL_STMT*) {
357
0
        return 0;
358
0
    }
359
360
0
    const char* mysql_stmt_error(MYSQL_STMT*) {
361
0
        g_stmt_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string();
362
0
        return g_stmt_error.c_str();
363
0
    }
364
365
86.4k
    int mysql_options(MYSQL*, enum mysql_option, const void*) {
366
86.4k
        return 0;
367
86.4k
    }
368
369
26.7k
    my_bool mysql_autocommit(MYSQL*, my_bool) {
370
26.7k
        return 0;
371
26.7k
    }
372
373
17.9k
    my_bool mysql_commit(MYSQL*) {
374
17.9k
        return 0;
375
17.9k
    }
376
377
78
    my_bool mysql_rollback(MYSQL*) {
378
78
        return 0;
379
78
    }
380
381
18.0k
    int mysql_query(MYSQL*, const char*) {
382
18.0k
        return 0;
383
18.0k
    }
384
385
59.8k
    my_bool mysql_stmt_bind_param(MYSQL_STMT*, MYSQL_BIND*) {
386
59.8k
        return 0;
387
59.8k
    }
388
389
25.8k
    unsigned int mysql_stmt_field_count(MYSQL_STMT* stmt) {
390
25.8k
        auto s = SS(stmt);
391
25.8k
        return s ? s->field_count : 0u;
392
25.8k
    }
393
394
0
    MYSQL_RES* mysql_stmt_result_metadata(MYSQL_STMT* stmt) {
395
0
        auto s = SS(stmt);
396
0
        if (!s) {
397
0
            return reinterpret_cast<MYSQL_RES*>(0x1);
398
0
        }
399
0
        if (s->is_version_stmt) {
400
0
            return reinterpret_cast<MYSQL_RES*>(0x1);
401
0
        }
402
0
        return nullptr;
403
0
    }
404
405
6.61k
    my_ulonglong mysql_insert_id(MYSQL*) {
406
6.61k
        return g_fdp ? g_fdp->ConsumeIntegral<my_ulonglong>() : 0ULL;
407
6.61k
    }
408
409
570
    const char* mysql_get_ssl_cipher(MYSQL*) {
410
570
        if (g_fdp && g_fdp->ConsumeBool()) {
411
444
            g_tls_cipher = g_fdp->ConsumeRandomLengthString(64);
412
444
            return g_tls_cipher.c_str();
413
444
        }
414
126
        return "TLS_FAKE_CIPHER_WITH_FAKE_SHA256";
415
570
    }
416
417
27.2k
    int mysql_stmt_fetch(MYSQL_STMT* stmt) {
418
27.2k
        auto s = SS(stmt);
419
420
27.2k
        if (!s || !s->has_rows){
421
8.61k
             return MYSQL_NO_DATA;
422
8.61k
        }
423
18.6k
        if (s->fetch_index >= s->rows.size()) {
424
280
            return MYSQL_NO_DATA;
425
280
        }
426
18.3k
        if (!s->res_binds) {
427
0
            return MYSQL_NO_DATA;
428
0
        }
429
430
18.3k
        const auto& row = s->rows[s->fetch_index++];
431
18.3k
        size_t cols = row.size();
432
18.3k
        if (s->field_count && cols > s->field_count) {
433
0
            cols = s->field_count;
434
0
        }
435
18.3k
        if (cols > s->res_binds_count) {
436
0
            cols = s->res_binds_count;
437
0
        }
438
439
171k
        for (size_t i = 0; i < cols; ++i) {
440
152k
            const auto& cell = row[i];
441
152k
            MYSQL_BIND& b = s->res_binds[i];
442
152k
            if (!b.buffer) {
443
0
                continue;
444
0
            }
445
152k
            switch (cell.kind) {
446
30.6k
                case MockResRow::CK_UINT32: {
447
30.6k
                    uint32_t v = cell.u32;
448
30.6k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
449
29.8k
                        memcpy(b.buffer, &v, sizeof(v));
450
29.8k
                        if (b.length) *b.length = sizeof(v);
451
29.8k
                    }
452
30.6k
                    break;
453
0
                }
454
90.9k
                case MockResRow::CK_UINT64: {
455
90.9k
                    uint64_t v = cell.u64;
456
90.9k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
457
52.4k
                        memcpy(b.buffer, &v, sizeof(v));
458
52.4k
                        if (b.length) *b.length = sizeof(v);
459
52.4k
                    }
460
90.9k
                    break;
461
0
                }
462
15.4k
                case MockResRow::CK_UINT8: {
463
15.4k
                    uint8_t v = cell.u8;
464
15.4k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
465
15.4k
                        memcpy(b.buffer, &v, sizeof(v));
466
15.4k
                        if (b.length) *b.length = sizeof(v);
467
15.4k
                    }
468
15.4k
                    break;
469
0
                }
470
7.45k
                case MockResRow::CK_STRING: {
471
7.45k
                    if (b.buffer_length > 0) {
472
7.45k
                        size_t n = std::min<size_t>(b.buffer_length - 1, cell.s.size());
473
7.45k
                        memcpy(b.buffer, cell.s.data(), n);
474
7.45k
                        reinterpret_cast<char*>(b.buffer)[n] = '\0';
475
7.45k
                        if (b.length) *b.length = n;
476
7.45k
                    }
477
7.45k
                    break;
478
0
                }
479
8.40k
                case MockResRow::CK_BLOB: {
480
8.40k
                    if (b.buffer_length > 0) {
481
8.40k
                        size_t n = std::min<size_t>(b.buffer_length, cell.blob.size());
482
8.40k
                        memcpy(b.buffer, cell.blob.data(), n);
483
8.40k
                        if (b.length) *b.length = n;
484
8.40k
                    }
485
8.40k
                    break;
486
0
                }
487
152k
            }
488
152k
        }
489
18.3k
        return 0;
490
18.3k
    }
491
}