From 7f4bbfd319f26b1c84228fd2bf6f8ca122855ed4 Mon Sep 17 00:00:00 2001 From: Eyal Halpern Shalev Date: Fri, 18 Jul 2025 04:01:52 +0300 Subject: [PATCH 1/3] Add Begin to pgx db template --- .../codegen/golang/templates/pgx/dbCode.tmpl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 236554d9f2..19309b9a5a 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -1,6 +1,8 @@ {{define "dbCodeTemplatePgx"}} type DBTX interface { + Begin(context.Context) (pgx.Tx, error) + BeginTx(context.Context, txOptions TxOptions) (pgx.Tx, error) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row @@ -33,5 +35,19 @@ func (q *Queries) WithTx(tx pgx.Tx) *Queries { db: tx, } } +func (q *Queries) Begin(ctx context.Context) (*Queries, error) { + tx, err := q.db.Begin(ctx) + if (err != nil { + return nil, err + } + return q.WithTx(tx), nil +} +func (q *Queries) BeginTx(ctx context.Context, txOptions TxOptions) (*Queries, error) { + tx, err := q.db.BeginTx(ctx, txOptions) + if (err != nil { + return nil, err + } + return q.WithTx(tx), nil +} {{end}} {{end}} From 647b6502ddd325d199ac65a1f40155b36038a3ef Mon Sep 17 00:00:00 2001 From: Eyal Halpern Shalev Date: Fri, 18 Jul 2025 04:11:13 +0300 Subject: [PATCH 2/3] update stdlib db template with BeginTx --- internal/codegen/golang/templates/stdlib/dbCode.tmpl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 7433d522f6..05c5dea855 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -1,5 +1,6 @@ {{define "dbCodeTemplateStd"}} type DBTX interface { + BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) @@ -101,5 +102,12 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { {{- end}} } } +func (q *Queries) BeginTx(ctx context.Context, opts driver.TxOptions) (*Queries, error) { + tx, err := q.db.BeginTx(ctx, opts) + if (err != nil { + return nil, err + } + return q.WithTx(tx), nil +} {{end}} {{end}} From 29949b990a40555783d1ba13cbb3a5c6dd098a92 Mon Sep 17 00:00:00 2001 From: Eyal Halpern Shalev Date: Sat, 19 Jul 2025 10:38:00 +0300 Subject: [PATCH 3/3] Hide the Begin methods behind an EmitBegin flag --- internal/codegen/golang/gen.go | 2 ++ internal/codegen/golang/opts/options.go | 1 + internal/codegen/golang/templates/pgx/dbCode.tmpl | 4 ++++ internal/codegen/golang/templates/stdlib/dbCode.tmpl | 4 ++++ 4 files changed, 11 insertions(+) diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index ac91cc537f..0b89e9802a 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -39,6 +39,7 @@ type tmplCtx struct { EmitAllEnumValues bool UsesCopyFrom bool UsesBatch bool + EmitBegin bool OmitSqlcVersion bool BuildTags string WrapErrors bool @@ -181,6 +182,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, EmitMethodsWithDBArgument: options.EmitMethodsWithDbArgument, EmitEnumValidMethod: options.EmitEnumValidMethod, EmitAllEnumValues: options.EmitAllEnumValues, + EmitBegin: options.EmitBegin, UsesCopyFrom: usesCopyFrom(queries), UsesBatch: usesBatch(queries), SQLDriver: parseDriver(options.SqlPackage), diff --git a/internal/codegen/golang/opts/options.go b/internal/codegen/golang/opts/options.go index 0d5d51c2dd..82bb6d2dcd 100644 --- a/internal/codegen/golang/opts/options.go +++ b/internal/codegen/golang/opts/options.go @@ -21,6 +21,7 @@ type Options struct { EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` EmitMethodsWithDbArgument bool `json:"emit_methods_with_db_argument,omitempty" yaml:"emit_methods_with_db_argument"` + EmitBegin bool `json:"emit_begin,omitempty" yaml:"emit_begin"` EmitPointersForNullTypes bool `json:"emit_pointers_for_null_types" yaml:"emit_pointers_for_null_types"` EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"` EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"` diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 19309b9a5a..dc91c0e01a 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -1,8 +1,10 @@ {{define "dbCodeTemplatePgx"}} type DBTX interface { +{{- if .EmitBegin }} Begin(context.Context) (pgx.Tx, error) BeginTx(context.Context, txOptions TxOptions) (pgx.Tx, error) +{{- end }} Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row @@ -35,6 +37,7 @@ func (q *Queries) WithTx(tx pgx.Tx) *Queries { db: tx, } } +{{- if .EmitBegin }} func (q *Queries) Begin(ctx context.Context) (*Queries, error) { tx, err := q.db.Begin(ctx) if (err != nil { @@ -49,5 +52,6 @@ func (q *Queries) BeginTx(ctx context.Context, txOptions TxOptions) (*Queries, e } return q.WithTx(tx), nil } +{{- end }} {{end}} {{end}} diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 05c5dea855..59081bd4ea 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -1,6 +1,8 @@ {{define "dbCodeTemplateStd"}} type DBTX interface { +{{- if .EmitBegin }} BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) +{{end}} ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) @@ -102,6 +104,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { {{- end}} } } +{{- if .EmitBegin }} func (q *Queries) BeginTx(ctx context.Context, opts driver.TxOptions) (*Queries, error) { tx, err := q.db.BeginTx(ctx, opts) if (err != nil { @@ -111,3 +114,4 @@ func (q *Queries) BeginTx(ctx context.Context, opts driver.TxOptions) (*Queries, } {{end}} {{end}} +{{end}}