Coverage Report

Created: 2025-11-16 07:29

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/kea-fuzzer/mysqlmock.cc
Line
Count
Source
1
// Copyright (C) 2025 Ada Logcis Ltd.
2
//
3
// This Source Code Form is subject to the terms of the Mozilla Public
4
// License, v. 2.0. If a copy of the MPL was not distributed with this
5
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
////////////////////////////////////////////////////////////////////////////////
7
#include <fuzzer/FuzzedDataProvider.h>
8
9
#include <mariadb/mysql.h>
10
11
#include <stdint.h>
12
#include <string.h>
13
#include <stdlib.h>
14
#include <vector>
15
#include <string>
16
#include <algorithm>
17
#include <unordered_map>
18
19
struct StmtState;
20
21
static thread_local FuzzedDataProvider* g_fdp = nullptr;
22
static thread_local std::string g_mysql_error;
23
static thread_local std::string g_stmt_error;
24
static thread_local std::string g_tls_cipher;
25
static thread_local std::unordered_map<MYSQL_STMT*, StmtState*> g_stmt_state;
26
static thread_local std::vector<MYSQL_STMT*> g_all_stmts;
27
28
6.76k
extern "C" void mysqlmock_load_bytes(const uint8_t* data, size_t size) {
29
6.76k
    for (auto* st : g_all_stmts) {
30
0
       free(st);
31
0
    }
32
6.76k
    g_all_stmts.clear();
33
6.76k
    for (auto& kv : g_stmt_state) {
34
0
        delete kv.second;
35
0
    }
36
6.76k
    g_stmt_state.clear();
37
6.76k
    g_stmt_state.rehash(0);
38
6.76k
    delete g_fdp;
39
6.76k
    g_fdp = new FuzzedDataProvider(data, size);
40
6.76k
}
41
42
struct MockResRow {
43
    enum ColKind {
44
        CK_UINT64, CK_UINT32, CK_UINT8, CK_STRING, CK_BLOB
45
    } kind;
46
    std::string s;
47
    std::vector<uint8_t> blob;
48
    uint64_t u64 = 0;
49
    uint32_t u32 = 0;
50
    uint8_t u8 = 0;
51
};
52
53
struct StmtState {
54
    MYSQL* mysql;
55
    std::string sql;
56
    std::vector<std::vector<MockResRow>> rows;
57
    size_t fetch_index = 0;
58
    MYSQL_BIND* res_binds = nullptr;
59
    size_t res_binds_count = 0;
60
    bool has_rows = false;
61
    bool is_version_stmt = false;
62
    unsigned int field_count = 0;
63
};
64
65
static thread_local std::vector<StmtState*> g_live_stmts;
66
67
2.17M
static bool is_like(const std::string& hay, const char* needle) {
68
2.17M
    std::string h = hay;
69
2.17M
    std::string n = needle ? needle : "";
70
2.17M
    std::transform(h.begin(), h.end(), h.begin(), ::tolower);
71
2.17M
    std::transform(n.begin(), n.end(), n.begin(), ::tolower);
72
2.17M
    return h.find(n) != std::string::npos;
73
2.17M
}
74
75
0
static MYSQL_STMT* make_stmt() {
76
0
    auto s = new StmtState();
77
0
    s->mysql = reinterpret_cast<MYSQL*>(0x1);
78
0
    g_live_stmts.push_back(s);
79
0
    return reinterpret_cast<MYSQL_STMT*>(s);
80
0
}
81
82
1.55M
static StmtState* SS(MYSQL_STMT* st) {
83
1.55M
    if (!st) {
84
0
        return nullptr;
85
0
    }
86
1.55M
    auto it = g_stmt_state.find(st);
87
1.55M
    return (it == g_stmt_state.end()) ? nullptr : it->second;
88
1.55M
}
89
90
1.49M
static StmtState* ensure_state(MYSQL_STMT* st) {
91
1.49M
    auto s = SS(st);
92
1.49M
    if (s) return s;
93
733k
    s = new StmtState();
94
733k
    s->mysql = reinterpret_cast<MYSQL*>(0x1);
95
733k
    g_stmt_state[st] = s;
96
733k
    g_live_stmts.push_back(s);
97
733k
    return s;
98
1.49M
}
99
100
720k
static unsigned int infer_field_count_from_sql(const std::string& sql) {
101
720k
    std::string s = sql;
102
720k
    std::transform(s.begin(), s.end(), s.begin(), ::tolower);
103
720k
    auto psel = s.find("select");
104
720k
    if (psel == std::string::npos) {
105
348k
        return 0;
106
348k
    }
107
371k
    auto pfrom = s.find(" from ", psel);
108
371k
    if (pfrom == std::string::npos) {
109
0
        return 0;
110
0
    }
111
371k
    std::string proj = s.substr(psel + 6, pfrom - (psel + 6));
112
113
371k
    proj.erase(0, proj.find_first_not_of(" \t\r\n"));
114
371k
    proj.erase(proj.find_last_not_of(" \t\r\n")+1);
115
371k
    if (proj.rfind("distinct", 0) == 0) {
116
0
        proj.erase(0, 8);
117
0
        proj.erase(0, proj.find_first_not_of(" \t\r\n"));
118
0
    }
119
120
371k
    unsigned int cols = 0;
121
371k
    int paren = 0;
122
371k
    bool in_s = false, in_d = false;
123
206M
    for (size_t i = 0; i < proj.size(); ++i) {
124
206M
        char c = proj[i];
125
206M
        if (!in_d && c=='\'' && (i==0 || proj[i-1] != '\\')) {
126
0
            in_s = !in_s;
127
206M
        } else if (!in_s && c=='"' && (i==0 || proj[i-1] != '\\')) {
128
0
            in_d = !in_d;
129
206M
        } else if (!in_s && !in_d) {
130
206M
            if (c=='(') {
131
0
                paren++;
132
206M
            } else if (c==')' && paren>0) {
133
0
                paren--;
134
206M
            } else if (c==',' && paren==0) {
135
11.4M
                cols++;
136
11.4M
            }
137
206M
        }
138
206M
    }
139
371k
    if (!proj.empty()) {
140
371k
        cols++;
141
371k
    }
142
371k
    return cols;
143
371k
}
144
145
12.9k
static void fill_version_stmt(StmtState* s) {
146
12.9k
    s->is_version_stmt = true;
147
12.9k
    s->has_rows = true;
148
12.9k
    s->rows.clear();
149
12.9k
    s->field_count = 2;
150
12.9k
    std::vector<MockResRow> r;
151
12.9k
    MockResRow c1;
152
12.9k
    c1.kind = MockResRow::CK_UINT32;
153
12.9k
    c1.u32 = 32;
154
12.9k
    MockResRow c2;
155
12.9k
    c2.kind = MockResRow::CK_UINT32;
156
12.9k
    c2.u32 = 0;
157
12.9k
    r.push_back(c1);
158
12.9k
    r.push_back(c2);
159
12.9k
    s->rows.push_back(r);
160
12.9k
}
161
162
682k
static void fill_no_rows(StmtState* s, unsigned int cols = 0) {
163
682k
    s->is_version_stmt = false;
164
682k
    s->has_rows = false;
165
682k
    s->rows.clear();
166
682k
    s->fetch_index = 0;
167
682k
    s->field_count = cols;
168
682k
}
169
170
37.5k
static void fill_fuzz_rows(StmtState* s, unsigned int ncols) {
171
37.5k
    s->is_version_stmt = false;
172
37.5k
    s->has_rows = true;
173
37.5k
    s->rows.clear();
174
37.5k
    s->fetch_index = 0;
175
37.5k
    s->field_count = ncols ? ncols : 1;;
176
37.5k
    unsigned int nrows = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 3);
