blob: 7b8f456217f30bb508097389db92a4f939a84739 [file] [log] [blame] [edit]
/*
* Copyright 2025 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Lift JS string imports into wasm strings in Binaryen IR, which can then be
// fully optimized. Typically StringLowering would be run later to lower them
// back down.
//
// A pass argument allows customizing the module name for string constants:
//
// --pass-arg=string-constants-module@MODULE_NAME
//
#include "ir/utils.h"
#include "pass.h"
#include "passes/string-utils.h"
#include "support/json.h"
#include "support/string.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
struct StringLifting : public Pass {
// Maps the global name of an imported string to the actual string.
std::unordered_map<Name, Name> importedStrings;
// Imported string functions. Imports that do not exist remain null.
Name fromCharCodeArrayImport;
Name intoCharCodeArrayImport;
Name fromCodePointImport;
Name concatImport;
Name equalsImport;
Name testImport;
Name compareImport;
Name lengthImport;
Name charCodeAtImport;
Name substringImport;
void run(Module* module) override {
// Whether we found any work to do.
bool found = false;
// Imported string constants look like
//
// (import "\'" "bar" (global $string.bar.internal.name (ref extern)))
//
// That is, they are imported from module "'" and the basename is the
// actual string. Find them all so we can apply them.
Name stringConstsModule =
getArgumentOrDefault("string-constants-module", WasmStringConstsModule);
for (auto& global : module->globals) {
if (!global->imported()) {
continue;
}
if (global->module == stringConstsModule) {
// Encode from WTF-8 to WTF-16.
auto wtf8 = global->base;
std::stringstream wtf16;
bool valid = String::convertWTF8ToWTF16(wtf16, wtf8.str);
if (!valid) {
Fatal() << "Bad string to lift: " << wtf8;
}
importedStrings[global->name] = wtf16.str();
found = true;
}
}
// Imported strings may also be found in the string section.
auto stringSectionIter = std::find_if(
module->customSections.begin(),
module->customSections.end(),
[&](CustomSection& section) { return section.name == "string.consts"; });
if (stringSectionIter != module->customSections.end()) {
// We found the string consts section. Parse it.
auto& section = *stringSectionIter;
auto copy = section.data;
json::Value array;
array.parse(copy.data(), json::Value::WTF16);
if (!array.isArray()) {
Fatal() << "StringLifting: string.const section should be a JSON array";
}
// We have the array of constants from the section. Find globals that
// refer to it.
for (auto& global : module->globals) {
if (!global->imported() || global->module != "string.const") {
continue;
}
// The index in the array is the basename.
Index index = std::stoi(std::string(global->base.str));
if (index >= array.size()) {
Fatal() << "StringLifting: bad index in string.const section";
}
auto item = array[index];
if (!item->isString()) {
Fatal()
<< "StringLifting: string.const section entry is not a string";
}
if (importedStrings.count(global->name)) {
Fatal() << "StringLifting: string.const section tramples other const";
}
importedStrings[global->name] = item->getIString();
}
// Remove the custom section: After lifting it has no purpose (and could
// cause problems with repeated lifting/lowering).
module->customSections.erase(stringSectionIter);
}
auto array16 = Type(Array(Field(Field::i16, Mutable)), Nullable);
auto refExtern = Type(HeapType::ext, NonNullable);
auto externref = Type(HeapType::ext, Nullable);
auto i32 = Type::i32;
// Find imported string functions.
for (auto& func : module->functions) {
if (!func->imported() || func->module != WasmStringsModule) {
continue;
}
// TODO: Check exactness here too.
auto type = func->type;
if (func->base == "fromCharCodeArray") {
if (type.getHeapType() != Signature({array16, i32, i32}, refExtern)) {
Fatal() << "StringLifting: bad type for fromCharCodeArray: " << type;
}
fromCharCodeArrayImport = func->name;
found = true;
} else if (func->base == "fromCodePoint") {
if (type.getHeapType() != Signature(i32, refExtern)) {
Fatal() << "StringLifting: bad type for fromCodePoint: " << type;
}
fromCodePointImport = func->name;
found = true;
} else if (func->base == "concat") {
if (type.getHeapType() !=
Signature({externref, externref}, refExtern)) {
Fatal() << "StringLifting: bad type for concat: " << type;
}
concatImport = func->name;
found = true;
} else if (func->base == "intoCharCodeArray") {
if (type.getHeapType() != Signature({externref, array16, i32}, i32)) {
Fatal() << "StringLifting: bad type for intoCharCodeArray: " << type;
}
intoCharCodeArrayImport = func->name;
found = true;
} else if (func->base == "equals") {
if (type.getHeapType() != Signature({externref, externref}, i32)) {
Fatal() << "StringLifting: bad type for equals: " << type;
}
equalsImport = func->name;
found = true;
} else if (func->base == "test") {
if (type.getHeapType() != Signature({externref}, i32)) {
Fatal() << "StringLifting: bad type for test: " << type;
}
testImport = func->name;
found = true;
} else if (func->base == "compare") {
if (type.getHeapType() != Signature({externref, externref}, i32)) {
Fatal() << "StringLifting: bad type for compare: " << type;
}
compareImport = func->name;
found = true;
} else if (func->base == "length") {
if (type.getHeapType() != Signature({externref}, i32)) {
Fatal() << "StringLifting: bad type for length: " << type;
}
lengthImport = func->name;
found = true;
} else if (func->base == "charCodeAt") {
if (type.getHeapType() != Signature({externref, i32}, i32)) {
Fatal() << "StringLifting: bad type for charCodeAt: " << type;
}
charCodeAtImport = func->name;
found = true;
} else if (func->base == "substring") {
if (type.getHeapType() != Signature({externref, i32, i32}, refExtern)) {
Fatal() << "StringLifting: bad type for substring: " << type;
}
substringImport = func->name;
found = true;
} else {
std::cerr << "warning: unknown strings import: " << func->base << '\n';
}
}
if (!found) {
// Nothing to do.
return;
}
struct StringApplier : public WalkerPass<PostWalker<StringApplier>> {
bool isFunctionParallel() override { return true; }
const StringLifting& parent;
StringApplier(const StringLifting& parent) : parent(parent) {}
std::unique_ptr<Pass> create() override {
return std::make_unique<StringApplier>(parent);
}
bool modified = false;
void visitGlobalGet(GlobalGet* curr) {
// Replace global.gets of imported strings with string.const.
auto iter = parent.importedStrings.find(curr->name);
if (iter != parent.importedStrings.end()) {
auto wtf16 = iter->second;
replaceCurrent(Builder(*getModule()).makeStringConst(wtf16.str));
modified = true;
}
}
void visitCall(Call* curr) {
// Replace calls of imported string methods with stringref operations.
if (curr->target == parent.fromCharCodeArrayImport) {
replaceCurrent(Builder(*getModule())
.makeStringNew(StringNewWTF16Array,
curr->operands[0],
curr->operands[1],
curr->operands[2]));
} else if (curr->target == parent.fromCodePointImport) {
replaceCurrent(
Builder(*getModule())
.makeStringNew(StringNewFromCodePoint, curr->operands[0]));
} else if (curr->target == parent.concatImport) {
replaceCurrent(
Builder(*getModule())
.makeStringConcat(curr->operands[0], curr->operands[1]));
} else if (curr->target == parent.intoCharCodeArrayImport) {
replaceCurrent(Builder(*getModule())
.makeStringEncode(StringEncodeWTF16Array,
curr->operands[0],
curr->operands[1],
curr->operands[2]));
} else if (curr->target == parent.equalsImport) {
replaceCurrent(Builder(*getModule())
.makeStringEq(StringEqEqual,
curr->operands[0],
curr->operands[1]));
} else if (curr->target == parent.testImport) {
replaceCurrent(
Builder(*getModule()).makeStringTest(curr->operands[0]));
} else if (curr->target == parent.compareImport) {
replaceCurrent(Builder(*getModule())
.makeStringEq(StringEqCompare,
curr->operands[0],
curr->operands[1]));
} else if (curr->target == parent.lengthImport) {
replaceCurrent(
Builder(*getModule())
.makeStringMeasure(StringMeasureWTF16, curr->operands[0]));
} else if (curr->target == parent.charCodeAtImport) {
replaceCurrent(
Builder(*getModule())
.makeStringWTF16Get(curr->operands[0], curr->operands[1]));
} else if (curr->target == parent.substringImport) {
replaceCurrent(Builder(*getModule())
.makeStringSliceWTF(curr->operands[0],
curr->operands[1],
curr->operands[2]));
}
}
void visitFunction(Function* curr) {
// If we made modifications then we need to refinalize, as we replace
// externrefs with stringrefs, a subtype.
if (modified) {
ReFinalize().walkFunctionInModule(curr, getModule());
}
}
};
StringApplier applier(*this);
applier.run(getPassRunner(), module);
applier.walkModuleCode(module);
// TODO: Add casts. We generate new string.* instructions, and all their
// string inputs should be stringref, not externref, but we have not
// converted all externrefs to stringrefs (since some externrefs might
// be something else). It is not urgent to fix this as the validator
// accepts externrefs there atm, and since toolchains will lower
// strings out at the end anyhow (which would remove such casts). Note
// that if we add a type import for stringref then this problem would
// become a lot simpler (we'd convert that type to stringref).
// Enable the feature so the module validates.
module->features.enable(FeatureSet::Strings);
}
};
Pass* createStringLiftingPass() { return new StringLifting(); }
} // namespace wasm