Coverage Report

Created: 2025-12-31 07:33

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/kea-fuzzer/mysqlmock.cc
Line
Count
Source
1
// Copyright (C) 2025 Ada Logcis Ltd.
2
//
3
// This Source Code Form is subject to the terms of the Mozilla Public
4
// License, v. 2.0. If a copy of the MPL was not distributed with this
5
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
////////////////////////////////////////////////////////////////////////////////
7
#include <fuzzer/FuzzedDataProvider.h>
8
9
#include <mariadb/mysql.h>
10
11
#include <stdint.h>
12
#include <string.h>
13
#include <stdlib.h>
14
#include <vector>
15
#include <string>
16
#include <algorithm>
17
#include <unordered_map>
18
19
struct StmtState;
20
21
static thread_local FuzzedDataProvider* g_fdp = nullptr;
22
static thread_local std::string g_mysql_error;
23
static thread_local std::string g_stmt_error;
24
static thread_local std::string g_tls_cipher;
25
26
struct MockResRow {
27
    enum ColKind {
28
        CK_UINT64, CK_UINT32, CK_UINT8, CK_STRING, CK_BLOB
29
    } kind;
30
    std::string s;
31
    std::vector<uint8_t> blob;
32
    uint64_t u64 = 0;
33
    uint32_t u32 = 0;
34
    uint8_t u8 = 0;
35
};
36
37
struct StmtState {
38
    MYSQL* mysql;
39
    std::string sql;
40
    std::vector<std::vector<MockResRow>> rows;
41
    size_t fetch_index = 0;
42
    MYSQL_BIND* res_binds = nullptr;
43
    size_t res_binds_count = 0;
44
    bool has_rows = false;
45
    bool is_version_stmt = false;
46
    unsigned int field_count = 0;
47
};
48
49
static thread_local std::unordered_map<MYSQL_STMT*, StmtState*> g_stmt_state;
50
static thread_local std::vector<MYSQL_STMT*> g_all_stmts;
51
static thread_local std::vector<StmtState*> g_live_stmts;
52
53
2.08k
extern "C" void mysqlmock_load_bytes(const uint8_t* data, size_t size) {
54
2.08k
    for (auto* st : g_all_stmts) {
55
0
       free(st);
56
0
    }
57
2.08k
    g_all_stmts.clear();
58
2.08k
    for (auto& kv : g_stmt_state) {
59
0
        delete kv.second;
60
0
    }
61
2.08k
    g_stmt_state.clear();
62
2.08k
    g_stmt_state.rehash(0);
63
2.08k
    delete g_fdp;
64
2.08k
    g_fdp = new FuzzedDataProvider(data, size);
65
2.08k
}
66
67
1.69k
static bool is_like(const std::string& hay, const char* needle) {
68
1.69k
    std::string h = hay;
69
1.69k
    std::string n = needle ? needle : "";
70
1.69k
    std::transform(h.begin(), h.end(), h.begin(), ::tolower);
71
1.69k
    std::transform(n.begin(), n.end(), n.begin(), ::tolower);
72
1.69k
    return h.find(n) != std::string::npos;
73
1.69k
}
74
75
0
static MYSQL_STMT* make_stmt() {
76
0
    auto s = new StmtState();
77
0
    s->mysql = reinterpret_cast<MYSQL*>(0x1);
78
0
    g_live_stmts.push_back(s);
79
0
    return reinterpret_cast<MYSQL_STMT*>(s);
80
0
}
81
82
8.49k
static StmtState* SS(MYSQL_STMT* st) {
83
8.49k
    if (!st) {
84
0
        return nullptr;
85
0
    }
86
8.49k
    auto it = g_stmt_state.find(st);
87
8.49k
    return (it == g_stmt_state.end()) ? nullptr : it->second;
88
8.49k
}
89
90
5.09k
static StmtState* ensure_state(MYSQL_STMT* st) {
91
5.09k
    auto s = SS(st);
92
5.09k
    if (s) return s;
93
1.69k
    s = new StmtState();
94
1.69k
    s->mysql = reinterpret_cast<MYSQL*>(0x1);
95
1.69k
    g_stmt_state[st] = s;
96
1.69k
    g_live_stmts.push_back(s);
97
1.69k
    return s;
98
5.09k
}
99
100
0
static unsigned int infer_field_count_from_sql(const std::string& sql) {
101
0
    std::string s = sql;
102
0
    std::transform(s.begin(), s.end(), s.begin(), ::tolower);
103
0
    auto psel = s.find("select");
104
0
    if (psel == std::string::npos) {
105
0
        return 0;
106
0
    }
107
0
    auto pfrom = s.find(" from ", psel);
108
0
    if (pfrom == std::string::npos) {
109
0
        return 0;
110
0
    }
111
0
    std::string proj = s.substr(psel + 6, pfrom - (psel + 6));
112
113
0
    proj.erase(0, proj.find_first_not_of(" \t\r\n"));
114
0
    proj.erase(proj.find_last_not_of(" \t\r\n")+1);
115
0
    if (proj.rfind("distinct", 0) == 0) {
116
0
        proj.erase(0, 8);
117
0
        proj.erase(0, proj.find_first_not_of(" \t\r\n"));
118
0
    }
119
120
0
    unsigned int cols = 0;
121
0
    int paren = 0;
122
0
    bool in_s = false, in_d = false;
123
0
    for (size_t i = 0; i < proj.size(); ++i) {
124
0
        char c = proj[i];
125
0
        if (!in_d && c=='\'' && (i==0 || proj[i-1] != '\\')) {
126
0
            in_s = !in_s;
127
0
        } else if (!in_s && c=='"' && (i==0 || proj[i-1] != '\\')) {
128
0
            in_d = !in_d;
129
0
        } else if (!in_s && !in_d) {
130
0
            if (c=='(') {
131
0
                paren++;
132
0
            } else if (c==')' && paren>0) {
133
0
                paren--;
134
0
            } else if (c==',' && paren==0) {
135
0
                cols++;
136
0
            }
137
0
        }
138
0
    }
139
0
    if (!proj.empty()) {
140
0
        cols++;
141
0
    }
142
0
    return cols;
143
0
}
144
145
1.69k
static void fill_version_stmt(StmtState* s) {
146
1.69k
    s->is_version_stmt = true;
147
1.69k
    s->has_rows = true;
148
1.69k
    s->rows.clear();
149
1.69k
    s->field_count = 2;
150
1.69k
    std::vector<MockResRow> r;
151
1.69k
    MockResRow c1;
152
1.69k
    c1.kind = MockResRow::CK_UINT32;
153
1.69k
    c1.u32 = 32;
154
1.69k
    MockResRow c2;
155
1.69k
    c2.kind = MockResRow::CK_UINT32;
156
1.69k
    c2.u32 = 0;
157
1.69k
    r.push_back(c1);
158
1.69k
    r.push_back(c2);
159
1.69k
    s->rows.push_back(r);
160
1.69k
}
161
162
0
static void fill_no_rows(StmtState* s, unsigned int cols = 0) {
163
0
    s->is_version_stmt = false;
164
0
    s->has_rows = false;
165
0
    s->rows.clear();
166
0
    s->fetch_index = 0;
167
0
    s->field_count = cols;
168
0
}
169
170
0
static void fill_fuzz_rows(StmtState* s, unsigned int ncols) {
171
0
    s->is_version_stmt = false;
172
0
    s->has_rows = true;
173
0
    s->rows.clear();
174
0
    s->fetch_index = 0;
175
0
    s->field_count = ncols ? ncols : 1;;
176
0
    unsigned int nrows = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 3);
