Skip to content

dbutil: add option to detect database calls with incorrect context #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions dbutil/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"regexp"
"strings"
"time"

"go.mau.fi/util/exsync"
)

type Dialect int
Expand Down Expand Up @@ -114,12 +116,16 @@ type Database struct {
Dialect Dialect
UpgradeTable UpgradeTable

txnCtxKey contextKey
txnCtxKey contextKey
txnDeadlockMap *exsync.Set[int64]

IgnoreForeignTables bool
IgnoreUnsupportedDatabase bool
DeadlockDetection bool
}

var ForceDeadlockDetection bool

var positionalParamPattern = regexp.MustCompile(`\$(\d+)`)

func (db *Database) mutateQuery(query string) string {
Expand All @@ -144,10 +150,12 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
Log: log,
Dialect: db.Dialect,

txnCtxKey: db.txnCtxKey,
txnCtxKey: db.txnCtxKey,
txnDeadlockMap: db.txnDeadlockMap,

IgnoreForeignTables: true,
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
DeadlockDetection: db.DeadlockDetection,
}
}

Expand All @@ -164,7 +172,10 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
IgnoreForeignTables: true,
VersionTable: "version",

txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),
txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),
txnDeadlockMap: exsync.NewSet[int64](),

DeadlockDetection: ForceDeadlockDetection,
}
wrappedDB.LoggingDB.UnderlyingExecable = db
wrappedDB.LoggingDB.db = wrappedDB
Expand Down Expand Up @@ -194,6 +205,8 @@ type PoolConfig struct {
type Config struct {
PoolConfig `yaml:",inline"`
ReadOnlyPool PoolConfig `yaml:"ro_pool"`

DeadlockDetection bool `yaml:"deadlock_detection"`
}

func (db *Database) Close() error {
Expand All @@ -211,6 +224,8 @@ func (db *Database) Close() error {
}

func (db *Database) Configure(cfg Config) error {
db.DeadlockDetection = cfg.DeadlockDetection || ForceDeadlockDetection

if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil {
return err
}
Expand Down
140 changes: 140 additions & 0 deletions dbutil/deadlock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package dbutil_test

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"go.mau.fi/util/dbutil"
_ "go.mau.fi/util/dbutil/litestream"
)

func initTestDB(t *testing.T) *dbutil.Database {
db, err := dbutil.NewFromConfig("", dbutil.Config{
PoolConfig: dbutil.PoolConfig{
Type: "sqlite3-fk-wal",
URI: ":memory:?_txlock=immediate",
MaxOpenConns: 1,
MaxIdleConns: 1,
},
DeadlockDetection: true,
}, nil)
require.NoError(t, err)
ctx := context.Background()
_, err = db.Exec(ctx, `
CREATE TABLE meow (id INTEGER PRIMARY KEY, value TEXT);
INSERT INTO meow (id, value) VALUES (1, 'meow');
INSERT INTO meow (id, value) VALUES (2, 'meow 2');
INSERT INTO meow (value) VALUES ('meow 3');
`)
require.NoError(t, err)
return db
}

func getMeow(ctx context.Context, db dbutil.Execable, id int) (value string, err error) {
err = db.QueryRowContext(ctx, "SELECT value FROM meow WHERE id = ?", id).Scan(&value)
return
}

func TestDatabase_NoDeadlock(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
require.NoError(t, db.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := db.Exec(ctx, "INSERT INTO meow (value) VALUES ('meow 4');")
require.NoError(t, err)
return nil
}))
val, err := getMeow(ctx, db.Execable(ctx), 4)
require.NoError(t, err)
require.Equal(t, "meow 4", val)
}

func TestDatabase_NoDeadlock_Goroutine(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
require.NoError(t, db.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := db.Exec(ctx, "INSERT INTO meow (value) VALUES ('meow 4');")
require.NoError(t, err)
go func() {
_, err := db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 5');")
require.NoError(t, err)
}()
time.Sleep(50 * time.Millisecond)
return nil
}))
val, err := getMeow(ctx, db.Execable(ctx), 4)
require.NoError(t, err)
require.Equal(t, "meow 4", val)
val, err = getMeow(ctx, db.Execable(ctx), 5)
require.NoError(t, err)
require.Equal(t, "meow 5", val)
}

func TestDatabase_Deadlock(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() {
_, _ = db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');")
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Acquire(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrAcquireDeadlock.Error(), func() {
_, _ = db.AcquireConn(context.Background())
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Txn(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrTransactionDeadlock.Error(), func() {
_ = db.DoTxn(context.Background(), nil, func(ctx context.Context) error {
return nil
})
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Child(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
childDB := db.Child("", nil, nil)
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() {
_, _ = childDB.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');")
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Child2(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
childDB := db.Child("", nil, nil)
_ = childDB.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() {
_, _ = db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');")
})
return fmt.Errorf("meow")
})
}
17 changes: 17 additions & 0 deletions dbutil/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"sync/atomic"
"time"

"github.com/petermattis/goid"
"github.com/rs/zerolog"

"go.mau.fi/util/exerrors"
Expand Down Expand Up @@ -51,6 +52,10 @@ func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sq
return db.Execable(ctx).QueryRowContext(ctx, query, args...)
}

var ErrTransactionDeadlock = errors.New("attempt to start new transaction in goroutine with transaction")
var ErrQueryDeadlock = errors.New("attempt to query without context in goroutine with transaction")
var ErrAcquireDeadlock = errors.New("attempt to acquire connection without context in goroutine with transaction")

func (db *Database) BeginTx(ctx context.Context, opts *TxnOptions) (*LoggingTxn, error) {
if ctx == nil {
panic("BeginTx() called with nil ctx")
Expand All @@ -65,6 +70,12 @@ func (db *Database) DoTxn(ctx context.Context, opts *TxnOptions, fn func(ctx con
if ctx.Value(db.txnCtxKey) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
return fn(ctx)
} else if db.DeadlockDetection {
goroutineID := goid.Get()
if !db.txnDeadlockMap.Add(goroutineID) {
panic(ErrTransactionDeadlock)
}
defer db.txnDeadlockMap.Remove(goroutineID)
}

log := zerolog.Ctx(ctx).With().Str("db_txn_id", random.String(12)).Logger()
Expand Down Expand Up @@ -141,6 +152,9 @@ func (db *Database) Execable(ctx context.Context) Execable {
if ok {
return txn
}
if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) {
panic(ErrQueryDeadlock)
}
return &db.LoggingDB
}

Expand All @@ -152,6 +166,9 @@ func (db *Database) AcquireConn(ctx context.Context) (Conn, error) {
if ok {
return nil, fmt.Errorf("cannot acquire connection while in a transaction")
}
if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) {
panic(ErrAcquireDeadlock)
}
conn, err := db.RawDB.Conn(ctx)
if err != nil {
return nil, err
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.21
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/mattn/go-sqlite3 v1.14.22
github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240707233637-46b078467d37
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 h1:DUDJI8T/9NcGbbL+AWk6vIYlmQ8ZBS8LZqVre6zbkPQ=
github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down