blob: db7f0bdfcc55224fe0bca2da36e23b5a0fd236f7 [file] [log] [blame]
/*
* Copyright 2016 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Expression analyzer utility
//
// Superoptimization is based on Bansal, Sorav; Aiken, Alex (21–25 October 2006): "Automatic Generation of Peephole Superoptimizers".
//
#include "support/colors.h"
#include "support/command-line.h"
#include "support/file.h"
#include "support/hash.h"
#include "support/permutation.h"
#include "wasm-s-parser.h"
#include "wasm-traversal.h"
#include "wasm-printing.h"
#include "wasm-interpreter.h"
#include "wasm-io.h"
#include "ast_utils.h"
#include "ast/cost.h"
using namespace cashew;
using namespace wasm;
// limits on what we care about
#define MAX_EXPRESSION_SIZE 20
#define MAX_LOCAL 4
// special values to make sure to consider in execution hashing
#define NUM_LIMITS 6
static int32_t LIMIT_I32S[NUM_LIMITS] = { std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max(), int32_t(std::numeric_limits<uint32_t>::min()), int32_t(std::numeric_limits<uint32_t>::max()), 0xfffff, -0xfffff };
static int64_t LIMIT_I64S[NUM_LIMITS] = { std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max(), int64_t(std::numeric_limits<uint64_t>::min()), int64_t(std::numeric_limits<uint64_t>::max()), 0xfffffLL, -0xfffffLL };
static float LIMIT_F32S[NUM_LIMITS] = { std::numeric_limits<float>::min(), std::numeric_limits<float>::max(), std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::infinity(), float(0xfffff), float(-0xfffff) };
static double LIMIT_F64S[NUM_LIMITS] = { std::numeric_limits<double>::min(), std::numeric_limits<double>::max(), std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::infinity(), double(0xfffff), double(-0xfffff) };
#define MAX_SMALL 260
#define NUM_SMALLS (MAX_SMALL + MAX_SMALL + 1) /* negatives, positives, and zero */
#define NUM_SPECIALS (NUM_LIMITS + NUM_SMALLS)
#define NUM_RANDOMS 1000
#define NUM_EXECUTIONS (NUM_SPECIALS + NUM_RANDOMS)
// An expression with a cached hash value
struct HashedExpression {
Expression* expr;
size_t hash;
HashedExpression(Expression* expr) : expr(expr) {
if (expr) {
hash = ExpressionAnalyzer::hash(expr);
}
}
HashedExpression(const HashedExpression& other) : expr(other.expr), hash(other.hash) {}
};
struct ExpressionHasher {
size_t operator()(const HashedExpression value) const {
return value.hash;
}
};
struct ExpressionComparer {
bool operator()(const HashedExpression a, const HashedExpression b) const {
if (a.hash != b.hash) return false;
return ExpressionAnalyzer::equal(a.expr, b.expr);
}
};
// expression -> a count
class ExpressionIntMap : public std::unordered_map<HashedExpression, size_t, ExpressionHasher, ExpressionComparer> {};
// global expression state
Module global; // a module that persists til the end
ExpressionIntMap freqs; // expression -> its frequency
// Normalize an expression, replacing irrelevant bits with
// generalizations to get_local, and make get_locals start
// from 0. Returns nullptr if the expression is irrelevant.
static Expression* normalize(Expression* expr, Module& wasm) {
struct Normalizer {
Module& wasm;
Builder builder;
Index nextLocal = 0;
std::unordered_map<Index, Index> localMap; // old local index => new
Normalizer(Module& wasm) : wasm(wasm), builder(wasm) {}
Expression* parentCopy(Expression* curr) {
return ExpressionManipulator::flexibleCopy(curr, wasm, [&](Expression* curr) { return this->copy(curr); });
}
Expression* copy(Expression* curr) {
// For now, we only handle math-type expressions: having a return value and no side effects
// TODO: do more stuff, modeling side effects etc.
if (!isConcreteWasmType(curr->type)) {
return builder.makeUnreachable();
}
if (auto* get = curr->dynCast<GetLocal>()) {
Index newIndex;
auto iter = localMap.find(get->index);
if (iter == localMap.end()) {
newIndex = nextLocal++;
localMap[get->index] = newIndex;
} else {
newIndex = iter->second;
}
return builder.makeGetLocal(newIndex, get->type);
}
if (curr->is<SetLocal>()) {
assert(curr->type != none); // this is a tee
// look through the tee
return parentCopy(curr->cast<SetLocal>()->value);
}
if (curr->is<Load>()) {
// consider the general case of an arbitrary expression here
return builder.makeGetLocal(nextLocal++, curr->type);
}
if (curr->is<Host>() || curr->is<Call>() || curr->is<CallImport>() || curr->is<CallIndirect>() || curr->is<GetGlobal>() || curr->is<Load>() || curr->is<Return>() || curr->is<Break>() || curr->is<Switch>()) {
return builder.makeUnreachable();
}
return nullptr; // allow the default copy to proceed
}
} normalizer(wasm);
auto* ret = ExpressionManipulator::flexibleCopy(expr, wasm, [&](Expression* curr) { return normalizer.copy(curr); });
if (!isConcreteWasmType(ret->type) || normalizer.nextLocal >= MAX_LOCAL || Measurer::measure(ret) > MAX_EXPRESSION_SIZE) {
return nullptr;
}
return ret;
}
// Scan an expression for local types. Assumes it has MAX_LOCAL locals at most
struct ScanLocals : public WalkerPass<PostWalker<ScanLocals, Visitor<ScanLocals>>> {
WasmType localTypes[MAX_LOCAL];
ScanLocals(Expression* expr) {
for (Index i = 0; i < MAX_LOCAL; i++) {
localTypes[i] = none;
}
walk(expr);
}
void visitGetLocal(GetLocal* curr) {
assert(curr->index < MAX_LOCAL);
localTypes[curr->index] = curr->type;
}
};
// Remap locals
struct RemapLocals : public WalkerPass<PostWalker<RemapLocals, Visitor<RemapLocals>>> {
std::vector<Index>& mapping;
RemapLocals(Expression* expr, std::vector<Index>& mapping) : mapping(mapping) {
walk(expr);
}
void visitGetLocal(GetLocal* curr) {
curr->index = mapping[curr->index];
assert(curr->index < MAX_LOCAL);
}
void visitSetLocal(SetLocal* curr) {
curr->index = mapping[curr->index];
assert(curr->index < MAX_LOCAL);
}
};
struct ScanSettings {
Index* totalExpressions;
bool adviseOnly;
ScanSettings(Index* totalExpressions, bool adviseOnly) : totalExpressions(totalExpressions), adviseOnly(adviseOnly) {}
};
// Scan a module for expressions
struct Scan : public WalkerPass<PostWalker<Scan, UnifiedExpressionVisitor<Scan>>> {
ScanSettings settings;
Scan(ScanSettings settings) : settings(settings) {}
void doWalkFunction(Function* func) {
//std::cout << " [" << func->name << ']' << '\n';
walk(func->body);
}
void visitExpression(Expression* curr) {
// normalize the expression, creating a temporary copy in this module,
// which is ephemeral TODO: avoid keeping them alive til the end of
// module processing to reduce peak mem usage?
auto* normalized = normalize(curr, *getModule());
if (!normalized) return;
if (!settings.adviseOnly) {
(*settings.totalExpressions)++; // this is relevant, count it
}
HashedExpression hashed(normalized);
auto iter = freqs.find(hashed);
if (iter != freqs.end()) {
if (!settings.adviseOnly) {
iter->second++; // just increment it
}
} else {
// create a persistent copy in the global module TODO: avoid the rehash here
auto* copy = ExpressionManipulator::copy(normalized, global);
freqs[HashedExpression(copy)] = settings.adviseOnly ? 0 : 1;
#if 1
// add the permutations on the locals as well, with freq 0, as we just want to use them as optimization targets,
// we don't need to optimize them, we optimize the canonical first form.
ScanLocals scanner(copy);
if (scanner.localTypes[1] != none) {
struct PermutationsLister {
std::vector<std::vector<std::vector<Index>>> list; // index => list of permutations of that size
PermutationsLister() {
list.resize(MAX_LOCAL + 1);
for (size_t i = 1; i < MAX_LOCAL + 1; i++) {
list[i] = Permutation::makeAllPermutations(i);
}
}
};
static PermutationsLister permutationsLister;
Index numLocals = 2;
while (numLocals < MAX_LOCAL && scanner.localTypes[numLocals] != none) {
numLocals++;
}
assert(numLocals <= MAX_LOCAL);
auto& perms = permutationsLister.list.at(numLocals);
// ignore the special first we already handled
for (size_t i = 0; i < perms.size(); i++) {
auto* remapped = ExpressionManipulator::copy(copy, global);
RemapLocals remapper(remapped, perms[i]);
auto hashed = HashedExpression(remapped);
if (freqs.find(hashed) == freqs.end()) {
freqs[hashed] = 0;
}
}
}
#endif
}
}
};
// Generate local values deterministically, using a seed
class LocalGenerator {
Index seed;
public:
LocalGenerator(Index seed) : seed(seed) {}
Literal get(Index index, WasmType type) {
// use low indexes to ensure we get representation of a few special values
// TODO: get each of the MAX_LOCALS to all of its NUM_SPECIALS values
int64_t special = seed; // start with 0-NS having them all taking the same value
if (special >= NUM_SPECIALS) { // then give each a range for itself
special = int64_t(seed) - int64_t(NUM_SPECIALS * (index + 1));
}
if (special >= 0 && special < NUM_SPECIALS) {
if (special < NUM_LIMITS) {
switch (type) {
case i32: return Literal(LIMIT_I32S[special]);
case i64: return Literal(LIMIT_I64S[special]);
case f32: return Literal(LIMIT_F32S[special]);
case f64: return Literal(LIMIT_F64S[special]);
default: WASM_UNREACHABLE();
}
} else {
special -= NUM_LIMITS;
assert(special >= 0 && special < NUM_SMALLS);
special -= MAX_SMALL;
assert(special >= -MAX_SMALL && special <= MAX_SMALL);
switch (type) {
case i32: return Literal(int32_t(special));
case i64: return Literal(int64_t(special));
case f32: return Literal(float(special));
case f64: return Literal(double(special));
default: WASM_UNREACHABLE();
}
}
}
// a general "random"/deterministic value
auto base = rehash(seed, index);
switch (type) {
case i32:
case f32: {
auto ret = Literal(rehash(base, Index(type)));
if (type == f32) ret = ret.castToF32();
return ret;
}
case i64:
case f64: {
auto ret = Literal(rehash(base, Index(type)) | (int64_t(rehash(base, Index(type + 1000))) << 32));
if (type == f64) ret = ret.castToF64();
return ret;
}
default: WASM_UNREACHABLE();
}
}
};
struct TrapException {}; // TODO: use a flow label for optimization?
// Execute the expression over a set of local values
class Runner : public ExpressionRunner<Runner> {
LocalGenerator& localGenerator;
public:
Runner(LocalGenerator& localGenerator) : localGenerator(localGenerator) {}
Flow visitLoop(Loop* curr) {
// loops might be infinite, so must be careful
// but we can't tell if non-infinite, since we don't have state, so loops are just impossible to optimize for now
trap("loop");
WASM_UNREACHABLE();
}
Flow visitCall(Call* curr) {
abort(); // we should not see this
}
Flow visitCallImport(CallImport* curr) {
abort(); // we should not see this
}
Flow visitCallIndirect(CallIndirect* curr) {
abort(); // we should not see this
}
Flow visitGetLocal(GetLocal* curr) {
return Flow(localGenerator.get(curr->index, curr->type));
}
Flow visitSetLocal(SetLocal* curr) {
abort(); // we should not see this
}
Flow visitGetGlobal(GetGlobal* curr) {
abort(); // we should not see this
}
Flow visitSetGlobal(SetGlobal* curr) {
abort(); // we should not see this
}
Flow visitLoad(Load* curr) {
abort(); // we should not see this
}
Flow visitStore(Store* curr) {
abort(); // we should not see this
}
Flow visitHost(Host* curr) {
abort(); // we should not see this
}
void trap(const char* why) override {
throw TrapException();
}
};
// Calculate a hash value based on executing an expression
struct ExecutionHasher {
std::unordered_map<size_t, std::vector<Expression*>> hashClasses; // hash value => list of expressions that have it, so they may be equal
void note(Expression* expr) {
size_t hash;
try {
hash = doHash(expr);
} catch (TrapException& e) {
// we don't bother trying to handle things that trap TODO: maybe abort the whole thing, move try out, for speed?
return;
}
hashClasses[hash].push_back(expr); // we depend on expr being unique, so the classes are mathematical sets
}
size_t doHash(Expression* expr) {
// combine the result of multiple executions into the final hash
size_t hash = 0;
for (Index i = 0; i < NUM_EXECUTIONS; i++) {
LocalGenerator localGenerator(i);
Flow flow = Runner(localGenerator).visit(expr);
if (flow.breaking()) {
hash = rehash(hash, 1);
hash = rehash(hash, 2);
hash = rehash(hash, 3);
hash = rehash(hash, size_t(flow.breakTo.str));
} else {
hash = rehash(hash, 4);
hash = rehash(hash, flow.value.type);
switch (flow.value.type) {
case f32: flow.value = flow.value.castToI32(); break;
case f64: flow.value = flow.value.castToI64(); break;
default: {}
}
switch (flow.value.type) {
case none: hash = rehash(hash, 5); hash = rehash(hash, 6); break;
case i32: hash = rehash(hash, flow.value.geti32()); hash = rehash(hash, 7); break;
case i64: hash = rehash(hash, flow.value.geti64()); hash = rehash(hash, flow.value.geti64() >> 32); break;
default: WASM_UNREACHABLE();
}
}
}
return hash;
}
};
// calculate the weight of an expression - a value we wish to minimize
Index calcWeight(Expression* expr) {
return /* CostAnalyzer(expr).cost + */ Measurer::measure(expr);
}
// can our optimizer do better on a than b?
static bool alreadyOptimizable(Expression* input, WasmType localTypes[MAX_LOCAL], Expression* output) {
Module temp;
// make a single function that receives the expressions locals and returns its output
auto* func = new Function();
func->name = Name("temp");
func->result = input->type;
for (Index i = 0; i < MAX_LOCAL; i++) {
func->params.push_back(localTypes[i]);
}
func->body = ExpressionManipulator::copy(input, temp);
temp.addFunction(func);
// export the function, so optimizations don't kill it!
auto* export_ = new Export();
export_->name = Name("export");
export_->value = func->name;
export_->kind = ExternalKind::Function;
temp.addExport(export_);
// run the optimizer
PassRunner passRunner(&temp);
passRunner.addDefaultOptimizationPasses();
passRunner.run();
// evaluate the output vs b
return calcWeight(func->body) <= calcWeight(output);
}
// Given two expressions that hashing suggests might be the same, try
// harder directly on the two to prove or disprove equivalence
bool looksValid(Expression* a, Expression* b) {
if (a->type != b->type) return false; // hash collision, these are not even the same type
// local types must be identical, otherwise the rule isn't even valid to apply
ScanLocals aScanner(a), bScanner(b);
for (Index i = 0; i < MAX_LOCAL; i++) {
if (aScanner.localTypes[i] != bScanner.localTypes[i]) {
return false; // mismatching local types
}
}
// Let's use brute force: we'll run the same checks we run for hashing,
// but instead of a single hash summarizing it all, we'll check each
// case on the two expressions.
for (Index i = 0; i < NUM_EXECUTIONS; i++) {
LocalGenerator localGenerator(i);
Flow aFlow = Runner(localGenerator).visit(a);
Flow bFlow = Runner(localGenerator).visit(b);
// TODO: breaking
if (aFlow.value != bFlow.value) return false;
}
// let's see if this possible optimization is already something our
// optimizer can do: if we optimize the input, do we get something
// as good or better than the output?
if (alreadyOptimizable(a, aScanner.localTypes, b)) return false;
// we see no reason these two should not be joined together in holy optimony
return true;
}
// Generalize an expression. Currently just generalizes away
// constant values, but we should do more, e.g. maybe fold away
// differences in shifts? TODO
Expression* generalize(Expression* expr, Module& wasm) {
struct Generalizer {
Module& wasm;
Builder builder;
Generalizer(Module& wasm) : wasm(wasm), builder(wasm) {}
Expression* copy(Expression* curr) {
if (curr->is<Const>()) {
return builder.makeUnreachable();
}
return nullptr; // allow the default copy to proceed
}
} generalizer(wasm);
return ExpressionManipulator::flexibleCopy(expr, wasm, [&](Expression* curr) { return generalizer.copy(curr); });
}
int main(int argc, const char *argv[]) {
// receive arguments
std::vector<std::string> filenames;
Options options("wasm-analyze", "Analyze a set of wasm modules. Provide a set of input files, optionally split by 'advice:' (in which case files afterwards are just advice, used to find optimization outputs but not inputs we focus on optimizing)");
options.add_positional("INFILES", Options::Arguments::N,
[&](Options *o, const std::string &argument) {
filenames.push_back(argument);
});
options.parse(argc, argv);
Index totalExpressions = 0;
bool adviseOnly = false;
// read inputs
for (auto& filename : filenames) {
if (filename == "advice:" || filename == "advise:") {
adviseOnly = true;
std::cerr << "[advice-only from here]\n";
continue;
}
auto input(read_file<std::string>(filename, Flags::Text, Flags::Release));
Module wasm;
std::cerr << "[processing: " << filename << ']' << '\n';
try {
ModuleReader reader;
reader.read(filename, wasm);
} catch (ParseException& p) {
p.dump(std::cerr);
Fatal() << "error in parsing input " << filename;
}
// scan all expressions in all functions, optimized and not
PassRunner passRunner(&wasm);
passRunner.add<Scan>(ScanSettings(&totalExpressions, adviseOnly));
passRunner.addDefaultOptimizationPasses();
passRunner.add<Scan>(ScanSettings(&totalExpressions, adviseOnly));
passRunner.run();
}
// print frequencies
#if 0
std::cout << "Frequencies:\n";
std::vector<HashedExpression> sorted;
for (auto& iter : freqs) {
sorted.push_back(iter.first);
}
std::sort(sorted.begin(), sorted.end(), [&](const HashedExpression& a, const HashedExpression& b) {
auto diff = int64_t(freqs[a]) - int64_t(freqs[b]);
if (diff > 0) return true;
if (diff < 0) return false;
return size_t(a.expr) < size_t(b.expr);
});
for (auto& item : sorted) {
std::cout << freqs[item] << " : " << item.expr << '\n';
}
#endif
// perform execution hashing, looking for expressions that are functionally equivalent,
// so one can be optimized to the other
std::cerr << "[hashing executions]\n";
ExecutionHasher executionHasher;
for (auto& iter : freqs) {
auto* expr = iter.first.expr;
executionHasher.note(expr);
}
// Basic statistics
std::cerr << "[writing basic output]\n";
std::cout << "Execution hashing info:\n";
std::cout << " num expression nodes in total: " << totalExpressions << '\n';
std::cout << " num unique expressions: " << freqs.size() << '\n';
{
size_t total = 0;
for (auto& pair : executionHasher.hashClasses) {
total += pair.second.size();
}
std::cout << " num relevant expressions: " << total << '\n';
}
std::cout << " num execution classes: " << executionHasher.hashClasses.size() << '\n';
{
size_t max = 0;
for (auto& pair : executionHasher.hashClasses) {
max = std::max(max, pair.second.size());
}
std::cout << " max class size: " << max << '\n';
}
// Detailed output
{
// a rule is a connection between one pattern and another, which we think may be equivalent to it,
// and which may provide a measured benefit
// TODO: test rules on more random inputs, trying to prove they are not equivalent?
struct Rule {
Expression* from;
Expression* to;
size_t benefit;
Rule(Expression* from, Expression* to, size_t benefit) : from(from), to(to), benefit(benefit) {}
};
std::vector<Rule> rules;
std::cerr << "[finding rules]\n";
for (auto& pair : executionHasher.hashClasses) {
auto& clazz = pair.second;
Index size = clazz.size();
if (size < 2) continue;
// consider all pairs, since some may be spurious hash collisions
for (Index i = 0; i < size; i++) {
auto* iExpr = clazz[i];
auto iFreq = freqs[iExpr];
if (iFreq == 0) continue; // no frequency means no benefit to optimize it; this expression is just a target of optimization, not an origin
Index iSize = calcWeight(iExpr);
Expression* best = nullptr;
Index bestSize = -1;
for (Index j = 0; j < size; j++) {
if (i == j) continue;
auto* jExpr = clazz[j];
Index jSize = calcWeight(jExpr);
// we are looking for a rule where i => j, so we need j to be smaller
if (iSize <= jSize) continue; // TODO: for equality, look not just at size, but cost etc.
// a likely candidate, if direct attempts to prove they differ fail, this is worth reporting to the user
if (best && jSize >= bestSize) continue; // we can't do better
if (looksValid(iExpr, jExpr)) {
best = jExpr;
bestSize = jSize;
}
}
if (best) {
rules.emplace_back(iExpr, best, (iSize - bestSize) * iFreq);
}
}
}
// Many rules are part of a more general pattern, for example x + 1 + 1 === x + 2 is
// closely related to x + 1 + 2 === x + 3. The generalized rule is what the human would
// write in the optimizer, so to assess the benefit of rules, we must generalize in
// our output.
std::cerr << "[generalizing]\n";
struct GeneralizedRule : public Rule {
std::vector<Rule*> rules; // the specific rules underlying this generalization
GeneralizedRule(Expression* from, Rule* rule) : Rule(from, nullptr, 0) {
addRule(rule);
}
void addRule(Rule* rule) {
benefit += rule->benefit;
rules.push_back(rule);
}
};
// hashed from expression => the generalized rules for that expression
std::unordered_map<HashedExpression, GeneralizedRule, ExpressionHasher, ExpressionComparer> generalizedRules;
for (auto& rule : rules) {
auto generalizedFrom = HashedExpression(generalize(rule.from, global)); // TODO: save memory, don't use global unless needed?
auto iter = generalizedRules.find(generalizedFrom);
if (iter == generalizedRules.end()) {
generalizedRules.emplace(generalizedFrom, GeneralizedRule(generalizedFrom.expr, &rule));
} else {
iter->second.addRule(&rule);
}
}
// final sorting and output
std::cerr << "[sorting generalized rules]\n";
std::vector<GeneralizedRule*> sortedGeneralizedRules;
for (auto& pair : generalizedRules) {
sortedGeneralizedRules.push_back(&pair.second);
}
auto ruleSorter = [](const Rule* a, const Rule* b) {
// primary sorting criteria is the size benefit
auto diff = int64_t(a->benefit) - int64_t(b->benefit);
if (diff > 0) return true;
if (diff < 0) return false;
return size_t(a->from) < size_t(b->from);
};
std::sort(sortedGeneralizedRules.begin(), sortedGeneralizedRules.end(), [&ruleSorter](const GeneralizedRule* a, const GeneralizedRule* b) {
return ruleSorter(a, b);
});
std::cout << "sorted possible optimization rules:\n";
Index totalWeight = totalExpressions * 2; // Just an estimate FIXME
size_t i = 0;
for (auto* item : sortedGeneralizedRules) {
std::cout << "\n[generalized rule " << (i++) << ": benefit: " << item->benefit << ", (" << (100*double(item->benefit)/totalWeight) << "%)], input pattern:\n" << item->from << '\n';
// show the specific rules underlying the generalized one
std::sort(item->rules.begin(), item->rules.end(), ruleSorter);
for (auto* rule : item->rules) {
std::cout << "\n[child specific rule benefit: " << rule->benefit << ", (" << (100*double(rule->benefit)/totalWeight) << "%)], possible rule:\n" << rule->from << "\n =->\n" << rule->to << '\n';
}
}
}
// TODO TODO: if all execution hashes of expr are the same, it might be constant (avoids needing to have all constants hashed)
}