177
112k
    for (unsigned int r = 0; r < nrows; ++r) {
178
74.8k
        std::vector<MockResRow> row;
179
74.8k
        row.reserve(ncols);
180
1.57M
        for (unsigned int c = 0; c < ncols; ++c) {
181
1.50M
            int k = g_fdp->ConsumeIntegralInRange<int>(0, 4);
182
1.50M
            MockResRow cell;
183
1.50M
            switch (k) {
184
473k
                case 0:
185
473k
                    cell.kind = MockResRow::CK_UINT64;
186
473k
                    cell.u64 = g_fdp->ConsumeIntegral<uint64_t>();
187
473k
                    break;
188
90.9k
                case 1:
189
90.9k
                    cell.kind = MockResRow::CK_UINT32;
190
90.9k
                    cell.u32 = g_fdp->ConsumeIntegral<uint32_t>();
191
90.9k
                    break;
192
566k
                case 2:
193
566k
                    cell.kind = MockResRow::CK_UINT8;
194
566k
                    cell.u8 = g_fdp->ConsumeIntegral<uint8_t>();
195
566k
                    break;
196
239k
                case 3:
197
239k
                    cell.kind = MockResRow::CK_STRING;
198
239k
                    cell.s = g_fdp->ConsumeRandomLengthString(32);
199
239k
                    break;
200
132k
                case 4: {
201
132k
                    cell.kind = MockResRow::CK_BLOB;
202
132k
                    size_t n = g_fdp->ConsumeIntegralInRange<size_t>(0, 32);
203
132k
                    cell.blob = g_fdp->ConsumeBytes<uint8_t>(n);
204
132k
                    break;
205
0
                }
206
1.50M
            }
207
1.50M
            row.push_back(cell);
208
1.50M
        }
209
74.8k
        s->rows.push_back(std::move(row));
210
74.8k
    }
211
37.5k
}
212
213
extern "C" {
214
26.7k
    MYSQL* mysql_init(MYSQL* in) {
215
26.7k
        return in ? in : reinterpret_cast<MYSQL*>(0x1);
216
26.7k
    }
217
218
26.7k
    void mysql_close(MYSQL*) {}
219
220
0
    void mysql_free_result(MYSQL_RES*) {}
221
222
0
    void mysql_server_end(void) {}
223
224
    MYSQL* mysql_real_connect(MYSQL* mysql,
225
                              const char*,
226
                              const char*,
227
                              const char*,
228
                              const char*,
229
                              unsigned int,
230
                              const char*,
231
25.8k
                              unsigned long) {
232
25.8k
        return mysql ? mysql : reinterpret_cast<MYSQL*>(0x1);
233
25.8k
    }
234
235
0
    unsigned int mysql_errno(MYSQL*) {
236
0
        return 0;
237
0
    }
238
239
0
    const char* mysql_error(MYSQL*) {
240
0
        g_mysql_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string();
241
0
        return g_mysql_error.c_str();
242
0
    }
243
244
733k
    MYSQL_STMT* mysql_stmt_init(MYSQL*) {
245
733k
        auto* stmt = static_cast<MYSQL_STMT*>(std::calloc(1, sizeof(MYSQL_STMT)));
246
733k
        if (!stmt) {
247
0
            return nullptr;
248
0
        }
249
733k
        stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
250
733k
        g_all_stmts.push_back(stmt);
251
733k
        auto* s = ensure_state(stmt);
252
733k
        s->mysql = stmt->mysql;
253
733k
        return stmt;
254
733k
    }
255
256
733k
    int mysql_stmt_prepare(MYSQL_STMT* stmt, const char* q, unsigned long len) {
257
733k
        if (!stmt) {
258
0
            return 1;
259
0
        }
260
733k
        stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
261
733k
        auto* s = ensure_state(stmt);
262
263
733k
        const uintptr_t pq = reinterpret_cast<uintptr_t>(q);
264
733k
        const bool ptr_ok  = (q != nullptr) && (pq > 0x10000);
265
733k
        const bool len_ok  = (len > 0);
266
733k
        if (ptr_ok && len_ok) {
267
733k
            size_t copy_len = std::min<size_t>(len, 4096);
268
733k
            s->sql.assign(q, copy_len);
269
733k
        } else {
270
0
            s->sql = "SELECT version, minor FROM schema_version";
271
0
        }
272
273
733k
        if (is_like(s->sql, "schema_version") || is_like(s->sql, "select version") || is_like(s->sql, "get_version")) {
274
12.9k
            fill_version_stmt(s);
275
12.9k
            return 0;
276
12.9k
        }
277
278
720k
        unsigned int cols = infer_field_count_from_sql(s->sql);
279
720k
        if (cols == 0) {
280
348k
            cols = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 8);
281
348k
        }
282
720k
        s->field_count = cols;
283
284
720k
        if (g_fdp->ConsumeBool()) {
285
37.5k
            fill_fuzz_rows(s, cols);
286
682k
        } else {
287
682k
            fill_no_rows(s);
288
682k
        }
289
720k
        return 0;
290
733k
    }
291
292
733k
    my_bool mysql_stmt_close(MYSQL_STMT* stmt) {
293
733k
        if (stmt) {
294
733k
            auto it = g_stmt_state.find(stmt);
295
733k
            if (it != g_stmt_state.end()) {
296
733k
                delete it->second;
297
733k
                g_stmt_state.erase(it);
298
733k
                if (g_stmt_state.empty()) {
299
19.3k
                    g_stmt_state.rehash(0);
300
19.3k
                }
301
733k
            }
302
733k
            auto it2 = std::find(g_all_stmts.begin(), g_all_stmts.end(), stmt);
303
733k
            if (it2 != g_all_stmts.end()) {
304
733k
                g_all_stmts.erase(it2);
305
733k
            }
306
733k
            free(stmt);
307
733k
        }
308
733k
        return 0;
309
733k
    }
310
311
24.9k
    my_bool mysql_stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) {
312
24.9k
        auto* s = ensure_state(stmt);
313
24.9k
        s->res_binds = bnd;
314
24.9k
        s->res_binds_count = bnd ? static_cast<size_t>(mysql_stmt_field_count(stmt)) : 0;
315
24.9k
        return 0;
316
24.9k
    }
317
318
76.2k
    int mysql_stmt_execute(MYSQL_STMT*) {
319
76.2k
        return 0;
320
76.2k
    }
321
322
12.0k
    int mysql_stmt_store_result(MYSQL_STMT*) {
323
12.0k
        return 0;
324
12.0k
    }
325
326
12.0k
    my_bool mysql_stmt_free_result(MYSQL_STMT* stmt) {
327
12.0k
        auto s = SS(stmt);
328
12.0k
        if (s) {
329
12.0k
            s->fetch_index = 0;
330
12.0k
        }
331
12.0k
        if (stmt) {
332
12.0k
            stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
333
12.0k
        }
334
12.0k
        return 0;
335
12.0k
    }
336
337
12.7k
    my_ulonglong mysql_stmt_affected_rows(MYSQL_STMT*) {
338
12.7k
        return 0ULL;
339
12.7k
    }
340
341
0
    my_bool mysql_stmt_reset(MYSQL_STMT* stmt) {
342
0
        auto s = SS(stmt);
343
0
        if (s) {
344
0
            s->fetch_index = 0;
345
0
        }
346
0
        if (stmt) {
347
0
            stmt->mysql = reinterpret_cast<MYSQL*>(0x1);
348
0
        }
349
0
        return 0;
350
0
    }
351
352
0
    unsigned int mysql_stmt_errno(MYSQL_STMT*) {
353
0
        return 0;
354
0
    }
355
356
0
    const char* mysql_stmt_error(MYSQL_STMT*) {
357
0
        g_stmt_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string();
358
0
        return g_stmt_error.c_str();
359
0
    }
360
361
82.6k
    int mysql_options(MYSQL*, enum mysql_option, const void*) {
362
82.6k
        return 0;
363
82.6k
    }
364
365
25.8k
    my_bool mysql_autocommit(MYSQL*, my_bool) {
366
25.8k
        return 0;
367
25.8k
    }
368
369
17.2k
    my_bool mysql_commit(MYSQL*) {
370
17.2k
        return 0;
371
17.2k
    }
372
373
73
    my_bool mysql_rollback(MYSQL*) {
374
73
        return 0;
375
73
    }
376
377
17.2k
    int mysql_query(MYSQL*, const char*) {
378
17.2k
        return 0;
379
17.2k
    }
380
381
57.2k
    my_bool mysql_stmt_bind_param(MYSQL_STMT*, MYSQL_BIND*) {
382
57.2k
        return 0;
383
57.2k
    }
384
385
24.9k
    unsigned int mysql_stmt_field_count(MYSQL_STMT* stmt) {
386
24.9k
        auto s = SS(stmt);
387
24.9k
        return s ? s->field_count : 0u;
388
24.9k
    }
389
390
0
    MYSQL_RES* mysql_stmt_result_metadata(MYSQL_STMT* stmt) {
391
0
        auto s = SS(stmt);
392
0
        if (!s) {
393
0
            return reinterpret_cast<MYSQL_RES*>(0x1);
394
0
        }
395
0
        if (s->is_version_stmt) {
396
0
            return reinterpret_cast<MYSQL_RES*>(0x1);
397
0
        }
398
0
        return nullptr;
399
0
    }
