This gist is associated with the blog post "Building safe-by-default tools in our Go web application".
It contains the gorm hooks that we use to ensure our queries are correctly scoped.
package safedb | |
import ( | |
"context" | |
"fmt" | |
"gorm.io/gorm" | |
"gorm.io/gorm/callbacks" | |
"gorm.io/gorm/clause" | |
) | |
func Apply(db *gorm.DB) { | |
EnforceUpdateByID(db) | |
EnforceScopeByOrganisationID(db) | |
} | |
type ctxKey string | |
const ( | |
skipUpdateWithoutID ctxKey = "safedb.SkipEnforceUpdateByID" | |
skipScopeByOrganisationID ctxKey = "safedb.SkipEnforceScopeByOrganisationID" | |
) | |
type ErrEnforceUpdateByID struct { | |
info QueryInfo | |
} | |
func (e ErrEnforceUpdateByID) Error() string { | |
return fmt.Sprintf("UPDATE is missing WHERE id = ? clause: %s", e.info.Query) | |
} | |
// SkipEnforceUpdateByID allows skipping of the EnforceUpdateByID, when updating in bulk | |
// is a deliberate choice. | |
func SkipEnforceUpdateByID(ctx context.Context) context.Context { | |
return context.WithValue(ctx, skipUpdateWithoutID, true) | |
} | |
// EnforceUpdateByID will fail any database query that tries issuing an update without a | |
// where clause specifying an ID. | |
// | |
// It is highly unusual that the app should make bulk updates, as you almost always want | |
// to be updating row-by-row. This guard helps catch any accidental ommissions that could | |
// potentially update all rows in a table, if left unchecked. | |
func EnforceUpdateByID(db *gorm.DB) { | |
db.Callback().Update().Before("gorm:update").Register("safedb:ban_bulk_updates", func(db *gorm.DB) { | |
if skip, _ := db.Statement.Context.Value(skipUpdateWithoutID).(bool); skip { | |
return | |
} | |
if db.Error != nil { | |
return | |
} | |
if _, found := hasColumnWhereClause(db, "id", "update"); !found { | |
db.AddError(ErrEnforceUpdateByID{buildInfo(db, "update")}) | |
} | |
}) | |
} | |
type ErrEnforceScopeByOrganisationID struct { | |
info QueryInfo | |
} | |
func (e ErrEnforceScopeByOrganisationID) Error() string { | |
return fmt.Sprintf("query is missing WHERE organisation_id = ? clause: %s", e.info.Query) | |
} | |
// SkipEnforceScopeByOrganisationID allows skipping of the EnforceScopeByOrganisationID, | |
// when global access is explicitly requested. | |
func SkipEnforceScopeByOrganisationID(ctx context.Context) context.Context { | |
return context.WithValue(ctx, skipScopeByOrganisationID, true) | |
} | |
// EnforceScopeByOrganisationID restricts read and write database queries by failing | |
// whenever the query lacks an organisation ID. This allows us to protect against | |
// accidental exposure, or mis-scoping. | |
func EnforceScopeByOrganisationID(db *gorm.DB) { | |
enforceScope := func(operation string, db *gorm.DB) { | |
if skip, _ := db.Statement.Context.Value(skipScopeByOrganisationID).(bool); skip { | |
return | |
} | |
if db.Error != nil { | |
return | |
} | |
switch operation { | |
// These operations have where conditions, which we can search for organisation_id | |
case "update", "query": | |
columnName := "organisation_id" | |
if db.Statement.Table == "organisations" { | |
columnName = "id" | |
} | |
value, found := hasColumnWhereClause(db, columnName, operation) | |
if !found { | |
db.AddError(ErrEnforceScopeByOrganisationID{buildInfo(db, operation)}) | |
return | |
} | |
// If we have preloads, we want to automatically add organisation_id conditions to | |
// them. This is because preloads will get executed as yet-another-gorm-query, which | |
// will also be subject to our guard, which would otherwise block the preload for no | |
// organisation ID. | |
// | |
// I'm not sure we want to do this, as I'd prefer us to skip the guards only when | |
// executing preload queries. I think we'd need to make upstream changes to gorm to | |
// support this though, which I want to avoid for now. | |
// | |
// We'll know if this goes wrong though, as it's likely to cause slow queries | |
// against tables which aren't favourable for foreign-key + organisation_id in the | |
// where clauses. | |
// | |
// For clarity, this should cause a gorm query of: | |
// | |
// db.Model(Message{}). | |
// Where("organisation_id = ?", "org-id"). | |
// Preload("Votes"). | |
// Find(&messages) | |
// | |
// To execute the following: | |
// | |
// SELECT * FROM "safedb"."messages" WHERE organisation_id = 'org-id' | |
// SELECT * FROM "safedb"."votes" WHERE "votes"."message_id" = '01FF0C4Z8ZN6RH7RQDXAX9GYVR' AND (organisation_id = 'org-id') | |
// | |
for name, query := range db.Statement.Preloads { | |
if query == nil && name != "Organisation" { | |
db.Statement.Preloads[name] = []interface{}{"organisation_id = ?", value} | |
} | |
} | |
} | |
} | |
var hookName = "safedb:enforce_scope_by_organisation" | |
db.Callback().Query().Before("gorm:query"). | |
Register(hookName, func(db *gorm.DB) { enforceScope("query", db) }) | |
} | |
func hasColumnWhereClause(db *gorm.DB, columnName, operation string) (interface{}, bool) { | |
// Without this, we may not have populated our query clauses | |
stmt := buildQuery(db, operation).Statement | |
where, ok := stmt.Clauses["WHERE"] | |
if !ok { | |
return nil, false | |
} | |
// Check all our where clauses. If we see a clause specifying `<column> = ?` and a valid | |
// non-empty value, we should return as we know this guard has passed. | |
if where, ok := where.Expression.(clause.Where); ok { | |
for _, expr := range where.Exprs { | |
switch expr := expr.(type) { | |
// For magic-string where clauses | |
case clause.Expr: | |
if expr.SQL == fmt.Sprintf("%s = ?", columnName) && len(expr.Vars) == 1 && expr.Vars[0].(string) != "" { | |
return expr.Vars[0].(string), true | |
} | |
// For type-safe where clauses | |
case clause.Eq: | |
switch column := expr.Column.(type) { | |
case clause.Column: | |
if column.Name == columnName && expr.Value != "" { | |
return expr.Value, true | |
} | |
case string: | |
if column == columnName && expr.Value != "" { | |
return expr.Value, true | |
} | |
} | |
} | |
} | |
} | |
return nil, false | |
} | |
type QueryInfo struct { | |
Query string | |
Statement gorm.Statement | |
} | |
// buildQuery runs the appropriate callbacks to populate the query clauses/string on the | |
// gorm.DB. It returns a copy, which has been run under dry-run mode. | |
func buildQuery(db *gorm.DB, operation string) *gorm.DB { | |
db = db.Session(&gorm.Session{DryRun: true}) | |
switch operation { | |
case "query": | |
callbacks.BuildQuerySQL(db) | |
case "update": | |
callbacks.Update(db) | |
} | |
return db | |
} | |
// buildInfo helps generate a spew-able structure that can identify the query, should it | |
// be banned by one of our guards. gorm.DB is cyclic, which doesn't lend itself to | |
// debugging- figuring out how to generate the actual SQL query associated with the handle | |
// is also difficult. | |
func buildInfo(db *gorm.DB, operation string) QueryInfo { | |
db = buildQuery(db, operation) | |
return QueryInfo{ | |
Query: db.Statement.SQL.String(), | |
Statement: gorm.Statement{ | |
Table: db.Statement.Table, | |
TableExpr: db.Statement.TableExpr, | |
Clauses: db.Statement.Clauses, | |
BuildClauses: db.Statement.BuildClauses, | |
Selects: db.Statement.Selects, | |
SQL: db.Statement.SQL, | |
Preloads: db.Statement.Preloads, | |
Joins: db.Statement.Joins, | |
}, | |
} | |
} |
package safedb_test | |
import ( | |
"context" | |
"github.com/incident-io/core/server/safedb" | |
"gorm.io/gorm" | |
. "github.com/onsi/ginkgo" | |
. "github.com/onsi/gomega" | |
) | |
type Message struct { | |
ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"` | |
OrganisationID string `json:"organisation_id"` | |
Message string `json:"message"` | |
Votes []Vote | |
} | |
func (Message) TableName() string { | |
return "safedb.messages" | |
} | |
type Vote struct { | |
ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"` | |
OrganisationID string `json:"organisation_id"` | |
MessageID string `json:"message_id"` | |
} | |
func (Vote) TableName() string { | |
return "safedb.votes" | |
} | |
type Organisation struct { | |
ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"` | |
Name string `json:"name"` | |
} | |
func (Organisation) TableName() string { | |
return "safedb.organisations" | |
} | |
var _ = Describe("safedb", func() { | |
var ( | |
db *gorm.DB | |
msg Message | |
vote Vote | |
org Organisation | |
) | |
BeforeEach(func() { | |
var err error | |
db = pg.GetRaw() | |
db.Exec(`drop schema safedb cascade;`) | |
err = db.Exec(`create schema safedb;`).Error | |
Expect(err).NotTo(HaveOccurred()) | |
err = db.Exec(` | |
create table safedb.messages ( | |
id text primary key default generate_ulid(), | |
organisation_id text not null, | |
message text not null | |
)`).Error | |
Expect(err).NotTo(HaveOccurred()) | |
err = db.Exec(` | |
create table safedb.votes ( | |
id text primary key default generate_ulid(), | |
organisation_id text not null, | |
message_id text not null references safedb.messages(id) | |
)`).Error | |
Expect(err).NotTo(HaveOccurred()) | |
err = db.Exec(` | |
create table safedb.organisations ( | |
id text primary key default generate_ulid(), | |
name text not null | |
)`).Error | |
Expect(err).NotTo(HaveOccurred()) | |
msg = Message{Message: "hello world", OrganisationID: "org-id"} | |
err = db.Create(&msg).Error | |
Expect(err).NotTo(HaveOccurred()) | |
vote = Vote{MessageID: msg.ID, OrganisationID: msg.OrganisationID} | |
err = db.Create(&vote).Error | |
Expect(err).NotTo(HaveOccurred()) | |
org = Organisation{Name: "Skynet"} | |
err = db.Create(&org).Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
Describe("EnforceUpdateByID", func() { | |
BeforeEach(func() { | |
safedb.EnforceUpdateByID(db) | |
}) | |
Context("with bulk updates", func() { | |
It("fails the query", func() { | |
err := db.Model(Message{}). | |
Where("organisation_id = ?", "org-id"). | |
Updates(Message{ | |
Message: "clobber me", | |
}). | |
Error | |
Expect(err).To(BeAssignableToTypeOf(safedb.ErrEnforceUpdateByID{})) | |
}) | |
}) | |
Context("calling Save()", func() { | |
It("permits the query", func() { | |
err := db.Model(msg).Save(&msg).Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
Context("with bulk updates with SkipEnforceUpdateByID", func() { | |
It("permits the query", func() { | |
err := db.Model(Message{}). | |
WithContext(safedb.SkipEnforceUpdateByID(context.Background())). | |
Where("organisation_id = ?", "org-id"). | |
Updates(Message{ | |
Message: "clobber me", | |
}). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
Context("with ID in text clause", func() { | |
It("permits the query", func() { | |
err := db.Model(Message{}). | |
Where("id = ?", "message-id"). | |
Updates(Message{ | |
Message: "give me treats", | |
}). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
Context("with ID in struct clause", func() { | |
It("permits the query", func() { | |
err := db.Model(Message{}). | |
Where(Message{ID: "message-id"}). | |
Updates(Message{ | |
Message: "give me treats", | |
}). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
}) | |
Describe("EnforceScopeByOrganisationID", func() { | |
BeforeEach(func() { | |
safedb.EnforceScopeByOrganisationID(db) | |
}) | |
Context("with scoped select", func() { | |
It("permits query", func() { | |
var count int64 | |
err := db.Model(Message{}). | |
Where("organisation_id = ?", "org-id"). | |
Count(&count). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
Context("with preloads", func() { | |
It("permits the preload query", func() { | |
var messages []*Message | |
err := db.Model(Message{}). | |
Where("organisation_id = ?", "org-id"). | |
Preload("Votes", "organisation_id = ?", msg.OrganisationID). | |
Find(&messages). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
Expect(messages).NotTo(BeEmpty()) | |
Expect(messages[0].Votes).NotTo(BeEmpty()) | |
}) | |
}) | |
Context("with scoped select using struct", func() { | |
It("permits query", func() { | |
var count int64 | |
err := db.Model(Message{}). | |
Where(Message{OrganisationID: "org-id"}). | |
Count(&count). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
Context("with select against organisations with ID", func() { | |
It("permits query", func() { | |
var res Organisation | |
err := db.Model(Organisation{}). | |
Where(Organisation{ID: org.ID}). | |
First(&res). | |
Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
// Create and Save are effectively the same, and go via the same hooks. | |
Context("with scoped create/save", func() { | |
It("permits query", func() { | |
msg := Message{Message: "boop", OrganisationID: "org-id"} | |
err := db.Create(&msg).Error | |
Expect(err).NotTo(HaveOccurred()) | |
}) | |
}) | |
Context("with unscoped select", func() { | |
query := func(ctx context.Context) error { | |
var count int64 | |
return db.Model(Message{}). | |
WithContext(ctx). | |
Where("id = ?", "message-id"). | |
Count(&count). | |
Error | |
} | |
It("fails query", func() { | |
ctx := context.Background() | |
Expect(query(ctx)). | |
To(BeAssignableToTypeOf(safedb.ErrEnforceScopeByOrganisationID{})) | |
}) | |
Context("with SkipEnforceScopeByOrganisationID", func() { | |
It("permits query", func() { | |
ctx := safedb.SkipEnforceScopeByOrganisationID(context.Background()) | |
Expect(query(ctx)). | |
To(Succeed()) | |
}) | |
}) | |
}) | |
}) | |
}) |