Skip to content

Instantly share code, notes, and snippets.

@paprikati
Last active August 4, 2024 09:25
Show Gist options
  • Save paprikati/746a3dfe0d99e5ece012642fa1d7a354 to your computer and use it in GitHub Desktop.
Save paprikati/746a3dfe0d99e5ece012642fa1d7a354 to your computer and use it in GitHub Desktop.
safedb: using a gorm hook to check that queries are scoped by organisation ID

SafeDB

This gist is associated with the blog post "Building safe-by-default tools in our Go web application".

It contains the gorm hooks that we use to ensure our queries are correctly scoped.

package safedb
import (
"context"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
)
func Apply(db *gorm.DB) {
EnforceUpdateByID(db)
EnforceScopeByOrganisationID(db)
}
type ctxKey string
const (
skipUpdateWithoutID ctxKey = "safedb.SkipEnforceUpdateByID"
skipScopeByOrganisationID ctxKey = "safedb.SkipEnforceScopeByOrganisationID"
)
type ErrEnforceUpdateByID struct {
info QueryInfo
}
func (e ErrEnforceUpdateByID) Error() string {
return fmt.Sprintf("UPDATE is missing WHERE id = ? clause: %s", e.info.Query)
}
// SkipEnforceUpdateByID allows skipping of the EnforceUpdateByID, when updating in bulk
// is a deliberate choice.
func SkipEnforceUpdateByID(ctx context.Context) context.Context {
return context.WithValue(ctx, skipUpdateWithoutID, true)
}
// EnforceUpdateByID will fail any database query that tries issuing an update without a
// where clause specifying an ID.
//
// It is highly unusual that the app should make bulk updates, as you almost always want
// to be updating row-by-row. This guard helps catch any accidental ommissions that could
// potentially update all rows in a table, if left unchecked.
func EnforceUpdateByID(db *gorm.DB) {
db.Callback().Update().Before("gorm:update").Register("safedb:ban_bulk_updates", func(db *gorm.DB) {
if skip, _ := db.Statement.Context.Value(skipUpdateWithoutID).(bool); skip {
return
}
if db.Error != nil {
return
}
if _, found := hasColumnWhereClause(db, "id", "update"); !found {
db.AddError(ErrEnforceUpdateByID{buildInfo(db, "update")})
}
})
}
type ErrEnforceScopeByOrganisationID struct {
info QueryInfo
}
func (e ErrEnforceScopeByOrganisationID) Error() string {
return fmt.Sprintf("query is missing WHERE organisation_id = ? clause: %s", e.info.Query)
}
// SkipEnforceScopeByOrganisationID allows skipping of the EnforceScopeByOrganisationID,
// when global access is explicitly requested.
func SkipEnforceScopeByOrganisationID(ctx context.Context) context.Context {
return context.WithValue(ctx, skipScopeByOrganisationID, true)
}
// EnforceScopeByOrganisationID restricts read and write database queries by failing
// whenever the query lacks an organisation ID. This allows us to protect against
// accidental exposure, or mis-scoping.
func EnforceScopeByOrganisationID(db *gorm.DB) {
enforceScope := func(operation string, db *gorm.DB) {
if skip, _ := db.Statement.Context.Value(skipScopeByOrganisationID).(bool); skip {
return
}
if db.Error != nil {
return
}
switch operation {
// These operations have where conditions, which we can search for organisation_id
case "update", "query":
columnName := "organisation_id"
if db.Statement.Table == "organisations" {
columnName = "id"
}
value, found := hasColumnWhereClause(db, columnName, operation)
if !found {
db.AddError(ErrEnforceScopeByOrganisationID{buildInfo(db, operation)})
return
}
// If we have preloads, we want to automatically add organisation_id conditions to
// them. This is because preloads will get executed as yet-another-gorm-query, which
// will also be subject to our guard, which would otherwise block the preload for no
// organisation ID.
//
// I'm not sure we want to do this, as I'd prefer us to skip the guards only when
// executing preload queries. I think we'd need to make upstream changes to gorm to
// support this though, which I want to avoid for now.
//
// We'll know if this goes wrong though, as it's likely to cause slow queries
// against tables which aren't favourable for foreign-key + organisation_id in the
// where clauses.
//
// For clarity, this should cause a gorm query of:
//
// db.Model(Message{}).
// Where("organisation_id = ?", "org-id").
// Preload("Votes").
// Find(&messages)
//
// To execute the following:
//
// SELECT * FROM "safedb"."messages" WHERE organisation_id = 'org-id'
// SELECT * FROM "safedb"."votes" WHERE "votes"."message_id" = '01FF0C4Z8ZN6RH7RQDXAX9GYVR' AND (organisation_id = 'org-id')
//
for name, query := range db.Statement.Preloads {
if query == nil && name != "Organisation" {
db.Statement.Preloads[name] = []interface{}{"organisation_id = ?", value}
}
}
}
}
var hookName = "safedb:enforce_scope_by_organisation"
db.Callback().Query().Before("gorm:query").
Register(hookName, func(db *gorm.DB) { enforceScope("query", db) })
}
func hasColumnWhereClause(db *gorm.DB, columnName, operation string) (interface{}, bool) {
// Without this, we may not have populated our query clauses
stmt := buildQuery(db, operation).Statement
where, ok := stmt.Clauses["WHERE"]
if !ok {
return nil, false
}
// Check all our where clauses. If we see a clause specifying `<column> = ?` and a valid
// non-empty value, we should return as we know this guard has passed.
if where, ok := where.Expression.(clause.Where); ok {
for _, expr := range where.Exprs {
switch expr := expr.(type) {
// For magic-string where clauses
case clause.Expr:
if expr.SQL == fmt.Sprintf("%s = ?", columnName) && len(expr.Vars) == 1 && expr.Vars[0].(string) != "" {
return expr.Vars[0].(string), true
}
// For type-safe where clauses
case clause.Eq:
switch column := expr.Column.(type) {
case clause.Column:
if column.Name == columnName && expr.Value != "" {
return expr.Value, true
}
case string:
if column == columnName && expr.Value != "" {
return expr.Value, true
}
}
}
}
}
return nil, false
}
type QueryInfo struct {
Query string
Statement gorm.Statement
}
// buildQuery runs the appropriate callbacks to populate the query clauses/string on the
// gorm.DB. It returns a copy, which has been run under dry-run mode.
func buildQuery(db *gorm.DB, operation string) *gorm.DB {
db = db.Session(&gorm.Session{DryRun: true})
switch operation {
case "query":
callbacks.BuildQuerySQL(db)
case "update":
callbacks.Update(db)
}
return db
}
// buildInfo helps generate a spew-able structure that can identify the query, should it
// be banned by one of our guards. gorm.DB is cyclic, which doesn't lend itself to
// debugging- figuring out how to generate the actual SQL query associated with the handle
// is also difficult.
func buildInfo(db *gorm.DB, operation string) QueryInfo {
db = buildQuery(db, operation)
return QueryInfo{
Query: db.Statement.SQL.String(),
Statement: gorm.Statement{
Table: db.Statement.Table,
TableExpr: db.Statement.TableExpr,
Clauses: db.Statement.Clauses,
BuildClauses: db.Statement.BuildClauses,
Selects: db.Statement.Selects,
SQL: db.Statement.SQL,
Preloads: db.Statement.Preloads,
Joins: db.Statement.Joins,
},
}
}
package safedb_test
import (
"context"
"github.com/incident-io/core/server/safedb"
"gorm.io/gorm"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type Message struct {
ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"`
OrganisationID string `json:"organisation_id"`
Message string `json:"message"`
Votes []Vote
}
func (Message) TableName() string {
return "safedb.messages"
}
type Vote struct {
ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"`
OrganisationID string `json:"organisation_id"`
MessageID string `json:"message_id"`
}
func (Vote) TableName() string {
return "safedb.votes"
}
type Organisation struct {
ID string `json:"id" gorm:"type:text;primaryKey;default:generate_ulid()"`
Name string `json:"name"`
}
func (Organisation) TableName() string {
return "safedb.organisations"
}
var _ = Describe("safedb", func() {
var (
db *gorm.DB
msg Message
vote Vote
org Organisation
)
BeforeEach(func() {
var err error
db = pg.GetRaw()
db.Exec(`drop schema safedb cascade;`)
err = db.Exec(`create schema safedb;`).Error
Expect(err).NotTo(HaveOccurred())
err = db.Exec(`
create table safedb.messages (
id text primary key default generate_ulid(),
organisation_id text not null,
message text not null
)`).Error
Expect(err).NotTo(HaveOccurred())
err = db.Exec(`
create table safedb.votes (
id text primary key default generate_ulid(),
organisation_id text not null,
message_id text not null references safedb.messages(id)
)`).Error
Expect(err).NotTo(HaveOccurred())
err = db.Exec(`
create table safedb.organisations (
id text primary key default generate_ulid(),
name text not null
)`).Error
Expect(err).NotTo(HaveOccurred())
msg = Message{Message: "hello world", OrganisationID: "org-id"}
err = db.Create(&msg).Error
Expect(err).NotTo(HaveOccurred())
vote = Vote{MessageID: msg.ID, OrganisationID: msg.OrganisationID}
err = db.Create(&vote).Error
Expect(err).NotTo(HaveOccurred())
org = Organisation{Name: "Skynet"}
err = db.Create(&org).Error
Expect(err).NotTo(HaveOccurred())
})
Describe("EnforceUpdateByID", func() {
BeforeEach(func() {
safedb.EnforceUpdateByID(db)
})
Context("with bulk updates", func() {
It("fails the query", func() {
err := db.Model(Message{}).
Where("organisation_id = ?", "org-id").
Updates(Message{
Message: "clobber me",
}).
Error
Expect(err).To(BeAssignableToTypeOf(safedb.ErrEnforceUpdateByID{}))
})
})
Context("calling Save()", func() {
It("permits the query", func() {
err := db.Model(msg).Save(&msg).Error
Expect(err).NotTo(HaveOccurred())
})
})
Context("with bulk updates with SkipEnforceUpdateByID", func() {
It("permits the query", func() {
err := db.Model(Message{}).
WithContext(safedb.SkipEnforceUpdateByID(context.Background())).
Where("organisation_id = ?", "org-id").
Updates(Message{
Message: "clobber me",
}).
Error
Expect(err).NotTo(HaveOccurred())
})
})
Context("with ID in text clause", func() {
It("permits the query", func() {
err := db.Model(Message{}).
Where("id = ?", "message-id").
Updates(Message{
Message: "give me treats",
}).
Error
Expect(err).NotTo(HaveOccurred())
})
})
Context("with ID in struct clause", func() {
It("permits the query", func() {
err := db.Model(Message{}).
Where(Message{ID: "message-id"}).
Updates(Message{
Message: "give me treats",
}).
Error
Expect(err).NotTo(HaveOccurred())
})
})
})
Describe("EnforceScopeByOrganisationID", func() {
BeforeEach(func() {
safedb.EnforceScopeByOrganisationID(db)
})
Context("with scoped select", func() {
It("permits query", func() {
var count int64
err := db.Model(Message{}).
Where("organisation_id = ?", "org-id").
Count(&count).
Error
Expect(err).NotTo(HaveOccurred())
})
})
Context("with preloads", func() {
It("permits the preload query", func() {
var messages []*Message
err := db.Model(Message{}).
Where("organisation_id = ?", "org-id").
Preload("Votes", "organisation_id = ?", msg.OrganisationID).
Find(&messages).
Error
Expect(err).NotTo(HaveOccurred())
Expect(messages).NotTo(BeEmpty())
Expect(messages[0].Votes).NotTo(BeEmpty())
})
})
Context("with scoped select using struct", func() {
It("permits query", func() {
var count int64
err := db.Model(Message{}).
Where(Message{OrganisationID: "org-id"}).
Count(&count).
Error
Expect(err).NotTo(HaveOccurred())
})
})
Context("with select against organisations with ID", func() {
It("permits query", func() {
var res Organisation
err := db.Model(Organisation{}).
Where(Organisation{ID: org.ID}).
First(&res).
Error
Expect(err).NotTo(HaveOccurred())
})
})
// Create and Save are effectively the same, and go via the same hooks.
Context("with scoped create/save", func() {
It("permits query", func() {
msg := Message{Message: "boop", OrganisationID: "org-id"}
err := db.Create(&msg).Error
Expect(err).NotTo(HaveOccurred())
})
})
Context("with unscoped select", func() {
query := func(ctx context.Context) error {
var count int64
return db.Model(Message{}).
WithContext(ctx).
Where("id = ?", "message-id").
Count(&count).
Error
}
It("fails query", func() {
ctx := context.Background()
Expect(query(ctx)).
To(BeAssignableToTypeOf(safedb.ErrEnforceScopeByOrganisationID{}))
})
Context("with SkipEnforceScopeByOrganisationID", func() {
It("permits query", func() {
ctx := safedb.SkipEnforceScopeByOrganisationID(context.Background())
Expect(query(ctx)).
To(Succeed())
})
})
})
})
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment