blob: d2a7e5c482dcd43b8c144dd11e9ac6f9796c7e57 [file] [log] [blame] [edit]
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/names.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
// Replace memory.copy and memory.fill with a call to a function that
// implements the same semantics. This is intended to be used with LLVM output,
// so anything considered undefined behavior in LLVM is ignored. (In
// particular, pointer overflow is UB and not handled here).
namespace wasm {
struct LLVMMemoryCopyFillLowering
: public WalkerPass<PostWalker<LLVMMemoryCopyFillLowering>> {
bool needsMemoryCopy = false;
bool needsMemoryFill = false;
Name memCopyFuncName;
Name memFillFuncName;
void visitMemoryCopy(MemoryCopy* curr) {
assert(curr->destMemory ==
curr->sourceMemory); // multi-memory not supported.
Builder builder(*getModule());
replaceCurrent(builder.makeCall(
"__memory_copy", {curr->dest, curr->source, curr->size}, Type::none));
needsMemoryCopy = true;
}
void visitMemoryFill(MemoryFill* curr) {
Builder builder(*getModule());
replaceCurrent(builder.makeCall(
"__memory_fill", {curr->dest, curr->value, curr->size}, Type::none));
needsMemoryFill = true;
}
void run(Module* module) override {
if (!module->features.hasBulkMemoryOpt()) {
return;
}
if (module->features.hasMemory64() || module->features.hasMultiMemory()) {
Fatal()
<< "Memory64 and multi-memory not supported by memory.copy lowering";
}
// Check for the presence of any passive data or table segments.
for (auto& segment : module->dataSegments) {
if (segment->isPassive) {
Fatal() << "memory.copy lowering should only be run on modules with "
"no passive segments";
}
}
for (auto& segment : module->elementSegments) {
if (!segment->table.is()) {
Fatal() << "memory.copy lowering should only be run on modules with"
" no passive segments";
}
}
// In order to introduce a call to a function, it must first exist, so
// create an empty stub.
Builder b(*module);
memCopyFuncName = Names::getValidFunctionName(*module, "__memory_copy");
memFillFuncName = Names::getValidFunctionName(*module, "__memory_fill");
auto memCopyFunc = b.makeFunction(
memCopyFuncName,
{{"dst", Type::i32}, {"src", Type::i32}, {"size", Type::i32}},
Signature({Type::i32, Type::i32, Type::i32}, {Type::none}),
{{"start", Type::i32},
{"end", Type::i32},
{"step", Type::i32},
{"i", Type::i32}});
memCopyFunc->body = b.makeBlock();
module->addFunction(memCopyFunc.release());
auto memFillFunc = b.makeFunction(
memFillFuncName,
{{"dst", Type::i32}, {"val", Type::i32}, {"size", Type::i32}},
Signature({Type::i32, Type::i32, Type::i32}, {Type::none}),
{});
memFillFunc->body = b.makeBlock();
module->addFunction(memFillFunc.release());
Super::run(module);
if (needsMemoryCopy) {
createMemoryCopyFunc(module);
} else {
module->removeFunction(memCopyFuncName);
}
if (needsMemoryFill) {
createMemoryFillFunc(module);
} else {
module->removeFunction(memFillFuncName);
}
module->features.setBulkMemoryOpt(false);
}
void createMemoryCopyFunc(Module* module) {
Builder b(*module);
Index dst = 0, src = 1, size = 2, start = 3, end = 4, step = 5, i = 6;
Name memory = module->memories.front()->name;
Block* body = b.makeBlock();
// end = memory size in bytes
body->list.push_back(
b.makeLocalSet(end,
b.makeBinary(BinaryOp::MulInt32,
b.makeMemorySize(memory),
b.makeConst(Memory::kPageSize))));
// if dst + size > memsize or src + size > memsize, then trap.
body->list.push_back(b.makeIf(
b.makeBinary(BinaryOp::OrInt32,
b.makeBinary(BinaryOp::GtUInt32,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(dst, Type::i32),
b.makeLocalGet(size, Type::i32)),
b.makeLocalGet(end, Type::i32)),
b.makeBinary(BinaryOp::GtUInt32,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(src, Type::i32),
b.makeLocalGet(size, Type::i32)),
b.makeLocalGet(end, Type::i32))),
b.makeUnreachable()));
// start and end are the starting and past-the-end indexes
// if src < dest: start = size - 1, end = -1, step = -1
// else: start = 0, end = size, step = 1
body->list.push_back(
b.makeIf(b.makeBinary(BinaryOp::LtUInt32,
b.makeLocalGet(src, Type::i32),
b.makeLocalGet(dst, Type::i32)),
b.makeBlock({
b.makeLocalSet(start,
b.makeBinary(BinaryOp::SubInt32,
b.makeLocalGet(size, Type::i32),
b.makeConst(1))),
b.makeLocalSet(end, b.makeConst(-1U)),
b.makeLocalSet(step, b.makeConst(-1U)),
}),
b.makeBlock({
b.makeLocalSet(start, b.makeConst(0)),
b.makeLocalSet(end, b.makeLocalGet(size, Type::i32)),
b.makeLocalSet(step, b.makeConst(1)),
})));
// i = start
body->list.push_back(b.makeLocalSet(i, b.makeLocalGet(start, Type::i32)));
body->list.push_back(b.makeBlock(
"out",
b.makeLoop(
"copy",
b.makeBlock(
{// break if i == end
b.makeBreak("out",
nullptr,
b.makeBinary(BinaryOp::EqInt32,
b.makeLocalGet(i, Type::i32),
b.makeLocalGet(end, Type::i32))),
// dst[i] = src[i]
b.makeStore(1,
0,
1,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(dst, Type::i32),
b.makeLocalGet(i, Type::i32)),
b.makeLoad(1,
false,
0,
1,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(src, Type::i32),
b.makeLocalGet(i, Type::i32)),
Type::i32,
memory),
Type::i32,
memory),
// i += step
b.makeLocalSet(i,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(i, Type::i32),
b.makeLocalGet(step, Type::i32))),
// loop
b.makeBreak("copy", nullptr)}))));
module->getFunction(memCopyFuncName)->body = body;
}
void createMemoryFillFunc(Module* module) {
Builder b(*module);
Index dst = 0, val = 1, size = 2;
Name memory = module->memories.front()->name;
Block* body = b.makeBlock();
// if dst + size > memsize in bytes, then trap.
body->list.push_back(
b.makeIf(b.makeBinary(BinaryOp::GtUInt32,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(dst, Type::i32),
b.makeLocalGet(size, Type::i32)),
b.makeBinary(BinaryOp::MulInt32,
b.makeMemorySize(memory),
b.makeConst(Memory::kPageSize))),
b.makeUnreachable()));
body->list.push_back(b.makeBlock(
"out",
b.makeLoop(
"copy",
b.makeBlock(
{// break if size == 0
b.makeBreak(
"out",
nullptr,
b.makeUnary(UnaryOp::EqZInt32, b.makeLocalGet(size, Type::i32))),
// size--
b.makeLocalSet(size,
b.makeBinary(BinaryOp::SubInt32,
b.makeLocalGet(size, Type::i32),
b.makeConst(1))),
// *(dst+size) = val
b.makeStore(1,
0,
1,
b.makeBinary(BinaryOp::AddInt32,
b.makeLocalGet(dst, Type::i32),
b.makeLocalGet(size, Type::i32)),
b.makeLocalGet(val, Type::i32),
Type::i32,
memory),
b.makeBreak("copy", nullptr)}))));
module->getFunction(memFillFuncName)->body = body;
}
void VisitTableCopy(TableCopy* curr) {
Fatal() << "table.copy instruction found. Memory copy lowering is not "
"designed to work on modules with bulk table operations";
}
void VisitTableFill(TableCopy* curr) {
Fatal() << "table.fill instruction found. Memory copy lowering is not "
"designed to work on modules with bulk table operations";
}
};
Pass* createLLVMMemoryCopyFillLoweringPass() {
return new LLVMMemoryCopyFillLowering();
}
} // namespace wasm