Last active
March 11, 2024 12:20
-
-
Save s3rj1k/a21a9486b2298312034ed09ccae9d999 to your computer and use it in GitHub Desktop.
Benchmark SQL builders.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"database/sql" | |
"fmt" | |
"testing" | |
"text/template" | |
sq "github.com/Masterminds/squirrel" | |
"github.com/cristalhq/builq" | |
"github.com/flosch/pongo2/v6" | |
"github.com/keegancsmith/sqlf" | |
_ "github.com/mattn/go-sqlite3" | |
) | |
// go test -bench=. -benchmem | |
func BenchmarkSQLiteInsertSelectUpdate(b *testing.B) { // dynamic, unsafe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
_, err := db.Exec("INSERT INTO benchmark(name, value1, value2, value3, value4, value5, value6, value7, value8, value9) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | |
fmt.Sprintf("Name%d", i), float64(i), float64(i+1), float64(i+2), float64(i+3), float64(i+4), float64(i+5), float64(i+6), float64(i+7), float64(i+8)) | |
if err != nil { | |
b.Fatalf("could not execute statement: %v", err) | |
} | |
var id int | |
err = db.QueryRow("SELECT id FROM benchmark WHERE name = ?", fmt.Sprintf("Name%d", i)).Scan(&id) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatalf("data mismatch: expected %d, got %d.", i+1, id) | |
} | |
_, err = db.Exec("UPDATE benchmark SET value1 = ? WHERE id = ?", float64(i+10), id) | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
var val float64 | |
err = db.QueryRow("SELECT value1 FROM benchmark WHERE id = ?", id).Scan(&val) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} | |
func BenchmarkSQLiteInsertSelectUpdateUsingSquirrel(b *testing.B) { // typed, safe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
psql := sq.StatementBuilder.PlaceholderFormat(sq.Question) | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
query, args, err := psql.Insert("benchmark").Columns("name", "value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8", "value9"). | |
Values(fmt.Sprintf("Name%d", i), float64(i), float64(i+1), float64(i+2), float64(i+3), float64(i+4), float64(i+5), float64(i+6), float64(i+7), float64(i+8)).ToSql() | |
if err != nil { | |
b.Fatalf("could not build insert SQL: %v", err) | |
} | |
_, err = db.Exec(query, args...) | |
if err != nil { | |
b.Fatalf("could not execute insert statement: %v", err) | |
} | |
query, args, err = psql.Select("id").From("benchmark").Where(sq.Eq{"name": fmt.Sprintf("Name%d", i)}).ToSql() | |
if err != nil { | |
b.Fatalf("could not build select SQL: %v", err) | |
} | |
var id int | |
err = db.QueryRow(query, args...).Scan(&id) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatalf("data mismatch: expected %d, got %d.", i+1, id) | |
} | |
query, args, err = psql.Update("benchmark").Set("value1", float64(i+10)).Where(sq.Eq{"id": id}).ToSql() | |
if err != nil { | |
b.Fatalf("could not build update SQL: %v", err) | |
} | |
_, err = db.Exec(query, args...) | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
query, args, err = psql.Select("value1").From("benchmark").Where(sq.Eq{"id": id}).ToSql() | |
if err != nil { | |
b.Fatalf("could not build select SQL: %v", err) | |
} | |
var val float64 | |
err = db.QueryRow(query, args...).Scan(&val) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} | |
func BenchmarkSQLiteInsertSelectUpdateUsingSqlf(b *testing.B) { // semi-dynamic, safe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
query := sqlf.Sprintf("INSERT INTO benchmark(name, value1, value2, value3, value4, value5, value6, value7, value8, value9) VALUES(%s, %f, %f, %f, %f, %f, %f, %f, %f, %f)", | |
fmt.Sprintf("Name%d", i), float64(i), float64(i+1), float64(i+2), float64(i+3), float64(i+4), float64(i+5), float64(i+6), float64(i+7), float64(i+8)) | |
_, err := db.Exec(query.Query(sqlf.SQLServerBindVar), query.Args()...) | |
if err != nil { | |
b.Fatalf("could not execute insert statement: %v", err) | |
} | |
query = sqlf.Sprintf("SELECT id FROM benchmark WHERE name = %s", fmt.Sprintf("Name%d", i)) | |
var id int | |
err = db.QueryRow(query.Query(sqlf.SQLServerBindVar), query.Args()...).Scan(&id) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatalf("data mismatch: expected %d, got %d.", i+1, id) | |
} | |
query = sqlf.Sprintf("UPDATE benchmark SET value1 = %f WHERE id = %d", float64(i+10), id) | |
_, err = db.Exec(query.Query(sqlf.SQLServerBindVar), query.Args()...) | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
query = sqlf.Sprintf("SELECT value1 FROM benchmark WHERE id = %s", id) | |
var val float64 | |
err = db.QueryRow(query.Query(sqlf.SQLServerBindVar), query.Args()...).Scan(&val) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} | |
func BenchmarkSQLiteInsertSelectUpdateUsingTemplateWithMap(b *testing.B) { // unsafe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
insertTemplate := `INSERT INTO benchmark(name, value1, value2, value3, value4, value5, value6, value7, value8, value9) VALUES('{{.Name}}', {{.Value1}}, {{.Value2}}, {{.Value3}}, {{.Value4}}, {{.Value5}}, {{.Value6}}, {{.Value7}}, {{.Value8}}, {{.Value9}})` | |
selectIdTemplate := `SELECT id FROM benchmark WHERE name = '{{.Name}}'` | |
updateTemplate := `UPDATE benchmark SET value1 = {{.Value1}} WHERE id = {{.Id}}` | |
selectValueTemplate := `SELECT value1 FROM benchmark WHERE id = {{.Id}}` | |
tmplInsert := template.Must(template.New("insert").Parse(insertTemplate)) | |
tmplSelectId := template.Must(template.New("selectId").Parse(selectIdTemplate)) | |
tmplUpdate := template.Must(template.New("update").Parse(updateTemplate)) | |
tmplSelectValue := template.Must(template.New("selectValue").Parse(selectValueTemplate)) | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
var buf bytes.Buffer | |
err = tmplInsert.Execute(&buf, map[string]any{ | |
"Name": fmt.Sprintf("Name%d", i), | |
"Value1": float64(i), | |
"Value2": float64(i + 1), | |
"Value3": float64(i + 2), | |
"Value4": float64(i + 3), | |
"Value5": float64(i + 4), | |
"Value6": float64(i + 5), | |
"Value7": float64(i + 6), | |
"Value8": float64(i + 7), | |
"Value9": float64(i + 8), | |
}) | |
if err != nil { | |
b.Fatalf("could not execute insert template: %v", err) | |
} | |
_, err = db.Exec(buf.String()) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute insert statement: %v", err) | |
} | |
err = tmplSelectId.Execute(&buf, map[string]any{ | |
"Name": fmt.Sprintf("Name%d", i), | |
}) | |
if err != nil { | |
b.Fatalf("could not execute select template: %v", err) | |
} | |
var id int | |
err = db.QueryRow(buf.String()).Scan(&id) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatalf("data mismatch: expected %d, got %d.", i+1, id) | |
} | |
err = tmplUpdate.Execute(&buf, map[string]any{ | |
"Value1": float64(i + 10), | |
"Id": id, | |
}) | |
if err != nil { | |
b.Fatalf("could not execute update template: %v", err) | |
} | |
_, err = db.Exec(buf.String()) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
err = tmplSelectValue.Execute(&buf, map[string]any{ | |
"Id": id, | |
}) | |
if err != nil { | |
b.Fatalf("could not execute select template: %v", err) | |
} | |
var val float64 | |
err = db.QueryRow(buf.String()).Scan(&val) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} | |
func BenchmarkSQLiteInsertSelectUpdateUsingTemplateWithStruct(b *testing.B) { // unsafe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
insertTemplate := `INSERT INTO benchmark(name, value1, value2, value3, value4, value5, value6, value7, value8, value9) VALUES('{{.Name}}', {{.Value1}}, {{.Value2}}, {{.Value3}}, {{.Value4}}, {{.Value5}}, {{.Value6}}, {{.Value7}}, {{.Value8}}, {{.Value9}})` | |
selectIdTemplate := `SELECT id FROM benchmark WHERE name = '{{.Name}}'` | |
updateTemplate := `UPDATE benchmark SET value1 = {{.Value1}} WHERE id = {{.Id}}` | |
selectValueTemplate := `SELECT value1 FROM benchmark WHERE id = {{.Id}}` | |
tmplInsert := template.Must(template.New("insert").Parse(insertTemplate)) | |
tmplSelectId := template.Must(template.New("selectId").Parse(selectIdTemplate)) | |
tmplUpdate := template.Must(template.New("update").Parse(updateTemplate)) | |
tmplSelectValue := template.Must(template.New("selectValue").Parse(selectValueTemplate)) | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
var buf bytes.Buffer | |
err = tmplInsert.Execute(&buf, struct { | |
Name string | |
Value1 float64 | |
Value2 float64 | |
Value3 float64 | |
Value4 float64 | |
Value5 float64 | |
Value6 float64 | |
Value7 float64 | |
Value8 float64 | |
Value9 float64 | |
}{ | |
Name: fmt.Sprintf("Name%d", i), | |
Value1: float64(i), | |
Value2: float64(i + 1), | |
Value3: float64(i + 2), | |
Value4: float64(i + 3), | |
Value5: float64(i + 4), | |
Value6: float64(i + 5), | |
Value7: float64(i + 6), | |
Value8: float64(i + 7), | |
Value9: float64(i + 8), | |
}) | |
if err != nil { | |
b.Fatalf("could not execute insert template: %v", err) | |
} | |
_, err = db.Exec(buf.String()) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute insert statement: %v", err) | |
} | |
err = tmplSelectId.Execute(&buf, struct { | |
Name string | |
}{ | |
Name: fmt.Sprintf("Name%d", i), | |
}) | |
if err != nil { | |
b.Fatalf("could not execute select template: %v", err) | |
} | |
var id int | |
err = db.QueryRow(buf.String()).Scan(&id) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatal("invalid data coming from DB") | |
} | |
err = tmplUpdate.Execute(&buf, struct { | |
Value1 float64 | |
Id int | |
}{ | |
Value1: float64(i + 10), | |
Id: id, | |
}) | |
if err != nil { | |
b.Fatalf("could not execute update template: %v", err) | |
} | |
_, err = db.Exec(buf.String()) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
err = tmplSelectValue.Execute(&buf, struct { | |
Id int | |
}{ | |
Id: id, | |
}) | |
if err != nil { | |
b.Fatalf("could not execute select template: %v", err) | |
} | |
var val float64 | |
err = db.QueryRow(buf.String()).Scan(&val) | |
buf.Reset() | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} | |
func BenchmarkSQLiteInsertSelectUpdateUsingPongo2(b *testing.B) { // unsafe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
insertTplString := "INSERT INTO benchmark(name, value1, value2, value3, value4, value5, value6, value7, value8, value9) VALUES('{{ name }}', {{ value1 }}, {{ value2 }}, {{ value3 }}, {{ value4 }}, {{ value5 }}, {{ value6 }}, {{ value7 }}, {{ value8 }}, {{ value9 }})" | |
selectIdTplString := "SELECT id FROM benchmark WHERE name = '{{ name }}'" | |
updateTplString := "UPDATE benchmark SET value1 = {{ value1 }} WHERE id = {{ id }}" | |
selectValueTplString := "SELECT value1 FROM benchmark WHERE id = {{ id }}" | |
insertTpl, err := pongo2.FromString(insertTplString) | |
if err != nil { | |
b.Fatalf("could not parse insert template: %v", err) | |
} | |
selectIdTpl, err := pongo2.FromString(selectIdTplString) | |
if err != nil { | |
b.Fatalf("could not parse select template: %v", err) | |
} | |
updateTpl, err := pongo2.FromString(updateTplString) | |
if err != nil { | |
b.Fatalf("could not parse update template: %v", err) | |
} | |
selectValueTpl, err := pongo2.FromString(selectValueTplString) | |
if err != nil { | |
b.Fatalf("could not parse select template: %v", err) | |
} | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
sql, err := insertTpl.Execute(pongo2.Context{ | |
"name": fmt.Sprintf("Name%d", i), | |
"value1": float64(i), | |
"value2": float64(i + 1), | |
"value3": float64(i + 2), | |
"value4": float64(i + 3), | |
"value5": float64(i + 4), | |
"value6": float64(i + 5), | |
"value7": float64(i + 6), | |
"value8": float64(i + 7), | |
"value9": float64(i + 8), | |
}) | |
if err != nil { | |
b.Fatalf("could not execute insert template: %v", err) | |
} | |
_, err = db.Exec(sql) | |
if err != nil { | |
b.Fatalf("could not execute insert statement: %v", err) | |
} | |
sql, err = selectIdTpl.Execute(pongo2.Context{"name": fmt.Sprintf("Name%d", i)}) | |
if err != nil { | |
b.Fatalf("could not execute select template: %v", err) | |
} | |
var id int | |
err = db.QueryRow(sql).Scan(&id) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatal("invalid data coming from DB") | |
} | |
sql, err = updateTpl.Execute(pongo2.Context{"value1": float64(i + 10), "id": id}) | |
if err != nil { | |
b.Fatalf("could not execute update template: %v", err) | |
} | |
_, err = db.Exec(sql) | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
sql, err = selectValueTpl.Execute(pongo2.Context{"id": id}) | |
if err != nil { | |
b.Fatalf("could not execute select template: %v", err) | |
} | |
var val float64 | |
err = db.QueryRow(sql).Scan(&val) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} | |
func BenchmarkSQLiteInsertSelectUpdateUsingBuilq(b *testing.B) { // dynamic, safe SQL | |
db, err := sql.Open("sqlite3", ":memory:") | |
if err != nil { | |
b.Fatalf("could not open sqlite3 database: %v", err) | |
} | |
defer db.Close() | |
_, err = db.Exec(`CREATE TABLE benchmark ( | |
id INTEGER PRIMARY KEY, | |
name TEXT, | |
value1 REAL, | |
value2 REAL, | |
value3 REAL, | |
value4 REAL, | |
value5 REAL, | |
value6 REAL, | |
value7 REAL, | |
value8 REAL, | |
value9 REAL | |
)`) | |
if err != nil { | |
b.Fatalf("could not create table: %v", err) | |
} | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
bb := builq.Builder{} | |
bb.Addf("INSERT INTO benchmark (%s)", | |
builq.Columns{"name", "value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8", "value9"}, | |
) | |
bb.Addf("VALUES (%+$)", | |
[]any{fmt.Sprintf("Name%d", i), float64(i), float64(i + 1), float64(i + 2), float64(i + 3), float64(i + 4), float64(i + 5), float64(i + 6), float64(i + 7), float64(i + 8)}, | |
) | |
query, args, err := bb.Build() | |
if err != nil { | |
b.Fatalf("could not build insert query: %v", err) | |
} | |
_, err = db.Exec(query, args...) | |
if err != nil { | |
b.Fatalf("could not execute insert statement: %v", err) | |
} | |
bf := builq.New() | |
bf("SELECT %s FROM %s", "id", "benchmark") | |
bf("WHERE %s = %$", "name", fmt.Sprintf("Name%d", i)) | |
query, args, err = bf.Build() | |
if err != nil { | |
b.Fatalf("could not build select query: %v", err) | |
} | |
var id int | |
err = db.QueryRow(query, args...).Scan(&id) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if id != (i + 1) { | |
b.Fatalf("data mismatch: expected %d, got %d.", i+1, id) | |
} | |
bb = builq.Builder{} | |
bb.Addf("UPDATE benchmark SET value1 = %$ WHERE id = %$", float64(i+10), id) | |
query, args, err = bb.Build() | |
if err != nil { | |
b.Fatalf("could not build update query: %v", err) | |
} | |
_, err = db.Exec(query, args...) | |
if err != nil { | |
b.Fatalf("could not execute update statement: %v", err) | |
} | |
bf = builq.New() | |
bf("SELECT %s FROM %s", "value1", "benchmark") | |
bf("WHERE %s = %$", "id", id) | |
query, args, err = bf.Build() | |
if err != nil { | |
b.Fatalf("could not build select query: %v", err) | |
} | |
var val float64 | |
err = db.QueryRow(query, args...).Scan(&val) | |
if err != nil { | |
b.Fatalf("could not execute select statement: %v", err) | |
} | |
if val != float64(i+10) { | |
b.Fatalf("data mismatch: expected %f, got %f.", float64(i+10), val) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"errors" | |
"fmt" | |
"regexp" | |
"testing" | |
"github.com/DATA-DOG/go-sqlmock" | |
"github.com/cristalhq/builq" | |
) | |
func TestSQLInjectionWithSqlmock(t *testing.T) { | |
db, mock, err := sqlmock.New() | |
if err != nil { | |
t.Fatalf("An error '%s' was not expected when opening a stub database connection", err) | |
} | |
defer db.Close() | |
qf := func(db *sql.DB, tableName, username string, unsafeQuery bool) (*sql.Rows, error) { | |
if unsafeQuery { | |
query := fmt.Sprintf("SELECT * FROM %s WHERE username = '%s'", tableName, username) | |
return db.Query(query) | |
} | |
return db.Query("SELECT * FROM users WHERE username = ?", username) | |
} | |
tests := []struct { | |
name string | |
user string | |
table string | |
unsafe bool | |
prepare func(mock sqlmock.Sqlmock, user string) | |
}{ | |
{ | |
name: "Tautologies", | |
user: "anything' OR 'x'='x", | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM users WHERE username = ?")). | |
WithArgs(user). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "Illegal/Logically Incorrect Queries", | |
user: "admin' AND 1=2 UNION SELECT * FROM users --", | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM users WHERE username = ?")). | |
WithArgs(user). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "Union Query", | |
user: "admin' UNION SELECT * FROM users --", | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM users WHERE username = ?")). | |
WithArgs(user). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "Piggy-Backed Queries", | |
user: "admin'; DROP TABLE users; --", | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM users WHERE username = ?")). | |
WithArgs(user). | |
WillReturnError(sql.ErrNoRows) | |
}, | |
}, | |
{ | |
name: "FmtSprintf Injection", | |
user: "'; DROP TABLE users; --", | |
unsafe: true, | |
table: "users", | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery("SELECT \\* FROM users WHERE username = '.*; DROP TABLE users; --'"). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "TableName Injection", | |
user: "admin", | |
table: "users; DROP TABLE sensitive_data; --", | |
unsafe: true, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(".+ DROP TABLE sensitive_data;.+"). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
} | |
for _, tc := range tests { | |
t.Run(tc.name, func(t *testing.T) { | |
tc.prepare(mock, tc.user) | |
_, err := qf(db, tc.table, tc.user, tc.unsafe) | |
if err != nil && !errors.Is(err, sql.ErrNoRows) { | |
t.Errorf("Unexpected error: %v", err) | |
} | |
if err := mock.ExpectationsWereMet(); err != nil { | |
t.Errorf("There were unfulfilled expectations: %s", err) | |
} | |
}) | |
} | |
} | |
func TestSQLInjectionPreventionUsingBuilq(t *testing.T) { | |
db, mock, err := sqlmock.New() | |
if err != nil { | |
t.Fatalf("An error '%s' was not expected when opening a stub database connection", err) | |
} | |
defer db.Close() | |
tests := []struct { | |
name string | |
user string | |
bb func(user string) *builq.Builder | |
prepare func(mock sqlmock.Sqlmock, user string) | |
}{ | |
{ | |
name: "Tautologies", | |
user: "anything' OR 'x'='x", | |
bb: func(user string) *builq.Builder { | |
return builq.New()("SELECT %s FROM %s WHERE username = %$", builq.Columns{"username"}, "users", user) | |
}, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery("SELECT username FROM users WHERE username = \\$1"). | |
WithArgs(user). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "Illegal/Logically Incorrect Queries", | |
user: "admin' AND 1=2 UNION SELECT * FROM users --", | |
bb: func(user string) *builq.Builder { | |
return builq.New()("SELECT %s FROM %s WHERE username = %$", builq.Columns{"username"}, "users", user) | |
}, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT username FROM users WHERE username = $1")). | |
WithArgs(user). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "Union Query", | |
user: "admin' UNION SELECT * FROM users --", | |
bb: func(user string) *builq.Builder { | |
return builq.New()("SELECT %s FROM %s WHERE username = %$", builq.Columns{"username"}, "users", user) | |
}, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT username FROM users WHERE username = $1")). | |
WithArgs(user). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "Piggy-Backed Queries", | |
user: "admin'; DROP TABLE users; --", | |
bb: func(user string) *builq.Builder { | |
return builq.New()("SELECT %s FROM %s WHERE username = %$", builq.Columns{"username"}, "users", user) | |
}, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(regexp.QuoteMeta("SELECT username FROM users WHERE username = $1")). | |
WithArgs(user). | |
WillReturnError(sql.ErrNoRows) | |
}, | |
}, | |
{ | |
name: "FmtSprintf Injection", | |
user: "'; DROP TABLE users; --", | |
bb: func(user string) *builq.Builder { | |
return builq.New()("SELECT * FROM users WHERE username = '%s'", user) | |
}, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery("SELECT \\* FROM users WHERE username = '.*; DROP TABLE users; --'"). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
{ | |
name: "TableName Injection", | |
user: "admin", | |
bb: func(user string) *builq.Builder { | |
tableName := "users; DROP TABLE sensitive_data; --" | |
return builq.New()("SELECT * FROM %s WHERE username = '%s'", tableName, user) | |
}, | |
prepare: func(mock sqlmock.Sqlmock, user string) { | |
mock.ExpectQuery(".+ DROP TABLE sensitive_data;.+"). | |
WillReturnRows(sqlmock.NewRows([]string{"id", "username"})) | |
}, | |
}, | |
} | |
for _, tc := range tests { | |
t.Run(tc.name, func(t *testing.T) { | |
tc.prepare(mock, tc.user) | |
query, args, err := tc.bb(tc.user).Build() | |
if err != nil { | |
t.Fatalf("could not build query: %v", err) | |
} | |
_, err = db.Query(query, args...) | |
if err != nil && !errors.Is(err, sql.ErrNoRows) { | |
t.Errorf("Unexpected error: %v", err) | |
} | |
if err := mock.ExpectationsWereMet(); err != nil { | |
t.Errorf("There were unfulfilled expectations: %s", err) | |
} | |
}) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ go test -bench=. -benchmem | |
goos: linux | |
goarch: amd64 | |
pkg: code.local/benchmark | |
cpu: 12th Gen Intel(R) Core(TM) i5-1240P | |
BenchmarkSQLiteInsertSelectUpdate-16 10000 117530 ns/op 2897 B/op 82 allocs/op | |
BenchmarkSQLiteInsertSelectUpdateUsingSquirrel-16 10000 134923 ns/op 14105 B/op 303 allocs/op | |
BenchmarkSQLiteInsertSelectUpdateUsingSqlf-16 10000 124792 ns/op 4867 B/op 148 allocs/op | |
BenchmarkSQLiteInsertSelectUpdateUsingTemplateWithMap-16 10000 123096 ns/op 4342 B/op 113 allocs/op | |
BenchmarkSQLiteInsertSelectUpdateUsingTemplateWithStruct-16 10000 121335 ns/op 2367 B/op 81 allocs/op | |
BenchmarkSQLiteInsertSelectUpdateUsingPongo2-16 10000 132478 ns/op 7332 B/op 140 allocs/op | |
BenchmarkSQLiteInsertSelectUpdateUsingBuilq-16 10000 130760 ns/op 6884 B/op 118 allocs/op | |
PASS | |
ok code.local/benchmark 8.893s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment