Skip to content

Commit ee48e78

Browse files
dvilaverdemefcorvi
andauthored
allow disabling the default golang database/sql retry behavior (#899)
* allow disabling the default golang database retry behavior * fixing comment * fixing comment * fix(canal): handle fake rotate events correctly for MariaDB 11.4 (#894) After upgrading to MariaDB 11.4, the canal module stopped detecting row updates within transactions due to incorrect handling of fake rotate events. MariaDB 11.4 does not set LogPos for certain events, causing these events to be ignored. This fix modifies the handling to consider fake rotate events only for ROTATE_EVENTs with timestamp = 0, aligning with MariaDB and MySQL documentation. * incorporating PR feedback --------- Co-authored-by: Bulat Aikaev <[email protected]>
1 parent 3b1dd0e commit ee48e78

File tree

3 files changed

+115
-9
lines changed

3 files changed

+115
-9
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,18 @@ golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.
426426
| --------- | --------- | ----------------------------------------------- |
427427
| duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s |
428428

429+
#### `retries`
430+
431+
Allows disabling the golang `database/sql` default behavior to retry errors
432+
when `ErrBadConn` is returned by the driver. When retries are disabled
433+
this driver will not return `ErrBadConn` from the `database/sql` package.
434+
435+
Valid values are `on` (default) and `off`.
436+
437+
| Type | Default | Example |
438+
| --------- | --------- | ----------------------------------------------- |
439+
| string | on | user:pass@localhost/mydb?retries=off |
440+
429441
### Custom Driver Options
430442

431443
The driver package exposes the function `SetDSNOptions`, allowing for modification of the

driver/driver.go

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ func parseDSN(dsn string) (connInfo, error) {
9797
// Open takes a supplied DSN string and opens a connection
9898
// See ParseDSN for more information on the form of the DSN
9999
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
100-
var c *client.Conn
100+
var (
101+
c *client.Conn
102+
// by default database/sql driver retries will be enabled
103+
retries = true
104+
)
101105

102106
ci, err := parseDSN(dsn)
103107

@@ -134,6 +138,10 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
134138
if timeout, err = time.ParseDuration(value[0]); err != nil {
135139
return nil, errors.Wrap(err, "invalid duration value for timeout option")
136140
}
141+
} else if key == "retries" && len(value) > 0 {
142+
// by default keep the golang database/sql retry behavior enabled unless
143+
// the retries driver option is explicitly set to 'off'
144+
retries = !strings.EqualFold(value[0], "off")
137145
} else {
138146
if option, ok := options[key]; ok {
139147
opt := func(o DriverOption, v string) client.Option {
@@ -161,15 +169,28 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
161169
return nil, err
162170
}
163171

164-
return &conn{c}, nil
172+
// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
173+
// retries by the database/sql package. If retries are 'off' then we'll return
174+
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
175+
// In this case the sqldriver.Validator interface is implemented and will return
176+
// false for IsValid() signaling the connection is bad and should be discarded.
177+
return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil
165178
}
166179

167180
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
168181

169182
var _ sqldriver.NamedValueChecker = &conn{}
183+
var _ sqldriver.Validator = &conn{}
184+
185+
type state struct {
186+
valid bool
187+
// when true, the driver connection will return ErrBadConn from the golang Standard Library
188+
useStdLibErrors bool
189+
}
170190

171191
type conn struct {
172192
*client.Conn
193+
state *state
173194
}
174195

175196
func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
@@ -190,13 +211,17 @@ func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
190211
return sqldriver.ErrSkip
191212
}
192213

214+
func (c *conn) IsValid() bool {
215+
return c.state.valid
216+
}
217+
193218
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
194219
st, err := c.Conn.Prepare(query)
195220
if err != nil {
196221
return nil, errors.Trace(err)
197222
}
198223

199-
return &stmt{st}, nil
224+
return &stmt{Stmt: st, connectionState: c.state}, nil
200225
}
201226

202227
func (c *conn) Close() error {
@@ -222,10 +247,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
222247
return a
223248
}
224249

225-
func replyError(err error) error {
226-
if mysql.ErrorEqual(err, mysql.ErrBadConn) {
250+
func (st *state) replyError(err error) error {
251+
isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn)
252+
253+
if st.useStdLibErrors && isBadConnection {
227254
return sqldriver.ErrBadConn
228255
} else {
256+
// if we have a bad connection, this mark the state of this connection as not valid
257+
// do the database/sql package can discard it instead of placing it back in the
258+
// sql.DB pool.
259+
st.valid = !isBadConnection
229260
return errors.Trace(err)
230261
}
231262
}
@@ -234,7 +265,7 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
234265
a := buildArgs(args)
235266
r, err := c.Conn.Execute(query, a...)
236267
if err != nil {
237-
return nil, replyError(err)
268+
return nil, c.state.replyError(err)
238269
}
239270
return &result{r}, nil
240271
}
@@ -243,13 +274,14 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
243274
a := buildArgs(args)
244275
r, err := c.Conn.Execute(query, a...)
245276
if err != nil {
246-
return nil, replyError(err)
277+
return nil, c.state.replyError(err)
247278
}
248279
return newRows(r.Resultset)
249280
}
250281

251282
type stmt struct {
252283
*client.Stmt
284+
connectionState *state
253285
}
254286

