/src/kea-fuzzer/mysqlmock.cc
Line | Count | Source |
1 | | // Copyright (C) 2025 Ada Logcis Ltd. |
2 | | // |
3 | | // This Source Code Form is subject to the terms of the Mozilla Public |
4 | | // License, v. 2.0. If a copy of the MPL was not distributed with this |
5 | | // file, You can obtain one at http://mozilla.org/MPL/2.0/. |
6 | | //////////////////////////////////////////////////////////////////////////////// |
7 | | #include <fuzzer/FuzzedDataProvider.h> |
8 | | |
9 | | #include <mariadb/mysql.h> |
10 | | |
11 | | #include <stdint.h> |
12 | | #include <string.h> |
13 | | #include <stdlib.h> |
14 | | #include <vector> |
15 | | #include <string> |
16 | | #include <algorithm> |
17 | | #include <unordered_map> |
18 | | |
19 | | struct StmtState; |
20 | | |
21 | | static thread_local FuzzedDataProvider* g_fdp = nullptr; |
22 | | static thread_local std::string g_mysql_error; |
23 | | static thread_local std::string g_stmt_error; |
24 | | static thread_local std::string g_tls_cipher; |
25 | | |
26 | | struct MockResRow { |
27 | | enum ColKind { |
28 | | CK_UINT64, CK_UINT32, CK_UINT8, CK_STRING, CK_BLOB |
29 | | } kind; |
30 | | std::string s; |
31 | | std::vector<uint8_t> blob; |
32 | | uint64_t u64 = 0; |
33 | | uint32_t u32 = 0; |
34 | | uint8_t u8 = 0; |
35 | | }; |
36 | | |
37 | | struct StmtState { |
38 | | MYSQL* mysql; |
39 | | std::string sql; |
40 | | std::vector<std::vector<MockResRow>> rows; |
41 | | size_t fetch_index = 0; |
42 | | MYSQL_BIND* res_binds = nullptr; |
43 | | size_t res_binds_count = 0; |
44 | | bool has_rows = false; |
45 | | bool is_version_stmt = false; |
46 | | unsigned int field_count = 0; |
47 | | }; |
48 | | |
49 | | static thread_local std::unordered_map<MYSQL_STMT*, StmtState*> g_stmt_state; |
50 | | static thread_local std::vector<MYSQL_STMT*> g_all_stmts; |
51 | | static thread_local std::vector<StmtState*> g_live_stmts; |
52 | | |
53 | 2.08k | extern "C" void mysqlmock_load_bytes(const uint8_t* data, size_t size) { |
54 | 2.08k | for (auto* st : g_all_stmts) { |
55 | 0 | free(st); |
56 | 0 | } |
57 | 2.08k | g_all_stmts.clear(); |
58 | 2.08k | for (auto& kv : g_stmt_state) { |
59 | 0 | delete kv.second; |
60 | 0 | } |
61 | 2.08k | g_stmt_state.clear(); |
62 | 2.08k | g_stmt_state.rehash(0); |
63 | 2.08k | delete g_fdp; |
64 | 2.08k | g_fdp = new FuzzedDataProvider(data, size); |
65 | 2.08k | } |
66 | | |
67 | 1.69k | static bool is_like(const std::string& hay, const char* needle) { |
68 | 1.69k | std::string h = hay; |
69 | 1.69k | std::string n = needle ? needle : ""; |
70 | 1.69k | std::transform(h.begin(), h.end(), h.begin(), ::tolower); |
71 | 1.69k | std::transform(n.begin(), n.end(), n.begin(), ::tolower); |
72 | 1.69k | return h.find(n) != std::string::npos; |
73 | 1.69k | } |
74 | | |
75 | 0 | static MYSQL_STMT* make_stmt() { |
76 | 0 | auto s = new StmtState(); |
77 | 0 | s->mysql = reinterpret_cast<MYSQL*>(0x1); |
78 | 0 | g_live_stmts.push_back(s); |
79 | 0 | return reinterpret_cast<MYSQL_STMT*>(s); |
80 | 0 | } |
81 | | |
82 | 8.49k | static StmtState* SS(MYSQL_STMT* st) { |
83 | 8.49k | if (!st) { |
84 | 0 | return nullptr; |
85 | 0 | } |
86 | 8.49k | auto it = g_stmt_state.find(st); |
87 | 8.49k | return (it == g_stmt_state.end()) ? nullptr : it->second; |
88 | 8.49k | } |
89 | | |
90 | 5.09k | static StmtState* ensure_state(MYSQL_STMT* st) { |
91 | 5.09k | auto s = SS(st); |
92 | 5.09k | if (s) return s; |
93 | 1.69k | s = new StmtState(); |
94 | 1.69k | s->mysql = reinterpret_cast<MYSQL*>(0x1); |
95 | 1.69k | g_stmt_state[st] = s; |
96 | 1.69k | g_live_stmts.push_back(s); |
97 | 1.69k | return s; |
98 | 5.09k | } |
99 | | |
100 | 0 | static unsigned int infer_field_count_from_sql(const std::string& sql) { |
101 | 0 | std::string s = sql; |
102 | 0 | std::transform(s.begin(), s.end(), s.begin(), ::tolower); |
103 | 0 | auto psel = s.find("select"); |
104 | 0 | if (psel == std::string::npos) { |
105 | 0 | return 0; |
106 | 0 | } |
107 | 0 | auto pfrom = s.find(" from ", psel); |
108 | 0 | if (pfrom == std::string::npos) { |
109 | 0 | return 0; |
110 | 0 | } |
111 | 0 | std::string proj = s.substr(psel + 6, pfrom - (psel + 6)); |
112 | |
|
113 | 0 | proj.erase(0, proj.find_first_not_of(" \t\r\n")); |
114 | 0 | proj.erase(proj.find_last_not_of(" \t\r\n")+1); |
115 | 0 | if (proj.rfind("distinct", 0) == 0) { |
116 | 0 | proj.erase(0, 8); |
117 | 0 | proj.erase(0, proj.find_first_not_of(" \t\r\n")); |
118 | 0 | } |
119 | |
|
120 | 0 | unsigned int cols = 0; |
121 | 0 | int paren = 0; |
122 | 0 | bool in_s = false, in_d = false; |
123 | 0 | for (size_t i = 0; i < proj.size(); ++i) { |
124 | 0 | char c = proj[i]; |
125 | 0 | if (!in_d && c=='\'' && (i==0 || proj[i-1] != '\\')) { |
126 | 0 | in_s = !in_s; |
127 | 0 | } else if (!in_s && c=='"' && (i==0 || proj[i-1] != '\\')) { |
128 | 0 | in_d = !in_d; |
129 | 0 | } else if (!in_s && !in_d) { |
130 | 0 | if (c=='(') { |
131 | 0 | paren++; |
132 | 0 | } else if (c==')' && paren>0) { |
133 | 0 | paren--; |
134 | 0 | } else if (c==',' && paren==0) { |
135 | 0 | cols++; |
136 | 0 | } |
137 | 0 | } |
138 | 0 | } |
139 | 0 | if (!proj.empty()) { |
140 | 0 | cols++; |
141 | 0 | } |
142 | 0 | return cols; |
143 | 0 | } |
144 | | |
145 | 1.69k | static void fill_version_stmt(StmtState* s) { |
146 | 1.69k | s->is_version_stmt = true; |
147 | 1.69k | s->has_rows = true; |
148 | 1.69k | s->rows.clear(); |
149 | 1.69k | s->field_count = 2; |
150 | 1.69k | std::vector<MockResRow> r; |
151 | 1.69k | MockResRow c1; |
152 | 1.69k | c1.kind = MockResRow::CK_UINT32; |
153 | 1.69k | c1.u32 = 32; |
154 | 1.69k | MockResRow c2; |
155 | 1.69k | c2.kind = MockResRow::CK_UINT32; |
156 | 1.69k | c2.u32 = 0; |
157 | 1.69k | r.push_back(c1); |
158 | 1.69k | r.push_back(c2); |
159 | 1.69k | s->rows.push_back(r); |
160 | 1.69k | } |
161 | | |
162 | 0 | static void fill_no_rows(StmtState* s, unsigned int cols = 0) { |
163 | 0 | s->is_version_stmt = false; |
164 | 0 | s->has_rows = false; |
165 | 0 | s->rows.clear(); |
166 | 0 | s->fetch_index = 0; |
167 | 0 | s->field_count = cols; |
168 | 0 | } |
169 | | |
170 | 0 | static void fill_fuzz_rows(StmtState* s, unsigned int ncols) { |
171 | 0 | s->is_version_stmt = false; |
172 | 0 | s->has_rows = true; |
173 | 0 | s->rows.clear(); |
174 | 0 | s->fetch_index = 0; |
175 | 0 | s->field_count = ncols ? ncols : 1;; |
176 | 0 | unsigned int nrows = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 3); |
177 | 0 | for (unsigned int r = 0; r < nrows; ++r) { |
178 | 0 | std::vector<MockResRow> row; |
179 | 0 | row.reserve(ncols); |
180 | 0 | for (unsigned int c = 0; c < ncols; ++c) { |
181 | 0 | int k = g_fdp->ConsumeIntegralInRange<int>(0, 4); |
182 | 0 | MockResRow cell; |
183 | 0 | switch (k) { |
184 | 0 | case 0: |
185 | 0 | cell.kind = MockResRow::CK_UINT64; |
186 | 0 | cell.u64 = g_fdp->ConsumeIntegral<uint64_t>(); |
187 | 0 | break; |
188 | 0 | case 1: |
189 | 0 | cell.kind = MockResRow::CK_UINT32; |
190 | 0 | cell.u32 = g_fdp->ConsumeIntegral<uint32_t>(); |
191 | 0 | break; |
192 | 0 | case 2: |
193 | 0 | cell.kind = MockResRow::CK_UINT8; |
194 | 0 | cell.u8 = g_fdp->ConsumeIntegral<uint8_t>(); |
195 | 0 | break; |
196 | 0 | case 3: |
197 | 0 | cell.kind = MockResRow::CK_STRING; |
198 | 0 | cell.s = g_fdp->ConsumeRandomLengthString(32); |
199 | 0 | break; |
200 | 0 | case 4: { |
201 | 0 | cell.kind = MockResRow::CK_BLOB; |
202 | 0 | size_t n = g_fdp->ConsumeIntegralInRange<size_t>(0, 32); |
203 | 0 | cell.blob = g_fdp->ConsumeBytes<uint8_t>(n); |
204 | 0 | break; |
205 | 0 | } |
206 | 0 | } |
207 | 0 | row.push_back(cell); |
208 | 0 | } |
209 | 0 | s->rows.push_back(std::move(row)); |
210 | 0 | } |
211 | 0 | } |
212 | | |
213 | | extern "C" { |
214 | 4 | int mysql_server_init(int argc, char **argv, char **groups) { |
215 | 4 | return 0; |
216 | 4 | } |
217 | | |
218 | 4.54k | MYSQL* mysql_init(MYSQL* in) { |
219 | 4.54k | return in ? in : reinterpret_cast<MYSQL*>(0x1); |
220 | 4.54k | } |
221 | | |
222 | 4.54k | void mysql_close(MYSQL*) {} |
223 | | |
224 | 0 | void mysql_free_result(MYSQL_RES*) {} |
225 | | |
226 | 0 | void mysql_server_end(void) {} |
227 | | |
228 | | MYSQL* mysql_real_connect(MYSQL* mysql, |
229 | | const char*, |
230 | | const char*, |
231 | | const char*, |
232 | | const char*, |
233 | | unsigned int, |
234 | | const char*, |
235 | 1.69k | unsigned long) { |
236 | 1.69k | return mysql ? mysql : reinterpret_cast<MYSQL*>(0x1); |
237 | 1.69k | } |
238 | | |
239 | 0 | unsigned int mysql_errno(MYSQL*) { |
240 | 0 | return 0; |
241 | 0 | } |
242 | | |
243 | 0 | const char* mysql_error(MYSQL*) { |
244 | 0 | g_mysql_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string(); |
245 | 0 | return g_mysql_error.c_str(); |
246 | 0 | } |
247 | | |
248 | 1.69k | MYSQL_STMT* mysql_stmt_init(MYSQL*) { |
249 | 1.69k | auto* stmt = static_cast<MYSQL_STMT*>(std::calloc(1, sizeof(MYSQL_STMT))); |
250 | 1.69k | if (!stmt) { |
251 | 0 | return nullptr; |
252 | 0 | } |
253 | 1.69k | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
254 | 1.69k | g_all_stmts.push_back(stmt); |
255 | 1.69k | auto* s = ensure_state(stmt); |
256 | 1.69k | s->mysql = stmt->mysql; |
257 | 1.69k | return stmt; |
258 | 1.69k | } |
259 | | |
260 | 1.69k | int mysql_stmt_prepare(MYSQL_STMT* stmt, const char* q, unsigned long len) { |
261 | 1.69k | if (!stmt) { |
262 | 0 | return 1; |
263 | 0 | } |
264 | 1.69k | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
265 | 1.69k | auto* s = ensure_state(stmt); |
266 | | |
267 | 1.69k | const uintptr_t pq = reinterpret_cast<uintptr_t>(q); |
268 | 1.69k | const bool ptr_ok = (q != nullptr) && (pq > 0x10000); |
269 | 1.69k | const bool len_ok = (len > 0); |
270 | 1.69k | if (ptr_ok && len_ok) { |
271 | 1.69k | size_t copy_len = std::min<size_t>(len, 4096); |
272 | 1.69k | s->sql.assign(q, copy_len); |
273 | 1.69k | } else { |
274 | 0 | s->sql = "SELECT version, minor FROM schema_version"; |
275 | 0 | } |
276 | | |
277 | 1.69k | if (is_like(s->sql, "schema_version") || is_like(s->sql, "select version") || is_like(s->sql, "get_version")) { |
278 | 1.69k | fill_version_stmt(s); |
279 | 1.69k | return 0; |
280 | 1.69k | } |
281 | | |
282 | 0 | unsigned int cols = infer_field_count_from_sql(s->sql); |
283 | 0 | if (cols == 0) { |
284 | 0 | cols = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 8); |
285 | 0 | } |
286 | 0 | s->field_count = cols; |
287 | |
|
288 | 0 | if (g_fdp->ConsumeBool()) { |
289 | 0 | fill_fuzz_rows(s, cols); |
290 | 0 | } else { |
291 | 0 | fill_no_rows(s); |
292 | 0 | } |
293 | 0 | return 0; |
294 | 1.69k | } |
295 | | |
296 | 1.69k | my_bool mysql_stmt_close(MYSQL_STMT* stmt) { |
297 | 1.69k | if (stmt) { |
298 | 1.69k | auto it = g_stmt_state.find(stmt); |
299 | 1.69k | if (it != g_stmt_state.end()) { |
300 | 1.69k | delete it->second; |
301 | 1.69k | g_stmt_state.erase(it); |
302 | 1.69k | if (g_stmt_state.empty()) { |
303 | 1.69k | g_stmt_state.rehash(0); |
304 | 1.69k | } |
305 | 1.69k | } |
306 | 1.69k | auto it2 = std::find(g_all_stmts.begin(), g_all_stmts.end(), stmt); |
307 | 1.69k | if (it2 != g_all_stmts.end()) { |
308 | 1.69k | g_all_stmts.erase(it2); |
309 | 1.69k | } |
310 | 1.69k | free(stmt); |
311 | 1.69k | } |
312 | 1.69k | return 0; |
313 | 1.69k | } |
314 | | |
315 | 1.69k | my_bool mysql_stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) { |
316 | 1.69k | auto* s = ensure_state(stmt); |
317 | 1.69k | s->res_binds = bnd; |
318 | 1.69k | s->res_binds_count = bnd ? static_cast<size_t>(mysql_stmt_field_count(stmt)) : 0; |
319 | 1.69k | return 0; |
320 | 1.69k | } |
321 | | |
322 | 1.69k | int mysql_stmt_execute(MYSQL_STMT*) { |
323 | 1.69k | return 0; |
324 | 1.69k | } |
325 | | |
326 | 0 | int mysql_stmt_store_result(MYSQL_STMT*) { |
327 | 0 | return 0; |
328 | 0 | } |
329 | | |
330 | 0 | my_bool mysql_stmt_free_result(MYSQL_STMT* stmt) { |
331 | 0 | auto s = SS(stmt); |
332 | 0 | if (s) { |
333 | 0 | s->fetch_index = 0; |
334 | 0 | } |
335 | 0 | if (stmt) { |
336 | 0 | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
337 | 0 | } |
338 | 0 | return 0; |
339 | 0 | } |
340 | | |
341 | 0 | my_ulonglong mysql_stmt_affected_rows(MYSQL_STMT*) { |
342 | 0 | return 0ULL; |
343 | 0 | } |
344 | | |
345 | 0 | my_bool mysql_stmt_reset(MYSQL_STMT* stmt) { |
346 | 0 | auto s = SS(stmt); |
347 | 0 | if (s) { |
348 | 0 | s->fetch_index = 0; |
349 | 0 | } |
350 | 0 | if (stmt) { |
351 | 0 | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
352 | 0 | } |
353 | 0 | return 0; |
354 | 0 | } |
355 | | |
356 | 0 | unsigned int mysql_stmt_errno(MYSQL_STMT*) { |
357 | 0 | return 0; |
358 | 0 | } |
359 | | |
360 | 0 | const char* mysql_stmt_error(MYSQL_STMT*) { |
361 | 0 | g_stmt_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string(); |
362 | 0 | return g_stmt_error.c_str(); |
363 | 0 | } |
364 | | |
365 | 5.56k | int mysql_options(MYSQL*, enum mysql_option, const void*) { |
366 | 5.56k | return 0; |
367 | 5.56k | } |
368 | | |
369 | 1.69k | my_bool mysql_autocommit(MYSQL*, my_bool) { |
370 | 1.69k | return 0; |
371 | 1.69k | } |
372 | | |
373 | 0 | my_bool mysql_commit(MYSQL*) { |
374 | 0 | return 0; |
375 | 0 | } |
376 | | |
377 | 0 | my_bool mysql_rollback(MYSQL*) { |
378 | 0 | return 0; |
379 | 0 | } |
380 | | |
381 | 0 | int mysql_query(MYSQL*, const char*) { |
382 | 0 | return 0; |
383 | 0 | } |
384 | | |
385 | 0 | my_bool mysql_stmt_bind_param(MYSQL_STMT*, MYSQL_BIND*) { |
386 | 0 | return 0; |
387 | 0 | } |
388 | | |
389 | 1.69k | unsigned int mysql_stmt_field_count(MYSQL_STMT* stmt) { |
390 | 1.69k | auto s = SS(stmt); |
391 | 1.69k | return s ? s->field_count : 0u; |
392 | 1.69k | } |
393 | | |
394 | 0 | MYSQL_RES* mysql_stmt_result_metadata(MYSQL_STMT* stmt) { |
395 | 0 | auto s = SS(stmt); |
396 | 0 | if (!s) { |
397 | 0 | return reinterpret_cast<MYSQL_RES*>(0x1); |
398 | 0 | } |
399 | 0 | if (s->is_version_stmt) { |
400 | 0 | return reinterpret_cast<MYSQL_RES*>(0x1); |
401 | 0 | } |
402 | 0 | return nullptr; |
403 | 0 | } |
404 | | |
405 | 0 | my_ulonglong mysql_insert_id(MYSQL*) { |
406 | 0 | return g_fdp ? g_fdp->ConsumeIntegral<my_ulonglong>() : 0ULL; |
407 | 0 | } |
408 | | |
409 | 0 | const char* mysql_get_ssl_cipher(MYSQL*) { |
410 | 0 | if (g_fdp && g_fdp->ConsumeBool()) { |
411 | 0 | g_tls_cipher = g_fdp->ConsumeRandomLengthString(64); |
412 | 0 | return g_tls_cipher.c_str(); |
413 | 0 | } |
414 | 0 | return "TLS_FAKE_CIPHER_WITH_FAKE_SHA256"; |
415 | 0 | } |
416 | | |
417 | 1.69k | int mysql_stmt_fetch(MYSQL_STMT* stmt) { |
418 | 1.69k | auto s = SS(stmt); |
419 | | |
420 | 1.69k | if (!s || !s->has_rows){ |
421 | 0 | return MYSQL_NO_DATA; |
422 | 0 | } |
423 | 1.69k | if (s->fetch_index >= s->rows.size()) { |
424 | 0 | return MYSQL_NO_DATA; |
425 | 0 | } |
426 | 1.69k | if (!s->res_binds) { |
427 | 0 | return MYSQL_NO_DATA; |
428 | 0 | } |
429 | | |
430 | 1.69k | const auto& row = s->rows[s->fetch_index++]; |
431 | 1.69k | size_t cols = row.size(); |
432 | 1.69k | if (s->field_count && cols > s->field_count) { |
433 | 0 | cols = s->field_count; |
434 | 0 | } |
435 | 1.69k | if (cols > s->res_binds_count) { |
436 | 0 | cols = s->res_binds_count; |
437 | 0 | } |
438 | | |
439 | 5.09k | for (size_t i = 0; i < cols; ++i) { |
440 | 3.39k | const auto& cell = row[i]; |
441 | 3.39k | MYSQL_BIND& b = s->res_binds[i]; |
442 | 3.39k | if (!b.buffer) { |
443 | 0 | continue; |
444 | 0 | } |
445 | 3.39k | switch (cell.kind) { |
446 | 3.39k | case MockResRow::CK_UINT32: { |
447 | 3.39k | uint32_t v = cell.u32; |
448 | 3.39k | if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) { |
449 | 3.39k | memcpy(b.buffer, &v, sizeof(v)); |
450 | 3.39k | if (b.length) *b.length = sizeof(v); |
451 | 3.39k | } |
452 | 3.39k | break; |
453 | 0 | } |
454 | 0 | case MockResRow::CK_UINT64: { |
455 | 0 | uint64_t v = cell.u64; |
456 | 0 | if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) { |
457 | 0 | memcpy(b.buffer, &v, sizeof(v)); |
458 | 0 | if (b.length) *b.length = sizeof(v); |
459 | 0 | } |
460 | 0 | break; |
461 | 0 | } |
462 | 0 | case MockResRow::CK_UINT8: { |
463 | 0 | uint8_t v = cell.u8; |
464 | 0 | if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) { |
465 | 0 | memcpy(b.buffer, &v, sizeof(v)); |
466 | 0 | if (b.length) *b.length = sizeof(v); |
467 | 0 | } |
468 | 0 | break; |
469 | 0 | } |
470 | 0 | case MockResRow::CK_STRING: { |
471 | 0 | if (b.buffer_length > 0) { |
472 | 0 | size_t n = std::min<size_t>(b.buffer_length - 1, cell.s.size()); |
473 | 0 | memcpy(b.buffer, cell.s.data(), n); |
474 | 0 | reinterpret_cast<char*>(b.buffer)[n] = '\0'; |
475 | 0 | if (b.length) *b.length = n; |
476 | 0 | } |
477 | 0 | break; |
478 | 0 | } |
479 | 0 | case MockResRow::CK_BLOB: { |
480 | 0 | if (b.buffer_length > 0) { |
481 | 0 | size_t n = std::min<size_t>(b.buffer_length, cell.blob.size()); |
482 | 0 | memcpy(b.buffer, cell.blob.data(), n); |
483 | 0 | if (b.length) *b.length = n; |
484 | 0 | } |
485 | 0 | break; |
486 | 0 | } |
487 | 3.39k | } |
488 | 3.39k | } |
489 | 1.69k | return 0; |
490 | 1.69k | } |
491 | | } |