blob: 04216f1ee7e059593cf39bda30e83c1e13d4b952 [file] [log] [blame] [edit]
/*
* Copyright 2016 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Optimize combinations of instructions
//
#include <algorithm>
#include <wasm.h>
#include <pass.h>
#include <wasm-s-parser.h>
#include <support/threads.h>
#include <ast_utils.h>
#include <ast/cost.h>
#include <ast/properties.h>
namespace wasm {
Name I32_EXPR = "i32.expr",
I64_EXPR = "i64.expr",
F32_EXPR = "f32.expr",
F64_EXPR = "f64.expr",
ANY_EXPR = "any.expr";
// A pattern
struct Pattern {
Expression* input;
Expression* output;
Pattern(Expression* input, Expression* output) : input(input), output(output) {}
};
// Database of patterns
struct PatternDatabase {
Module wasm;
char* input;
std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more
PatternDatabase() {
// generate module
input = strdup(
#include "OptimizeInstructions.wast.processed"
);
try {
SExpressionParser parser(input);
Element& root = *parser.root;
SExpressionWasmBuilder builder(wasm, *root[0]);
// parse module form
auto* func = wasm.getFunction("patterns");
auto* body = func->body->cast<Block>();
for (auto* item : body->list) {
auto* pair = item->cast<Block>();
patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]);
}
} catch (ParseException& p) {
p.dump(std::cerr);
Fatal() << "error in parsing wasm binary";
}
}
~PatternDatabase() {
free(input);
};
};
static PatternDatabase* database = nullptr;
struct DatabaseEnsurer {
DatabaseEnsurer() {
assert(!database);
database = new PatternDatabase;
}
};
// Check for matches and apply them
struct Match {
Module& wasm;
Pattern& pattern;
Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {}
std::vector<Expression*> wildcards; // id in i32.any(id) etc. => the expression it represents in this match
// Comparing/checking
// Check if we can match to this pattern, updating ourselves with the info if so
bool check(Expression* seen) {
// compare seen to the pattern input, doing a special operation for our "wildcards"
assert(wildcards.size() == 0);
return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, *this);
}
bool compare(Expression* subInput, Expression* subSeen) {
CallImport* call = subInput->dynCast<CallImport>();
if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false;
Index index = call->operands[0]->cast<Const>()->value.geti32();
// handle our special functions
auto checkMatch = [&](WasmType type) {
if (type != none && subSeen->type != type) return false;
while (index >= wildcards.size()) {
wildcards.push_back(nullptr);
}
if (!wildcards[index]) {
// new wildcard
wildcards[index] = subSeen; // NB: no need to copy
return true;
} else {
// We are seeing this index for a second or later time, check it matches
return ExpressionAnalyzer::equal(subSeen, wildcards[index]);
};
};
if (call->target == I32_EXPR) {
if (checkMatch(i32)) return true;
} else if (call->target == I64_EXPR) {
if (checkMatch(i64)) return true;
} else if (call->target == F32_EXPR) {
if (checkMatch(f32)) return true;
} else if (call->target == F64_EXPR) {
if (checkMatch(f64)) return true;
} else if (call->target == ANY_EXPR) {
if (checkMatch(none)) return true;
}
return false;
}
// Applying/copying
// Apply the match, generate an output expression from the matched input, performing substitutions as necessary
Expression* apply() {
return ExpressionManipulator::flexibleCopy(pattern.output, wasm, *this);
}
// When copying a wildcard, perform the substitution.
// TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow.
Expression* copy(Expression* curr) {
CallImport* call = curr->dynCast<CallImport>();
if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr;
Index index = call->operands[0]->cast<Const>()->value.geti32();
// handle our special functions
if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) {
return ExpressionManipulator::copy(wildcards.at(index), wasm);
}
return nullptr;
}
};
// Main pass class
struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new OptimizeInstructions; }
void prepareToRun(PassRunner* runner, Module* module) override {
static DatabaseEnsurer ensurer;
}
void visitExpression(Expression* curr) {
// we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles
while (1) {
auto* handOptimized = handOptimize(curr);
if (handOptimized) {
curr = handOptimized;
replaceCurrent(curr);
continue;
}
auto iter = database->patternMap.find(curr->_id);
if (iter == database->patternMap.end()) return;
auto& patterns = iter->second;
bool more = false;
for (auto& pattern : patterns) {
Match match(*getModule(), pattern);
if (match.check(curr)) {
curr = match.apply();
replaceCurrent(curr);
more = true;
break; // exit pattern for loop, return to main while loop
}
}
if (!more) break;
}
}
// Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe
Expression* handOptimize(Expression* curr) {
if (auto* binary = curr->dynCast<Binary>()) {
if (Properties::isSymmetric(binary)) {
// canonicalize a const to the second position
if (binary->left->is<Const>() && !binary->right->is<Const>()) {
std::swap(binary->left, binary->right);
}
}
// pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
if (binary->op == BinaryOp::ShrSInt32 && binary->right->is<Const>()) {
auto shifts = binary->right->cast<Const>()->value.geti32();
if (shifts == 24 || shifts == 16) {
auto* left = binary->left->dynCast<Binary>();
if (left && left->op == ShlInt32 && left->right->is<Const>() && left->right->cast<Const>()->value.geti32() == shifts) {
auto* load = left->left->dynCast<Load>();
if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) {
load->signed_ = true;
return load;
}
}
}
} else if (binary->op == EqInt32) {
if (auto* c = binary->right->dynCast<Const>()) {
if (c->value.geti32() == 0) {
// equal 0 => eqz
return Builder(*getModule()).makeUnary(EqZInt32, binary->left);
}
}
if (auto* c = binary->left->dynCast<Const>()) {
if (c->value.geti32() == 0) {
// equal 0 => eqz
return Builder(*getModule()).makeUnary(EqZInt32, binary->right);
}
}
} else if (binary->op == AndInt32) {
if (auto* right = binary->right->dynCast<Const>()) {
if (right->type == i32) {
auto mask = right->value.geti32();
// and with -1 does nothing (common in asm.js output)
if (mask == -1) {
return binary->left;
}
// small loads do not need to be masted, the load itself masks
if (auto* load = binary->left->dynCast<Load>()) {
if ((load->bytes == 1 && mask == 0xff) ||
(load->bytes == 2 && mask == 0xffff)) {
load->signed_ = false;
return load;
}
} else if (mask == 1 && Properties::emitsBoolean(binary->left)) {
// (bool) & 1 does not need the outer mask
return binary->left;
}
}
}
return conditionalizeExpensiveOnBitwise(binary);
} else if (binary->op == OrInt32) {
return conditionalizeExpensiveOnBitwise(binary);
}
} else if (auto* unary = curr->dynCast<Unary>()) {
// de-morgan's laws
if (unary->op == EqZInt32) {
if (auto* inner = unary->value->dynCast<Binary>()) {
switch (inner->op) {
case EqInt32: inner->op = NeInt32; return inner;
case NeInt32: inner->op = EqInt32; return inner;
case LtSInt32: inner->op = GeSInt32; return inner;
case LtUInt32: inner->op = GeUInt32; return inner;
case LeSInt32: inner->op = GtSInt32; return inner;
case LeUInt32: inner->op = GtUInt32; return inner;
case GtSInt32: inner->op = LeSInt32; return inner;
case GtUInt32: inner->op = LeUInt32; return inner;
case GeSInt32: inner->op = LtSInt32; return inner;
case GeUInt32: inner->op = LtUInt32; return inner;
case EqInt64: inner->op = NeInt64; return inner;
case NeInt64: inner->op = EqInt64; return inner;
case LtSInt64: inner->op = GeSInt64; return inner;
case LtUInt64: inner->op = GeUInt64; return inner;
case LeSInt64: inner->op = GtSInt64; return inner;
case LeUInt64: inner->op = GtUInt64; return inner;
case GtSInt64: inner->op = LeSInt64; return inner;
case GtUInt64: inner->op = LeUInt64; return inner;
case GeSInt64: inner->op = LtSInt64; return inner;
case GeUInt64: inner->op = LtUInt64; return inner;
case EqFloat32: inner->op = NeFloat32; return inner;
case NeFloat32: inner->op = EqFloat32; return inner;
case EqFloat64: inner->op = NeFloat64; return inner;
case NeFloat64: inner->op = EqFloat64; return inner;
default: {}
}
}
}
} else if (auto* set = curr->dynCast<SetGlobal>()) {
// optimize out a set of a get
auto* get = set->value->dynCast<GetGlobal>();
if (get && get->name == set->name) {
ExpressionManipulator::nop(curr);
}
} else if (auto* iff = curr->dynCast<If>()) {
iff->condition = optimizeBoolean(iff->condition);
if (iff->ifFalse) {
if (auto* unary = iff->condition->dynCast<Unary>()) {
if (unary->op == EqZInt32) {
// flip if-else arms to get rid of an eqz
iff->condition = unary->value;
std::swap(iff->ifTrue, iff->ifFalse);
}
}
}
} else if (auto* select = curr->dynCast<Select>()) {
select->condition = optimizeBoolean(select->condition);
auto* condition = select->condition->dynCast<Unary>();
if (condition && condition->op == EqZInt32) {
// flip select to remove eqz, if we can reorder
EffectAnalyzer ifTrue(select->ifTrue);
EffectAnalyzer ifFalse(select->ifFalse);
if (!ifTrue.invalidates(ifFalse)) {
select->condition = condition->value;
std::swap(select->ifTrue, select->ifFalse);
}
}
} else if (auto* br = curr->dynCast<Break>()) {
if (br->condition) {
br->condition = optimizeBoolean(br->condition);
}
} else if (auto* load = curr->dynCast<Load>()) {
optimizeMemoryAccess(load->ptr, load->offset);
} else if (auto* store = curr->dynCast<Store>()) {
optimizeMemoryAccess(store->ptr, store->offset);
// stores of fewer bits truncates anyhow
if (auto* binary = store->value->dynCast<Binary>()) {
if (binary->op == AndInt32) {
if (auto* right = binary->right->dynCast<Const>()) {
if (right->type == i32) {
auto mask = right->value.geti32();
if ((store->bytes == 1 && mask == 0xff) ||
(store->bytes == 2 && mask == 0xffff)) {
store->value = binary->left;
}
}
}
}
} else if (auto* unary = store->value->dynCast<Unary>()) {
if (unary->op == WrapInt64) {
// instead of wrapping to 32, just store some of the bits in the i64
store->valueType = i64;
store->value = unary->value;
}
}
}
return nullptr;
}
private:
// Optimize given that the expression is flowing into a boolean context
Expression* optimizeBoolean(Expression* boolean) {
if (auto* unary = boolean->dynCast<Unary>()) {
if (unary && unary->op == EqZInt32) {
auto* unary2 = unary->value->dynCast<Unary>();
if (unary2 && unary2->op == EqZInt32) {
// double eqz
return unary2->value;
}
}
} else if (auto* binary = boolean->dynCast<Binary>()) {
if (binary->op == OrInt32) {
// an or flowing into a boolean context can consider each input as boolean
binary->left = optimizeBoolean(binary->left);
binary->right = optimizeBoolean(binary->right);
} else if (binary->op == NeInt32) {
// x != 0 is just x if it's used as a bool
if (auto* num = binary->right->dynCast<Const>()) {
if (num->value.geti32() == 0) {
return binary->left;
}
}
}
} else if (auto* block = boolean->dynCast<Block>()) {
if (block->type == i32 && block->list.size() > 0) {
block->list.back() = optimizeBoolean(block->list.back());
}
} else if (auto* iff = boolean->dynCast<If>()) {
if (iff->type == i32) {
iff->ifTrue = optimizeBoolean(iff->ifTrue);
iff->ifFalse = optimizeBoolean(iff->ifFalse);
}
}
// TODO: recurse into br values?
return boolean;
}
// expensive1 | expensive2 can be turned into expensive1 ? 1 : expensive2, and
// expensive | cheap can be turned into cheap ? 1 : expensive,
// so that we can avoid one expensive computation, if it has no side effects.
Expression* conditionalizeExpensiveOnBitwise(Binary* binary) {
// this operation can increase code size, so don't always do it
auto& options = getPassRunner()->options;
if (options.optimizeLevel < 2 || options.shrinkLevel > 0) return nullptr;
const auto MIN_COST = 7;
assert(binary->op == AndInt32 || binary->op == OrInt32);
if (binary->right->is<Const>()) return nullptr; // trivial
// bitwise logical operator on two non-numerical values, check if they are boolean
auto* left = binary->left;
auto* right = binary->right;
if (!Properties::emitsBoolean(left) || !Properties::emitsBoolean(right)) return nullptr;
auto leftEffects = EffectAnalyzer(left).hasSideEffects();
auto rightEffects = EffectAnalyzer(right).hasSideEffects();
if (leftEffects && rightEffects) return nullptr; // both must execute
// canonicalize with side effects, if any, happening on the left
if (rightEffects) {
if (CostAnalyzer(left).cost < MIN_COST) return nullptr; // avoidable code is too cheap
std::swap(left, right);
} else if (leftEffects) {
if (CostAnalyzer(right).cost < MIN_COST) return nullptr; // avoidable code is too cheap
} else {
// no side effects, reorder based on cost estimation
auto leftCost = CostAnalyzer(left).cost;
auto rightCost = CostAnalyzer(right).cost;
if (std::max(leftCost, rightCost) < MIN_COST) return nullptr; // avoidable code is too cheap
// canonicalize with expensive code on the right
if (leftCost > rightCost) {
std::swap(left, right);
}
}
// worth it! perform conditionalization
Builder builder(*getModule());
if (binary->op == OrInt32) {
return builder.makeIf(left, builder.makeConst(Literal(int32_t(1))), right);
} else { // &
return builder.makeIf(left, right, builder.makeConst(Literal(int32_t(0))));
}
}
// fold constant factors into the offset
void optimizeMemoryAccess(Expression*& ptr, Address& offset) {
// ptr may be a const, but it isn't worth folding that in (we still have a const); in fact,
// it's better to do the opposite for gzip purposes as well as for readability.
auto* last = ptr->dynCast<Const>();
if (last) {
last->value = Literal(int32_t(last->value.geti32() + offset));
offset = 0;
}
}
};
Pass *createOptimizeInstructionsPass() {
return new OptimizeInstructions();
}
} // namespace wasm