Skip to content

[stack-switching] JIT support for stack switching #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/engine/BytecodeIterator.v3
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class BytecodeIterator {
def read_ZEROB = read_u8;
def read_VALTS = cp.read_ValueTypeCodes;
def read_CATCHES = cp.read_catches;
def read_HANDLERS = cp.read_suspension_handlers;

// Dispatch to appropriate visit_OP() method.
// XXX: generate match on operator and read of immediates from Opcodes table.
Expand Down Expand Up @@ -767,8 +768,8 @@ class BytecodeIterator {
CONT_NEW => v.visit_CONT_NEW(read_CONT());
CONT_BIND => v.visit_CONT_BIND(read_CONT(), read_CONT());
SUSPEND => v.visit_SUSPEND(read_TAG());
RESUME => v.visit_RESUME(read_CONT()); // TODO
RESUME_THROW => v.visit_RESUME_THROW(read_CONT(), read_TAG());
RESUME => v.visit_RESUME(read_CONT(), read_HANDLERS());
RESUME_THROW => v.visit_RESUME_THROW(read_CONT(), read_TAG(), read_HANDLERS());
SWITCH => v.visit_SWITCH(read_CONT(), read_TAG());
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/engine/CodePtr.v3
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class CodePtr extends DataReader {
for (i < length) result[i] = (read_uleb31(), read_uleb31());
return result;
}
def read_suspension_handlers() -> Array<SuspensionHandler> {
var length = read_uleb31();
var result = Array<SuspensionHandler>.new(length);
for (i < length) result[i] = read_suspension_handler();
return result;
}
def read_catches() -> Array<BpCatchCode> {
var length = read_uleb31();
var result = Array<BpCatchCode>.new(length);
Expand All @@ -73,6 +79,11 @@ class CodePtr extends DataReader {
var d: BpCatchCode;
return d; // TODO: error
}
def read_suspension_handler() -> SuspensionHandler {
var kind = read_uleb31();
if (kind == 0) return SuspensionHandler.Suspend(read_uleb31(), read_uleb31());
else return SuspensionHandler.Switch(read_uleb31());
}
def iterate_local_codes<T>(f: (u32, ValueTypeCode) -> T) -> int {
var bcount = int.!(read_uleb32()); // pairs count
for (i < bcount) {
Expand Down
16 changes: 9 additions & 7 deletions src/engine/CodeValidator.v3
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ class CodeValidator(extensions: Extension.set, limits: Limits, module: Module, e
RESUME => {
var ct = parser.readCont();
if (ct == null) return;
var handlers = parser.readSusHandlers();
var handlers = parser.readSuspensionHandlers();

popE(ValueTypes.Ref(true, ct));
ctlxfer.refR(ct.sig.params.length, handlers.length);
Expand All @@ -1197,7 +1197,7 @@ class CodeValidator(extensions: Extension.set, limits: Limits, module: Module, e
RESUME_THROW => {
var ct = parser.readCont();
if (ct == null) return;
var handlers = parser.readSusHandlers();
var handlers = parser.readSuspensionHandlers();

popE(ValueTypes.Ref(true, ct));
ctlxfer.refR(ct.sig.params.length, handlers.length);
Expand Down Expand Up @@ -1427,22 +1427,24 @@ class CodeValidator(extensions: Extension.set, limits: Limits, module: Module, e
}
return SigDecl.!(tag_type);
}
def readAndCheckContHandlerTable(cont: ContDecl, handlers: Array<SusHandler>) {
def readAndCheckContHandlerTable(cont: ContDecl, handlers: Array<SuspensionHandler>) {
var sidetable_pos = ctlxfer.sidetable.length;
for (i < handlers.length) {
// Same sidetable structure is used for both {suspend} and {switch}.
var sidetable_entry = sidetable_pos + (i * Sidetable_CatchEntry.size / 4);
var info = ExHandlerInfo.Sidetable(false, sidetable_entry);
var resume_pos = opcode_pos - func_start_pos;
match (handlers[i]) {
Suspend(tag, depth) => {
Suspend(tag_index, depth) => {
var tag = module.tags[tag_index];
var target = getControl(depth);
if (target == null) return;
checkContHandle(i, tag, cont, target);
ctlxfer.refC(target, tag, false, "refH[suspend]");
suspend_handlers.put(ExHandlerEntry(tag.tag_index, resume_pos, resume_pos + 1, info));
suspend_handlers.put(ExHandlerEntry(tag_index, resume_pos, resume_pos + 1, info));
}
Switch(tag) => {
Switch(tag_index) => {
var tag = module.tags[tag_index];
var expected = cont.sig.results;
var tag_type = module.heaptypes[tag.sig_index];
if (!SigDecl.?(tag_type)) return err_atpc().ExpectedSignature(tag_type);
Expand All @@ -1457,7 +1459,7 @@ class CodeValidator(extensions: Extension.set, limits: Limits, module: Module, e
}
}
ctlxfer.refC(null, tag, false, "refH[switch]");
switch_handlers.put(ExHandlerEntry(tag.tag_index, resume_pos, resume_pos + 1, info));
switch_handlers.put(ExHandlerEntry(tag_index, resume_pos, resume_pos + 1, info));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/engine/Opcodes.v3
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ class InstrTracer {
out.put1("%d...", handlers.length);
}
SUS_HANDLERS => {
var handlers = parser.readSusHandlers();
var handlers = parser.readSuspensionHandlers();
out.put1("%d...", handlers.length);
}
CATCHES => {
Expand Down
3 changes: 2 additions & 1 deletion src/engine/Runtime.v3
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ component Runtime {
if (func == null) return stack.trap(TrapReason.NULL_DEREF);

var new_stack = make_stack(func);
stack.push(Value.Ref(Continuation.new(new_stack, new_stack)));
var cont = Continuation.new(new_stack, new_stack);
stack.push(Value.Ref(cont));
return null;
}
def CONT_BIND(stack: WasmStack, instance: Instance, in_cont_index: u31, out_cont_index: u31) -> Throwable {
Expand Down
1 change: 1 addition & 0 deletions src/engine/Type.v3
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ component ValueTypes {

def ONE_STRUCTREF_TYPE: Array<ValueType> = [STRUCTREF];
def ONE_ARRAYREF_TYPE: Array<ValueType> = [ARRAYREF];
def ONE_CONTREF_TYPE: Array<ValueType> = [CONTREF];

// Helper utility for a final signature type with no supertypes.
def newSig = SigDecl.new(true, NO_HEAPTYPES, _, _);
Expand Down
16 changes: 8 additions & 8 deletions src/engine/WasmParser.v3
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,15 @@ class WasmParser(extensions: Extension.set, limits: Limits, module: Module,
return result;
}
// Reads the same content as readHandlers but with extra checks.
def readSusHandlers() -> Array<SusHandler> {
def readSuspensionHandlers() -> Array<SuspensionHandler> {
var pt = decoder.pos;
var count = readU32("sus handler count", limits.max_func_size);
var result = Array<SusHandler>.new(int.!(count));
var result = Array<SuspensionHandler>.new(int.!(count));
for (i < count) {
var kind = readU32("sus handler kind", 1);
var ch: SusHandler;
if (kind == 0) ch = SusHandler.Suspend(readTagRef(), readLabel());
else ch = SusHandler.Switch(readTagRef());
var ch: SuspensionHandler;
if (kind == 0) ch = SuspensionHandler.Suspend(u31.!(readTagRef().tag_index), readLabel());
else ch = SuspensionHandler.Switch(u31.!(readTagRef().tag_index));
result[i] = ch;
}
return result;
Expand Down Expand Up @@ -536,7 +536,7 @@ class WasmParser(extensions: Extension.set, limits: Limits, module: Module,
}
type Catch(tag: TagDecl, exnref: bool, depth: u32) {
}
type SusHandler {
case Suspend(t: TagDecl, depth: u32);
case Switch(t: TagDecl);
type SuspensionHandler {
case Suspend(tag_index: u31, depth: u32);
case Switch(tag_index: u31);
}
4 changes: 4 additions & 0 deletions src/engine/compiler/MacroAssembler.v3
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class MacroAssembler(valuerep: Tagging, regConfig: RegConfig) {
source_loc = src;
}

def recordRetSourceLoc();

def emit_intentional_crash() {
}

Expand Down Expand Up @@ -307,6 +309,7 @@ class MacroAssembler(valuerep: Tagging, regConfig: RegConfig) {
def emit_jump_HostCallStub();

def emit_call_runtime_callHost(func_arg: Reg);
def emit_call_runtime_SUSPEND();
def emit_jump_to_trap_at(reason: TrapReason);
def emit_call_runtime_op(op: Opcode);
def emit_get_curstack(r: Reg);
Expand Down Expand Up @@ -336,6 +339,7 @@ class MacroAssembler(valuerep: Tagging, regConfig: RegConfig) {
def emit_chain_cont_to_parent(parent: Reg, cont: Reg);
// Called as the last step during resume. Switches {curStack} to the top of {cont}.
def emit_switch_curStack_to_cont(cont: Reg);
def emit_cont_mv(from_vsp: Reg, cont: Reg, n_vals: Reg, tmp1: Reg, tmp2: Reg, xmm0: Reg);

def fatalUnimplemented();

Expand Down
17 changes: 17 additions & 0 deletions src/engine/compiler/SinglePassCompiler.v3
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,17 @@ class SinglePassCompiler(xenv: SpcExecEnv, masm: MacroAssembler, regAlloc: RegAl
state.overwrite(sv.kindFlagsAndTag(IN_REG | (sv.flags & IS_STORED)), reg, 0);
}
}
// ext: stack-switching
def visit_CONT_NEW(cont_index: u31) {
var decl = ContDecl.!(module.heaptypes[cont_index]);
emit_call_runtime_op1n(Opcode.CONT_NEW, cont_index, 1, ValueTypes.ONE_CONTREF_TYPE, true);
}
def visit_CONT_BIND(in_cont_id: u31, out_cont_id: u31) {
var in_cont = ContDecl.!(module.heaptypes[in_cont_id]);
var out_cont = ContDecl.!(module.heaptypes[out_cont_id]);
var n_binds = u32.!(in_cont.sig.params.length - out_cont.sig.params.length);
emit_call_runtime_op2n(Opcode.CONT_BIND, in_cont_id, out_cont_id, n_binds + 1, ValueTypes.ONE_CONTREF_TYPE, true);
}
def visit_STRUCT_NEW(struct_index: u31) {
var decl = StructDecl.!(module.heaptypes[struct_index]);
emit_call_runtime_op1n(Opcode.STRUCT_NEW, struct_index, u32.!(decl.field_types.length), ValueTypes.ONE_STRUCTREF_TYPE, true);
Expand Down Expand Up @@ -2168,6 +2179,7 @@ class SpcState(regAlloc: RegAlloc) {
def emitFallthru(resolver: SpcMoveResolver) {
emitTransfer(ctl_stack.peek(), resolver);
}
// [rv]: use this, copy {target}
def emitTransfer(target: SpcControl, resolver: SpcMoveResolver) {
if (!ctl_stack.peek().reachable) {
if (Trace.compiler) OUT.puts(" xfer not reachable").ln();
Expand Down Expand Up @@ -2292,6 +2304,11 @@ class SpcState(regAlloc: RegAlloc) {
push(typeToKindFlags(t) | TAG_STORED | IS_STORED, NO_REG, 0);
}
}
def pushResults(results: Range<ValueType>) {
for (t in results) {
push(typeToKindFlags(t) | TAG_STORED | IS_STORED, NO_REG, 0);
}
}
def peek() -> SpcVal {
return state[sp - 1];
}
Expand Down
7 changes: 6 additions & 1 deletion src/engine/x86-64/X86_64MacroAssembler.v3
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,9 @@ class X86_64MacroAssembler extends MacroAssembler {
def emit_call_runtime_TRAP() {
emit_call_runtime(X86_64RT.runtime_TRAP);
}
def emit_call_runtime_SUSPEND() {
emit_call_runtime(X86_64RT.runtime_handle_suspend);
}
def emit_jump_to_trap_at(reason: TrapReason) {
var ip = trap_stubs.getIpForReason(reason);
asm.movq_r_l(scratch, ip - Pointer.NULL);
Expand Down Expand Up @@ -874,6 +877,8 @@ class X86_64MacroAssembler extends MacroAssembler {
ARRAY_COPY => emit_call_runtime(RT.ARRAY_COPY);
ARRAY_INIT_DATA => emit_call_runtime(RT.ARRAY_INIT_DATA);
ARRAY_INIT_ELEM => emit_call_runtime(RT.ARRAY_INIT_ELEM);
CONT_NEW => emit_call_runtime(X86_64RT.runtime_CONT_NEW);
CONT_BIND => emit_call_runtime(RT.CONT_BIND);
_ => unimplemented();
}
}
Expand Down Expand Up @@ -1521,7 +1526,7 @@ class X86_64MacroAssembler extends MacroAssembler {
emit_br_r(cont, MasmBrCond.REF_NULL, newTrapLabel(TrapReason.NULL_DEREF));
emit_mov_r_m(ValueKind.REF, scratch, MasmAddr(cont, offsets.Continuation_used));
emit_br_r(scratch, MasmBrCond.I64_NONZERO, newTrapLabel(TrapReason.USED_CONTINUATION));
emit_mov_m_i(MasmAddr(cont, offsets.Continuation_used), 1);
emit_mov_m_i(MasmAddr(cont, offsets.Continuation_used), 1);
}
def emit_chain_cont_to_parent(parent: Reg, cont: Reg) {
var scratch = regConfig.scratch;
Expand Down
2 changes: 1 addition & 1 deletion src/engine/x86-64/X86_64MasmRegs.v3
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ component X86_64MasmRegs {
// local state used during building of {SET} and {CONFIG}
def NO_REG = Reg(0);
private def GPRS = Array<X86_64Gpr>.new(256); // fast mapping byte -> GPR
private def XMMS = Array<X86_64Xmmr>.new(256); // fast mapping byte -> XMM
def XMMS = Array<X86_64Xmmr>.new(256); // fast mapping byte -> XMM
private var all = Vector<Reg>.new().grow(32).put(NO_REG);
private var ints = Vector<Reg>.new().grow(16);
private var floats = Vector<Reg>.new().grow(16);
Expand Down
2 changes: 1 addition & 1 deletion src/engine/x86-64/X86_64Runtime.v3
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ component X86_64Runtime {
if (l.head.0 == pc) {
var entrypoint = osr.spc_entry + l.head.1;
if (Debug.runtime) Trace.OUT.put1(" tierup to 0x%x", entrypoint - Pointer.NULL).ln();
var retaddr_ptr = CiRuntime.callerSp() + -Pointer.SIZE;
var retaddr_ptr = stack.rsp;
retaddr_ptr.store<Pointer>(entrypoint); // overwrite return address to return to JIT code
return;
}
Expand Down
122 changes: 122 additions & 0 deletions src/engine/x86-64/X86_64SinglePassCompiler.v3
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,128 @@ class X86_64SinglePassCompiler extends SinglePassCompiler {
state.push(a.kindFlagsMatching(ValueKind.V128, IN_REG), a.reg, 0);
}

def visit_RESUME(cont_id: u31, handlers: Range<SuspensionHandler>) {
var offsets = V3Offsets.new();
var stub_resume = X86_64Label.new(), end = X86_64Label.new();
var cont_decl = ContDecl.!(module.heaptypes[cont_id]);
if (checkForConstNull(state.peek())) return;

var vsp = allocTmpFixed(ValueKind.REF, regs.vsp);
var top = popReg();
var cont = allocTmp(ValueKind.REF);
masm.emit_mov_r_r(ValueKind.REF, cont, top.reg);
masm.emit_validate_and_consume_cont(cont);

// transfer params to child stack
// XXX: not necessary to store the entire value stack to memory
state.emitSaveAll(resolver, SpillMode.SAVE_AND_FREE_REGS);
var curStack = allocTmp(ValueKind.REF);
var nvals = allocTmp(ValueKind.REF);
var tmp0 = allocTmp(ValueKind.REF);
var tmp1 = allocTmp(ValueKind.REF);
var tmp2 = allocTmp(ValueKind.V128);

// call to transfer
emit_compute_vsp(regs.vsp, state.sp);
masm.emit_mov_r_i(nvals, cont_decl.sig.params.length);
masm.emit_cont_mv(vsp, cont, nvals, tmp0, tmp1, tmp2);
dropN(u32.!(cont_decl.sig.params.length));
state.emitRestoreAll(resolver);

// context switch: store to old stack
masm.emit_get_curstack(curStack);
emit_compute_vsp(regs.vsp, state.sp);
emit_spill_vsp(xenv.vsp);
masm.emit_v3_set_X86_64Stack_vsp_r_r(curStack, xenv.vsp);
masm.emit_v3_set_X86_64Stack_rsp_r_r(curStack, xenv.sp);

// reserve space for call addr
asm.sub_m_i(G(curStack).plus(offsets.X86_64Stack_rsp), Pointer.SIZE);
masm.emit_chain_cont_to_parent(curStack, cont);
masm.emit_switch_curStack_to_cont(cont);

asm.call_rel_far(stub_resume);
/* === CHILD STACK RUNNING === */

// clean up after stack switch
masm.recordRetSourceLoc();
state.pushResults(cont_decl.sig.results);
emit_reload_regs();
if (!runtimeSpillMode.free_regs) state.emitRestoreAll(resolver);

// skip {stub_resume}
asm.jmp_rel_near(end);
asm.bind(stub_resume); {
// load new stack pointers
masm.emit_get_curstack(curStack);
masm.emit_v3_X86_64Stack_vsp_r_r(xenv.vsp, curStack);
masm.emit_v3_X86_64Stack_rsp_r_r(xenv.sp, curStack);
// pop address from new stack and go to it
masm.emit_pop_r(ValueKind.REF, xenv.scratch);
masm.emit_jump_r(xenv.scratch);
}
asm.bind(end);
}

// def visit_SUSPEND(tag_index: u31) {
// var offsets = V3Offsets.new();
// var stub_suspend = X86_64Label.new(), end = X86_64Label.new();

// state.emitSaveAll(resolver, SpillMode.SAVE_AND_FREE_REGS);

// emit_compute_vsp(regs.vsp, state.sp);
// emit_spill_vsp(xenv.vsp);
// masm.emit_store_curstack_vsp(regs.vsp);
// masm.emit_get_curstack(regs.runtime_arg0);
// masm.emit_v3_set_X86_64Stack_rsp_r_r(regs.runtime_arg0, regs.sp);
// masm.emit_push_X86_64Stack_rsp_r_r(regs.runtime_arg0);
// emit_load_instance(regs.runtime_arg1);
// masm.emit_mov_r_i(regs.runtime_arg2, tag_index);
// masm.emit_call_runtime_SUSPEND();

// var sig_index = module.tags[tag_index].sig_index;
// state.popArgsAndPushResults(SigDecl.!(module.heaptypes[sig_index]));
// asm.call_rel_far(stub_suspend);

// /* === PARENT STACK RUNNING === */

// state.emitRestoreAll(resolver);
// emit_reload_regs();

// // skip {stub_suspend}
// asm.jmp_rel_near(end);
// asm.bind(stub_suspend); {
// // load new stack pointers
// masm.emit_get_curstack(regs.scratch);
// // unlike fast-int, the ptr offset for RT calls is done outside of the `call_runtime_...`
// // shorthands, so manually popping %rsp is needed
// masm.emit_pop_X86_64Stack_rsp_r_r(regs.scratch);
// masm.emit_v3_X86_64Stack_vsp_r_r(xenv.vsp, regs.scratch);
// masm.emit_v3_X86_64Stack_rsp_r_r(xenv.sp, regs.scratch);
// // restore state
// emit_reload_regs();
// }
// asm.bind(end);
// }

// e.g., try_table:
// at each label, record if it can be reached by EH
// collect all possible labels and their abstract states

// asm.bind(label_try_table);
//
// throw() -> stub which loads the states -> label
// addr of stub should be in handler table of function

private def checkForConstNull(sv: SpcVal) -> bool {
if (sv.isConst() && sv.const == 0) {
emitTrap(TrapReason.NULL_DEREF);
setUnreachable();
return true;
}
return false;
}

private def visit_V128_SHIFT1<T>(masm_shift: (X86_64Xmmr, X86_64Gpr, X86_64Gpr, X86_64Xmmr, X86_64Xmmr) -> T) {
var b = popReg();
var a = popRegToOverwrite();
Expand Down
Loading
Loading