| /* |
| * Copyright 2024 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // |
| // Utilities for lowering strings into simpler things. |
| // |
| // StringGathering collects all string.const operations and stores them in |
| // globals, avoiding them appearing in code that can run more than once (which |
| // can have overhead in VMs). |
| // |
| // StringLowering does the same, and also replaces those new globals with |
| // imported globals of type externref, for use with the string imports proposal. |
| // |
| // A pass argument allows customizing the module name for string constants: |
| // |
| // --pass-arg=string-constants-module@MODULE_NAME |
| // |
| // Specs: |
| // https://github.com/WebAssembly/stringref/blob/main/proposals/stringref/Overview.md |
| // https://github.com/WebAssembly/js-string-builtins/blob/main/proposals/js-string-builtins/Overview.md |
| // |
| |
| #include <algorithm> |
| |
| #include "ir/module-utils.h" |
| #include "ir/names.h" |
| #include "ir/type-updating.h" |
| #include "ir/utils.h" |
| #include "pass.h" |
| #include "passes/string-utils.h" |
| #include "support/string.h" |
| #include "wasm-builder.h" |
| #include "wasm-type.h" |
| #include "wasm.h" |
| |
| namespace wasm { |
| |
| struct StringGathering : public Pass { |
| // All the strings we found in the module. |
| std::vector<Name> strings; |
| |
| // Pointers to all StringConsts, so that we can replace them. |
| using StringPtrs = std::vector<Expression**>; |
| StringPtrs stringPtrs; |
| |
| // Main entry point. |
| void run(Module* module) override { |
| processModule(module); |
| addGlobals(module); |
| replaceStrings(module); |
| } |
| |
| // Scan the entire wasm to find the relevant strings to populate our global |
| // data structures. |
| void processModule(Module* module) { |
| struct StringWalker : public PostWalker<StringWalker> { |
| StringPtrs& stringPtrs; |
| |
| StringWalker(StringPtrs& stringPtrs) : stringPtrs(stringPtrs) {} |
| |
| void visitStringConst(StringConst* curr) { |
| stringPtrs.push_back(getCurrentPointer()); |
| } |
| }; |
| |
| ModuleUtils::ParallelFunctionAnalysis<StringPtrs> analysis( |
| *module, [&](Function* func, StringPtrs& stringPtrs) { |
| if (!func->imported()) { |
| StringWalker(stringPtrs).walk(func->body); |
| } |
| }); |
| |
| // Also walk the global module code (for simplicity, also add it to the |
| // function map, using a "function" key of nullptr). |
| auto& globalStrings = analysis.map[nullptr]; |
| StringWalker(globalStrings).walkModuleCode(module); |
| |
| // Combine all the strings. |
| std::unordered_set<Name> stringSet; |
| for (auto& [_, currStringPtrs] : analysis.map) { |
| for (auto** stringPtr : currStringPtrs) { |
| stringSet.insert((*stringPtr)->cast<StringConst>()->string); |
| stringPtrs.push_back(stringPtr); |
| } |
| } |
| |
| // Sort the strings for determinism (alphabetically). |
| strings = std::vector<Name>(stringSet.begin(), stringSet.end()); |
| std::sort(strings.begin(), strings.end()); |
| } |
| |
| // For each string, the name of the global that replaces it. |
| std::unordered_map<Name, Name> stringToGlobalName; |
| |
| Type nnstringref = Type(HeapType::string, NonNullable); |
| |
| // Existing globals already in the form we emit can be reused. That is, if |
| // we see |
| // |
| // (global $foo (ref string) (string.const ..)) |
| // |
| // then we can just use that as the global for that string. This avoids |
| // repeated executions of the pass adding more and more globals. |
| // |
| // Any time we reuse a global, we must not modify its body (or else we'd |
| // replace the global that all others read from); we note them here and |
| // avoid them in replaceStrings later to avoid such trampling. |
| std::unordered_set<Expression**> stringPtrsToPreserve; |
| |
| void addGlobals(Module* module) { |
| // The names of the globals that define a string. Such globals may be |
| // referred to by others, and so we will need to sort them, later. |
| std::unordered_set<Name> definingNames; |
| |
| // Find globals to reuse (see comment on stringPtrsToPreserve for context). |
| for (auto& global : module->globals) { |
| if (global->type == nnstringref && !global->imported() && |
| !global->mutable_) { |
| if (auto* stringConst = global->init->dynCast<StringConst>()) { |
| auto& globalName = stringToGlobalName[stringConst->string]; |
| if (!globalName.is()) { |
| // This is the first global for this string, use it. |
| globalName = global->name; |
| stringPtrsToPreserve.insert(&global->init); |
| } |
| } |
| } |
| } |
| |
| Builder builder(*module); |
| for (Index i = 0; i < strings.size(); i++) { |
| auto& globalName = stringToGlobalName[strings[i]]; |
| if (globalName.is()) { |
| // We are reusing a global for this one, with its existing name. |
| definingNames.insert(globalName); |
| continue; |
| } |
| |
| auto& string = strings[i]; |
| // Re-encode from WTF-16 to WTF-8 to make the name easier to read. |
| std::stringstream wtf8; |
| [[maybe_unused]] bool valid = |
| String::convertWTF16ToWTF8(wtf8, string.str); |
| assert(valid); |
| // Then escape it because identifiers must be valid UTF-8. |
| // TODO: Use wtf8.view() and escaped.view() once we have C++20. |
| std::stringstream escaped; |
| String::printEscaped(escaped, wtf8.str()); |
| auto name = Names::getValidGlobalName( |
| *module, std::string("string.const_") + std::string(escaped.str())); |
| globalName = name; |
| definingNames.insert(name); |
| auto* stringConst = builder.makeStringConst(string); |
| auto global = |
| builder.makeGlobal(name, nnstringref, stringConst, Builder::Immutable); |
| module->addGlobal(std::move(global)); |
| } |
| |
| // Sort defining globals to the start, as other global initializers may use |
| // them (and it would be invalid for us to appear after a use). This sort is |
| // a simple way to ensure that we validate, but it may be unoptimal (we |
| // leave that for reorder-globals). |
| std::stable_sort( |
| module->globals.begin(), |
| module->globals.end(), |
| [&](const std::unique_ptr<Global>& a, const std::unique_ptr<Global>& b) { |
| return definingNames.count(a->name) && !definingNames.count(b->name); |
| }); |
| } |
| |
| void replaceStrings(Module* module) { |
| Builder builder(*module); |
| for (auto** stringPtr : stringPtrs) { |
| if (stringPtrsToPreserve.count(stringPtr)) { |
| continue; |
| } |
| auto* stringConst = (*stringPtr)->cast<StringConst>(); |
| auto globalName = stringToGlobalName[stringConst->string]; |
| *stringPtr = builder.makeGlobalGet(globalName, nnstringref); |
| } |
| } |
| }; |
| |
| struct StringLowering : public StringGathering { |
| // If true, then encode well-formed strings as (import "'" "string...") |
| // instead of emitting them into the JSON custom section. |
| bool useMagicImports; |
| |
| // Whether to throw a fatal error on non-UTF8 strings that would not be able |
| // to use the "magic import" mechanism. Only usable in conjunction with magic |
| // imports. |
| bool assertUTF8; |
| |
| StringLowering(bool useMagicImports = false, bool assertUTF8 = false) |
| : useMagicImports(useMagicImports), assertUTF8(assertUTF8) { |
| // If we are asserting valid UTF-8, we must be using magic imports. |
| assert(!assertUTF8 || useMagicImports); |
| } |
| |
| void run(Module* module) override { |
| if (!module->features.has(FeatureSet::Strings)) { |
| return; |
| } |
| |
| // First, run the gathering operation so all string.consts are in one place. |
| StringGathering::run(module); |
| |
| // Remove all HeapType::string etc. in favor of externref. |
| updateTypes(module); |
| |
| // Lower the string.const globals into imports. |
| makeImports(module); |
| |
| // Replace string.* etc. operations with imported ones. |
| replaceInstructions(module); |
| |
| // ReFinalize to apply all the above changes. |
| ReFinalize().run(getPassRunner(), module); |
| |
| // Disable the feature here after we lowered everything away. |
| module->features.disable(FeatureSet::Strings); |
| } |
| |
| void makeImports(Module* module) { |
| Name stringConstsModule = |
| getArgumentOrDefault("string-constants-module", WasmStringConstsModule); |
| Index jsonImportIndex = 0; |
| std::stringstream json; |
| bool first = true; |
| for (auto& global : module->globals) { |
| if (global->init) { |
| if (auto* c = global->init->dynCast<StringConst>()) { |
| std::stringstream utf8; |
| if (useMagicImports && |
| String::convertUTF16ToUTF8(utf8, c->string.str)) { |
| global->module = stringConstsModule; |
| global->base = Name(utf8.str()); |
| } else { |
| if (assertUTF8) { |
| std::stringstream escaped; |
| String::printEscaped(escaped, utf8.str()); |
| Fatal() << "Cannot lower non-UTF-16 string " << escaped.str() |
| << '\n'; |
| } |
| global->module = "string.const"; |
| global->base = std::to_string(jsonImportIndex); |
| if (first) { |
| first = false; |
| } else { |
| json << ','; |
| } |
| String::printEscapedJSON(json, c->string.str); |
| jsonImportIndex++; |
| } |
| global->init = nullptr; |
| } |
| } |
| } |
| |
| auto jsonString = json.str(); |
| if (!jsonString.empty()) { |
| // If we are asserting UTF8, then we shouldn't be generating any JSON. |
| assert(!assertUTF8); |
| // Add a custom section with the JSON. |
| auto str = '[' + jsonString + ']'; |
| auto vec = std::vector<char>(str.begin(), str.end()); |
| module->customSections.emplace_back( |
| CustomSection{"string.consts", std::move(vec)}); |
| } |
| } |
| |
| // Common types used in imports. |
| Type nullArray16 = Type(HeapTypes::getMutI16Array(), Nullable); |
| Type nullExt = Type(HeapType::ext, Nullable); |
| Type nnExt = Type(HeapType::ext, NonNullable); |
| |
| void updateTypes(Module* module) { |
| TypeMapper::TypeUpdates updates; |
| |
| // TypeMapper will not handle public types, but we do want to modify them as |
| // well: we are modifying the public ABI here. We can't simply tell |
| // TypeMapper to consider them private, as then they'd end up in the new big |
| // rec group with the private types (and as they are public, that would make |
| // the entire rec group public, and all types in the module with it). |
| // Instead, manually handle singleton-rec groups of function types. This |
| // keeps them at size 1, as expected, and handles the cases of function |
| // imports and exports. If we need more (non-function types, non-singleton |
| // rec groups, etc.) then more work will be necessary TODO |
| // |
| // Note that we do this before TypeMapper, which allows it to then fix up |
| // things like the types of parameters (which depend on the type of the |
| // function, which must be modified either in TypeMapper - but as just |
| // explained we cannot do that - or before it, which is what we do here). |
| for (auto& func : module->functions) { |
| if (func->type.getHeapType().getRecGroup().size() != 1 || |
| !func->type.getFeatures().hasStrings()) { |
| continue; |
| } |
| |
| // Fix up the stringrefs in this type that uses strings and is in a |
| // singleton rec group. |
| std::vector<Type> params, results; |
| auto fix = [](Type t) { |
| if (t.isRef() && t.getHeapType().isMaybeShared(HeapType::string)) { |
| auto share = t.getHeapType().getShared(); |
| t = Type(HeapTypes::ext.getBasic(share), t.getNullability()); |
| } |
| return t; |
| }; |
| for (auto param : func->type.getHeapType().getSignature().params) { |
| params.push_back(fix(param)); |
| } |
| for (auto result : func->type.getHeapType().getSignature().results) { |
| results.push_back(fix(result)); |
| } |
| |
| // In addition to doing the update, mark it in the map of updates for |
| // TypeMapper, so RefFuncs with this type get updated. |
| auto old = func->type; |
| func->type = func->type.with(Signature(params, results)); |
| updates[old.getHeapType()] = func->type.getHeapType(); |
| } |
| |
| // Strings turn into externref. |
| updates[HeapType::string] = HeapType::ext; |
| |
| TypeMapper(*module, updates).map(); |
| } |
| |
| // Imported string functions. |
| Name fromCharCodeArrayImport; |
| Name intoCharCodeArrayImport; |
| Name fromCodePointImport; |
| Name concatImport; |
| Name equalsImport; |
| Name testImport; |
| Name compareImport; |
| Name lengthImport; |
| Name charCodeAtImport; |
| Name substringImport; |
| |
| // Creates an imported string function, returning its name (which is equal to |
| // the true name of the import, if there is no conflict). |
| Name addImport(Module* module, Name trueName, Type params, Type results) { |
| auto name = Names::getValidFunctionName(*module, trueName); |
| auto sig = Signature(params, results); |
| Builder builder(*module); |
| auto* func = module->addFunction( |
| builder.makeFunction(name, Type(sig, NonNullable, Inexact), {})); |
| func->module = WasmStringsModule; |
| func->base = trueName; |
| return name; |
| } |
| |
| void replaceInstructions(Module* module) { |
| // Add all the possible imports up front, to avoid adding them during |
| // parallel work. Optimizations can remove unneeded ones later. |
| |
| // string.fromCharCodeArray: array, start, end -> ext |
| fromCharCodeArrayImport = addImport( |
| module, "fromCharCodeArray", {nullArray16, Type::i32, Type::i32}, nnExt); |
| // string.fromCodePoint: codepoint -> ext |
| fromCodePointImport = addImport(module, "fromCodePoint", Type::i32, nnExt); |
| // string.concat: string, string -> string |
| concatImport = addImport(module, "concat", {nullExt, nullExt}, nnExt); |
| // string.intoCharCodeArray: string, array, start -> num written |
| intoCharCodeArrayImport = addImport(module, |
| "intoCharCodeArray", |
| {nullExt, nullArray16, Type::i32}, |
| Type::i32); |
| // string.equals: string, string -> i32 |
| equalsImport = addImport(module, "equals", {nullExt, nullExt}, Type::i32); |
| // string.test: externref -> i32 |
| testImport = addImport(module, "test", {nullExt}, Type::i32); |
| // string.compare: string, string -> i32 |
| compareImport = addImport(module, "compare", {nullExt, nullExt}, Type::i32); |
| // string.length: string -> i32 |
| lengthImport = addImport(module, "length", nullExt, Type::i32); |
| // string.codePointAt: string, offset -> i32 |
| charCodeAtImport = |
| addImport(module, "charCodeAt", {nullExt, Type::i32}, Type::i32); |
| // string.substring: string, start, end -> string |
| substringImport = |
| addImport(module, "substring", {nullExt, Type::i32, Type::i32}, nnExt); |
| |
| // Replace the string instructions in parallel. |
| struct Replacer : public WalkerPass<PostWalker<Replacer>> { |
| bool isFunctionParallel() override { return true; } |
| |
| StringLowering& lowering; |
| |
| std::unique_ptr<Pass> create() override { |
| return std::make_unique<Replacer>(lowering); |
| } |
| |
| Replacer(StringLowering& lowering) : lowering(lowering) {} |
| |
| void visitStringNew(StringNew* curr) { |
| Builder builder(*getModule()); |
| switch (curr->op) { |
| case StringNewWTF16Array: |
| replaceCurrent(builder.makeCall(lowering.fromCharCodeArrayImport, |
| {curr->ref, curr->start, curr->end}, |
| lowering.nnExt)); |
| return; |
| case StringNewFromCodePoint: |
| replaceCurrent(builder.makeCall( |
| lowering.fromCodePointImport, {curr->ref}, lowering.nnExt)); |
| return; |
| default: |
| WASM_UNREACHABLE("TODO: all of string.new*"); |
| } |
| } |
| |
| void visitStringConcat(StringConcat* curr) { |
| Builder builder(*getModule()); |
| replaceCurrent(builder.makeCall( |
| lowering.concatImport, {curr->left, curr->right}, lowering.nnExt)); |
| } |
| |
| void visitStringEncode(StringEncode* curr) { |
| Builder builder(*getModule()); |
| switch (curr->op) { |
| case StringEncodeWTF16Array: |
| replaceCurrent( |
| builder.makeCall(lowering.intoCharCodeArrayImport, |
| {curr->str, curr->array, curr->start}, |
| Type::i32)); |
| return; |
| default: |
| WASM_UNREACHABLE("TODO: all of string.encode*"); |
| } |
| } |
| |
| void visitStringEq(StringEq* curr) { |
| Builder builder(*getModule()); |
| switch (curr->op) { |
| case StringEqEqual: |
| replaceCurrent(builder.makeCall( |
| lowering.equalsImport, {curr->left, curr->right}, Type::i32)); |
| return; |
| case StringEqCompare: |
| replaceCurrent(builder.makeCall( |
| lowering.compareImport, {curr->left, curr->right}, Type::i32)); |
| return; |
| default: |
| WASM_UNREACHABLE("invalid string.eq*"); |
| } |
| } |
| |
| void visitStringTest(StringTest* curr) { |
| Builder builder(*getModule()); |
| replaceCurrent( |
| builder.makeCall(lowering.testImport, {curr->ref}, Type::i32)); |
| } |
| |
| void visitStringMeasure(StringMeasure* curr) { |
| Builder builder(*getModule()); |
| replaceCurrent( |
| builder.makeCall(lowering.lengthImport, {curr->ref}, Type::i32)); |
| } |
| |
| void visitStringWTF16Get(StringWTF16Get* curr) { |
| Builder builder(*getModule()); |
| replaceCurrent(builder.makeCall( |
| lowering.charCodeAtImport, {curr->ref, curr->pos}, Type::i32)); |
| } |
| |
| void visitStringSliceWTF(StringSliceWTF* curr) { |
| Builder builder(*getModule()); |
| replaceCurrent(builder.makeCall(lowering.substringImport, |
| {curr->ref, curr->start, curr->end}, |
| lowering.nnExt)); |
| } |
| }; |
| |
| Replacer replacer(*this); |
| replacer.run(getPassRunner(), module); |
| replacer.walkModuleCode(module); |
| } |
| }; |
| |
| Pass* createStringGatheringPass() { return new StringGathering(); } |
| Pass* createStringLoweringPass() { return new StringLowering(); } |
| Pass* createStringLoweringMagicImportPass() { return new StringLowering(true); } |
| Pass* createStringLoweringMagicImportAssertPass() { |
| return new StringLowering(true, true); |
| } |
| |
| } // namespace wasm |