Last active
January 18, 2024 07:00
-
-
Save ionling/10f50bf3d77040fa8bb4f6695c23befe to your computer and use it in GitHub Desktop.
Check PostgreSQL data differences
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"database/sql" | |
"flag" | |
"fmt" | |
"log/slog" | |
"math/rand" | |
"os" | |
"reflect" | |
"strings" | |
"github.com/uptrace/bun" | |
"github.com/uptrace/bun/dialect/pgdialect" | |
"github.com/uptrace/bun/driver/pgdriver" | |
"github.com/uptrace/bun/extra/bundebug" | |
"github.com/uptrace/bun/extra/bunotel" | |
"golang.org/x/sync/errgroup" | |
) | |
const ( | |
maxCheckRows = 3000 | |
) | |
var ( | |
srcDSN = os.Getenv("SRC_DSN") | |
dstDSN = os.Getenv("DST_DSN") | |
checkF = flag.Bool("check", false, | |
"Check the difference between the source and destination table") | |
syncSeqsF = flag.Bool("sync-seqs", false, | |
"Sync the last value of sequences from source to destination") | |
schemaF = flag.String("schema", "public", "PostgreSQL database schema") | |
tablesF = flag.String("tables", "", | |
"Table names separated by ',', default to all tables in database") | |
orderByF = flag.String("orderby", "", "Order by clause used in query") | |
) | |
func main() { | |
flag.Parse() | |
l := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ | |
Level: slog.LevelDebug, | |
})) | |
if err := do(context.Background(), l); err != nil { | |
fmt.Println("ERR:", err) | |
} | |
} | |
func do(ctx context.Context, l *slog.Logger) error { | |
srcDB, srcCleanup, err := newPostgres(srcDSN, *schemaF) | |
if err != nil { | |
return fmt.Errorf("new src db: %w", err) | |
} | |
defer srcCleanup() | |
dstDB, dstCleanup, err := newPostgres(dstDSN, *schemaF) | |
if err != nil { | |
return fmt.Errorf("new dst db: %w", err) | |
} | |
defer dstCleanup() | |
lr := LogicalRep{ | |
SrcDB: srcDB, | |
DstDB: dstDB, | |
Schema: *schemaF, | |
l: l, | |
} | |
switch { | |
default: | |
return fmt.Errorf("no action") | |
case *checkF: | |
return lr.Check(ctx) | |
case *syncSeqsF: | |
res, err := lr.SyncSequences(ctx) | |
if err != nil { | |
return err | |
} | |
fmt.Printf("%v\n", res) | |
} | |
return nil | |
} | |
func newPostgres(dsn, schema string) (db *bun.DB, cleanup func(), err error) { | |
dsn += "&application_name=logicalrep&search_path=" + schema | |
// The config has defaults for timeouts, | |
// so it's not necessary to specify them again: | |
// - DialTimeout: 5 * time.Second | |
// - ReadTimeout: 10 * time.Second | |
// - WriteTimeout: 5 * time.Second | |
connector := pgdriver.NewConnector(pgdriver.WithDSN(dsn)) | |
sqldb := sql.OpenDB(connector) | |
db = bun.NewDB(sqldb, pgdialect.New(), bun.WithDiscardUnknownColumns()) | |
db.AddQueryHook(bundebug.NewQueryHook()) | |
db.AddQueryHook(bunotel.NewQueryHook( | |
bunotel.WithDBName(connector.Config().Database), | |
)) | |
err = db.Ping() | |
cleanup = func() { | |
db.Close() | |
} | |
return | |
} | |
func (lr *LogicalRep) Check(ctx context.Context) error { | |
l := lr.l.With("func", "Check") | |
if *tablesF != "" { | |
var tables []string | |
for _, t := range strings.Split(*tablesF, ",") { | |
t = strings.TrimSpace(t) | |
if t != "" { | |
tables = append(tables, t) | |
} | |
} | |
if len(tables) == 0 { | |
return fmt.Errorf("no table provided") | |
} | |
lr.Tables = tables | |
} else { | |
tables, err := lr.ListTables(ctx) | |
if err != nil { | |
return fmt.Errorf("list tables: %w", err) | |
} | |
for _, t := range tables { | |
lr.Tables = append(lr.Tables, t.TableName) | |
} | |
} | |
countRes, err := lr.CheckCount(ctx) | |
if err != nil { | |
return fmt.Errorf("check count: %w", err) | |
} | |
rowsRes, err := lr.CheckRows(ctx, countRes) | |
if err != nil { | |
return fmt.Errorf("check rows: %w", err) | |
} | |
l.InfoContext(ctx, "check count", "result", countRes) | |
l.InfoContext(ctx, "check rows", "result", rowsRes) | |
return nil | |
} | |
type LogicalRep struct { | |
SrcDB *bun.DB | |
DstDB *bun.DB | |
Schema string | |
Tables []string | |
l *slog.Logger | |
} | |
type Table struct { | |
TableName string | |
} | |
func (lr *LogicalRep) ListTables(ctx context.Context) (res []*Table, err error) { | |
// REF https://stackoverflow.com/a/2276722/7134763 | |
err = lr.SrcDB.NewSelect().Table("information_schema.tables"). | |
Where("table_schema = ?", lr.Schema). | |
Scan(ctx, &res) | |
return | |
} | |
type TableCount struct { | |
Src, Dst int | |
} | |
type CheckCountRes struct { | |
Count map[string]TableCount // table -> count | |
BadCount int | |
} | |
func (lr *LogicalRep) CheckCount(ctx context.Context) (res *CheckCountRes, err error) { | |
l := lr.l.With("func", "CheckCount") | |
res = &CheckCountRes{ | |
Count: make(map[string]TableCount), | |
} | |
for _, t := range lr.Tables { | |
l := l.With("table", t) | |
eg, ctx := errgroup.WithContext(ctx) | |
var srcN, dstN int | |
eg.Go(func() (err error) { | |
srcN, err = lr.SrcDB.NewSelect().Table(t).Count(ctx) | |
return wrap(err, "count src") | |
}) | |
eg.Go(func() (err error) { | |
dstN, err = lr.DstDB.NewSelect().Table(t).Count(ctx) | |
return wrap(err, "count dst") | |
}) | |
if err := eg.Wait(); err != nil { | |
return nil, err | |
} | |
res.Count[t] = TableCount{ | |
Src: srcN, | |
Dst: dstN, | |
} | |
l = l.With("src", srcN, "dst", dstN) | |
if srcN == dstN { | |
l.InfoContext(ctx, "compare") | |
} else { | |
res.BadCount++ | |
l.WarnContext(ctx, "compare") | |
} | |
} | |
return | |
} | |
type CheckRowsRes struct { | |
BadCount int | |
} | |
func (lr *LogicalRep) CheckRows( | |
ctx context.Context, countRes *CheckCountRes, | |
) (res *CheckRowsRes, err error) { | |
l := lr.l.With("func", "CheckRows") | |
res = &CheckRowsRes{} | |
for _, t := range lr.Tables { | |
l := l.With("table", t) | |
count := countRes.Count[t] | |
maxN := max(count.Src, count.Dst) | |
var sql string | |
var offset int | |
if maxN <= maxCheckRows { | |
offset = 0 | |
} else { | |
offset = rand.Intn(maxN - maxCheckRows) | |
} | |
var orderBy string | |
if *orderByF != "" { | |
orderBy = " ORDER BY " + *orderByF | |
} | |
sql = fmt.Sprintf("SELECT * FROM %s%s LIMIT %d OFFSET %d", | |
t, orderBy, maxCheckRows, offset) | |
// Below random selection is difficult to archive, | |
// because we don't have the fixed columns to order them. | |
// sql = fmt.Sprintf("SELECT * FROM %s TABLESAMPLE SYSTEM(3000 / %d)", t, maxN) | |
var srcs, dsts []map[string]any | |
eg, ctx := errgroup.WithContext(ctx) | |
eg.Go(func() error { | |
err := lr.SrcDB.NewRaw(sql).Scan(ctx, &srcs) | |
return wrap(err, "query src") | |
}) | |
eg.Go(func() error { | |
err := lr.DstDB.NewRaw(sql).Scan(ctx, &dsts) | |
return wrap(err, "query dst") | |
}) | |
if err := eg.Wait(); err != nil { | |
return nil, err | |
} | |
l = l.With("len(srcs)", len(srcs), "len(dsts)", len(dsts)) | |
if ok := reflect.DeepEqual(srcs, dsts); ok { | |
l.InfoContext(ctx, "compare", "ok", ok) | |
} else { | |
res.BadCount++ | |
l.WarnContext(ctx, "compare", "ok", ok) | |
} | |
} | |
return | |
} | |
type SyncSequencesRes struct { | |
Total int | |
ErrCount int | |
OKCount int | |
Errs map[string]error // seq -> error | |
} | |
func (lr *LogicalRep) SyncSequences(ctx context.Context) (res *SyncSequencesRes, err error) { | |
seqs, err := lr.listSequences(ctx) | |
if err != nil { | |
return nil, fmt.Errorf("list seqs: %w", err) | |
} | |
l := lr.l.With("func", "SyncSequences") | |
res = &SyncSequencesRes{ | |
Total: len(seqs), | |
} | |
for _, seq := range seqs { | |
lv := seq.LastValue * 110 / 100 | |
if lv == seq.LastValue { | |
lv += 10 | |
} | |
l := l.With("seq_name", seq.Name, "src_last_value", seq.LastValue, "dst_last_value", lv) | |
if err := lr.setSeqLastValue(ctx, seq.Name, lv); err != nil { | |
res.ErrCount++ | |
res.Errs[seq.Name] = err | |
l.ErrorContext(ctx, err.Error()) | |
} else { | |
res.OKCount++ | |
l.InfoContext(ctx, "ok") | |
} | |
} | |
return | |
} | |
type Sequence struct { | |
bun.BaseModel `bun:"table:pg_sequences,alias:s"` | |
Schema string `bun:"schemaname"` | |
Name string `bun:"sequencename"` | |
LastValue int | |
} | |
func (lr *LogicalRep) listSequences(ctx context.Context) (res []*Sequence, err error) { | |
q := lr.SrcDB.NewSelect().Model(&res) | |
if lr.Schema != "" { | |
q.Where("schemaname = ?", lr.Schema) | |
} | |
err = q.Scan(ctx) | |
return | |
} | |
func (lr *LogicalRep) setSeqLastValue(ctx context.Context, name string, lastValue int) error { | |
q := "ALTER SEQUENCE " + lr.Schema + "." + name + " RESTART ?" | |
_, err := lr.DstDB.NewRaw(q, lastValue).Exec(ctx) | |
return err | |
} | |
func max(x, y int) int { | |
if x >= y { | |
return x | |
} | |
return y | |
} | |
func wrap(err error, msg string) error { | |
if err == nil { | |
return nil | |
} | |
return fmt.Errorf("%s: %w", msg, err) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment