@@ -97,7 +97,11 @@ func parseDSN(dsn string) (connInfo, error) {
97
97
// Open takes a supplied DSN string and opens a connection
98
98
// See ParseDSN for more information on the form of the DSN
99
99
func (d driver ) Open (dsn string ) (sqldriver.Conn , error ) {
100
- var c * client.Conn
100
+ var (
101
+ c * client.Conn
102
+ // by default database/sql driver retries will be enabled
103
+ retries = true
104
+ )
101
105
102
106
ci , err := parseDSN (dsn )
103
107
@@ -134,6 +138,10 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
134
138
if timeout , err = time .ParseDuration (value [0 ]); err != nil {
135
139
return nil , errors .Wrap (err , "invalid duration value for timeout option" )
136
140
}
141
+ } else if key == "retries" && len (value ) > 0 {
142
+ // by default keep the golang database/sql retry behavior enabled unless
143
+ // the retries driver option is explicitly set to 'off'
144
+ retries = ! strings .EqualFold (value [0 ], "off" )
137
145
} else {
138
146
if option , ok := options [key ]; ok {
139
147
opt := func (o DriverOption , v string ) client.Option {
@@ -161,15 +169,28 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
161
169
return nil , err
162
170
}
163
171
164
- return & conn {c }, nil
172
+ // if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
173
+ // retries by the database/sql package. If retries are 'off' then we'll return
174
+ // the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
175
+ // In this case the sqldriver.Validator interface is implemented and will return
176
+ // false for IsValid() signaling the connection is bad and should be discarded.
177
+ return & conn {Conn : c , state : & state {valid : true , useStdLibErrors : retries }}, nil
165
178
}
166
179
167
180
type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
168
181
169
182
var _ sqldriver.NamedValueChecker = & conn {}
183
+ var _ sqldriver.Validator = & conn {}
184
+
185
+ type state struct {
186
+ valid bool
187
+ // when true, the driver connection will return ErrBadConn from the golang Standard Library
188
+ useStdLibErrors bool
189
+ }
170
190
171
191
type conn struct {
172
192
* client.Conn
193
+ state * state
173
194
}
174
195
175
196
func (c * conn ) CheckNamedValue (nv * sqldriver.NamedValue ) error {
@@ -190,13 +211,17 @@ func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
190
211
return sqldriver .ErrSkip
191
212
}
192
213
214
+ func (c * conn ) IsValid () bool {
215
+ return c .state .valid
216
+ }
217
+
193
218
func (c * conn ) Prepare (query string ) (sqldriver.Stmt , error ) {
194
219
st , err := c .Conn .Prepare (query )
195
220
if err != nil {
196
221
return nil , errors .Trace (err )
197
222
}
198
223
199
- return & stmt {st }, nil
224
+ return & stmt {Stmt : st , connectionState : c . state }, nil
200
225
}
201
226
202
227
func (c * conn ) Close () error {
@@ -222,10 +247,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
222
247
return a
223
248
}
224
249
225
- func replyError (err error ) error {
226
- if mysql .ErrorEqual (err , mysql .ErrBadConn ) {
250
+ func (st * state ) replyError (err error ) error {
251
+ isBadConnection := mysql .ErrorEqual (err , mysql .ErrBadConn )
252
+
253
+ if st .useStdLibErrors && isBadConnection {
227
254
return sqldriver .ErrBadConn
228
255
} else {
256
+ // if we have a bad connection, this mark the state of this connection as not valid
257
+ // do the database/sql package can discard it instead of placing it back in the
258
+ // sql.DB pool.
259
+ st .valid = ! isBadConnection
229
260
return errors .Trace (err )
230
261
}
231
262
}
@@ -234,7 +265,7 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
234
265
a := buildArgs (args )
235
266
r , err := c .Conn .Execute (query , a ... )
236
267
if err != nil {
237
- return nil , replyError (err )
268
+ return nil , c . state . replyError (err )
238
269
}
239
270
return & result {r }, nil
240
271
}
@@ -243,13 +274,14 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
243
274
a := buildArgs (args )
244
275
r , err := c .Conn .Execute (query , a ... )
245
276
if err != nil {
246
- return nil , replyError (err )
277
+ return nil , c . state . replyError (err )
247
278
}
248
279
return newRows (r .Resultset )
249
280
}
250
281
251
282
type stmt struct {
252
283
* client.Stmt
284
+ connectionState * state
253
285
}
254
286
255
287
func (s * stmt ) Close () error {
@@ -264,7 +296,7 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
264
296
a := buildArgs (args )
265
297
r , err := s .Stmt .Execute (a ... )
266
298
if err != nil {
267
- return nil , replyError (err )
299
+ return nil , s . connectionState . replyError (err )
268
300
}
269
301
return & result {r }, nil
270
302
}
@@ -273,7 +305,7 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
273
305
a := buildArgs (args )
274
306
r , err := s .Stmt .Execute (a ... )
275
307
if err != nil {
276
- return nil , replyError (err )
308
+ return nil , s . connectionState . replyError (err )
277
309
}
278
310
return newRows (r .Resultset )
279
311
}
0 commit comments