Skip to content

Commit e30c0ed

Browse files
authored
Fix connection leak on query retry reconnect (#79)
* Close released conn on a reconnect attempt * fix err message * tests: do actual DB connection on each driver connect
1 parent cd68d26 commit e30c0ed

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

datasource.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,10 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery,
227227
if errors.Is(err, ErrorQuery) && !errors.Is(err, context.DeadlineExceeded) {
228228
for i := 0; i < ds.driverSettings.Retries; i++ {
229229
backend.Logger.Warn(fmt.Sprintf("query failed. retrying %d times", i))
230-
db, err := ds.c.Connect(dbConn.settings, q.ConnectionArgs)
230+
db, err := ds.dbReconnect(dbConn, q, cacheKey)
231231
if err != nil {
232232
return nil, err
233233
}
234-
ds.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings})
235234

236235
if ds.driverSettings.Pause > 0 {
237236
time.Sleep(time.Duration(ds.driverSettings.Pause * int(time.Second)))
@@ -247,11 +246,10 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery,
247246
if errors.Is(err, context.DeadlineExceeded) {
248247
for i := 0; i < ds.driverSettings.Retries; i++ {
249248
backend.Logger.Warn(fmt.Sprintf("connection timed out. retrying %d times", i))
250-
db, err := ds.c.Connect(dbConn.settings, q.ConnectionArgs)
249+
db, err := ds.dbReconnect(dbConn, q, cacheKey)
251250
if err != nil {
252251
continue
253252
}
254-
ds.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings})
255253

256254
res, err = QueryDB(ctx, db, ds.c.Converters(), fillMode, q)
257255
if err == nil {
@@ -263,6 +261,19 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery,
263261
return nil, err
264262
}
265263

264+
func (ds *SQLDatasource) dbReconnect(dbConn dbConnection, q *Query, cacheKey string) (*sql.DB, error) {
265+
if err := dbConn.db.Close(); err != nil {
266+
backend.Logger.Warn(fmt.Sprintf("closing existing connection failed: %s", err.Error()))
267+
}
268+
269+
db, err := ds.c.Connect(dbConn.settings, q.ConnectionArgs)
270+
if err != nil {
271+
return nil, err
272+
}
273+
ds.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings})
274+
return db, nil
275+
}
276+
266277
// CheckHealth pings the connected SQL database
267278
func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
268279
key := defaultKey(getDatasourceUID(*req.PluginContext.DataSourceInstanceSettings))

datasource_test.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ import (
1818
)
1919

2020
type fakeDriver struct {
21-
db *sql.DB
21+
openDBfn func() (*sql.DB, error)
2222

2323
Driver
2424
}
2525

2626
func (d fakeDriver) Connect(backend.DataSourceInstanceSettings, json.RawMessage) (db *sql.DB, err error) {
27-
return d.db, nil
27+
return d.openDBfn()
2828
}
2929

3030
func (d fakeDriver) Macros() Macros {
@@ -41,7 +41,7 @@ func Test_getDBConnectionFromQuery(t *testing.T) {
4141
db := &sql.DB{}
4242
db2 := &sql.DB{}
4343
db3 := &sql.DB{}
44-
d := &fakeDriver{db: db3}
44+
d := &fakeDriver{openDBfn: func() (*sql.DB, error) { return db3, nil }}
4545
tests := []struct {
4646
desc string
4747
dsUID string
@@ -144,7 +144,7 @@ func Test_timeout_retries(t *testing.T) {
144144
t.Errorf("failed to connect to mock driver: %v", err)
145145
}
146146
timeoutDriver := fakeDriver{
147-
db: db,
147+
openDBfn: func() (*sql.DB, error) { return db, nil },
148148
}
149149
retries := 5
150150
max := time.Duration(testTimeout) * time.Second
@@ -178,12 +178,15 @@ func Test_error_retries(t *testing.T) {
178178
}
179179
mockDriver := "sqlmock-error"
180180
mock.RegisterDriver(mockDriver, handler)
181-
db, err := sql.Open(mockDriver, "")
182-
if err != nil {
183-
t.Errorf("failed to connect to mock driver: %v", err)
184-
}
181+
185182
timeoutDriver := fakeDriver{
186-
db: db,
183+
openDBfn: func() (*sql.DB, error) {
184+
db, err := sql.Open(mockDriver, "")
185+
if err != nil {
186+
t.Errorf("failed to connect to mock driver: %v", err)
187+
}
188+
return db, nil
189+
},
187190
}
188191
retries := 5
189192
max := time.Duration(10) * time.Second
@@ -192,6 +195,7 @@ func Test_error_retries(t *testing.T) {
192195

193196
key := defaultKey(dsUID)
194197
// Add the mandatory default db
198+
db, _ := timeoutDriver.Connect(settings, nil)
195199
ds.storeDBConnection(key, dbConnection{db, settings})
196200
ctx := context.Background()
197201

0 commit comments

Comments
 (0)