Created
July 12, 2024 12:57
-
-
Save thiagozs/772fd1246ef7f6cbca06f6a1fbf6dc4e to your computer and use it in GitHub Desktop.
Dynamic scan you query DB to struct golang
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dynscan | |
import ( | |
"database/sql" | |
"fmt" | |
"reflect" | |
"strings" | |
"time" | |
"github.com/google/uuid" | |
) | |
// DynamicScan scans a single row or multiple rows into the provided destination struct or slice | |
func DynamicScan(db *sql.DB, result any, tableName string, dest any) error { | |
switch res := result.(type) { | |
case *sql.Row: | |
columns, err := fetchColumns(db, tableName) | |
if err != nil { | |
return fmt.Errorf("error fetching columns: %w", err) | |
} | |
return scanRow(res, columns, dest) | |
case *sql.Rows: | |
columns, err := res.Columns() | |
if err != nil { | |
return fmt.Errorf("error fetching columns from rows: %w", err) | |
} | |
return scanRows(res, columns, dest) | |
default: | |
return fmt.Errorf("unsupported result type %T", result) | |
} | |
} | |
// scanRow scans a single row into the provided destination struct | |
func scanRow(row *sql.Row, columns []string, dest any) error { | |
fieldPtrs, fieldMap, err := prepareFieldPointers(dest, columns) | |
if err != nil { | |
return err | |
} | |
if err := row.Scan(fieldPtrs...); err != nil { | |
return fmt.Errorf("error scanning row: %w", err) | |
} | |
if err := setStructFields(dest, columns, fieldPtrs, fieldMap); err != nil { | |
return fmt.Errorf("error setting struct fields: %w", err) | |
} | |
return nil | |
} | |
// scanRows scans multiple rows into the provided destination slice | |
func scanRows(rows *sql.Rows, columns []string, dest any) error { | |
destVal := reflect.ValueOf(dest) | |
if destVal.Kind() != reflect.Ptr || destVal.Elem().Kind() != reflect.Slice { | |
return fmt.Errorf("dest must be a pointer to a slice") | |
} | |
destSlice := destVal.Elem() | |
elemType := destSlice.Type().Elem() | |
for rows.Next() { | |
elem := reflect.New(elemType).Elem() | |
fieldPtrs, fieldMap, err := prepareFieldPointers(elem.Addr().Interface(), columns) | |
if err != nil { | |
return err | |
} | |
if err := rows.Scan(fieldPtrs...); err != nil { | |
return fmt.Errorf("error scanning rows: %w", err) | |
} | |
if err := setStructFields(elem.Addr().Interface(), columns, fieldPtrs, fieldMap); err != nil { | |
return fmt.Errorf("error setting struct fields: %w", err) | |
} | |
destSlice.Set(reflect.Append(destSlice, elem)) | |
} | |
return rows.Err() | |
} | |
// prepareFieldPointers prepares pointers for fields based on the column names | |
func prepareFieldPointers(dest any, columns []string) ([]interface{}, map[string]int, error) { | |
v := reflect.ValueOf(dest).Elem() | |
t := v.Type() | |
fieldPtrs := make([]interface{}, len(columns)) | |
fieldMap := make(map[string]int) | |
for i := 0; i < t.NumField(); i++ { | |
field := t.Field(i) | |
fieldMap[cleanTag(field.Tag.Get("json"))] = i | |
} | |
for i, column := range columns { | |
idx, ok := fieldMap[column] | |
if !ok { | |
fieldPtrs[i] = new(interface{}) // Unmapped columns | |
continue | |
} | |
field := v.Field(idx) | |
fieldType := field.Type() | |
switch fieldType { | |
case reflect.TypeOf(time.Time{}): | |
fieldPtrs[i] = new(sql.NullTime) | |
case reflect.TypeOf(bool(false)): | |
fieldPtrs[i] = new(sql.NullBool) | |
case reflect.TypeOf(""): | |
fieldPtrs[i] = new(sql.NullString) | |
case reflect.TypeOf(uuid.UUID{}): | |
fieldPtrs[i] = new(sql.NullString) | |
case reflect.TypeOf([]byte{}): | |
fieldPtrs[i] = new([]byte) | |
default: | |
if fieldType.Kind() == reflect.Ptr { | |
fieldPtrs[i] = reflect.New(fieldType.Elem()).Interface() | |
} else { | |
fieldPtrs[i] = field.Addr().Interface() | |
} | |
} | |
} | |
return fieldPtrs, fieldMap, nil | |
} | |
// setStructFields sets the scanned values to the struct fields | |
func setStructFields(dest any, columns []string, fieldPtrs []interface{}, fieldMap map[string]int) error { | |
v := reflect.ValueOf(dest).Elem() | |
for i, column := range columns { | |
idx, ok := fieldMap[column] | |
if !ok { | |
continue | |
} | |
field := v.Field(idx) | |
fieldType := field.Type() | |
if !field.CanSet() { | |
return fmt.Errorf("cannot set field %s", column) | |
} | |
switch ptr := fieldPtrs[i].(type) { | |
case *sql.NullTime: | |
if ptr.Valid { | |
field.Set(reflect.ValueOf(ptr.Time)) | |
} else { | |
field.Set(reflect.ValueOf(time.Time{})) | |
} | |
case *sql.NullBool: | |
if ptr.Valid { | |
field.SetBool(ptr.Bool) | |
} else { | |
field.SetBool(false) | |
} | |
case *sql.NullString: | |
if ptr.Valid { | |
if fieldType == reflect.TypeOf(uuid.UUID{}) { | |
u, err := uuid.Parse(ptr.String) | |
if err != nil { | |
return fmt.Errorf("error parsing UUID: %w", err) | |
} | |
field.Set(reflect.ValueOf(u)) | |
} else { | |
field.SetString(ptr.String) | |
} | |
} else { | |
field.Set(reflect.Zero(fieldType)) | |
} | |
case *[]byte: | |
field.SetBytes(*ptr) | |
default: | |
if field.Kind() == reflect.Ptr { | |
if reflect.ValueOf(ptr).Elem().IsValid() { | |
field.Set(reflect.ValueOf(ptr).Elem()) | |
} else { | |
field.Set(reflect.Zero(fieldType)) | |
} | |
} else { | |
field.Set(reflect.ValueOf(ptr).Elem()) | |
} | |
} | |
} | |
return nil | |
} | |
// fetchColumns fetches the columns from a dummy query | |
func fetchColumns(db *sql.DB, tableName string) ([]string, error) { | |
query := fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) | |
rows, err := db.Query(query) | |
if err != nil { | |
return nil, err | |
} | |
defer rows.Close() | |
return rows.Columns() | |
} | |
func cleanTag(tag string) string { | |
return strings.Split(tag, ",")[0] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment