/*
 * Copyright 2025 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Lift JS string imports into wasm strings in Binaryen IR, which can then be
// fully optimized. Typically StringLowering would be run later to lower them
// back down.
//
// A pass argument allows customizing the module name for string constants:
//
//   --pass-arg=string-constants-module@MODULE_NAME
//

#include "ir/utils.h"
#include "pass.h"
#include "passes/string-utils.h"
#include "support/json.h"
#include "support/string.h"
#include "wasm-builder.h"
#include "wasm.h"

namespace wasm {

struct StringLifting : public Pass {
  // Maps the global name of an imported string to the actual string.
  std::unordered_map<Name, Name> importedStrings;

  // Imported string functions. Imports that do not exist remain null.
  Name fromCharCodeArrayImport;
  Name intoCharCodeArrayImport;
  Name fromCodePointImport;
  Name concatImport;
  Name equalsImport;
  Name testImport;
  Name compareImport;
  Name lengthImport;
  Name charCodeAtImport;
  Name substringImport;

  // Shared imported string functions.
  Name fromCharCodeArraySharedImport;
  Name intoCharCodeArraySharedImport;
  Name fromCodePointSharedImport;
  Name concatSharedImport;
  Name equalsSharedImport;
  Name testSharedImport;
  Name compareSharedImport;
  Name lengthSharedImport;
  Name charCodeAtSharedImport;
  Name substringSharedImport;

  void run(Module* module) override {
    // Whether we found any work to do.
    bool found = false;

    // Imported string constants look like
    //
    //   (import "\'" "bar" (global $string.bar.internal.name (ref extern)))
    //
    // That is, they are imported from module "'" and the basename is the
    // actual string. Find them all so we can apply them.
    Name stringConstsModule =
      getArgumentOrDefault("string-constants-module", WasmStringConstsModule);
    for (auto& global : module->globals) {
      if (!global->imported()) {
        continue;
      }
      if (global->module == stringConstsModule) {
        // Encode from WTF-8 to WTF-16.
        auto wtf8 = global->base;
        std::stringstream wtf16;
        bool valid = String::convertWTF8ToWTF16(wtf16, wtf8.str);
        if (!valid) {
          Fatal() << "Bad string to lift: " << wtf8;
        }
        importedStrings[global->name] = wtf16.str();
        found = true;
      }
    }

    // Imported strings may also be found in the string section.
    auto stringSectionIter = std::find_if(
      module->customSections.begin(),
      module->customSections.end(),
      [&](CustomSection& section) { return section.name == "string.consts"; });
    if (stringSectionIter != module->customSections.end()) {
      // We found the string consts section. Parse it.
      auto& section = *stringSectionIter;
      auto copy = section.data;
      json::Value array;
      array.parse(copy.data(), json::Value::WTF16);
      if (!array.isArray()) {
        Fatal() << "StringLifting: string.const section should be a JSON array";
      }

      // We have the array of constants from the section. Find globals that
      // refer to it.
      for (auto& global : module->globals) {
        if (!global->imported() || global->module != "string.const") {
          continue;
        }
        // The index in the array is the basename.
        Index index = std::stoi(std::string(global->base.str));
        if (index >= array.size()) {
          Fatal() << "StringLifting: bad index in string.const section";
        }
        auto item = array[index];
        if (!item->isString()) {
          Fatal()
            << "StringLifting: string.const section entry is not a string";
        }
        if (importedStrings.count(global->name)) {
          Fatal() << "StringLifting: string.const section tramples other const";
        }
        importedStrings[global->name] = item->getIString();
      }

      // Remove the custom section: After lifting it has no purpose (and could
      // cause problems with repeated lifting/lowering).
      module->customSections.erase(stringSectionIter);
    }

    auto array16 = Type(HeapTypes::getMutI16Array(), Nullable);
    auto refExtern = Type(HeapType::ext, NonNullable);
    auto externref = Type(HeapType::ext, Nullable);
    auto i32 = Type::i32;

    auto sharedArray16 = Type(HeapTypes::getSharedMutI16Array(), Nullable);
    auto refSharedExtern =
      Type(HeapType(HeapType::ext).getBasic(Shared), NonNullable);
    auto sharedExternref =
      Type(HeapType(HeapType::ext).getBasic(Shared), Nullable);

    // Find imported string functions.
    for (auto& func : module->functions) {
      if (!func->imported() || func->module != WasmStringsModule) {
        continue;
      }
      auto type = func->type;
      if (func->base == "fromCharCodeArray") {
        if (type.getHeapType() == Signature(Type({array16, i32, i32}), refExtern)) {
          fromCharCodeArrayImport = func->name;
          found = true;
        } else if (type.getHeapType() ==
                   Signature(Type({sharedArray16, i32, i32}), refSharedExtern)) {
          fromCharCodeArraySharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for fromCharCodeArray: " << type;
        }
      } else if (func->base == "fromCodePoint") {
        if (type.getHeapType() == Signature(i32, refExtern)) {
          fromCodePointImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(i32, refSharedExtern)) {
          fromCodePointSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for fromCodePoint: " << type;
        }
      } else if (func->base == "concat") {
        if (type.getHeapType() ==
            Signature(Type({externref, externref}), refExtern)) {
          concatImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref,
                                                         sharedExternref}),
                                                   refSharedExtern)) {
          concatSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for concat: " << type;
        }
      } else if (func->base == "intoCharCodeArray") {
        if (type.getHeapType() == Signature(Type({externref, array16, i32}), i32)) {
          intoCharCodeArrayImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref,
                                                         sharedArray16,
                                                         i32}),
                                                   i32)) {
          intoCharCodeArraySharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for intoCharCodeArray: " << type;
        }
      } else if (func->base == "equals") {
        if (type.getHeapType() == Signature(Type({externref, externref}), i32)) {
          equalsImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref,
                                                         sharedExternref}),
                                                   i32)) {
          equalsSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for equals: " << type;
        }
      } else if (func->base == "test") {
        if (type.getHeapType() == Signature(Type({externref}), i32)) {
          testImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref}), i32)) {
          testSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for test: " << type;
        }
      } else if (func->base == "compare") {
        if (type.getHeapType() == Signature(Type({externref, externref}), i32)) {
          compareImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref,
                                                         sharedExternref}),
                                                   i32)) {
          compareSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for compare: " << type;
        }
      } else if (func->base == "length") {
        if (type.getHeapType() == Signature(Type({externref}), i32)) {
          lengthImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref}), i32)) {
          lengthSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for length: " << type;
        }
      } else if (func->base == "charCodeAt") {
        if (type.getHeapType() == Signature(Type({externref, i32}), i32)) {
          charCodeAtImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref, i32}), i32)) {
          charCodeAtSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for charCodeAt: " << type;
        }
      } else if (func->base == "substring") {
        if (type.getHeapType() == Signature(Type({externref, i32, i32}), refExtern)) {
          substringImport = func->name;
          found = true;
        } else if (type.getHeapType() == Signature(Type({sharedExternref,
                                                         i32,
                                                         i32}),
                                                   refSharedExtern)) {
          substringSharedImport = func->name;
          found = true;
        } else {
          Fatal() << "StringLifting: bad type for substring: " << type;
        }
      } else {
        std::cerr << "warning: unknown strings import: " << func->base << '\n';
      }
    }

    if (!found) {
      // Nothing to do.
      return;
    }

    struct StringApplier : public WalkerPass<PostWalker<StringApplier>> {
      bool isFunctionParallel() override { return true; }

      const StringLifting& parent;

      StringApplier(const StringLifting& parent) : parent(parent) {}

      std::unique_ptr<Pass> create() override {
        return std::make_unique<StringApplier>(parent);
      }

      bool modified = false;

      void visitGlobalGet(GlobalGet* curr) {
        // Replace global.gets of imported strings with string.const.
        auto iter = parent.importedStrings.find(curr->name);
        if (iter != parent.importedStrings.end()) {
          auto wtf16 = iter->second;
          replaceCurrent(Builder(*getModule()).makeStringConst(wtf16.str));
          modified = true;
        }
      }

      void visitCall(Call* curr) {
        Builder builder(*getModule());
        // Replace calls of imported string methods with stringref operations.
        if (curr->target == parent.fromCharCodeArrayImport ||
            curr->target == parent.fromCharCodeArraySharedImport) {
          replaceCurrent(builder.makeStringNew(StringNewWTF16Array,
                                               curr->operands[0],
                                               curr->operands[1],
                                               curr->operands[2]));
        } else if (curr->target == parent.fromCodePointImport ||
                   curr->target == parent.fromCodePointSharedImport) {
          replaceCurrent(builder.makeStringNew(StringNewFromCodePoint,
                                               curr->operands[0]));
        } else if (curr->target == parent.concatImport ||
                   curr->target == parent.concatSharedImport) {
          replaceCurrent(builder.makeStringConcat(curr->operands[0],
                                                  curr->operands[1]));
        } else if (curr->target == parent.intoCharCodeArrayImport ||
                   curr->target == parent.intoCharCodeArraySharedImport) {
          replaceCurrent(builder.makeStringEncode(StringEncodeWTF16Array,
                                                  curr->operands[0],
                                                  curr->operands[1],
                                                  curr->operands[2]));
        } else if (curr->target == parent.equalsImport ||
                   curr->target == parent.equalsSharedImport) {
          replaceCurrent(builder.makeStringEq(StringEqEqual,
                                              curr->operands[0],
                                              curr->operands[1]));
        } else if (curr->target == parent.testImport ||
                   curr->target == parent.testSharedImport) {
          replaceCurrent(builder.makeStringTest(curr->operands[0]));
        } else if (curr->target == parent.compareImport ||
                   curr->target == parent.compareSharedImport) {
          replaceCurrent(builder.makeStringEq(StringEqCompare,
                                              curr->operands[0],
                                              curr->operands[1]));
        } else if (curr->target == parent.lengthImport ||
                   curr->target == parent.lengthSharedImport) {
          replaceCurrent(builder.makeStringMeasure(StringMeasureWTF16,
                                                   curr->operands[0]));
        } else if (curr->target == parent.charCodeAtImport ||
                   curr->target == parent.charCodeAtSharedImport) {
          replaceCurrent(builder.makeStringWTF16Get(curr->operands[0],
                                                    curr->operands[1]));
        } else if (curr->target == parent.substringImport ||
                   curr->target == parent.substringSharedImport) {
          replaceCurrent(builder.makeStringSliceWTF(curr->operands[0],
                                                    curr->operands[1],
                                                    curr->operands[2]));
        }
      }

      void visitFunction(Function* curr) {
        // If we made modifications then we need to refinalize, as we replace
        // externrefs with stringrefs, a subtype.
        if (modified) {
          ReFinalize().walkFunctionInModule(curr, getModule());
        }
      }
    };

    StringApplier applier(*this);
    applier.run(getPassRunner(), module);
    applier.walkModuleCode(module);

    // TODO: Add casts. We generate new string.* instructions, and all their
    //       string inputs should be stringref, not externref, but we have not
    //       converted all externrefs to stringrefs (since some externrefs might
    //       be something else). It is not urgent to fix this as the validator
    //       accepts externrefs there atm, and since toolchains will lower
    //       strings out at the end anyhow (which would remove such casts). Note
    //       that if we add a type import for stringref then this problem would
    //       become a lot simpler (we'd convert that type to stringref).

    // Enable the feature so the module validates.
    module->features.enable(FeatureSet::Strings);
  }
};

Pass* createStringLiftingPass() { return new StringLifting(); }

} // namespace wasm
