Last active
April 12, 2016 17:42
-
-
Save ericlagergren/a7f05b9ad12b627cf9994fdbca001769 to your computer and use it in GitHub Desktop.
Parse PostgreSQL arrays
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"errors" | |
"fmt" | |
"strconv" | |
"strings" | |
) | |
func main() { | |
data := map[string]sql.Scanner{ | |
`{google.com,https://foo.bar.com,www.example.com?abc=123#foo}`: &StringArray{}, | |
`{}`: &StringArray{}, | |
`{NULL,NULL,NULL}`: &NullStringArray{}, | |
`{"something coo\"l"}`: &StringArray{}, | |
`{foo\,bar,baz}`: &StringArray{}, | |
`{100,200,300,500,10000,-1}`: &IntArray{}, | |
`{NULL,NULL,NULL,NULL}`: &NullIntArray{}, | |
`{1234,0}`: &IntArray{}, | |
`{ }`: &IntArray{}, | |
} | |
for k, v := range data { | |
v.Scan([]byte(k)) | |
var length int | |
switch t := v.(type) { | |
case *IntArray: | |
length = len(*t) | |
case *NullIntArray: | |
length = len(*t) | |
case *StringArray: | |
length = len(*t) | |
case *NullStringArray: | |
length = len(*t) | |
} | |
fmt.Printf("%s (%d): %v\n", k, length, v) | |
} | |
} | |
type ParseError struct { | |
c byte | |
} | |
func (p *ParseError) Error() string { | |
return fmt.Sprintf("%c is an invalid character", p.c) | |
} | |
func parse(data []byte, pred func(byte) bool, get func([]byte) error) (err error) { | |
if len(data) < 2 { | |
return errors.New("len(data) < 2") | |
} | |
if data[0] == '{' && data[len(data)-1] == '}' { | |
if len(data) == 2 { | |
return nil | |
} | |
data = data[1 : len(data)-1] | |
} | |
var mark int | |
eof := len(data) - 1 | |
for i := range data { | |
switch c := data[i]; { | |
case i == eof: | |
i++ | |
fallthrough | |
case c == ',': | |
if data[i-1] != '\\' { | |
err = get(data[mark:i]) | |
if err != nil { | |
return err | |
} | |
mark = i + 1 | |
} | |
case !pred(c): | |
return &ParseError{c: c} | |
} | |
} | |
return nil | |
} | |
func getData(val interface{}) ([]byte, bool) { | |
if val == nil { | |
return nil, true | |
} | |
data, ok := val.([]byte) | |
return data, ok | |
} | |
type IntArray []int64 | |
func (i *IntArray) Scan(val interface{}) error { | |
data, ok := getData(val) | |
if !ok { | |
return errors.New("IntArray.Scan: invalid type") | |
} | |
return parse(data, isnum, func(data []byte) error { | |
v, err := strconv.ParseInt(string(data), 10, 64) | |
if err != nil { | |
return err | |
} | |
*i = append(*i, v) | |
return nil | |
}) | |
} | |
func isnum(c byte) bool { | |
return (c >= '0' && c <= '9') || c == '-' | |
} | |
type NullIntArray []sql.NullInt64 | |
func (i *NullIntArray) Scan(val interface{}) error { | |
data, ok := getData(val) | |
if !ok { | |
return errors.New("NullIntArray.Scan: invalid type") | |
} | |
var n sql.NullInt64 | |
return parse(data, isNullNum, func(data []byte) error { | |
// If data != []byte("NULL") | |
// Avoiding an alloc if possible. | |
n.Valid = len(data) != 4 || | |
data[0] != 'N' || data[1] != 'U' || | |
data[2] != 'L' || data[3] != 'L' | |
if n.Valid { | |
v, err := strconv.ParseInt(string(data), 10, 64) | |
if err != nil { | |
return err | |
} | |
n.Int64 = v | |
} | |
*i = append(*i, n) | |
return nil | |
}) | |
} | |
func isNullNum(c byte) bool { | |
return (c >= '0' && c <= '9') || c == '-' || | |
c == 'N' || c == 'U' || c == 'L' | |
} | |
var repl = strings.NewReplacer(`\\`, `\`, `\`, ``) | |
type StringArray []string | |
func (s *StringArray) Scan(val interface{}) error { | |
data, ok := getData(val) | |
if !ok { | |
return errors.New("StringArray.Scan: invalid type") | |
} | |
return parse(data, func(byte) bool { return true }, func(data []byte) error { | |
*s = append(*s, repl.Replace(string(data))) | |
return nil | |
}) | |
} | |
type NullStringArray []sql.NullString | |
func (n *NullStringArray) Scan(val interface{}) error { | |
data, ok := getData(val) | |
if !ok { | |
return errors.New("NullStringArray.Scan: invalid type") | |
} | |
var s sql.NullString | |
return parse(data, func(byte) bool { return true }, func(data []byte) error { | |
// Avoiding an alloc if possible. | |
s.Valid = len(data) != 4 || | |
data[0] != 'N' || data[1] != 'U' || | |
data[2] != 'L' || data[3] != 'L' | |
// If data != []byte("NULL") | |
if s.Valid { | |
s.String = repl.Replace(string(data)) | |
} | |
*n = append(*n, s) | |
return nil | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment