/src/llama.cpp/fuzzers/fuzz_inference.cpp
Line | Count | Source |
1 | | /* Copyright 2024 Google LLC |
2 | | Licensed under the Apache License, Version 2.0 (the "License"); |
3 | | you may not use this file except in compliance with the License. |
4 | | You may obtain a copy of the License at |
5 | | http://www.apache.org/licenses/LICENSE-2.0 |
6 | | Unless required by applicable law or agreed to in writing, software |
7 | | distributed under the License is distributed on an "AS IS" BASIS, |
8 | | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
9 | | See the License for the specific language governing permissions and |
10 | | limitations under the License. |
11 | | */ |
12 | | |
13 | | #include "common.h" |
14 | | #include "llama.h" |
15 | | |
16 | | #include <fuzzer/FuzzedDataProvider.h> |
17 | | #include <iostream> |
18 | | #include <setjmp.h> |
19 | | #include <unistd.h> |
20 | | #include <vector> |
21 | | |
22 | | jmp_buf fuzzing_jmp_buf; |
23 | | |
24 | 55 | extern "C" void __wrap_abort(void) { longjmp(fuzzing_jmp_buf, 1); } |
25 | | |
26 | 864 | extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { |
27 | | |
28 | 864 | FuzzedDataProvider fdp(data, size); |
29 | | |
30 | 864 | std::string model_payload = fdp.ConsumeRandomLengthString(); |
31 | 864 | if (model_payload.size() < 10) { |
32 | 2 | return 0; |
33 | 2 | } |
34 | 862 | model_payload[0] = 'G'; |
35 | 862 | model_payload[1] = 'G'; |
36 | 862 | model_payload[2] = 'U'; |
37 | 862 | model_payload[3] = 'F'; |
38 | | |
39 | 862 | std::string prompt = fdp.ConsumeRandomLengthString(); |
40 | | |
41 | 862 | llama_backend_init(); |
42 | | |
43 | 862 | common_params params; |
44 | 862 | params.prompt = prompt.c_str(); |
45 | 862 | params.n_predict = 4; |
46 | | |
47 | | // Create and load the model |
48 | 862 | char filename[256]; |
49 | 862 | sprintf(filename, "/tmp/libfuzzer.%d", getpid()); |
50 | | |
51 | 862 | FILE *fp = fopen(filename, "wb"); |
52 | 862 | if (!fp) { |
53 | 0 | return 0; |
54 | 0 | } |
55 | 862 | fwrite(model_payload.c_str(), model_payload.size(), 1, fp); |
56 | 862 | fclose(fp); |
57 | | |
58 | 862 | llama_model_params model_params = common_model_params_to_llama(params); |
59 | 862 | model_params.use_mmap = false; |
60 | | |
61 | 862 | const int n_predict = params.n_predict; |
62 | 862 | if (setjmp(fuzzing_jmp_buf) == 0) { |
63 | 862 | auto *model = llama_load_model_from_file(filename, model_params); |
64 | 862 | if (model != nullptr) { |
65 | | |
66 | | // Now time to do inference. |
67 | 0 | llama_context_params ctx_params = |
68 | 0 | common_context_params_to_llama(params); |
69 | 0 | llama_context *ctx = llama_new_context_with_model(model, ctx_params); |
70 | 0 | if (ctx != NULL) { |
71 | | /* |
72 | | std::vector<llama_token> tokens_list; |
73 | | tokens_list = ::llama_tokenize(ctx, params.prompt, true); |
74 | | |
75 | | const int n_ctx = llama_n_ctx(ctx); |
76 | | const int n_kv_req = |
77 | | tokens_list.size() + (n_predict - tokens_list.size()); |
78 | | |
79 | | if (n_kv_req <= n_ctx) { |
80 | | llama_batch batch = llama_batch_init(512, 0, 1); |
81 | | |
82 | | for (size_t i = 0; i < tokens_list.size(); i++) { |
83 | | llama_batch_add(batch, tokens_list[i], i, {0}, false); |
84 | | } |
85 | | |
86 | | // set to only output logits for last token |
87 | | batch.logits[batch.n_tokens - 1] = true; |
88 | | if (llama_decode(ctx, batch) == 0) { |
89 | | int n_cur = batch.n_tokens; |
90 | | while (n_cur <= n_predict) { |
91 | | { |
92 | | auto n_vocab = llama_n_vocab(model); |
93 | | auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); |
94 | | |
95 | | std::vector<llama_token_data> candidates; |
96 | | candidates.reserve(n_vocab); |
97 | | |
98 | | for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
99 | | candidates.emplace_back( |
100 | | llama_token_data{token_id, logits[token_id], 0.0f}); |
101 | | } |
102 | | |
103 | | llama_token_data_array candidates_p = { |
104 | | candidates.data(), candidates.size(), false}; |
105 | | |
106 | | // sample the most likely token |
107 | | const llama_token new_token_id = |
108 | | llama_sample_token_greedy(ctx, &candidates_p); |
109 | | |
110 | | // exit if end of generation |
111 | | if (llama_token_is_eog(model, new_token_id) || |
112 | | n_cur == n_predict) { |
113 | | break; |
114 | | } |
115 | | |
116 | | // Prepare for next iteration |
117 | | llama_batch_clear(batch); |
118 | | llama_batch_add(batch, new_token_id, n_cur, {0}, true); |
119 | | } |
120 | | |
121 | | n_cur += 1; |
122 | | |
123 | | if (llama_decode(ctx, batch)) { |
124 | | break; |
125 | | } |
126 | | } |
127 | | } |
128 | | llama_batch_free(batch); |
129 | | } |
130 | | */ |
131 | 0 | llama_free(ctx); |
132 | 0 | } |
133 | |
|
134 | 0 | llama_free_model(model); |
135 | 0 | } |
136 | 862 | } |
137 | 862 | llama_backend_free(); |
138 | | |
139 | 862 | unlink(filename); |
140 | 862 | return 0; |
141 | 862 | } |