blob: 4c797351b8b386bd9f20d20d5b503631378d73d3 [file] [log] [blame] [edit]
/*
* Copyright 2020 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Enforce stack pointer limits. This pass will add checks around all
// assignments to the __stack_pointer global that LLVM uses for its
// shadow stack.
//
#include "abi/js.h"
#include "ir/import-utils.h"
#include "pass.h"
#include "shared-constants.h"
#include "support/debug.h"
#include "wasm-emscripten.h"
#define DEBUG_TYPE "stack-check"
namespace wasm {
static Name STACK_LIMIT("__stack_limit");
static Name SET_STACK_LIMIT("__set_stack_limit");
static void importStackOverflowHandler(Module& module, Name name) {
ImportInfo info(module);
if (!info.getImportedFunction(ENV, name)) {
auto* import = new Function;
import->name = name;
import->module = ENV;
import->base = name;
import->sig = Signature(Type::none, Type::none);
module.addFunction(import);
}
}
static void addExportedFunction(Module& module, Function* function) {
module.addFunction(function);
auto export_ = new Export;
export_->name = export_->value = function->name;
export_->kind = ExternalKind::Function;
module.addExport(export_);
}
static void generateSetStackLimitFunction(Module& module) {
Builder builder(module);
Function* function =
builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
function->body = store;
addExportedFunction(module, function);
}
struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
EnforceStackLimit(Global* stackPointer,
Global* stackLimit,
Builder& builder,
Name handler)
: stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
handler(handler) {}
bool isFunctionParallel() override { return true; }
Pass* create() override {
return new EnforceStackLimit(stackPointer, stackLimit, builder, handler);
}
Expression* stackBoundsCheck(Function* func,
Expression* value,
Global* stackPointer,
Global* stackLimit) {
// Add a local to store the value of the expression. We need the value
// twice: once to check if it has overflowed, and again to assign to store
// it.
auto newSP = Builder::addVar(func, stackPointer->type);
// If we imported a handler, call it. That can show a nice error in JS.
// Otherwise, just trap.
Expression* handlerExpr;
if (handler.is()) {
handlerExpr = builder.makeCall(handler, {}, Type::none);
} else {
handlerExpr = builder.makeUnreachable();
}
// (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
auto check = builder.makeIf(
builder.makeBinary(
BinaryOp::LtUInt32,
builder.makeLocalTee(newSP, value, stackPointer->type),
builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
handlerExpr);
// (global.set $__stack_pointer (local.get $newSP))
auto newSet = builder.makeGlobalSet(
stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
return builder.blockify(check, newSet);
}
void visitGlobalSet(GlobalSet* curr) {
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
replaceCurrent(
stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
}
}
private:
Global* stackPointer;
Global* stackLimit;
Builder& builder;
Name handler;
};
struct StackCheck : public Pass {
void run(PassRunner* runner, Module* module) override {
Global* stackPointer = getStackPointerGlobal(*module);
if (!stackPointer) {
BYN_DEBUG(std::cerr << "no stack pointer found\n");
return;
}
Name handler;
auto handlerName =
runner->options.getArgumentOrDefault("stack-check-handler", "");
if (handlerName != "") {
handler = handlerName;
importStackOverflowHandler(*module, handler);
}
Builder builder(*module);
Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
stackPointer->type,
builder.makeConst(int32_t(0)),
Builder::Mutable);
module->addGlobal(stackLimit);
PassRunner innerRunner(module);
EnforceStackLimit(stackPointer, stackLimit, builder, handler)
.run(&innerRunner, module);
generateSetStackLimitFunction(*module);
}
};
Pass* createStackCheckPass() { return new StackCheck; }
} // namespace wasm