177
0
    for (unsigned int r = 0; r < nrows; ++r) {
178
0
        std::vector<MockResRow> row;
179
0
        row.reserve(ncols);
180
0
        for (unsigned int c = 0; c < ncols; ++c) {
181
0
            int k = g_fdp->ConsumeIntegralInRange<int>(0, 4);
182
0
            MockResRow cell;
183
0
            switch (k) {
184
0
                case 0:
185
0
                    cell.kind = MockResRow::CK_UINT64;
186
0
                    cell.u64 = g_fdp->ConsumeIntegral<uint64_t>();
187
0
                    break;
188
0
                case 1:
189
0
                    cell.kind = MockResRow::CK_UINT32;
190
0
                    cell.u32 = g_fdp->ConsumeIntegral<uint32_t>();
191
0
                    break;
192
0
                case 2:
193
0
                    cell.kind = MockResRow::CK_UINT8;
194
0
                    cell.u8 = g_fdp->ConsumeIntegral<uint8_t>();
195
0
                    break;
196
0
                case 3:
197
0
                    cell.kind = MockResRow::CK_STRING;
198
0
                    cell.s = g_fdp->ConsumeRandomLengthString(32);
199
0
                    break;
200
0
                case 4: {
201
0
                    cell.kind = MockResRow::CK_BLOB;
202
0
                    size_t n = g_fdp->ConsumeIntegralInRange<size_t>(0, 32);
203
0
                    cell.blob = g_fdp->ConsumeBytes<uint8_t>(n);
204
0
                    break;
205
0
                }
206
0
            }
207
0
            row.push_back(cell);
208
0
        }
209
0
        s->rows.push_back(std::move(row));
210
0
    }
211
0
}
212
213
extern "C" {
214
4
    int mysql_server_init(int argc, char **argv, char **groups) {
215
4
        return 0;
216
4
    }
217
218
4.54k
    MYSQL* mysql_init(MYSQL* in) {
219
4.54k
        return in ? in : reinterpret_cast<MYSQL*>(0x1);
220
4.54k
    }
221
222
4.54k
    void mysql_close(MYSQL*) {}
223
224
0
    void mysql_free_result(MYSQL_RES*) {}
225
226
0
    void mysql_server_end(void) {}
227
228
    MYSQL* mysql_real_connect(MYSQL* mysql,
229
                              const char*,
230
                              const char*,
231
                              const char*,
232
                              const char*,
233
                              unsigned int,
234
                              const char*,
235
1.69k
                              unsigned long) {
236
1.69k
        return mysql ? mysql : reinterpret_cast<MYSQL*>(0x1);
237
1.69k
    }
238
239
0
    unsigned int mysql_errno(MYSQL*) {
240
0
        return 0;
241
0
    }
242
243
0
    const char* mysql_error(MYSQL*) {
244
0
        g_mysql_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string();
245
0
        return g_mysql_error.c_str();
246
0
    }
247
248
1.69k
    MYSQL_STMT* mysql_stmt_init(MYSQL*) {
249
1.69k
        auto* stmt = static_cast<MYSQL_STMT*>(std::calloc(1, sizeof(MYSQL_STMT)));
250
1.69k
        if (!stmt) {
251
0
            return nullptr;
252
0
        }
253
1.69k
        stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
254
1.69k
        g_all_stmts.push_back(stmt);
255
1.69k
        auto* s = ensure_state(stmt);
256
1.69k
        s->mysql = stmt->mysql;
257
1.69k
        return stmt;
258
1.69k
    }
259
260
1.69k
    int mysql_stmt_prepare(MYSQL_STMT* stmt, const char* q, unsigned long len) {
261
1.69k
        if (!stmt) {
262
0
            return 1;
263
0
        }
264
1.69k
        stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
265
1.69k
        auto* s = ensure_state(stmt);
266
267
1.69k
        const uintptr_t pq = reinterpret_cast<uintptr_t>(q);
268
1.69k
        const bool ptr_ok  = (q != nullptr) && (pq > 0x10000);
269
1.69k
        const bool len_ok  = (len > 0);
270
1.69k
        if (ptr_ok && len_ok) {
271
1.69k
            size_t copy_len = std::min<size_t>(len, 4096);
272
1.69k
            s->sql.assign(q, copy_len);
273
1.69k
        } else {
274
0
            s->sql = "SELECT version, minor FROM schema_version";
275
0
        }
276
277
1.69k
        if (is_like(s->sql, "schema_version") || is_like(s->sql, "select version") || is_like(s->sql, "get_version")) {
278
1.69k
            fill_version_stmt(s);
279
1.69k
            return 0;
280
1.69k
        }
281
282
0
        unsigned int cols = infer_field_count_from_sql(s->sql);
283
0
        if (cols == 0) {
284
0
            cols = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 8);
285
0
        }
286
0
        s->field_count = cols;
287
288
0
        if (g_fdp->ConsumeBool()) {
289
0
            fill_fuzz_rows(s, cols);
290
0
        } else {
291
0
            fill_no_rows(s);
292
0
        }
293
0
        return 0;
294
1.69k
    }
