Created
October 2, 2024 13:17
-
-
Save yuri-potatoq/20bab227ca80fb8c3abbcd773b2afcb5 to your computer and use it in GitHub Desktop.
Transactions management with Go.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package enrollment | |
import ( | |
"context" | |
"database/sql" | |
"github.com/jmoiron/sqlx" | |
"github.com/yuri-potatoq/generic-profile/infra/db" | |
) | |
type Repository interface { | |
db.TxManager | |
NewEnrollment(ctx context.Context) db.Querier[int] | |
GetEnrollmentState(ctx context.Context, id int) db.Querier[EnrollmentState] | |
CheckEnrollment(ctx context.Context, id int) db.Querier[bool] | |
} | |
type repository struct { | |
db.TxManager | |
} | |
func NewEnrollmentRepository(d *sqlx.DB) Repository { | |
return &repository{ | |
TxManager: db.NewTxManager(d), | |
} | |
} | |
func (r *repository) NewEnrollment(ctx context.Context) db.Querier[int] { | |
return db.NewQuerier[int](ctx, r.TxManager, func(iCtx context.Context, tx *sql.Tx) (int, error) { | |
rs, err := tx.ExecContext(iCtx, "INSERT INTO enrollments DEFAULT VALUES RETURNING ID;") | |
id, err := rs.LastInsertId() | |
return int(id), err | |
}) | |
} | |
func (r *repository) CheckEnrollment(ctx context.Context, id int) db.Querier[bool] { | |
return db.NewQuerier[bool](ctx, r.TxManager, func(iCtx context.Context, tx *sql.Tx) (bool, error) { | |
var total int | |
if err := tx.QueryRowContext(ctx, | |
"SELECT COUNT(*) FROM enrollments WHERE ID = $1;", id, | |
).Scan(&total); err != nil { | |
return false, err | |
} | |
return total > 0, nil | |
}) | |
} | |
func (r *repository) GetEnrollmentState(ctx context.Context, id int) db.Querier[EnrollmentState] { | |
return db.NewQuerier(ctx, r.TxManager, func(iCtx context.Context, tx *sql.Tx) (EnrollmentState, error) { | |
return EnrollmentState{}, nil | |
}) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package enrollment | |
import ( | |
"context" | |
"database/sql" | |
) | |
type Service interface { | |
NewEnrollment(ctx context.Context) (id int, err error) | |
GetEnrollmentState(ctx context.Context, id int) (EnrollmentState, error) | |
} | |
type service struct { | |
r Repository | |
} | |
func NewEnrollmentService(r Repository) Service { | |
return &service{ | |
r: r, | |
} | |
} | |
func (s *service) NewEnrollment(ctx context.Context) (int, error) { | |
return s.r.NewEnrollment(ctx).Atomic() | |
} | |
func (s *service) GetEnrollmentState(ctx context.Context, id int) (EnrollmentState, error) { | |
return s.r.GetEnrollmentState(ctx, id).Atomic() | |
} | |
func (s *service) BulkUpdate(ctx context.Context, id int) (EnrollmentState, error) { | |
var stt EnrollmentState | |
err := s.r.WithTx(ctx, func(tx *sql.Tx) error { | |
var err error | |
enrollmentId := id | |
exists, err := s.r.CheckEnrollment(ctx, id).Tx(tx) | |
if err != nil { | |
return err | |
} | |
if !exists { | |
enrollmentId, err = s.r.NewEnrollment(ctx).Tx(tx) | |
if err != nil { | |
return err | |
} | |
} | |
stt, err = s.r.GetEnrollmentState(ctx, enrollmentId).Tx(tx) | |
return err | |
}) | |
return stt, err | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"context" | |
"database/sql" | |
"github.com/jmoiron/sqlx" | |
"runtime/debug" | |
) | |
type TxManager interface { | |
WithTx(ctx context.Context, op func(tx *sql.Tx) error) error | |
} | |
type commonManager struct { | |
db *sqlx.DB | |
} | |
func NewTxManager(db *sqlx.DB) TxManager { | |
return &commonManager{db} | |
} | |
func (m *commonManager) WithTx(ctx context.Context, op func(tx *sql.Tx) error) error { | |
var err error | |
c, err := m.db.Conn(ctx) | |
if err != nil { | |
return err | |
} | |
tx, err := c.BeginTx(ctx, nil) | |
if err != nil { | |
return err | |
} | |
defer func() { | |
if p := recover(); p != nil { | |
err = tx.Rollback() | |
debug.PrintStack() | |
} else if err != nil { | |
tx.Rollback() | |
} else { | |
err = tx.Commit() | |
} | |
}() | |
return op(tx) | |
} | |
type Querier[T any] interface { | |
Atomic() (T, error) | |
Tx(tx *sql.Tx) (T, error) | |
} | |
type querier[T any] struct { | |
TxManager | |
ctx context.Context | |
queryF func(ctx context.Context, tx *sql.Tx) (T, error) | |
} | |
func (q *querier[T]) Atomic() (T, error) { | |
var rs T | |
return rs, q.TxManager.WithTx(q.ctx, func(tx *sql.Tx) error { | |
var err error | |
rs, err = q.queryF(q.ctx, tx) | |
return err | |
}) | |
} | |
func (q *querier[T]) Tx(tx *sql.Tx) (T, error) { | |
return q.queryF(q.ctx, tx) | |
} | |
func NewQuerier[T any]( | |
ctx context.Context, | |
tx TxManager, | |
f func(ctx context.Context, tx *sql.Tx, | |
) (T, error)) Querier[T] { | |
return &querier[T]{ | |
ctx: ctx, | |
queryF: f, | |
TxManager: tx, | |
} | |
} | |
type Executer interface { | |
Atomic() error | |
Tx(tx *sql.Tx) error | |
} | |
type executer struct { | |
TxManager | |
ctx context.Context | |
execF func(ctx context.Context, tx *sql.Tx) error | |
} | |
func (q *executer) Atomic() error { | |
return q.TxManager.WithTx(q.ctx, func(tx *sql.Tx) error { | |
return q.execF(q.ctx, tx) | |
}) | |
} | |
func (q *executer) Tx(tx *sql.Tx) error { | |
return q.execF(q.ctx, tx) | |
} | |
// TODO: missing some executer implementations |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment