From 008b48b92c0f53ba79ef345c723853f9846804ee Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Fri, 8 Nov 2024 21:10:22 -0500 Subject: [PATCH 1/2] test: add Exec tests and benchmarks --- sqlite3_test.go | 133 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..67aa6ba4 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -1090,6 +1091,67 @@ func TestExecer(t *testing.T) { } } +func TestExecDriverResult(t *testing.T) { + setup := func(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", t.TempDir()+"/test.sqlite3") + if err != nil { + t.Fatal("Failed to open database:", err) + } + if _, err := db.Exec(`CREATE TABLE foo (id INTEGER PRIMARY KEY);`); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + return db + } + + test := func(t *testing.T, execStmt string, args ...any) { + db := setup(t) + res, err := db.Exec(execStmt, args...) + if err != nil { + t.Fatal(err) + } + rows, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } + // We only return the changes from the last statement. + if rows != 1 { + t.Errorf("RowsAffected got: %d want: %d", rows, 1) + } + id, err := res.LastInsertId() + if err != nil { + t.Fatal(err) + } + if id != 3 { + t.Errorf("LastInsertId got: %d want: %d", id, 3) + } + var count int64 + err = db.QueryRow(`SELECT COUNT(*) FROM foo WHERE id IN (1, 2, 3);`).Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 3 { + t.Errorf("Expected count to be %d got: %d", 3, count) + } + } + + t.Run("NoArgs", func(t *testing.T) { + const stmt = ` + INSERT INTO foo(id) VALUES(1); + INSERT INTO foo(id) VALUES(2); + INSERT INTO foo(id) VALUES(3);` + test(t, stmt) + }) + + t.Run("WithArgs", func(t *testing.T) { + const stmt = ` + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?);` + test(t, stmt, 1, 2, 3) + }) +} + func TestQueryer(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2106,6 +2168,10 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkExecContext", F: benchmarkExecContext}, + {Name: "BenchmarkExecStep", F: benchmarkExecStep}, + {Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep}, + {Name: "BenchmarkExecTx", F: benchmarkExecTx}, {Name: "BenchmarkQuery", F: benchmarkQuery}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, @@ -2459,13 +2525,78 @@ func testExecEmptyQuery(t *testing.T) { // benchmarkExec is benchmark for exec func benchmarkExec(b *testing.B) { + b.Run("Params", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select ?;", int64(1)); err != nil { + panic(err) + } + } + }) + b.Run("NoParams", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select 1;"); err != nil { + panic(err) + } + } + }) +} + +func benchmarkExecContext(b *testing.B) { + b.Run("Params", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select ?;", int64(1)); err != nil { + panic(err) + } + } + }) + b.Run("NoParams", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select 1;"); err != nil { + panic(err) + } + } + }) +} + +func benchmarkExecTx(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := db.Exec("select 1"); err != nil { + tx, err := db.Begin() + if err != nil { + panic(err) + } + if _, err := tx.Exec("select 1;"); err != nil { + panic(err) + } + if err := tx.Commit(); err != nil { panic(err) } } } +var largeSelectStmt = strings.Repeat("select 1;\n", 1_000) + +func benchmarkExecStep(b *testing.B) { + for n := 0; n < b.N; n++ { + if _, err := db.Exec(largeSelectStmt); err != nil { + b.Fatal(err) + } + } +} + +func benchmarkExecContextStep(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for n := 0; n < b.N; n++ { + if _, err := db.ExecContext(ctx, largeSelectStmt); err != nil { + b.Fatal(err) + } + } +} + // benchmarkQuery is benchmark for query func benchmarkQuery(b *testing.B) { for i := 0; i < b.N; i++ { From 4974017ad3cfa6f16059294a741a541509c018f7 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Fri, 8 Nov 2024 21:10:49 -0500 Subject: [PATCH 2/2] Fix exponential memory allocation in Exec and improve performance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit changes SQLiteConn.Exec to use the raw Go query string instead of repeatedly converting it to a C string (which it would do for every statement in the provided query). This yields a ~20% performance improvement for a query containing one statement and a significantly larger improvement when the query contains multiple statements as is common when importing a SQL dump (our benchmark shows a 5x improvement for handling 1k SQL statements). Additionally, this commit improves the performance of Exec by 2x or more and makes number and size of allocations constant when there are no bind parameters (the performance improvement scales with the number of SQL statements in the query). This is achieved by having the entire query processed in C code thus requiring only one CGO call. The speedup for Exec'ing single statement queries means that wrapping simple statements in a transaction is now twice as fast. This commit also improves the test coverage of Exec, which previously failed to test that Exec could process multiple statements like INSERT. It also adds some Exec specific benchmarks that highlight both the improvements here and the overhead of using a cancellable Context. This commit is a slimmed down and improved version of PR #1133: https://github.com/mattn/go-sqlite3/pull/1133 ``` goos: darwin goarch: arm64 pkg: github.com/mattn/go-sqlite3 cpu: Apple M1 Max │ b.txt │ n.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkExec/Params-10 1.434µ ± 1% 1.186µ ± 0% -17.27% (p=0.000 n=10) Suite/BenchmarkExec/NoParams-10 1267.5n ± 0% 759.2n ± 1% -40.10% (p=0.000 n=10) Suite/BenchmarkExecContext/Params-10 2.886µ ± 0% 2.517µ ± 0% -12.80% (p=0.000 n=10) Suite/BenchmarkExecContext/NoParams-10 2.605µ ± 1% 1.829µ ± 1% -29.81% (p=0.000 n=10) Suite/BenchmarkExecStep-10 1852.6µ ± 1% 582.3µ ± 0% -68.57% (p=0.000 n=10) Suite/BenchmarkExecContextStep-10 3053.3µ ± 3% 582.0µ ± 0% -80.94% (p=0.000 n=10) Suite/BenchmarkExecTx-10 4.126µ ± 2% 2.200µ ± 1% -46.67% (p=0.000 n=10) geomean 16.40µ 8.455µ -48.44% │ b.txt │ n.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkExec/Params-10 248.0 ± 0% 240.0 ± 0% -3.23% (p=0.000 n=10) Suite/BenchmarkExec/NoParams-10 128.00 ± 0% 64.00 ± 0% -50.00% (p=0.000 n=10) Suite/BenchmarkExecContext/Params-10 408.0 ± 0% 400.0 ± 0% -1.96% (p=0.000 n=10) Suite/BenchmarkExecContext/NoParams-10 288.0 ± 0% 208.0 ± 0% -27.78% (p=0.000 n=10) Suite/BenchmarkExecStep-10 5406674.50 ± 0% 64.00 ± 0% -100.00% (p=0.000 n=10) Suite/BenchmarkExecContextStep-10 5566758.5 ± 0% 208.0 ± 0% -100.00% (p=0.000 n=10) Suite/BenchmarkExecTx-10 712.0 ± 0% 520.0 ± 0% -26.97% (p=0.000 n=10) geomean 4.899Ki 189.7 -96.22% │ b.txt │ n.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkExec/Params-10 10.000 ± 0% 9.000 ± 0% -10.00% (p=0.000 n=10) Suite/BenchmarkExec/NoParams-10 7.000 ± 0% 4.000 ± 0% -42.86% (p=0.000 n=10) Suite/BenchmarkExecContext/Params-10 12.00 ± 0% 11.00 ± 0% -8.33% (p=0.000 n=10) Suite/BenchmarkExecContext/NoParams-10 9.000 ± 0% 6.000 ± 0% -33.33% (p=0.000 n=10) Suite/BenchmarkExecStep-10 7000.000 ± 0% 4.000 ± 0% -99.94% (p=0.000 n=10) Suite/BenchmarkExecContextStep-10 9001.000 ± 0% 6.000 ± 0% -99.93% (p=0.000 n=10) Suite/BenchmarkExecTx-10 27.00 ± 0% 18.00 ± 0% -33.33% (p=0.000 n=10) geomean 74.60 7.224 -90.32% ``` --- sqlite3.go | 171 +++++++++++++++++++++++++++++++++++++++++------- unsafe_go120.go | 17 +++++ unsafe_go121.go | 23 +++++++ 3 files changed, 189 insertions(+), 22 deletions(-) create mode 100644 unsafe_go120.go create mode 100644 unsafe_go121.go diff --git a/sqlite3.go b/sqlite3.go index 3025a500..aca480a5 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -137,6 +137,61 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ } #endif +static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) { + const char *tail; + int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + if (tail) { + // Set oBytes to the number of bytes consumed instead of using the + // **pzTail out param since that requires storing a Go pointer in + // a C pointer, which is not allowed by CGO and will cause + // runtime.cgoCheckPointer to fail. + *oBytes = tail - zSql; + } else { + // NB: this should not happen, but if it does advance oBytes to the + // end of the string so that we do not loop infinitely. + *oBytes = nBytes; + } + return SQLITE_OK; +} + +// _sqlite3_exec_no_args executes all of the statements in zSql. None of the +// statements are allowed to have positional arguments. +int _sqlite3_exec_no_args(sqlite3 *db, const char *zSql, int nBytes, int64_t *rowid, int64_t *changes) { + while (*zSql && nBytes > 0) { + sqlite3_stmt *stmt; + const char *tail; + int rv = sqlite3_prepare_v2(db, zSql, nBytes, &stmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + + // Process statement + do { + rv = _sqlite3_step_internal(stmt); + } while (rv == SQLITE_ROW); + + // Only record the number of changes made by the last statement. + *changes = sqlite3_changes64(db); + *rowid = sqlite3_last_insert_rowid(db); + + if (rv != SQLITE_OK && rv != SQLITE_DONE) { + sqlite3_finalize(stmt); + return rv; + } + rv = sqlite3_finalize(stmt); + if (rv != SQLITE_OK) { + return rv; + } + + nBytes -= tail - zSql; + zSql = tail; + } + return SQLITE_OK; +} + void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { sqlite3_result_text(ctx, s, -1, &free); } @@ -858,54 +913,119 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err } func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - start := 0 + // Trim the query. This is mostly important for getting rid + // of any trailing space. + query = strings.TrimSpace(query) + if len(args) > 0 { + return c.execArgs(ctx, query, args) + } + return c.execNoArgs(ctx, query) +} + +func (c *SQLiteConn) execArgs(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + var ( + stmtArgs []driver.NamedValue + start int + s SQLiteStmt // escapes to the heap so reuse it + sz C.int // number of query bytes consumed: escapes to the heap + ) for { - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + s = SQLiteStmt{c: c} // reset + sz = 0 + rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &s.s, &sz) + if rv != C.SQLITE_OK { + return nil, c.lastError() } + query = strings.TrimSpace(query[sz:]) + var res driver.Result - if s.(*SQLiteStmt).s != nil { - stmtArgs := make([]driver.NamedValue, 0, len(args)) + if s.s != nil { na := s.NumInput() if len(args)-start < na { - s.Close() + s.finalize() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } // consume the number of arguments used in the current // statement and append all named arguments not // contained therein - if len(args[start:start+na]) > 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 + stmtArgs = append(stmtArgs[:0], args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) } } - res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + var err error + res, err = s.exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() + s.finalize() return nil, err } start += na } - tail := s.(*SQLiteStmt).t - s.Close() - if tail == "" { + s.finalize() + if len(query) == 0 { if res == nil { // https://github.com/mattn/go-sqlite3/issues/963 res = &SQLiteResult{0, 0} } return res, nil } - query = tail } } +// execNoArgsSync processes every SQL statement in query. All processing occurs +// in C code, which reduces the overhead of CGO calls. +func (c *SQLiteConn) execNoArgsSync(query string) (_ driver.Result, err error) { + var rowid, changes C.int64_t + rv := C._sqlite3_exec_no_args(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &rowid, &changes) + if rv != C.SQLITE_OK { + err = c.lastError() + } + return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, err +} + +func (c *SQLiteConn) execNoArgs(ctx context.Context, query string) (driver.Result, error) { + done := ctx.Done() + if done == nil { + return c.execNoArgsSync(query) + } + + // Fast check if the Context is cancelled + if err := ctx.Err(); err != nil { + return nil, err + } + + ch := make(chan struct{}) + defer close(ch) + go func() { + select { + case <-done: + C.sqlite3_interrupt(c.db) + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns, which is + // why we can't check if only done is closed when waiting below. + <-ch + case <-ch: + } + }() + + res, err := c.execNoArgsSync(query) + + // Stop the goroutine and make sure we're at a point where + // sqlite3_interrupt cannot be called again. + ch <- struct{}{} + + if isInterruptErr(err) { + err = ctx.Err() + } + return res, err +} + // Query implements Queryer. func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { list := make([]driver.NamedValue, len(args)) @@ -1914,6 +2034,13 @@ func (s *SQLiteStmt) Close() error { return nil } +func (s *SQLiteStmt) finalize() { + if s.s != nil { + C.sqlite3_finalize(s.s) + s.s = nil + } +} + // NumInput return a number of parameters. func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) diff --git a/unsafe_go120.go b/unsafe_go120.go new file mode 100644 index 00000000..95d673ed --- /dev/null +++ b/unsafe_go120.go @@ -0,0 +1,17 @@ +//go:build !go1.21 +// +build !go1.21 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + b := *(*[]byte)(unsafe.Pointer(&s)) + return &b[0] + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +} diff --git a/unsafe_go121.go b/unsafe_go121.go new file mode 100644 index 00000000..b9c00a12 --- /dev/null +++ b/unsafe_go121.go @@ -0,0 +1,23 @@ +//go:build go1.21 +// +build go1.21 + +// The unsafe.StringData function was made available in Go 1.20 but it +// was not until Go 1.21 that Go was changed to interpret the Go version +// in go.mod (1.19 as of writing this) as the minimum version required +// instead of the exact version. +// +// See: https://github.com/golang/go/issues/59033 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + return unsafe.StringData(s) + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +}