Skip to content

Commit 1674490

Browse files
committed
db: define manual action SQL queries
Here, we manually define some queries for the actions store. We do this so that we can manually build the "SELECT" and only add "WHERE" clauses that are actually needed for the query and hence ensure that available indexes are used.
1 parent 65e4309 commit 1674490

File tree

3 files changed

+234
-3
lines changed

3 files changed

+234
-3
lines changed

db/interfaces.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,13 @@ type BatchedQuerier interface {
8787
// create a batched version of the normal methods they need.
8888
sqlc.Querier
8989

90+
// CustomQueries is the set of custom queries that we have manually
91+
// defined in addition to the ones generated by sqlc.
92+
sqlc.CustomQueries
93+
9094
// BeginTx creates a new database transaction given the set of
9195
// transaction options.
9296
BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error)
93-
94-
// Backend returns the type of the database backend used.
95-
Backend() sqlc.BackendType
9697
}
9798

9899
// txExecutorOptions is a struct that holds the options for the transaction

db/sqlc/actions_custom.go

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
package sqlc
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"strconv"
7+
"strings"
8+
)
9+
10+
// ActionQueryParams defines the parameters for querying actions.
11+
type ActionQueryParams struct {
12+
SessionID sql.NullInt64
13+
AccountID sql.NullInt64
14+
FeatureName sql.NullString
15+
ActorName sql.NullString
16+
RpcMethod sql.NullString
17+
State sql.NullInt16
18+
EndTime sql.NullTime
19+
StartTime sql.NullTime
20+
GroupID sql.NullInt64
21+
}
22+
23+
// ListActionsParams defines the parameters for listing actions, including
24+
// the ActionQueryParams for filtering and a Pagination struct for
25+
// pagination. The Reversed field indicates whether the results should be
26+
// returned in reverse order based on the created_at timestamp.
27+
type ListActionsParams struct {
28+
ActionQueryParams
29+
Reversed bool
30+
*Pagination
31+
}
32+
33+
// Pagination defines the pagination parameters for listing actions.
34+
type Pagination struct {
35+
NumOffset int32
36+
NumLimit int32
37+
}
38+
39+
// ListActions retrieves a list of actions based on the provided
40+
// ListActionsParams.
41+
func (q *Queries) ListActions(ctx context.Context,
42+
arg ListActionsParams) ([]Action, error) {
43+
44+
query, args := buildListActionsQuery(arg)
45+
rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...)
46+
if err != nil {
47+
return nil, err
48+
}
49+
defer rows.Close()
50+
var items []Action
51+
for rows.Next() {
52+
var i Action
53+
if err := rows.Scan(
54+
&i.ID,
55+
&i.SessionID,
56+
&i.AccountID,
57+
&i.MacaroonIdentifier,
58+
&i.ActorName,
59+
&i.FeatureName,
60+
&i.ActionTrigger,
61+
&i.Intent,
62+
&i.StructuredJsonData,
63+
&i.RpcMethod,
64+
&i.RpcParamsJson,
65+
&i.CreatedAt,
66+
&i.ActionState,
67+
&i.ErrorReason,
68+
); err != nil {
69+
return nil, err
70+
}
71+
items = append(items, i)
72+
}
73+
if err := rows.Close(); err != nil {
74+
return nil, err
75+
}
76+
if err := rows.Err(); err != nil {
77+
return nil, err
78+
}
79+
return items, nil
80+
}
81+
82+
// CountActions returns the number of actions that match the provided
83+
// ActionQueryParams.
84+
func (q *Queries) CountActions(ctx context.Context,
85+
arg ActionQueryParams) (int64, error) {
86+
87+
query, args := buildActionsQuery(arg, true)
88+
row := q.db.QueryRowContext(ctx, query, args...)
89+
90+
var count int64
91+
err := row.Scan(&count)
92+
93+
return count, err
94+
}
95+
96+
// buildActionsQuery constructs a SQL query to retrieve actions based on the
97+
// provided parameters. We do this manually so that if, for example, we have
98+
// a sessionID we are filtering by, then this appears in the query as:
99+
// `WHERE a.session_id = ?` which will properly make use of the underlying
100+
// index. If we were instead to use a single SQLC query, it would include many
101+
// WHERE clauses like:
102+
// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)".
103+
// This would use the index if run against postres but not when run against
104+
// sqlite.
105+
//
106+
// The 'count' param indicates whether the query should return a count of
107+
// actions that match the criteria or the actions themselves.
108+
func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) {
109+
var (
110+
conditions []string
111+
args []any
112+
)
113+
114+
if params.SessionID.Valid {
115+
conditions = append(conditions, "a.session_id = ?")
116+
args = append(args, params.SessionID.Int64)
117+
}
118+
if params.AccountID.Valid {
119+
conditions = append(conditions, "a.account_id = ?")
120+
args = append(args, params.AccountID.Int64)
121+
}
122+
if params.FeatureName.Valid {
123+
conditions = append(conditions, "a.feature_name = ?")
124+
args = append(args, params.FeatureName.String)
125+
}
126+
if params.ActorName.Valid {
127+
conditions = append(conditions, "a.actor_name = ?")
128+
args = append(args, params.ActorName.String)
129+
}
130+
if params.RpcMethod.Valid {
131+
conditions = append(conditions, "a.rpc_method = ?")
132+
args = append(args, params.RpcMethod.String)
133+
}
134+
if params.State.Valid {
135+
conditions = append(conditions, "a.action_state = ?")
136+
args = append(args, params.State.Int16)
137+
}
138+
if params.EndTime.Valid {
139+
conditions = append(conditions, "a.created_at <= ?")
140+
args = append(args, params.EndTime.Time)
141+
}
142+
if params.StartTime.Valid {
143+
conditions = append(conditions, "a.created_at >= ?")
144+
args = append(args, params.StartTime.Time)
145+
}
146+
if params.GroupID.Valid {
147+
conditions = append(conditions, `
148+
EXISTS (
149+
SELECT 1
150+
FROM sessions s
151+
WHERE s.id = a.session_id AND s.group_id = ?
152+
)`)
153+
args = append(args, params.GroupID.Int64)
154+
}
155+
156+
query := "SELECT a.* FROM actions a"
157+
if count {
158+
query = "SELECT COUNT(*) FROM actions a"
159+
}
160+
if len(conditions) > 0 {
161+
query += " WHERE " + strings.Join(conditions, " AND ")
162+
}
163+
164+
return query, args
165+
}
166+
167+
// buildListActionsQuery constructs a SQL query to retrieve a list of actions
168+
// based on the provided parameters. It builds upon the `buildActionsQuery`
169+
// function, adding pagination and ordering based on the reversed parameter.
170+
func buildListActionsQuery(params ListActionsParams) (string, []interface{}) {
171+
query, args := buildActionsQuery(params.ActionQueryParams, false)
172+
173+
// Determine order direction.
174+
order := "ASC"
175+
if params.Reversed {
176+
order = "DESC"
177+
}
178+
query += " ORDER BY a.created_at " + order
179+
180+
// Maybe paginate.
181+
if params.Pagination != nil {
182+
query += " LIMIT ? OFFSET ?"
183+
args = append(args, params.NumLimit, params.NumOffset)
184+
}
185+
186+
return query, args
187+
}
188+
189+
// fillPlaceHolders replaces all '?' placeholders in the SQL query with
190+
// positional placeholders like $1, $2, etc. This is necessary for
191+
// compatibility with Postgres.
192+
func fillPlaceHolders(query string) string {
193+
var (
194+
sb strings.Builder
195+
argNum = 1
196+
)
197+
198+
for i := range len(query) {
199+
if query[i] != '?' {
200+
sb.WriteByte(query[i])
201+
continue
202+
}
203+
204+
sb.WriteString("$")
205+
sb.WriteString(strconv.Itoa(argNum))
206+
argNum++
207+
}
208+
209+
return sb.String()
210+
}

db/sqlc/db_custom.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package sqlc
22

3+
import (
4+
"context"
5+
)
6+
37
// BackendType is an enum that represents the type of database backend we're
48
// using.
59
type BackendType uint8
@@ -44,3 +48,19 @@ func NewSqlite(db DBTX) *Queries {
4448
func NewPostgres(db DBTX) *Queries {
4549
return &Queries{db: &wrappedTX{db, BackendTypePostgres}}
4650
}
51+
52+
// CustomQueries defines a set of custom queries that we define in addition
53+
// to the ones generated by sqlc.
54+
type CustomQueries interface {
55+
// CountActions returns the number of actions that match the provided
56+
// ActionQueryParams.
57+
CountActions(ctx context.Context, arg ActionQueryParams) (int64, error)
58+
59+
// ListActions retrieves a list of actions based on the provided
60+
// ListActionsParams.
61+
ListActions(ctx context.Context,
62+
arg ListActionsParams) ([]Action, error)
63+
64+
// Backend returns the type of the database backend used.
65+
Backend() BackendType
66+
}

0 commit comments

Comments
 (0)