295
296
1.69k
    my_bool mysql_stmt_close(MYSQL_STMT* stmt) {
297
1.69k
        if (stmt) {
298
1.69k
            auto it = g_stmt_state.find(stmt);
299
1.69k
            if (it != g_stmt_state.end()) {
300
1.69k
                delete it->second;
301
1.69k
                g_stmt_state.erase(it);
302
1.69k
                if (g_stmt_state.empty()) {
303
1.69k
                    g_stmt_state.rehash(0);
304
1.69k
                }
305
1.69k
            }
306
1.69k
            auto it2 = std::find(g_all_stmts.begin(), g_all_stmts.end(), stmt);
307
1.69k
            if (it2 != g_all_stmts.end()) {
308
1.69k
                g_all_stmts.erase(it2);
309
1.69k
            }
310
1.69k
            free(stmt);
311
1.69k
        }
312
1.69k
        return 0;
313
1.69k
    }
314
315
1.69k
    my_bool mysql_stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) {
316
1.69k
        auto* s = ensure_state(stmt);
317
1.69k
        s->res_binds = bnd;
318
1.69k
        s->res_binds_count = bnd ? static_cast<size_t>(mysql_stmt_field_count(stmt)) : 0;
319
1.69k
        return 0;
320
1.69k
    }
321
322
1.69k
    int mysql_stmt_execute(MYSQL_STMT*) {
323
1.69k
        return 0;
324
1.69k
    }
325
326
0
    int mysql_stmt_store_result(MYSQL_STMT*) {
327
0
        return 0;
328
0
    }
329
330
0
    my_bool mysql_stmt_free_result(MYSQL_STMT* stmt) {
331
0
        auto s = SS(stmt);
332
0
        if (s) {
333
0
            s->fetch_index = 0;
334
0
        }
335
0
        if (stmt) {
336
0
            stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
337
0
        }
338
0
        return 0;
339
0
    }
340
341
0
    my_ulonglong mysql_stmt_affected_rows(MYSQL_STMT*) {
342
0
        return 0ULL;
343
0
    }
344
345
0
    my_bool mysql_stmt_reset(MYSQL_STMT* stmt) {
346
0
        auto s = SS(stmt);
347
0
        if (s) {
348
0
            s->fetch_index = 0;
349
0
        }
350
0
        if (stmt) {
351
0
            stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
352
0
        }
353
0
        return 0;
354
0
    }
355
356
0
    unsigned int mysql_stmt_errno(MYSQL_STMT*) {
357
0
        return 0;
358
0
    }
359
360
0
    const char* mysql_stmt_error(MYSQL_STMT*) {
361
0
        g_stmt_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string();
362
0
        return g_stmt_error.c_str();
363
0
    }
364
365
5.56k
    int mysql_options(MYSQL*, enum mysql_option, const void*) {
366
5.56k
        return 0;
367
5.56k
    }
368
369
1.69k
    my_bool mysql_autocommit(MYSQL*, my_bool) {
370
1.69k
        return 0;
371
1.69k
    }
372
373
0
    my_bool mysql_commit(MYSQL*) {
374
0
        return 0;
375
0
    }
376
377
0
    my_bool mysql_rollback(MYSQL*) {
378
0
        return 0;
379
0
    }
380
381
0
    int mysql_query(MYSQL*, const char*) {
382
0
        return 0;
383
0
    }
384
385
0
    my_bool mysql_stmt_bind_param(MYSQL_STMT*, MYSQL_BIND*) {
386
0
        return 0;
387
0
    }
388
389
1.69k
    unsigned int mysql_stmt_field_count(MYSQL_STMT* stmt) {
390
1.69k
        auto s = SS(stmt);
391
1.69k
        return s ? s->field_count : 0u;
392
1.69k
    }
