Skip to content

Commit 31cc2a2

Browse files
authored
validate sql (#21)
validate sql
1 parent 50cd057 commit 31cc2a2

File tree

4 files changed

+190
-26
lines changed

4 files changed

+190
-26
lines changed

duck/duckdb.go

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ import (
77
"fmt"
88
"os"
99
"os/exec"
10+
"strings"
1011

1112
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
1213
sdk "github.com/grafana/grafana-plugin-sdk-go/data"
1314
"github.com/grafana/grafana-plugin-sdk-go/data/framestruct"
1415
"github.com/hairyhenderson/go-which"
1516
"github.com/iancoleman/orderedmap"
17+
"github.com/jeremywohl/flatten"
1618
"github.com/scottlepp/go-duck/duck/data"
1719
)
1820

@@ -105,6 +107,10 @@ func (d *DuckDB) Query(query string) (string, error) {
105107

106108
// QueryFrame will load a dataframe into a view named RefID, and run the query against that view
107109
func (d *DuckDB) QueryFrames(name string, query string, frames []*sdk.Frame) (string, bool, error) {
110+
err := d.validate(query)
111+
if err != nil {
112+
return "", false, err
113+
}
108114
data := FrameData{
109115
cacheDuration: d.cacheDuration,
110116
cache: &d.cache,
@@ -123,15 +129,21 @@ func wipe(dirs map[string]string) {
123129
}
124130
}
125131

126-
func (d *DuckDB) QueryFramesInto(name string, query string, frames []*sdk.Frame, f *sdk.Frame) error {
132+
func (d *DuckDB) QueryFramesToFrames(name string, query string, frames []*sdk.Frame) (*sdk.Frame, error) {
133+
err := d.validate(query)
134+
if err != nil {
135+
return nil, err
136+
}
137+
138+
f := &sdk.Frame{}
127139
res, cached, err := d.QueryFrames(name, query, frames)
128140
if err != nil {
129-
return err
141+
return nil, err
130142
}
131143

132144
err = resultsToFrame(name, res, f, frames)
133145
if err != nil {
134-
return err
146+
return nil, err
135147
}
136148
if cached {
137149
for _, frame := range frames {
@@ -145,7 +157,7 @@ func (d *DuckDB) QueryFramesInto(name string, query string, frames []*sdk.Frame,
145157
frame.Meta.Notices = append(frame.Meta.Notices, notice)
146158
}
147159
}
148-
return nil
160+
return f, nil
149161
}
150162

151163
// Destroy will remove database files created by duckdb
@@ -294,3 +306,89 @@ func getTempDir() string {
294306
}
295307
return temp
296308
}
309+
310+
const (
311+
TABLE_NAME = "table_name"
312+
ERROR = ".error"
313+
ERROR_MESSAGE = ".error_message"
314+
)
315+
316+
func (d *DuckDB) validate(rawSQL string) error {
317+
rawSQL = strings.Replace(rawSQL, "'", "''", -1)
318+
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL)
319+
ret, err := d.RunCommands([]string{cmd})
320+
if err != nil {
321+
logger.Error("error validating sql", "error", err.Error(), "sql", rawSQL, "cmd", cmd)
322+
return fmt.Errorf("error validating sql: %s", err.Error())
323+
}
324+
325+
result := []map[string]any{}
326+
err = json.Unmarshal([]byte(ret), &result)
327+
if err != nil {
328+
logger.Error("error converting json sql to ast", "error", err.Error(), "ret", ret)
329+
return fmt.Errorf("error converting json to ast: %s", err.Error())
330+
}
331+
332+
if len(result) == 0 {
333+
logger.Error("no ast returned", "ret", ret)
334+
}
335+
336+
var ast map[string]any
337+
for _, v := range result[0] {
338+
validAst, ok := v.(map[string]any)
339+
if !ok {
340+
logger.Error("invalid sql", "sql", ret)
341+
return fmt.Errorf("invalid sql: %s", ret)
342+
}
343+
ast = validAst
344+
break
345+
}
346+
347+
errMsg := ast["error"]
348+
if errMsg != nil {
349+
errMsgBool, ok := errMsg.(bool)
350+
if !ok {
351+
logger.Error("error in ast", "error", ret)
352+
return fmt.Errorf("error in ast: %v", ret)
353+
}
354+
if errMsgBool {
355+
logger.Error("error in ast", "error", ret)
356+
return fmt.Errorf("error in ast: %v", ret)
357+
}
358+
}
359+
360+
statements := ast["statements"]
361+
if statements == nil {
362+
logger.Error("no statements in ast", "ast", ast)
363+
return fmt.Errorf("no statements in ast: %v", ast)
364+
}
365+
366+
flat, err := flatten.Flatten(ast, "", flatten.DotStyle)
367+
if err != nil {
368+
logger.Error("error flattening ast", "error", err.Error(), "ast", ast)
369+
return fmt.Errorf("error flattening ast: %s", err.Error())
370+
}
371+
372+
for k, v := range flat {
373+
if strings.HasSuffix(k, ERROR) {
374+
v, ok := v.(bool)
375+
if ok && v {
376+
logger.Error("error in sql", "error", k)
377+
return fmt.Errorf("error flattening ast: %s", k)
378+
}
379+
}
380+
if strings.Contains(k, "from_table.function.function_name") {
381+
logger.Error("function not allowed", "function", v)
382+
return fmt.Errorf("function not allowed: %s", v)
383+
}
384+
if strings.HasSuffix(k, "from_table.table_name") {
385+
v, ok := v.(string)
386+
if ok && strings.Contains(v, ".") {
387+
logger.Error("table names with . not allowed", "table", v)
388+
return fmt.Errorf("table names with . not allowed: %s", v)
389+
}
390+
}
391+
}
392+
393+
return nil
394+
}

duck/duckdb_test.go

Lines changed: 85 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ func TestCommands(t *testing.T) {
2626
assert.Contains(t, res, `[{"i":1,"j":5}]`)
2727
}
2828

29+
func TestDotCommands(t *testing.T) {
30+
db := NewInMemoryDB()
31+
32+
commands := []string{
33+
".databases",
34+
}
35+
res, err := db.RunCommands(commands)
36+
if err != nil {
37+
t.Fail()
38+
return
39+
}
40+
assert.Contains(t, res, `memory`)
41+
}
42+
2943
func TestCommandsDocker(t *testing.T) {
3044
db := NewInMemoryDB(Opts{Docker: true})
3145

@@ -74,6 +88,66 @@ func TestQueryFrame(t *testing.T) {
7488
assert.Contains(t, res, `[{"value":"test"}]`)
7589
}
7690

91+
func TestQueryAgg(t *testing.T) {
92+
db := NewInMemoryDB()
93+
94+
var values = []string{"test"}
95+
frame := data.NewFrame("foo", data.NewField("value", nil, values))
96+
frame.RefID = "foo"
97+
frames := []*data.Frame{frame}
98+
99+
res, _, err := db.QueryFrames("foo", "select min(value) as value from foo", frames)
100+
assert.Nil(t, err)
101+
102+
assert.Contains(t, res, `[{"value":"test"}]`)
103+
}
104+
105+
func TestQueryJson(t *testing.T) {
106+
db := NewInMemoryDB()
107+
108+
var values = []string{"test"}
109+
frame := data.NewFrame("foo", data.NewField("value", nil, values))
110+
frame.RefID = "foo"
111+
frames := []*data.Frame{frame}
112+
113+
_, _, err := db.QueryFrames("foo", "SELECT * FROM read_json('todos.json')", frames)
114+
assert.NotNil(t, err)
115+
}
116+
117+
func TestValid(t *testing.T) {
118+
db := NewInMemoryDB()
119+
120+
var values = []string{"test"}
121+
frame := data.NewFrame("foo", data.NewField("value", nil, values))
122+
frame.RefID = "foo"
123+
frames := []*data.Frame{frame}
124+
125+
query := fmt.Sprintf(".databases %s", newline)
126+
_, _, err := db.QueryFrames("foo", query, frames)
127+
assert.NotNil(t, err)
128+
}
129+
130+
func TestQueryFrameNoFileRead(t *testing.T) {
131+
db := NewInMemoryDB()
132+
133+
var values = []string{"test"}
134+
frame := data.NewFrame("foo", data.NewField("value", nil, values))
135+
frame.RefID = "foo"
136+
frames := []*data.Frame{frame}
137+
138+
_, _, err := db.QueryFrames("foo", "SELECT * FROM read_csv('flights.csv')", frames)
139+
assert.NotNil(t, err)
140+
141+
_, _, err = db.QueryFrames("foo", "SELECT * FROM read_json('flights.json')", frames)
142+
assert.NotNil(t, err)
143+
144+
_, _, err = db.QueryFrames("foo", "SELECT * FROM 'test.parquet'", frames)
145+
assert.NotNil(t, err)
146+
147+
_, _, err = db.QueryFrames("foo", "COPY test FROM 'test.parquet'", frames)
148+
assert.NotNil(t, err)
149+
}
150+
77151
func TestQueryFrameCache(t *testing.T) {
78152
opts := Opts{
79153
CacheDuration: 5,
@@ -153,8 +227,7 @@ func TestQueryFrameIntoFrame(t *testing.T) {
153227

154228
frames := []*data.Frame{frame, frame2}
155229

156-
model := &data.Frame{}
157-
err := db.QueryFramesInto("foo", "select * from foo order by value desc", frames, model)
230+
model, err := db.QueryFramesToFrames("foo", "select * from foo order by value desc", frames)
158231
assert.Nil(t, err)
159232

160233
assert.Equal(t, 2, model.Rows())
@@ -178,8 +251,7 @@ func TestQueryFrameIntoFrameDocker(t *testing.T) {
178251

179252
frames := []*data.Frame{frame, frame2}
180253

181-
model := &data.Frame{}
182-
err := db.QueryFramesInto("foo", "select * from foo order by value desc", frames, model)
254+
model, err := db.QueryFramesToFrames("foo", "select * from foo order by value desc", frames)
183255
assert.Nil(t, err)
184256

185257
assert.Equal(t, 2, model.Rows())
@@ -203,8 +275,7 @@ func TestQueryFrameIntoFrameMultipleColumns(t *testing.T) {
203275

204276
frames := []*data.Frame{frame}
205277

206-
model := &data.Frame{}
207-
err := db.QueryFramesInto("B", "select * from A", frames, model)
278+
model, err := db.QueryFramesToFrames("B", "select * from A", frames)
208279
assert.Nil(t, err)
209280

210281
assert.Equal(t, "Z State", model.Fields[0].Name)
@@ -230,8 +301,7 @@ func TestMultiFrame(t *testing.T) {
230301

231302
frames := []*data.Frame{frame, frame2}
232303

233-
model := &data.Frame{}
234-
err := db.QueryFramesInto("foo", "select * from foo", frames, model)
304+
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
235305
assert.Nil(t, err)
236306

237307
assert.Equal(t, 2, model.Rows())
@@ -257,8 +327,7 @@ func TestMultiFrame2(t *testing.T) {
257327

258328
frames := []*data.Frame{frame, frame2}
259329

260-
model := &data.Frame{}
261-
err := db.QueryFramesInto("foo", "select * from foo", frames, model)
330+
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
262331
assert.Nil(t, err)
263332

264333
assert.Equal(t, 2, model.Rows())
@@ -281,8 +350,7 @@ func TestTimestamps(t *testing.T) {
281350

282351
frames := []*data.Frame{frame}
283352

284-
model := &data.Frame{}
285-
err = db.QueryFramesInto("foo", "select * from foo", frames, model)
353+
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
286354
assert.Nil(t, err)
287355

288356
assert.Equal(t, 1, model.Rows())
@@ -311,8 +379,7 @@ func TestTimeSeries(t *testing.T) {
311379

312380
frames := []*data.Frame{frame}
313381

314-
model := &data.Frame{}
315-
err = db.QueryFramesInto("foo", "select * from foo", frames, model)
382+
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
316383
assert.Nil(t, err)
317384

318385
assert.Equal(t, data.FrameTypeTimeSeriesWide, model.Meta.Type)
@@ -341,8 +408,7 @@ func TestTimeSeriesWide(t *testing.T) {
341408

342409
frames := []*data.Frame{frame}
343410

344-
model := &data.Frame{}
345-
err = db.QueryFramesInto("foo", "select * from foo", frames, model)
411+
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
346412
assert.Nil(t, err)
347413

348414
assert.Equal(t, data.FrameTypeTimeSeriesWide, model.Meta.Type)
@@ -377,8 +443,7 @@ func TestLabels(t *testing.T) {
377443

378444
frames := []*data.Frame{frame, frame2}
379445

380-
model := &data.Frame{}
381-
err := db.QueryFramesInto("foo", "select * from foo", frames, model)
446+
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
382447
assert.Nil(t, err)
383448

384449
assert.Equal(t, 2, model.Rows())
@@ -427,8 +492,7 @@ func TestLabelsMultiFrame(t *testing.T) {
427492
frames := []*data.Frame{frame, frame2}
428493

429494
// TODO - ordering is broken!
430-
model := &data.Frame{}
431-
err = db.QueryFramesInto("foo", "select * from foo order by timestamp desc", frames, model)
495+
model, err := db.QueryFramesToFrames("foo", "select * from foo order by timestamp desc", frames)
432496
assert.Nil(t, err)
433497

434498
assert.Equal(t, 4, model.Rows())
@@ -459,8 +523,7 @@ func TestTimeSeriesAggregate(t *testing.T) {
459523

460524
frames := []*data.Frame{frame}
461525

462-
model := &data.Frame{}
463-
err = db.QueryFramesInto("foo", "select CURRENT_TIMESTAMP, min(time) as t, 1 as j from foo group by category", frames, model)
526+
model, err := db.QueryFramesToFrames("foo", "select min(time) as t, 1 as j from foo group by category", frames)
464527
assert.Nil(t, err)
465528

466529
assert.Equal(t, data.FrameTypeTimeSeriesWide, model.Meta.Type)

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ require (
2323
github.com/google/flatbuffers v24.3.25+incompatible // indirect
2424
github.com/google/go-cmp v0.6.0 // indirect
2525
github.com/hashicorp/go-hclog v1.6.3 // indirect
26+
github.com/jeremywohl/flatten v1.0.1
2627
github.com/json-iterator/go v1.1.12 // indirect
2728
github.com/klauspost/asmfmt v1.3.2 // indirect
2829
github.com/klauspost/compress v1.17.9 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ github.com/iancoleman/orderedmap v0.3.0 h1:5cbR2grmZR/DiVt+VJopEhtVs9YGInGIxAoMJ
8282
github.com/iancoleman/orderedmap v0.3.0/go.mod h1:XuLcCUkdL5owUCQeF2Ue9uuw1EptkJDkXXS7VoV7XGE=
8383
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
8484
github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
85+
github.com/jeremywohl/flatten v1.0.1 h1:LrsxmB3hfwJuE+ptGOijix1PIfOoKLJ3Uee/mzbgtrs=
86+
github.com/jeremywohl/flatten v1.0.1/go.mod h1:4AmD/VxjWcI5SRB0n6szE2A6s2fsNHDLO0nAlMHgfLQ=
8587
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
8688
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
8789
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=

0 commit comments

Comments
 (0)