Skip to content

Instantly share code, notes, and snippets.

@marciol
Last active February 7, 2017 16:33
Show Gist options
  • Save marciol/e5ec8f8ec5253272ab0750cf768e6d0e to your computer and use it in GitHub Desktop.
Save marciol/e5ec8f8ec5253272ab0750cf768e6d0e to your computer and use it in GitHub Desktop.
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/lib/pq"
_ "github.com/lib/pq"
"math/rand"
"net/url"
"strings"
"log"
)
type DB struct {
sqldb pq.DB
schema string
}
func NewDB(db DBDriver, schema string) *DB {
return &DB{db, schema}
}
func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
return db.sqldb.Exec(query, args...)
}
func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
return db.sqldb.Query(query, args...)
}
func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
return db.sqldb.QueryRow(query, args...)
}
func (db *DB) Select(dest interface{}, query string, args ...interface{}) error {
return Select(db, dest, query, args...)
}
type Tx struct {
sqltx pq.DB
schema string
}
func (db *DB) Begin() (*Tx, error) {
tx, err := db.begin()
if err != nil {
return nil, err
}
err = setSchema(tx, db.schema)
if err != nil {
return nil, err
}
return &Tx{tx, db.schema}, nil
}
func (db *DB) begin() (TxDriver, error) {
var tx *Tx
var err error
tx, err = sqldb.Begin()
return tx, err
}
func (tx *Tx) SetSchema(schema string) error {
return setSchema(tx.sqltx, schema)
}
func setSchema(txdrive TxDriver, schema string) error {
_, err := txdrive.Exec(fmt.Sprintf(`SET search_path TO %s`, pq.QuoteIdentifier(schema)))
return err
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.sqltx.Exec(query, args...)
}
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
return tx.sqltx.Query(query, args...)
}
func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
return tx.sqltx.QueryRow(query, args...)
}
func (tx *Tx) Commit() error {
return tx.sqltx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.sqltx.Rollback()
}
func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error {
return Select(tx, dest, query, args...)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment