Skip to content

Instantly share code, notes, and snippets.

@yuri-potatoq
Created October 2, 2024 13:17
Show Gist options
  • Save yuri-potatoq/20bab227ca80fb8c3abbcd773b2afcb5 to your computer and use it in GitHub Desktop.
Save yuri-potatoq/20bab227ca80fb8c3abbcd773b2afcb5 to your computer and use it in GitHub Desktop.
Transactions management with Go.
package enrollment
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"github.com/yuri-potatoq/generic-profile/infra/db"
)
type Repository interface {
db.TxManager
NewEnrollment(ctx context.Context) db.Querier[int]
GetEnrollmentState(ctx context.Context, id int) db.Querier[EnrollmentState]
CheckEnrollment(ctx context.Context, id int) db.Querier[bool]
}
type repository struct {
db.TxManager
}
func NewEnrollmentRepository(d *sqlx.DB) Repository {
return &repository{
TxManager: db.NewTxManager(d),
}
}
func (r *repository) NewEnrollment(ctx context.Context) db.Querier[int] {
return db.NewQuerier[int](ctx, r.TxManager, func(iCtx context.Context, tx *sql.Tx) (int, error) {
rs, err := tx.ExecContext(iCtx, "INSERT INTO enrollments DEFAULT VALUES RETURNING ID;")
id, err := rs.LastInsertId()
return int(id), err
})
}
func (r *repository) CheckEnrollment(ctx context.Context, id int) db.Querier[bool] {
return db.NewQuerier[bool](ctx, r.TxManager, func(iCtx context.Context, tx *sql.Tx) (bool, error) {
var total int
if err := tx.QueryRowContext(ctx,
"SELECT COUNT(*) FROM enrollments WHERE ID = $1;", id,
).Scan(&total); err != nil {
return false, err
}
return total > 0, nil
})
}
func (r *repository) GetEnrollmentState(ctx context.Context, id int) db.Querier[EnrollmentState] {
return db.NewQuerier(ctx, r.TxManager, func(iCtx context.Context, tx *sql.Tx) (EnrollmentState, error) {
return EnrollmentState{}, nil
})
}
package enrollment
import (
"context"
"database/sql"
)
type Service interface {
NewEnrollment(ctx context.Context) (id int, err error)
GetEnrollmentState(ctx context.Context, id int) (EnrollmentState, error)
}
type service struct {
r Repository
}
func NewEnrollmentService(r Repository) Service {
return &service{
r: r,
}
}
func (s *service) NewEnrollment(ctx context.Context) (int, error) {
return s.r.NewEnrollment(ctx).Atomic()
}
func (s *service) GetEnrollmentState(ctx context.Context, id int) (EnrollmentState, error) {
return s.r.GetEnrollmentState(ctx, id).Atomic()
}
func (s *service) BulkUpdate(ctx context.Context, id int) (EnrollmentState, error) {
var stt EnrollmentState
err := s.r.WithTx(ctx, func(tx *sql.Tx) error {
var err error
enrollmentId := id
exists, err := s.r.CheckEnrollment(ctx, id).Tx(tx)
if err != nil {
return err
}
if !exists {
enrollmentId, err = s.r.NewEnrollment(ctx).Tx(tx)
if err != nil {
return err
}
}
stt, err = s.r.GetEnrollmentState(ctx, enrollmentId).Tx(tx)
return err
})
return stt, err
}
package db
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"runtime/debug"
)
type TxManager interface {
WithTx(ctx context.Context, op func(tx *sql.Tx) error) error
}
type commonManager struct {
db *sqlx.DB
}
func NewTxManager(db *sqlx.DB) TxManager {
return &commonManager{db}
}
func (m *commonManager) WithTx(ctx context.Context, op func(tx *sql.Tx) error) error {
var err error
c, err := m.db.Conn(ctx)
if err != nil {
return err
}
tx, err := c.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
err = tx.Rollback()
debug.PrintStack()
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
return op(tx)
}
type Querier[T any] interface {
Atomic() (T, error)
Tx(tx *sql.Tx) (T, error)
}
type querier[T any] struct {
TxManager
ctx context.Context
queryF func(ctx context.Context, tx *sql.Tx) (T, error)
}
func (q *querier[T]) Atomic() (T, error) {
var rs T
return rs, q.TxManager.WithTx(q.ctx, func(tx *sql.Tx) error {
var err error
rs, err = q.queryF(q.ctx, tx)
return err
})
}
func (q *querier[T]) Tx(tx *sql.Tx) (T, error) {
return q.queryF(q.ctx, tx)
}
func NewQuerier[T any](
ctx context.Context,
tx TxManager,
f func(ctx context.Context, tx *sql.Tx,
) (T, error)) Querier[T] {
return &querier[T]{
ctx: ctx,
queryF: f,
TxManager: tx,
}
}
type Executer interface {
Atomic() error
Tx(tx *sql.Tx) error
}
type executer struct {
TxManager
ctx context.Context
execF func(ctx context.Context, tx *sql.Tx) error
}
func (q *executer) Atomic() error {
return q.TxManager.WithTx(q.ctx, func(tx *sql.Tx) error {
return q.execF(q.ctx, tx)
})
}
func (q *executer) Tx(tx *sql.Tx) error {
return q.execF(q.ctx, tx)
}
// TODO: missing some executer implementations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment