Created
November 3, 2018 08:02
-
-
Save acoshift/aad4f3843f8cc7b146370cba4c76322b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package sqldb | |
import ( | |
"context" | |
"database/sql" | |
"net/http" | |
"github.com/acoshift/middleware" | |
"github.com/acoshift/pgsql" | |
) | |
type ( | |
ctxKeyDB struct{} | |
ctxKeyQueryer struct{} | |
) | |
// Abort aborts tx | |
var Abort = pgsql.ErrAbortTx | |
// Middleware injects db into context | |
func Middleware(db *sql.DB) middleware.Middleware { | |
return func(h http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
ctx := r.Context() | |
ctx = context.WithValue(ctx, ctxKeyDB{}, db) | |
ctx = context.WithValue(ctx, ctxKeyQueryer{}, db) | |
r = r.WithContext(ctx) | |
h.ServeHTTP(w, r) | |
}) | |
} | |
} | |
type queryer interface { | |
QueryRowContext(context.Context, string, ...interface{}) *sql.Row | |
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) | |
ExecContext(context.Context, string, ...interface{}) (sql.Result, error) | |
} | |
func q(ctx context.Context) queryer { | |
return ctx.Value(ctxKeyQueryer{}).(queryer) | |
} | |
// QueryRow runs query row | |
func QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { | |
return q(ctx).QueryRowContext(ctx, query, args...) | |
} | |
// Query runs query | |
func Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | |
return q(ctx).QueryContext(ctx, query, args...) | |
} | |
// Exec runs exec | |
func Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | |
return q(ctx).ExecContext(ctx, query, args...) | |
} | |
// RunInTx runs f in tx | |
func RunInTx(ctx context.Context, f func(context.Context) error) error { | |
if _, ok := ctx.Value(ctxKeyQueryer{}).(*sql.Tx); ok { | |
return f(ctx) | |
} | |
db := ctx.Value(ctxKeyDB{}).(*sql.DB) | |
return pgsql.RunInTxContext(ctx, db, nil, func(tx *sql.Tx) error { | |
ctx := context.WithValue(ctx, ctxKeyQueryer{}, tx) | |
return f(ctx) | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment