Skip to content

Instantly share code, notes, and snippets.

@gingerhot
Forked from troyk/db.go
Created May 2, 2017 01:13
Show Gist options
  • Save gingerhot/ce0097bcd8872ed6a683049feb043690 to your computer and use it in GitHub Desktop.
Save gingerhot/ce0097bcd8872ed6a683049feb043690 to your computer and use it in GitHub Desktop.
golang sql query abstraction
package db
import (
"database/sql"
"log"
"strings"
"sync"
_ "github.com/lib/pq" // required for database/sql
"golang.org/x/net/context"
)
var db *sql.DB
// Queryer interface is used to query the db, allowing both tx and normal
// query opts
type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(query string, args ...interface{}) (sql.Result, error)
}
type ctxkey int
const ctxQueryer ctxkey = 0
// FromContext retrieves the Queryer from the current context,
// this is how transactions, current user, etc can be passed around
func FromContext(ctx context.Context) Queryer {
q, ok := ctx.Value(ctxQueryer).(Queryer)
if !ok {
return db
}
return q
}
var schemaCache struct {
sync.Mutex
tables map[string][]column
}
// Internal representation of a field on a database table, and its
// relation to a struct field.
type column struct {
// column = table column name
Table, Column, DefaultValue string
// Is a primary key
IsPrimaryKey bool
IsNullable bool
}
func init() {
db, _ = sql.Open("postgres", "postgres://[email protected]:5432/assurehire?sslmode=disable")
db.SetMaxIdleConns(10)
db.SetMaxOpenConns(10)
// return
}
// GetSchema returns []columm for a table
func GetSchema(table string) []column {
schemaCache.Lock()
columns := schemaCache.tables[table]
// load schema if tables is empty
if columns == nil && len(schemaCache.tables) == 0 {
rows, err := db.Query(`select c.table_schema,c.table_name,c.column_name,c.column_default,c.is_nullable::bool, (select array_to_string(array_agg(tc.constraint_type::text),',') from information_schema.key_column_usage kc join information_schema.table_constraints tc on kc.constraint_name = tc.constraint_name where kc.table_schema=c.table_schema and kc.table_name=c.table_name and kc.column_name=c.column_name) as constraints from information_schema.columns c where c.table_schema = 'public' order by c.table_name,c.ordinal_position;`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var schema, table, name, defvalue, constraints string
var nullable bool
if err := rows.Scan(&schema, &table, &name, &defvalue, &nullable, &constraints); err != nil {
log.Fatal(err)
}
columns := schemaCache.tables[table]
if columns == nil {
columns = []column{}
}
schemaCache.tables[table] = append(columns, column{table, name, defvalue, strings.Contains(constraints, "PRIMARY KEY"), nullable})
}
columns = schemaCache.tables[table]
}
schemaCache.Unlock()
return columns
}
package db
import (
"fmt"
"strconv"
"strings"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestSqlBuilder(t *testing.T) {
Convey("replaceHolders", t, func() {
So(replaceHolders("select $1, $2", 0), ShouldEqual, "select $1, $2")
So(replaceHolders("select $1, $2", 1), ShouldEqual, "select $2, $3")
})
Convey("Select Builder", t, func() {
q := From("users")
So(q.from, ShouldEqual, "users")
q.Where("email=$1", "foo")
So(len(q.wheres), ShouldEqual, 1)
sql, args := q.ToSQL()
So(sql, ShouldEqual, "SELECT * FROM users WHERE email=$1")
So(args, ShouldContain, "foo")
q.Where("username=$1", "bar")
sql, args = q.ToSQL()
So(sql, ShouldEqual, "SELECT * FROM users WHERE email=$1 AND username=$2")
So(len(args), ShouldEqual, 2)
So(args[0], ShouldEqual, "foo")
So(args[1], ShouldEqual, "bar")
q = From("background_checks as bg").
Where("bg.id=$1 and bg.account_id=$2", ":id", ":account_id").
Select("bg.*,(select jp from job_positions jp where jp.id=bg.job_position_id limit 1) as job_position")
sql, args = q.ToSQL()
So(len(args), ShouldEqual, 2)
So(sql, ShouldEqual, "SELECT bg.*,(select jp from job_positions jp where jp.id=bg.job_position_id limit 1) as job_position FROM background_checks as bg WHERE bg.id=$1 and bg.account_id=$2")
})
Convey("Get", t, func() {
id := "00000000-0000-0000-0000-000000000001"
user := map[string]interface{}{}
err := From("users").Get(&user, id)
So(err, ShouldBeNil)
So(user["id"], ShouldEqual, id)
})
}
// confirms which placeholder format is the fastst
func BenchmarkStringCat(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = "$" + strconv.Itoa(i+1)
_ = "$" + strconv.Itoa(i)
}
}
func BenchmarkStringFmt(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = fmt.Sprintf("$%d", i+1)
_ = fmt.Sprintf("$%d", i)
}
}
func BenchmarkStringJoin(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = strings.Join([]string{"$", strconv.Itoa(i + 1)}, "")
_ = strings.Join([]string{"$", strconv.Itoa(i + 1)}, "")
}
}
package db
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
)
// ErrNoRow is returned by a Scan operation when it doesn't return a
// row. In such a case, the dest is left untouched, so checking
// it for nil is not a good indicator of row status
var ErrNoRows = sql.ErrNoRows
// Get wraps the query in a pg to_json (json_agg if the scan target is a slice).
// If db is nill then the default *db will be used, and is perfectly ok unless
// running the query in a tx (in this case, pass in the tx -- it satisifies the
// Queryer interface)
// Get will return a ErrNoRows if no rows are found. Use Scan if this is not
// desired
func Get(tx Queryer, dest interface{}, query string, args ...interface{}) error {
if tx == nil {
tx = db
}
qbuf := &bytes.Buffer{}
desttype := reflect.TypeOf(dest)
if desttype.Kind() == reflect.Slice || (desttype.Kind() == reflect.Ptr && desttype.Elem().Kind() == reflect.Slice) {
qbuf.WriteString("SELECT json_agg(r) FROM(")
} else {
qbuf.WriteString("SELECT to_json(r) FROM(")
}
qbuf.WriteString(query)
qbuf.WriteString(")r")
fmt.Printf("\n%v\n%v", qbuf.String(), args)
rows, err := tx.Query(qbuf.String(), args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return ErrNoRows
}
return Scan(rows, dest)
}
// Scan populates dest from JSON, which should be a single column
func Scan(row *sql.Rows, dest interface{}) error {
var data []byte // TODO: use a byte pool
err := row.Scan(&data)
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// From returns a new sqlselect that lets you compose sql queries
// pg $ placeholders are used only they are localized to the query part
// being used so you dont have to keep track of the correct position arg
// notice in the example below the name arg used $1 and not $3 because it is
// the $1 pos for it's term
// e.g. From("users").where("id=$1 and last_login<$2",id,date).where("name=$1",name)
func From(table string, queryer ...Queryer) *sqlselect {
s := (&sqlselect{}).From(table)
if len(queryer) > 0 {
s.db = queryer[0]
} else {
s.db = db
}
return s
}
func (s *sqlselect) Select(cols string, args ...interface{}) *sqlselect {
s.selects = append(s.selects, sqlterm{cols, args})
return s
}
func (s *sqlselect) From(table string) *sqlselect {
s.from = table
return s
}
func (s *sqlselect) Join(join string, args ...interface{}) *sqlselect {
s.joins = append(s.joins, sqlterm{join, args})
return s
}
func (s *sqlselect) Where(condition string, args ...interface{}) *sqlselect {
s.wheres = append(s.wheres, sqlterm{condition, args})
return s
}
func (s *sqlselect) OrderBy(orderBy string, args ...interface{}) *sqlselect {
s.orders = append(s.orders, sqlterm{orderBy, args})
return s
}
func (s *sqlselect) Limit(count interface{}) *sqlselect {
s.limit, _ = strconv.Atoi(fmt.Sprintf("%v", count))
return s
}
// Scan runs the query and scans in the result. id is variadic and if used
// will also set limit to 1. Returns ErrNoRows if no rows found
func (s *sqlselect) Get(dest interface{}, id ...interface{}) error {
if len(id) > 0 {
s.Where("id=$1", id[0]).Limit(1)
}
query, args := s.ToSQL()
return Get(s.db, dest, query, args...)
}
// Scan runs the query and scans in the result. ErrNoRows is ignored, if you
// need ErrNoRows use Get without an ID
func (s *sqlselect) Scan(dest interface{}) error {
query, args := s.ToSQL()
err := Get(s.db, dest, query, args...)
if err == nil || err == ErrNoRows {
return nil
}
return err
}
func (s *sqlselect) ToSQL() (string, []interface{}) {
buf := &bytes.Buffer{}
args := []interface{}{}
buf.WriteString("SELECT ")
var term string
if len(s.selects) > 0 {
term, args = s.selects.merge(", ", args)
buf.WriteString(term)
} else {
buf.WriteString("*")
}
buf.WriteString(" FROM ")
buf.WriteString(s.from)
if len(s.joins) > 0 {
buf.WriteString(" ")
term, args = s.joins.merge(", ", args)
buf.WriteString(term)
}
if len(s.wheres) > 0 {
buf.WriteString(" WHERE ")
term, args = s.wheres.merge(" AND ", args)
buf.WriteString(term)
}
if len(s.orders) > 0 {
buf.WriteString(" ORDER BY ")
term, args = s.orders.merge(", ", args)
buf.WriteString(term)
}
if s.limit > 0 {
buf.WriteString(" LIMIT ")
buf.WriteString(strconv.Itoa(s.limit))
}
// sql.WriteString(s.from)
//return sql.String(), args
return buf.String(), args
}
type sqlterm struct {
term string
args []interface{}
}
type sqlterms []sqlterm
func (t sqlterms) merge(sep string, args []interface{}) (string, []interface{}) {
terms := make([]string, 0, len(t))
for _, term := range t {
terms = append(terms, replaceHolders(term.term, len(args)))
args = append(args, term.args...)
}
return strings.Join(terms, sep), args
}
type sqlselect struct {
db Queryer
from string
selects sqlterms
joins sqlterms
wheres sqlterms
groups sqlterms
havings sqlterms
orders sqlterms
limit int
}
func replaceHolders(sqlstr string, offset int) string {
src := []byte(sqlstr)
size := len(src)
dest := make([]byte, 0, size)
inQuote := false
var capture []byte
for i, b := range src {
// pg quoted delims can be ',",$$
if b == '"' || b == '\'' || (b == '$' && i < (size-1) && src[i+1] == '$') {
inQuote = !inQuote
dest = append(dest, b)
continue
}
if inQuote {
dest = append(dest, b)
continue
}
if capture != nil {
if b >= '0' && b <= '9' {
capture = append(capture, b)
if i < (size - 1) {
continue
}
}
pos, err := strconv.Atoi(string(capture))
if err != nil {
// this could be an sql injection leak, but is the best way to record
// in the query that $[pos] holder was not formatted correctly
dest = append(dest, []byte(err.Error())...)
} else {
dest = append(dest, []byte(strconv.Itoa(pos+offset))...)
}
// if capture was last char in buffer then continue so we dont write the
// last char to dest twice
if i == (size-1) && b == capture[len(capture)-1] {
continue
}
capture = nil
}
dest = append(dest, b)
if b == '$' {
capture = make([]byte, 0, 2)
}
}
return string(dest)
}
// hte below gunc uses ? but decided prob best to keep with pg's $
// func replaceHolders(sql string, args []interface{}) string {
// src := []byte(sql)
// dest := make([]byte, 0, len(src))
// inQuote := false
// i := 0
// for _, b := range src {
// if b == '"' || b == '\'' {
// inQuote = !inQuote
// dest = append(dest, b)
// continue
// }
// if inQuote {
// dest = append(dest, b)
// continue
// }
// if b == '?' {
// i++
// dest = append(dest, '$')
// dest = append(dest, []byte(strconv.Itoa(i))...)
// // For debugging only, comment the above and uncomment below to have the
// // args inlined in the sql. this is good to copy/paste into psql but can
// // really ruin your day with SQL injection
// // dest = append(dest, '\'')
// // dest = append(dest, []byte(fmt.Sprintf("%v", args[i-1]))...)
// // dest = append(dest, '\'')
// } else {
// dest = append(dest, b)
// }
// }
// return string(dest)
// }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment