blob: b7970cc7b9f3cc7c3eacb4a9ee6301a3c24c9346 [file] [log] [blame]
* Copyright 2022 WebAssembly Community Group participants
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
// Refine uses of locals where possible. For example, consider this:
// (some.operation
// (ref.cast .. (local.get $ref))
// (local.get $ref)
// )
// The second use might as well use the refined/cast value as well:
// (some.operation
// (local.tee $temp
// (ref.cast .. (local.get $ref))
// )
// (local.get $temp)
// )
// This change adds a local but it switches some local.gets to use a local of a
// more refined type. That can help other optimizations later.
// An example of an important pattern this handles are itable calls:
// (call_ref
// (ref.cast $actual.type
// (local.get $object)
// )
// (struct.get $vtable ..
// (ref.cast $vtable
// (struct.get $itable ..
// (local.get $object)
// )
// )
// )
// )
// We cast to the actual type for the |this| parameter, but we technically do
// not need to do so for reading its itable - since the itable may be of a
// generic type, and we cast the vtable afterwards anyhow. But since we cast
// |this|, we can use the cast value for the itable get, which may then lead to
// removing the vtable cast after we refine the itable type. And that can lead
// to devirtualization later.
// Closely related things appear in other passes:
// * SimplifyLocals will find locals already containing a more refined type and
// switch to them. RedundantSetElimination does the same across basic blocks.
// In theory one of them could be extended to also add new locals, and then
// they would be doing something similar to this pass.
// * LocalCSE finds repeated expressions and stores them in locals for use
// later. In theory that pass could be extended to look not for exact copies
// but for equivalent things through a cast, and then it would be doing
// something similar to this pass.
// However, while those other passes could be extended to cover what this pass
// does, we will have further cast-specific optimizations to add, which make
// sense in new pass anyhow, and things should be simpler overall to keep such
// casts all in one pass, here.
// TODO: Move casts earlier in a basic block as well, at least in traps-never-
// happen mode where we can assume they never fail.
// TODO: Look past individual basic blocks?
// TODO: When looking at Local Sets, check fallthroughs/descendants for casts
// instead of just the immediate child
#include "ir/linear-execution.h"
#include "ir/properties.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
namespace {
// Find the best casted verisons of local.gets: other local.gets with the same
// value, but cast to a more refined type.
struct BestCastFinder : public LinearExecutionWalker<BestCastFinder> {
PassOptions options;
// Map local indices to the most refined downcastings of local.gets from those
// indices.
// This is tracked in each basic block, and cleared between them.
std::unordered_map<Index, Expression*> mostCastedGets;
// Map local indices to the current downcasting of local.set to those indices.
// Also tracked in each basic block and cleared between them.
std::unordered_map<Index, Expression*> curCastedSets;
// For each most-downcasted local.get or local.set, a vector of other
// local.gets that could be replaced with gets of the downcasted value.
// This is tracked until the end of the entire function, and contains the
// information we need to optimize later. That is, entries here are things we
// want to apply.
std::unordered_map<Expression*, std::vector<LocalGet*>> lessCastedGets;
static void doNoteNonLinear(BestCastFinder* self, Expression** currp) {
void visitLocalSet(LocalSet* curr) {
// Clear any information about this local; it has a new value here.
// This only checks the immediate child for casts. This should be extended
// to look deeper for casts
if (curr->value->dynCast<RefAs>() || curr->value->dynCast<RefCast>()) {
curCastedSets[curr->index] = curr->value;
} else {
// If the local.set doesn't use a cast, get rid of any old cast information
void visitLocalGet(LocalGet* curr) {
auto getIter = mostCastedGets.find(curr->index);
auto setIter = curCastedSets.find(curr->index);
if (getIter != mostCastedGets.end()) {
auto* bestCast = getIter->second;
if (setIter != curCastedSets.end()) {
// Always use a cast in local.set if it is equal or better than
// a local.get since we know it is always before any gets that
// retrieve the set value from the index
if (bestCast->type == setIter->second->type ||
Type::isSubType(setIter->second->type, bestCast->type)) {
bestCast = setIter->second;
if (curr->type != bestCast->type &&
Type::isSubType(bestCast->type, curr->type)) {
// The best cast has a more refined type, note that we want to use it.
} else if (setIter != curCastedSets.end()) {
auto* setCast = setIter->second;
if (curr->type != setCast->type &&
Type::isSubType(setCast->type, curr->type)) {
void visitRefAs(RefAs* curr) { handleRefinement(curr); }
void visitRefCast(RefCast* curr) { handleRefinement(curr); }
void handleRefinement(Expression* curr) {
auto* fallthrough = Properties::getFallthrough(curr, options, *getModule());
if (auto* get = fallthrough->dynCast<LocalGet>()) {
auto*& bestCast = mostCastedGets[get->index];
if (!bestCast) {
// This is the first.
bestCast = curr;
// See if we are better than the current best.
if (curr->type != bestCast->type &&
Type::isSubType(curr->type, bestCast->type)) {
bestCast = curr;
// Given a set of best casts, apply them: save each best cast in a local and use
// it in the places that want to.
// It is simpler to do this in another pass after BestCastFinder so that we do
// not need to worry about corner cases with invalidation of pointers in things
// we've already walked past.
struct FindingApplier : public PostWalker<FindingApplier> {
BestCastFinder& finder;
FindingApplier(BestCastFinder& finder) : finder(finder) {}
void visitRefAs(RefAs* curr) { handleRefinement(curr); }
void visitRefCast(RefCast* curr) { handleRefinement(curr); }
void handleRefinement(Expression* curr) {
auto iter = finder.lessCastedGets.find(curr);
if (iter == finder.lessCastedGets.end()) {
// This expression was the best cast for some gets. Add a new local to
// store this value, then use it for the gets.
auto var = Builder::addVar(getFunction(), curr->type);
auto& gets = iter->second;
for (auto* get : gets) {
get->index = var;
get->type = curr->type;
// Replace ourselves with a tee.
replaceCurrent(Builder(*getModule()).makeLocalTee(var, curr, curr->type));
} // anonymous namespace
struct OptimizeCasts : public WalkerPass<PostWalker<OptimizeCasts>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<OptimizeCasts>();
void doWalkFunction(Function* func) {
if (!getModule()->features.hasGC()) {
// First, find the best casts that we want to use.
BestCastFinder finder;
finder.options = getPassOptions();
finder.walkFunctionInModule(func, getModule());
if (finder.lessCastedGets.empty()) {
// Nothing to do.
// Apply the requests: use the best casts.
FindingApplier applier(finder);
applier.walkFunctionInModule(func, getModule());
// LocalGet type changes must be propagated.
ReFinalize().walkFunctionInModule(func, getModule());
Pass* createOptimizeCastsPass() { return new OptimizeCasts(); }
} // namespace wasm