blob: c4fd51b16dafb008eeac8a111ce06f86997ecd44 [file] [log] [blame] [edit]
/*
* Copyright 2023 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Optimize away trivial tuples. When values are bundled together in a tuple, we
// are limited in how we can optimize then in the various local-related passes,
// like this:
//
// (local.set $tuple
// (tuple.make 3 (A) (B) (C)))
// (use
// (tuple.extract 3 0
// (local.get $tuple)))
//
// If there are no other uses, then we just need one of the three elements. By
// lowing them to three separate locals, other passes can remove the other two.
//
// Specifically, this pass seeks out tuple locals that have these properties:
//
// * They are always written either a tuple.make or another tuple local with
// these properties.
// * They are always used either in tuple.extract or they are copied to another
// tuple local with these properties.
//
// The set of those tuple locals can be easily optimized into individual locals,
// as the tuple does not "escape" into, say, a return value.
//
// TODO: Blocks etc. might be handled here, but it's not clear if we want to:
// there are situations where multivalue leads to smaller code using
// those constructs. Atm this pass should only remove things that are
// definitely worth lowering.
//
#include <pass.h>
#include <support/unique_deferring_queue.h>
#include <wasm-builder.h>
#include <wasm.h>
namespace wasm {
struct TupleOptimization : public WalkerPass<PostWalker<TupleOptimization>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<TupleOptimization>();
}
// Track the number of uses for each tuple local. We consider a use as a
// local.get, a set, or a tee. A tee counts as two uses (since it both sets
// and gets, and so we must see that it is both used and uses properly).
std::vector<Index> uses;
// Tracks which tuple local uses are valid, that is, follow the properties
// above. If we have more uses than valid uses then we must have an invalid
// one, and the local cannot be optimized.
std::vector<Index> validUses;
// When one tuple local copies the value of another, we need to track the
// index that was copied, as if the source ends up bad then the target is bad
// as well.
//
// This is a symmetrical map, that is, we consider copies to work both ways:
//
// x \in copiedIndexed[y] <==> y \in copiedIndexed[x]
//
std::vector<std::unordered_set<Index>> copiedIndexes;
void doWalkFunction(Function* func) {
// If tuples are not enabled, or there are no tuple locals, then there is no
// work to do.
if (!getModule()->features.hasMultivalue()) {
return;
}
bool hasTuple = false;
for (auto var : func->vars) {
if (var.isTuple()) {
hasTuple = true;
break;
}
}
if (!hasTuple) {
return;
}
// Prepare global data structures before we collect info.
auto numLocals = func->getNumLocals();
uses.resize(numLocals);
validUses.resize(numLocals);
copiedIndexes.resize(numLocals);
// Walk the code to collect info.
Super::doWalkFunction(func);
// Analyze and optimize.
optimize(func);
}
void visitLocalGet(LocalGet* curr) {
if (curr->type.isTuple()) {
uses[curr->index]++;
}
}
void visitLocalSet(LocalSet* curr) {
if (getFunction()->getLocalType(curr->index).isTuple()) {
// See comment above about tees (we consider their set and get each a
// separate use).
uses[curr->index] += curr->isTee() ? 2 : 1;
auto* value = curr->value;
// We need the input to the local to be another such local (from a tee, or
// a get), or a tuple.make.
if (auto* tee = value->dynCast<LocalSet>()) {
assert(tee->isTee());
// We don't want to count anything as valid if the inner tee is
// unreachable. In that case the outer tee is also unreachable, of
// course, and in fact they might not even have the same tuple type or
// the inner one might not even be a tuple (since we are in unreachable
// code, that is possible). We could try to optimize unreachable tees in
// some cases, but it's simpler to let DCE simplify the code first.
if (tee->type != Type::unreachable) {
validUses[tee->index]++;
validUses[curr->index]++;
copiedIndexes[tee->index].insert(curr->index);
copiedIndexes[curr->index].insert(tee->index);
}
} else if (auto* get = value->dynCast<LocalGet>()) {
validUses[get->index]++;
validUses[curr->index]++;
copiedIndexes[get->index].insert(curr->index);
copiedIndexes[curr->index].insert(get->index);
} else if (value->is<TupleMake>()) {
validUses[curr->index]++;
}
}
}
void visitTupleExtract(TupleExtract* curr) {
// We need the input to be a local, either from a tee or a get.
if (auto* set = curr->tuple->dynCast<LocalSet>()) {
validUses[set->index]++;
} else if (auto* get = curr->tuple->dynCast<LocalGet>()) {
validUses[get->index]++;
}
}
void optimize(Function* func) {
auto numLocals = func->getNumLocals();
// Find the set of bad indexes. We add each such candidate to a worklist
// that we will then flow to find all those corrupted.
std::vector<bool> bad(numLocals);
UniqueDeferredQueue<Index> work;
for (Index i = 0; i < uses.size(); i++) {
assert(validUses[i] <= uses[i]);
if (uses[i] > 0 && validUses[i] < uses[i]) {
// This is a bad tuple.
work.push(i);
}
}
// Flow badness forward.
while (!work.empty()) {
auto i = work.pop();
if (bad[i]) {
continue;
}
bad[i] = true;
for (auto target : copiedIndexes[i]) {
work.push(target);
}
}
// Good indexes we can optimize are tuple locals with uses that are not bad.
std::vector<bool> good(numLocals);
bool hasGood = false;
for (Index i = 0; i < uses.size(); i++) {
if (uses[i] > 0 && !bad[i]) {
good[i] = true;
hasGood = true;
}
}
if (!hasGood) {
return;
}
// We found things to optimize! Create new non-tuple locals for their
// contents, and then rewrite the code to use those according to the
// mapping from tuple locals to normal ones. The mapping maps a tuple local
// to the base index used for its contents: an index and several others
// right after it, depending on the tuple size.
std::unordered_map<Index, Index> tupleToNewBaseMap;
for (Index i = 0; i < good.size(); i++) {
if (!good[i]) {
continue;
}
auto newBase = func->getNumLocals();
tupleToNewBaseMap[i] = newBase;
Index lastNewIndex = 0;
for (auto t : func->getLocalType(i)) {
Index newIndex = Builder::addVar(func, t);
if (lastNewIndex == 0) {
// This is the first new local we added (0 is an impossible value,
// since tuple locals exist, hence index 0 was already taken), so it
// must be equal to the base.
assert(newIndex == newBase);
} else {
// This must be right after the former.
assert(newIndex == lastNewIndex + 1);
}
lastNewIndex = newIndex;
}
}
MapApplier mapApplier(tupleToNewBaseMap);
mapApplier.walkFunctionInModule(func, getModule());
}
struct MapApplier : public PostWalker<MapApplier> {
std::unordered_map<Index, Index>& tupleToNewBaseMap;
MapApplier(std::unordered_map<Index, Index>& tupleToNewBaseMap)
: tupleToNewBaseMap(tupleToNewBaseMap) {}
// Gets the new base index if there is one, or 0 if not (0 is an impossible
// value for a new index, as local index 0 was taken before, as tuple
// locals existed).
Index getNewBaseIndex(Index i) {
auto iter = tupleToNewBaseMap.find(i);
if (iter == tupleToNewBaseMap.end()) {
return 0;
}
return iter->second;
}
// Given a local.get or local.set, return the new base index for the local
// index used there. Returns 0 (an impossible value, see above) otherwise.
Index getSetOrGetBaseIndex(Expression* setOrGet) {
Index index;
if (auto* set = setOrGet->dynCast<LocalSet>()) {
index = set->index;
} else if (auto* get = setOrGet->dynCast<LocalGet>()) {
index = get->index;
} else {
return 0;
}
return getNewBaseIndex(index);
}
// Replacing a local.tee requires some care, since we might have
//
// (local.set
// (local.tee
// ..
//
// We replace the local.tee with a block of sets of the new non-tuple
// locals, and the outer set must then (1) keep those around and also (2)
// identify the local that was tee'd, so we know what to get (which has been
// replaced by the block). To make that simple keep a map of the things that
// replaced tees.
std::unordered_map<Expression*, LocalSet*> replacedTees;
void visitLocalSet(LocalSet* curr) {
auto replace = [&](Expression* replacement) {
if (curr->isTee()) {
replacedTees[replacement] = curr;
}
replaceCurrent(replacement);
};
if (auto targetBase = getNewBaseIndex(curr->index)) {
Builder builder(*getModule());
auto type = getFunction()->getLocalType(curr->index);
auto* value = curr->value;
if (auto* make = value->dynCast<TupleMake>()) {
// Write each of the tuple.make fields into the proper local.
std::vector<Expression*> sets;
for (Index i = 0; i < type.size(); i++) {
auto* value = make->operands[i];
sets.push_back(builder.makeLocalSet(targetBase + i, value));
}
replace(builder.makeBlock(sets));
return;
}
std::vector<Expression*> contents;
auto iter = replacedTees.find(value);
if (iter != replacedTees.end()) {
// The input to us was a tee that has been replaced. The actual value
// we read from (the tee) can be found in replacedTees. Also, we
// need to keep around the replacement of the tee.
contents.push_back(value);
value = iter->second;
}
// This is a copy of a tuple local into another. Copy all the fields
// between them.
Index sourceBase = getSetOrGetBaseIndex(value);
// The target is being optimized, so the source must be as well, or else
// we were confused earlier and the target should not be.
assert(sourceBase);
// The source and target may have different element types due to
// subtyping (but their sizes must be equal).
auto sourceType = value->type;
assert(sourceType.size() == type.size());
for (Index i = 0; i < type.size(); i++) {
auto* get = builder.makeLocalGet(sourceBase + i, sourceType[i]);
contents.push_back(builder.makeLocalSet(targetBase + i, get));
}
replace(builder.makeBlock(contents));
}
}
void visitTupleExtract(TupleExtract* curr) {
auto* value = curr->tuple;
Expression* extraContents = nullptr;
auto iter = replacedTees.find(value);
if (iter != replacedTees.end()) {
// The input to us was a tee that has been replaced. Handle it as in
// visitLocalSet.
extraContents = value;
value = iter->second;
}
auto type = value->type;
if (type == Type::unreachable) {
return;
}
Index sourceBase = getSetOrGetBaseIndex(value);
if (!sourceBase) {
return;
}
Builder builder(*getModule());
auto i = curr->index;
auto* get = builder.makeLocalGet(sourceBase + i, type[i]);
if (extraContents) {
replaceCurrent(builder.makeSequence(extraContents, get));
} else {
replaceCurrent(get);
}
}
};
};
Pass* createTupleOptimizationPass() { return new TupleOptimization(); }
} // namespace wasm