/src/kea-fuzzer/mysqlmock.cc
Line | Count | Source |
1 | | // Copyright (C) 2025 Ada Logcis Ltd. |
2 | | // |
3 | | // This Source Code Form is subject to the terms of the Mozilla Public |
4 | | // License, v. 2.0. If a copy of the MPL was not distributed with this |
5 | | // file, You can obtain one at http://mozilla.org/MPL/2.0/. |
6 | | //////////////////////////////////////////////////////////////////////////////// |
7 | | #include <fuzzer/FuzzedDataProvider.h> |
8 | | |
9 | | #include <mariadb/mysql.h> |
10 | | |
11 | | #include <stdint.h> |
12 | | #include <string.h> |
13 | | #include <stdlib.h> |
14 | | #include <vector> |
15 | | #include <string> |
16 | | #include <algorithm> |
17 | | #include <unordered_map> |
18 | | |
19 | | struct StmtState; |
20 | | |
21 | | static thread_local FuzzedDataProvider* g_fdp = nullptr; |
22 | | static thread_local std::string g_mysql_error; |
23 | | static thread_local std::string g_stmt_error; |
24 | | static thread_local std::string g_tls_cipher; |
25 | | static thread_local std::unordered_map<MYSQL_STMT*, StmtState*> g_stmt_state; |
26 | | static thread_local std::vector<MYSQL_STMT*> g_all_stmts; |
27 | | |
28 | 6.76k | extern "C" void mysqlmock_load_bytes(const uint8_t* data, size_t size) { |
29 | 6.76k | for (auto* st : g_all_stmts) { |
30 | 0 | free(st); |
31 | 0 | } |
32 | 6.76k | g_all_stmts.clear(); |
33 | 6.76k | for (auto& kv : g_stmt_state) { |
34 | 0 | delete kv.second; |
35 | 0 | } |
36 | 6.76k | g_stmt_state.clear(); |
37 | 6.76k | g_stmt_state.rehash(0); |
38 | 6.76k | delete g_fdp; |
39 | 6.76k | g_fdp = new FuzzedDataProvider(data, size); |
40 | 6.76k | } |
41 | | |
42 | | struct MockResRow { |
43 | | enum ColKind { |
44 | | CK_UINT64, CK_UINT32, CK_UINT8, CK_STRING, CK_BLOB |
45 | | } kind; |
46 | | std::string s; |
47 | | std::vector<uint8_t> blob; |
48 | | uint64_t u64 = 0; |
49 | | uint32_t u32 = 0; |
50 | | uint8_t u8 = 0; |
51 | | }; |
52 | | |
53 | | struct StmtState { |
54 | | MYSQL* mysql; |
55 | | std::string sql; |
56 | | std::vector<std::vector<MockResRow>> rows; |
57 | | size_t fetch_index = 0; |
58 | | MYSQL_BIND* res_binds = nullptr; |
59 | | size_t res_binds_count = 0; |
60 | | bool has_rows = false; |
61 | | bool is_version_stmt = false; |
62 | | unsigned int field_count = 0; |
63 | | }; |
64 | | |
65 | | static thread_local std::vector<StmtState*> g_live_stmts; |
66 | | |
67 | 2.17M | static bool is_like(const std::string& hay, const char* needle) { |
68 | 2.17M | std::string h = hay; |
69 | 2.17M | std::string n = needle ? needle : ""; |
70 | 2.17M | std::transform(h.begin(), h.end(), h.begin(), ::tolower); |
71 | 2.17M | std::transform(n.begin(), n.end(), n.begin(), ::tolower); |
72 | 2.17M | return h.find(n) != std::string::npos; |
73 | 2.17M | } |
74 | | |
75 | 0 | static MYSQL_STMT* make_stmt() { |
76 | 0 | auto s = new StmtState(); |
77 | 0 | s->mysql = reinterpret_cast<MYSQL*>(0x1); |
78 | 0 | g_live_stmts.push_back(s); |
79 | 0 | return reinterpret_cast<MYSQL_STMT*>(s); |
80 | 0 | } |
81 | | |
82 | 1.55M | static StmtState* SS(MYSQL_STMT* st) { |
83 | 1.55M | if (!st) { |
84 | 0 | return nullptr; |
85 | 0 | } |
86 | 1.55M | auto it = g_stmt_state.find(st); |
87 | 1.55M | return (it == g_stmt_state.end()) ? nullptr : it->second; |
88 | 1.55M | } |
89 | | |
90 | 1.49M | static StmtState* ensure_state(MYSQL_STMT* st) { |
91 | 1.49M | auto s = SS(st); |
92 | 1.49M | if (s) return s; |
93 | 733k | s = new StmtState(); |
94 | 733k | s->mysql = reinterpret_cast<MYSQL*>(0x1); |
95 | 733k | g_stmt_state[st] = s; |
96 | 733k | g_live_stmts.push_back(s); |
97 | 733k | return s; |
98 | 1.49M | } |
99 | | |
100 | 720k | static unsigned int infer_field_count_from_sql(const std::string& sql) { |
101 | 720k | std::string s = sql; |
102 | 720k | std::transform(s.begin(), s.end(), s.begin(), ::tolower); |
103 | 720k | auto psel = s.find("select"); |
104 | 720k | if (psel == std::string::npos) { |
105 | 348k | return 0; |
106 | 348k | } |
107 | 371k | auto pfrom = s.find(" from ", psel); |
108 | 371k | if (pfrom == std::string::npos) { |
109 | 0 | return 0; |
110 | 0 | } |
111 | 371k | std::string proj = s.substr(psel + 6, pfrom - (psel + 6)); |
112 | | |
113 | 371k | proj.erase(0, proj.find_first_not_of(" \t\r\n")); |
114 | 371k | proj.erase(proj.find_last_not_of(" \t\r\n")+1); |
115 | 371k | if (proj.rfind("distinct", 0) == 0) { |
116 | 0 | proj.erase(0, 8); |
117 | 0 | proj.erase(0, proj.find_first_not_of(" \t\r\n")); |
118 | 0 | } |
119 | | |
120 | 371k | unsigned int cols = 0; |
121 | 371k | int paren = 0; |
122 | 371k | bool in_s = false, in_d = false; |
123 | 206M | for (size_t i = 0; i < proj.size(); ++i) { |
124 | 206M | char c = proj[i]; |
125 | 206M | if (!in_d && c=='\'' && (i==0 || proj[i-1] != '\\')) { |
126 | 0 | in_s = !in_s; |
127 | 206M | } else if (!in_s && c=='"' && (i==0 || proj[i-1] != '\\')) { |
128 | 0 | in_d = !in_d; |
129 | 206M | } else if (!in_s && !in_d) { |
130 | 206M | if (c=='(') { |
131 | 0 | paren++; |
132 | 206M | } else if (c==')' && paren>0) { |
133 | 0 | paren--; |
134 | 206M | } else if (c==',' && paren==0) { |
135 | 11.4M | cols++; |
136 | 11.4M | } |
137 | 206M | } |
138 | 206M | } |
139 | 371k | if (!proj.empty()) { |
140 | 371k | cols++; |
141 | 371k | } |
142 | 371k | return cols; |
143 | 371k | } |
144 | | |
145 | 12.9k | static void fill_version_stmt(StmtState* s) { |
146 | 12.9k | s->is_version_stmt = true; |
147 | 12.9k | s->has_rows = true; |
148 | 12.9k | s->rows.clear(); |
149 | 12.9k | s->field_count = 2; |
150 | 12.9k | std::vector<MockResRow> r; |
151 | 12.9k | MockResRow c1; |
152 | 12.9k | c1.kind = MockResRow::CK_UINT32; |
153 | 12.9k | c1.u32 = 32; |
154 | 12.9k | MockResRow c2; |
155 | 12.9k | c2.kind = MockResRow::CK_UINT32; |
156 | 12.9k | c2.u32 = 0; |
157 | 12.9k | r.push_back(c1); |
158 | 12.9k | r.push_back(c2); |
159 | 12.9k | s->rows.push_back(r); |
160 | 12.9k | } |
161 | | |
162 | 682k | static void fill_no_rows(StmtState* s, unsigned int cols = 0) { |
163 | 682k | s->is_version_stmt = false; |
164 | 682k | s->has_rows = false; |
165 | 682k | s->rows.clear(); |
166 | 682k | s->fetch_index = 0; |
167 | 682k | s->field_count = cols; |
168 | 682k | } |
169 | | |
170 | 37.5k | static void fill_fuzz_rows(StmtState* s, unsigned int ncols) { |
171 | 37.5k | s->is_version_stmt = false; |
172 | 37.5k | s->has_rows = true; |
173 | 37.5k | s->rows.clear(); |
174 | 37.5k | s->fetch_index = 0; |
175 | 37.5k | s->field_count = ncols ? ncols : 1;; |
176 | 37.5k | unsigned int nrows = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 3); |
177 | 112k | for (unsigned int r = 0; r < nrows; ++r) { |
178 | 74.8k | std::vector<MockResRow> row; |
179 | 74.8k | row.reserve(ncols); |
180 | 1.57M | for (unsigned int c = 0; c < ncols; ++c) { |
181 | 1.50M | int k = g_fdp->ConsumeIntegralInRange<int>(0, 4); |
182 | 1.50M | MockResRow cell; |
183 | 1.50M | switch (k) { |
184 | 473k | case 0: |
185 | 473k | cell.kind = MockResRow::CK_UINT64; |
186 | 473k | cell.u64 = g_fdp->ConsumeIntegral<uint64_t>(); |
187 | 473k | break; |
188 | 90.9k | case 1: |
189 | 90.9k | cell.kind = MockResRow::CK_UINT32; |
190 | 90.9k | cell.u32 = g_fdp->ConsumeIntegral<uint32_t>(); |
191 | 90.9k | break; |
192 | 566k | case 2: |
193 | 566k | cell.kind = MockResRow::CK_UINT8; |
194 | 566k | cell.u8 = g_fdp->ConsumeIntegral<uint8_t>(); |
195 | 566k | break; |
196 | 239k | case 3: |
197 | 239k | cell.kind = MockResRow::CK_STRING; |
198 | 239k | cell.s = g_fdp->ConsumeRandomLengthString(32); |
199 | 239k | break; |
200 | 132k | case 4: { |
201 | 132k | cell.kind = MockResRow::CK_BLOB; |
202 | 132k | size_t n = g_fdp->ConsumeIntegralInRange<size_t>(0, 32); |
203 | 132k | cell.blob = g_fdp->ConsumeBytes<uint8_t>(n); |
204 | 132k | break; |
205 | 0 | } |
206 | 1.50M | } |
207 | 1.50M | row.push_back(cell); |
208 | 1.50M | } |
209 | 74.8k | s->rows.push_back(std::move(row)); |
210 | 74.8k | } |
211 | 37.5k | } |
212 | | |
213 | | extern "C" { |
214 | 26.7k | MYSQL* mysql_init(MYSQL* in) { |
215 | 26.7k | return in ? in : reinterpret_cast<MYSQL*>(0x1); |
216 | 26.7k | } |
217 | | |
218 | 26.7k | void mysql_close(MYSQL*) {} |
219 | | |
220 | 0 | void mysql_free_result(MYSQL_RES*) {} |
221 | | |
222 | 0 | void mysql_server_end(void) {} |
223 | | |
224 | | MYSQL* mysql_real_connect(MYSQL* mysql, |
225 | | const char*, |
226 | | const char*, |
227 | | const char*, |
228 | | const char*, |
229 | | unsigned int, |
230 | | const char*, |
231 | 25.8k | unsigned long) { |
232 | 25.8k | return mysql ? mysql : reinterpret_cast<MYSQL*>(0x1); |
233 | 25.8k | } |
234 | | |
235 | 0 | unsigned int mysql_errno(MYSQL*) { |
236 | 0 | return 0; |
237 | 0 | } |
238 | | |
239 | 0 | const char* mysql_error(MYSQL*) { |
240 | 0 | g_mysql_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string(); |
241 | 0 | return g_mysql_error.c_str(); |
242 | 0 | } |
243 | | |
244 | 733k | MYSQL_STMT* mysql_stmt_init(MYSQL*) { |
245 | 733k | auto* stmt = static_cast<MYSQL_STMT*>(std::calloc(1, sizeof(MYSQL_STMT))); |
246 | 733k | if (!stmt) { |
247 | 0 | return nullptr; |
248 | 0 | } |
249 | 733k | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
250 | 733k | g_all_stmts.push_back(stmt); |
251 | 733k | auto* s = ensure_state(stmt); |
252 | 733k | s->mysql = stmt->mysql; |
253 | 733k | return stmt; |
254 | 733k | } |
255 | | |
256 | 733k | int mysql_stmt_prepare(MYSQL_STMT* stmt, const char* q, unsigned long len) { |
257 | 733k | if (!stmt) { |
258 | 0 | return 1; |
259 | 0 | } |
260 | 733k | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
261 | 733k | auto* s = ensure_state(stmt); |
262 | | |
263 | 733k | const uintptr_t pq = reinterpret_cast<uintptr_t>(q); |
264 | 733k | const bool ptr_ok = (q != nullptr) && (pq > 0x10000); |
265 | 733k | const bool len_ok = (len > 0); |
266 | 733k | if (ptr_ok && len_ok) { |
267 | 733k | size_t copy_len = std::min<size_t>(len, 4096); |
268 | 733k | s->sql.assign(q, copy_len); |
269 | 733k | } else { |
270 | 0 | s->sql = "SELECT version, minor FROM schema_version"; |
271 | 0 | } |
272 | | |
273 | 733k | if (is_like(s->sql, "schema_version") || is_like(s->sql, "select version") || is_like(s->sql, "get_version")) { |
274 | 12.9k | fill_version_stmt(s); |
275 | 12.9k | return 0; |
276 | 12.9k | } |
277 | | |
278 | 720k | unsigned int cols = infer_field_count_from_sql(s->sql); |
279 | 720k | if (cols == 0) { |
280 | 348k | cols = g_fdp->ConsumeIntegralInRange<unsigned int>(1, 8); |
281 | 348k | } |
282 | 720k | s->field_count = cols; |
283 | | |
284 | 720k | if (g_fdp->ConsumeBool()) { |
285 | 37.5k | fill_fuzz_rows(s, cols); |
286 | 682k | } else { |
287 | 682k | fill_no_rows(s); |
288 | 682k | } |
289 | 720k | return 0; |
290 | 733k | } |
291 | | |
292 | 733k | my_bool mysql_stmt_close(MYSQL_STMT* stmt) { |
293 | 733k | if (stmt) { |
294 | 733k | auto it = g_stmt_state.find(stmt); |
295 | 733k | if (it != g_stmt_state.end()) { |
296 | 733k | delete it->second; |
297 | 733k | g_stmt_state.erase(it); |
298 | 733k | if (g_stmt_state.empty()) { |
299 | 19.3k | g_stmt_state.rehash(0); |
300 | 19.3k | } |
301 | 733k | } |
302 | 733k | auto it2 = std::find(g_all_stmts.begin(), g_all_stmts.end(), stmt); |
303 | 733k | if (it2 != g_all_stmts.end()) { |
304 | 733k | g_all_stmts.erase(it2); |
305 | 733k | } |
306 | 733k | free(stmt); |
307 | 733k | } |
308 | 733k | return 0; |
309 | 733k | } |
310 | | |
311 | 24.9k | my_bool mysql_stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) { |
312 | 24.9k | auto* s = ensure_state(stmt); |
313 | 24.9k | s->res_binds = bnd; |
314 | 24.9k | s->res_binds_count = bnd ? static_cast<size_t>(mysql_stmt_field_count(stmt)) : 0; |
315 | 24.9k | return 0; |
316 | 24.9k | } |
317 | | |
318 | 76.2k | int mysql_stmt_execute(MYSQL_STMT*) { |
319 | 76.2k | return 0; |
320 | 76.2k | } |
321 | | |
322 | 12.0k | int mysql_stmt_store_result(MYSQL_STMT*) { |
323 | 12.0k | return 0; |
324 | 12.0k | } |
325 | | |
326 | 12.0k | my_bool mysql_stmt_free_result(MYSQL_STMT* stmt) { |
327 | 12.0k | auto s = SS(stmt); |
328 | 12.0k | if (s) { |
329 | 12.0k | s->fetch_index = 0; |
330 | 12.0k | } |
331 | 12.0k | if (stmt) { |
332 | 12.0k | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
333 | 12.0k | } |
334 | 12.0k | return 0; |
335 | 12.0k | } |
336 | | |
337 | 12.7k | my_ulonglong mysql_stmt_affected_rows(MYSQL_STMT*) { |
338 | 12.7k | return 0ULL; |
339 | 12.7k | } |
340 | | |
341 | 0 | my_bool mysql_stmt_reset(MYSQL_STMT* stmt) { |
342 | 0 | auto s = SS(stmt); |
343 | 0 | if (s) { |
344 | 0 | s->fetch_index = 0; |
345 | 0 | } |
346 | 0 | if (stmt) { |
347 | 0 | stmt->mysql = reinterpret_cast<MYSQL*>(0x1); |
348 | 0 | } |
349 | 0 | return 0; |
350 | 0 | } |
351 | | |
352 | 0 | unsigned int mysql_stmt_errno(MYSQL_STMT*) { |
353 | 0 | return 0; |
354 | 0 | } |
355 | | |
356 | 0 | const char* mysql_stmt_error(MYSQL_STMT*) { |
357 | 0 | g_stmt_error = g_fdp ? g_fdp->ConsumeRandomLengthString(32) : std::string(); |
358 | 0 | return g_stmt_error.c_str(); |
359 | 0 | } |
360 | | |
361 | 82.6k | int mysql_options(MYSQL*, enum mysql_option, const void*) { |
362 | 82.6k | return 0; |
363 | 82.6k | } |
364 | | |
365 | 25.8k | my_bool mysql_autocommit(MYSQL*, my_bool) { |
366 | 25.8k | return 0; |
367 | 25.8k | } |
368 | | |
369 | 17.2k | my_bool mysql_commit(MYSQL*) { |
370 | 17.2k | return 0; |
371 | 17.2k | } |
372 | | |
373 | 73 | my_bool mysql_rollback(MYSQL*) { |
374 | 73 | return 0; |
375 | 73 | } |
376 | | |
377 | 17.2k | int mysql_query(MYSQL*, const char*) { |
378 | 17.2k | return 0; |
379 | 17.2k | } |
380 | | |
381 | 57.2k | my_bool mysql_stmt_bind_param(MYSQL_STMT*, MYSQL_BIND*) { |
382 | 57.2k | return 0; |
383 | 57.2k | } |
384 | | |
385 | 24.9k | unsigned int mysql_stmt_field_count(MYSQL_STMT* stmt) { |
386 | 24.9k | auto s = SS(stmt); |
387 | 24.9k | return s ? s->field_count : 0u; |
388 | 24.9k | } |
389 | | |
390 | 0 | MYSQL_RES* mysql_stmt_result_metadata(MYSQL_STMT* stmt) { |
391 | 0 | auto s = SS(stmt); |
392 | 0 | if (!s) { |
393 | 0 | return reinterpret_cast<MYSQL_RES*>(0x1); |
394 | 0 | } |
395 | 0 | if (s->is_version_stmt) { |
396 | 0 | return reinterpret_cast<MYSQL_RES*>(0x1); |
397 | 0 | } |
398 | 0 | return nullptr; |
399 | 0 | } |
400 | | |
401 | 6.38k | my_ulonglong mysql_insert_id(MYSQL*) { |
402 | 6.38k | return g_fdp ? g_fdp->ConsumeIntegral<my_ulonglong>() : 0ULL; |
403 | 6.38k | } |
404 | | |
405 | 492 | const char* mysql_get_ssl_cipher(MYSQL*) { |
406 | 492 | if (g_fdp && g_fdp->ConsumeBool()) { |
407 | 385 | g_tls_cipher = g_fdp->ConsumeRandomLengthString(64); |
408 | 385 | return g_tls_cipher.c_str(); |
409 | 385 | } |
410 | 107 | return "TLS_FAKE_CIPHER_WITH_FAKE_SHA256"; |
411 | 492 | } |
412 | | |
413 | 26.3k | int mysql_stmt_fetch(MYSQL_STMT* stmt) { |
414 | 26.3k | auto s = SS(stmt); |
415 | | |
416 | 26.3k | if (!s || !s->has_rows){ |
417 | 8.46k | return MYSQL_NO_DATA; |
418 | 8.46k | } |
419 | 17.8k | if (s->fetch_index >= s->rows.size()) { |
420 | 258 | return MYSQL_NO_DATA; |
421 | 258 | } |
422 | 17.5k | if (!s->res_binds) { |
423 | 0 | return MYSQL_NO_DATA; |
424 | 0 | } |
425 | | |
426 | 17.5k | const auto& row = s->rows[s->fetch_index++]; |
427 | 17.5k | size_t cols = row.size(); |
428 | 17.5k | if (s->field_count && cols > s->field_count) { |
429 | 0 | cols = s->field_count; |
430 | 0 | } |
431 | 17.5k | if (cols > s->res_binds_count) { |
432 | 0 | cols = s->res_binds_count; |
433 | 0 | } |
434 | | |
435 | 139k | for (size_t i = 0; i < cols; ++i) { |
436 | 121k | const auto& cell = row[i]; |
437 | 121k | MYSQL_BIND& b = s->res_binds[i]; |
438 | 121k | if (!b.buffer) { |
439 | 0 | continue; |
440 | 0 | } |
441 | 121k | switch (cell.kind) { |
442 | 30.0k | case MockResRow::CK_UINT32: { |
443 | 30.0k | uint32_t v = cell.u32; |
444 | 30.0k | if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) { |
445 | 29.2k | memcpy(b.buffer, &v, sizeof(v)); |
446 | 29.2k | if (b.length) *b.length = sizeof(v); |
447 | 29.2k | } |
448 | 30.0k | break; |
449 | 0 | } |
450 | 62.2k | case MockResRow::CK_UINT64: { |
451 | 62.2k | uint64_t v = cell.u64; |
452 | 62.2k | if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) { |
453 | 36.4k | memcpy(b.buffer, &v, sizeof(v)); |
454 | 36.4k | if (b.length) *b.length = sizeof(v); |
455 | 36.4k | } |
456 | 62.2k | break; |
457 | 0 | } |
458 | 14.9k | case MockResRow::CK_UINT8: { |
459 | 14.9k | uint8_t v = cell.u8; |
460 | 14.9k | if (b.buffer_length == 0 || b.buffer_length >= sizeof(v)) { |
461 | 14.9k | memcpy(b.buffer, &v, sizeof(v)); |
462 | 14.9k | if (b.length) *b.length = sizeof(v); |
463 | 14.9k | } |
464 | 14.9k | break; |
465 | 0 | } |
466 | 6.67k | case MockResRow::CK_STRING: { |
467 | 6.67k | if (b.buffer_length > 0) { |
468 | 6.67k | size_t n = std::min<size_t>(b.buffer_length - 1, cell.s.size()); |
469 | 6.67k | memcpy(b.buffer, cell.s.data(), n); |
470 | 6.67k | reinterpret_cast<char*>(b.buffer)[n] = '\0'; |
471 | 6.67k | if (b.length) *b.length = n; |
472 | 6.67k | } |
473 | 6.67k | break; |
474 | 0 | } |
475 | 8.04k | case MockResRow::CK_BLOB: { |
476 | 8.04k | if (b.buffer_length > 0) { |
477 | 8.04k | size_t n = std::min<size_t>(b.buffer_length, cell.blob.size()); |
478 | 8.04k | memcpy(b.buffer, cell.blob.data(), n); |
479 | 8.04k | if (b.length) *b.length = n; |
480 | 8.04k | } |
481 | 8.04k | break; |
482 | 0 | } |
483 | 121k | } |
484 | 121k | } |
485 | 17.5k | return 0; |
486 | 17.5k | } |
487 | | } |