diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index ac91cc537f..0b89e9802a 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -39,6 +39,7 @@ type tmplCtx struct { EmitAllEnumValues bool UsesCopyFrom bool UsesBatch bool + EmitBegin bool OmitSqlcVersion bool BuildTags string WrapErrors bool @@ -181,6 +182,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, EmitMethodsWithDBArgument: options.EmitMethodsWithDbArgument, EmitEnumValidMethod: options.EmitEnumValidMethod, EmitAllEnumValues: options.EmitAllEnumValues, + EmitBegin: options.EmitBegin, UsesCopyFrom: usesCopyFrom(queries), UsesBatch: usesBatch(queries), SQLDriver: parseDriver(options.SqlPackage), diff --git a/internal/codegen/golang/opts/options.go b/internal/codegen/golang/opts/options.go index 0d5d51c2dd..82bb6d2dcd 100644 --- a/internal/codegen/golang/opts/options.go +++ b/internal/codegen/golang/opts/options.go @@ -21,6 +21,7 @@ type Options struct { EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` EmitMethodsWithDbArgument bool `json:"emit_methods_with_db_argument,omitempty" yaml:"emit_methods_with_db_argument"` + EmitBegin bool `json:"emit_begin,omitempty" yaml:"emit_begin"` EmitPointersForNullTypes bool `json:"emit_pointers_for_null_types" yaml:"emit_pointers_for_null_types"` EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"` EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"` diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 236554d9f2..dc91c0e01a 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -1,6 +1,10 @@ {{define "dbCodeTemplatePgx"}} type DBTX interface { +{{- if .EmitBegin }} + Begin(context.Context) (pgx.Tx, error) + BeginTx(context.Context, txOptions TxOptions) (pgx.Tx, error) +{{- end }} Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row @@ -33,5 +37,21 @@ func (q *Queries) WithTx(tx pgx.Tx) *Queries { db: tx, } } +{{- if .EmitBegin }} +func (q *Queries) Begin(ctx context.Context) (*Queries, error) { + tx, err := q.db.Begin(ctx) + if (err != nil { + return nil, err + } + return q.WithTx(tx), nil +} +func (q *Queries) BeginTx(ctx context.Context, txOptions TxOptions) (*Queries, error) { + tx, err := q.db.BeginTx(ctx, txOptions) + if (err != nil { + return nil, err + } + return q.WithTx(tx), nil +} +{{- end }} {{end}} {{end}} diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 7433d522f6..59081bd4ea 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -1,5 +1,8 @@ {{define "dbCodeTemplateStd"}} type DBTX interface { +{{- if .EmitBegin }} + BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) +{{end}} ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) @@ -101,5 +104,14 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { {{- end}} } } +{{- if .EmitBegin }} +func (q *Queries) BeginTx(ctx context.Context, opts driver.TxOptions) (*Queries, error) { + tx, err := q.db.BeginTx(ctx, opts) + if (err != nil { + return nil, err + } + return q.WithTx(tx), nil +} +{{end}} {{end}} {{end}}