Skip to content

Instantly share code, notes, and snippets.

@pseudomuto
Last active October 12, 2024 14:49
Show Gist options
  • Save pseudomuto/0900a7a3605470760579752fcf0fc2b7 to your computer and use it in GitHub Desktop.
Save pseudomuto/0900a7a3605470760579752fcf0fc2b7 to your computer and use it in GitHub Desktop.
Blog Code: Clean SQL Transactions in Golang
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
tx, err := db.Begin()
handleError(err)
// insert a record into table1
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
if err != nil {
tx.Rollback()
log.Fatal(err)
}
// fetch the auto incremented id
id, err := res.LastInsertId()
handleError(err)
// insert record into table2, referencing the first record from table1
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
if err != nil {
tx.Rollback()
log.Fatal(err)
}
// commit the transaction
handleError(tx.Commit())
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
err = WithTransaction(db, func(tx Transaction) error {
// insert a record into table1
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
if err != nil {
return err
}
id, err := res.LastInsertId()
if err != nil {
return err
}
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
if err != nil {
return err
}
})
handleError(err)
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
stmts := []*PipelineStmt{
NewPipelineStmt("INSERT INTO table1(name) VALUES(?)", "some name"),
NewPipelineStmt("INSERT INTO table2(table1_id, name) VALUES({LAST_INS_ID}, ?)", "other name"),
}
err = WithTransaction(db, func(tx Transaction) error {
_, err := RunPipeline(tx, stmts...)
return err
})
handleError(err)
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"strconv"
"strings"
)
// A PipelineStmt is a simple wrapper for creating a statement consisting of
// a query and a set of arguments to be passed to that query.
type PipelineStmt struct {
query string
args []interface{}
}
func NewPipelineStmt(query string, args ...interface{}) *PipelineStmt {
return &PipelineStmt{query, args}
}
// Executes the statement within supplied transaction. The literal string `{LAST_INS_ID}`
// will be replaced with the supplied value to make chaining `PipelineStmt` objects together
// simple.
func (ps *PipelineStmt) Exec(tx Transaction, lastInsertId int64) (sql.Result, error) {
query := strings.Replace(ps.query, "{LAST_INS_ID}", strconv.Itoa(int(lastInsertId)), -1)
return tx.Exec(query, ps.args...)
}
// Runs the supplied statements within the transaction. If any statement fails, the transaction
// is rolled back, and the original error is returned.
//
// The `LastInsertId` from the previous statement will be passed to `Exec`. The zero-value (0) is
// used initially.
func RunPipeline(tx Transaction, stmts ...*PipelineStmt) (sql.Result, error) {
var res sql.Result
var err error
var lastInsId int64
for _, ps := range stmts {
res, err = ps.Exec(tx, lastInsId)
if err != nil {
return nil, err
}
lastInsId, err = res.LastInsertId()
if err != nil {
return nil, err
}
}
return res, nil
}
package main
import (
"database/sql"
)
// Transaction is an interface that models the standard transaction in
// `database/sql`.
//
// To ensure `TxFn` funcs cannot commit or rollback a transaction (which is
// handled by `WithTransaction`), those methods are not included here.
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// A Txfn is a function that will be called with an initialized `Transaction` object
// that can be used for executing statements and queries against a database.
type TxFn func(Transaction) error
// WithTransaction creates a new transaction and handles rollback/commit based on the
// error object returned by the `TxFn`
func WithTransaction(db *sql.DB, fn TxFn) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
tx.Rollback()
panic(p)
} else if err != nil {
// something went wrong, rollback
tx.Rollback()
} else {
// all good, commit
err = tx.Commit()
}
}()
err = fn(tx)
return err
}
@lzap
Copy link

lzap commented May 25, 2022

Hey, thanks. This helped me to break down the problem, I've implemented similar approach in our project. The key thing was abstracting the Tx interface (we use sqlx). Here is what I ended up with: https://stackoverflow.com/questions/65024138/how-to-do-transaction-management/72377629#72377629

@pranayhere
Copy link

I want to be able to create transactions at the service/use case layer. The problem I'm facing is db *sql.DB required by func WithTransaction(db *sql.DB, fn TxFn) is not available in the service layer. Is there a way to get the transaction at the service layer?

@umardev500
Copy link

I want to be able to create transactions at the service/use case layer. The problem I'm facing is db *sql.DB required by func WithTransaction(db *sql.DB, fn TxFn) is not available in the service layer. Is there a way to get the transaction at the service layer?

package tx

import (
	"context"
	"database/sql"
	"noname/constant"

	"github.com/rs/zerolog/log"
)

type Queries interface {
	Exec(query string, args ...interface{}) (sql.Result, error)
	Prepare(query string) (*sql.Stmt, error)
	Query(query string, args ...interface{}) (*sql.Rows, error)
	QueryRow(query string, args ...interface{}) *sql.Row
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

type DB interface {
	Begin() (*sql.Tx, error)
	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

type Tr struct {
	db DB
}

func NewTransaction(db DB) *Tr {
	return &Tr{
		db: db,
	}
}

type TxFn func(context.Context) error

func (tr *Tr) WithTransaction(ctx context.Context, fn TxFn) (err error) {
	tx, err := tr.db.BeginTx(ctx, nil)
	if err != nil {
		return
	}

	defer func() {
		if err != nil {
			log.Info().Msg("Rollback")
			tx.Rollback()
		} else {
			tx.Commit()
		}
	}()

	ctx = context.WithValue(ctx, constant.KeyTx, tx)
	err = fn(ctx)

	return
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment