| /* |
| * Copyright 2017 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include <mutex> |
| #include <set> |
| #include <sstream> |
| #include <unordered_set> |
| |
| #include "wasm.h" |
| #include "wasm-printing.h" |
| #include "wasm-validator.h" |
| #include "ast_utils.h" |
| #include "ast/branch-utils.h" |
| #include "support/colors.h" |
| |
| |
| namespace wasm { |
| |
| // Print anything that can be streamed to an ostream |
| template <typename T, |
| typename std::enable_if< |
| !std::is_base_of<Expression, typename std::remove_pointer<T>::type>::value |
| >::type* = nullptr> |
| inline std::ostream& printModuleComponent(T curr, std::ostream& stream) { |
| stream << curr << std::endl; |
| return stream; |
| } |
| |
| // Extra overload for Expressions, to print type info too |
| inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) { |
| WasmPrinter::printExpression(curr, stream, false, true) << std::endl; |
| return stream; |
| } |
| |
| // For parallel validation, we have a helper struct for coordination |
| struct ValidationInfo { |
| bool validateWeb; |
| bool validateGlobally; |
| bool quiet; |
| |
| std::atomic<bool> valid; |
| |
| // a stream of error test for each function. we print in the right order at |
| // the end, for deterministic output |
| // note errors are rare/unexpected, so it's ok to use a slow mutex here |
| std::mutex mutex; |
| std::unordered_map<Function*, std::unique_ptr<std::ostringstream>> outputs; |
| |
| ValidationInfo() { |
| valid.store(true); |
| } |
| |
| std::ostringstream& getStream(Function* func) { |
| std::unique_lock<std::mutex> lock(mutex); |
| auto iter = outputs.find(func); |
| if (iter != outputs.end()) return *(iter->second.get()); |
| auto& ret = outputs[func] = make_unique<std::ostringstream>(); |
| return *ret.get(); |
| } |
| |
| // printing and error handling support |
| |
| template <typename T, typename S> |
| std::ostream& fail(S text, T curr, Function* func) { |
| valid.store(false); |
| auto& stream = getStream(func); |
| if (quiet) return stream; |
| auto& ret = printFailureHeader(func); |
| ret << text << ", on \n"; |
| return printModuleComponent(curr, ret); |
| } |
| |
| std::ostream& printFailureHeader(Function* func) { |
| auto& stream = getStream(func); |
| if (quiet) return stream; |
| Colors::red(stream); |
| if (func) { |
| stream << "[wasm-validator error in function "; |
| Colors::green(stream); |
| stream << func->name; |
| Colors::red(stream); |
| stream << "] "; |
| } else { |
| stream << "[wasm-validator error in module] "; |
| } |
| Colors::normal(stream); |
| return stream; |
| } |
| |
| // checking utilities |
| |
| template<typename T> |
| bool shouldBeTrue(bool result, T curr, const char* text, Function* func = nullptr) { |
| if (!result) { |
| fail("unexpected false: " + std::string(text), curr, func); |
| return false; |
| } |
| return result; |
| } |
| template<typename T> |
| bool shouldBeFalse(bool result, T curr, const char* text, Function* func = nullptr) { |
| if (result) { |
| fail("unexpected true: " + std::string(text), curr, func); |
| return false; |
| } |
| return result; |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeEqual(S left, S right, T curr, const char* text, Function* func = nullptr) { |
| if (left != right) { |
| std::ostringstream ss; |
| ss << left << " != " << right << ": " << text; |
| fail(ss.str(), curr, func); |
| return false; |
| } |
| return true; |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text, Function* func = nullptr) { |
| if (left != unreachable && left != right) { |
| std::ostringstream ss; |
| ss << left << " != " << right << ": " << text; |
| fail(ss.str(), curr, func); |
| return false; |
| } |
| return true; |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeUnequal(S left, S right, T curr, const char* text, Function* func = nullptr) { |
| if (left == right) { |
| std::ostringstream ss; |
| ss << left << " == " << right << ": " << text; |
| fail(ss.str(), curr, func); |
| return false; |
| } |
| return true; |
| } |
| |
| void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text, Function* func = nullptr) { |
| switch (ty) { |
| case i32: |
| case i64: |
| case unreachable: { |
| break; |
| } |
| default: fail(text, curr, func); |
| } |
| } |
| |
| }; |
| |
| struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { |
| bool isFunctionParallel() override { return true; } |
| |
| Pass* create() override { return new FunctionValidator(&info); } |
| |
| ValidationInfo& info; |
| |
| FunctionValidator(ValidationInfo* info) : info(*info) {} |
| |
| struct BreakInfo { |
| WasmType type; |
| Index arity; |
| BreakInfo() {} |
| BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} |
| }; |
| |
| std::map<Name, Expression*> breakTargets; |
| std::map<Expression*, BreakInfo> breakInfos; |
| |
| WasmType returnType = unreachable; // type used in returns |
| |
| std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that |
| |
| std::unordered_set<Expression*> seenExpressions; // expressions must not appear twice |
| |
| void noteLabelName(Name name); |
| |
| public: |
| // visitors |
| |
| static void visitPreBlock(FunctionValidator* self, Expression** currp) { |
| auto* curr = (*currp)->cast<Block>(); |
| if (curr->name.is()) self->breakTargets[curr->name] = curr; |
| } |
| |
| void visitBlock(Block *curr); |
| |
| static void visitPreLoop(FunctionValidator* self, Expression** currp) { |
| auto* curr = (*currp)->cast<Loop>(); |
| if (curr->name.is()) self->breakTargets[curr->name] = curr; |
| } |
| |
| void visitLoop(Loop *curr); |
| void visitIf(If *curr); |
| |
| // override scan to add a pre and a post check task to all nodes |
| static void scan(FunctionValidator* self, Expression** currp) { |
| PostWalker<FunctionValidator>::scan(self, currp); |
| |
| auto* curr = *currp; |
| if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); |
| if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); |
| } |
| |
| void noteBreak(Name name, Expression* value, Expression* curr); |
| void visitBreak(Break *curr); |
| void visitSwitch(Switch *curr); |
| void visitCall(Call *curr); |
| void visitCallImport(CallImport *curr); |
| void visitCallIndirect(CallIndirect *curr); |
| void visitGetLocal(GetLocal* curr); |
| void visitSetLocal(SetLocal *curr); |
| void visitLoad(Load *curr); |
| void visitStore(Store *curr); |
| void visitAtomicRMW(AtomicRMW *curr); |
| void visitAtomicCmpxchg(AtomicCmpxchg *curr); |
| void visitAtomicWait(AtomicWait *curr); |
| void visitAtomicWake(AtomicWake *curr); |
| void visitBinary(Binary *curr); |
| void visitUnary(Unary *curr); |
| void visitSelect(Select* curr); |
| void visitDrop(Drop* curr); |
| void visitReturn(Return* curr); |
| void visitHost(Host* curr); |
| void visitFunction(Function *curr); |
| |
| // helpers |
| private: |
| std::ostream& getStream() { |
| return info.getStream(getFunction()); |
| } |
| |
| template<typename T> |
| bool shouldBeTrue(bool result, T curr, const char* text) { |
| return info.shouldBeTrue(result, curr, text, getFunction()); |
| } |
| template<typename T> |
| bool shouldBeFalse(bool result, T curr, const char* text) { |
| return info.shouldBeFalse(result, curr, text, getFunction()); |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeEqual(S left, S right, T curr, const char* text) { |
| return info.shouldBeEqual(left, right, curr, text, getFunction()); |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { |
| return info.shouldBeEqualOrFirstIsUnreachable(left, right, curr, text, getFunction()); |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeUnequal(S left, S right, T curr, const char* text) { |
| return info.shouldBeUnequal(left, right, curr, text, getFunction()); |
| } |
| |
| void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text) { |
| return info.shouldBeIntOrUnreachable(ty, curr, text, getFunction()); |
| } |
| |
| void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic, |
| Expression* curr); |
| void validateMemBytes(uint8_t bytes, WasmType type, Expression* curr); |
| }; |
| |
| void FunctionValidator::noteLabelName(Name name) { |
| if (!name.is()) return; |
| shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that"); |
| labelNames.insert(name); |
| } |
| |
| void FunctionValidator::visitBlock(Block *curr) { |
| // if we are break'ed to, then the value must be right for us |
| if (curr->name.is()) { |
| noteLabelName(curr->name); |
| if (breakInfos.count(curr) > 0) { |
| auto& info = breakInfos[curr]; |
| if (isConcreteWasmType(curr->type)) { |
| shouldBeTrue(info.arity != 0, curr, "break arities must be > 0 if block has a value"); |
| } else { |
| shouldBeTrue(info.arity == 0, curr, "break arities must be 0 if block has no value"); |
| } |
| // none or unreachable means a poison value that we should ignore - if consumed, it will error |
| if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) { |
| shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value"); |
| } |
| if (isConcreteWasmType(curr->type) && info.arity && info.type != unreachable) { |
| shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks have arity"); |
| } |
| shouldBeTrue(info.arity != Index(-1), curr, "break arities must match"); |
| if (curr->list.size() > 0) { |
| auto last = curr->list.back()->type; |
| if (isConcreteWasmType(last) && info.type != unreachable) { |
| shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value"); |
| } |
| if (last == none) { |
| shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type"); |
| } |
| } |
| } |
| breakTargets.erase(curr->name); |
| } |
| if (curr->list.size() > 1) { |
| for (Index i = 0; i < curr->list.size() - 1; i++) { |
| if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)") && !info.quiet) { |
| getStream() << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; |
| } |
| } |
| } |
| if (curr->list.size() > 0) { |
| auto backType = curr->list.back()->type; |
| if (!isConcreteWasmType(curr->type)) { |
| shouldBeFalse(isConcreteWasmType(backType), curr, "if block is not returning a value, final element should not flow out a value"); |
| } else { |
| if (isConcreteWasmType(backType)) { |
| shouldBeEqual(curr->type, backType, curr, "block with value and last element with value must match types"); |
| } else { |
| shouldBeUnequal(backType, none, curr, "block with value must not have last element that is none"); |
| } |
| } |
| } |
| if (isConcreteWasmType(curr->type)) { |
| shouldBeTrue(curr->list.size() > 0, curr, "block with a value must not be empty"); |
| } |
| } |
| |
| void FunctionValidator::visitLoop(Loop *curr) { |
| if (curr->name.is()) { |
| noteLabelName(curr->name); |
| breakTargets.erase(curr->name); |
| if (breakInfos.count(curr) > 0) { |
| auto& info = breakInfos[curr]; |
| shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value"); |
| } |
| } |
| if (curr->type == none) { |
| shouldBeFalse(isConcreteWasmType(curr->body->type), curr, "bad body for a loop that has no value"); |
| } |
| } |
| |
| void FunctionValidator::visitIf(If *curr) { |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be valid"); |
| if (!curr->ifFalse) { |
| shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body"); |
| if (curr->condition->type != unreachable) { |
| shouldBeEqual(curr->type, none, curr, "if without else and reachable condition must be none"); |
| } |
| } else { |
| if (curr->type != unreachable) { |
| shouldBeEqualOrFirstIsUnreachable(curr->ifTrue->type, curr->type, curr, "returning if-else's true must have right type"); |
| shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, curr->type, curr, "returning if-else's false must have right type"); |
| } else { |
| if (curr->condition->type != unreachable) { |
| shouldBeEqual(curr->ifTrue->type, unreachable, curr, "unreachable if-else must have unreachable true"); |
| shouldBeEqual(curr->ifFalse->type, unreachable, curr, "unreachable if-else must have unreachable false"); |
| } |
| } |
| if (isConcreteWasmType(curr->ifTrue->type)) { |
| shouldBeEqual(curr->type, curr->ifTrue->type, curr, "if type must match concrete ifTrue"); |
| shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, curr->ifTrue->type, curr, "other arm must match concrete ifTrue"); |
| } |
| if (isConcreteWasmType(curr->ifFalse->type)) { |
| shouldBeEqual(curr->type, curr->ifFalse->type, curr, "if type must match concrete ifFalse"); |
| shouldBeEqualOrFirstIsUnreachable(curr->ifTrue->type, curr->ifFalse->type, curr, "other arm must match concrete ifFalse"); |
| } |
| } |
| } |
| |
| void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr) { |
| WasmType valueType = none; |
| Index arity = 0; |
| if (value) { |
| valueType = value->type; |
| shouldBeUnequal(valueType, none, curr, "breaks must have a valid value"); |
| arity = 1; |
| } |
| if (!shouldBeTrue(breakTargets.count(name) > 0, curr, "all break targets must be valid")) return; |
| auto* target = breakTargets[name]; |
| if (breakInfos.count(target) == 0) { |
| breakInfos[target] = BreakInfo(valueType, arity); |
| } else { |
| auto& info = breakInfos[target]; |
| if (info.type == unreachable) { |
| info.type = valueType; |
| } else if (valueType != unreachable) { |
| if (valueType != info.type) { |
| info.type = none; // a poison value that must not be consumed |
| } |
| } |
| if (arity != info.arity) { |
| info.arity = Index(-1); // a poison value |
| } |
| } |
| } |
| void FunctionValidator::visitBreak(Break *curr) { |
| noteBreak(curr->name, curr->value, curr); |
| if (curr->condition) { |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); |
| } |
| } |
| |
| void FunctionValidator::visitSwitch(Switch *curr) { |
| for (auto& target : curr->targets) { |
| noteBreak(target, curr->value, curr); |
| } |
| noteBreak(curr->default_, curr->value, curr); |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); |
| } |
| |
| void FunctionValidator::visitCall(Call *curr) { |
| if (!info.validateGlobally) return; |
| auto* target = getModule()->getFunctionOrNull(curr->target); |
| if (!shouldBeTrue(!!target, curr, "call target must exist")) { |
| if (getModule()->getImportOrNull(curr->target) && !info.quiet) { |
| getStream() << "(perhaps it should be a CallImport instead of Call?)\n"; |
| } |
| return; |
| } |
| if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return; |
| for (size_t i = 0; i < curr->operands.size(); i++) { |
| if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match") && !info.quiet) { |
| getStream() << "(on argument " << i << ")\n"; |
| } |
| } |
| } |
| |
| void FunctionValidator::visitCallImport(CallImport *curr) { |
| if (!info.validateGlobally) return; |
| auto* import = getModule()->getImportOrNull(curr->target); |
| if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return; |
| if (!shouldBeTrue(!!import->functionType.is(), curr, "called import must be function")) return; |
| auto* type = getModule()->getFunctionType(import->functionType); |
| if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; |
| for (size_t i = 0; i < curr->operands.size(); i++) { |
| if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) { |
| getStream() << "(on argument " << i << ")\n"; |
| } |
| } |
| } |
| |
| void FunctionValidator::visitCallIndirect(CallIndirect *curr) { |
| if (!info.validateGlobally) return; |
| auto* type = getModule()->getFunctionTypeOrNull(curr->fullType); |
| if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return; |
| shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32"); |
| if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; |
| for (size_t i = 0; i < curr->operands.size(); i++) { |
| if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) { |
| getStream() << "(on argument " << i << ")\n"; |
| } |
| } |
| } |
| |
| void FunctionValidator::visitGetLocal(GetLocal* curr) { |
| shouldBeTrue(isConcreteWasmType(curr->type), curr, "get_local must have a valid type - check what you provided when you constructed the node"); |
| } |
| |
| void FunctionValidator::visitSetLocal(SetLocal *curr) { |
| shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough"); |
| if (curr->value->type != unreachable) { |
| if (curr->type != none) { // tee is ok anyhow |
| shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); |
| } |
| shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function"); |
| } |
| } |
| |
| void FunctionValidator::visitLoad(Load *curr) { |
| shouldBeFalse(curr->isAtomic && !getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); |
| validateMemBytes(curr->bytes, curr->type, curr); |
| validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32"); |
| if (curr->isAtomic) shouldBeFalse(curr->signed_, curr, "atomic loads must be unsigned"); |
| } |
| |
| void FunctionValidator::visitStore(Store *curr) { |
| shouldBeFalse(curr->isAtomic && !getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); |
| validateMemBytes(curr->bytes, curr->valueType, curr); |
| validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); |
| shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none"); |
| shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match"); |
| } |
| |
| void FunctionValidator::visitAtomicRMW(AtomicRMW* curr) { |
| shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); |
| validateMemBytes(curr->bytes, curr->type, curr); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicRMW pointer type must be i32"); |
| shouldBeEqualOrFirstIsUnreachable(curr->type, curr->value->type, curr, "AtomicRMW result type must match operand"); |
| shouldBeIntOrUnreachable(curr->type, curr, "Atomic operations are only valid on int types"); |
| } |
| |
| void FunctionValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) { |
| shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); |
| validateMemBytes(curr->bytes, curr->type, curr); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "cmpxchg pointer type must be i32"); |
| if (curr->expected->type != unreachable && curr->replacement->type != unreachable) { |
| shouldBeEqual(curr->expected->type, curr->replacement->type, curr, "cmpxchg operand types must match"); |
| } |
| shouldBeEqualOrFirstIsUnreachable(curr->type, curr->expected->type, curr, "Cmpxchg result type must match expected"); |
| shouldBeEqualOrFirstIsUnreachable(curr->type, curr->replacement->type, curr, "Cmpxchg result type must match replacement"); |
| shouldBeIntOrUnreachable(curr->expected->type, curr, "Atomic operations are only valid on int types"); |
| } |
| |
| void FunctionValidator::visitAtomicWait(AtomicWait* curr) { |
| shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); |
| shouldBeEqualOrFirstIsUnreachable(curr->type, i32, curr, "AtomicWait must have type i32"); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicWait pointer type must be i32"); |
| shouldBeIntOrUnreachable(curr->expected->type, curr, "AtomicWait expected type must be int"); |
| shouldBeEqualOrFirstIsUnreachable(curr->expected->type, curr->expectedType, curr, "AtomicWait expected type must match operand"); |
| shouldBeEqualOrFirstIsUnreachable(curr->timeout->type, i64, curr, "AtomicWait timeout type must be i64"); |
| } |
| |
| void FunctionValidator::visitAtomicWake(AtomicWake* curr) { |
| shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); |
| shouldBeEqualOrFirstIsUnreachable(curr->type, i32, curr, "AtomicWake must have type i32"); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicWake pointer type must be i32"); |
| shouldBeEqualOrFirstIsUnreachable(curr->wakeCount->type, i32, curr, "AtomicWake wakeCount type must be i32"); |
| } |
| |
| void FunctionValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* curr) { |
| switch (bytes) { |
| case 1: |
| case 2: |
| case 4: break; |
| case 8: { |
| // if we have a concrete type for the load, then we know the size of the mem operation and |
| // can validate it |
| if (type != unreachable) { |
| shouldBeEqual(getWasmTypeSize(type), 8U, curr, "8-byte mem operations are only allowed with 8-byte wasm types"); |
| } |
| break; |
| } |
| default: info.fail("Memory operations must be 1,2,4, or 8 bytes", curr, getFunction()); |
| } |
| } |
| |
| void FunctionValidator::visitBinary(Binary *curr) { |
| if (curr->left->type != unreachable && curr->right->type != unreachable) { |
| shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal"); |
| } |
| switch (curr->op) { |
| case AddInt32: |
| case SubInt32: |
| case MulInt32: |
| case DivSInt32: |
| case DivUInt32: |
| case RemSInt32: |
| case RemUInt32: |
| case AndInt32: |
| case OrInt32: |
| case XorInt32: |
| case ShlInt32: |
| case ShrUInt32: |
| case ShrSInt32: |
| case RotLInt32: |
| case RotRInt32: |
| case EqInt32: |
| case NeInt32: |
| case LtSInt32: |
| case LtUInt32: |
| case LeSInt32: |
| case LeUInt32: |
| case GtSInt32: |
| case GtUInt32: |
| case GeSInt32: |
| case GeUInt32: { |
| shouldBeEqualOrFirstIsUnreachable(curr->left->type, i32, curr, "i32 op"); |
| break; |
| } |
| case AddInt64: |
| case SubInt64: |
| case MulInt64: |
| case DivSInt64: |
| case DivUInt64: |
| case RemSInt64: |
| case RemUInt64: |
| case AndInt64: |
| case OrInt64: |
| case XorInt64: |
| case ShlInt64: |
| case ShrUInt64: |
| case ShrSInt64: |
| case RotLInt64: |
| case RotRInt64: |
| case EqInt64: |
| case NeInt64: |
| case LtSInt64: |
| case LtUInt64: |
| case LeSInt64: |
| case LeUInt64: |
| case GtSInt64: |
| case GtUInt64: |
| case GeSInt64: |
| case GeUInt64: { |
| shouldBeEqualOrFirstIsUnreachable(curr->left->type, i64, curr, "i64 op"); |
| break; |
| } |
| case AddFloat32: |
| case SubFloat32: |
| case MulFloat32: |
| case DivFloat32: |
| case CopySignFloat32: |
| case MinFloat32: |
| case MaxFloat32: |
| case EqFloat32: |
| case NeFloat32: |
| case LtFloat32: |
| case LeFloat32: |
| case GtFloat32: |
| case GeFloat32: { |
| shouldBeEqualOrFirstIsUnreachable(curr->left->type, f32, curr, "f32 op"); |
| break; |
| } |
| case AddFloat64: |
| case SubFloat64: |
| case MulFloat64: |
| case DivFloat64: |
| case CopySignFloat64: |
| case MinFloat64: |
| case MaxFloat64: |
| case EqFloat64: |
| case NeFloat64: |
| case LtFloat64: |
| case LeFloat64: |
| case GtFloat64: |
| case GeFloat64: { |
| shouldBeEqualOrFirstIsUnreachable(curr->left->type, f64, curr, "f64 op"); |
| break; |
| } |
| default: WASM_UNREACHABLE(); |
| } |
| } |
| |
| void FunctionValidator::visitUnary(Unary *curr) { |
| shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input"); |
| if (curr->value->type == unreachable) return; // nothing to check |
| switch (curr->op) { |
| case ClzInt32: |
| case CtzInt32: |
| case PopcntInt32: { |
| shouldBeEqual(curr->value->type, i32, curr, "i32 unary value type must be correct"); |
| break; |
| } |
| case ClzInt64: |
| case CtzInt64: |
| case PopcntInt64: { |
| shouldBeEqual(curr->value->type, i64, curr, "i64 unary value type must be correct"); |
| break; |
| } |
| case NegFloat32: |
| case AbsFloat32: |
| case CeilFloat32: |
| case FloorFloat32: |
| case TruncFloat32: |
| case NearestFloat32: |
| case SqrtFloat32: { |
| shouldBeEqual(curr->value->type, f32, curr, "f32 unary value type must be correct"); |
| break; |
| } |
| case NegFloat64: |
| case AbsFloat64: |
| case CeilFloat64: |
| case FloorFloat64: |
| case TruncFloat64: |
| case NearestFloat64: |
| case SqrtFloat64: { |
| shouldBeEqual(curr->value->type, f64, curr, "f64 unary value type must be correct"); |
| break; |
| } |
| case EqZInt32: { |
| shouldBeTrue(curr->value->type == i32, curr, "i32.eqz input must be i32"); |
| break; |
| } |
| case EqZInt64: { |
| shouldBeTrue(curr->value->type == i64, curr, "i64.eqz input must be i64"); |
| break; |
| } |
| case ExtendSInt32: |
| case ExtendUInt32: |
| case ExtendS8Int32: |
| case ExtendS16Int32: { |
| shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break; |
| } |
| case ExtendS8Int64: |
| case ExtendS16Int64: |
| case ExtendS32Int64: { |
| shouldBeEqual(curr->value->type, i64, curr, "extend type must be correct"); break; |
| } |
| case WrapInt64: shouldBeEqual(curr->value->type, i64, curr, "wrap type must be correct"); break; |
| case TruncSFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncSFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncUFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncUFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncSFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case TruncSFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case TruncUFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case TruncUFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case ReinterpretFloat32: shouldBeEqual(curr->value->type, f32, curr, "reinterpret/f32 type must be correct"); break; |
| case ReinterpretFloat64: shouldBeEqual(curr->value->type, f64, curr, "reinterpret/f64 type must be correct"); break; |
| case ConvertUInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertUInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertSInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertSInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertUInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case ConvertUInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case ConvertSInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case ConvertSInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case PromoteFloat32: shouldBeEqual(curr->value->type, f32, curr, "promote type must be correct"); break; |
| case DemoteFloat64: shouldBeEqual(curr->value->type, f64, curr, "demote type must be correct"); break; |
| case ReinterpretInt32: shouldBeEqual(curr->value->type, i32, curr, "reinterpret/i32 type must be correct"); break; |
| case ReinterpretInt64: shouldBeEqual(curr->value->type, i64, curr, "reinterpret/i64 type must be correct"); break; |
| default: abort(); |
| } |
| } |
| |
| void FunctionValidator::visitSelect(Select* curr) { |
| shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid"); |
| shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid"); |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid"); |
| if (curr->ifTrue->type != unreachable && curr->ifFalse->type != unreachable) { |
| shouldBeEqual(curr->ifTrue->type, curr->ifFalse->type, curr, "select sides must be equal"); |
| } |
| } |
| |
| void FunctionValidator::visitDrop(Drop* curr) { |
| shouldBeTrue(isConcreteWasmType(curr->value->type) || curr->value->type == unreachable, curr, "can only drop a valid value"); |
| } |
| |
| void FunctionValidator::visitReturn(Return* curr) { |
| if (curr->value) { |
| if (returnType == unreachable) { |
| returnType = curr->value->type; |
| } else if (curr->value->type != unreachable) { |
| shouldBeEqual(curr->value->type, returnType, curr, "function results must match"); |
| } |
| } else { |
| returnType = none; |
| } |
| } |
| |
| void FunctionValidator::visitHost(Host* curr) { |
| switch (curr->op) { |
| case GrowMemory: { |
| shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand"); |
| shouldBeEqualOrFirstIsUnreachable(curr->operands[0]->type, i32, curr, "grow_memory must have i32 operand"); |
| break; |
| } |
| case PageSize: |
| case CurrentMemory: |
| case HasFeature: break; |
| default: WASM_UNREACHABLE(); |
| } |
| } |
| |
| void FunctionValidator::visitFunction(Function *curr) { |
| // if function has no result, it is ignored |
| // if body is unreachable, it might be e.g. a return |
| if (curr->body->type != unreachable) { |
| shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); |
| } |
| if (returnType != unreachable) { |
| shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns"); |
| } |
| shouldBeTrue(breakTargets.empty(), curr->body, "all named break targets must exist"); |
| returnType = unreachable; |
| labelNames.clear(); |
| // expressions must not be seen more than once |
| struct Walker : public PostWalker<Walker, UnifiedExpressionVisitor<Walker>> { |
| std::unordered_set<Expression*>& seen; |
| std::vector<Expression*> dupes; |
| |
| Walker(std::unordered_set<Expression*>& seen) : seen(seen) {} |
| |
| void visitExpression(Expression* curr) { |
| bool inserted; |
| std::tie(std::ignore, inserted) = seen.insert(curr); |
| if (!inserted) dupes.push_back(curr); |
| } |
| }; |
| Walker walker(seenExpressions); |
| walker.walk(curr->body); |
| for (auto* bad : walker.dupes) { |
| info.fail("expression seen more than once in the tree", bad, getFunction()); |
| } |
| } |
| |
| static bool checkOffset(Expression* curr, Address add, Address max) { |
| if (curr->is<GetGlobal>()) return true; |
| auto* c = curr->dynCast<Const>(); |
| if (!c) return false; |
| uint64_t raw = c->value.getInteger(); |
| if (raw > std::numeric_limits<Address::address_t>::max()) { |
| return false; |
| } |
| if (raw + uint64_t(add) > std::numeric_limits<Address::address_t>::max()) { |
| return false; |
| } |
| Address offset = raw; |
| return offset + add <= max; |
| } |
| |
| void FunctionValidator::validateAlignment(size_t align, WasmType type, Index bytes, |
| bool isAtomic, Expression* curr) { |
| if (isAtomic) { |
| shouldBeEqual(align, (size_t)bytes, curr, "atomic accesses must have natural alignment"); |
| return; |
| } |
| switch (align) { |
| case 1: |
| case 2: |
| case 4: |
| case 8: break; |
| default:{ |
| info.fail("bad alignment: " + std::to_string(align), curr, getFunction()); |
| break; |
| } |
| } |
| shouldBeTrue(align <= bytes, curr, "alignment must not exceed natural"); |
| switch (type) { |
| case i32: |
| case f32: { |
| shouldBeTrue(align <= 4, curr, "alignment must not exceed natural"); |
| break; |
| } |
| case i64: |
| case f64: { |
| shouldBeTrue(align <= 8, curr, "alignment must not exceed natural"); |
| break; |
| } |
| default: {} |
| } |
| } |
| |
| static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { |
| struct BinaryenIRValidator : public PostWalker<BinaryenIRValidator, UnifiedExpressionVisitor<BinaryenIRValidator>> { |
| ValidationInfo& info; |
| |
| BinaryenIRValidator(ValidationInfo& info) : info(info) {} |
| |
| void visitExpression(Expression* curr) { |
| // check if a node type is 'stale', i.e., we forgot to finalize() the node. |
| auto oldType = curr->type; |
| ReFinalizeNode().visit(curr); |
| auto newType = curr->type; |
| if (newType != oldType) { |
| // We accept concrete => undefined, |
| // e.g. |
| // |
| // (drop (block (result i32) (unreachable))) |
| // |
| // The block has an added type, not derived from the ast itself, so it is |
| // ok for it to be either i32 or unreachable. |
| if (!(isConcreteWasmType(oldType) && newType == unreachable)) { |
| std::ostringstream ss; |
| ss << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n"; |
| info.fail(ss.str(), curr, getFunction()); |
| } |
| curr->type = oldType; |
| } |
| } |
| }; |
| BinaryenIRValidator binaryenIRValidator(info); |
| binaryenIRValidator.walkModule(&wasm); |
| } |
| |
| // Main validator class |
| |
| static void validateImports(Module& module, ValidationInfo& info) { |
| for (auto& curr : module.imports) { |
| if (curr->kind == ExternalKind::Function) { |
| if (info.validateWeb) { |
| auto* functionType = module.getFunctionType(curr->functionType); |
| info.shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type"); |
| for (WasmType param : functionType->params) { |
| info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); |
| } |
| } |
| } |
| if (curr->kind == ExternalKind::Table) { |
| info.shouldBeTrue(module.table.imported, curr->name, "Table import record exists but table is not marked as imported"); |
| } |
| if (curr->kind == ExternalKind::Memory) { |
| info.shouldBeTrue(module.memory.imported, curr->name, "Memory import record exists but memory is not marked as imported"); |
| } |
| } |
| } |
| |
| static void validateExports(Module& module, ValidationInfo& info) { |
| for (auto& curr : module.exports) { |
| if (curr->kind == ExternalKind::Function) { |
| if (info.validateWeb) { |
| Function* f = module.getFunction(curr->value); |
| info.shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type"); |
| for (auto param : f->params) { |
| info.shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters"); |
| } |
| } |
| } |
| } |
| std::set<Name> exportNames; |
| for (auto& exp : module.exports) { |
| Name name = exp->value; |
| if (exp->kind == ExternalKind::Function) { |
| bool found = false; |
| for (auto& func : module.functions) { |
| if (func->name == name) { |
| found = true; |
| break; |
| } |
| } |
| info.shouldBeTrue(found, name, "module function exports must be found"); |
| } else if (exp->kind == ExternalKind::Global) { |
| info.shouldBeTrue(module.getGlobalOrNull(name), name, "module global exports must be found"); |
| } else if (exp->kind == ExternalKind::Table) { |
| info.shouldBeTrue(name == Name("0") || name == module.table.name, name, "module table exports must be found"); |
| } else if (exp->kind == ExternalKind::Memory) { |
| info.shouldBeTrue(name == Name("0") || name == module.memory.name, name, "module memory exports must be found"); |
| } else { |
| WASM_UNREACHABLE(); |
| } |
| Name exportName = exp->name; |
| info.shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); |
| exportNames.insert(exportName); |
| } |
| } |
| |
| static void validateGlobals(Module& module, ValidationInfo& info) { |
| for (auto& curr : module.globals) { |
| info.shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null"); |
| info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid"); |
| if (!info.shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type") && !info.quiet) { |
| info.getStream(nullptr) << "(on global " << curr->name << ")\n"; |
| } |
| } |
| } |
| |
| static void validateMemory(Module& module, ValidationInfo& info) { |
| auto& curr = module.memory; |
| info.shouldBeFalse(curr.initial > curr.max, "memory", "memory max >= initial"); |
| info.shouldBeTrue(curr.max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB"); |
| info.shouldBeTrue(!curr.shared || curr.hasMax(), "memory", "shared memory must have max size"); |
| Index mustBeGreaterOrEqual = 0; |
| for (auto& segment : curr.segments) { |
| if (!info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue; |
| info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable"); |
| Index size = segment.data.size(); |
| info.shouldBeTrue(size <= curr.initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); |
| if (segment.offset->is<Const>()) { |
| Index start = segment.offset->cast<Const>()->value.geti32(); |
| Index end = start + size; |
| info.shouldBeTrue(end <= curr.initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); |
| info.shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory"); |
| mustBeGreaterOrEqual = end; |
| } |
| } |
| } |
| |
| static void validateTable(Module& module, ValidationInfo& info) { |
| auto& curr = module.table; |
| for (auto& segment : curr.segments) { |
| info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32"); |
| info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable"); |
| for (auto name : segment.data) { |
| info.shouldBeTrue(module.getFunctionOrNull(name) || module.getImportOrNull(name), name, "segment name should be valid"); |
| } |
| } |
| } |
| |
| static void validateModule(Module& module, ValidationInfo& info) { |
| // start |
| if (module.start.is()) { |
| auto func = module.getFunctionOrNull(module.start); |
| if (info.shouldBeTrue(func != nullptr, module.start, "start must be found")) { |
| info.shouldBeTrue(func->params.size() == 0, module.start, "start must have 0 params"); |
| info.shouldBeTrue(func->result == none, module.start, "start must not return a value"); |
| } |
| } |
| } |
| |
| // TODO: If we want the validator to be part of libwasm rather than libpasses, then |
| // Using PassRunner::getPassDebug causes a circular dependence. We should fix that, |
| // perhaps by moving some of the pass infrastructure into libsupport. |
| bool WasmValidator::validate(Module& module, Flags flags) { |
| ValidationInfo info; |
| info.validateWeb = flags & Web; |
| info.validateGlobally = flags & Globally; |
| info.quiet = flags & Quiet; |
| // parallel wasm logic validation |
| PassRunner runner(&module); |
| runner.add<FunctionValidator>(&info); |
| runner.setIsNested(true); |
| runner.run(); |
| // validate globally |
| if (info.validateGlobally) { |
| validateImports(module, info); |
| validateExports(module, info); |
| validateGlobals(module, info); |
| validateMemory(module, info); |
| validateTable(module, info); |
| validateModule(module, info); |
| } |
| // validate additional internal IR details when in pass-debug mode |
| if (PassRunner::getPassDebug()) { |
| validateBinaryenIR(module, info); |
| } |
| // print all the data |
| if (!info.valid.load() && !info.quiet) { |
| for (auto& func : module.functions) { |
| std::cerr << info.getStream(func.get()).str(); |
| } |
| std::cerr << info.getStream(nullptr).str(); |
| // also print the module |
| WasmPrinter::printModule(&module, std::cerr); |
| } |
| return info.valid.load(); |
| } |
| |
| } // namespace wasm |