Skip to content

Commit bc662ea

Browse files
authored
added pre and post health check methods (#147)
1 parent 204a975 commit bc662ea

File tree

4 files changed

+142
-17
lines changed

4 files changed

+142
-17
lines changed

datasource.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ type SQLDatasource struct {
5050
CustomRoutes map[string]func(http.ResponseWriter, *http.Request)
5151
metrics Metrics
5252
EnableMultipleConnections bool
53+
// PreCheckHealth (optional). Performs custom health check before the Connect method
54+
PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
55+
// PostCheckHealth (optional).Performs custom health check after the Connect method
56+
PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
5357
}
5458

5559
// NewDatasource creates a new `SQLDatasource`.
@@ -252,8 +256,10 @@ func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHeal
252256
ctx, req = checkHealthMutator.MutateCheckHealth(ctx, req)
253257
}
254258
healthChecker := &HealthChecker{
255-
Connector: ds.connector,
256-
Metrics: ds.metrics.WithEndpoint(EndpointHealth),
259+
Connector: ds.connector,
260+
Metrics: ds.metrics.WithEndpoint(EndpointHealth),
261+
PreCheckHealth: ds.PreCheckHealth,
262+
PostCheckHealth: ds.PostCheckHealth,
257263
}
258264
return healthChecker.Check(ctx, req)
259265
}

driver-mock.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"encoding/json"
7+
"errors"
78
"os"
89
"path/filepath"
910
"strings"
@@ -17,7 +18,8 @@ import (
1718

1819
// SQLMock connects to a local folder with csv files
1920
type SQLMock struct {
20-
folder string
21+
folder string
22+
ShouldFailToConnect bool
2123
}
2224

2325
func (h *SQLMock) Settings(_ context.Context, _ backend.DataSourceInstanceSettings) DriverSettings {
@@ -31,6 +33,9 @@ func (h *SQLMock) Settings(_ context.Context, _ backend.DataSourceInstanceSettin
3133

3234
// Connect opens a sql.DB connection using datasource settings
3335
func (h *SQLMock) Connect(_ context.Context, _ backend.DataSourceInstanceSettings, msg json.RawMessage) (*sql.DB, error) {
36+
if h.ShouldFailToConnect {
37+
return nil, errors.New("failed to create mock")
38+
}
3439
backend.Logger.Debug("connecting to mock data")
3540
folder := h.folder
3641
if folder == "" {

health.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,30 @@ import (
88
)
99

1010
type HealthChecker struct {
11-
Connector *Connector
12-
Metrics Metrics
11+
Connector *Connector
12+
Metrics Metrics
13+
PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
14+
PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
1315
}
1416

1517
func (hc *HealthChecker) Check(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
1618
start := time.Now()
17-
18-
_, err := hc.Connector.Connect(ctx, req.GetHTTPHeaders())
19-
if err != nil {
19+
if hc.PreCheckHealth != nil {
20+
if res := hc.PreCheckHealth(ctx, req); res != nil && res.Status == backend.HealthStatusError {
21+
hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds())
22+
return res, nil
23+
}
24+
}
25+
if _, err := hc.Connector.Connect(ctx, req.GetHTTPHeaders()); err != nil {
2026
hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds())
21-
return &backend.CheckHealthResult{
22-
Status: backend.HealthStatusError,
23-
Message: err.Error(),
24-
}, DownstreamError(err)
27+
return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: err.Error()}, nil
28+
}
29+
if hc.PostCheckHealth != nil {
30+
if res := hc.PostCheckHealth(ctx, req); res != nil && res.Status == backend.HealthStatusError {
31+
hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds())
32+
return res, nil
33+
}
2534
}
2635
hc.Metrics.CollectDuration(SourceDownstream, StatusOK, time.Since(start).Seconds())
27-
28-
return &backend.CheckHealthResult{
29-
Status: backend.HealthStatusOk,
30-
Message: "Data source is working",
31-
}, nil
36+
return &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"}, nil
3237
}

health_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package sqlds_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/grafana/grafana-plugin-sdk-go/backend"
8+
sqlds "github.com/grafana/sqlds/v4"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func getFakeConnector(t *testing.T, shouldFail bool) *sqlds.Connector {
14+
t.Helper()
15+
c, _ := sqlds.NewConnector(context.TODO(), &sqlds.SQLMock{ShouldFailToConnect: shouldFail}, backend.DataSourceInstanceSettings{}, false)
16+
return c
17+
}
18+
19+
func TestHealthChecker_Check(t *testing.T) {
20+
tests := []struct {
21+
name string
22+
Connector *sqlds.Connector
23+
Metrics sqlds.Metrics
24+
PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
25+
PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
26+
ctx context.Context
27+
req *backend.CheckHealthRequest
28+
want *backend.CheckHealthResult
29+
wantErr error
30+
}{
31+
{
32+
name: "default health check should return valid result",
33+
Connector: getFakeConnector(t, false),
34+
},
35+
{
36+
name: "should not error when pre check succeed",
37+
Connector: getFakeConnector(t, false),
38+
PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
39+
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
40+
},
41+
},
42+
{
43+
name: "should error when pre check failed",
44+
Connector: getFakeConnector(t, false),
45+
PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
46+
return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"}
47+
},
48+
want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"},
49+
},
50+
{
51+
name: "should return actual error when pre and post health check succeed but actual connect failed",
52+
Connector: getFakeConnector(t, true),
53+
PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
54+
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
55+
},
56+
PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
57+
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
58+
},
59+
want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unable to get default db connection"},
60+
},
61+
{
62+
name: "should not error when post check succeed",
63+
Connector: getFakeConnector(t, false),
64+
PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
65+
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
66+
},
67+
},
68+
{
69+
name: "should error when post check failed",
70+
Connector: getFakeConnector(t, false),
71+
PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
72+
return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"}
73+
},
74+
want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"},
75+
},
76+
}
77+
for _, tt := range tests {
78+
t.Run(tt.name, func(t *testing.T) {
79+
connector := tt.Connector
80+
if connector == nil {
81+
connector = &sqlds.Connector{}
82+
}
83+
req := tt.req
84+
if req == nil {
85+
req = &backend.CheckHealthRequest{}
86+
}
87+
want := tt.want
88+
if want == nil {
89+
want = &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"}
90+
}
91+
hc := &sqlds.HealthChecker{
92+
Connector: connector,
93+
Metrics: tt.Metrics,
94+
PreCheckHealth: tt.PreCheckHealth,
95+
PostCheckHealth: tt.PostCheckHealth,
96+
}
97+
got, err := hc.Check(tt.ctx, req)
98+
if tt.wantErr != nil {
99+
require.NotNil(t, err)
100+
assert.Equal(t, tt.wantErr.Error(), err.Error())
101+
return
102+
}
103+
require.Nil(t, err)
104+
require.NotNil(t, got)
105+
assert.Equal(t, want.Message, got.Message)
106+
assert.Equal(t, want.Status, got.Status)
107+
})
108+
}
109+
}

0 commit comments

Comments
 (0)