Created
November 12, 2015 23:17
-
-
Save troyk/5731f86dd2c1e9459141 to your computer and use it in GitHub Desktop.
golang sql query abstraction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"database/sql" | |
"log" | |
"strings" | |
"sync" | |
_ "github.com/lib/pq" // required for database/sql | |
"golang.org/x/net/context" | |
) | |
var db *sql.DB | |
// Queryer interface is used to query the db, allowing both tx and normal | |
// query opts | |
type Queryer interface { | |
Query(query string, args ...interface{}) (*sql.Rows, error) | |
Exec(query string, args ...interface{}) (sql.Result, error) | |
} | |
type ctxkey int | |
const ctxQueryer ctxkey = 0 | |
// FromContext retrieves the Queryer from the current context, | |
// this is how transactions, current user, etc can be passed around | |
func FromContext(ctx context.Context) Queryer { | |
q, ok := ctx.Value(ctxQueryer).(Queryer) | |
if !ok { | |
return db | |
} | |
return q | |
} | |
var schemaCache struct { | |
sync.Mutex | |
tables map[string][]column | |
} | |
// Internal representation of a field on a database table, and its | |
// relation to a struct field. | |
type column struct { | |
// column = table column name | |
Table, Column, DefaultValue string | |
// Is a primary key | |
IsPrimaryKey bool | |
IsNullable bool | |
} | |
func init() { | |
db, _ = sql.Open("postgres", "postgres://[email protected]:5432/assurehire?sslmode=disable") | |
db.SetMaxIdleConns(10) | |
db.SetMaxOpenConns(10) | |
// return | |
} | |
// GetSchema returns []columm for a table | |
func GetSchema(table string) []column { | |
schemaCache.Lock() | |
columns := schemaCache.tables[table] | |
// load schema if tables is empty | |
if columns == nil && len(schemaCache.tables) == 0 { | |
rows, err := db.Query(`select c.table_schema,c.table_name,c.column_name,c.column_default,c.is_nullable::bool, (select array_to_string(array_agg(tc.constraint_type::text),',') from information_schema.key_column_usage kc join information_schema.table_constraints tc on kc.constraint_name = tc.constraint_name where kc.table_schema=c.table_schema and kc.table_name=c.table_name and kc.column_name=c.column_name) as constraints from information_schema.columns c where c.table_schema = 'public' order by c.table_name,c.ordinal_position;`) | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer rows.Close() | |
for rows.Next() { | |
var schema, table, name, defvalue, constraints string | |
var nullable bool | |
if err := rows.Scan(&schema, &table, &name, &defvalue, &nullable, &constraints); err != nil { | |
log.Fatal(err) | |
} | |
columns := schemaCache.tables[table] | |
if columns == nil { | |
columns = []column{} | |
} | |
schemaCache.tables[table] = append(columns, column{table, name, defvalue, strings.Contains(constraints, "PRIMARY KEY"), nullable}) | |
} | |
columns = schemaCache.tables[table] | |
} | |
schemaCache.Unlock() | |
return columns | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"fmt" | |
"strconv" | |
"strings" | |
"testing" | |
. "github.com/smartystreets/goconvey/convey" | |
) | |
func TestSqlBuilder(t *testing.T) { | |
Convey("replaceHolders", t, func() { | |
So(replaceHolders("select $1, $2", 0), ShouldEqual, "select $1, $2") | |
So(replaceHolders("select $1, $2", 1), ShouldEqual, "select $2, $3") | |
}) | |
Convey("Select Builder", t, func() { | |
q := From("users") | |
So(q.from, ShouldEqual, "users") | |
q.Where("email=$1", "foo") | |
So(len(q.wheres), ShouldEqual, 1) | |
sql, args := q.ToSQL() | |
So(sql, ShouldEqual, "SELECT * FROM users WHERE email=$1") | |
So(args, ShouldContain, "foo") | |
q.Where("username=$1", "bar") | |
sql, args = q.ToSQL() | |
So(sql, ShouldEqual, "SELECT * FROM users WHERE email=$1 AND username=$2") | |
So(len(args), ShouldEqual, 2) | |
So(args[0], ShouldEqual, "foo") | |
So(args[1], ShouldEqual, "bar") | |
q = From("background_checks as bg"). | |
Where("bg.id=$1 and bg.account_id=$2", ":id", ":account_id"). | |
Select("bg.*,(select jp from job_positions jp where jp.id=bg.job_position_id limit 1) as job_position") | |
sql, args = q.ToSQL() | |
So(len(args), ShouldEqual, 2) | |
So(sql, ShouldEqual, "SELECT bg.*,(select jp from job_positions jp where jp.id=bg.job_position_id limit 1) as job_position FROM background_checks as bg WHERE bg.id=$1 and bg.account_id=$2") | |
}) | |
Convey("Get", t, func() { | |
id := "00000000-0000-0000-0000-000000000001" | |
user := map[string]interface{}{} | |
err := From("users").Get(&user, id) | |
So(err, ShouldBeNil) | |
So(user["id"], ShouldEqual, id) | |
}) | |
} | |
// confirms which placeholder format is the fastst | |
func BenchmarkStringCat(b *testing.B) { | |
for i := 0; i < b.N; i++ { | |
_ = "$" + strconv.Itoa(i+1) | |
_ = "$" + strconv.Itoa(i) | |
} | |
} | |
func BenchmarkStringFmt(b *testing.B) { | |
for i := 0; i < b.N; i++ { | |
_ = fmt.Sprintf("$%d", i+1) | |
_ = fmt.Sprintf("$%d", i) | |
} | |
} | |
func BenchmarkStringJoin(b *testing.B) { | |
for i := 0; i < b.N; i++ { | |
_ = strings.Join([]string{"$", strconv.Itoa(i + 1)}, "") | |
_ = strings.Join([]string{"$", strconv.Itoa(i + 1)}, "") | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"bytes" | |
"database/sql" | |
"encoding/json" | |
"fmt" | |
"reflect" | |
"strconv" | |
"strings" | |
) | |
// ErrNoRow is returned by a Scan operation when it doesn't return a | |
// row. In such a case, the dest is left untouched, so checking | |
// it for nil is not a good indicator of row status | |
var ErrNoRows = sql.ErrNoRows | |
// Get wraps the query in a pg to_json (json_agg if the scan target is a slice). | |
// If db is nill then the default *db will be used, and is perfectly ok unless | |
// running the query in a tx (in this case, pass in the tx -- it satisifies the | |
// Queryer interface) | |
// Get will return a ErrNoRows if no rows are found. Use Scan if this is not | |
// desired | |
func Get(tx Queryer, dest interface{}, query string, args ...interface{}) error { | |
if tx == nil { | |
tx = db | |
} | |
qbuf := &bytes.Buffer{} | |
desttype := reflect.TypeOf(dest) | |
if desttype.Kind() == reflect.Slice || (desttype.Kind() == reflect.Ptr && desttype.Elem().Kind() == reflect.Slice) { | |
qbuf.WriteString("SELECT json_agg(r) FROM(") | |
} else { | |
qbuf.WriteString("SELECT to_json(r) FROM(") | |
} | |
qbuf.WriteString(query) | |
qbuf.WriteString(")r") | |
fmt.Printf("\n%v\n%v", qbuf.String(), args) | |
rows, err := tx.Query(qbuf.String(), args...) | |
if err != nil { | |
return err | |
} | |
defer rows.Close() | |
if !rows.Next() { | |
return ErrNoRows | |
} | |
return Scan(rows, dest) | |
} | |
// Scan populates dest from JSON, which should be a single column | |
func Scan(row *sql.Rows, dest interface{}) error { | |
var data []byte // TODO: use a byte pool | |
err := row.Scan(&data) | |
if err != nil { | |
return err | |
} | |
return json.Unmarshal(data, dest) | |
} | |
// From returns a new sqlselect that lets you compose sql queries | |
// pg $ placeholders are used only they are localized to the query part | |
// being used so you dont have to keep track of the correct position arg | |
// notice in the example below the name arg used $1 and not $3 because it is | |
// the $1 pos for it's term | |
// e.g. From("users").where("id=$1 and last_login<$2",id,date).where("name=$1",name) | |
func From(table string, queryer ...Queryer) *sqlselect { | |
s := (&sqlselect{}).From(table) | |
if len(queryer) > 0 { | |
s.db = queryer[0] | |
} else { | |
s.db = db | |
} | |
return s | |
} | |
func (s *sqlselect) Select(cols string, args ...interface{}) *sqlselect { | |
s.selects = append(s.selects, sqlterm{cols, args}) | |
return s | |
} | |
func (s *sqlselect) From(table string) *sqlselect { | |
s.from = table | |
return s | |
} | |
func (s *sqlselect) Join(join string, args ...interface{}) *sqlselect { | |
s.joins = append(s.joins, sqlterm{join, args}) | |
return s | |
} | |
func (s *sqlselect) Where(condition string, args ...interface{}) *sqlselect { | |
s.wheres = append(s.wheres, sqlterm{condition, args}) | |
return s | |
} | |
func (s *sqlselect) OrderBy(orderBy string, args ...interface{}) *sqlselect { | |
s.orders = append(s.orders, sqlterm{orderBy, args}) | |
return s | |
} | |
func (s *sqlselect) Limit(count interface{}) *sqlselect { | |
s.limit, _ = strconv.Atoi(fmt.Sprintf("%v", count)) | |
return s | |
} | |
// Scan runs the query and scans in the result. id is variadic and if used | |
// will also set limit to 1. Returns ErrNoRows if no rows found | |
func (s *sqlselect) Get(dest interface{}, id ...interface{}) error { | |
if len(id) > 0 { | |
s.Where("id=$1", id[0]).Limit(1) | |
} | |
query, args := s.ToSQL() | |
return Get(s.db, dest, query, args...) | |
} | |
// Scan runs the query and scans in the result. ErrNoRows is ignored, if you | |
// need ErrNoRows use Get without an ID | |
func (s *sqlselect) Scan(dest interface{}) error { | |
query, args := s.ToSQL() | |
err := Get(s.db, dest, query, args...) | |
if err == nil || err == ErrNoRows { | |
return nil | |
} | |
return err | |
} | |
func (s *sqlselect) ToSQL() (string, []interface{}) { | |
buf := &bytes.Buffer{} | |
args := []interface{}{} | |
buf.WriteString("SELECT ") | |
var term string | |
if len(s.selects) > 0 { | |
term, args = s.selects.merge(", ", args) | |
buf.WriteString(term) | |
} else { | |
buf.WriteString("*") | |
} | |
buf.WriteString(" FROM ") | |
buf.WriteString(s.from) | |
if len(s.joins) > 0 { | |
buf.WriteString(" ") | |
term, args = s.joins.merge(", ", args) | |
buf.WriteString(term) | |
} | |
if len(s.wheres) > 0 { | |
buf.WriteString(" WHERE ") | |
term, args = s.wheres.merge(" AND ", args) | |
buf.WriteString(term) | |
} | |
if len(s.orders) > 0 { | |
buf.WriteString(" ORDER BY ") | |
term, args = s.orders.merge(", ", args) | |
buf.WriteString(term) | |
} | |
if s.limit > 0 { | |
buf.WriteString(" LIMIT ") | |
buf.WriteString(strconv.Itoa(s.limit)) | |
} | |
// sql.WriteString(s.from) | |
//return sql.String(), args | |
return buf.String(), args | |
} | |
type sqlterm struct { | |
term string | |
args []interface{} | |
} | |
type sqlterms []sqlterm | |
func (t sqlterms) merge(sep string, args []interface{}) (string, []interface{}) { | |
terms := make([]string, 0, len(t)) | |
for _, term := range t { | |
terms = append(terms, replaceHolders(term.term, len(args))) | |
args = append(args, term.args...) | |
} | |
return strings.Join(terms, sep), args | |
} | |
type sqlselect struct { | |
db Queryer | |
from string | |
selects sqlterms | |
joins sqlterms | |
wheres sqlterms | |
groups sqlterms | |
havings sqlterms | |
orders sqlterms | |
limit int | |
} | |
func replaceHolders(sqlstr string, offset int) string { | |
src := []byte(sqlstr) | |
size := len(src) | |
dest := make([]byte, 0, size) | |
inQuote := false | |
var capture []byte | |
for i, b := range src { | |
// pg quoted delims can be ',",$$ | |
if b == '"' || b == '\'' || (b == '$' && i < (size-1) && src[i+1] == '$') { | |
inQuote = !inQuote | |
dest = append(dest, b) | |
continue | |
} | |
if inQuote { | |
dest = append(dest, b) | |
continue | |
} | |
if capture != nil { | |
if b >= '0' && b <= '9' { | |
capture = append(capture, b) | |
if i < (size - 1) { | |
continue | |
} | |
} | |
pos, err := strconv.Atoi(string(capture)) | |
if err != nil { | |
// this could be an sql injection leak, but is the best way to record | |
// in the query that $[pos] holder was not formatted correctly | |
dest = append(dest, []byte(err.Error())...) | |
} else { | |
dest = append(dest, []byte(strconv.Itoa(pos+offset))...) | |
} | |
// if capture was last char in buffer then continue so we dont write the | |
// last char to dest twice | |
if i == (size-1) && b == capture[len(capture)-1] { | |
continue | |
} | |
capture = nil | |
} | |
dest = append(dest, b) | |
if b == '$' { | |
capture = make([]byte, 0, 2) | |
} | |
} | |
return string(dest) | |
} | |
// hte below gunc uses ? but decided prob best to keep with pg's $ | |
// func replaceHolders(sql string, args []interface{}) string { | |
// src := []byte(sql) | |
// dest := make([]byte, 0, len(src)) | |
// inQuote := false | |
// i := 0 | |
// for _, b := range src { | |
// if b == '"' || b == '\'' { | |
// inQuote = !inQuote | |
// dest = append(dest, b) | |
// continue | |
// } | |
// if inQuote { | |
// dest = append(dest, b) | |
// continue | |
// } | |
// if b == '?' { | |
// i++ | |
// dest = append(dest, '$') | |
// dest = append(dest, []byte(strconv.Itoa(i))...) | |
// // For debugging only, comment the above and uncomment below to have the | |
// // args inlined in the sql. this is good to copy/paste into psql but can | |
// // really ruin your day with SQL injection | |
// // dest = append(dest, '\'') | |
// // dest = append(dest, []byte(fmt.Sprintf("%v", args[i-1]))...) | |
// // dest = append(dest, '\'') | |
// } else { | |
// dest = append(dest, b) | |
// } | |
// } | |
// return string(dest) | |
// } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment