Last active
April 3, 2017 23:19
-
-
Save kmtr/4143761b73ecd7f470f3e8a7e6360df7 to your computer and use it in GitHub Desktop.
PostgreSQL struct generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"fmt" | |
"io" | |
"log" | |
"os" | |
"strings" | |
"text/template" | |
_ "github.com/lib/pq" | |
) | |
func main() { | |
datasourceName := "" | |
dbName := "" | |
db, err := sql.Open("postgres", datasourceName) | |
if err != nil { | |
fmt.Fprint(os.Stdout, err) | |
os.Exit(1) | |
} | |
columns, err := fetchColumnsSchema(db, dbName, "public") | |
if err != nil { | |
fmt.Fprintln(os.Stdout, err) | |
os.Exit(1) | |
} | |
tables := NewTables(columns) | |
fmt.Fprintln(os.Stdout, "package main") | |
for _, table := range tables { | |
if err := table.WriteTemplate(os.Stdout); err != nil { | |
fmt.Fprintln(os.Stdout, err) | |
os.Exit(1) | |
} | |
} | |
db.Close() | |
} | |
type YesOrNo bool | |
func (yn *YesOrNo) Scan(src interface{}) error { | |
switch t := src.(type) { | |
default: | |
return fmt.Errorf("Unexpected type %v", t) | |
case string: | |
if src == "YES" { | |
*yn = true | |
} else if src == "NO" { | |
*yn = false | |
} else { | |
return fmt.Errorf("Unexpected string %s", src) | |
} | |
return nil | |
} | |
} | |
type Table struct { | |
Name string | |
Columns []InformationSchemaColumns | |
} | |
func (table *Table) WriteTemplate(w io.Writer) error { | |
tmpl, err := template.New("dbstruct").Parse(`type {{.Name}} struct { | |
{{range .Columns}}{{.Generate}} | |
{{end}} | |
} | |
`) | |
if err != nil { | |
return err | |
} | |
return tmpl.Execute(w, table) | |
} | |
func NewTables(columns []InformationSchemaColumns) []Table { | |
tmap := map[string]Table{} | |
for _, column := range columns { | |
table, ok := tmap[column.TableName] | |
if !ok { | |
table = Table{} | |
table.Name = toCamel(column.TableName) | |
table.Columns = []InformationSchemaColumns{} | |
} | |
table.Columns = append(table.Columns, column) | |
tmap[column.TableName] = table | |
} | |
tables := []Table{} | |
for _, t := range tmap { | |
tables = append(tables, t) | |
} | |
return tables | |
} | |
type InformationSchemaColumns struct { | |
TableName string | |
ColumnName string | |
OrdinalPosition int | |
IsNullable YesOrNo | |
DataType string | |
NumericPrecision sql.NullInt64 | |
NumericPrecisionRadix sql.NullInt64 | |
NumericScale sql.NullInt64 | |
DateTimePrecision sql.NullInt64 | |
} | |
func (isc InformationSchemaColumns) Generate() string { | |
name := isc.ColumnName | |
fieldName := toCamel(name) | |
typeName := detectTypeName(isc.DataType, isc.NumericPrecision.Int64, bool(isc.IsNullable)) | |
return fmt.Sprintf("%s %s `json:\"%s\" db:\"%s\"`", fieldName, typeName, name, name) | |
} | |
func detectTypeName(dataType string, numPrec int64, isNil bool) string { | |
pointer := "" | |
if isNil { | |
pointer = "*" | |
} | |
typeName := "unknown" | |
switch dataType { | |
case "character varying": | |
typeName = "string" | |
case "integer": | |
if numPrec == 64 { | |
typeName = fmt.Sprintf("int%d", numPrec) | |
} else { | |
typeName = fmt.Sprint("int") | |
} | |
case "bigint": | |
if numPrec == 64 { | |
typeName = fmt.Sprintf("int%d", numPrec) | |
} else { | |
typeName = fmt.Sprint("int") | |
} | |
case "double precision": | |
if numPrec == 64 { | |
typeName = fmt.Sprintf("float64") | |
} else { | |
typeName = fmt.Sprintf("float32") | |
} | |
case "timestamp without time zone": | |
typeName = "time.Time" | |
case "boolean": | |
typeName = "bool" | |
default: | |
log.Printf("%s\n", dataType) | |
} | |
return pointer + typeName | |
} | |
func fetchColumnsSchema(db *sql.DB, tableCatalog, schema string) ([]InformationSchemaColumns, error) { | |
query := `SELECT | |
table_name, | |
column_name, | |
ordinal_position, | |
is_nullable, | |
data_type, | |
numeric_precision, | |
numeric_precision_radix, | |
numeric_scale, | |
datetime_precision | |
FROM information_schema.columns | |
WHERE table_catalog = $1 AND table_schema = $2 | |
ORDER BY table_name, ordinal_position | |
` | |
rows, err := db.Query(query, tableCatalog, schema) | |
if err != nil { | |
return nil, err | |
} | |
columns := []InformationSchemaColumns{} | |
for rows.Next() { | |
isc := InformationSchemaColumns{} | |
err := rows.Scan( | |
&isc.TableName, | |
&isc.ColumnName, | |
&isc.OrdinalPosition, | |
&isc.IsNullable, | |
&isc.DataType, | |
&isc.NumericPrecision, | |
&isc.NumericPrecisionRadix, | |
&isc.NumericScale, | |
&isc.DateTimePrecision, | |
) | |
if err != nil { | |
return nil, err | |
} | |
columns = append(columns, isc) | |
} | |
return columns, nil | |
} | |
func toCamel(s string) string { | |
s = strings.Trim(s, " ") | |
n := "" | |
capNext := true | |
for _, v := range s { | |
if v >= '0' && v <= '9' { | |
n += string(v) | |
} | |
if v >= 'A' && v <= 'Z' { | |
n += string(v) | |
} | |
if v >= 'a' && v <= 'z' { | |
if capNext { | |
n += strings.ToUpper(string(v)) | |
} else { | |
n += string(v) | |
} | |
} | |
if v == '_' || v == ' ' { | |
capNext = true | |
} else { | |
capNext = false | |
} | |
} | |
return n | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment