Enhance interpolateParams to correctly handle placeholders (#1732)
Enhance client side statement to correctly handle placeholders in
queries with comments, strings, and backticks.
diff --git a/connection.go b/connection.go
index b1660a5..ab15668 100644
--- a/connection.go
+++ b/connection.go
@@ -172,7 +172,7 @@
}
// Closes the network connection and unsets internal variables. Do not call this
-// function after successfully authentication, call Close instead. This function
+// function after successful authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
@@ -246,100 +246,184 @@
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
- // Number of ? should be same to len(args)
- if strings.Count(query, "?") != len(args) {
- return "", driver.ErrSkip
- }
+ noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
+ const (
+ stateNormal = iota
+ stateString
+ stateEscape
+ stateEOLComment
+ stateSlashStarComment
+ stateBacktick
+ )
+
+ const (
+ QUOTE_BYTE = byte('\'')
+ DBL_QUOTE_BYTE = byte('"')
+ BACKSLASH_BYTE = byte('\\')
+ QUESTION_MARK_BYTE = byte('?')
+ SLASH_BYTE = byte('/')
+ STAR_BYTE = byte('*')
+ HASH_BYTE = byte('#')
+ MINUS_BYTE = byte('-')
+ LINE_FEED_BYTE = byte('\n')
+ BACKTICK_BYTE = byte('`')
+ )
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
- // can not take the buffer. Something must be wrong with the connection
mc.cleanup()
- // interpolateParams would be called before sending any query.
- // So its safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
+ state := stateNormal
+ singleQuotes := false
+ lastChar := byte(0)
argPos := 0
+ lenQuery := len(query)
+ lastIdx := 0
- for i := 0; i < len(query); i++ {
- q := strings.IndexByte(query[i:], '?')
- if q == -1 {
- buf = append(buf, query[i:]...)
- break
- }
- buf = append(buf, query[i:i+q]...)
- i += q
-
- arg := args[argPos]
- argPos++
-
- if arg == nil {
- buf = append(buf, "NULL"...)
+ for i := 0; i < lenQuery; i++ {
+ currentChar := query[i]
+ if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
+ state = stateString
+ lastChar = currentChar
continue
}
-
- switch v := arg.(type) {
- case int64:
- buf = strconv.AppendInt(buf, v, 10)
- case uint64:
- // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
- buf = strconv.AppendUint(buf, v, 10)
- case float64:
- buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
- case bool:
- if v {
- buf = append(buf, '1')
- } else {
- buf = append(buf, '0')
+ switch currentChar {
+ case STAR_BYTE:
+ if state == stateNormal && lastChar == SLASH_BYTE {
+ state = stateSlashStarComment
}
- case time.Time:
- if v.IsZero() {
- buf = append(buf, "'0000-00-00'"...)
- } else {
- buf = append(buf, '\'')
- buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
- if err != nil {
- return "", err
- }
- buf = append(buf, '\'')
+ case SLASH_BYTE:
+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
+ state = stateNormal
+ // Clear lastChar so the '/' that closed the comment isn't
+ // reused to start a new comment with a following '*'.
+ lastChar = 0
+ continue
}
- case json.RawMessage:
- buf = append(buf, '\'')
- if mc.status&statusNoBackslashEscapes == 0 {
- buf = escapeBytesBackslash(buf, v)
- } else {
- buf = escapeBytesQuotes(buf, v)
+ case HASH_BYTE:
+ if state == stateNormal {
+ state = stateEOLComment
}
- buf = append(buf, '\'')
- case []byte:
- if v == nil {
- buf = append(buf, "NULL"...)
- } else {
- buf = append(buf, "_binary'"...)
- if mc.status&statusNoBackslashEscapes == 0 {
- buf = escapeBytesBackslash(buf, v)
+ case MINUS_BYTE:
+ if state == stateNormal && lastChar == MINUS_BYTE {
+ // -- only starts a comment if followed by whitespace or control char
+ if i+1 < lenQuery {
+ nextChar := query[i+1]
+ if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
+ state = stateEOLComment
+ }
} else {
- buf = escapeBytesQuotes(buf, v)
+ state = stateEOLComment
}
- buf = append(buf, '\'')
}
- case string:
- buf = append(buf, '\'')
- if mc.status&statusNoBackslashEscapes == 0 {
- buf = escapeStringBackslash(buf, v)
- } else {
- buf = escapeStringQuotes(buf, v)
+ case LINE_FEED_BYTE:
+ if state == stateEOLComment {
+ state = stateNormal
}
- buf = append(buf, '\'')
- default:
- return "", driver.ErrSkip
- }
+ case DBL_QUOTE_BYTE:
+ if state == stateNormal {
+ state = stateString
+ singleQuotes = false
+ } else if state == stateString && !singleQuotes {
+ state = stateNormal
+ } else if state == stateEscape {
+ state = stateString
+ }
+ case QUOTE_BYTE:
+ if state == stateNormal {
+ state = stateString
+ singleQuotes = true
+ } else if state == stateString && singleQuotes {
+ state = stateNormal
+ } else if state == stateEscape {
+ state = stateString
+ }
+ case BACKSLASH_BYTE:
+ if state == stateString && !noBackslashEscapes {
+ state = stateEscape
+ }
+ case QUESTION_MARK_BYTE:
+ if state == stateNormal {
+ if argPos >= len(args) {
+ return "", driver.ErrSkip
+ }
+ buf = append(buf, query[lastIdx:i]...)
+ arg := args[argPos]
+ argPos++
- if len(buf)+4 > mc.maxAllowedPacket {
- return "", driver.ErrSkip
+ if arg == nil {
+ buf = append(buf, "NULL"...)
+ lastIdx = i + 1
+ break
+ }
+
+ switch v := arg.(type) {
+ case int64:
+ buf = strconv.AppendInt(buf, v, 10)
+ case uint64:
+ buf = strconv.AppendUint(buf, v, 10)
+ case float64:
+ buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+ case bool:
+ if v {
+ buf = append(buf, '1')
+ } else {
+ buf = append(buf, '0')
+ }
+ case time.Time:
+ if v.IsZero() {
+ buf = append(buf, "'0000-00-00'"...)
+ } else {
+ buf = append(buf, '\'')
+ buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
+ if err != nil {
+ return "", err
+ }
+ buf = append(buf, '\'')
+ }
+ case json.RawMessage:
+ if noBackslashEscapes {
+ buf = escapeBytesQuotes(buf, v, false)
+ } else {
+ buf = escapeBytesBackslash(buf, v, false)
+ }
+ case []byte:
+ if v == nil {
+ buf = append(buf, "NULL"...)
+ } else {
+ if noBackslashEscapes {
+ buf = escapeBytesQuotes(buf, v, true)
+ } else {
+ buf = escapeBytesBackslash(buf, v, true)
+ }
+ }
+ case string:
+ if noBackslashEscapes {
+ buf = escapeStringQuotes(buf, v)
+ } else {
+ buf = escapeStringBackslash(buf, v)
+ }
+ default:
+ return "", driver.ErrSkip
+ }
+
+ if len(buf)+4 > mc.maxAllowedPacket {
+ return "", driver.ErrSkip
+ }
+ lastIdx = i + 1
+ }
+ case BACKTICK_BYTE:
+ if state == stateBacktick {
+ state = stateNormal
+ } else if state == stateNormal {
+ state = stateBacktick
+ }
}
+ lastChar = currentChar
}
+ buf = append(buf, query[lastIdx:]...)
if argPos != len(args) {
return "", driver.ErrSkip
}
diff --git a/connection_test.go b/connection_test.go
index d489c1e..2827fd0 100644
--- a/connection_test.go
+++ b/connection_test.go
@@ -80,24 +80,6 @@
}
}
-// We don't support placeholder in string literal for now.
-// https://github.com/go-sql-driver/mysql/pull/490
-func TestInterpolateParamsPlaceholderInString(t *testing.T) {
- mc := &mysqlConn{
- buf: newBuffer(),
- maxAllowedPacket: maxPacketSize,
- cfg: &Config{
- InterpolateParams: true,
- },
- }
-
- q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
- // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
- if err != driver.ErrSkip {
- t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
- }
-}
-
func TestInterpolateParamsUint64(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
@@ -206,6 +188,64 @@
return nil
}
+func TestInterpolateParamsWithComments(t *testing.T) {
+ mc := &mysqlConn{
+ buf: newBuffer(),
+ maxAllowedPacket: maxPacketSize,
+ cfg: &Config{
+ InterpolateParams: true,
+ },
+ }
+
+ tests := []struct {
+ query string
+ args []driver.Value
+ expected string
+ shouldSkip bool
+ }{
+ // ? in single-line comment (--) should not be replaced
+ {"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
+ // ? in single-line comment (#) should not be replaced
+ {"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
+ // ? in multi-line comment should not be replaced
+ {"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
+ // ? in string literal should not be replaced
+ {"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
+ // ? in backtick identifier should not be replaced
+ {"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
+ // ? in backslash-escaped string literal should not be replaced
+ {"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
+ // ? in backslash-escaped string literal should not be replaced
+ {"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
+ // Multiple comments and real placeholders
+ {"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
+ // 2--1: -- followed by digit is NOT a comment (it's the number 2 minus minus 1)
+ {"SELECT ?--1", []driver.Value{int64(2)}, "SELECT 2--1", false},
+ // /* */*: After closing block comment, */* should NOT start a new comment
+ {"SELECT /* comment */* ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* comment */* 1, 2", false},
+ // /* */*: More complex case with actual comment after
+ {"SELECT /* c1 */*/* c2 */ ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* c1 */*/* c2 */ 1, 2", false},
+ }
+
+ for i, test := range tests {
+
+ q, err := mc.interpolateParams(test.query, test.args)
+ if test.shouldSkip {
+ if err != driver.ErrSkip {
+ t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
+ continue
+ }
+ if q != test.expected {
+ t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
+ }
+ }
+}
+
// chunkedConn is a net.Conn that serves pre-built data chunks, one per Read
// call. This simulates the behavior seen with TLS connections, where the
// server's TLS library typically produces a separate TLS record per write
diff --git a/utils.go b/utils.go
index b041804..2dccb7d 100644
--- a/utils.go
+++ b/utils.go
@@ -625,108 +625,80 @@
return buf[:newSize]
}
-// escapeBytesBackslash escapes []byte with backslashes (\)
-// This escapes the contents of a string (provided as []byte) by adding backslashes before special
-// characters, and turning others into specific escape sequences, such as
-// turning newlines into \n and null bytes into \0.
-// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
-func escapeBytesBackslash(buf, v []byte) []byte {
- pos := len(buf)
- buf = reserveBuffer(buf, len(v)*2)
+// Lookup table for backslash escapes (used for both string and bytes)
+var backslashEscapeTable [256]byte
- for _, c := range v {
- switch c {
- case '\x00':
- buf[pos+1] = '0'
- buf[pos] = '\\'
- pos += 2
- case '\n':
- buf[pos+1] = 'n'
- buf[pos] = '\\'
- pos += 2
- case '\r':
- buf[pos+1] = 'r'
- buf[pos] = '\\'
- pos += 2
- case '\x1a':
- buf[pos+1] = 'Z'
- buf[pos] = '\\'
- pos += 2
- case '\'':
- buf[pos+1] = '\''
- buf[pos] = '\\'
- pos += 2
- case '"':
- buf[pos+1] = '"'
- buf[pos] = '\\'
- pos += 2
- case '\\':
- buf[pos+1] = '\\'
- buf[pos] = '\\'
- pos += 2
- default:
- buf[pos] = c
- pos++
- }
- }
-
- return buf[:pos]
+func init() {
+ backslashEscapeTable['\x00'] = '0'
+ backslashEscapeTable['\n'] = 'n'
+ backslashEscapeTable['\r'] = 'r'
+ backslashEscapeTable['\x1a'] = 'Z'
+ backslashEscapeTable['\''] = '\''
+ backslashEscapeTable['"'] = '"'
+ backslashEscapeTable['\\'] = '\\'
}
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
func escapeStringBackslash(buf []byte, v string) []byte {
pos := len(buf)
- buf = reserveBuffer(buf, len(v)*2)
-
- for i := range len(v) {
+ buf = reserveBuffer(buf, len(v)*2+2)
+ buf[pos] = '\''
+ pos++
+ for i := 0; i < len(v); i++ {
c := v[i]
- switch c {
- case '\x00':
- buf[pos+1] = '0'
+ if esc := backslashEscapeTable[c]; esc != 0 {
+ buf[pos+1] = esc
buf[pos] = '\\'
pos += 2
- case '\n':
- buf[pos+1] = 'n'
- buf[pos] = '\\'
- pos += 2
- case '\r':
- buf[pos+1] = 'r'
- buf[pos] = '\\'
- pos += 2
- case '\x1a':
- buf[pos+1] = 'Z'
- buf[pos] = '\\'
- pos += 2
- case '\'':
- buf[pos+1] = '\''
- buf[pos] = '\\'
- pos += 2
- case '"':
- buf[pos+1] = '"'
- buf[pos] = '\\'
- pos += 2
- case '\\':
- buf[pos+1] = '\\'
- buf[pos] = '\\'
- pos += 2
- default:
+ } else {
buf[pos] = c
pos++
}
}
-
+ buf[pos] = '\''
+ pos++
return buf[:pos]
}
-// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
-// This escapes the contents of a string by doubling up any apostrophes that
-// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
-// effect on the server.
-// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
-func escapeBytesQuotes(buf, v []byte) []byte {
+// escapeBytesBackslash appends _binary'...' or '...' with backslash escaping for bytes.
+func escapeBytesBackslash(buf, v []byte, binary bool) []byte {
pos := len(buf)
- buf = reserveBuffer(buf, len(v)*2)
+ if binary {
+ buf = reserveBuffer(buf, len(v)*2+9)
+ copy(buf[pos:], []byte("_binary'"))
+ pos += 8
+ } else {
+ buf = reserveBuffer(buf, len(v)*2+2)
+ buf[pos] = '\''
+ pos++
+ }
+ for _, c := range v {
+ if esc := backslashEscapeTable[c]; esc != 0 {
+ buf[pos+1] = esc
+ buf[pos] = '\\'
+ pos += 2
+ } else {
+ buf[pos] = c
+ pos++
+ }
+ }
+ buf[pos] = '\''
+ pos++
+ return buf[:pos]
+}
+// escapeBytesQuotes appends _binary'...' or '...' with single-quote escaping for bytes.
+func escapeBytesQuotes(buf, v []byte, binary bool) []byte {
+ pos := len(buf)
+ if binary {
+ buf = reserveBuffer(buf, len(v)*2+9)
+ copy(buf[pos:], []byte("_binary'"))
+ pos += 8
+ } else {
+ buf = reserveBuffer(buf, len(v)*2+2)
+ buf[pos] = '\''
+ pos++
+ }
for _, c := range v {
if c == '\'' {
buf[pos+1] = '\''
@@ -737,15 +709,17 @@
pos++
}
}
-
+ buf[pos] = '\''
+ pos++
return buf[:pos]
}
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
func escapeStringQuotes(buf []byte, v string) []byte {
pos := len(buf)
- buf = reserveBuffer(buf, len(v)*2)
-
+ buf = reserveBuffer(buf, len(v)*2+2)
+ buf[pos] = '\''
+ pos++
for i := range len(v) {
c := v[i]
if c == '\'' {
@@ -757,7 +731,8 @@
pos++
}
}
-
+ buf[pos] = '\''
+ pos++
return buf[:pos]
}
diff --git a/utils_test.go b/utils_test.go
index 42a8839..4c171f6 100644
--- a/utils_test.go
+++ b/utils_test.go
@@ -120,7 +120,7 @@
func TestEscapeBackslash(t *testing.T) {
expect := func(expected, value string) {
- actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
+ actual := string(escapeBytesBackslash([]byte{}, []byte(value), false))
if actual != expected {
t.Errorf(
"expected %s, got %s",
@@ -137,18 +137,36 @@
}
}
- expect("foo\\0bar", "foo\x00bar")
- expect("foo\\nbar", "foo\nbar")
- expect("foo\\rbar", "foo\rbar")
- expect("foo\\Zbar", "foo\x1abar")
- expect("foo\\\"bar", "foo\"bar")
- expect("foo\\\\bar", "foo\\bar")
- expect("foo\\'bar", "foo'bar")
+ expect("'foo\\0bar'", "foo\x00bar")
+ expect("'foo\\nbar'", "foo\nbar")
+ expect("'foo\\rbar'", "foo\rbar")
+ expect("'foo\\Zbar'", "foo\x1abar")
+ expect("'foo\\\"bar'", "foo\"bar")
+ expect("'foo\\\\bar'", "foo\\bar")
+ expect("'foo\\'bar'", "foo'bar")
+
+ // Test binary flag for escapeBytesBackslash
+ binExpect := func(expected, value string) {
+ actual := string(escapeBytesBackslash([]byte{}, []byte(value), true))
+ if actual != expected {
+ t.Errorf(
+ "expected %s, got %s (binary)",
+ expected, actual,
+ )
+ }
+ }
+ binExpect("_binary'foo\\0bar'", "foo\x00bar")
+ binExpect("_binary'foo\\nbar'", "foo\nbar")
+ binExpect("_binary'foo\\rbar'", "foo\rbar")
+ binExpect("_binary'foo\\Zbar'", "foo\x1abar")
+ binExpect("_binary'foo\\\"bar'", "foo\"bar")
+ binExpect("_binary'foo\\\\bar'", "foo\\bar")
+ binExpect("_binary'foo\\'bar'", "foo'bar")
}
func TestEscapeQuotes(t *testing.T) {
expect := func(expected, value string) {
- actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
+ actual := string(escapeBytesQuotes([]byte{}, []byte(value), false))
if actual != expected {
t.Errorf(
"expected %s, got %s",
@@ -165,12 +183,29 @@
}
}
- expect("foo\x00bar", "foo\x00bar") // not affected
- expect("foo\nbar", "foo\nbar") // not affected
- expect("foo\rbar", "foo\rbar") // not affected
- expect("foo\x1abar", "foo\x1abar") // not affected
- expect("foo''bar", "foo'bar") // affected
- expect("foo\"bar", "foo\"bar") // not affected
+ expect("'foo\x00bar'", "foo\x00bar") // not affected
+ expect("'foo\nbar'", "foo\nbar") // not affected
+ expect("'foo\rbar'", "foo\rbar") // not affected
+ expect("'foo\x1abar'", "foo\x1abar") // not affected
+ expect("'foo''bar'", "foo'bar") // affected
+ expect("'foo\"bar'", "foo\"bar") // not affected
+
+ // Test binary flag for escapeBytesQuotes
+ binExpect := func(expected, value string) {
+ actual := string(escapeBytesQuotes([]byte{}, []byte(value), true))
+ if actual != expected {
+ t.Errorf(
+ "expected %s, got %s (binary)",
+ expected, actual,
+ )
+ }
+ }
+ binExpect("_binary'foo\x00bar'", "foo\x00bar")
+ binExpect("_binary'foo\nbar'", "foo\nbar")
+ binExpect("_binary'foo\rbar'", "foo\rbar")
+ binExpect("_binary'foo\x1abar'", "foo\x1abar")
+ binExpect("_binary'foo''bar'", "foo'bar")
+ binExpect("_binary'foo\"bar'", "foo\"bar")
}
func TestAtomicError(t *testing.T) {