400
401
6.38k
    my_ulonglong mysql_insert_id(MYSQL*) {
402
6.38k
        return g_fdp ? g_fdp->ConsumeIntegral<my_ulonglong>() : 0ULL;
403
6.38k
    }
404
405
492
    const char* mysql_get_ssl_cipher(MYSQL*) {
406
492
        if (g_fdp && g_fdp->ConsumeBool()) {
407
385
            g_tls_cipher = g_fdp->ConsumeRandomLengthString(64);
408
385
            return g_tls_cipher.c_str();
409
385
        }
410
107
        return "TLS_FAKE_CIPHER_WITH_FAKE_SHA256";
411
492
    }
412
413
26.3k
    int mysql_stmt_fetch(MYSQL_STMT* stmt) {
414
26.3k
        auto s = SS(stmt);
415
416
26.3k
        if (!s || !s->has_rows){
417
8.46k
             return MYSQL_NO_DATA;
418
8.46k
        }
419
17.8k
        if (s->fetch_index >= s->rows.size()) {
420
258
            return MYSQL_NO_DATA;
421
258
        }
422
17.5k
        if (!s->res_binds) {
423
0
            return MYSQL_NO_DATA;
424
0
        }
425
426
17.5k
        const auto& row = s->rows[s->fetch_index++];
427
17.5k
        size_t cols = row.size();
428
17.5k
        if (s->field_count && cols > s->field_count) {
429
0
            cols = s->field_count;
430
0
        }
431
17.5k
        if (cols > s->res_binds_count) {
432
0
            cols = s->res_binds_count;
433
0
        }
434
435
139k
        for (size_t i = 0; i < cols; ++i) {
436
121k
            const auto& cell = row[i];
437
121k
            MYSQL_BIND& b = s->res_binds[i];
438
121k
            if (!b.buffer) {
439
0
                continue;
440
0
            }
441
121k
            switch (cell.kind) {
442
30.0k
                case MockResRow::CK_UINT32: {
443
30.0k
                    uint32_t v = cell.u32;
444
30.0k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
445
29.2k
                        memcpy(b.buffer, &v, sizeof(v));
446
29.2k
                        if (b.length) *b.length = sizeof(v);
447
29.2k
                    }
448
30.0k
                    break;
449
0
                }
450
62.2k
                case MockResRow::CK_UINT64: {
451
62.2k
                    uint64_t v = cell.u64;
452
62.2k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
453
36.4k
                        memcpy(b.buffer, &v, sizeof(v));
454
36.4k
                        if (b.length) *b.length = sizeof(v);
455
36.4k
                    }
456
62.2k
                    break;
457
0
                }
458
14.9k
                case MockResRow::CK_UINT8: {
459
14.9k
                    uint8_t v = cell.u8;
460
14.9k
                    if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) {
461
14.9k
                        memcpy(b.buffer, &v, sizeof(v));
462
14.9k
                        if (b.length) *b.length = sizeof(v);
463
14.9k
                    }
464
14.9k
                    break;
465
0
                }
466
6.67k
                case MockResRow::CK_STRING: {
467
6.67k
                    if (b.buffer_length > 0) {
468
6.67k
                        size_t n = std::min<size_t>(b.buffer_length - 1, cell.s.size());
469
6.67k
                        memcpy(b.buffer, cell.s.data(), n);
470
6.67k
                        reinterpret_cast<char*>(b.buffer)[n] = '\0';
471
6.67k
                        if (b.length) *b.length = n;
472
6.67k
                    }
473
6.67k
                    break;
474
0
                }
475
8.04k
                case MockResRow::CK_BLOB: {
476
8.04k
                    if (b.buffer_length > 0) {
477
8.04k
                        size_t n = std::min<size_t>(b.buffer_length, cell.blob.size());
478
8.04k
                        memcpy(b.buffer, cell.blob.data(), n);
479
8.04k
                        if (b.length) *b.length = n;
480
8.04k
                    }
481
8.04k
                    break;
482
0
                }
483
121k
            }
484
121k
        }
485
17.5k
        return 0;
486
17.5k
    }
487
}