393
394
0
    MYSQL_RES* mysql_stmt_result_metadata(MYSQL_STMT* stmt) {
395
0
        auto s = SS(stmt);
396
0
        if (!s) {
397
0
            return reinterpret_cast<MYSQL_RES*>(0x1);
398
0
        }
399
0
        if (s->is_version_stmt) {
400
0
            return reinterpret_cast<MYSQL_RES*>(0x1);
401
0
        }
402
0
        return nullptr;
403
0
    }
404
405
0
    my_ulonglong mysql_insert_id(MYSQL*) {
406
0
        return g_fdp ? g_fdp->ConsumeIntegral<my_ulonglong>() : 0ULL;
407
0
    }
408
409
0
    const char* mysql_get_ssl_cipher(MYSQL*) {
410
0
        if (g_fdp && g_fdp->ConsumeBool()) {
411
0
            g_tls_cipher = g_fdp->ConsumeRandomLengthString(64);
412
0
            return g_tls_cipher.c_str();
413
0
        }
414
0
        return "TLS_FAKE_CIPHER_WITH_FAKE_SHA256";
415
0
    }
416
417
1.69k
    int mysql_stmt_fetch(MYSQL_STMT* stmt) {
418
1.69k
        auto s = SS(stmt);
419
420
1.69k
        if (!s || !s->has_rows){
421
0
             return MYSQL_NO_DATA;
422
0
        }
423
1.69k
        if (s->fetch_index >= s->rows.size()) {
424
0
            return MYSQL_NO_DATA;
425
0
        }
426
1.69k
        if (!s->res_binds) {
427
0
            return MYSQL_NO_DATA;
428
0
        }
429
430
1.69k
        const auto& row = s->rows[s->fetch_index++];
431
1.69k
        size_t cols = row.size();
432
1.69k
        if (s->field_count && cols > s->field_count) {
433
0
            cols = s->field_count;
434
0
        }
435
1.69k
        if (cols > s->res_binds_count) {
436
0
            cols = s->res_binds_count;
437
0
        }
438
439
5.09k
        for (size_t i = 0; i < cols; ++i) {
440
3.39k
            const auto& cell = row[i];
441
3.39k
            MYSQL_BIND& b = s->res_binds[i];
442
3.39k
            if (!b.buffer) {
443
0
                continue;
444
0
            }
445
3.39k
            switch (cell.kind) {
446
3.39k
                case MockResRow::CK_UINT32: {
447
3.39k
                    uint32_t v = cell.u32;
448
3.39k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
449
3.39k
                        memcpy(b.buffer, &v, sizeof(v));
450
3.39k
                        if (b.length) *b.length = sizeof(v);
451
3.39k
                    }
452
3.39k
                    break;
453
0
                }
454
0
                case MockResRow::CK_UINT64: {
455
0
                    uint64_t v = cell.u64;
456
0
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
457
0
                        memcpy(b.buffer, &v, sizeof(v));
458
0
                        if (b.length) *b.length = sizeof(v);
459
0
                    }
460
0
                    break;
461
0
                }
462
0
                case MockResRow::CK_UINT8: {
463
0
                    uint8_t v = cell.u8;
464
0
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
465
0
                        memcpy(b.buffer, &v, sizeof(v));
466
0
                        if (b.length) *b.length = sizeof(v);
467
0
                    }
468
0
                    break;
469
0
                }
470
0
                case MockResRow::CK_STRING: {
471
0
                    if (b.buffer_length > 0) {
472
0
                        size_t n = std::min<size_t>(b.buffer_length - 1, cell.s.size());
473
0
                        memcpy(b.buffer, cell.s.data(), n);
474
0
                        reinterpret_cast<char*>(b.buffer)[n] = '\0';
475
0
                        if (b.length) *b.length = n;
476
0
                    }
477
0
                    break;
478
0
                }
479
0
                case MockResRow::CK_BLOB: {
480
0
                    if (b.buffer_length > 0) {
481
0
                        size_t n = std::min<size_t>(b.buffer_length, cell.blob.size());
482
0
                        memcpy(b.buffer, cell.blob.data(), n);
483
0
                        if (b.length) *b.length = n;
484
0
                    }
485
0
                    break;
486
0
                }
487
3.39k
            }
488
3.39k
        }
489
1.69k
        return 0;
490
1.69k
    }
491
}