Last active
April 14, 2023 03:20
-
-
Save compnski/a89a5e53eb308671bd6e to your computer and use it in GitHub Desktop.
Builds a SQL statement to hash a table in any of postgres/mysql/redshift, with the same result across the 3.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"flag" | |
"fmt" | |
"strings" | |
) | |
const DB_REDSHIFT = "redshift" | |
const DB_POSTGRES = "postgres" | |
const DB_MYSQL = "mysql" | |
type Table struct { | |
Name string | |
Schema string | |
Columns []Column | |
} | |
type Column struct { | |
Name string | |
Type string | |
NotNull bool | |
Encoding string | |
} | |
var ( | |
tableName string | |
schemaName string | |
idCol string | |
lowerBound string | |
upperBound string | |
colList string | |
dbType string | |
tablePath string | |
debug bool | |
) | |
func init() { | |
flag.StringVar(&tableName, "table", "", "Table name.") | |
flag.StringVar(&schemaName, "schema", "", "Schema name.") | |
flag.StringVar(&colList, "cols", "", "Comma separated list of column names/type pairs, E.g. (name 1, type1, name2, type2, ...).") | |
flag.StringVar(&idCol, "id", "id", "Id column") | |
flag.StringVar(&lowerBound, "lower", "0", "Lower value of id column, exclusive, to hash.") | |
flag.StringVar(&upperBound, "upper", "1000000", "Upper value of id column, inclusive, to hash.") | |
flag.StringVar(&dbType, "dbType", DB_POSTGRES, "Type of database, only postgres / redshift are supported") | |
flag.BoolVar(&debug, "debug", false, "debug") | |
} | |
func main() { | |
flag.Parse() | |
if colList == "" { | |
colList = strings.Join(flag.Args(), " ") | |
} | |
if tableName == "" || colList == "" { | |
flag.Usage() | |
return | |
} | |
table := makeTable(tableName, schemaName, strings.Split(colList, ",")...) | |
fmt.Println("") | |
if schemaName != "" { | |
tablePath = fmt.Sprintf(`%s.%s`, schemaName, tableName) | |
} else { | |
tablePath = tableName | |
} | |
fmt.Println(hashAllCols(table, idCol, lowerBound, upperBound)) | |
} | |
func makeTable(name, schema string, colPairs ...string) *Table { | |
cols := []Column{} | |
for i := 0; i < len(colPairs); i += 2 { | |
cols = append(cols, Column{Name: strings.TrimSpace(colPairs[i]), Type: strings.TrimSpace(colPairs[i+1])}) | |
} | |
return &Table{Name: name, Schema: schema, Columns: cols} | |
} | |
func hashAllCols(table *Table, idCol, lowerBound, upperBound string) string { | |
var ( | |
colSqls = make([]string, len(table.Columns)) | |
rowSql string | |
) | |
switch dbType { | |
case DB_MYSQL: | |
for idx, col := range table.Columns { | |
colSqls[idx] = colAsString_mysql(&col) | |
} | |
rowSql = fmt.Sprintf("concat(%s)", strings.Join(colSqls, ", ")) | |
if debug { | |
rowSql = fmt.Sprintf("concat_ws(',',%s)", strings.Join(colSqls, ",")) | |
} | |
//colAsString = colAsString_mysql | |
default: | |
for idx, col := range table.Columns { | |
colSqls[idx] = colAsString_pg(&col) | |
} | |
rowSql = strings.Join(colSqls, " || ") | |
if debug { | |
rowSql = strings.Join(colSqls, " || ',' || ") | |
} | |
} | |
if debug { | |
return fmt.Sprintf(`select md5(%s), %s from %s where %s > '%s' and %s <= '%s' order by %s desc`, | |
strings.Replace(rowSql, "md5", "", -1), strings.Replace(rowSql, "md5", "", -1), tablePath, idCol, lowerBound, idCol, upperBound, idCol) | |
} | |
innerQuery := fmt.Sprintf(`select md5(%s) as hash from %s where %s > '%s' and %s <= '%s'`, | |
rowSql, tablePath, idCol, lowerBound, idCol, upperBound) | |
outerQuery := fmt.Sprintf("select %s from (%s) a;", getSumOfHash("hash"), innerQuery) | |
return outerQuery | |
} | |
func colAsString_pg(col *Column) string { | |
colSql := fmt.Sprintf(`"%s"`, col.Name) | |
if col.Type == "date" { | |
colSql = fmt.Sprintf("(%s - '0001-01-01'::date)", colSql) | |
} | |
if strings.Contains(col.Type, "timestamp") { | |
colSql = fmt.Sprintf("floor(extract(epoch from %s))", colSql) | |
} | |
if col.Type == "boolean" { | |
colSql = colSql + "::integer" | |
} | |
if strings.Contains(col.Type, "varchar") { | |
colSql = fmt.Sprintf("md5(%s)", colSql) | |
} else { | |
colSql = fmt.Sprintf("md5(%s::text)", colSql) | |
} | |
if !col.NotNull { | |
colSql = fmt.Sprintf("coalesce(%s, ' ')", colSql) | |
} | |
return colSql | |
} | |
func colAsString_mysql(col *Column) string { | |
colSql := fmt.Sprintf("%s", col.Name) | |
if col.Type == "date" { | |
colSql = fmt.Sprintf("(to_days(%s) - 366)", colSql) //366 to represent diff from Day 1, Year 1, not Year 0 which never existed. | |
} | |
if strings.Contains(col.Type, "timestamp") || col.Type == "datetime" { | |
colSql = fmt.Sprintf("floor(unix_timestamp(%s - interval 7 hour))", colSql) | |
} | |
colSql = fmt.Sprintf("md5(%s)", colSql) | |
if !col.NotNull { | |
colSql = fmt.Sprintf("coalesce(%s, ' ')", colSql) | |
} | |
return colSql | |
} | |
func getSumOfHash(col string) string { | |
var queryPart string | |
switch dbType { | |
case DB_POSTGRES: | |
queryPart = `sum(('x'||substring(%s,%d,8))::bit(32)::bigint)` | |
case DB_REDSHIFT: | |
queryPart = `sum(trunc(strtol(substring(%s,%d,8),16)))` | |
case DB_MYSQL: | |
queryPart = `sum(cast(conv(substring(%s,%d,8), 16, 10) as unsigned))` | |
} | |
queryParts := []string{fmt.Sprintf(queryPart, col, 1), fmt.Sprintf(queryPart, col, 9), fmt.Sprintf(queryPart, col, 17), fmt.Sprintf(queryPart, col, 25)} | |
return strings.Join(queryParts, ", ") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A collision is certainly possible, but exceedingly unlikely, I'm not sure it'd ever come up.
This was my attempt at some rsync-like syncing. We never ended up using much of this matching. In the end, we mostly relied on lastModifiedAt / lastUpdatedAt timestamp to know when to pull more rows.
There also might be other (faster) hash functions implemented by now -- Redshift was quite lacking at the time, but has made good progress.