Skip to content

Enhance interpolateParams to correctly handle placeholders #1732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 149 additions & 77 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
}

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// function after successful authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
Expand Down Expand Up @@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
}

func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
const (
stateNormal = iota
stateString
stateEscape
stateEOLComment
stateSlashStarComment
stateBacktick
)

var (
QUOTE_BYTE = byte('\'')
DBL_QUOTE_BYTE = byte('"')
BACKSLASH_BYTE = byte('\\')
QUESTION_MARK_BYTE = byte('?')
SLASH_BYTE = byte('/')
STAR_BYTE = byte('*')
HASH_BYTE = byte('#')
MINUS_BYTE = byte('-')
LINE_FEED_BYTE = byte('\n')
RADICAL_BYTE = byte('`')
)

buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.cleanup()
// interpolateParams would be called before sending any query.
// So its safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
state := stateNormal
singleQuotes := false
lastChar := byte(0)
argPos := 0

for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q

arg := args[argPos]
argPos++

if arg == nil {
buf = append(buf, "NULL"...)
lenQuery := len(query)
lastIdx := 0

for i := 0; i < lenQuery; i++ {
currentChar := query[i]
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
state = stateString
lastChar = currentChar
continue
}

switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
switch currentChar {
case STAR_BYTE:
if state == stateNormal && lastChar == SLASH_BYTE {
state = stateSlashStarComment
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
case SLASH_BYTE:
if state == stateSlashStarComment && lastChar == STAR_BYTE {
state = stateNormal
}
case json.RawMessage:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
case HASH_BYTE:
if state == stateNormal {
state = stateEOLComment
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
case MINUS_BYTE:
if state == stateNormal && lastChar == MINUS_BYTE {
state = stateEOLComment
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
case LINE_FEED_BYTE:
if state == stateEOLComment {
state = stateNormal
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
case DBL_QUOTE_BYTE:
if state == stateNormal {
state = stateString
singleQuotes = false
} else if state == stateString && !singleQuotes {
state = stateNormal
} else if state == stateEscape {
state = stateString
}
case QUOTE_BYTE:
if state == stateNormal {
state = stateString
singleQuotes = true
} else if state == stateString && singleQuotes {
state = stateNormal
} else if state == stateEscape {
state = stateString
}
case BACKSLASH_BYTE:
if state == stateString && !noBackslashEscapes {
state = stateEscape
}
case QUESTION_MARK_BYTE:
if state == stateNormal {
if argPos >= len(args) {
return "", driver.ErrSkip
}
buf = append(buf, query[lastIdx:i]...)
arg := args[argPos]
argPos++

if arg == nil {
buf = append(buf, "NULL"...)
lastIdx = i + 1
break
}

switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
if noBackslashEscapes {
buf = escapeBytesQuotes(buf, v, false)
} else {
buf = escapeBytesBackslash(buf, v, false)
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
if noBackslashEscapes {
buf = escapeBytesQuotes(buf, v, true)
} else {
buf = escapeBytesBackslash(buf, v, true)
}
}
case string:
if noBackslashEscapes {
buf = escapeStringQuotes(buf, v)
} else {
buf = escapeStringBackslash(buf, v)
}
default:
return "", driver.ErrSkip
}

if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
lastIdx = i + 1
}
case RADICAL_BYTE:
if state == stateBacktick {
state = stateNormal
} else if state == stateNormal {
state = stateBacktick
}
}
lastChar = currentChar
}
buf = append(buf, query[lastIdx:]...)
if argPos != len(args) {
return "", driver.ErrSkip
}
Expand Down
70 changes: 52 additions & 18 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
}
}

// We don't support placeholder in string literal for now.
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}

q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
if err != driver.ErrSkip {
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
}
}

func TestInterpolateParamsUint64(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
Expand Down Expand Up @@ -204,3 +186,55 @@ func (bc badConnection) Write(b []byte) (n int, err error) {
func (bc badConnection) Close() error {
return nil
}

func TestInterpolateParamsWithComments(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}

tests := []struct {
query string
args []driver.Value
expected string
shouldSkip bool
}{
// ? in single-line comment (--) should not be replaced
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
// ? in single-line comment (#) should not be replaced
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
// ? in multi-line comment should not be replaced
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
// ? in string literal should not be replaced
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
// ? in backtick identifier should not be replaced
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
// ? in backslash-escaped string literal should not be replaced
{"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
// ? in backslash-escaped string literal should not be replaced
{"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
// Multiple comments and real placeholders
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
}

for i, test := range tests {

q, err := mc.interpolateParams(test.query, test.args)
if test.shouldSkip {
if err != driver.ErrSkip {
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
}
continue
}
if err != nil {
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
continue
}
if q != test.expected {
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
}
}
}
Loading