Skip to content

Commit df26f96

Browse files
authored
Stack switching: fix some optimization passes (#7271)
This continues #7041 by adapting the optimizations passes to work with the stack switching instructions.
1 parent 52ac4c1 commit df26f96

15 files changed

+715
-22
lines changed

scripts/test/fuzzing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@
105105
'stack_switching_resume.wast',
106106
'stack_switching_resume_throw.wast',
107107
'stack_switching_switch.wast',
108+
'stack_switching_switch_2.wast',
109+
'O3_stack-switching.wast',
110+
'coalesce-locals-stack-switching.wast',
111+
'dce-stack-switching.wast',
112+
'precompute-stack-switching.wast',
113+
'vacuum-stack-switching.wast'
108114
# TODO: fuzzer support for custom descriptors
109115
'custom-descriptors.wast',
110116
]

src/cfg/cfg-traversal.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,19 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
444444
self->tryStack.pop_back();
445445
}
446446

447+
static void doEndResume(SubType* self, Expression** currp) {
448+
auto* module = self->getModule();
449+
if (!module || module->features.hasExceptionHandling()) {
450+
// This resume might throw, so run the code to handle that.
451+
doEndThrowingInst(self, currp);
452+
}
453+
auto handlerBlocks = BranchUtils::getUniqueTargets(*currp);
454+
// Add branches to the targets.
455+
for (auto target : handlerBlocks) {
456+
self->branches[target].push_back(self->currBasicBlock);
457+
}
458+
}
459+
447460
static bool isReturnCall(Expression* curr) {
448461
switch (curr->_id) {
449462
case Expression::Id::CallId:
@@ -521,6 +534,20 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
521534
self->pushTask(SubType::doEndThrow, currp);
522535
break;
523536
}
537+
case Expression::Id::ResumeId:
538+
case Expression::Id::ResumeThrowId: {
539+
self->pushTask(SubType::doEndResume, currp);
540+
break;
541+
}
542+
case Expression::Id::SuspendId:
543+
case Expression::Id::StackSwitchId: {
544+
auto* module = self->getModule();
545+
if (!module || module->features.hasExceptionHandling()) {
546+
// This might throw, so run the code to handle that.
547+
self->pushTask(SubType::doEndCall, currp);
548+
}
549+
break;
550+
}
524551
default: {
525552
if (Properties::isBranch(curr)) {
526553
self->pushTask(SubType::doEndBranch, currp);

src/ir/ReFinalize.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,18 @@ void ReFinalize::visitStringSliceWTF(StringSliceWTF* curr) { curr->finalize(); }
183183
void ReFinalize::visitContNew(ContNew* curr) { curr->finalize(); }
184184
void ReFinalize::visitContBind(ContBind* curr) { curr->finalize(); }
185185
void ReFinalize::visitSuspend(Suspend* curr) { curr->finalize(getModule()); }
186-
void ReFinalize::visitResume(Resume* curr) { curr->finalize(); }
187-
void ReFinalize::visitResumeThrow(ResumeThrow* curr) { curr->finalize(); }
186+
void ReFinalize::visitResume(Resume* curr) {
187+
curr->finalize();
188+
for (size_t i = 0; i < curr->handlerBlocks.size(); i++) {
189+
updateBreakValueType(curr->handlerBlocks[i], curr->sentTypes[i]);
190+
}
191+
}
192+
void ReFinalize::visitResumeThrow(ResumeThrow* curr) {
193+
curr->finalize();
194+
for (size_t i = 0; i < curr->handlerBlocks.size(); i++) {
195+
updateBreakValueType(curr->handlerBlocks[i], curr->sentTypes[i]);
196+
}
197+
}
188198
void ReFinalize::visitStackSwitch(StackSwitch* curr) { curr->finalize(); }
189199

190200
void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); }

src/ir/branch-utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ void operateOnScopeNameUsesAndSentTypes(Expression* expr, T func) {
8383
}
8484
}
8585
} else if (auto* r = expr->dynCast<Resume>()) {
86-
for (Index i = 0; i < r->handlerTags.size(); i++) {
87-
auto dest = r->handlerTags[i];
86+
for (Index i = 0; i < r->handlerBlocks.size(); i++) {
87+
auto dest = r->handlerBlocks[i];
8888
if (!dest.isNull() && dest == name) {
8989
func(name, r->sentTypes[i]);
9090
}
9191
}
9292
} else if (auto* r = expr->dynCast<ResumeThrow>()) {
93-
for (Index i = 0; i < r->handlerTags.size(); i++) {
94-
auto dest = r->handlerTags[i];
93+
for (Index i = 0; i < r->handlerBlocks.size(); i++) {
94+
auto dest = r->handlerBlocks[i];
9595
if (!dest.isNull() && dest == name) {
9696
func(name, r->sentTypes[i]);
9797
}

src/ir/subtypes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ struct SubTypes {
126126
basic = HeapTypes::array.getBasic(share);
127127
break;
128128
case HeapTypeKind::Cont:
129-
WASM_UNREACHABLE("TODO: cont");
129+
basic = HeapTypes::cont.getBasic(share);
130+
break;
130131
case HeapTypeKind::Basic:
131132
WASM_UNREACHABLE("unexpected kind");
132133
}

src/ir/type-updating.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,12 @@ GlobalTypeRewriter::TypeMap GlobalTypeRewriter::rebuildTypes(
150150
typeBuilder[i] = newArray;
151151
break;
152152
}
153-
case HeapTypeKind::Cont:
154-
WASM_UNREACHABLE("TODO: cont");
153+
case HeapTypeKind::Cont: {
154+
auto newCont = HeapType(typeBuilder[i]).getContinuation();
155+
modifyContinuation(type, newCont);
156+
typeBuilder[i] = newCont;
157+
break;
158+
}
155159
case HeapTypeKind::Basic:
156160
WASM_UNREACHABLE("unexpected kind");
157161
}

src/ir/type-updating.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ class GlobalTypeRewriter {
382382
// used to define the new type in the TypeBuilder.
383383
virtual void modifyStruct(HeapType oldType, Struct& struct_) {}
384384
virtual void modifyArray(HeapType oldType, Array& array) {}
385+
virtual void modifyContinuation(HeapType oldType, Continuation& sig) {}
385386
virtual void modifySignature(HeapType oldType, Signature& sig) {}
386387

387388
// This additional hook is called after modify* and other operations, and
@@ -490,16 +491,19 @@ class TypeMapper : public GlobalTypeRewriter {
490491
mapTypes(newMapping);
491492
}
492493

494+
HeapType getNewHeapType(HeapType type) {
495+
auto iter = mapping.find(type);
496+
if (iter != mapping.end()) {
497+
return iter->second;
498+
}
499+
return type;
500+
}
501+
493502
Type getNewType(Type type) {
494503
if (!type.isRef()) {
495504
return type;
496505
}
497-
auto heapType = type.getHeapType();
498-
auto iter = mapping.find(heapType);
499-
if (iter != mapping.end()) {
500-
return getTempType(Type(iter->second, type.getNullability()));
501-
}
502-
return getTempType(type);
506+
return getTempType(type.with(getNewHeapType(type.getHeapType())));
503507
}
504508

505509
void modifyStruct(HeapType oldType, Struct& struct_) override {
@@ -513,6 +517,10 @@ class TypeMapper : public GlobalTypeRewriter {
513517
void modifyArray(HeapType oldType, Array& array) override {
514518
array.element.type = getNewType(oldType.getArray().element.type);
515519
}
520+
void modifyContinuation(HeapType oldType,
521+
Continuation& continuation) override {
522+
continuation.type = getNewHeapType(oldType.getContinuation().type);
523+
}
516524
void modifySignature(HeapType oldSignatureType, Signature& sig) override {
517525
auto getUpdatedTypeList = [&](Type type) {
518526
std::vector<Type> vec;

src/wasm-interpreter.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,15 +2617,29 @@ class ConstantExpressionRunner : public ExpressionRunner<SubType> {
26172617
}
26182618
return ExpressionRunner<SubType>::visitRefAs(curr);
26192619
}
2620-
Flow visitContNew(ContNew* curr) { WASM_UNREACHABLE("unimplemented"); }
2621-
Flow visitContBind(ContBind* curr) { WASM_UNREACHABLE("unimplemented"); }
2622-
Flow visitSuspend(Suspend* curr) { WASM_UNREACHABLE("unimplemented"); }
2623-
Flow visitResume(Resume* curr) { WASM_UNREACHABLE("unimplemented"); }
2620+
Flow visitContNew(ContNew* curr) {
2621+
NOTE_ENTER("ContNew");
2622+
return Flow(NONCONSTANT_FLOW);
2623+
}
2624+
Flow visitContBind(ContBind* curr) {
2625+
NOTE_ENTER("ContBind");
2626+
return Flow(NONCONSTANT_FLOW);
2627+
}
2628+
Flow visitSuspend(Suspend* curr) {
2629+
NOTE_ENTER("Suspend");
2630+
return Flow(NONCONSTANT_FLOW);
2631+
}
2632+
Flow visitResume(Resume* curr) {
2633+
NOTE_ENTER("Resume");
2634+
return Flow(NONCONSTANT_FLOW);
2635+
}
26242636
Flow visitResumeThrow(ResumeThrow* curr) {
2625-
WASM_UNREACHABLE("unimplemented");
2637+
NOTE_ENTER("ResumeThrow");
2638+
return Flow(NONCONSTANT_FLOW);
26262639
}
26272640
Flow visitStackSwitch(StackSwitch* curr) {
2628-
WASM_UNREACHABLE("unimplemented");
2641+
NOTE_ENTER("StackSwitch");
2642+
return Flow(NONCONSTANT_FLOW);
26292643
}
26302644

26312645
void trap(const char* why) override { throw NonconstantException(); }

src/wasm/wasm.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,8 +1443,12 @@ void StackSwitch::finalize() {
14431443
}
14441444

14451445
assert(this->cont->type.isContinuation());
1446-
type =
1446+
Type params =
14471447
this->cont->type.getHeapType().getContinuation().type.getSignature().params;
1448+
assert(params.size() > 0);
1449+
Type cont = params[params.size() - 1];
1450+
assert(cont.isContinuation());
1451+
type = cont.getHeapType().getContinuation().type.getSignature().params;
14481452
}
14491453

14501454
size_t Function::getNumParams() { return getParams().size(); }
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
2+
;; RUN: wasm-opt -all %s -S -o - | filecheck %s
3+
4+
5+
(module
6+
;; CHECK: (type $function (func (param i64)))
7+
(type $function (func (param i64)))
8+
;; CHECK: (type $cont (cont $function))
9+
(type $cont (cont $function))
10+
;; CHECK: (type $function_2 (func (param i32 (ref $cont))))
11+
(type $function_2 (func (param i32 (ref $cont))))
12+
;; CHECK: (type $cont_2 (cont $function_2))
13+
(type $cont_2 (cont $function_2))
14+
;; CHECK: (tag $tag (type $4))
15+
(tag $tag)
16+
17+
;; CHECK: (func $switch (type $5) (param $c (ref $cont_2)) (result i64)
18+
;; CHECK-NEXT: (switch $cont_2 $tag
19+
;; CHECK-NEXT: (i32.const 0)
20+
;; CHECK-NEXT: (local.get $c)
21+
;; CHECK-NEXT: )
22+
;; CHECK-NEXT: )
23+
(func $switch (param $c (ref $cont_2)) (result i64)
24+
(switch $cont_2 $tag
25+
(i32.const 0)
26+
(local.get $c)
27+
)
28+
)
29+
)

0 commit comments

Comments
 (0)