[Parser] Lex strings (#4687)

diff --git a/src/wasm/wat-parser-internal.h b/src/wasm/wat-parser-internal.h
index 2f7e108..7879139 100644
--- a/src/wasm/wat-parser-internal.h
+++ b/src/wasm/wat-parser-internal.h
@@ -28,6 +28,7 @@
 #include <cctype>
 #include <iostream>
 #include <optional>
+#include <sstream>
 #include <variant>
 
 using namespace std::string_view_literals;
@@ -216,6 +217,75 @@
   }
 };
 
+struct LexStrResult : LexResult {
+  // Allocate a string only if there are escape sequences, otherwise just use
+  // the original string_view.
+  std::optional<std::string> str;
+};
+
+struct LexStrCtx : LexCtx {
+private:
+  // Whether we are building a string due to the presence of escape
+  // sequences.
+  bool building = false;
+  std::stringstream ss;
+
+public:
+  LexStrCtx(std::string_view in) : LexCtx(in) {}
+
+  std::optional<LexStrResult> lexed() {
+    if (auto basic = LexCtx::lexed()) {
+      auto str = building ? std::optional<std::string>{ss.str()} : std::nullopt;
+      return {LexStrResult{*basic, str}};
+    }
+    return {};
+  }
+
+  void takeChar() {
+    if (building) {
+      ss << peek();
+    }
+    LexCtx::take(1);
+  }
+
+  void ensureBuilding() {
+    if (building) {
+      return;
+    }
+    // Drop the opening '"'.
+    ss << LexCtx::lexed()->span.substr(1);
+    building = true;
+  }
+
+  void appendEscaped(char c) { ss << c; }
+
+  bool appendUnicode(uint64_t u) {
+    if ((0xd800 <= u && u < 0xe000) || 0x110000 <= u) {
+      return false;
+    }
+    if (u < 0x80) {
+      // 0xxxxxxx
+      ss << uint8_t(u);
+    } else if (u < 0x800) {
+      // 110xxxxx 10xxxxxx
+      ss << uint8_t(0b11000000 | ((u >> 6) & 0b00011111));
+      ss << uint8_t(0b10000000 | ((u >> 0) & 0b00111111));
+    } else if (u < 0x10000) {
+      // 1110xxxx 10xxxxxx 10xxxxxx
+      ss << uint8_t(0b11100000 | ((u >> 12) & 0b00001111));
+      ss << uint8_t(0b10000000 | ((u >> 6) & 0b00111111));
+      ss << uint8_t(0b10000000 | ((u >> 0) & 0b00111111));
+    } else {
+      // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+      ss << uint8_t(0b11110000 | ((u >> 18) & 0b00000111));
+      ss << uint8_t(0b10000000 | ((u >> 12) & 0b00111111));
+      ss << uint8_t(0b10000000 | ((u >> 6) & 0b00111111));
+      ss << uint8_t(0b10000000 | ((u >> 0) & 0b00111111));
+    }
+    return true;
+  }
+};
+
 std::optional<LexResult> lparen(std::string_view in) {
   LexCtx ctx(in);
   ctx.takePrefix("("sv);
@@ -441,6 +511,80 @@
   return {};
 }
 
+// string     ::= '"' (b*:stringelem)* '"'  => concat((b*)*)
+//                    (if |concat((b*)*)| < 2^32)
+// stringelem ::= c:stringchar              => utf8(c)
+//              | '\' n:hexdigit m:hexdigit => 16*n + m
+// stringchar ::= c:char                    => c
+//                    (if c >= U+20 && c != U+7f && c != '"' && c != '\')
+//              | '\t' => \t | '\n' => \n | '\r' => \r
+//              | '\\' => \ | '\"' => " | '\'' => '
+//              | '\u{' n:hexnum '}'        => U+(n)
+//                    (if n < 0xD800 and 0xE000 <= n <= 0x110000)
+std::optional<LexStrResult> str(std::string_view in) {
+  LexStrCtx ctx(in);
+  if (!ctx.takePrefix("\""sv)) {
+    return {};
+  }
+  while (!ctx.takePrefix("\""sv)) {
+    if (ctx.empty()) {
+      // TODO: Add error production for unterminated string.
+      return {};
+    }
+    if (ctx.startsWith("\\"sv)) {
+      // Escape sequences
+      ctx.ensureBuilding();
+      ctx.take(1);
+      if (ctx.takePrefix("t"sv)) {
+        ctx.appendEscaped('\t');
+      } else if (ctx.takePrefix("n"sv)) {
+        ctx.appendEscaped('\n');
+      } else if (ctx.takePrefix("r"sv)) {
+        ctx.appendEscaped('\r');
+      } else if (ctx.takePrefix("\\"sv)) {
+        ctx.appendEscaped('\\');
+      } else if (ctx.takePrefix("\""sv)) {
+        ctx.appendEscaped('"');
+      } else if (ctx.takePrefix("'"sv)) {
+        ctx.appendEscaped('\'');
+      } else if (ctx.takePrefix("u{"sv)) {
+        auto lexed = hexnum(ctx.next());
+        if (!lexed) {
+          // TODO: Add error production for malformed unicode escapes.
+          return {};
+        }
+        ctx.take(*lexed);
+        if (!ctx.takePrefix("}"sv)) {
+          // TODO: Add error production for malformed unicode escapes.
+          return {};
+        }
+        if (!ctx.appendUnicode(lexed->n)) {
+          // TODO: Add error production for invalid unicode values.
+          return {};
+        }
+      } else {
+        LexIntCtx ictx(ctx.next());
+        if (!ictx.takeHexdigit() || !ictx.takeHexdigit()) {
+          // TODO: Add error production for unrecognized escape sequence.
+          return {};
+        }
+        auto lexed = *ictx.lexed();
+        ctx.take(lexed);
+        ctx.appendEscaped(char(lexed.n));
+      }
+    } else {
+      // Normal characters
+      if (uint8_t c = ctx.peek(); c >= 0x20 && c != 0x7F) {
+        ctx.takeChar();
+      } else {
+        // TODO: Add error production for unescaped control characters.
+        return {};
+      }
+    }
+  }
+  return ctx.lexed();
+}
+
 // ======
 // Tokens
 // ======
@@ -482,8 +626,25 @@
   friend bool operator==(const IdTok&, const IdTok&) { return true; }
 };
 
+struct StringTok {
+  std::optional<std::string> str;
+
+  friend std::ostream& operator<<(std::ostream& os, const StringTok& tok) {
+    if (tok.str) {
+      os << '"' << *tok.str << '"';
+    } else {
+      os << "(raw string)";
+    }
+    return os;
+  }
+
+  friend bool operator==(const StringTok& t1, const StringTok& t2) {
+    return t1.str == t2.str;
+  }
+};
+
 struct Token {
-  using Data = std::variant<LParenTok, RParenTok, IntTok, IdTok>;
+  using Data = std::variant<LParenTok, RParenTok, IntTok, IdTok, StringTok>;
 
   std::string_view span;
   Data data;
@@ -571,6 +732,8 @@
       tok = Token{t->span, IdTok{}};
     } else if (auto t = integer(next())) {
       tok = Token{t->span, IntTok{t->n, t->signedness}};
+    } else if (auto t = str(next())) {
+      tok = Token{t->span, StringTok{t->str}};
     } else {
       // TODO: Do something about lexing errors.
       curr = std::nullopt;
diff --git a/test/gtest/wat-parser.cpp b/test/gtest/wat-parser.cpp
index 2ddb781..be6d76e 100644
--- a/test/gtest/wat-parser.cpp
+++ b/test/gtest/wat-parser.cpp
@@ -367,3 +367,105 @@
     EXPECT_EQ(lexer, lexer.end());
   }
 }
+
+TEST(ParserTest, LexString) {
+  {
+    auto pangram = "\"The quick brown fox jumps over the lazy dog\""sv;
+    Lexer lexer(pangram);
+    ASSERT_NE(lexer, lexer.end());
+    Token expected{pangram, StringTok{{}}};
+    EXPECT_EQ(*lexer, expected);
+  }
+  {
+    auto chars = "\"`~!@#$%^&*()_-+0123456789|,.<>/?;:'\""sv;
+    Lexer lexer(chars);
+    ASSERT_NE(lexer, lexer.end());
+    Token expected{chars, StringTok{{}}};
+    EXPECT_EQ(*lexer, expected);
+  }
+  {
+    auto escapes = "\"_\\t_\\n_\\r_\\\\_\\\"_\\'_\""sv;
+    Lexer lexer(escapes);
+    ASSERT_NE(lexer, lexer.end());
+    Token expected{escapes, StringTok{{"_\t_\n_\r_\\_\"_'_"}}};
+    EXPECT_EQ(*lexer, expected);
+  }
+  {
+    auto escapes = "\"_\\00_\\07_\\20_\\5A_\\7F_\\ff_\\ffff_\""sv;
+    Lexer lexer(escapes);
+    ASSERT_NE(lexer, lexer.end());
+    std::string escaped{"_\0_\7_ _Z_\x7f_\xff_\xff"
+                        "ff_"sv};
+    Token expected{escapes, StringTok{{escaped}}};
+    EXPECT_EQ(*lexer, expected);
+  }
+  {
+    // _$_£_€_𐍈_
+    auto unicode = "\"_\\u{24}_\\u{00a3}_\\u{20AC}_\\u{10348}_\""sv;
+    Lexer lexer(unicode);
+    ASSERT_NE(lexer, lexer.end());
+    std::string escaped{"_$_\xC2\xA3_\xE2\x82\xAC_\xF0\x90\x8D\x88_"};
+    Token expected{unicode, StringTok{{escaped}}};
+    EXPECT_EQ(*lexer, expected);
+  }
+  {
+    // _$_£_€_𐍈_
+    auto unicode = "\"_$_\xC2\xA3_\xE2\x82\xAC_\xF0\x90\x8D\x88_\""sv;
+    Lexer lexer(unicode);
+    ASSERT_NE(lexer, lexer.end());
+    Token expected{unicode, StringTok{{}}};
+    EXPECT_EQ(*lexer, expected);
+  }
+  {
+    Lexer lexer("\"unterminated"sv);
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"unescaped nul\0\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"unescaped U+19\x19\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"unescaped U+7f\x7f\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"\\ stray backslash\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"short \\f hex escape\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"bad hex \\gg\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"empty unicode \\u{}\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"not unicode \\u{abcdefg}\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"extra chars \\u{123(}\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"unpaired surrogate unicode crimes \\u{d800}\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"more surrogate unicode crimes \\u{dfff}\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+  {
+    Lexer lexer("\"too big \\u{110000}\"");
+    ASSERT_EQ(lexer, lexer.end());
+  }
+}