blob: 07d1507048eab32a52c289b187a4caf8e1821bd8 [file] [log] [blame] [edit]
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "support/command-line.h"
#include "wasm-io.h"
#include <iostream>
#include <z3++.h>
using namespace wasm;
struct ToSMT : UnifiedExpressionVisitor<ToSMT, z3::expr> {
z3::context& ctx;
Function* func;
std::vector<z3::expr> params;
ToSMT(z3::context& ctx, Function* func) : ctx(ctx), func(func) {
initParams(func);
}
void initParams(Function* func) {
for (Index i = 0; i < func->getNumParams(); ++i) {
auto type = func->getLocalType(i);
auto name = func->getLocalNameOrGeneric(i).str.data();
if (type.isBasic()) {
switch (type.getBasic()) {
case Type::none:
case Type::unreachable:
case Type::f32:
case Type::f64:
break;
case Type::i32:
params.push_back(ctx.bv_const(name, 32));
continue;
case Type::i64:
params.push_back(ctx.bv_const(name, 64));
continue;
case Type::v128:
params.push_back(ctx.bv_const(name, 128));
continue;
}
}
WASM_UNREACHABLE("unimplemented param type");
}
}
z3::expr visitExpression(Expression* curr) {
WASM_UNREACHABLE("unimplemented expression");
}
z3::expr visitLocalGet(LocalGet* curr) {
assert(curr->index < func->getNumParams() && "TODO");
return params[curr->index];
}
z3::expr visitConst(Const* curr) {
assert(curr->type.isBasic());
switch (curr->type.getBasic()) {
case Type::none:
case Type::unreachable:
break;
case Type::f32:
case Type::f64:
WASM_UNREACHABLE("TODO: fp const");
case Type::i32:
return ctx.bv_val(curr->value.geti32(), 32);
case Type::i64:
return ctx.bv_val(curr->value.geti64(), 64);
case Type::v128:
WASM_UNREACHABLE("TODO: v128.const");
}
WASM_UNREACHABLE("unexpected type");
}
z3::expr visitBinary(Binary* curr) {
auto lhs = visit(curr->left);
auto rhs = visit(curr->right);
switch (curr->op) {
case MulInt32:
return lhs * rhs;
case ShlInt32:
return z3::shl(lhs, rhs);
default:
break;
}
WASM_UNREACHABLE("unimplemented binary op");
}
};
z3::expr funcToSMT(z3::context& ctx, Function* func) {
return ToSMT(ctx, func).visit(func->body);
}
z3::expr refinedBy(const z3::expr& src, const z3::expr& tgt) {
// TODO: Something more complicated!
return tgt == src;
}
void prove(const z3::expr& conjecture) {
z3::context& ctx = conjecture.ctx();
z3::solver solver(ctx);
solver.add(!conjecture);
std::cout << "Proving conjecture:\n" << conjecture << "\n";
if (solver.check() == z3::unsat) {
std::cout << "proved!\n";
} else {
std::cout << "counterexample:\n" << solver.get_model() << "\n";
}
}
void checkRefinement(Function* src, Function* tgt) {
z3::context ctx;
auto srcSMT = funcToSMT(ctx, src);
auto tgtSMT = funcToSMT(ctx, tgt);
prove(refinedBy(srcSMT, tgtSMT));
}
struct ValidateRefinementOptions : Options {
std::string source;
std::string target;
ValidateRefinementOptions(const std::string& command, const std::string& desc)
: Options(command, desc) {
add("--source",
"-s",
"The original module",
"",
Arguments::One,
[&](Options*, const std::string& val) { source = val; });
add("--target",
"-t",
"The transformed module",
"",
Arguments::One,
[&](Options*, const std::string& val) { target = val; });
}
};
int main(int argc, const char* argv[]) {
ValidateRefinementOptions options(
"wasm-validate-refinement",
"Bounded translation validation for WebAssembly");
options.parse(argc, argv);
if (options.source.empty()) {
std::cerr << "Source module must be provided (--source)\n";
return 1;
}
if (options.target.empty()) {
std::cerr << "Target module must be provided (--target)\n";
return 1;
}
Module src, tgt;
ModuleReader().read(options.source, src);
ModuleReader().read(options.target, tgt);
// TODO: Verify that src and tgt have matching global structures, including
// function signatures.
for (size_t i = 0; i < src.functions.size(); ++i) {
if (src.functions[i]->imported()) {
continue;
}
assert(i < tgt.functions.size() && !tgt.functions[i]->imported());
checkRefinement(src.functions[i].get(), tgt.functions[i].get());
}
}