This gist is associated with the blog post "Building safe-by-default tools in our Go web application".
It contains the gorm hooks that we use to ensure our queries are correctly scoped.
| package safedb | |
| import ( | |
| "context" | |
| "fmt" | |
| "gorm.io/gorm" | |
| "gorm.io/gorm/callbacks" | |
| "gorm.io/gorm/clause" | |
| ) | |
| func Apply(db *gorm.DB) { | |
| EnforceUpdateByID(db) | |
| EnforceScopeByOrganisationID(db) | |
| } | |
| type ctxKey string | |
| const ( | |
| skipUpdateWithoutID ctxKey = "safedb.SkipEnforceUpdateByID" | |
| skipScopeByOrganisationID ctxKey = "safedb.SkipEnforceScopeByOrganisationID" | |
| ) | |
| type ErrEnforceUpdateByID struct { | |
| info QueryInfo | |
| } | |
| func (e ErrEnforceUpdateByID) Error() string { | |
| return fmt.Sprintf("UPDATE is missing WHERE id = ? clause: %s", e.info.Query) | |
| } | |
| // SkipEnforceUpdateByID allows skipping of the EnforceUpdateByID, when updating in bulk | |
| // is a deliberate choice. | |
| func SkipEnforceUpdateByID(ctx context.Context) context.Context { | |
| return context.WithValue(ctx, skipUpdateWithoutID, true) | |
| } | |
| // EnforceUpdateByID will fail any database query that tries issuing an update without a | |
| // where clause specifying an ID. | |
| // | |
| // It is highly unusual that the app should make bulk updates, as you almost always want | |
| // to be updating row-by-row. This guard helps catch any accidental ommissions that could | |
| // potentially update all rows in a table, if left unchecked. | |
| func EnforceUpdateByID(db *gorm.DB) { | |
| db.Callback().Update().Before("gorm:update").Register("safedb:ban_bulk_updates", func(db *gorm.DB) { | |
| if skip, _ := db.Statement.Context.Value(skipUpdateWithoutID).(bool); skip { | |
| return | |
| } | |
| if db.Error != nil { | |
| return | |
| } | |
| if _, found := hasColumnWhereClause(db, "id", "update"); !found { | |
| db.AddError(ErrEnforceUpdateByID{buildInfo(db, "update")}) | |
| } | |
| }) | |
| } | |
| type ErrEnforceScopeByOrganisationID struct { | |
| info QueryInfo | |
| } | |
| func (e ErrEnforceScopeByOrganisationID) Error() string { | |
| return fmt.Sprintf("query is missing WHERE organisation_id = ? clause: %s", e.info.Query) | |
| } | |
| // SkipEnforceScopeByOrganisationID allows skipping of the EnforceScopeByOrganisationID, | |
| // when global access is explicitly requested. | |
| func SkipEnforceScopeByOrganisationID(ctx context.Context) context.Context { | |
| return context.WithValue(ctx, skipScopeByOrganisationID, true) | |
| } | |
| // EnforceScopeByOrganisationID restricts read and write database queries by failing | |
| // whenever the query lacks an organisation ID. This allows us to protect against | |
| // accidental exposure, or mis-scoping. | |
| func EnforceScopeByOrganisationID(db *gorm.DB) { | |
| enforceScope := func(operation string, db *gorm.DB) { | |
| if skip, _ := db.Statement.Context.Value(skipScopeByOrganisationID).(bool); skip { | |
| return | |
| } | |
| if db.Error != nil { | |
| return | |
| } | |
| switch operation { | |
| // These operations have where conditions, which we can search for organisation_id | |
| case "update", "query": | |
| columnName := "organisation_id" | |
| if db.Statement.Table == "organisations" { | |
| columnName = "id" | |
| } | |
| value, found := hasColumnWhereClause(db, columnName, operation) | |
| if !found { | |
| db.AddError(ErrEnforceScopeByOrganisationID{buildInfo(db, operation)}) | |
| return | |
| } | |
| // If we have preloads, we want to automatically add organisation_id conditions to | |
| // them. This is because preloads will get executed as yet-another-gorm-query, which | |
| // will also be subject to our guard, which would otherwise block the preload for no | |
| // organisation ID. | |
| // | |
| // I'm not sure we want to do this, as I'd prefer us to skip the guards only when | |
| // executing preload queries. I think we'd need to make upstream changes to gorm to | |
| // support this though, which I want to avoid for now. | |
| // | |
| // We'll know if this goes wrong though, as it's likely to cause slow queries | |
| // against tables which aren't favourable for foreign-key + organisation_id in the | |
| // where clauses. | |
| // | |
| // For clarity, this should cause a gorm query of: | |
| // | |
| // db.Model(Message{}). | |
| // Where("organisation_id = ?", "org-id"). | |
| // Preload("Votes"). | |
| // Find(&messages) | |
| // | |
| // To execute the following: | |
| // | |
| // SELECT * FROM "safedb"."messages" WHERE organisation_id = 'org-id' | |
| // SELECT * FROM "safedb"."votes" WHERE "votes"."message_id" = '01FF0C4Z8ZN6RH7RQDXAX9GYVR' AND (organisation_id = 'org-id') | |
| // | |
| for name, query := range db.Statement.Preloads { | |
| if query == nil && name != "Organisation" { | |
| db.Statement.Preloads[name] = []interface{}{"organisation_id = ?", value} | |
| } | |
| } | |
| } | |
| } | |
| var hookName = "safedb:enforce_scope_by_organisation" | |
| db.Callback().Query().Before("gorm:query"). | |
| Register(hookName, func(db *gorm.DB) { enforceScope("query", db) }) | |
| } | |
| func hasColumnWhereClause(db *gorm.DB, columnName, operation string) (interface{}, bool) { | |
| // Without this, we may not have populated our query clauses | |
| stmt := buildQuery(db, operation).Statement | |
| where, ok := stmt.Clauses["WHERE"] | |
| if !ok { | |
| return nil, false | |
| } | |
| // Check all our where clauses. If we see a clause specifying `<column> = ?` and a valid | |
| // non-empty value, we should return as we know this guard has passed. | |
| if where, ok := where.Expression.(clause.Where); ok { | |
| for _, expr := range where.Exprs { | |
| switch expr := expr.(type) { | |
| // For magic-string where clauses | |
| case clause.Expr: | |
| if expr.SQL == fmt.Sprintf("%s = ?", columnName) && len(expr.Vars) == 1 && expr.Vars[0].(string) != "" { | |
| return expr.Vars[0].(string), true | |
| } | |
| // For type-safe where clauses | |
| case clause.Eq: | |
| switch column := expr.Column.(type) { | |
| case clause.Column: | |
| if column.Name == columnName && expr.Value != "" { | |
| return expr.Value, true | |
| } | |
| case string: | |
| if column == columnName && expr.Value != "" { | |
| return expr.Value, true | |
| } | |
| } | |
| } | |
| } | |
| } | |
| return nil, false | |
| } | |
| type QueryInfo struct { | |
| Query string | |
| Statement gorm.Statement | |
| } | |
| // buildQuery runs the appropriate callbacks to populate the query clauses/string on the | |
| // gorm.DB. It returns a copy, which has been run under dry-run mode. | |
| func buildQuery(db *gorm.DB, operation string) *gorm.DB { | |
| db = db.Session(&gorm.Session{DryRun: true}) | |
| switch operation { | |
| case "query": | |
| callbacks.BuildQuerySQL(db) | |
| case "update": | |
| callbacks.Update(db) | |
| } | |
| return db | |
| } | |
| // buildInfo helps generate a spew-able structure that can identify the query, should it | |
| // be banned by one of our guards. gorm.DB is cyclic, which doesn't lend itself to | |
| // debugging- figuring out how to generate the actual SQL query associated with the handle | |
| // is also difficult. | |
| func buildInfo(db *gorm.DB, operation string) QueryInfo { | |
| db = buildQuery(db, operation) | |
| return QueryInfo{ | |
| Query: db.Statement.SQL.String(), | |
| Statement: gorm.Statement{ | |
| Table: db.Statement.Table, | |
| TableExpr: db.Statement.TableExpr, | |
| Clauses: db.Statement.Clauses, | |
| BuildClauses: db.Statement.BuildClauses, | |
| Selects: db.Statement.Selects, | |
| SQL: db.Statement.SQL, | |
| Preloads: db.Statement.Preloads, | |
| Joins: db.Statement.Joins, | |
| }, | |
| } | |
| } |
| package safedb_test | |
| import ( | |
| "context" | |
| "github.com/incident-io/core/server/safedb" | |
| "gorm.io/gorm" | |
| . "github.com/onsi/ginkgo" | |
| . "github.com/onsi/gomega" | |
| ) | |
| type Message struct { | |
| ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"` | |
| OrganisationID string `json:"organisation_id"` | |
| Message string `json:"message"` | |
| Votes []Vote | |
| } | |
| func (Message) TableName() string { | |
| return "safedb.messages" | |
| } | |
| type Vote struct { | |
| ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"` | |
| OrganisationID string `json:"organisation_id"` | |
| MessageID string `json:"message_id"` | |
| } | |
| func (Vote) TableName() string { | |
| return "safedb.votes" | |
| } | |
| type Organisation struct { | |
| ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"` | |
| Name string `json:"name"` | |
| } | |
| func (Organisation) TableName() string { | |
| return "safedb.organisations" | |
| } | |
| var _ = Describe("safedb", func() { | |
| var ( | |
| db *gorm.DB | |
| msg Message | |
| vote Vote | |
| org Organisation | |
| ) | |
| BeforeEach(func() { | |
| var err error | |
| db = pg.GetRaw() | |
| db.Exec(`drop schema safedb cascade;`) | |
| err = db.Exec(`create schema safedb;`).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| err = db.Exec(` | |
| create table safedb.messages ( | |
| id text primary key default generate_ulid(), | |
| organisation_id text not null, | |
| message text not null | |
| )`).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| err = db.Exec(` | |
| create table safedb.votes ( | |
| id text primary key default generate_ulid(), | |
| organisation_id text not null, | |
| message_id text not null references safedb.messages(id) | |
| )`).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| err = db.Exec(` | |
| create table safedb.organisations ( | |
| id text primary key default generate_ulid(), | |
| name text not null | |
| )`).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| msg = Message{Message: "hello world", OrganisationID: "org-id"} | |
| err = db.Create(&msg).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| vote = Vote{MessageID: msg.ID, OrganisationID: msg.OrganisationID} | |
| err = db.Create(&vote).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| org = Organisation{Name: "Skynet"} | |
| err = db.Create(&org).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| Describe("EnforceUpdateByID", func() { | |
| BeforeEach(func() { | |
| safedb.EnforceUpdateByID(db) | |
| }) | |
| Context("with bulk updates", func() { | |
| It("fails the query", func() { | |
| err := db.Model(Message{}). | |
| Where("organisation_id = ?", "org-id"). | |
| Updates(Message{ | |
| Message: "clobber me", | |
| }). | |
| Error | |
| Expect(err).To(BeAssignableToTypeOf(safedb.ErrEnforceUpdateByID{})) | |
| }) | |
| }) | |
| Context("calling Save()", func() { | |
| It("permits the query", func() { | |
| err := db.Model(msg).Save(&msg).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| Context("with bulk updates with SkipEnforceUpdateByID", func() { | |
| It("permits the query", func() { | |
| err := db.Model(Message{}). | |
| WithContext(safedb.SkipEnforceUpdateByID(context.Background())). | |
| Where("organisation_id = ?", "org-id"). | |
| Updates(Message{ | |
| Message: "clobber me", | |
| }). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| Context("with ID in text clause", func() { | |
| It("permits the query", func() { | |
| err := db.Model(Message{}). | |
| Where("id = ?", "message-id"). | |
| Updates(Message{ | |
| Message: "give me treats", | |
| }). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| Context("with ID in struct clause", func() { | |
| It("permits the query", func() { | |
| err := db.Model(Message{}). | |
| Where(Message{ID: "message-id"}). | |
| Updates(Message{ | |
| Message: "give me treats", | |
| }). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| }) | |
| Describe("EnforceScopeByOrganisationID", func() { | |
| BeforeEach(func() { | |
| safedb.EnforceScopeByOrganisationID(db) | |
| }) | |
| Context("with scoped select", func() { | |
| It("permits query", func() { | |
| var count int64 | |
| err := db.Model(Message{}). | |
| Where("organisation_id = ?", "org-id"). | |
| Count(&count). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| Context("with preloads", func() { | |
| It("permits the preload query", func() { | |
| var messages []*Message | |
| err := db.Model(Message{}). | |
| Where("organisation_id = ?", "org-id"). | |
| Preload("Votes", "organisation_id = ?", msg.OrganisationID). | |
| Find(&messages). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| Expect(messages).NotTo(BeEmpty()) | |
| Expect(messages[0].Votes).NotTo(BeEmpty()) | |
| }) | |
| }) | |
| Context("with scoped select using struct", func() { | |
| It("permits query", func() { | |
| var count int64 | |
| err := db.Model(Message{}). | |
| Where(Message{OrganisationID: "org-id"}). | |
| Count(&count). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| Context("with select against organisations with ID", func() { | |
| It("permits query", func() { | |
| var res Organisation | |
| err := db.Model(Organisation{}). | |
| Where(Organisation{ID: org.ID}). | |
| First(&res). | |
| Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| // Create and Save are effectively the same, and go via the same hooks. | |
| Context("with scoped create/save", func() { | |
| It("permits query", func() { | |
| msg := Message{Message: "boop", OrganisationID: "org-id"} | |
| err := db.Create(&msg).Error | |
| Expect(err).NotTo(HaveOccurred()) | |
| }) | |
| }) | |
| Context("with unscoped select", func() { | |
| query := func(ctx context.Context) error { | |
| var count int64 | |
| return db.Model(Message{}). | |
| WithContext(ctx). | |
| Where("id = ?", "message-id"). | |
| Count(&count). | |
| Error | |
| } | |
| It("fails query", func() { | |
| ctx := context.Background() | |
| Expect(query(ctx)). | |
| To(BeAssignableToTypeOf(safedb.ErrEnforceScopeByOrganisationID{})) | |
| }) | |
| Context("with SkipEnforceScopeByOrganisationID", func() { | |
| It("permits query", func() { | |
| ctx := safedb.SkipEnforceScopeByOrganisationID(context.Background()) | |
| Expect(query(ctx)). | |
| To(Succeed()) | |
| }) | |
| }) | |
| }) | |
| }) | |
| }) |