| /* |
| * Copyright 2015 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // |
| // Simple WebAssembly module validator. |
| // |
| |
| #ifndef wasm_wasm_validator_h |
| #define wasm_wasm_validator_h |
| |
| #include "wasm.h" |
| #include "wasm-printing.h" |
| |
| namespace wasm { |
| |
| struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> { |
| bool valid = true; |
| bool validateWebConstraints = false; |
| |
| std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type |
| WasmType returnType = unreachable; // type used in returns |
| |
| public: |
| bool validate(Module& module, bool validateWeb=false) { |
| validateWebConstraints = validateWeb; |
| walkModule(&module); |
| return valid; |
| } |
| |
| // visitors |
| |
| void visitBlock(Block *curr) { |
| // if we are break'ed to, then the value must be right for us |
| if (curr->name.is()) { |
| // none or unreachable means a poison value that we should ignore - if consumed, it will error |
| if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) { |
| shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value"); |
| } |
| breakTypes.erase(curr->name); |
| } |
| } |
| void visitIf(If *curr) { |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); |
| } |
| void visitLoop(Loop *curr) { |
| if (curr->in.is()) { |
| breakTypes.erase(curr->in); |
| } |
| if (curr->out.is()) { |
| breakTypes.erase(curr->out); |
| } |
| } |
| void noteBreak(Name name, Expression* value) { |
| WasmType valueType = none; |
| if (value) { |
| valueType = value->type; |
| } |
| if (breakTypes.count(name) == 0) { |
| breakTypes[name] = valueType; |
| } else { |
| if (breakTypes[name] == unreachable) { |
| breakTypes[name] = valueType; |
| } else if (valueType != unreachable) { |
| if (valueType != breakTypes[name]) { |
| breakTypes[name] = none; // a poison value that must not be consumed |
| } |
| } |
| } |
| } |
| void visitBreak(Break *curr) { |
| noteBreak(curr->name, curr->value); |
| if (curr->condition) { |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); |
| } |
| } |
| void visitSwitch(Switch *curr) { |
| for (auto& target : curr->targets) { |
| noteBreak(target, curr->value); |
| } |
| noteBreak(curr->default_, curr->value); |
| shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); |
| } |
| void visitCall(Call *curr) { |
| auto* target = getModule()->checkFunction(curr->target); |
| if (!shouldBeTrue(!!target, curr, "call target must exist")) return; |
| if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return; |
| for (size_t i = 0; i < curr->operands.size(); i++) { |
| shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match"); |
| } |
| } |
| void visitCallImport(CallImport *curr) { |
| auto* import = getModule()->checkImport(curr->target); |
| if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return; |
| auto* type = import->type; |
| if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; |
| for (size_t i = 0; i < curr->operands.size(); i++) { |
| shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match"); |
| } |
| } |
| void visitCallIndirect(CallIndirect *curr) { |
| auto* type = getModule()->checkFunctionType(curr->fullType); |
| if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return; |
| shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32"); |
| if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; |
| for (size_t i = 0; i < curr->operands.size(); i++) { |
| shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match"); |
| } |
| } |
| void visitSetLocal(SetLocal *curr) { |
| if (curr->value->type != unreachable) { |
| shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); |
| } |
| } |
| void visitLoad(Load *curr) { |
| validateAlignment(curr->align); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32"); |
| } |
| void visitStore(Store *curr) { |
| validateAlignment(curr->align); |
| shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); |
| shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); |
| } |
| void visitBinary(Binary *curr) { |
| if (curr->left->type != unreachable && curr->right->type != unreachable) { |
| shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal"); |
| } |
| } |
| void visitUnary(Unary *curr) { |
| shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input"); |
| switch (curr->op) { |
| case EqZInt32: |
| case EqZInt64: { |
| shouldBeEqual(curr->type, i32, curr, "eqz must return i32"); |
| break; |
| } |
| default: {} |
| } |
| if (curr->value->type == unreachable) return; |
| switch (curr->op) { |
| case ClzInt32: |
| case CtzInt32: |
| case PopcntInt32: |
| case NegFloat32: |
| case AbsFloat32: |
| case CeilFloat32: |
| case FloorFloat32: |
| case TruncFloat32: |
| case NearestFloat32: |
| case SqrtFloat32: |
| case ClzInt64: |
| case CtzInt64: |
| case PopcntInt64: |
| case NegFloat64: |
| case AbsFloat64: |
| case CeilFloat64: |
| case FloorFloat64: |
| case TruncFloat64: |
| case NearestFloat64: |
| case SqrtFloat64: { |
| if (curr->value->type != unreachable) { |
| shouldBeEqual(curr->value->type, curr->type, curr, "non-conversion unaries must return the same type"); |
| } |
| break; |
| } |
| case EqZInt32: |
| case EqZInt64: { |
| shouldBeTrue(curr->value->type == i32 || curr->value->type == i64, curr, "eqz input must be i32 or i64"); |
| break; |
| } |
| case ExtendSInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break; |
| case ExtendUInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break; |
| case WrapInt64: shouldBeEqual(curr->value->type, i64, curr, "wrap type must be correct"); break; |
| case TruncSFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncSFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncUFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncUFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break; |
| case TruncSFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case TruncSFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case TruncUFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case TruncUFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break; |
| case ReinterpretFloat32: shouldBeEqual(curr->value->type, f32, curr, "reinterpret/f32 type must be correct"); break; |
| case ReinterpretFloat64: shouldBeEqual(curr->value->type, f64, curr, "reinterpret/f64 type must be correct"); break; |
| case ConvertUInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertUInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertSInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertSInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break; |
| case ConvertUInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case ConvertUInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case ConvertSInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case ConvertSInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break; |
| case PromoteFloat32: shouldBeEqual(curr->value->type, f32, curr, "promote type must be correct"); break; |
| case DemoteFloat64: shouldBeEqual(curr->value->type, f64, curr, "demote type must be correct"); break; |
| case ReinterpretInt32: shouldBeEqual(curr->value->type, i32, curr, "reinterpret/i32 type must be correct"); break; |
| case ReinterpretInt64: shouldBeEqual(curr->value->type, i64, curr, "reinterpret/i64 type must be correct"); break; |
| default: abort(); |
| } |
| } |
| |
| void visitReturn(Return* curr) { |
| if (curr->value) { |
| if (returnType == unreachable) { |
| returnType = curr->value->type; |
| } else if (curr->value->type != unreachable && returnType != curr->value->type) { |
| returnType = none; // poison |
| } |
| } else { |
| returnType = none; |
| } |
| } |
| |
| void visitHost(Host* curr) { |
| switch (curr->op) { |
| case GrowMemory: { |
| shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand"); |
| shouldBeEqualOrFirstIsUnreachable(curr->operands[0]->type, i32, curr, "grow_memory must have i32 operand"); |
| break; |
| } |
| case PageSize: |
| case CurrentMemory: |
| case HasFeature: break; |
| default: WASM_UNREACHABLE(); |
| } |
| } |
| |
| void visitImport(Import* curr) { |
| if (!validateWebConstraints) return; |
| shouldBeUnequal(curr->type->result, i64, curr->name, "Imported function must not have i64 return type"); |
| for (WasmType param : curr->type->params) { |
| shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); |
| } |
| } |
| |
| void visitExport(Export* curr) { |
| if (!validateWebConstraints) return; |
| Function* f = getModule()->getFunction(curr->value); |
| shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type"); |
| for (auto param : f->params) { |
| shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters"); |
| } |
| } |
| |
| void visitFunction(Function *curr) { |
| // if function has no result, it is ignored |
| // if body is unreachable, it might be e.g. a return |
| if (curr->result != none) { |
| if (curr->body->type != unreachable) { |
| shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); |
| } |
| if (returnType != unreachable) { |
| shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); |
| } |
| } |
| returnType = unreachable; |
| } |
| void visitMemory(Memory *curr) { |
| shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial"); |
| shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB"); |
| size_t top = 0; |
| for (auto& segment : curr->segments) { |
| shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough"); |
| top = segment.offset + segment.data.size(); |
| } |
| shouldBeFalse(top > curr->initial * Memory::kPageSize, "memory", "total segments must be small enough"); |
| } |
| void visitModule(Module *curr) { |
| // exports |
| std::set<Name> exportNames; |
| for (auto& exp : curr->exports) { |
| Name name = exp->value; |
| bool found = false; |
| for (auto& func : curr->functions) { |
| if (func->name == name) { |
| found = true; |
| break; |
| } |
| } |
| shouldBeTrue(found, name, "module exports must be found"); |
| Name exportName = exp->name; |
| shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); |
| exportNames.insert(exportName); |
| } |
| // start |
| if (curr->start.is()) { |
| auto func = curr->checkFunction(curr->start); |
| if (shouldBeTrue(func != nullptr, curr->start, "start must be found")) { |
| shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params"); |
| shouldBeTrue(func->result == none, curr, "start must not return a value"); |
| } |
| } |
| } |
| |
| void doWalkFunction(Function* func) { |
| PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func); |
| shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid"); |
| } |
| |
| private: |
| |
| // helpers |
| |
| std::ostream& fail() { |
| Colors::red(std::cerr); |
| if (getFunction()) { |
| std::cerr << "[wasm-validator error in function "; |
| Colors::green(std::cerr); |
| std::cerr << getFunction()->name; |
| Colors::red(std::cerr); |
| std::cerr << "] "; |
| } else { |
| std::cerr << "[wasm-validator error in module] "; |
| } |
| Colors::normal(std::cerr); |
| return std::cerr; |
| } |
| |
| template<typename T> |
| bool shouldBeTrue(bool result, T curr, const char* text) { |
| if (!result) { |
| fail() << "unexpected false: " << text << ", on \n" << curr << std::endl; |
| valid = false; |
| return false; |
| } |
| return result; |
| } |
| template<typename T> |
| bool shouldBeFalse(bool result, T curr, const char* text) { |
| if (result) { |
| fail() << "unexpected true: " << text << ", on \n" << curr << std::endl; |
| valid = false; |
| return false; |
| } |
| return result; |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeEqual(S left, S right, T curr, const char* text) { |
| if (left != right) { |
| fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; |
| valid = false; |
| return false; |
| } |
| return true; |
| } |
| template<typename T, typename S, typename U> |
| bool shouldBeEqual(S left, S right, T curr, U other, const char* text) { |
| if (left != right) { |
| fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << " / " << other << std::endl; |
| valid = false; |
| return false; |
| } |
| return true; |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { |
| if (left != unreachable && left != right) { |
| fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; |
| valid = false; |
| return false; |
| } |
| return true; |
| } |
| |
| template<typename T, typename S> |
| bool shouldBeUnequal(S left, S right, T curr, const char* text) { |
| if (left == right) { |
| fail() << "" << left << " == " << right << ": " << text << ", on \n" << curr << std::endl; |
| valid = false; |
| return false; |
| } |
| return true; |
| } |
| |
| void validateAlignment(size_t align) { |
| switch (align) { |
| case 1: |
| case 2: |
| case 4: |
| case 8: break; |
| default:{ |
| fail() << "bad alignment: " << align << std::endl; |
| valid = false; |
| break; |
| } |
| } |
| } |
| }; |
| |
| } // namespace wasm |
| |
| #endif // wasm_wasm_validator_h |