Skip to content

Commit 29ba2a5

Browse files
authored
forward headers (#90)
forward headers
1 parent 196f681 commit 29ba2a5

File tree

6 files changed

+342
-144
lines changed

6 files changed

+342
-144
lines changed

datasource.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const defaultKeySuffix = "default"
2222
var (
2323
ErrorMissingMultipleConnectionsConfig = errors.New("received connection arguments but the feature is not enabled")
2424
ErrorMissingDBConnection = errors.New("unable to get default db connection")
25-
25+
HeaderKey = "grafana-http-headers"
2626
// Deprecated: ErrorMissingMultipleConnectionsConfig should be used instead
2727
MissingMultipleConnectionsConfig = ErrorMissingMultipleConnectionsConfig
2828
// Deprecated: ErrorMissingDBConnection should be used instead
@@ -114,6 +114,8 @@ func (ds *SQLDatasource) Dispose() {
114114

115115
// QueryData creates the Responses list and executes each query
116116
func (ds *SQLDatasource) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
117+
headers := req.GetHTTPHeaders()
118+
117119
var (
118120
response = NewResponse(backend.NewQueryDataResponse())
119121
wg = sync.WaitGroup{}
@@ -124,7 +126,7 @@ func (ds *SQLDatasource) QueryData(ctx context.Context, req *backend.QueryDataRe
124126
// Execute each query and store the results by query RefID
125127
for _, q := range req.Queries {
126128
go func(query backend.DataQuery) {
127-
frames, err := ds.handleQuery(ctx, query, getDatasourceUID(*req.PluginContext.DataSourceInstanceSettings))
129+
frames, err := ds.handleQuery(ctx, query, getDatasourceUID(*req.PluginContext.DataSourceInstanceSettings), headers)
128130
if err == nil {
129131
if responseMutator, ok := ds.c.(ResponseMutator); ok {
130132
frames, err = responseMutator.MutateResponse(ctx, frames)
@@ -150,7 +152,7 @@ func (ds *SQLDatasource) GetDBFromQuery(q *Query, datasourceUID string) (*sql.DB
150152
}
151153

152154
func (ds *SQLDatasource) getDBConnectionFromQuery(q *Query, datasourceUID string) (string, dbConnection, error) {
153-
if !ds.EnableMultipleConnections && len(q.ConnectionArgs) > 0 {
155+
if !ds.EnableMultipleConnections && !ds.driverSettings.ForwardHeaders && len(q.ConnectionArgs) > 0 {
154156
return "", dbConnection{}, MissingMultipleConnectionsConfig
155157
}
156158
// The database connection may vary depending on query arguments
@@ -182,13 +184,13 @@ func (ds *SQLDatasource) getDBConnectionFromQuery(q *Query, datasourceUID string
182184
}
183185

184186
// handleQuery will call query, and attempt to reconnect if the query failed
185-
func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, datasourceUID string) (data.Frames, error) {
187+
func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, datasourceUID string, headers http.Header) (data.Frames, error) {
186188
if queryMutator, ok := ds.c.(QueryMutator); ok {
187189
ctx, req = queryMutator.MutateQuery(ctx, req)
188190
}
189191

190192
// Convert the backend.DataQuery into a Query object
191-
q, err := GetQuery(req)
193+
q, err := GetQuery(req, headers, ds.driverSettings.ForwardHeaders)
192194
if err != nil {
193195
return nil, err
194196
}
@@ -302,19 +304,31 @@ func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHeal
302304
return ds.check(dbConn)
303305
}
304306

305-
return ds.checkWithRetries(dbConn)
307+
return ds.checkWithRetries(dbConn, key, req.GetHTTPHeaders())
306308
}
307309

308310
func (ds *SQLDatasource) DriverSettings() DriverSettings {
309311
return ds.driverSettings
310312
}
311313

312-
func (ds *SQLDatasource) checkWithRetries(conn dbConnection) (*backend.CheckHealthResult, error) {
314+
func (ds *SQLDatasource) checkWithRetries(conn dbConnection, key string, headers http.Header) (*backend.CheckHealthResult, error) {
313315
var result *backend.CheckHealthResult
314-
var err error
316+
317+
q := &Query{}
318+
if ds.driverSettings.ForwardHeaders {
319+
applyHeaders(q, headers)
320+
}
315321

316322
for i := 0; i < ds.driverSettings.Retries; i++ {
317-
result, err = ds.check(conn)
323+
db, err := ds.dbReconnect(conn, q, key)
324+
if err != nil {
325+
return nil, err
326+
}
327+
c := dbConnection{
328+
db: db,
329+
settings: conn.settings,
330+
}
331+
result, err = ds.check(c)
318332
if err == nil {
319333
return result, err
320334
}

datasource_test.go

Lines changed: 154 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ import (
1818
)
1919

2020
type fakeDriver struct {
21-
openDBfn func() (*sql.DB, error)
21+
openDBfn func(msg json.RawMessage) (*sql.DB, error)
2222

2323
Driver
2424
}
2525

26-
func (d fakeDriver) Connect(backend.DataSourceInstanceSettings, json.RawMessage) (db *sql.DB, err error) {
27-
return d.openDBfn()
26+
func (d fakeDriver) Connect(settings backend.DataSourceInstanceSettings, msg json.RawMessage) (db *sql.DB, err error) {
27+
return d.openDBfn(msg)
2828
}
2929

3030
func (d fakeDriver) Macros() Macros {
@@ -35,13 +35,11 @@ func (d fakeDriver) Converters() []sqlutil.Converter {
3535
return []sqlutil.Converter{}
3636
}
3737

38-
// func (d fakeDriver) Settings(backend.DataSourceInstanceSettings) DriverSettings
39-
4038
func Test_getDBConnectionFromQuery(t *testing.T) {
4139
db := &sql.DB{}
4240
db2 := &sql.DB{}
4341
db3 := &sql.DB{}
44-
d := &fakeDriver{openDBfn: func() (*sql.DB, error) { return db3, nil }}
42+
d := &fakeDriver{openDBfn: func(msg json.RawMessage) (*sql.DB, error) { return db3, nil }}
4543
tests := []struct {
4644
desc string
4745
dsUID string
@@ -136,15 +134,21 @@ func Test_timeout_retries(t *testing.T) {
136134
dsUID := "timeout"
137135
settings := backend.DataSourceInstanceSettings{UID: dsUID}
138136

139-
handler := testSqlHandler{}
137+
handler := &testSqlHandler{}
140138
mockDriver := "sqlmock"
141139
mock.RegisterDriver(mockDriver, handler)
142140
db, err := sql.Open(mockDriver, "")
143141
if err != nil {
144142
t.Errorf("failed to connect to mock driver: %v", err)
145143
}
146144
timeoutDriver := fakeDriver{
147-
openDBfn: func() (*sql.DB, error) { return db, nil },
145+
openDBfn: func(msg json.RawMessage) (*sql.DB, error) {
146+
db, err := sql.Open(mockDriver, "")
147+
if err != nil {
148+
t.Errorf("failed to connect to mock driver: %v", err)
149+
}
150+
return db, nil
151+
},
148152
}
149153
retries := 5
150154
max := time.Duration(testTimeout) * time.Second
@@ -173,14 +177,14 @@ func Test_error_retries(t *testing.T) {
173177
dsUID := "error"
174178
settings := backend.DataSourceInstanceSettings{UID: dsUID}
175179

176-
handler := testSqlHandler{
180+
handler := &testSqlHandler{
177181
error: errors.New("foo"),
178182
}
179183
mockDriver := "sqlmock-error"
180184
mock.RegisterDriver(mockDriver, handler)
181185

182186
timeoutDriver := fakeDriver{
183-
openDBfn: func() (*sql.DB, error) {
187+
openDBfn: func(msg json.RawMessage) (*sql.DB, error) {
184188
db, err := sql.Open(mockDriver, "")
185189
if err != nil {
186190
t.Errorf("failed to connect to mock driver: %v", err)
@@ -220,19 +224,154 @@ func Test_error_retries(t *testing.T) {
220224

221225
}
222226

227+
func Test_query_apply_headers(t *testing.T) {
228+
testCounter = 0
229+
dsUID := "headers"
230+
settings := backend.DataSourceInstanceSettings{UID: dsUID}
231+
232+
// first query will fail since the connection is missing tokens
233+
handler := &testSqlHandler{
234+
error: errors.New("missing token"),
235+
}
236+
mockDriver := "sqlmock-query-error"
237+
mock.RegisterDriver(mockDriver, handler)
238+
239+
opened := false
240+
var message json.RawMessage
241+
timeoutDriver := fakeDriver{
242+
openDBfn: func(msg json.RawMessage) (*sql.DB, error) {
243+
if opened {
244+
// second query attempt will have tokens and won't return an error
245+
handler = &testSqlHandler{}
246+
mockDriver = "sqlmock-query-token"
247+
mock.RegisterDriver(mockDriver, handler)
248+
}
249+
db, err := sql.Open(mockDriver, "")
250+
if err != nil {
251+
t.Errorf("failed to connect to mock driver: %v", err)
252+
}
253+
opened = true
254+
message = msg
255+
return db, nil
256+
},
257+
}
258+
max := time.Duration(10) * time.Second
259+
// retry once for token errors since the first connection will not have the token and throw a connection error
260+
driverSettings := DriverSettings{Retries: 1, Timeout: max, Pause: 1, RetryOn: []string{"token"}, ForwardHeaders: true}
261+
ds := &SQLDatasource{c: timeoutDriver, driverSettings: driverSettings}
262+
263+
key := defaultKey(dsUID)
264+
// Add the mandatory default db
265+
db, _ := timeoutDriver.Connect(settings, nil)
266+
ds.storeDBConnection(key, dbConnection{db, settings})
267+
ctx := context.Background()
268+
269+
qry := `{ "rawSql": "foo" }`
270+
271+
req := &backend.QueryDataRequest{
272+
PluginContext: backend.PluginContext{
273+
DataSourceInstanceSettings: &settings,
274+
},
275+
Queries: []backend.DataQuery{
276+
{
277+
RefID: "foo",
278+
JSON: []byte(qry),
279+
},
280+
},
281+
}
282+
req.SetHTTPHeader("hey", "scott")
283+
284+
data, err := ds.QueryData(ctx, req)
285+
assert.Nil(t, err)
286+
assert.NotNil(t, data.Responses)
287+
288+
res := data.Responses["foo"]
289+
assert.Nil(t, res.Error)
290+
291+
assert.Contains(t, string(message), "scott")
292+
}
293+
294+
func Test_check_health_with_headers(t *testing.T) {
295+
dsUID := "headers"
296+
settings := backend.DataSourceInstanceSettings{UID: dsUID}
297+
298+
// first check will fail since the connection is missing tokens
299+
handler := &testSqlHandler{
300+
error: errors.New("missing token"),
301+
}
302+
mockDriver := "sqlmock-header-error"
303+
mock.RegisterDriver(mockDriver, handler)
304+
305+
opened := false
306+
var message json.RawMessage
307+
timeoutDriver := fakeDriver{
308+
openDBfn: func(msg json.RawMessage) (*sql.DB, error) {
309+
if opened {
310+
// second query attempt will have tokens and won't return an error
311+
handler = &testSqlHandler{
312+
ping: true,
313+
checks: handler.checks,
314+
}
315+
mockDriver = "sqlmock-header-token"
316+
mock.RegisterDriver(mockDriver, handler)
317+
}
318+
db, err := sql.Open(mockDriver, "")
319+
if err != nil {
320+
t.Errorf("failed to connect to mock driver: %v", err)
321+
}
322+
opened = true
323+
message = msg
324+
return db, nil
325+
},
326+
}
327+
max := time.Duration(10) * time.Second
328+
// retry once for token errors since the first connection will not have the token and throw a connection error
329+
driverSettings := DriverSettings{Retries: 1, Timeout: max, Pause: 1, RetryOn: []string{"token"}, ForwardHeaders: true}
330+
ds := &SQLDatasource{c: timeoutDriver, driverSettings: driverSettings}
331+
332+
key := defaultKey(dsUID)
333+
// Add the mandatory default db
334+
db, _ := timeoutDriver.Connect(settings, nil)
335+
ds.storeDBConnection(key, dbConnection{db, settings})
336+
ctx := context.Background()
337+
338+
headers := map[string]string{}
339+
headers["foo"] = "bar"
340+
req := &backend.CheckHealthRequest{
341+
PluginContext: backend.PluginContext{
342+
DataSourceInstanceSettings: &settings,
343+
},
344+
Headers: headers,
345+
}
346+
347+
req.SetHTTPHeader("foo", "bar")
348+
349+
res, err := ds.CheckHealth(ctx, req)
350+
assert.Nil(t, err)
351+
assert.Equal(t, "Data source is working", res.Message)
352+
353+
assert.Contains(t, string(message), "bar")
354+
}
355+
223356
var testCounter = 0
224357
var testTimeout = 1
225358
var testRows = 0
226359

227360
type testSqlHandler struct {
228361
mock.DBHandler
229362
error
363+
ping bool
364+
checks int
230365
}
231366

232-
func (s testSqlHandler) Ping(ctx context.Context) error {
367+
func (s *testSqlHandler) Ping(ctx context.Context) error {
368+
s.checks += 1
233369
if s.error != nil {
234370
return s.error
235371
}
372+
if s.ping {
373+
return nil
374+
}
236375
testCounter++ // track the retries for the test assertion
237376
time.Sleep(time.Duration(testTimeout + 1)) // simulate a connection delay
238377
select {
@@ -242,7 +381,7 @@ func (s testSqlHandler) Ping(ctx context.Context) error {
242381
}
243382
}
244383

245-
func (s testSqlHandler) Query(args []driver.Value) (driver.Rows, error) {
384+
func (s *testSqlHandler) Query(args []driver.Value) (driver.Rows, error) {
246385
fmt.Println("query")
247386
if s.error != nil {
248387
testCounter++
@@ -251,11 +390,11 @@ func (s testSqlHandler) Query(args []driver.Value) (driver.Rows, error) {
251390
return s, nil
252391
}
253392

254-
func (s testSqlHandler) Columns() []string {
393+
func (s *testSqlHandler) Columns() []string {
255394
return []string{"foo", "bar"}
256395
}
257396

258-
func (s testSqlHandler) Next(dest []driver.Value) error {
397+
func (s *testSqlHandler) Next(dest []driver.Value) error {
259398
testRows++
260399
if testRows > 5 {
261400
return io.EOF
@@ -265,6 +404,6 @@ func (s testSqlHandler) Next(dest []driver.Value) error {
265404
return nil
266405
}
267406

268-
func (s testSqlHandler) Close() error {
407+
func (s *testSqlHandler) Close() error {
269408
return nil
270409
}

driver.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ import (
1212
)
1313

1414
type DriverSettings struct {
15-
Timeout time.Duration
16-
FillMode *data.FillMissing
17-
Retries int
18-
Pause int
19-
RetryOn []string
15+
Timeout time.Duration
16+
FillMode *data.FillMissing
17+
Retries int
18+
Pause int
19+
RetryOn []string
20+
ForwardHeaders bool
2021
}
2122

2223
// Driver is a simple interface that defines how to connect to a backend SQL datasource

0 commit comments

Comments
 (0)