From 66edfad9d2c331e4a463251262604d2b5642bb53 Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Sun, 6 Jul 2025 21:58:55 +0200 Subject: [PATCH] nhance `interpolateParams` to correctly handle placeholders in queries with comments, strings, and backticks. * Add `findParamPositions` to identify real parameter positions * Update and expand related tests. --- connection.go | 226 ++++++++++++++++++++++++++++++--------------- connection_test.go | 70 ++++++++++---- utils.go | 149 +++++++++++++----------------- utils_test.go | 65 ++++++++++--- 4 files changed, 313 insertions(+), 197 deletions(-) diff --git a/connection.go b/connection.go index 5648e47d..94d54634 100644 --- a/connection.go +++ b/connection.go @@ -172,7 +172,7 @@ func (mc *mysqlConn) close() { } // Closes the network connection and unsets internal variables. Do not call this -// function after successfully authentication, call Close instead. This function +// function after successful authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { @@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { - // Number of ? should be same to len(args) - if strings.Count(query, "?") != len(args) { - return "", driver.ErrSkip - } + noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0 + const ( + stateNormal = iota + stateString + stateEscape + stateEOLComment + stateSlashStarComment + stateBacktick + ) + + var ( + QUOTE_BYTE = byte('\'') + DBL_QUOTE_BYTE = byte('"') + BACKSLASH_BYTE = byte('\\') + QUESTION_MARK_BYTE = byte('?') + SLASH_BYTE = byte('/') + STAR_BYTE = byte('*') + HASH_BYTE = byte('#') + MINUS_BYTE = byte('-') + LINE_FEED_BYTE = byte('\n') + RADICAL_BYTE = byte('`') + ) buf, err := mc.buf.takeCompleteBuffer() if err != nil { - // can not take the buffer. Something must be wrong with the connection mc.cleanup() - // interpolateParams would be called before sending any query. - // So its safe to retry. return "", driver.ErrBadConn } buf = buf[:0] + state := stateNormal + singleQuotes := false + lastChar := byte(0) argPos := 0 - - for i := 0; i < len(query); i++ { - q := strings.IndexByte(query[i:], '?') - if q == -1 { - buf = append(buf, query[i:]...) - break - } - buf = append(buf, query[i:i+q]...) - i += q - - arg := args[argPos] - argPos++ - - if arg == nil { - buf = append(buf, "NULL"...) + lenQuery := len(query) + lastIdx := 0 + + for i := 0; i < lenQuery; i++ { + currentChar := query[i] + if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) { + state = stateString + lastChar = currentChar continue } - - switch v := arg.(type) { - case int64: - buf = strconv.AppendInt(buf, v, 10) - case uint64: - // Handle uint64 explicitly because our custom ConvertValue emits unsigned values - buf = strconv.AppendUint(buf, v, 10) - case float64: - buf = strconv.AppendFloat(buf, v, 'g', -1, 64) - case bool: - if v { - buf = append(buf, '1') - } else { - buf = append(buf, '0') + switch currentChar { + case STAR_BYTE: + if state == stateNormal && lastChar == SLASH_BYTE { + state = stateSlashStarComment } - case time.Time: - if v.IsZero() { - buf = append(buf, "'0000-00-00'"...) - } else { - buf = append(buf, '\'') - buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) - if err != nil { - return "", err - } - buf = append(buf, '\'') + case SLASH_BYTE: + if state == stateSlashStarComment && lastChar == STAR_BYTE { + state = stateNormal } - case json.RawMessage: - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBytesBackslash(buf, v) - } else { - buf = escapeBytesQuotes(buf, v) + case HASH_BYTE: + if state == stateNormal { + state = stateEOLComment } - buf = append(buf, '\'') - case []byte: - if v == nil { - buf = append(buf, "NULL"...) - } else { - buf = append(buf, "_binary'"...) - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBytesBackslash(buf, v) - } else { - buf = escapeBytesQuotes(buf, v) - } - buf = append(buf, '\'') + case MINUS_BYTE: + if state == stateNormal && lastChar == MINUS_BYTE { + state = stateEOLComment } - case string: - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeStringBackslash(buf, v) - } else { - buf = escapeStringQuotes(buf, v) + case LINE_FEED_BYTE: + if state == stateEOLComment { + state = stateNormal } - buf = append(buf, '\'') - default: - return "", driver.ErrSkip - } + case DBL_QUOTE_BYTE: + if state == stateNormal { + state = stateString + singleQuotes = false + } else if state == stateString && !singleQuotes { + state = stateNormal + } else if state == stateEscape { + state = stateString + } + case QUOTE_BYTE: + if state == stateNormal { + state = stateString + singleQuotes = true + } else if state == stateString && singleQuotes { + state = stateNormal + } else if state == stateEscape { + state = stateString + } + case BACKSLASH_BYTE: + if state == stateString && !noBackslashEscapes { + state = stateEscape + } + case QUESTION_MARK_BYTE: + if state == stateNormal { + if argPos >= len(args) { + return "", driver.ErrSkip + } + buf = append(buf, query[lastIdx:i]...) + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + lastIdx = i + 1 + break + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) + if err != nil { + return "", err + } + buf = append(buf, '\'') + } + case json.RawMessage: + if noBackslashEscapes { + buf = escapeBytesQuotes(buf, v, false) + } else { + buf = escapeBytesBackslash(buf, v, false) + } + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + if noBackslashEscapes { + buf = escapeBytesQuotes(buf, v, true) + } else { + buf = escapeBytesBackslash(buf, v, true) + } + } + case string: + if noBackslashEscapes { + buf = escapeStringQuotes(buf, v) + } else { + buf = escapeStringBackslash(buf, v) + } + default: + return "", driver.ErrSkip + } - if len(buf)+4 > mc.maxAllowedPacket { - return "", driver.ErrSkip + if len(buf)+4 > mc.maxAllowedPacket { + return "", driver.ErrSkip + } + lastIdx = i + 1 + } + case RADICAL_BYTE: + if state == stateBacktick { + state = stateNormal + } else if state == stateNormal { + state = stateBacktick + } } + lastChar = currentChar } + buf = append(buf, query[lastIdx:]...) if argPos != len(args) { return "", driver.ErrSkip } diff --git a/connection_test.go b/connection_test.go index 440ecbff..3b8c43bf 100644 --- a/connection_test.go +++ b/connection_test.go @@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { } } -// We don't support placeholder in string literal for now. -// https://github.com/go-sql-driver/mysql/pull/490 -func TestInterpolateParamsPlaceholderInString(t *testing.T) { - mc := &mysqlConn{ - buf: newBuffer(), - maxAllowedPacket: maxPacketSize, - cfg: &Config{ - InterpolateParams: true, - }, - } - - q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) - // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` - if err != driver.ErrSkip { - t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) - } -} - func TestInterpolateParamsUint64(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(), @@ -204,3 +186,55 @@ func (bc badConnection) Write(b []byte) (n int, err error) { func (bc badConnection) Close() error { return nil } + +func TestInterpolateParamsWithComments(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + tests := []struct { + query string + args []driver.Value + expected string + shouldSkip bool + }{ + // ? in single-line comment (--) should not be replaced + {"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false}, + // ? in single-line comment (#) should not be replaced + {"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false}, + // ? in multi-line comment should not be replaced + {"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false}, + // ? in string literal should not be replaced + {"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false}, + // ? in backtick identifier should not be replaced + {"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false}, + // ? in backslash-escaped string literal should not be replaced + {"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false}, + // ? in backslash-escaped string literal should not be replaced + {"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false}, + // Multiple comments and real placeholders + {"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true}, + } + + for i, test := range tests { + + q, err := mc.interpolateParams(test.query, test.args) + if test.shouldSkip { + if err != driver.ErrSkip { + t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q) + } + continue + } + if err != nil { + t.Errorf("Test %d: Expected err=nil, got %#v", i, err) + continue + } + if q != test.expected { + t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q) + } + } +} diff --git a/utils.go b/utils.go index b041804d..18d8519b 100644 --- a/utils.go +++ b/utils.go @@ -625,139 +625,114 @@ func reserveBuffer(buf []byte, appendSize int) []byte { return buf[:newSize] } -// escapeBytesBackslash escapes []byte with backslashes (\) -// This escapes the contents of a string (provided as []byte) by adding backslashes before special -// characters, and turning others into specific escape sequences, such as -// turning newlines into \n and null bytes into \0. -// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 -func escapeBytesBackslash(buf, v []byte) []byte { - pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) +// Lookup table for backslash escapes (used for both string and bytes) +var backslashEscapeTable [256]byte - for _, c := range v { - switch c { - case '\x00': - buf[pos+1] = '0' - buf[pos] = '\\' - pos += 2 - case '\n': - buf[pos+1] = 'n' - buf[pos] = '\\' - pos += 2 - case '\r': - buf[pos+1] = 'r' - buf[pos] = '\\' - pos += 2 - case '\x1a': - buf[pos+1] = 'Z' - buf[pos] = '\\' - pos += 2 - case '\'': - buf[pos+1] = '\'' - buf[pos] = '\\' - pos += 2 - case '"': - buf[pos+1] = '"' - buf[pos] = '\\' - pos += 2 - case '\\': - buf[pos+1] = '\\' +func init() { + backslashEscapeTable['\x00'] = '0' + backslashEscapeTable['\n'] = 'n' + backslashEscapeTable['\r'] = 'r' + backslashEscapeTable['\x1a'] = 'Z' + backslashEscapeTable['\''] = '\'' + backslashEscapeTable['"'] = '"' + backslashEscapeTable['\\'] = '\\' +} + +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + for i := 0; i < len(v); i++ { + c := v[i] + if esc := backslashEscapeTable[c]; esc != 0 { buf[pos] = '\\' + buf[pos+1] = esc pos += 2 - default: + } else { buf[pos] = c pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } -// escapeStringBackslash is similar to escapeBytesBackslash but for string. -func escapeStringBackslash(buf []byte, v string) []byte { +// escapeBytesBackslash appends _binary'...' or '...' with backslash escaping for bytes. +func escapeBytesBackslash(buf, v []byte, binary bool) []byte { pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) - - for i := range len(v) { - c := v[i] - switch c { - case '\x00': - buf[pos+1] = '0' - buf[pos] = '\\' - pos += 2 - case '\n': - buf[pos+1] = 'n' - buf[pos] = '\\' - pos += 2 - case '\r': - buf[pos+1] = 'r' - buf[pos] = '\\' - pos += 2 - case '\x1a': - buf[pos+1] = 'Z' - buf[pos] = '\\' - pos += 2 - case '\'': - buf[pos+1] = '\'' - buf[pos] = '\\' - pos += 2 - case '"': - buf[pos+1] = '"' - buf[pos] = '\\' - pos += 2 - case '\\': - buf[pos+1] = '\\' + if binary { + buf = reserveBuffer(buf, len(v)*2+9) + copy(buf[pos:], []byte("_binary'")) + pos += 8 + } else { + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + } + for _, c := range v { + if esc := backslashEscapeTable[c]; esc != 0 { buf[pos] = '\\' + buf[pos+1] = esc pos += 2 - default: + } else { buf[pos] = c pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } -// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. -// This escapes the contents of a string by doubling up any apostrophes that -// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in -// effect on the server. -// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 -func escapeBytesQuotes(buf, v []byte) []byte { +// escapeBytesQuotes appends _binary'...' or '...' with single-quote escaping for bytes. +func escapeBytesQuotes(buf, v []byte, binary bool) []byte { pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) - + if binary { + buf = reserveBuffer(buf, len(v)*2+9) + copy(buf[pos:], []byte("_binary'")) + pos += 8 + } else { + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + } for _, c := range v { if c == '\'' { - buf[pos+1] = '\'' buf[pos] = '\'' + buf[pos+1] = '\'' pos += 2 } else { buf[pos] = c pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } // escapeStringQuotes is similar to escapeBytesQuotes but for string. func escapeStringQuotes(buf []byte, v string) []byte { pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) - + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ for i := range len(v) { c := v[i] if c == '\'' { - buf[pos+1] = '\'' buf[pos] = '\'' + buf[pos+1] = '\'' pos += 2 } else { buf[pos] = c pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } diff --git a/utils_test.go b/utils_test.go index 42a88393..4c171f62 100644 --- a/utils_test.go +++ b/utils_test.go @@ -120,7 +120,7 @@ func TestFormatBinaryTime(t *testing.T) { func TestEscapeBackslash(t *testing.T) { expect := func(expected, value string) { - actual := string(escapeBytesBackslash([]byte{}, []byte(value))) + actual := string(escapeBytesBackslash([]byte{}, []byte(value), false)) if actual != expected { t.Errorf( "expected %s, got %s", @@ -137,18 +137,36 @@ func TestEscapeBackslash(t *testing.T) { } } - expect("foo\\0bar", "foo\x00bar") - expect("foo\\nbar", "foo\nbar") - expect("foo\\rbar", "foo\rbar") - expect("foo\\Zbar", "foo\x1abar") - expect("foo\\\"bar", "foo\"bar") - expect("foo\\\\bar", "foo\\bar") - expect("foo\\'bar", "foo'bar") + expect("'foo\\0bar'", "foo\x00bar") + expect("'foo\\nbar'", "foo\nbar") + expect("'foo\\rbar'", "foo\rbar") + expect("'foo\\Zbar'", "foo\x1abar") + expect("'foo\\\"bar'", "foo\"bar") + expect("'foo\\\\bar'", "foo\\bar") + expect("'foo\\'bar'", "foo'bar") + + // Test binary flag for escapeBytesBackslash + binExpect := func(expected, value string) { + actual := string(escapeBytesBackslash([]byte{}, []byte(value), true)) + if actual != expected { + t.Errorf( + "expected %s, got %s (binary)", + expected, actual, + ) + } + } + binExpect("_binary'foo\\0bar'", "foo\x00bar") + binExpect("_binary'foo\\nbar'", "foo\nbar") + binExpect("_binary'foo\\rbar'", "foo\rbar") + binExpect("_binary'foo\\Zbar'", "foo\x1abar") + binExpect("_binary'foo\\\"bar'", "foo\"bar") + binExpect("_binary'foo\\\\bar'", "foo\\bar") + binExpect("_binary'foo\\'bar'", "foo'bar") } func TestEscapeQuotes(t *testing.T) { expect := func(expected, value string) { - actual := string(escapeBytesQuotes([]byte{}, []byte(value))) + actual := string(escapeBytesQuotes([]byte{}, []byte(value), false)) if actual != expected { t.Errorf( "expected %s, got %s", @@ -165,12 +183,29 @@ func TestEscapeQuotes(t *testing.T) { } } - expect("foo\x00bar", "foo\x00bar") // not affected - expect("foo\nbar", "foo\nbar") // not affected - expect("foo\rbar", "foo\rbar") // not affected - expect("foo\x1abar", "foo\x1abar") // not affected - expect("foo''bar", "foo'bar") // affected - expect("foo\"bar", "foo\"bar") // not affected + expect("'foo\x00bar'", "foo\x00bar") // not affected + expect("'foo\nbar'", "foo\nbar") // not affected + expect("'foo\rbar'", "foo\rbar") // not affected + expect("'foo\x1abar'", "foo\x1abar") // not affected + expect("'foo''bar'", "foo'bar") // affected + expect("'foo\"bar'", "foo\"bar") // not affected + + // Test binary flag for escapeBytesQuotes + binExpect := func(expected, value string) { + actual := string(escapeBytesQuotes([]byte{}, []byte(value), true)) + if actual != expected { + t.Errorf( + "expected %s, got %s (binary)", + expected, actual, + ) + } + } + binExpect("_binary'foo\x00bar'", "foo\x00bar") + binExpect("_binary'foo\nbar'", "foo\nbar") + binExpect("_binary'foo\rbar'", "foo\rbar") + binExpect("_binary'foo\x1abar'", "foo\x1abar") + binExpect("_binary'foo''bar'", "foo'bar") + binExpect("_binary'foo\"bar'", "foo\"bar") } func TestAtomicError(t *testing.T) {