| /* |
| * Copyright 2016 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // |
| // Removes module elements that are are never used: functions and globals, |
| // which may be imported or not, and function types (which we merge |
| // and remove if unneeded) |
| // |
| |
| |
| #include <memory> |
| |
| #include "wasm.h" |
| #include "pass.h" |
| #include "ir/utils.h" |
| #include "asm_v_wasm.h" |
| |
| namespace wasm { |
| |
| enum class ModuleElementKind { |
| Function, |
| Global |
| }; |
| |
| typedef std::pair<ModuleElementKind, Name> ModuleElement; |
| |
| // Finds reachabilities |
| |
| struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { |
| Module* module; |
| std::vector<ModuleElement> queue; |
| std::set<ModuleElement> reachable; |
| bool usesMemory = false; |
| bool usesTable = false; |
| |
| ReachabilityAnalyzer(Module* module, const std::vector<ModuleElement>& roots) : module(module) { |
| queue = roots; |
| // Globals used in memory/table init expressions are also roots |
| for (auto& segment : module->memory.segments) { |
| walk(segment.offset); |
| } |
| for (auto& segment : module->table.segments) { |
| walk(segment.offset); |
| } |
| // main loop |
| while (queue.size()) { |
| auto& curr = queue.back(); |
| queue.pop_back(); |
| if (reachable.count(curr) == 0) { |
| reachable.insert(curr); |
| if (curr.first == ModuleElementKind::Function) { |
| // if not an import, walk it |
| auto* func = module->getFunctionOrNull(curr.second); |
| if (func) { |
| walk(func->body); |
| } |
| } else { |
| // if not imported, it has an init expression we need to walk |
| auto* glob = module->getGlobalOrNull(curr.second); |
| if (glob) { |
| walk(glob->init); |
| } |
| } |
| } |
| } |
| } |
| |
| void visitCall(Call* curr) { |
| if (reachable.count(ModuleElement(ModuleElementKind::Function, curr->target)) == 0) { |
| queue.emplace_back(ModuleElementKind::Function, curr->target); |
| } |
| } |
| void visitCallImport(CallImport* curr) { |
| if (reachable.count(ModuleElement(ModuleElementKind::Function, curr->target)) == 0) { |
| queue.emplace_back(ModuleElementKind::Function, curr->target); |
| } |
| } |
| void visitCallIndirect(CallIndirect* curr) { |
| usesTable = true; |
| } |
| |
| void visitGetGlobal(GetGlobal* curr) { |
| if (reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0) { |
| queue.emplace_back(ModuleElementKind::Global, curr->name); |
| } |
| } |
| void visitSetGlobal(SetGlobal* curr) { |
| if (reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0) { |
| queue.emplace_back(ModuleElementKind::Global, curr->name); |
| } |
| } |
| |
| void visitLoad(Load* curr) { |
| usesMemory = true; |
| } |
| void visitStore(Store* curr) { |
| usesMemory = true; |
| } |
| void visitAtomicCmpxchg(AtomicCmpxchg* curr) { |
| usesMemory = true; |
| } |
| void visitAtomicRMW(AtomicRMW* curr) { |
| usesMemory = true; |
| } |
| void visitAtomicWait(AtomicWait* curr) { |
| usesMemory = true; |
| } |
| void visitAtomicWake(AtomicWake* curr) { |
| usesMemory = true; |
| } |
| void visitHost(Host* curr) { |
| if (curr->op == CurrentMemory || curr->op == GrowMemory) { |
| usesMemory = true; |
| } |
| } |
| }; |
| |
| // Finds function type usage |
| |
| struct FunctionTypeAnalyzer : public PostWalker<FunctionTypeAnalyzer> { |
| std::vector<Import*> functionImports; |
| std::vector<Function*> functions; |
| std::vector<CallIndirect*> indirectCalls; |
| |
| void visitImport(Import* curr) { |
| if (curr->kind == ExternalKind::Function && curr->functionType.is()) { |
| functionImports.push_back(curr); |
| } |
| } |
| |
| void visitFunction(Function* curr) { |
| if (curr->type.is()) { |
| functions.push_back(curr); |
| } |
| } |
| |
| void visitCallIndirect(CallIndirect* curr) { |
| indirectCalls.push_back(curr); |
| } |
| }; |
| |
| struct RemoveUnusedModuleElements : public Pass { |
| void run(PassRunner* runner, Module* module) override { |
| optimizeGlobalsAndFunctions(module); |
| optimizeFunctionTypes(module); |
| } |
| |
| void optimizeGlobalsAndFunctions(Module* module) { |
| std::vector<ModuleElement> roots; |
| // Module start is a root. |
| if (module->start.is()) { |
| auto startFunction = module->getFunction(module->start); |
| // Can be skipped if the start function is empty. |
| if (startFunction->body->is<Nop>()) { |
| module->start.clear(); |
| } else { |
| roots.emplace_back(ModuleElementKind::Function, module->start); |
| } |
| } |
| // Exports are roots. |
| bool exportsMemory = false; |
| bool exportsTable = false; |
| for (auto& curr : module->exports) { |
| if (curr->kind == ExternalKind::Function) { |
| roots.emplace_back(ModuleElementKind::Function, curr->value); |
| } else if (curr->kind == ExternalKind::Global) { |
| roots.emplace_back(ModuleElementKind::Global, curr->value); |
| } else if (curr->kind == ExternalKind::Memory) { |
| exportsMemory = true; |
| } else if (curr->kind == ExternalKind::Table) { |
| exportsTable = true; |
| } |
| } |
| // For now, all functions that can be called indirectly are marked as roots. |
| for (auto& segment : module->table.segments) { |
| for (auto& curr : segment.data) { |
| roots.emplace_back(ModuleElementKind::Function, curr); |
| } |
| } |
| // Compute reachability starting from the root set. |
| ReachabilityAnalyzer analyzer(module, roots); |
| // Remove unreachable elements. |
| { |
| auto& v = module->functions; |
| v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Function>& curr) { |
| return analyzer.reachable.count(ModuleElement(ModuleElementKind::Function, curr->name)) == 0; |
| }), v.end()); |
| } |
| { |
| auto& v = module->globals; |
| v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Global>& curr) { |
| return analyzer.reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0; |
| }), v.end()); |
| } |
| { |
| auto& v = module->imports; |
| v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Import>& curr) { |
| if (curr->kind == ExternalKind::Function) { |
| return analyzer.reachable.count(ModuleElement(ModuleElementKind::Function, curr->name)) == 0; |
| } else if (curr->kind == ExternalKind::Global) { |
| return analyzer.reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0; |
| } |
| return false; |
| }), v.end()); |
| } |
| module->updateMaps(); |
| // Handle the memory and table |
| if (!exportsMemory && !analyzer.usesMemory && module->memory.segments.empty()) { |
| module->memory.exists = false; |
| module->memory.imported = false; |
| module->memory.initial = 0; |
| module->memory.max = 0; |
| removeImport(ExternalKind::Memory, module); |
| } |
| if (!exportsTable && !analyzer.usesTable && module->table.segments.empty()) { |
| module->table.exists = false; |
| module->table.imported = false; |
| module->table.initial = 0; |
| module->table.max = 0; |
| removeImport(ExternalKind::Table, module); |
| } |
| } |
| |
| void removeImport(ExternalKind kind, Module* module) { |
| auto& v = module->imports; |
| v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Import>& curr) { |
| return curr->kind == kind; |
| }), v.end()); |
| } |
| |
| void optimizeFunctionTypes(Module* module) { |
| FunctionTypeAnalyzer analyzer; |
| analyzer.walkModule(module); |
| // maps each string signature to a single canonical function type |
| std::unordered_map<std::string, FunctionType*> canonicals; |
| std::unordered_set<FunctionType*> needed; |
| auto canonicalize = [&](Name name) { |
| if (!name.is()) return name; |
| FunctionType* type = module->getFunctionType(name); |
| auto sig = getSig(type); |
| auto iter = canonicals.find(sig); |
| if (iter == canonicals.end()) { |
| needed.insert(type); |
| canonicals[sig] = type; |
| return type->name; |
| } else { |
| return iter->second->name; |
| } |
| }; |
| // canonicalize all uses of function types |
| for (auto* import : analyzer.functionImports) { |
| import->functionType = canonicalize(import->functionType); |
| } |
| for (auto* func : analyzer.functions) { |
| func->type = canonicalize(func->type); |
| } |
| for (auto* call : analyzer.indirectCalls) { |
| call->fullType = canonicalize(call->fullType); |
| } |
| // remove no-longer used types |
| module->functionTypes.erase(std::remove_if(module->functionTypes.begin(), module->functionTypes.end(), [&needed](std::unique_ptr<FunctionType>& type) { |
| return needed.count(type.get()) == 0; |
| }), module->functionTypes.end()); |
| module->updateMaps(); |
| } |
| }; |
| |
| Pass* createRemoveUnusedModuleElementsPass() { |
| return new RemoveUnusedModuleElements(); |
| } |
| |
| } // namespace wasm |