Skip to content

Add support for getting db session form context #3772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ The `gen` mapping supports the following keys:
that returns all valid enum values.
- `emit_sql_as_comment`:
- If true, emits the SQL statement as a code-block comment above the generated function, appending to any existing comments. Defaults to `false`.
- `get_db_from_context`:
- If true, emits `New` method for `Querier` with a function argument which accepts a ctx argument and returns DBTX. Defaults to `false`.
- `build_tags`:
- If set, add a `//go:build <build_tags>` directive at the beginning of each generated Go file.
- `initialisms`:
Expand Down Expand Up @@ -414,6 +416,7 @@ packages:
emit_pointers_for_null_types: false
emit_enum_valid_method: false
emit_all_enum_values: false
get_db_from_context: false
build_tags: "some_tag"
json_tags_case_style: "camel"
omit_unused_structs: false
Expand Down Expand Up @@ -469,6 +472,8 @@ Each mapping in the `packages` collection has the following keys:
- `emit_all_enum_values`:
- If true, emit a function per enum type
that returns all valid enum values.
- `get_db_from_context`:
- If true, emits `New` method for `Querier` with a function argument which accepts a ctx argument and returns DBTX. Defaults to `false`.
- `build_tags`:
- If set, add a `//go:build <build_tags>` directive at the beginning of each generated Go file.
- `json_tags_case_style`:
Expand Down
5 changes: 5 additions & 0 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type tmplCtx struct {
EmitMethodsWithDBArgument bool
EmitEnumValidMethod bool
EmitAllEnumValues bool
GetDBFromContext bool
UsesCopyFrom bool
UsesBatch bool
OmitSqlcVersion bool
Expand Down Expand Up @@ -65,6 +66,9 @@ func (t *tmplCtx) codegenQueryMethod(q Query) string {
if t.EmitMethodsWithDBArgument {
db = "db"
}
if t.GetDBFromContext {
db = "q.getDBFromContext(ctx)"
}

switch q.Cmd {
case ":one":
Expand Down Expand Up @@ -177,6 +181,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
EmitMethodsWithDBArgument: options.EmitMethodsWithDbArgument,
EmitEnumValidMethod: options.EmitEnumValidMethod,
EmitAllEnumValues: options.EmitAllEnumValues,
GetDBFromContext: options.GetDBFromContext,
UsesCopyFrom: usesCopyFrom(queries),
UsesBatch: usesBatch(queries),
SQLDriver: parseDriver(options.SqlPackage),
Expand Down
7 changes: 7 additions & 0 deletions internal/codegen/golang/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Options struct {
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
GetDBFromContext bool `json:"get_db_from_context,omitempty" yaml:"get_db_from_context"`
JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
Package string `json:"package" yaml:"package"`
Out string `json:"out" yaml:"out"`
Expand Down Expand Up @@ -147,6 +148,12 @@ func ValidateOpts(opts *Options) error {
if opts.EmitMethodsWithDbArgument && opts.EmitPreparedQueries {
return fmt.Errorf("invalid options: emit_methods_with_db_argument and emit_prepared_queries options are mutually exclusive")
}
if opts.GetDBFromContext && opts.EmitPreparedQueries {
return fmt.Errorf("invalid options: get_db_from_context and emit_prepared_queries options are mutually exclusive")
}
if opts.GetDBFromContext && opts.EmitMethodsWithDbArgument {
return fmt.Errorf("invalid options: get_db_from_context and emit_methods_with_db_argument options are mutually exclusive")
}
if *opts.QueryParameterLimit < 0 {
return fmt.Errorf("invalid options: query parameter limit must not be negative")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArg
go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}})
// The string interpolation is necessary because LOAD DATA INFILE requires
// the file name to be given as a literal string.
result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping))
{{- $db := "q.db"}}
{{- if $.EmitMethodsWithDBArgument}}
{{- $db = "db"}}
{{- else if $.GetDBFromContext}}
{{- $db = "q.getDBFromContext(ctx)"}}
{{- end}}
result, err := {{$db}}.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping))
if err != nil {
return 0, err
}
Expand Down
8 changes: 7 additions & 1 deletion internal/codegen/golang/templates/pgx/batchCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDB
}
batch.Queue({{.ConstantName}}, vals...)
}
br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch)
{{- $db := "q.db"}}
{{- if $.EmitMethodsWithDBArgument}}
{{- $db = "db"}}
{{- else if $.GetDBFromContext}}
{{- $db = "q.getDBFromContext(ctx)"}}
{{- end}}
br := {{$db}}.SendBatch(ctx, batch)
return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false}
}

Expand Down
7 changes: 6 additions & 1 deletion internal/codegen/golang/templates/pgx/copyfromCopy.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair
return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {
return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})

{{- $db := "db"}}
{{- if $.GetDBFromContext}}
{{- $db = "getDBFromContext(ctx)"}}
{{- end}}
return q.{{$db}}.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
{{- end}}
}

Expand Down
9 changes: 7 additions & 2 deletions internal/codegen/golang/templates/pgx/dbCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,24 @@ type DBTX interface {
{{ if .EmitMethodsWithDBArgument}}
func New() *Queries {
return &Queries{}
{{- else if .GetDBFromContext}}
func New(getDBFromContext func(context.Context) DBTX) *Queries {
return &Queries{getDBFromContext: getDBFromContext}
{{- else -}}
func New(db DBTX) *Queries {
return &Queries{db: db}
{{- end}}
}

type Queries struct {
{{if not .EmitMethodsWithDBArgument}}
{{- if .GetDBFromContext}}
getDBFromContext func(context.Context) DBTX
{{- else if not .EmitMethodsWithDBArgument}}
db DBTX
{{end}}
}

{{if not .EmitMethodsWithDBArgument}}
{{if and (not .EmitMethodsWithDBArgument) (not .GetDBFromContext)}}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
Expand Down
15 changes: 10 additions & 5 deletions internal/codegen/golang/templates/pgx/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
{{$.Q}}
{{end}}

{{- $db := "db" }}
{{- if $.GetDBFromContext}}
{{- $db = "getDBFromContext(ctx)"}}
{{- end}}

{{if ne (hasPrefix .Cmd ":batch") true}}
{{if .Arg.EmitStruct}}
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
Expand All @@ -31,7 +36,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
row := q.{{$db}}.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
var {{.Ret.Name}} {{.Ret.Type}}
Expand All @@ -49,7 +54,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {
rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
rows, err := q.{{$db}}.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
if err != nil {
return nil, err
Expand Down Expand Up @@ -82,7 +87,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) e
_, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {
_, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
_, err := q.{{$db}}.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
return err
}
Expand All @@ -96,7 +101,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {
result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
result, err := q.{{$db}}.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
if err != nil {
return 0, err
Expand All @@ -113,7 +118,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (
return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
return q.{{$db}}.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
}
{{end}}
Expand Down
9 changes: 7 additions & 2 deletions internal/codegen/golang/templates/stdlib/dbCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ type DBTX interface {
{{ if .EmitMethodsWithDBArgument}}
func New() *Queries {
return &Queries{}
{{- else if .GetDBFromContext}}
func New(getDBFromContext func(context.Context) DBTX) *Queries {
return &Queries{getDBFromContext: getDBFromContext}
{{- else -}}
func New(db DBTX) *Queries {
return &Queries{db: db}
Expand Down Expand Up @@ -77,7 +80,9 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
{{end}}

type Queries struct {
{{- if not .EmitMethodsWithDBArgument}}
{{- if .GetDBFromContext}}
getDBFromContext func(context.Context) DBTX
{{- else if not .EmitMethodsWithDBArgument}}
db DBTX
{{- end}}

Expand All @@ -89,7 +94,7 @@ type Queries struct {
{{- end}}
}

{{if not .EmitMethodsWithDBArgument}}
{{if and (not .EmitMethodsWithDBArgument) (not .GetDBFromContext)}}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
Expand Down
2 changes: 2 additions & 0 deletions internal/config/v_one.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type v1PackageSettings struct {
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
GetDBFromContext bool `json:"get_db_from_context,omitempty" yaml:"get_db_from_context"`
JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
SQLPackage string `json:"sql_package" yaml:"sql_package"`
SQLDriver string `json:"sql_driver" yaml:"sql_driver"`
Expand Down Expand Up @@ -152,6 +153,7 @@ func (c *V1GenerateSettings) Translate() Config {
EmitEnumValidMethod: pkg.EmitEnumValidMethod,
EmitAllEnumValues: pkg.EmitAllEnumValues,
EmitSqlAsComment: pkg.EmitSqlAsComment,
GetDBFromContext: pkg.GetDBFromContext,
Package: pkg.Name,
Out: pkg.Path,
SqlPackage: pkg.SQLPackage,
Expand Down
3 changes: 3 additions & 0 deletions internal/config/v_one.json
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@
"emit_sql_as_comment": {
"type": "boolean"
},
"get_db_from_context": {
"type": "boolean"
},
"build_tags": {
"type": "string"
},
Expand Down
3 changes: 3 additions & 0 deletions internal/config/v_two.json
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
"emit_sql_as_comment": {
"type": "boolean"
},
"get_db_from_context": {
"type": "boolean"
},
"build_tags": {
"type": "string"
},
Expand Down
25 changes: 25 additions & 0 deletions internal/endtoend/testdata/get_db_from_context/mysql/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions internal/endtoend/testdata/get_db_from_context/mysql/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/* name: GetAll :many */
SELECT * FROM users;
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE TABLE users (
id integer NOT NULL AUTO_INCREMENT PRIMARY KEY,
first_name varchar(255) NOT NULL,
last_name varchar(255),
age integer NOT NULL
);
13 changes: 13 additions & 0 deletions internal/endtoend/testdata/get_db_from_context/mysql/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"version": "1",
"packages": [
{
"name": "querytest",
"path": "go",
"schema": "schema.sql",
"queries": "query.sql",
"engine": "mysql",
"get_db_from_context": true
}
]
}
Loading
Loading