use crate::{Engine, FuncType, Trap, ValRaw};
use anyhow::Result;
use std::panic::{self, AssertUnwindSafe};
use std::ptr::NonNull;
use wasmtime_jit::{CodeMemory, ProfilingAgent};
use wasmtime_runtime::{
VMContext, VMHostFuncContext, VMOpaqueContext, VMSharedSignatureIndex, VMTrampoline,
};
struct TrampolineState<F> {
func: F,
#[allow(dead_code)]
code_memory: CodeMemory,
}
unsafe extern "C" fn stub_fn<F>(
vmctx: *mut VMOpaqueContext,
caller_vmctx: *mut VMContext,
values_vec: *mut ValRaw,
values_vec_len: usize,
) where
F: Fn(*mut VMContext, &mut [ValRaw]) -> Result<(), Trap> + 'static,
{
let result = panic::catch_unwind(AssertUnwindSafe(|| {
let vmctx = VMHostFuncContext::from_opaque(vmctx);
let state = (*vmctx).host_state();
debug_assert!(state.is::<TrampolineState<F>>());
let state = &*(state as *const _ as *const TrampolineState<F>);
let values_vec = std::slice::from_raw_parts_mut(values_vec, values_vec_len);
(state.func)(caller_vmctx, values_vec)
}));
match result {
Ok(Ok(())) => {}
Ok(Err(trap)) => wasmtime_runtime::raise_user_trap(trap.into()),
Err(panic) => wasmtime_runtime::resume_panic(panic),
}
}
#[cfg(compiler)]
fn register_trampolines(profiler: &dyn ProfilingAgent, image: &object::File<'_>) {
use object::{Object as _, ObjectSection, ObjectSymbol, SectionKind, SymbolKind};
let pid = std::process::id();
let tid = pid;
let text_base = match image.sections().find(|s| s.kind() == SectionKind::Text) {
Some(section) => match section.data() {
Ok(data) => data.as_ptr() as usize,
Err(_) => return,
},
None => return,
};
for sym in image.symbols() {
if !sym.is_definition() {
continue;
}
if sym.kind() != SymbolKind::Text {
continue;
}
let address = sym.address();
let size = sym.size();
if address == 0 || size == 0 {
continue;
}
if let Ok(name) = sym.name() {
let addr = text_base + address as usize;
profiler.load_single_trampoline(name, addr as *const u8, size as usize, pid, tid);
}
}
}
#[cfg(compiler)]
pub fn create_function<F>(
ft: &FuncType,
func: F,
engine: &Engine,
) -> Result<(Box<VMHostFuncContext>, VMSharedSignatureIndex, VMTrampoline)>
where
F: Fn(*mut VMContext, &mut [ValRaw]) -> Result<(), Trap> + Send + Sync + 'static,
{
let mut obj = engine.compiler().object()?;
let (t1, t2) = engine.compiler().emit_trampoline_obj(
ft.as_wasm_func_type(),
stub_fn::<F> as usize,
&mut obj,
)?;
let obj = wasmtime_jit::mmap_vec_from_obj(obj)?;
let mut code_memory = CodeMemory::new(obj);
let code = code_memory.publish()?;
register_trampolines(engine.profiler(), &code.obj);
let host_trampoline = code.text[t1.start as usize..][..t1.length as usize].as_ptr();
let wasm_trampoline = code.text[t2.start as usize..].as_ptr() as *mut _;
let wasm_trampoline = NonNull::new(wasm_trampoline).unwrap();
let sig = engine.signatures().register(ft.as_wasm_func_type());
unsafe {
let ctx = VMHostFuncContext::new(
wasm_trampoline,
sig,
Box::new(TrampolineState { func, code_memory }),
);
let host_trampoline = std::mem::transmute::<*const u8, VMTrampoline>(host_trampoline);
Ok((ctx, sig, host_trampoline))
}
}