Last active
October 12, 2024 14:49
-
-
Save pseudomuto/0900a7a3605470760579752fcf0fc2b7 to your computer and use it in GitHub Desktop.
Blog Code: Clean SQL Transactions in Golang
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"log" | |
) | |
func main() { | |
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE") | |
handleError(err) | |
defer db.Close() | |
tx, err := db.Begin() | |
handleError(err) | |
// insert a record into table1 | |
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name") | |
if err != nil { | |
tx.Rollback() | |
log.Fatal(err) | |
} | |
// fetch the auto incremented id | |
id, err := res.LastInsertId() | |
handleError(err) | |
// insert record into table2, referencing the first record from table1 | |
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name") | |
if err != nil { | |
tx.Rollback() | |
log.Fatal(err) | |
} | |
// commit the transaction | |
handleError(tx.Commit()) | |
log.Println("Done.") | |
} | |
func handleError(err error) { | |
if err != nil { | |
log.Fatal(err) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"log" | |
) | |
func main() { | |
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE") | |
handleError(err) | |
defer db.Close() | |
err = WithTransaction(db, func(tx Transaction) error { | |
// insert a record into table1 | |
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name") | |
if err != nil { | |
return err | |
} | |
id, err := res.LastInsertId() | |
if err != nil { | |
return err | |
} | |
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name") | |
if err != nil { | |
return err | |
} | |
}) | |
handleError(err) | |
log.Println("Done.") | |
} | |
func handleError(err error) { | |
if err != nil { | |
log.Fatal(err) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"log" | |
) | |
func main() { | |
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE") | |
handleError(err) | |
defer db.Close() | |
stmts := []*PipelineStmt{ | |
NewPipelineStmt("INSERT INTO table1(name) VALUES(?)", "some name"), | |
NewPipelineStmt("INSERT INTO table2(table1_id, name) VALUES({LAST_INS_ID}, ?)", "other name"), | |
} | |
err = WithTransaction(db, func(tx Transaction) error { | |
_, err := RunPipeline(tx, stmts...) | |
return err | |
}) | |
handleError(err) | |
log.Println("Done.") | |
} | |
func handleError(err error) { | |
if err != nil { | |
log.Fatal(err) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"strconv" | |
"strings" | |
) | |
// A PipelineStmt is a simple wrapper for creating a statement consisting of | |
// a query and a set of arguments to be passed to that query. | |
type PipelineStmt struct { | |
query string | |
args []interface{} | |
} | |
func NewPipelineStmt(query string, args ...interface{}) *PipelineStmt { | |
return &PipelineStmt{query, args} | |
} | |
// Executes the statement within supplied transaction. The literal string `{LAST_INS_ID}` | |
// will be replaced with the supplied value to make chaining `PipelineStmt` objects together | |
// simple. | |
func (ps *PipelineStmt) Exec(tx Transaction, lastInsertId int64) (sql.Result, error) { | |
query := strings.Replace(ps.query, "{LAST_INS_ID}", strconv.Itoa(int(lastInsertId)), -1) | |
return tx.Exec(query, ps.args...) | |
} | |
// Runs the supplied statements within the transaction. If any statement fails, the transaction | |
// is rolled back, and the original error is returned. | |
// | |
// The `LastInsertId` from the previous statement will be passed to `Exec`. The zero-value (0) is | |
// used initially. | |
func RunPipeline(tx Transaction, stmts ...*PipelineStmt) (sql.Result, error) { | |
var res sql.Result | |
var err error | |
var lastInsId int64 | |
for _, ps := range stmts { | |
res, err = ps.Exec(tx, lastInsId) | |
if err != nil { | |
return nil, err | |
} | |
lastInsId, err = res.LastInsertId() | |
if err != nil { | |
return nil, err | |
} | |
} | |
return res, nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
) | |
// Transaction is an interface that models the standard transaction in | |
// `database/sql`. | |
// | |
// To ensure `TxFn` funcs cannot commit or rollback a transaction (which is | |
// handled by `WithTransaction`), those methods are not included here. | |
type Transaction interface { | |
Exec(query string, args ...interface{}) (sql.Result, error) | |
Prepare(query string) (*sql.Stmt, error) | |
Query(query string, args ...interface{}) (*sql.Rows, error) | |
QueryRow(query string, args ...interface{}) *sql.Row | |
} | |
// A Txfn is a function that will be called with an initialized `Transaction` object | |
// that can be used for executing statements and queries against a database. | |
type TxFn func(Transaction) error | |
// WithTransaction creates a new transaction and handles rollback/commit based on the | |
// error object returned by the `TxFn` | |
func WithTransaction(db *sql.DB, fn TxFn) (err error) { | |
tx, err := db.Begin() | |
if err != nil { | |
return | |
} | |
defer func() { | |
if p := recover(); p != nil { | |
// a panic occurred, rollback and repanic | |
tx.Rollback() | |
panic(p) | |
} else if err != nil { | |
// something went wrong, rollback | |
tx.Rollback() | |
} else { | |
// all good, commit | |
err = tx.Commit() | |
} | |
}() | |
err = fn(tx) | |
return err | |
} |
I want to be able to create transactions at the service/use case layer. The problem I'm facing is db *sql.DB
required by func WithTransaction(db *sql.DB, fn TxFn)
is not available in the service layer. Is there a way to get the transaction at the service layer?
I want to be able to create transactions at the service/use case layer. The problem I'm facing is
db *sql.DB
required byfunc WithTransaction(db *sql.DB, fn TxFn)
is not available in the service layer. Is there a way to get the transaction at the service layer?
package tx
import (
"context"
"database/sql"
"noname/constant"
"github.com/rs/zerolog/log"
)
type Queries interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type DB interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type Tr struct {
db DB
}
func NewTransaction(db DB) *Tr {
return &Tr{
db: db,
}
}
type TxFn func(context.Context) error
func (tr *Tr) WithTransaction(ctx context.Context, fn TxFn) (err error) {
tx, err := tr.db.BeginTx(ctx, nil)
if err != nil {
return
}
defer func() {
if err != nil {
log.Info().Msg("Rollback")
tx.Rollback()
} else {
tx.Commit()
}
}()
ctx = context.WithValue(ctx, constant.KeyTx, tx)
err = fn(ctx)
return
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey, thanks. This helped me to break down the problem, I've implemented similar approach in our project. The key thing was abstracting the Tx interface (we use sqlx). Here is what I ended up with: https://stackoverflow.com/questions/65024138/how-to-do-transaction-management/72377629#72377629