/src/WasmEdge/lib/llvm/lazyjit.cpp
Line | Count | Source |
1 | | // SPDX-License-Identifier: Apache-2.0 |
2 | | // SPDX-FileCopyrightText: Copyright The WasmEdge Authors |
3 | | |
4 | | #include "llvm/lazyjit.h" |
5 | | |
6 | | #include "ast/instruction.h" |
7 | | #include "ast/module.h" |
8 | | #include "common/spdlog.h" |
9 | | #include "runtime/instance/module.h" |
10 | | #include "llvm/compiler.h" |
11 | | #include "llvm/data.h" |
12 | | #include "llvm/jit.h" |
13 | | |
14 | | #include <algorithm> |
15 | | #include <cstdint> |
16 | | #include <memory> |
17 | | #include <mutex> |
18 | | #include <shared_mutex> |
19 | | #include <string_view> |
20 | | #include <unordered_map> |
21 | | #include <utility> |
22 | | #include <vector> |
23 | | |
24 | | namespace WasmEdge::LLVM { |
25 | | |
26 | | namespace { |
27 | | |
28 | | using namespace std::literals; |
29 | | |
30 | | // Collect the not-yet-compiled local functions reachable from the seed |
31 | | // function through direct calls and function references, so one lazy |
32 | | // compilation batch covers the whole call graph of the entry. |
33 | | std::vector<uint32_t> collectCallGraphBatch( |
34 | | uint32_t LocalSeed, const AST::Module &Module, uint32_t ImportFuncCount, |
35 | 0 | const std::unordered_map<uint32_t, WasmFunctionCodeAddress> &Compiled) { |
36 | 0 | std::vector<uint32_t> SortedLocals; |
37 | 0 | const auto &CodeSec = Module.getCodeSection().getContent(); |
38 | 0 | const uint32_t DefinedCount = Module.getDefinedFuncCount(); |
39 | | |
40 | | // The caller's findPendingCompile guarantees a valid, not-yet-compiled |
41 | | // seed. |
42 | 0 | assuming(LocalSeed < DefinedCount && Compiled.count(LocalSeed) == 0); |
43 | | |
44 | 0 | std::vector<uint8_t> Visited(DefinedCount, 0); |
45 | 0 | std::vector<uint32_t> Stack; |
46 | 0 | Stack.reserve(64); |
47 | |
|
48 | 0 | Visited[LocalSeed] = 1; |
49 | 0 | Stack.push_back(LocalSeed); |
50 | 0 | SortedLocals.push_back(LocalSeed); |
51 | |
|
52 | 0 | while (!Stack.empty()) { |
53 | 0 | const uint32_t L = Stack.back(); |
54 | 0 | Stack.pop_back(); |
55 | |
|
56 | 0 | for (const auto &Instr : CodeSec[L].getExpr().getInstrs()) { |
57 | 0 | const auto Op = Instr.getOpCode(); |
58 | 0 | if (Op == OpCode::Call || Op == OpCode::Return_call || |
59 | 0 | Op == OpCode::Ref__func) { |
60 | 0 | const uint32_t Target = Instr.getTargetIndex(); |
61 | 0 | if (Target >= ImportFuncCount) { |
62 | 0 | const uint32_t LocalIdx = Target - ImportFuncCount; |
63 | 0 | if (LocalIdx < DefinedCount && !Visited[LocalIdx] && |
64 | 0 | Compiled.count(LocalIdx) == 0) { |
65 | 0 | Visited[LocalIdx] = 1; |
66 | 0 | Stack.push_back(LocalIdx); |
67 | 0 | SortedLocals.push_back(LocalIdx); |
68 | 0 | } |
69 | 0 | } |
70 | 0 | } |
71 | 0 | } |
72 | 0 | } |
73 | |
|
74 | 0 | std::sort(SortedLocals.begin(), SortedLocals.end()); |
75 | 0 | return SortedLocals; |
76 | 0 | } |
77 | | |
78 | | // Upgrade the function instance at GlobalFuncIdx of a bound module instance |
79 | | // to run the compiled code at Address. Shared by the fresh-batch path and the |
80 | | // re-instantiation restore path. |
81 | | void upgradeToCompiled( |
82 | | Span<const Runtime::Instance::FunctionInstance *const> FuncInsts, |
83 | | size_t GlobalFuncIdx, JITLibrary &JITLib, |
84 | 0 | WasmFunctionCodeAddress Address) noexcept { |
85 | | // A fully instantiated instance of the same AST module covers every |
86 | | // compiled local function index. |
87 | 0 | assuming(GlobalFuncIdx < FuncInsts.size()); |
88 | | // The function instances are owned mutable by the module instance; the |
89 | | // accessor only adds constness. Upgrading them to compiled mode is the |
90 | | // purpose of this engine. Non-wasm functions are declined by |
91 | | // unsafeUpgradeToCompiled itself. |
92 | 0 | auto *FuncInst = const_cast<Runtime::Instance::FunctionInstance *>( |
93 | 0 | FuncInsts[GlobalFuncIdx]); |
94 | 0 | FuncInst->unsafeUpgradeToCompiled(JITLib.createCodeSymbol(Address)); |
95 | 0 | } |
96 | | |
97 | | // True while someone outside the engine still holds the AST module and could |
98 | | // re-instantiate it; a state failing this can never be rebound, so keeping it |
99 | | // would leak its JIT and compiled code for the lifetime of the engine. |
100 | | bool isReinstantiable( |
101 | 0 | const std::shared_ptr<const AST::Module> &Module) noexcept { |
102 | 0 | return Module != nullptr && Module.use_count() > 1; |
103 | 0 | } |
104 | | |
105 | | } // namespace |
106 | | |
107 | | struct LazyJITEngine::Impl { |
108 | | struct ModuleState { |
109 | | /// Shared ownership of the AST module for on-demand compilation. Held |
110 | | /// from prepare time on, so the AST module pointers used as map keys |
111 | | /// below can never dangle or be reused by another allocation. |
112 | | std::shared_ptr<const AST::Module> Module; |
113 | | /// Per-module LLVM data holding the thread-safe context. |
114 | | Data LLData; |
115 | | /// Per-module ORC LLJIT holding the generated code. |
116 | | std::shared_ptr<JITLibrary> JITLib; |
117 | | /// Resolved machine code addresses of lazily compiled functions, keyed |
118 | | /// by local function index. Survives re-instantiations, so rebinding |
119 | | /// restores compiled functions without further JIT symbol lookups. |
120 | | std::unordered_map<uint32_t, WasmFunctionCodeAddress> CompiledCode; |
121 | | /// Reverse lookup from function instance to function index. |
122 | | std::unordered_map<const Runtime::Instance::FunctionInstance *, uint32_t> |
123 | | FuncIndices; |
124 | | /// Number of imported functions of the module. |
125 | | uint32_t ImportFuncCount = 0; |
126 | | }; |
127 | | |
128 | 0 | Impl(const Configure &C) noexcept : Conf(C) {} |
129 | | |
130 | | /// Locate the bound state and the local function index when the function |
131 | | /// still needs lazy compilation. Returns {nullptr, 0} when there is |
132 | | /// nothing to do. The caller must hold Mutex (shared or exclusive); all |
133 | | /// writers hold it exclusively, so shared-locked reads are race-free. |
134 | | std::pair<ModuleState *, uint32_t> |
135 | | findPendingCompile(const Runtime::Instance::ModuleInstance *ModInst, |
136 | 0 | const Runtime::Instance::FunctionInstance *FuncInst) { |
137 | | // Already compiled or not a wasm function: nothing to do. Checked first |
138 | | // because it needs no map lookup, so the steady state where every call |
139 | | // probes an already-compiled function short-circuits here. The check is |
140 | | // done under the engine mutex to avoid racing with the upgrade in |
141 | | // compileOnDemand. |
142 | 0 | if (!FuncInst->isWasmFunction() || FuncInst->isCompiledFunction()) { |
143 | 0 | return {nullptr, 0}; |
144 | 0 | } |
145 | 0 | auto StateIt = States.find(ModInst); |
146 | 0 | if (StateIt == States.end()) { |
147 | 0 | return {nullptr, 0}; |
148 | 0 | } |
149 | 0 | auto &State = StateIt->second; |
150 | 0 | auto IdxIt = State.FuncIndices.find(FuncInst); |
151 | 0 | if (IdxIt == State.FuncIndices.end()) { |
152 | | // A bound module knows all of its function instances; reaching here |
153 | | // indicates a foreign or stale instance. |
154 | 0 | spdlog::debug( |
155 | 0 | "[lazy-jit]: function instance not bound to its module state"sv); |
156 | 0 | return {nullptr, 0}; |
157 | 0 | } |
158 | 0 | const uint32_t FuncIdx = IdxIt->second; |
159 | 0 | if (FuncIdx < State.ImportFuncCount) { |
160 | 0 | return {nullptr, 0}; |
161 | 0 | } |
162 | 0 | const uint32_t LocalFuncIdx = FuncIdx - State.ImportFuncCount; |
163 | 0 | if (State.CompiledCode.count(LocalFuncIdx) > 0) { |
164 | 0 | return {nullptr, 0}; |
165 | 0 | } |
166 | 0 | return {&State, LocalFuncIdx}; |
167 | 0 | } |
168 | | |
169 | | const Configure Conf; |
170 | | mutable std::shared_mutex Mutex; |
171 | | /// States prepared but not yet bound to a module instance, keyed by the |
172 | | /// AST module owned by the state itself. |
173 | | std::unordered_map<const AST::Module *, ModuleState> PendingStates; |
174 | | /// States bound to instantiated module instances. |
175 | | std::unordered_map<const Runtime::Instance::ModuleInstance *, ModuleState> |
176 | | States; |
177 | | }; |
178 | | |
179 | | LazyJITEngine::LazyJITEngine(const Configure &Conf) noexcept |
180 | 0 | : PImpl(std::make_unique<Impl>(Conf)) {} |
181 | | |
182 | 0 | LazyJITEngine::~LazyJITEngine() noexcept = default; |
183 | | |
184 | | Expect<std::shared_ptr<Executable>> |
185 | 0 | LazyJITEngine::prepare(std::shared_ptr<const AST::Module> Module) { |
186 | 0 | if (!Module) { |
187 | 0 | return Unexpect(ErrCode::Value::WrongVMWorkflow); |
188 | 0 | } |
189 | 0 | Impl::ModuleState State; |
190 | 0 | State.Module = std::move(Module); |
191 | |
|
192 | 0 | Compiler InfraCompiler(PImpl->Conf); |
193 | 0 | EXPECTED_TRY(InfraCompiler.checkConfigure()); |
194 | | |
195 | 0 | EXPECTED_TRY(State.LLData, |
196 | 0 | InfraCompiler.compileInfrastructure(*State.Module)); |
197 | |
|
198 | 0 | JIT JITEngine(PImpl->Conf); |
199 | 0 | EXPECTED_TRY(auto Exec, JITEngine.loadLazy(State.LLData)); |
200 | 0 | State.JITLib = std::static_pointer_cast<JITLibrary>(Exec); |
201 | |
|
202 | 0 | State.ImportFuncCount = State.Module->getImportFuncCount(); |
203 | |
|
204 | 0 | std::unique_lock Lock(PImpl->Mutex); |
205 | | // Prune pending states nobody can re-instantiate anymore. |
206 | 0 | for (auto It = PImpl->PendingStates.begin(); |
207 | 0 | It != PImpl->PendingStates.end();) { |
208 | 0 | if (isReinstantiable(It->second.Module)) { |
209 | 0 | ++It; |
210 | 0 | } else { |
211 | 0 | It = PImpl->PendingStates.erase(It); |
212 | 0 | } |
213 | 0 | } |
214 | 0 | const auto *Key = State.Module.get(); |
215 | 0 | PImpl->PendingStates.insert_or_assign(Key, std::move(State)); |
216 | 0 | return Exec; |
217 | 0 | } |
218 | | |
219 | | void LazyJITEngine::registerInstance( |
220 | | const Runtime::Instance::ModuleInstance &ModInst, |
221 | 0 | std::shared_ptr<const AST::Module> Module) noexcept { |
222 | 0 | if (!Module) { |
223 | 0 | return; |
224 | 0 | } |
225 | 0 | std::unique_lock Lock(PImpl->Mutex); |
226 | 0 | auto It = PImpl->PendingStates.find(Module.get()); |
227 | 0 | if (It == PImpl->PendingStates.end()) { |
228 | 0 | return; |
229 | 0 | } |
230 | 0 | auto State = std::move(It->second); |
231 | 0 | PImpl->PendingStates.erase(It); |
232 | |
|
233 | 0 | const auto FuncInsts = ModInst.getFunctionInstances(); |
234 | 0 | State.FuncIndices.reserve(FuncInsts.size()); |
235 | 0 | for (uint32_t I = 0; I < FuncInsts.size(); ++I) { |
236 | 0 | State.FuncIndices.emplace(FuncInsts[I], I); |
237 | 0 | } |
238 | | |
239 | | // Rebinding after a re-instantiation: the fresh function instances start |
240 | | // in interpreter mode, so restore the functions already compiled in |
241 | | // earlier instantiations from their persisted code addresses. |
242 | 0 | for (const auto &[LocalFuncIdx, Address] : State.CompiledCode) { |
243 | 0 | upgradeToCompiled(FuncInsts, size_t{State.ImportFuncCount} + LocalFuncIdx, |
244 | 0 | *State.JITLib, Address); |
245 | 0 | } |
246 | |
|
247 | 0 | PImpl->States.insert_or_assign(&ModInst, std::move(State)); |
248 | 0 | } |
249 | | |
250 | | void LazyJITEngine::unregisterInstance( |
251 | 0 | const Runtime::Instance::ModuleInstance &ModInst) noexcept { |
252 | 0 | std::unique_lock Lock(PImpl->Mutex); |
253 | 0 | auto It = PImpl->States.find(&ModInst); |
254 | 0 | if (It == PImpl->States.end()) { |
255 | 0 | return; |
256 | 0 | } |
257 | | // Drop only the per-instance bindings; the module-level JIT state moves |
258 | | // back to the pending map so a later instantiation of the same AST module |
259 | | // (which skips prepare because its executable is already hooked) rebinds |
260 | | // it instead of silently losing lazy compilation. |
261 | 0 | auto State = std::move(It->second); |
262 | 0 | PImpl->States.erase(It); |
263 | 0 | State.FuncIndices.clear(); |
264 | | // Keep the state only while it can be rebound; otherwise drop it instead |
265 | | // of leaking the JIT and its compiled code. |
266 | 0 | if (const auto *Key = State.Module.get(); isReinstantiable(State.Module)) { |
267 | 0 | PImpl->PendingStates.insert_or_assign(Key, std::move(State)); |
268 | 0 | } |
269 | 0 | } |
270 | | |
271 | | Expect<void> LazyJITEngine::compileOnDemand( |
272 | 0 | const Runtime::Instance::FunctionInstance *FuncInst) { |
273 | 0 | if (FuncInst == nullptr) { |
274 | 0 | return {}; |
275 | 0 | } |
276 | 0 | const auto *ModInst = FuncInst->getModule(); |
277 | 0 | if (ModInst == nullptr) { |
278 | 0 | return {}; |
279 | 0 | } |
280 | | |
281 | | // Fast path: the common no-work cases (unbound module, host function, |
282 | | // already compiled) need only read access. All writers hold the exclusive |
283 | | // lock on the same mutex, so the shared lock keeps this race-free without |
284 | | // serializing concurrent callers. |
285 | 0 | { |
286 | 0 | std::shared_lock SharedLock(PImpl->Mutex); |
287 | 0 | if (PImpl->findPendingCompile(ModInst, FuncInst).first == nullptr) { |
288 | 0 | return {}; |
289 | 0 | } |
290 | 0 | } |
291 | | |
292 | 0 | std::unique_lock Lock(PImpl->Mutex); |
293 | | // Re-locate under the exclusive lock; the state may have changed between |
294 | | // the two locks. |
295 | 0 | auto [StatePtr, LocalFuncIdx] = PImpl->findPendingCompile(ModInst, FuncInst); |
296 | 0 | if (StatePtr == nullptr) { |
297 | 0 | return {}; |
298 | 0 | } |
299 | 0 | auto &State = *StatePtr; |
300 | |
|
301 | 0 | auto BatchLocals = collectCallGraphBatch( |
302 | 0 | LocalFuncIdx, *State.Module, State.ImportFuncCount, State.CompiledCode); |
303 | |
|
304 | 0 | spdlog::debug( |
305 | 0 | "[lazy-jit]: lazy compiling batch ({} local funcs) for entry local {}"sv, |
306 | 0 | BatchLocals.size(), LocalFuncIdx); |
307 | |
|
308 | 0 | const auto LogError = [](std::string_view Stage) { |
309 | 0 | return [Stage](auto Err) { |
310 | 0 | spdlog::error("[lazy-jit]: {} failed: {}"sv, Stage, Err); |
311 | 0 | return Err; |
312 | 0 | }; |
313 | 0 | }; |
314 | | |
315 | | // The configure was already validated by checkConfigure() in prepare(), |
316 | | // and a state only exists after a successful prepare, so re-validating |
317 | | // here would only repeat its per-proposal warnings once per batch. |
318 | 0 | Compiler BatchCompiler(PImpl->Conf); |
319 | 0 | EXPECTED_TRY( |
320 | 0 | auto CompiledData, |
321 | 0 | BatchCompiler |
322 | 0 | .compileFunctions(std::move(State.LLData), *State.Module, BatchLocals) |
323 | 0 | .map_error(LogError("lazy JIT function compilation"sv))); |
324 | 0 | State.LLData = std::move(CompiledData); |
325 | |
|
326 | 0 | std::vector<uint32_t> BatchGlobal; |
327 | 0 | BatchGlobal.reserve(BatchLocals.size()); |
328 | 0 | for (uint32_t L : BatchLocals) { |
329 | 0 | BatchGlobal.push_back(State.ImportFuncCount + L); |
330 | 0 | } |
331 | | |
332 | | // The JIT library is created in prepare() and lives as long as the state. |
333 | 0 | assuming(State.JITLib); |
334 | 0 | JIT JITEngine(PImpl->Conf); |
335 | 0 | EXPECTED_TRY(auto ResolvedAddresses, |
336 | 0 | JITEngine.add(*State.JITLib, State.LLData, BatchGlobal) |
337 | 0 | .map_error(LogError("lazy JIT add"sv))); |
338 | | |
339 | | // The machine code now lives in the persisted JIT regardless of the |
340 | | // instance bindings, so record each address before upgrading its instance. |
341 | 0 | const auto FuncInsts = ModInst->getFunctionInstances(); |
342 | 0 | for (size_t I = 0; I < BatchLocals.size(); ++I) { |
343 | 0 | State.CompiledCode.emplace(BatchLocals[I], ResolvedAddresses[I]); |
344 | 0 | upgradeToCompiled(FuncInsts, BatchGlobal[I], *State.JITLib, |
345 | 0 | ResolvedAddresses[I]); |
346 | 0 | } |
347 | |
|
348 | 0 | spdlog::debug( |
349 | 0 | "[lazy-jit]: lazy compilation completed for batch of {} functions, " |
350 | 0 | "total compiled: {}"sv, |
351 | 0 | BatchLocals.size(), State.CompiledCode.size()); |
352 | |
|
353 | 0 | return {}; |
354 | 0 | } |
355 | | |
356 | 0 | uint32_t LazyJITEngine::compiledFunctionCount() const noexcept { |
357 | 0 | std::shared_lock Lock(PImpl->Mutex); |
358 | 0 | uint32_t Count = 0; |
359 | 0 | for (const auto &Pair : PImpl->States) { |
360 | 0 | Count += static_cast<uint32_t>(Pair.second.CompiledCode.size()); |
361 | 0 | } |
362 | | // Pending states of unbound modules still hold live compiled code. |
363 | 0 | for (const auto &Pair : PImpl->PendingStates) { |
364 | 0 | Count += static_cast<uint32_t>(Pair.second.CompiledCode.size()); |
365 | 0 | } |
366 | 0 | return Count; |
367 | 0 | } |
368 | | |
369 | 0 | void LazyJITEngine::clear() noexcept { |
370 | 0 | std::unique_lock Lock(PImpl->Mutex); |
371 | 0 | PImpl->PendingStates.clear(); |
372 | 0 | PImpl->States.clear(); |
373 | 0 | } |
374 | | |
375 | | } // namespace WasmEdge::LLVM |