blob: 41012fccbe466a4da41d28cc85d83014a4409105 [file] [log] [blame] [edit]
/*
* Copyright 2016 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Removes dead, i.e. unreachable, code.
//
// We keep a record of when control flow is reachable. When it isn't, we
// kill (turn into unreachable). We then fold away entire unreachable
// expressions.
//
// When dead code causes an operation to not happen, like a store, a call
// or an add, we replace with a block with a list of what does happen.
// That isn't necessarily smaller, but blocks are friendlier to other
// optimizations: blocks can be merged and eliminated, and they clearly
// have no side effects.
//
#include <wasm.h>
#include <pass.h>
#include <ast_utils.h>
#include <wasm-builder.h>
namespace wasm {
struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new DeadCodeElimination; }
// whether the current code is actually reachable
bool reachable;
void doWalkFunction(Function* func) {
reachable = true;
walk(func->body);
}
std::set<Name> reachableBreaks;
void addBreak(Name name) {
assert(reachable);
reachableBreaks.insert(name);
}
bool isDead(Expression* curr) {
return curr && curr->is<Unreachable>();
}
// things that stop control flow
void visitBreak(Break* curr) {
if (isDead(curr->value)) {
// the condition is evaluated last, so if the value was unreachable, the whole thing is
replaceCurrent(curr->value);
return;
}
addBreak(curr->name);
if (!curr->condition) {
reachable = false;
}
}
void visitSwitch(Switch* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
return;
}
for (auto target : curr->targets) {
addBreak(target);
}
addBreak(curr->default_);
reachable = false;
}
void visitReturn(Return* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
return;
}
reachable = false;
}
void visitUnreachable(Unreachable* curr) {
reachable = false;
}
// we maintain a stack for blocks, as we visit each item, and the parameter is the index
std::vector<Index> blockStack; // index in current block
static void doPreBlock(DeadCodeElimination* self, Expression** currp) {
self->blockStack.push_back(0);
}
static void doAfterBlockElement(DeadCodeElimination* self, Expression** currp) {
auto* block = (*currp)->cast<Block>();
Index i = self->blockStack.back();
self->blockStack.back()++;
if (!self->reachable) {
// control flow ended in the middle of the block, so we can truncate the rest.
// note that we still visit the rest, so if we already truncated, do not lengthen.
// note that it is ok that we visit the others even though the list was shortened;
// our arena vectors leave things as they are when shrinking.
if (block->list.size() > i + 1) {
// but note that it is not legal to truncate a block if it leaves a bad last element,
// given the wasm type rules. For example, if the last element is a return, then
// the block doesn't care about it for type checking purposes, but if removing
// it would leave an element with type none as the last, that could be a problem,
// see https://github.com/WebAssembly/spec/issues/355
if (!(isConcreteWasmType(block->type) && block->list[i]->type == none)) {
block->list.resize(i + 1);
// note that we do *not* finalize here. it is incorrect to re-finalize a block
// after removing elements, as it may no longer have branches to it that would
// determine its type, so re-finalizing would just wipe out an existing type
// that it had.
}
}
}
}
void visitBlock(Block* curr) {
blockStack.pop_back();
if (curr->name.is()) {
reachable = reachable || reachableBreaks.count(curr->name);
reachableBreaks.erase(curr->name);
}
if (curr->list.size() == 1 && isDead(curr->list[0])) {
replaceCurrent(curr->list[0]);
}
}
void visitLoop(Loop* curr) {
if (curr->name.is()) {
reachableBreaks.erase(curr->name);
}
if (isDead(curr->body)) {
replaceCurrent(curr->body);
return;
}
}
// ifs need special handling
std::vector<bool> ifStack; // stack of reachable state, for forking and joining
static void doAfterIfCondition(DeadCodeElimination* self, Expression** currp) {
self->ifStack.push_back(self->reachable);
}
static void doAfterIfElseTrue(DeadCodeElimination* self, Expression** currp) {
assert((*currp)->cast<If>()->ifFalse);
bool reachableBefore = self->ifStack.back();
self->ifStack.pop_back();
self->ifStack.push_back(self->reachable);
self->reachable = reachableBefore;
}
void visitIf(If* curr) {
// the ifStack has the branch that joins us, either from before if just an if, or the ifTrue if an if-else
reachable = reachable || ifStack.back();
ifStack.pop_back();
if (isDead(curr->condition)) {
replaceCurrent(curr->condition);
}
}
static void scan(DeadCodeElimination* self, Expression** currp) {
if (!self->reachable) {
// convert to an unreachable. do this without UB, even though we have no destructors on AST nodes
#define DELEGATE(CLASS_TO_VISIT) \
{ ExpressionManipulator::convert<CLASS_TO_VISIT, Unreachable>(static_cast<CLASS_TO_VISIT*>(*currp)); break; }
switch ((*currp)->_id) {
case Expression::Id::BlockId: DELEGATE(Block);
case Expression::Id::IfId: DELEGATE(If);
case Expression::Id::LoopId: DELEGATE(Loop);
case Expression::Id::BreakId: DELEGATE(Break);
case Expression::Id::SwitchId: DELEGATE(Switch);
case Expression::Id::CallId: DELEGATE(Call);
case Expression::Id::CallImportId: DELEGATE(CallImport);
case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
case Expression::Id::GetLocalId: DELEGATE(GetLocal);
case Expression::Id::SetLocalId: DELEGATE(SetLocal);
case Expression::Id::GetGlobalId: DELEGATE(GetGlobal);
case Expression::Id::SetGlobalId: DELEGATE(SetGlobal);
case Expression::Id::LoadId: DELEGATE(Load);
case Expression::Id::StoreId: DELEGATE(Store);
case Expression::Id::ConstId: DELEGATE(Const);
case Expression::Id::UnaryId: DELEGATE(Unary);
case Expression::Id::BinaryId: DELEGATE(Binary);
case Expression::Id::SelectId: DELEGATE(Select);
case Expression::Id::DropId: DELEGATE(Drop);
case Expression::Id::ReturnId: DELEGATE(Return);
case Expression::Id::HostId: DELEGATE(Host);
case Expression::Id::NopId: DELEGATE(Nop);
case Expression::Id::UnreachableId: break;
case Expression::Id::InvalidId:
default: WASM_UNREACHABLE();
}
#undef DELEGATE
return;
}
auto* curr =* currp;
if (curr->is<If>()) {
self->pushTask(DeadCodeElimination::doVisitIf, currp);
if (curr->cast<If>()->ifFalse) {
self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->ifFalse);
self->pushTask(DeadCodeElimination::doAfterIfElseTrue, currp);
}
self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->ifTrue);
self->pushTask(DeadCodeElimination::doAfterIfCondition, currp);
self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->condition);
} else if (curr->is<Block>()) {
self->pushTask(DeadCodeElimination::doVisitBlock, currp);
auto& list = curr->cast<Block>()->list;
for (int i = int(list.size()) - 1; i >= 0; i--) {
self->pushTask(DeadCodeElimination::doAfterBlockElement, currp);
self->pushTask(DeadCodeElimination::scan, &list[i]);
}
self->pushTask(DeadCodeElimination::doPreBlock, currp);
} else {
WalkerPass<PostWalker<DeadCodeElimination>>::scan(self, currp);
}
}
// other things
Expression* drop(Expression* toDrop) {
if (toDrop->is<Unreachable>()) return toDrop;
return Builder(*getModule()).makeDrop(toDrop);
}
template<typename T>
Expression* handleCall(T* curr) {
for (Index i = 0; i < curr->operands.size(); i++) {
if (isDead(curr->operands[i])) {
if (i > 0) {
auto* block = getModule()->allocator.alloc<Block>();
Index newSize = i + 1;
block->list.resize(newSize);
Index j = 0;
for (; j < newSize; j++) {
block->list[j] = drop(curr->operands[j]);
}
block->finalize();
return replaceCurrent(block);
} else {
return replaceCurrent(curr->operands[i]);
}
}
}
return curr;
}
void visitCall(Call* curr) {
handleCall(curr);
}
void visitCallImport(CallImport* curr) {
handleCall(curr);
}
void visitCallIndirect(CallIndirect* curr) {
if (handleCall(curr) != curr) return;
if (isDead(curr->target)) {
auto* block = getModule()->allocator.alloc<Block>();
for (auto* operand : curr->operands) {
block->list.push_back(drop(operand));
}
block->list.push_back(curr->target);
block->finalize();
replaceCurrent(block);
}
}
void visitSetLocal(SetLocal* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
}
}
void visitLoad(Load* curr) {
if (isDead(curr->ptr)) {
replaceCurrent(curr->ptr);
}
}
void visitStore(Store* curr) {
if (isDead(curr->ptr)) {
replaceCurrent(curr->ptr);
return;
}
if (isDead(curr->value)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
block->list[0] = drop(curr->ptr);
block->list[1] = curr->value;
block->finalize();
replaceCurrent(block);
}
}
void visitUnary(Unary* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
}
}
void visitBinary(Binary* curr) {
if (isDead(curr->left)) {
replaceCurrent(curr->left);
return;
}
if (isDead(curr->right)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
block->list[0] = drop(curr->left);
block->list[1] = curr->right;
block->finalize();
replaceCurrent(block);
}
}
void visitSelect(Select* curr) {
if (isDead(curr->ifTrue)) {
replaceCurrent(curr->ifTrue);
return;
}
if (isDead(curr->ifFalse)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
block->list[0] = drop(curr->ifTrue);
block->list[1] = curr->ifFalse;
block->finalize();
replaceCurrent(block);
return;
}
if (isDead(curr->condition)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(3);
block->list[0] = drop(curr->ifTrue);
block->list[1] = drop(curr->ifFalse);
block->list[2] = curr->condition;
block->finalize();
replaceCurrent(block);
return;
}
}
void visitHost(Host* curr) {
// TODO
}
void visitFunction(Function* curr) {
assert(reachableBreaks.size() == 0);
}
};
Pass *createDeadCodeEliminationPass() {
return new DeadCodeElimination();
}
} // namespace wasm