255287
func (s *stmt) Close() error {
@@ -264,7 +296,7 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
264296
a := buildArgs(args)
265297
r, err := s.Stmt.Execute(a...)
266298
if err != nil {
267-
return nil, replyError(err)
299+
return nil, s.connectionState.replyError(err)
268300
}
269301
return &result{r}, nil
270302
}
@@ -273,7 +305,7 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
273305
a := buildArgs(args)
274306
r, err := s.Stmt.Execute(a...)
275307
if err != nil {
276-
return nil, replyError(err)
308+
return nil, s.connectionState.replyError(err)
277309
}
278310
return newRows(r.Resultset)
279311
}

driver/driver_options_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,51 @@ type testServer struct {
3131
}
3232

3333
type mockHandler struct {
34+
// the number of times a query executed
35+
queryCount int
36+
}
37+
38+
func TestDriverOptions_SetRetriesOn(t *testing.T) {
39+
log.SetLevel(log.LevelDebug)
40+
srv := CreateMockServer(t)
41+
defer srv.Stop()
42+
43+
conn, err := sql.Open("mysql", "[email protected]:3307/test?readTimeout=1s")
44+
defer func() {
45+
_ = conn.Close()
46+
}()
47+
require.NoError(t, err)
48+
49+
rows, err := conn.QueryContext(context.TODO(), "select * from slow;")
50+
require.Nil(t, rows)
51+
52+
// we want to get a golang database/sql/driver ErrBadConn
53+
require.ErrorIs(t, err, sqlDriver.ErrBadConn)
54+
55+
// here we issue assert that even though we only issued 1 query, that the retries
56+
// remained on and there were 3 calls to the DB.
57+
require.Equal(t, 3, srv.handler.queryCount)
58+
}
59+
60+
func TestDriverOptions_SetRetriesOff(t *testing.T) {
61+
log.SetLevel(log.LevelDebug)
62+
srv := CreateMockServer(t)
63+
defer srv.Stop()
64+
65+
conn, err := sql.Open("mysql", "[email protected]:3307/test?readTimeout=1s&retries=off")
66+
defer func() {
67+
_ = conn.Close()
68+
}()
69+
require.NoError(t, err)
70+
71+
rows, err := conn.QueryContext(context.TODO(), "select * from slow;")
72+
require.Nil(t, rows)
73+
// we want the native error from this driver implementation
74+
require.ErrorIs(t, err, mysql.ErrBadConn)
75+
76+
// here we issue assert that even though we only issued 1 query, that the retries
77+
// remained on and there were 3 calls to the DB.
78+
require.Equal(t, 1, srv.handler.queryCount)
3479
}
3580

3681
func TestDriverOptions_SetCollation(t *testing.T) {
@@ -65,6 +110,9 @@ func TestDriverOptions_ConnectTimeout(t *testing.T) {
65110
defer srv.Stop()
66111

67112
conn, err := sql.Open("mysql", "[email protected]:3307/test?timeout=1s")
113+
defer func() {
114+
_ = conn.Close()
115+
}()
68116
require.NoError(t, err)
69117

70118
rows, err := conn.QueryContext(context.TODO(), "select * from table;")
@@ -88,6 +136,9 @@ func TestDriverOptions_BufferSize(t *testing.T) {
88136
})
89137

90138
conn, err := sql.Open("mysql", "[email protected]:3307/test?bufferSize=4096")
139+
defer func() {
140+
_ = conn.Close()
141+
}()
91142
require.NoError(t, err)
92143

93144
rows, err := conn.QueryContext(context.TODO(), "select * from table;")
@@ -103,6 +154,9 @@ func TestDriverOptions_ReadTimeout(t *testing.T) {
103154
defer srv.Stop()
104155

105156
conn, err := sql.Open("mysql", "[email protected]:3307/test?readTimeout=1s")
157+
defer func() {
158+
_ = conn.Close()
159+
}()
106160
require.NoError(t, err)
107161

108162
rows, err := conn.QueryContext(context.TODO(), "select * from slow;")
@@ -134,11 +188,15 @@ func TestDriverOptions_writeTimeout(t *testing.T) {
134188
require.Contains(t, err.Error(), "missing unit in duration")
135189
require.Error(t, err)
136190
require.Nil(t, result)
191+
require.NoError(t, conn.Close())
137192

138193
// use an almost zero (1ns) writeTimeout to ensure the insert statement
139194
// can't write before the timeout. Just want to make sure ExecContext()
140195
// will throw an error.
141196
conn, err = sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1ns")
197+
defer func() {
198+
_ = conn.Close()
199+
}()
142200
require.NoError(t, err)
143201

144202
// ExecContext() should fail due to the write timeout of 1ns
@@ -165,6 +223,9 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
165223
srv := CreateMockServer(t)
166224
defer srv.Stop()
167225
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1s")
226+
defer func() {
227+
_ = conn.Close()
228+
}()
168229
require.NoError(t, err)
169230
defer conn.Close()
170231

@@ -248,6 +309,7 @@ func (h *mockHandler) UseDB(dbName string) error {
248309
}
249310

250311
func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) {
312+
h.queryCount++
251313
ss := strings.Split(query, " ")
252314
switch strings.ToLower(ss[0]) {
253315
case "select":

0 commit comments

Comments
 (0)