blob: 87da111f0368ab53ba239737dd2d62c9917d91cb [file] [log] [blame] [edit]
/*
* Copyright 2023 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unordered_map>
#include "ir/branch-utils.h"
#include "ir/subtypes.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "support/unique_deferring_queue.h"
#include "wasm-traversal.h"
#include "wasm-type.h"
#include "wasm.h"
// Compute and use the minimal subtype relation required to maintain module
// validity and behavior. This minimal relation will be a subset of the original
// subtype relation. Start by walking the IR and collecting pairs of types that
// need to be in the subtype relation for each expression to validate. For
// example, a local.set requires that the type of its operand be a subtype of
// the local's type. Casts do not generate subtypings at this point because it
// is not necessary for the cast target to be a subtype of the cast source for
// the cast to validate.
//
// From that initial subtype relation, we then start finding new subtypings that
// are required by the subtypings we have found already. These transitively
// required subtypings come from two sources.
//
// The first source is type definitions. Consider these type definitions:
//
// (type $A (sub (struct (ref $X))))
// (type $B (sub $A (struct (ref $Y))))
//
// If we have determined that $B must remain a subtype of $A, then we know that
// $Y must remain a subtype of $X as well, since the type definitions would not
// be valid otherwise. Similarly, knowing that $X must remain a subtype of $Y
// may transitively require other subtypings as well based on their type
// definitions.
//
// The second source of transitive subtyping requirements is casts. Although
// casting from one type to another does not necessarily require that those
// types are related, we do need to make sure that we do not change the
// behavior of casts by removing subtype relationships they might observe. For
// example, consider this module:
//
// (module
// ;; original subtyping: $bot <: $mid <: $top
// (type $top (sub (struct)))
// (type $mid (sub $top (struct)))
// (type $bot (sub $mid (struct)))
//
// (func $f
// (local $top (ref $top))
// (local $mid (ref $mid))
//
// ;; Requires $bot <: $top
// (local.set $top (struct.new $bot))
//
// ;; Cast $top to $mid
// (local.set $mid (ref.cast (ref $mid) (local.get $top)))
// )
// )
//
// The only subtype relation directly required by the IR for this module is $bot
// <: $top. However, if we optimized the module so that $bot <: $top was the
// only subtype relation, we would change the behavior of the cast. In the
// original module, a value of type (ref $bot) is cast to (ref $mid). The cast
// succeeds because in the original module, $bot <: $mid. If we optimize so that
// we have $bot <: $top and no other subtypings, though, the cast will fail
// because the value of type (ref $bot) no longer inhabits (ref $mid). To
// prevent the cast's behavior from changing, we need to ensure that $bot <:
// $mid.
//
// The set of subtyping requirements generated by a cast from $src to $dest is
// that for every known remaining subtype $v of $src, if $v <: $dest in the
// original module, then $v <: $dest in the optimized module. In other words,
// for every type $v of values we know can flow into the cast, if the cast would
// have succeeded for values of type $v before, then we know the cast must
// continue to succeed for values of type $v. These requirements arising from
// casts can also generate transitive requirements because we learn about new
// types of values that can flow into casts as we learn about new subtypes of
// cast sources.
//
// Starting with the initial subtype relation determined by walking the IR,
// repeatedly search for new subtypings by analyzing type definitions and casts
// in lock step until we reach a fixed point. This is the minimal subtype
// relation that preserves module validity and behavior that can be found
// without a more precise analysis of types that might flow into each cast.
namespace wasm {
namespace {
struct Unsubtyping
: WalkerPass<ControlFlowWalker<Unsubtyping, OverriddenVisitor<Unsubtyping>>> {
// The new set of supertype relations.
std::unordered_map<HeapType, HeapType> supertypes;
// Map from cast source types to their destinations.
std::unordered_map<HeapType, std::unordered_set<HeapType>> castTypes;
// The set of subtypes that need to have their type definitions analyzed to
// transitively find other subtype relations they depend on. We add to it
// every time we find a new subtype relationship we need to keep.
UniqueDeferredQueue<HeapType> work;
void run(Module* wasm) override {
if (!wasm->features.hasGC()) {
return;
}
analyzePublicTypes(*wasm);
walkModule(wasm);
analyzeTransitiveDependencies();
optimizeTypes(*wasm);
// Cast types may be refinable if their source and target types are no
// longer related. TODO: Experiment with running this only after checking
// whether it is necessary.
ReFinalize().run(getPassRunner(), wasm);
}
// Note that sub must remain a subtype of super.
void noteSubtype(HeapType sub, HeapType super) {
if (sub == super || sub.isBottom() || super.isBottom()) {
return;
}
auto [it, inserted] = supertypes.insert({sub, super});
if (inserted) {
work.push(sub);
// TODO: Incrementally check all subtypes (inclusive) of sub against super
// and all its supertypes if we have already analyzed casts.
return;
}
// We already had a recorded supertype. The new supertype might be deeper,
// shallower, or identical to the old supertype.
auto oldSuper = it->second;
if (super == oldSuper) {
return;
}
// There are two different supertypes, but each type can only have a single
// direct subtype so the supertype chain cannot fork and one of the
// supertypes must be a supertype of the other. Recursively record that
// relationship as well.
if (HeapType::isSubType(super, oldSuper)) {
// sub <: super <: oldSuper
it->second = super;
work.push(sub);
// TODO: Incrementally check all subtypes (inclusive) of sub against super
// if we have already analyzed casts.
noteSubtype(super, oldSuper);
} else {
// sub <: oldSuper <: super
noteSubtype(oldSuper, super);
}
}
void noteSubtype(Type sub, Type super) {
if (sub.isTuple()) {
assert(super.isTuple() && sub.size() == super.size());
for (size_t i = 0, size = sub.size(); i < size; ++i) {
noteSubtype(sub[i], super[i]);
}
return;
}
if (!sub.isRef() || !super.isRef()) {
return;
}
noteSubtype(sub.getHeapType(), super.getHeapType());
}
void noteCast(HeapType src, HeapType dest) {
if (src == dest || dest.isBottom()) {
return;
}
assert(HeapType::isSubType(dest, src));
castTypes[src].insert(dest);
}
void noteCast(Type src, Type dest) {
assert(!src.isTuple() && !dest.isTuple());
if (src == Type::unreachable) {
return;
}
assert(src.isRef() && dest.isRef());
noteCast(src.getHeapType(), dest.getHeapType());
}
void analyzePublicTypes(Module& wasm) {
// We cannot change supertypes for anything public.
for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) {
if (auto super = type.getDeclaredSuperType()) {
noteSubtype(type, *super);
}
}
}
void analyzeTransitiveDependencies() {
// While we have found new subtypings and have not reached a fixed point...
while (!work.empty()) {
// Subtype relationships that we are keeping might depend on other subtype
// relationships that we are not yet planning to keep. Transitively find
// all the relationships we need to keep all our type definitions valid.
while (!work.empty()) {
auto type = work.pop();
auto super = supertypes.at(type);
if (super.isBasic()) {
continue;
}
if (type.isStruct()) {
const auto& fields = type.getStruct().fields;
const auto& superFields = super.getStruct().fields;
for (size_t i = 0, size = superFields.size(); i < size; ++i) {
noteSubtype(fields[i].type, superFields[i].type);
}
} else if (type.isArray()) {
auto elem = type.getArray().element;
noteSubtype(elem.type, super.getArray().element.type);
} else {
assert(type.isSignature());
auto sig = type.getSignature();
auto superSig = super.getSignature();
noteSubtype(superSig.params, sig.params);
noteSubtype(sig.results, superSig.results);
}
}
// Analyze all casts at once.
// TODO: This is expensive. Analyze casts incrementally after we
// initially analyze them.
analyzeCasts();
}
}
void analyzeCasts() {
// For each cast (src, dest) pair, any type that remains a subtype of src
// (meaning its values can inhabit locations typed src) and that was
// originally a subtype of dest (meaning its values would have passed the
// cast) should remain a subtype of dest so that its values continue to pass
// the cast.
//
// For every type, walk up its new supertype chain to find cast sources and
// compare against their associated cast destinations.
for (auto it = supertypes.begin(); it != supertypes.end(); ++it) {
auto type = it->first;
for (auto srcIt = it; srcIt != supertypes.end();
srcIt = supertypes.find(srcIt->second)) {
auto src = srcIt->second;
auto destsIt = castTypes.find(src);
if (destsIt == castTypes.end()) {
continue;
}
for (auto dest : destsIt->second) {
if (HeapType::isSubType(type, dest)) {
noteSubtype(type, dest);
}
}
}
}
}
void optimizeTypes(Module& wasm) {
struct Rewriter : GlobalTypeRewriter {
Unsubtyping& parent;
Rewriter(Unsubtyping& parent, Module& wasm)
: GlobalTypeRewriter(wasm), parent(parent) {}
std::optional<HeapType> getDeclaredSuperType(HeapType type) override {
if (auto it = parent.supertypes.find(type);
it != parent.supertypes.end() && !it->second.isBasic()) {
return it->second;
}
return std::nullopt;
}
};
Rewriter(*this, wasm).update();
}
void doWalkModule(Module* wasm) {
// Visit the functions in parallel, filling in `supertypes` and `castTypes`
// on separate instances which will later be merged.
ModuleUtils::ParallelFunctionAnalysis<Unsubtyping> analysis(
*wasm, [&](Function* func, Unsubtyping& unsubtyping) {
if (!func->imported()) {
unsubtyping.walkFunctionInModule(func, wasm);
}
});
// Collect the results from the functions.
for (auto& [_, unsubtyping] : analysis.map) {
for (auto [sub, super] : unsubtyping.supertypes) {
noteSubtype(sub, super);
}
for (auto& [src, dests] : unsubtyping.castTypes) {
for (auto dest : dests) {
noteCast(src, dest);
}
}
}
// Collect constraints from top-level items.
for (auto& global : wasm->globals) {
visitGlobal(global.get());
}
for (auto& seg : wasm->elementSegments) {
visitElementSegment(seg.get());
}
// Visit the rest of the code that is not in functions.
walkModuleCode(wasm);
}
void visitFunction(Function* func) {
if (func->body) {
noteSubtype(func->body->type, func->getResults());
}
}
void visitGlobal(Global* global) {
if (global->init) {
noteSubtype(global->init->type, global->type);
}
}
void visitElementSegment(ElementSegment* seg) {
if (seg->offset) {
noteSubtype(seg->type, getModule()->getTable(seg->table)->type);
}
for (auto init : seg->data) {
noteSubtype(init->type, seg->type);
}
}
void visitNop(Nop* curr) {}
void visitBlock(Block* curr) {
if (!curr->list.empty()) {
noteSubtype(curr->list.back()->type, curr->type);
}
}
void visitIf(If* curr) {
if (curr->ifFalse) {
noteSubtype(curr->ifTrue->type, curr->type);
noteSubtype(curr->ifFalse->type, curr->type);
}
}
void visitLoop(Loop* curr) { noteSubtype(curr->body->type, curr->type); }
void visitBreak(Break* curr) {
if (curr->value) {
noteSubtype(curr->value->type, findBreakTarget(curr->name)->type);
}
}
void visitSwitch(Switch* curr) {
if (curr->value) {
for (auto name : BranchUtils::getUniqueTargets(curr)) {
noteSubtype(curr->value->type, findBreakTarget(name)->type);
}
}
}
template<typename T> void handleCall(T* curr, Signature sig) {
assert(curr->operands.size() == sig.params.size());
for (size_t i = 0, size = sig.params.size(); i < size; ++i) {
noteSubtype(curr->operands[i]->type, sig.params[i]);
}
if (curr->isReturn) {
noteSubtype(sig.results, getFunction()->getResults());
}
}
void visitCall(Call* curr) {
handleCall(curr, getModule()->getFunction(curr->target)->getSig());
}
void visitCallIndirect(CallIndirect* curr) {
handleCall(curr, curr->heapType.getSignature());
auto* table = getModule()->getTable(curr->table);
auto tableType = table->type.getHeapType();
if (HeapType::isSubType(tableType, curr->heapType)) {
// Unlike other casts, where cast targets are always subtypes of cast
// sources, call_indirect target types may be supertypes of their source
// table types. In this case, the cast will always succeed, but only if we
// keep the types related.
noteSubtype(tableType, curr->heapType);
} else if (HeapType::isSubType(curr->heapType, tableType)) {
noteCast(tableType, curr->heapType);
} else {
// The types are unrelated and the cast will fail. We can keep the types
// unrelated.
}
}
void visitLocalGet(LocalGet* curr) {}
void visitLocalSet(LocalSet* curr) {
noteSubtype(curr->value->type, getFunction()->getLocalType(curr->index));
}
void visitGlobalGet(GlobalGet* curr) {}
void visitGlobalSet(GlobalSet* curr) {
noteSubtype(curr->value->type, getModule()->getGlobal(curr->name)->type);
}
void visitLoad(Load* curr) {}
void visitStore(Store* curr) {}
void visitAtomicRMW(AtomicRMW* curr) {}
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {}
void visitAtomicWait(AtomicWait* curr) {}
void visitAtomicNotify(AtomicNotify* curr) {}
void visitAtomicFence(AtomicFence* curr) {}
void visitSIMDExtract(SIMDExtract* curr) {}
void visitSIMDReplace(SIMDReplace* curr) {}
void visitSIMDShuffle(SIMDShuffle* curr) {}
void visitSIMDTernary(SIMDTernary* curr) {}
void visitSIMDShift(SIMDShift* curr) {}
void visitSIMDLoad(SIMDLoad* curr) {}
void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) {}
void visitMemoryInit(MemoryInit* curr) {}
void visitDataDrop(DataDrop* curr) {}
void visitMemoryCopy(MemoryCopy* curr) {}
void visitMemoryFill(MemoryFill* curr) {}
void visitConst(Const* curr) {}
void visitUnary(Unary* curr) {}
void visitBinary(Binary* curr) {}
void visitSelect(Select* curr) {
noteSubtype(curr->ifTrue->type, curr->type);
noteSubtype(curr->ifFalse->type, curr->type);
}
void visitDrop(Drop* curr) {}
void visitReturn(Return* curr) {
if (curr->value) {
noteSubtype(curr->value->type, getFunction()->getResults());
}
}
void visitMemorySize(MemorySize* curr) {}
void visitMemoryGrow(MemoryGrow* curr) {}
void visitUnreachable(Unreachable* curr) {}
void visitPop(Pop* curr) {}
void visitRefNull(RefNull* curr) {}
void visitRefIsNull(RefIsNull* curr) {}
void visitRefFunc(RefFunc* curr) {}
void visitRefEq(RefEq* curr) {}
void visitTableGet(TableGet* curr) {}
void visitTableSet(TableSet* curr) {
noteSubtype(curr->value->type, getModule()->getTable(curr->table)->type);
}
void visitTableSize(TableSize* curr) {}
void visitTableGrow(TableGrow* curr) {}
void visitTableFill(TableFill* curr) {
noteSubtype(curr->value->type, getModule()->getTable(curr->table)->type);
}
void visitTry(Try* curr) {
noteSubtype(curr->body->type, curr->type);
for (auto* body : curr->catchBodies) {
noteSubtype(body->type, curr->type);
}
}
void visitThrow(Throw* curr) {
Type params = getModule()->getTag(curr->tag)->sig.params;
assert(params.size() == curr->operands.size());
for (size_t i = 0, size = curr->operands.size(); i < size; ++i) {
noteSubtype(curr->operands[i]->type, params[i]);
}
}
void visitRethrow(Rethrow* curr) {}
void visitTupleMake(TupleMake* curr) {}
void visitTupleExtract(TupleExtract* curr) {}
void visitRefI31(RefI31* curr) {}
void visitI31Get(I31Get* curr) {}
void visitCallRef(CallRef* curr) {
if (!curr->target->type.isSignature()) {
return;
}
handleCall(curr, curr->target->type.getHeapType().getSignature());
}
void visitRefTest(RefTest* curr) {
noteCast(curr->ref->type, curr->castType);
}
void visitRefCast(RefCast* curr) { noteCast(curr->ref->type, curr->type); }
void visitBrOn(BrOn* curr) {
if (curr->op == BrOnCast || curr->op == BrOnCastFail) {
noteCast(curr->ref->type, curr->castType);
}
noteSubtype(curr->getSentType(), findBreakTarget(curr->name)->type);
}
void visitStructNew(StructNew* curr) {
if (!curr->type.isStruct() || curr->isWithDefault()) {
return;
}
const auto& fields = curr->type.getHeapType().getStruct().fields;
assert(fields.size() == curr->operands.size());
for (size_t i = 0, size = fields.size(); i < size; ++i) {
noteSubtype(curr->operands[i]->type, fields[i].type);
}
}
void visitStructGet(StructGet* curr) {}
void visitStructSet(StructSet* curr) {
if (!curr->ref->type.isStruct()) {
return;
}
const auto& fields = curr->ref->type.getHeapType().getStruct().fields;
noteSubtype(curr->value->type, fields[curr->index].type);
}
void visitArrayNew(ArrayNew* curr) {
if (!curr->type.isArray() || curr->isWithDefault()) {
return;
}
auto array = curr->type.getHeapType().getArray();
noteSubtype(curr->init->type, array.element.type);
}
void visitArrayNewData(ArrayNewData* curr) {}
void visitArrayNewElem(ArrayNewElem* curr) {
if (!curr->type.isArray()) {
return;
}
auto array = curr->type.getHeapType().getArray();
auto* seg = getModule()->getElementSegment(curr->segment);
noteSubtype(seg->type, array.element.type);
}
void visitArrayNewFixed(ArrayNewFixed* curr) {
if (!curr->type.isArray()) {
return;
}
auto array = curr->type.getHeapType().getArray();
for (auto* value : curr->values) {
noteSubtype(value->type, array.element.type);
}
}
void visitArrayGet(ArrayGet* curr) {}
void visitArraySet(ArraySet* curr) {
if (!curr->ref->type.isArray()) {
return;
}
auto array = curr->ref->type.getHeapType().getArray();
noteSubtype(curr->value->type, array.element.type);
}
void visitArrayLen(ArrayLen* curr) {}
void visitArrayCopy(ArrayCopy* curr) {
if (!curr->srcRef->type.isArray() || !curr->destRef->type.isArray()) {
return;
}
auto src = curr->srcRef->type.getHeapType().getArray();
auto dest = curr->destRef->type.getHeapType().getArray();
noteSubtype(src.element.type, dest.element.type);
}
void visitArrayFill(ArrayFill* curr) {
if (!curr->ref->type.isArray()) {
return;
}
auto array = curr->ref->type.getHeapType().getArray();
noteSubtype(curr->value->type, array.element.type);
}
void visitArrayInitData(ArrayInitData* curr) {}
void visitArrayInitElem(ArrayInitElem* curr) {
if (!curr->ref->type.isArray()) {
return;
}
auto array = curr->ref->type.getHeapType().getArray();
auto* seg = getModule()->getElementSegment(curr->segment);
noteSubtype(seg->type, array.element.type);
}
void visitRefAs(RefAs* curr) {}
void visitStringNew(StringNew* curr) {}
void visitStringConst(StringConst* curr) {}
void visitStringMeasure(StringMeasure* curr) {}
void visitStringEncode(StringEncode* curr) {}
void visitStringConcat(StringConcat* curr) {}
void visitStringEq(StringEq* curr) {}
void visitStringAs(StringAs* curr) {}
void visitStringWTF8Advance(StringWTF8Advance* curr) {}
void visitStringWTF16Get(StringWTF16Get* curr) {}
void visitStringIterNext(StringIterNext* curr) {}
void visitStringIterMove(StringIterMove* curr) {}
void visitStringSliceWTF(StringSliceWTF* curr) {}
void visitStringSliceIter(StringSliceIter* curr) {}
};
} // anonymous namespace
Pass* createUnsubtypingPass() { return new Unsubtyping(); }
} // namespace wasm