Enhance interpolateParams to correctly handle placeholders (#1732)

Enhance client side statement to correctly handle placeholders in
queries with comments, strings, and backticks.
diff --git a/connection.go b/connection.go
index b1660a5..ab15668 100644
--- a/connection.go
+++ b/connection.go
@@ -172,7 +172,7 @@
 }
 
 // Closes the network connection and unsets internal variables. Do not call this
-// function after successfully authentication, call Close instead. This function
+// function after successful authentication, call Close instead. This function
 // is called before auth or on auth failure because MySQL will have already
 // closed the network connection.
 func (mc *mysqlConn) cleanup() {
@@ -246,100 +246,184 @@
 }
 
 func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
-	// Number of ? should be same to len(args)
-	if strings.Count(query, "?") != len(args) {
-		return "", driver.ErrSkip
-	}
+	noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
+	const (
+		stateNormal = iota
+		stateString
+		stateEscape
+		stateEOLComment
+		stateSlashStarComment
+		stateBacktick
+	)
+
+	const (
+		QUOTE_BYTE         = byte('\'')
+		DBL_QUOTE_BYTE     = byte('"')
+		BACKSLASH_BYTE     = byte('\\')
+		QUESTION_MARK_BYTE = byte('?')
+		SLASH_BYTE         = byte('/')
+		STAR_BYTE          = byte('*')
+		HASH_BYTE          = byte('#')
+		MINUS_BYTE         = byte('-')
+		LINE_FEED_BYTE     = byte('\n')
+		BACKTICK_BYTE      = byte('`')
+	)
 
 	buf, err := mc.buf.takeCompleteBuffer()
 	if err != nil {
-		// can not take the buffer. Something must be wrong with the connection
 		mc.cleanup()
-		// interpolateParams would be called before sending any query.
-		// So its safe to retry.
 		return "", driver.ErrBadConn
 	}
 	buf = buf[:0]
+	state := stateNormal
+	singleQuotes := false
+	lastChar := byte(0)
 	argPos := 0
+	lenQuery := len(query)
+	lastIdx := 0
 
-	for i := 0; i < len(query); i++ {
-		q := strings.IndexByte(query[i:], '?')
-		if q == -1 {
-			buf = append(buf, query[i:]...)
-			break
-		}
-		buf = append(buf, query[i:i+q]...)
-		i += q
-
-		arg := args[argPos]
-		argPos++
-
-		if arg == nil {
-			buf = append(buf, "NULL"...)
+	for i := 0; i < lenQuery; i++ {
+		currentChar := query[i]
+		if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
+			state = stateString
+			lastChar = currentChar
 			continue
 		}
-
-		switch v := arg.(type) {
-		case int64:
-			buf = strconv.AppendInt(buf, v, 10)
-		case uint64:
-			// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
-			buf = strconv.AppendUint(buf, v, 10)
-		case float64:
-			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
-		case bool:
-			if v {
-				buf = append(buf, '1')
-			} else {
-				buf = append(buf, '0')
+		switch currentChar {
+		case STAR_BYTE:
+			if state == stateNormal && lastChar == SLASH_BYTE {
+				state = stateSlashStarComment
 			}
-		case time.Time:
-			if v.IsZero() {
-				buf = append(buf, "'0000-00-00'"...)
-			} else {
-				buf = append(buf, '\'')
-				buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
-				if err != nil {
-					return "", err
-				}
-				buf = append(buf, '\'')
+		case SLASH_BYTE:
+			if state == stateSlashStarComment && lastChar == STAR_BYTE {
+				state = stateNormal
+				// Clear lastChar so the '/' that closed the comment isn't
+				// reused to start a new comment with a following '*'.
+				lastChar = 0
+				continue
 			}
-		case json.RawMessage:
-			buf = append(buf, '\'')
-			if mc.status&statusNoBackslashEscapes == 0 {
-				buf = escapeBytesBackslash(buf, v)
-			} else {
-				buf = escapeBytesQuotes(buf, v)
+		case HASH_BYTE:
+			if state == stateNormal {
+				state = stateEOLComment
 			}
-			buf = append(buf, '\'')
-		case []byte:
-			if v == nil {
-				buf = append(buf, "NULL"...)
-			} else {
-				buf = append(buf, "_binary'"...)
-				if mc.status&statusNoBackslashEscapes == 0 {
-					buf = escapeBytesBackslash(buf, v)
+		case MINUS_BYTE:
+			if state == stateNormal && lastChar == MINUS_BYTE {
+				// -- only starts a comment if followed by whitespace or control char
+				if i+1 < lenQuery {
+					nextChar := query[i+1]
+					if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
+						state = stateEOLComment
+					}
 				} else {
-					buf = escapeBytesQuotes(buf, v)
+					state = stateEOLComment
 				}
-				buf = append(buf, '\'')
 			}
-		case string:
-			buf = append(buf, '\'')
-			if mc.status&statusNoBackslashEscapes == 0 {
-				buf = escapeStringBackslash(buf, v)
-			} else {
-				buf = escapeStringQuotes(buf, v)
+		case LINE_FEED_BYTE:
+			if state == stateEOLComment {
+				state = stateNormal
 			}
-			buf = append(buf, '\'')
-		default:
-			return "", driver.ErrSkip
-		}
+		case DBL_QUOTE_BYTE:
+			if state == stateNormal {
+				state = stateString
+				singleQuotes = false
+			} else if state == stateString && !singleQuotes {
+				state = stateNormal
+			} else if state == stateEscape {
+				state = stateString
+			}
+		case QUOTE_BYTE:
+			if state == stateNormal {
+				state = stateString
+				singleQuotes = true
+			} else if state == stateString && singleQuotes {
+				state = stateNormal
+			} else if state == stateEscape {
+				state = stateString
+			}
+		case BACKSLASH_BYTE:
+			if state == stateString && !noBackslashEscapes {
+				state = stateEscape
+			}
+		case QUESTION_MARK_BYTE:
+			if state == stateNormal {
+				if argPos >= len(args) {
+					return "", driver.ErrSkip
+				}
+				buf = append(buf, query[lastIdx:i]...)
+				arg := args[argPos]
+				argPos++
 
-		if len(buf)+4 > mc.maxAllowedPacket {
-			return "", driver.ErrSkip
+				if arg == nil {
+					buf = append(buf, "NULL"...)
+					lastIdx = i + 1
+					break
+				}
+
+				switch v := arg.(type) {
+				case int64:
+					buf = strconv.AppendInt(buf, v, 10)
+				case uint64:
+					buf = strconv.AppendUint(buf, v, 10)
+				case float64:
+					buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+				case bool:
+					if v {
+						buf = append(buf, '1')
+					} else {
+						buf = append(buf, '0')
+					}
+				case time.Time:
+					if v.IsZero() {
+						buf = append(buf, "'0000-00-00'"...)
+					} else {
+						buf = append(buf, '\'')
+						buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
+						if err != nil {
+							return "", err
+						}
+						buf = append(buf, '\'')
+					}
+				case json.RawMessage:
+					if noBackslashEscapes {
+						buf = escapeBytesQuotes(buf, v, false)
+					} else {
+						buf = escapeBytesBackslash(buf, v, false)
+					}
+				case []byte:
+					if v == nil {
+						buf = append(buf, "NULL"...)
+					} else {
+						if noBackslashEscapes {
+							buf = escapeBytesQuotes(buf, v, true)
+						} else {
+							buf = escapeBytesBackslash(buf, v, true)
+						}
+					}
+				case string:
+					if noBackslashEscapes {
+						buf = escapeStringQuotes(buf, v)
+					} else {
+						buf = escapeStringBackslash(buf, v)
+					}
+				default:
+					return "", driver.ErrSkip
+				}
+
+				if len(buf)+4 > mc.maxAllowedPacket {
+					return "", driver.ErrSkip
+				}
+				lastIdx = i + 1
+			}
+		case BACKTICK_BYTE:
+			if state == stateBacktick {
+				state = stateNormal
+			} else if state == stateNormal {
+				state = stateBacktick
+			}
 		}
+		lastChar = currentChar
 	}
+	buf = append(buf, query[lastIdx:]...)
 	if argPos != len(args) {
 		return "", driver.ErrSkip
 	}
diff --git a/connection_test.go b/connection_test.go
index d489c1e..2827fd0 100644
--- a/connection_test.go
+++ b/connection_test.go
@@ -80,24 +80,6 @@
 	}
 }
 
-// We don't support placeholder in string literal for now.
-// https://github.com/go-sql-driver/mysql/pull/490
-func TestInterpolateParamsPlaceholderInString(t *testing.T) {
-	mc := &mysqlConn{
-		buf:              newBuffer(),
-		maxAllowedPacket: maxPacketSize,
-		cfg: &Config{
-			InterpolateParams: true,
-		},
-	}
-
-	q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
-	// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
-	if err != driver.ErrSkip {
-		t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
-	}
-}
-
 func TestInterpolateParamsUint64(t *testing.T) {
 	mc := &mysqlConn{
 		buf:              newBuffer(),
@@ -206,6 +188,64 @@
 	return nil
 }
 
+func TestInterpolateParamsWithComments(t *testing.T) {
+	mc := &mysqlConn{
+		buf:              newBuffer(),
+		maxAllowedPacket: maxPacketSize,
+		cfg: &Config{
+			InterpolateParams: true,
+		},
+	}
+
+	tests := []struct {
+		query      string
+		args       []driver.Value
+		expected   string
+		shouldSkip bool
+	}{
+		// ? in single-line comment (--) should not be replaced
+		{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
+		// ? in single-line comment (#) should not be replaced
+		{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
+		// ? in multi-line comment should not be replaced
+		{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
+		// ? in string literal should not be replaced
+		{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
+		// ? in backtick identifier should not be replaced
+		{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
+		// ? in backslash-escaped string literal should not be replaced
+		{"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
+		// ? in backslash-escaped string literal should not be replaced
+		{"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
+		// Multiple comments and real placeholders
+		{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
+		// 2--1: -- followed by digit is NOT a comment (it's the number 2 minus minus 1)
+		{"SELECT ?--1", []driver.Value{int64(2)}, "SELECT 2--1", false},
+		// /* */*: After closing block comment, */* should NOT start a new comment
+		{"SELECT /* comment */* ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* comment */* 1, 2", false},
+		// /* */*: More complex case with actual comment after
+		{"SELECT /* c1 */*/* c2 */ ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* c1 */*/* c2 */ 1, 2", false},
+	}
+
+	for i, test := range tests {
+
+		q, err := mc.interpolateParams(test.query, test.args)
+		if test.shouldSkip {
+			if err != driver.ErrSkip {
+				t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
+			}
+			continue
+		}
+		if err != nil {
+			t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
+			continue
+		}
+		if q != test.expected {
+			t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
+		}
+	}
+}
+
 // chunkedConn is a net.Conn that serves pre-built data chunks, one per Read
 // call. This simulates the behavior seen with TLS connections, where the
 // server's TLS library typically produces a separate TLS record per write
diff --git a/utils.go b/utils.go
index b041804..2dccb7d 100644
--- a/utils.go
+++ b/utils.go
@@ -625,108 +625,80 @@
 	return buf[:newSize]
 }
 
-// escapeBytesBackslash escapes []byte with backslashes (\)
-// This escapes the contents of a string (provided as []byte) by adding backslashes before special
-// characters, and turning others into specific escape sequences, such as
-// turning newlines into \n and null bytes into \0.
-// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
-func escapeBytesBackslash(buf, v []byte) []byte {
-	pos := len(buf)
-	buf = reserveBuffer(buf, len(v)*2)
+// Lookup table for backslash escapes (used for both string and bytes)
+var backslashEscapeTable [256]byte
 
-	for _, c := range v {
-		switch c {
-		case '\x00':
-			buf[pos+1] = '0'
-			buf[pos] = '\\'
-			pos += 2
-		case '\n':
-			buf[pos+1] = 'n'
-			buf[pos] = '\\'
-			pos += 2
-		case '\r':
-			buf[pos+1] = 'r'
-			buf[pos] = '\\'
-			pos += 2
-		case '\x1a':
-			buf[pos+1] = 'Z'
-			buf[pos] = '\\'
-			pos += 2
-		case '\'':
-			buf[pos+1] = '\''
-			buf[pos] = '\\'
-			pos += 2
-		case '"':
-			buf[pos+1] = '"'
-			buf[pos] = '\\'
-			pos += 2
-		case '\\':
-			buf[pos+1] = '\\'
-			buf[pos] = '\\'
-			pos += 2
-		default:
-			buf[pos] = c
-			pos++
-		}
-	}
-
-	return buf[:pos]
+func init() {
+	backslashEscapeTable['\x00'] = '0'
+	backslashEscapeTable['\n'] = 'n'
+	backslashEscapeTable['\r'] = 'r'
+	backslashEscapeTable['\x1a'] = 'Z'
+	backslashEscapeTable['\''] = '\''
+	backslashEscapeTable['"'] = '"'
+	backslashEscapeTable['\\'] = '\\'
 }
 
 // escapeStringBackslash is similar to escapeBytesBackslash but for string.
 func escapeStringBackslash(buf []byte, v string) []byte {
 	pos := len(buf)
-	buf = reserveBuffer(buf, len(v)*2)
-
-	for i := range len(v) {
+	buf = reserveBuffer(buf, len(v)*2+2)
+	buf[pos] = '\''
+	pos++
+	for i := 0; i < len(v); i++ {
 		c := v[i]
-		switch c {
-		case '\x00':
-			buf[pos+1] = '0'
+		if esc := backslashEscapeTable[c]; esc != 0 {
+			buf[pos+1] = esc
 			buf[pos] = '\\'
 			pos += 2
-		case '\n':
-			buf[pos+1] = 'n'
-			buf[pos] = '\\'
-			pos += 2
-		case '\r':
-			buf[pos+1] = 'r'
-			buf[pos] = '\\'
-			pos += 2
-		case '\x1a':
-			buf[pos+1] = 'Z'
-			buf[pos] = '\\'
-			pos += 2
-		case '\'':
-			buf[pos+1] = '\''
-			buf[pos] = '\\'
-			pos += 2
-		case '"':
-			buf[pos+1] = '"'
-			buf[pos] = '\\'
-			pos += 2
-		case '\\':
-			buf[pos+1] = '\\'
-			buf[pos] = '\\'
-			pos += 2
-		default:
+		} else {
 			buf[pos] = c
 			pos++
 		}
 	}
-
+	buf[pos] = '\''
+	pos++
 	return buf[:pos]
 }
 
-// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
-// This escapes the contents of a string by doubling up any apostrophes that
-// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
-// effect on the server.
-// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
-func escapeBytesQuotes(buf, v []byte) []byte {
+// escapeBytesBackslash appends _binary'...' or '...' with backslash escaping for bytes.
+func escapeBytesBackslash(buf, v []byte, binary bool) []byte {
 	pos := len(buf)
-	buf = reserveBuffer(buf, len(v)*2)
+	if binary {
+		buf = reserveBuffer(buf, len(v)*2+9)
+		copy(buf[pos:], []byte("_binary'"))
+		pos += 8
+	} else {
+		buf = reserveBuffer(buf, len(v)*2+2)
+		buf[pos] = '\''
+		pos++
+	}
+	for _, c := range v {
+		if esc := backslashEscapeTable[c]; esc != 0 {
+			buf[pos+1] = esc
+			buf[pos] = '\\'
+			pos += 2
+		} else {
+			buf[pos] = c
+			pos++
+		}
+	}
+	buf[pos] = '\''
+	pos++
+	return buf[:pos]
+}
 
+// escapeBytesQuotes appends _binary'...' or '...' with single-quote escaping for bytes.
+func escapeBytesQuotes(buf, v []byte, binary bool) []byte {
+	pos := len(buf)
+	if binary {
+		buf = reserveBuffer(buf, len(v)*2+9)
+		copy(buf[pos:], []byte("_binary'"))
+		pos += 8
+	} else {
+		buf = reserveBuffer(buf, len(v)*2+2)
+		buf[pos] = '\''
+		pos++
+	}
 	for _, c := range v {
 		if c == '\'' {
 			buf[pos+1] = '\''
@@ -737,15 +709,17 @@
 			pos++
 		}
 	}
-
+	buf[pos] = '\''
+	pos++
 	return buf[:pos]
 }
 
 // escapeStringQuotes is similar to escapeBytesQuotes but for string.
 func escapeStringQuotes(buf []byte, v string) []byte {
 	pos := len(buf)
-	buf = reserveBuffer(buf, len(v)*2)
-
+	buf = reserveBuffer(buf, len(v)*2+2)
+	buf[pos] = '\''
+	pos++
 	for i := range len(v) {
 		c := v[i]
 		if c == '\'' {
@@ -757,7 +731,8 @@
 			pos++
 		}
 	}
-
+	buf[pos] = '\''
+	pos++
 	return buf[:pos]
 }
 
diff --git a/utils_test.go b/utils_test.go
index 42a8839..4c171f6 100644
--- a/utils_test.go
+++ b/utils_test.go
@@ -120,7 +120,7 @@
 
 func TestEscapeBackslash(t *testing.T) {
 	expect := func(expected, value string) {
-		actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
+		actual := string(escapeBytesBackslash([]byte{}, []byte(value), false))
 		if actual != expected {
 			t.Errorf(
 				"expected %s, got %s",
@@ -137,18 +137,36 @@
 		}
 	}
 
-	expect("foo\\0bar", "foo\x00bar")
-	expect("foo\\nbar", "foo\nbar")
-	expect("foo\\rbar", "foo\rbar")
-	expect("foo\\Zbar", "foo\x1abar")
-	expect("foo\\\"bar", "foo\"bar")
-	expect("foo\\\\bar", "foo\\bar")
-	expect("foo\\'bar", "foo'bar")
+	expect("'foo\\0bar'", "foo\x00bar")
+	expect("'foo\\nbar'", "foo\nbar")
+	expect("'foo\\rbar'", "foo\rbar")
+	expect("'foo\\Zbar'", "foo\x1abar")
+	expect("'foo\\\"bar'", "foo\"bar")
+	expect("'foo\\\\bar'", "foo\\bar")
+	expect("'foo\\'bar'", "foo'bar")
+
+	// Test binary flag for escapeBytesBackslash
+	binExpect := func(expected, value string) {
+		actual := string(escapeBytesBackslash([]byte{}, []byte(value), true))
+		if actual != expected {
+			t.Errorf(
+				"expected %s, got %s (binary)",
+				expected, actual,
+			)
+		}
+	}
+	binExpect("_binary'foo\\0bar'", "foo\x00bar")
+	binExpect("_binary'foo\\nbar'", "foo\nbar")
+	binExpect("_binary'foo\\rbar'", "foo\rbar")
+	binExpect("_binary'foo\\Zbar'", "foo\x1abar")
+	binExpect("_binary'foo\\\"bar'", "foo\"bar")
+	binExpect("_binary'foo\\\\bar'", "foo\\bar")
+	binExpect("_binary'foo\\'bar'", "foo'bar")
 }
 
 func TestEscapeQuotes(t *testing.T) {
 	expect := func(expected, value string) {
-		actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
+		actual := string(escapeBytesQuotes([]byte{}, []byte(value), false))
 		if actual != expected {
 			t.Errorf(
 				"expected %s, got %s",
@@ -165,12 +183,29 @@
 		}
 	}
 
-	expect("foo\x00bar", "foo\x00bar") // not affected
-	expect("foo\nbar", "foo\nbar")     // not affected
-	expect("foo\rbar", "foo\rbar")     // not affected
-	expect("foo\x1abar", "foo\x1abar") // not affected
-	expect("foo''bar", "foo'bar")      // affected
-	expect("foo\"bar", "foo\"bar")     // not affected
+	expect("'foo\x00bar'", "foo\x00bar") // not affected
+	expect("'foo\nbar'", "foo\nbar")     // not affected
+	expect("'foo\rbar'", "foo\rbar")     // not affected
+	expect("'foo\x1abar'", "foo\x1abar") // not affected
+	expect("'foo''bar'", "foo'bar")      // affected
+	expect("'foo\"bar'", "foo\"bar")     // not affected
+
+	// Test binary flag for escapeBytesQuotes
+	binExpect := func(expected, value string) {
+		actual := string(escapeBytesQuotes([]byte{}, []byte(value), true))
+		if actual != expected {
+			t.Errorf(
+				"expected %s, got %s (binary)",
+				expected, actual,
+			)
+		}
+	}
+	binExpect("_binary'foo\x00bar'", "foo\x00bar")
+	binExpect("_binary'foo\nbar'", "foo\nbar")
+	binExpect("_binary'foo\rbar'", "foo\rbar")
+	binExpect("_binary'foo\x1abar'", "foo\x1abar")
+	binExpect("_binary'foo''bar'", "foo'bar")
+	binExpect("_binary'foo\"bar'", "foo\"bar")
 }
 
 func TestAtomicError(t *testing.T) {