Skip to content

Commit 67b5b2d

Browse files
committed
[Custom Page Sizes] Support Custom Page Sizes proposal (#6873)
1 parent 8470f1b commit 67b5b2d

33 files changed

+691
-192
lines changed

scripts/test/shared.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ def get_tests(test_dir, extensions=[], recursive=False):
405405

406406
# Unlinkable module accepted
407407
'linking.wast',
408+
'memory_max.wast',
409+
'memory_max_i64.wast',
408410

409411
# Invalid module accepted
410412
'unreached-invalid.wast',

src/ir/module-utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Memory* copyMemory(const Memory* memory, Module& out) {
173173
ret->hasExplicitName = memory->hasExplicitName;
174174
ret->initial = memory->initial;
175175
ret->max = memory->max;
176+
ret->pageSizeLog2 = memory->pageSizeLog2;
176177
ret->shared = memory->shared;
177178
ret->addressType = memory->addressType;
178179
ret->module = memory->module;

src/js/binaryen.js-post.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ function initializeConstants() {
180180
'FP16',
181181
'BulkMemoryOpt',
182182
'CallIndirectOverlong',
183+
'CustomPageSizes',
183184
'All'
184185
].forEach(name => {
185186
Module['Features'][name] = Module['_BinaryenFeature' + name]();

src/parser/context-decls.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ Result<Memory*> ParseDeclsCtx::addMemoryDecl(Index pos,
143143
m->initial = type.limits.initial;
144144
m->max = type.limits.max ? *type.limits.max : Memory::kUnlimitedSize;
145145
m->shared = type.shared;
146+
m->pageSizeLog2 = type.pageSizeLog2;
146147
if (name) {
147148
// TODO: if the existing memory is not explicitly named, fix its name
148149
// and continue.

src/parser/contexts.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct Limits {
4949
struct MemType {
5050
Type addressType;
5151
Limits limits;
52+
uint8_t pageSizeLog2;
5253
bool shared;
5354
};
5455

@@ -353,7 +354,9 @@ template<typename Ctx> struct TypeParserCtx {
353354
Result<LimitsT> makeLimits(uint64_t, std::optional<uint64_t>) { return Ok{}; }
354355
LimitsT getLimitsFromData(DataStringT) { return Ok{}; }
355356

356-
MemTypeT makeMemType(Type, LimitsT, bool) { return Ok{}; }
357+
MemTypeT makeMemType(Type, LimitsT, bool, std::optional<uint8_t>) {
358+
return Ok{};
359+
}
357360

358361
HeapType getBlockTypeFromResult(const std::vector<Type> results) {
359362
assert(results.size() == 1);
@@ -1052,13 +1055,18 @@ struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx {
10521055
data.insert(data.end(), str.begin(), str.end());
10531056
}
10541057

1055-
Limits getLimitsFromData(const std::vector<char>& data) {
1056-
uint64_t size = (data.size() + Memory::kPageSize - 1) / Memory::kPageSize;
1058+
Limits getLimitsFromData(const std::vector<char>& data, std::optional<uint8_t> pageSizeLog2) {
1059+
uint8_t _pageSizeLog2 = pageSizeLog2.value_or(16);
1060+
uint64_t size = (data.size() + (1<<_pageSizeLog2) - 1) / (1<<_pageSizeLog2);
10571061
return {size, size};
10581062
}
10591063

1060-
MemType makeMemType(Type addressType, Limits limits, bool shared) {
1061-
return {addressType, limits, shared};
1064+
MemType makeMemType(Type addressType,
1065+
Limits limits,
1066+
bool shared,
1067+
std::optional<uint8_t> pageSize) {
1068+
uint8_t pageSizeLog2 = pageSize.value_or(16);
1069+
return {addressType, limits, pageSizeLog2, shared};
10621070
}
10631071

10641072
Result<TypeUseT>
@@ -1400,8 +1408,10 @@ struct ParseModuleTypesCtx : TypeParserCtx<ParseModuleTypesCtx>,
14001408

14011409
Type makeTableType(Type addressType, LimitsT, Type type) { return type; }
14021410

1403-
LimitsT getLimitsFromData(DataStringT) { return Ok{}; }
1404-
MemTypeT makeMemType(Type, LimitsT, bool) { return Ok{}; }
1411+
LimitsT getLimitsFromData(DataStringT, std::optional<uint8_t>) {
1412+
return Ok{};
1413+
}
1414+
MemTypeT makeMemType(Type, LimitsT, bool, std::optional<uint8_t>) { return Ok{}; }
14051415

14061416
Result<> addFunc(Name name,
14071417
const std::vector<Name>&,

src/parser/parsers.h

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,36 @@ template<typename Ctx> Result<typename Ctx::LimitsT> limits64(Ctx& ctx) {
818818
return ctx.makeLimits(uint64_t(*n), m);
819819
}
820820

821-
// memtype ::= (limits32 | 'i32' limits32 | 'i64' limit64) shared?
821+
// mempagesize? ::= ('(' 'pagesize' u64 ')') ?
822+
template<typename Ctx> Result<std::optional<uint8_t>> mempagesize(Ctx& ctx) {
823+
if (!ctx.in.takeSExprStart("pagesize"sv)) {
824+
return std::nullopt; // No pagesize specified
825+
}
826+
auto pageSize = ctx.in.takeU64();
827+
if (!pageSize) {
828+
return ctx.in.err("expected page size");
829+
}
830+
831+
if (!Bits::isPowerOf2(*pageSize)) {
832+
return ctx.in.err("page size must be a power of two");
833+
}
834+
835+
if (!ctx.in.takeRParen()) {
836+
return ctx.in.err("expected end of mempagesize");
837+
}
838+
839+
// return the log2 of the page size, which is the number of trailing zeros
840+
uint8_t pageSizeLog2 = (uint8_t) Bits::ceilLog2(*pageSize);
841+
842+
if (pageSizeLog2 != 0 && pageSizeLog2 != Memory::kDefaultPageSizeLog2) {
843+
return ctx.in.err("memory page size can only be 1 or 64 KiB");
844+
}
845+
846+
return std::make_optional<uint8_t>(pageSizeLog2);
847+
848+
}
849+
850+
// memtype ::= (limits32 | 'i32' limits32 | 'i64' limit64) shared? mempagesize?
822851
// note: the index type 'i32' or 'i64' is already parsed to simplify parsing of
823852
// memory abbreviations.
824853
template<typename Ctx> Result<typename Ctx::MemTypeT> memtype(Ctx& ctx) {
@@ -840,7 +869,9 @@ Result<typename Ctx::MemTypeT> memtypeContinued(Ctx& ctx, Type addressType) {
840869
if (ctx.in.takeKeyword("shared"sv)) {
841870
shared = true;
842871
}
843-
return ctx.makeMemType(addressType, *limits, shared);
872+
auto pageSize = mempagesize(ctx);
873+
CHECK_ERR(pageSize);
874+
return ctx.makeMemType(addressType, *limits, shared, *pageSize);
844875
}
845876

846877
// memorder ::= '' | 'seqcst' | 'acqrel'
@@ -3434,6 +3465,8 @@ template<typename Ctx> MaybeResult<> memory(Ctx& ctx) {
34343465

34353466
std::optional<typename Ctx::MemTypeT> mtype;
34363467
std::optional<typename Ctx::DataStringT> data;
3468+
auto mempageSize = mempagesize(ctx);
3469+
CHECK_ERR(mempageSize);
34373470
if (ctx.in.takeSExprStart("data"sv)) {
34383471
if (import) {
34393472
return ctx.in.err("imported memories cannot have inline data");
@@ -3443,9 +3476,13 @@ template<typename Ctx> MaybeResult<> memory(Ctx& ctx) {
34433476
if (!ctx.in.takeRParen()) {
34443477
return ctx.in.err("expected end of inline data");
34453478
}
3446-
mtype =
3447-
ctx.makeMemType(addressType, ctx.getLimitsFromData(*datastr), false);
3479+
mtype = ctx.makeMemType(
3480+
addressType, ctx.getLimitsFromData(*datastr, *mempageSize), false, *mempageSize);
34483481
data = *datastr;
3482+
} else if ((*mempageSize).has_value()) {
3483+
// If we have a memory page size not within a memtype expression, we expect
3484+
// a memory abbreviation.
3485+
return ctx.in.err("expected data segment in memory abbreviation");
34493486
} else {
34503487
auto type = memtypeContinued(ctx, addressType);
34513488
CHECK_ERR(type);

src/passes/LLVMMemoryCopyFillLowering.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,15 @@ struct LLVMMemoryCopyFillLowering
117117
void createMemoryCopyFunc(Module* module) {
118118
Builder b(*module);
119119
Index dst = 0, src = 1, size = 2, start = 3, end = 4, step = 5, i = 6;
120-
Name memory = module->memories.front()->name;
120+
Name memory_name = module->memories.front()->name;
121+
Address::address32_t memory_page_size = module->memories.front()->pageSizeLog2;
121122
Block* body = b.makeBlock();
122123
// end = memory size in bytes
123124
body->list.push_back(
124125
b.makeLocalSet(end,
125-
b.makeBinary(BinaryOp::MulInt32,
126-
b.makeMemorySize(memory),
127-
b.makeConst(Memory::kPageSize))));
126+
b.makeBinary(BinaryOp::ShlInt32,
127+
b.makeMemorySize(memory_name),
128+
b.makeConst(memory_page_size))));
128129
// if dst + size > memsize or src + size > memsize, then trap.
129130
body->list.push_back(b.makeIf(
130131
b.makeBinary(BinaryOp::OrInt32,
@@ -187,9 +188,9 @@ struct LLVMMemoryCopyFillLowering
187188
b.makeLocalGet(src, Type::i32),
188189
b.makeLocalGet(i, Type::i32)),
189190
Type::i32,
190-
memory),
191+
memory_name),
191192
Type::i32,
192-
memory),
193+
memory_name),
193194
// i += step
194195
b.makeLocalSet(i,
195196
b.makeBinary(BinaryOp::AddInt32,
@@ -203,7 +204,9 @@ struct LLVMMemoryCopyFillLowering
203204
void createMemoryFillFunc(Module* module) {
204205
Builder b(*module);
205206
Index dst = 0, val = 1, size = 2;
206-
Name memory = module->memories.front()->name;
207+
Name memory_name = module->memories.front()->name;
208+
Address::address32_t memory_page_size =
209+
module->memories.front()->pageSizeLog2;
207210
Block* body = b.makeBlock();
208211

209212
// if dst + size > memsize in bytes, then trap.
@@ -212,9 +215,9 @@ struct LLVMMemoryCopyFillLowering
212215
b.makeBinary(BinaryOp::AddInt32,
213216
b.makeLocalGet(dst, Type::i32),
214217
b.makeLocalGet(size, Type::i32)),
215-
b.makeBinary(BinaryOp::MulInt32,
216-
b.makeMemorySize(memory),
217-
b.makeConst(Memory::kPageSize))),
218+
b.makeBinary(BinaryOp::ShlInt32,
219+
b.makeMemorySize(memory_name),
220+
b.makeConst(memory_page_size))),
218221
b.makeUnreachable()));
219222

220223
body->list.push_back(b.makeBlock(
@@ -241,7 +244,7 @@ struct LLVMMemoryCopyFillLowering
241244
b.makeLocalGet(size, Type::i32)),
242245
b.makeLocalGet(val, Type::i32),
243246
Type::i32,
244-
memory),
247+
memory_name),
245248
b.makeBreak("copy", nullptr)}))));
246249
module->getFunction(memFillFuncName)->body = body;
247250
}

src/passes/Memory64Lowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ struct Memory64Lowering : public WalkerPass<PostWalker<Memory64Lowering>> {
293293
for (auto& memory : module->memories) {
294294
if (memory->is64()) {
295295
memory->addressType = Type::i32;
296-
if (memory->hasMax() && memory->max > Memory::kMaxSize32) {
297-
memory->max = Memory::kMaxSize32;
296+
if (memory->hasMax() && memory->max > memory->maxSize32()) {
297+
memory->max = memory->maxSize32();
298298
}
299299
}
300300
}

src/passes/MemoryPacking.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ void MemoryPacking::calculateRanges(Module* module,
315315
// Check if we can rule out a trap by it being in bounds.
316316
if (auto* c = segment->offset->dynCast<Const>()) {
317317
auto* memory = module->getMemory(segment->memory);
318-
auto memorySize = memory->initial * Memory::kPageSize;
318+
auto memorySize = memory->initial << memory->pageSizeLog2;
319319
Index start = c->value.getUnsigned();
320320
Index size = segment->data.size();
321321
Index end;

src/passes/MultiMemoryLowering.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ struct MultiMemoryLowering : public Pass {
6868
// properties will be set
6969
Name module;
7070
Name base;
71-
// The initial page size of the combined memory
71+
// The page size of the combined memory
72+
uint8_t pageSizeLog2;
73+
// The initial page count of the combined memory
7274
Address totalInitialPages;
73-
// The max page size of the combined memory
75+
// The max page count of the combined memory
7476
Address totalMaxPages;
7577
// There is no offset for the first memory, so offsetGlobalNames will always
7678
// have a size that is one less than the count of memories at the time this
@@ -435,6 +437,7 @@ struct MultiMemoryLowering : public Pass {
435437
: Builder::MemoryInfo::Memory64;
436438
isShared = getFirstMemory().shared;
437439
isImported = getFirstMemory().imported();
440+
pageSizeLog2 = Memory::kDefaultPageSizeLog2;
438441
for (auto& memory : wasm->memories) {
439442
// We are assuming that each memory is configured the same as the first
440443
// and assert if any of the memories does not match this configuration
@@ -446,18 +449,21 @@ struct MultiMemoryLowering : public Pass {
446449
Fatal() << "MultiMemoryLowering: only the first memory can be imported";
447450
}
448451

452+
// Calculating the page size of the combined memory.
453+
// This corresponds to the smaller granularity among combined memories
454+
pageSizeLog2 = std::min(pageSizeLog2,memory->pageSizeLog2);
455+
449456
// Calculating the total initial and max page size for the combined memory
450457
// by totaling the initial and max page sizes for the memories in the
451458
// module
452-
totalInitialPages = totalInitialPages + memory->initial;
459+
totalInitialPages = totalInitialPages + (memory->initial << (memory->pageSizeLog2 - pageSizeLog2));
453460
if (memory->hasMax()) {
454-
totalMaxPages = totalMaxPages + memory->max;
461+
totalMaxPages = totalMaxPages + (memory->max << (memory->pageSizeLog2 - pageSizeLog2));
455462
}
456463
}
457464
// Ensuring valid initial and max page sizes that do not exceed the number
458465
// of pages addressable by the pointerType
459-
Address maxSize =
460-
pointerType == Type::i32 ? Memory::kMaxSize32 : Memory::kMaxSize64;
466+
Address maxSize = pointerType == Type::i32 ? 1ull<<(32-pageSizeLog2) : 1ull<<(64-pageSizeLog2);
461467
if (totalMaxPages > maxSize || totalMaxPages == 0) {
462468
totalMaxPages = Memory::kUnlimitedSize;
463469
}
@@ -504,9 +510,9 @@ struct MultiMemoryLowering : public Pass {
504510
Name name = Names::getValidGlobalName(
505511
*wasm, memory->name.toString() + "_byte_offset");
506512
offsetGlobalNames.push_back(std::move(name));
507-
addGlobal(name, offsetRunningTotal * Memory::kPageSize);
513+
addGlobal(name, offsetRunningTotal << pageSizeLog2);
508514
}
509-
offsetRunningTotal += memory->initial;
515+
offsetRunningTotal += memory->initial << (memory->pageSizeLog2 - pageSizeLog2);
510516
}
511517
}
512518

@@ -554,10 +560,10 @@ struct MultiMemoryLowering : public Pass {
554560
functionName, Signature(pointerType, pointerType), {});
555561
function->setLocalName(0, "page_delta");
556562
auto pageSizeConst = [&]() {
557-
return builder.makeConst(Literal(Memory::kPageSize));
563+
return builder.makeConst(Literal(wasm->memories[memIdx]->pageSizeLog2));
558564
};
559565
auto getOffsetDelta = [&]() {
560-
return builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Mul),
566+
return builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Shl),
561567
builder.makeLocalGet(0, pointerType),
562568
pageSizeConst());
563569
};
@@ -588,7 +594,8 @@ struct MultiMemoryLowering : public Pass {
588594
builder.makeBinary(
589595
EqInt32,
590596
builder.makeMemoryGrow(
591-
builder.makeLocalGet(0, pointerType), combinedMemory, memoryInfo),
597+
builder.makeBinary( Abstract::getBinary(pointerType, Abstract::Shl),
598+
builder.makeLocalGet(0, pointerType), builder.makeConst(Literal(wasm->memories[memIdx]->pageSizeLog2 - pageSizeLog2))) , combinedMemory, memoryInfo),
592599
builder.makeConst(-1)),
593600
builder.makeReturn(builder.makeConst(-1))));
594601

@@ -609,7 +616,7 @@ struct MultiMemoryLowering : public Pass {
609616
// size
610617
builder.makeBinary(
611618
Abstract::getBinary(pointerType, Abstract::Sub),
612-
builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Mul),
619+
builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Shl),
613620
builder.makeLocalGet(sizeLocal, pointerType),
614621
pageSizeConst()),
615622
getMoveSource(offsetGlobalName)),
@@ -646,11 +653,11 @@ struct MultiMemoryLowering : public Pass {
646653
functionName, Signature(Type::none, pointerType), {});
647654
Expression* functionBody;
648655
auto pageSizeConst = [&]() {
649-
return builder.makeConst(Literal(Memory::kPageSize));
656+
return builder.makeConst(Literal(pageSizeLog2));
650657
};
651658
auto getOffsetInPageUnits = [&](Name global) {
652659
return builder.makeBinary(
653-
Abstract::getBinary(pointerType, Abstract::DivU),
660+
Abstract::getBinary(pointerType, Abstract::ShrU),
654661
builder.makeGlobalGet(global, pointerType),
655662
pageSizeConst());
656663
};
@@ -697,6 +704,7 @@ struct MultiMemoryLowering : public Pass {
697704
memory->base = base;
698705
memory->module = module;
699706
}
707+
memory->pageSizeLog2 = pageSizeLog2;
700708
wasm->addMemory(std::move(memory));
701709
}
702710

0 commit comments

Comments
 (0)