| /* |
| * Copyright 2024 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "ir/names.h" |
| #include "pass.h" |
| #include "wasm-builder.h" |
| #include "wasm.h" |
| |
| // Replace memory.copy and memory.fill with a call to a function that |
| // implements the same semantics. This is intended to be used with LLVM output, |
| // so anything considered undefined behavior in LLVM is ignored. (In |
| // particular, pointer overflow is UB and not handled here). |
| |
| namespace wasm { |
| struct LLVMMemoryCopyFillLowering |
| : public WalkerPass<PostWalker<LLVMMemoryCopyFillLowering>> { |
| bool needsMemoryCopy = false; |
| bool needsMemoryFill = false; |
| Name memCopyFuncName; |
| Name memFillFuncName; |
| |
| void visitMemoryCopy(MemoryCopy* curr) { |
| assert(curr->destMemory == |
| curr->sourceMemory); // multi-memory not supported. |
| Builder builder(*getModule()); |
| replaceCurrent(builder.makeCall( |
| "__memory_copy", {curr->dest, curr->source, curr->size}, Type::none)); |
| needsMemoryCopy = true; |
| } |
| |
| void visitMemoryFill(MemoryFill* curr) { |
| Builder builder(*getModule()); |
| replaceCurrent(builder.makeCall( |
| "__memory_fill", {curr->dest, curr->value, curr->size}, Type::none)); |
| needsMemoryFill = true; |
| } |
| |
| void run(Module* module) override { |
| if (!module->features.hasBulkMemoryOpt()) { |
| return; |
| } |
| if (module->features.hasMemory64() || module->features.hasMultiMemory()) { |
| Fatal() |
| << "Memory64 and multi-memory not supported by memory.copy lowering"; |
| } |
| |
| // Check for the presence of any passive data or table segments. |
| for (auto& segment : module->dataSegments) { |
| if (segment->isPassive) { |
| Fatal() << "memory.copy lowering should only be run on modules with " |
| "no passive segments"; |
| } |
| } |
| for (auto& segment : module->elementSegments) { |
| if (!segment->table.is()) { |
| Fatal() << "memory.copy lowering should only be run on modules with" |
| " no passive segments"; |
| } |
| } |
| |
| // In order to introduce a call to a function, it must first exist, so |
| // create an empty stub. |
| Builder b(*module); |
| |
| memCopyFuncName = Names::getValidFunctionName(*module, "__memory_copy"); |
| memFillFuncName = Names::getValidFunctionName(*module, "__memory_fill"); |
| auto memCopyFunc = b.makeFunction( |
| memCopyFuncName, |
| {{"dst", Type::i32}, {"src", Type::i32}, {"size", Type::i32}}, |
| Signature({Type::i32, Type::i32, Type::i32}, {Type::none}), |
| {{"start", Type::i32}, |
| {"end", Type::i32}, |
| {"step", Type::i32}, |
| {"i", Type::i32}}); |
| memCopyFunc->body = b.makeBlock(); |
| module->addFunction(memCopyFunc.release()); |
| auto memFillFunc = b.makeFunction( |
| memFillFuncName, |
| {{"dst", Type::i32}, {"val", Type::i32}, {"size", Type::i32}}, |
| Signature({Type::i32, Type::i32, Type::i32}, {Type::none}), |
| {}); |
| memFillFunc->body = b.makeBlock(); |
| module->addFunction(memFillFunc.release()); |
| |
| Super::run(module); |
| |
| if (needsMemoryCopy) { |
| createMemoryCopyFunc(module); |
| } else { |
| module->removeFunction(memCopyFuncName); |
| } |
| |
| if (needsMemoryFill) { |
| createMemoryFillFunc(module); |
| } else { |
| module->removeFunction(memFillFuncName); |
| } |
| module->features.setBulkMemoryOpt(false); |
| } |
| |
| void createMemoryCopyFunc(Module* module) { |
| Builder b(*module); |
| Index dst = 0, src = 1, size = 2, start = 3, end = 4, step = 5, i = 6; |
| Name memory = module->memories.front()->name; |
| Block* body = b.makeBlock(); |
| // end = memory size in bytes |
| body->list.push_back( |
| b.makeLocalSet(end, |
| b.makeBinary(BinaryOp::MulInt32, |
| b.makeMemorySize(memory), |
| b.makeConst(Memory::kPageSize)))); |
| // if dst + size > memsize or src + size > memsize, then trap. |
| body->list.push_back(b.makeIf( |
| b.makeBinary(BinaryOp::OrInt32, |
| b.makeBinary(BinaryOp::GtUInt32, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(dst, Type::i32), |
| b.makeLocalGet(size, Type::i32)), |
| b.makeLocalGet(end, Type::i32)), |
| b.makeBinary(BinaryOp::GtUInt32, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(src, Type::i32), |
| b.makeLocalGet(size, Type::i32)), |
| b.makeLocalGet(end, Type::i32))), |
| b.makeUnreachable())); |
| // start and end are the starting and past-the-end indexes |
| // if src < dest: start = size - 1, end = -1, step = -1 |
| // else: start = 0, end = size, step = 1 |
| body->list.push_back( |
| b.makeIf(b.makeBinary(BinaryOp::LtUInt32, |
| b.makeLocalGet(src, Type::i32), |
| b.makeLocalGet(dst, Type::i32)), |
| b.makeBlock({ |
| b.makeLocalSet(start, |
| b.makeBinary(BinaryOp::SubInt32, |
| b.makeLocalGet(size, Type::i32), |
| b.makeConst(1))), |
| b.makeLocalSet(end, b.makeConst(-1U)), |
| b.makeLocalSet(step, b.makeConst(-1U)), |
| }), |
| b.makeBlock({ |
| b.makeLocalSet(start, b.makeConst(0)), |
| b.makeLocalSet(end, b.makeLocalGet(size, Type::i32)), |
| b.makeLocalSet(step, b.makeConst(1)), |
| }))); |
| // i = start |
| body->list.push_back(b.makeLocalSet(i, b.makeLocalGet(start, Type::i32))); |
| body->list.push_back(b.makeBlock( |
| "out", |
| b.makeLoop( |
| "copy", |
| b.makeBlock( |
| {// break if i == end |
| b.makeBreak("out", |
| nullptr, |
| b.makeBinary(BinaryOp::EqInt32, |
| b.makeLocalGet(i, Type::i32), |
| b.makeLocalGet(end, Type::i32))), |
| // dst[i] = src[i] |
| b.makeStore(1, |
| 0, |
| 1, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(dst, Type::i32), |
| b.makeLocalGet(i, Type::i32)), |
| b.makeLoad(1, |
| false, |
| 0, |
| 1, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(src, Type::i32), |
| b.makeLocalGet(i, Type::i32)), |
| Type::i32, |
| memory), |
| Type::i32, |
| memory), |
| // i += step |
| b.makeLocalSet(i, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(i, Type::i32), |
| b.makeLocalGet(step, Type::i32))), |
| // loop |
| b.makeBreak("copy", nullptr)})))); |
| module->getFunction(memCopyFuncName)->body = body; |
| } |
| |
| void createMemoryFillFunc(Module* module) { |
| Builder b(*module); |
| Index dst = 0, val = 1, size = 2; |
| Name memory = module->memories.front()->name; |
| Block* body = b.makeBlock(); |
| |
| // if dst + size > memsize in bytes, then trap. |
| body->list.push_back( |
| b.makeIf(b.makeBinary(BinaryOp::GtUInt32, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(dst, Type::i32), |
| b.makeLocalGet(size, Type::i32)), |
| b.makeBinary(BinaryOp::MulInt32, |
| b.makeMemorySize(memory), |
| b.makeConst(Memory::kPageSize))), |
| b.makeUnreachable())); |
| |
| body->list.push_back(b.makeBlock( |
| "out", |
| b.makeLoop( |
| "copy", |
| b.makeBlock( |
| {// break if size == 0 |
| b.makeBreak( |
| "out", |
| nullptr, |
| b.makeUnary(UnaryOp::EqZInt32, b.makeLocalGet(size, Type::i32))), |
| // size-- |
| b.makeLocalSet(size, |
| b.makeBinary(BinaryOp::SubInt32, |
| b.makeLocalGet(size, Type::i32), |
| b.makeConst(1))), |
| // *(dst+size) = val |
| b.makeStore(1, |
| 0, |
| 1, |
| b.makeBinary(BinaryOp::AddInt32, |
| b.makeLocalGet(dst, Type::i32), |
| b.makeLocalGet(size, Type::i32)), |
| b.makeLocalGet(val, Type::i32), |
| Type::i32, |
| memory), |
| b.makeBreak("copy", nullptr)})))); |
| module->getFunction(memFillFuncName)->body = body; |
| } |
| |
| void VisitTableCopy(TableCopy* curr) { |
| Fatal() << "table.copy instruction found. Memory copy lowering is not " |
| "designed to work on modules with bulk table operations"; |
| } |
| void VisitTableFill(TableCopy* curr) { |
| Fatal() << "table.fill instruction found. Memory copy lowering is not " |
| "designed to work on modules with bulk table operations"; |
| } |
| }; |
| |
| Pass* createLLVMMemoryCopyFillLoweringPass() { |
| return new LLVMMemoryCopyFillLowering(); |
| } |
| |
| } // namespace wasm |