Created
June 11, 2020 02:40
-
-
Save d4l3k/ef2edb288608d2037abfd57e9fb138b9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"encoding/binary" | |
"flag" | |
"io" | |
"io/ioutil" | |
"log" | |
"os" | |
"./fbs/ap" | |
caffepb "./proto" | |
"github.com/gogo/protobuf/proto" | |
"github.com/pkg/errors" | |
) | |
var fbsFile = flag.String("fbs", "share/vision/fisheye_int8.fbs", "flatbuffer file to laod") | |
func main() { | |
log.SetFlags(log.Lshortfile | log.Flags()) | |
flag.Parse() | |
if err := run2(); err != nil { | |
log.Fatalf("%+v", err) | |
} | |
} | |
func getDims(w ap.Weights) []uint32 { | |
var dims []uint32 | |
for i := 0; i < w.DimsLength(); i++ { | |
dims = append(dims, w.Dims(i)) | |
} | |
return dims | |
} | |
func getWeights(w ap.Tensor) []float32 { | |
var weights []float32 | |
for i := 0; i < w.DataLength(); i++ { | |
weights = append(weights, w.Data(i)) | |
} | |
return weights | |
} | |
func run2() error { | |
buf, err := ioutil.ReadFile(*fbsFile) | |
if err != nil { | |
return err | |
} | |
root := ap.GetRootAsRoot(buf, 0) | |
log.Printf("root %+v", root.LayersLength()) | |
for i := 0; i < root.LayersLength(); i++ { | |
var layer ap.Layer | |
if !root.Layers(&layer, i) { | |
return errors.Errorf("failed to load layer %d", i) | |
} | |
log.Printf("layer %d %s: %d weights", i, layer.Name(), layer.WeightsLength()) | |
for j := 0; j < layer.WeightsLength(); j++ { | |
var weights ap.Weights | |
if !layer.Weights(&weights, j) { | |
return errors.Errorf("failed to load weights %d", j) | |
} | |
log.Printf("weight %d::%d %+v %v", i, j, getDims(weights), weights.A()) | |
var tensor ap.Tensor | |
if weights.Tensor(&tensor) == nil { | |
return errors.Errorf("failed to load tensor") | |
} | |
log.Printf(" - %v", getWeights(tensor)) | |
} | |
} | |
return nil | |
} | |
type Reader struct { | |
reader io.ReadSeeker | |
offset int64 | |
length int64 | |
} | |
func (r *Reader) Len() int64 { | |
return r.length | |
} | |
func (r *Reader) ReadOffset() (int64, error) { | |
curOffset := r.offset | |
var offset uint32 | |
if err := binary.Read(r.reader, binary.LittleEndian, &offset); err != nil { | |
return 0, errors.Wrapf(err, "reading %d", curOffset) | |
} | |
r.offset += 4 | |
return curOffset + int64(offset), nil | |
} | |
func (r *Reader) ReadSOffset() (int64, error) { | |
curOffset := r.offset | |
var offset uint32 | |
if err := binary.Read(r.reader, binary.LittleEndian, &offset); err != nil { | |
return 0, errors.Wrapf(err, "reading %d", curOffset) | |
} | |
r.offset += 4 | |
soffset := 1<<32 - int64(offset) | |
return curOffset + soffset, nil | |
} | |
func (r *Reader) ReadUint16() (int, error) { | |
curOffset := r.offset | |
var v uint16 | |
if err := binary.Read(r.reader, binary.LittleEndian, &v); err != nil { | |
return 0, errors.Wrapf(err, "reading %d", curOffset) | |
} | |
r.offset += 2 | |
return int(v), nil | |
} | |
func (r *Reader) ReadUint32() (int, error) { | |
curOffset := r.offset | |
var v uint32 | |
if err := binary.Read(r.reader, binary.LittleEndian, &v); err != nil { | |
return 0, errors.Wrapf(err, "reading %d", curOffset) | |
} | |
r.offset += 4 | |
return int(v), nil | |
} | |
func (r *Reader) ReadFloat32() (float32, error) { | |
curOffset := r.offset | |
var v float32 | |
if err := binary.Read(r.reader, binary.LittleEndian, &v); err != nil { | |
return 0, errors.Wrapf(err, "reading %d", curOffset) | |
} | |
r.offset += 4 | |
return v, nil | |
} | |
func (r *Reader) Seek(offset int64) error { | |
if _, err := r.reader.Seek(offset, io.SeekStart); err != nil { | |
return errors.Wrapf(err, "seeking %d", offset) | |
} | |
r.offset = offset | |
return nil | |
} | |
func (r *Reader) Offset() int64 { | |
return r.offset | |
} | |
func (r *Reader) PrintDebug(n int) { | |
offset := r.Offset() | |
buf := make([]byte, n) | |
n, err := r.reader.Read(buf) | |
if err != nil { | |
log.Fatalf("%+v", err) | |
} | |
r.offset += int64(n) | |
buf = buf[:n] | |
log.Printf("%d: %+v |%s|", offset, buf, buf) | |
if err := r.Seek(offset); err != nil { | |
log.Fatalf("%+v", err) | |
} | |
} | |
func loadNet() (*caffepb.NetParameter, error) { | |
in, err := ioutil.ReadFile("share/vision/fisheye.prototxt") | |
if err != nil { | |
return nil, err | |
} | |
var net caffepb.NetParameter | |
if err := proto.UnmarshalText(string(in), &net); err != nil { | |
return nil, err | |
} | |
return &net, nil | |
} | |
type vtable struct { | |
Position int64 | |
VTableSize int | |
TableSize int | |
Entries []int | |
} | |
func (r *Reader) ReadVTable(offset int64) (*vtable, error) { | |
curOffset := r.Offset() | |
if err := r.Seek(offset); err != nil { | |
return nil, err | |
} | |
vtableLength, err := r.ReadUint16() | |
if err != nil { | |
return nil, err | |
} | |
tableLength, err := r.ReadUint16() | |
if err != nil { | |
return nil, err | |
} | |
entryCount := (vtableLength - 4) / 2 | |
var entries []int | |
for i := 0; i < entryCount; i++ { | |
offset, err := r.ReadUint16() | |
if err != nil { | |
return nil, err | |
} | |
entries = append(entries, offset) | |
} | |
if err := r.Seek(curOffset); err != nil { | |
return nil, err | |
} | |
table := vtable{ | |
Position: offset, | |
VTableSize: vtableLength, | |
TableSize: tableLength, | |
Entries: entries, | |
} | |
return &table, nil | |
} | |
func (r *Reader) ReadTable() error { | |
log.Printf("loading table %d", r.Offset()) | |
vtableAddr, err := r.ReadSOffset() | |
if err != nil { | |
return err | |
} | |
vtable, err := r.ReadVTable(vtableAddr) | |
if err != nil { | |
return err | |
} | |
log.Printf("vtable %#v", vtable) | |
//r.PrintDebug(100) | |
if vtable.Position == 1866322 { // ptr to vector | |
log.Printf("loading ptr to vector...") | |
r.PrintDebug(128) | |
vec, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
log.Printf("offset %d", vec) | |
if err := r.Seek(vec); err != nil { | |
return err | |
} | |
vecLen, err := r.ReadUint32() | |
if err != nil { | |
return err | |
} | |
log.Printf("vec length = %d", vecLen) | |
var subtables []int64 | |
for i := 0; i < vecLen; i++ { | |
subtable, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
log.Printf("subtable %d", subtable) | |
if subtable >= r.Len() { | |
log.Printf("invalid subtable hmm %d", subtable) | |
if i == 0 { | |
continue | |
} else { | |
return errors.Errorf("too many invalid") | |
} | |
} | |
subtables = append(subtables, subtable) | |
} | |
subtables = subtables[1:] | |
for _, subtable := range subtables { | |
log.Printf("loading subtable %d", subtable) | |
if err := r.Seek(subtable); err != nil { | |
return err | |
} | |
if err := r.ReadTable(); err != nil { | |
return err | |
} | |
} | |
return nil | |
} else if vtable.Position == 1866070 { // layer | |
log.Printf("loading layer entry...") | |
dataAddr, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
nameAddr, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
if err := r.Seek(nameAddr); err != nil { | |
return err | |
} | |
name, err := r.ReadString() | |
if err != nil { | |
return err | |
} | |
log.Printf("name = %q", name) | |
if err := r.Seek(dataAddr); err != nil { | |
return err | |
} | |
subtableCount, err := r.ReadUint32() | |
if err != nil { | |
return err | |
} | |
log.Printf("subtable count = %d", subtableCount) | |
var subtables []int64 | |
for i := 0; i < subtableCount; i++ { | |
subtable, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
subtables = append(subtables, subtable) | |
} | |
for _, subtable := range subtables { | |
if err := r.Seek(subtable); err != nil { | |
return err | |
} | |
if err := r.ReadTable(); err != nil { | |
return err | |
} | |
} | |
return nil | |
} else if vtable.Position == 1866292 { | |
log.Printf("loading weights entry...") | |
datatable, err := r.ReadOffset() // table of vec of bytes | |
if err != nil { | |
return err | |
} | |
if _, err := r.ReadUint32(); err != nil { // unknown (always seems to be 1) | |
return err | |
} | |
dimsAddr, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
if err := r.Seek(dimsAddr); err != nil { | |
return err | |
} | |
dims, err := r.ReadUint32s() | |
if err != nil { | |
return err | |
} | |
log.Printf("dims = %+v", dims) | |
if err := r.Seek(datatable); err != nil { | |
return err | |
} | |
vtableAddr, err := r.ReadSOffset() | |
if err != nil { | |
return err | |
} | |
vtable, err := r.ReadVTable(vtableAddr) | |
if err != nil { | |
return err | |
} | |
if vtable.Position != 1866322 { | |
return errors.Errorf("got unexpected table") | |
} | |
dataAddr, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
if err := r.Seek(dataAddr); err != nil { | |
return err | |
} | |
data, err := r.ReadFloat32s() | |
if err != nil { | |
return err | |
} | |
log.Printf("read data (%d) = %v", len(data), data) | |
r.PrintDebug(100) | |
return nil | |
} else { | |
return errors.Errorf("unknown vtable %#v", vtable) | |
} | |
} | |
func (r *Reader) ReadUint32s() ([]int, error) { | |
length, err := r.ReadUint32() | |
if err != nil { | |
return nil, err | |
} | |
var nums []int | |
for i := 0; i < length; i++ { | |
v, err := r.ReadUint32() | |
if err != nil { | |
return nil, err | |
} | |
nums = append(nums, v) | |
} | |
return nums, nil | |
} | |
func (r *Reader) ReadFloat32s() ([]float32, error) { | |
length, err := r.ReadUint32() | |
if err != nil { | |
return nil, err | |
} | |
var nums []float32 | |
for i := 0; i < length; i++ { | |
v, err := r.ReadFloat32() | |
if err != nil { | |
return nil, err | |
} | |
nums = append(nums, v) | |
} | |
return nums, nil | |
} | |
func (r *Reader) ReadBytes() ([]byte, error) { | |
length, err := r.ReadUint32() | |
if err != nil { | |
return nil, err | |
} | |
buf := make([]byte, length) | |
n, err := r.reader.Read(buf) | |
if err != nil { | |
return nil, err | |
} | |
if n != length { | |
return nil, errors.Errorf("failed to read entire string: expected %d, got %d", length, n) | |
} | |
return buf, nil | |
} | |
func (r *Reader) ReadString() (string, error) { | |
buf, err := r.ReadBytes() | |
if err != nil { | |
return "", err | |
} | |
return string(buf), nil | |
} | |
func run() error { | |
net, err := loadNet() | |
if err != nil { | |
return err | |
} | |
var layers []string | |
for _, layer := range net.Layer { | |
layers = append(layers, *layer.Name) | |
} | |
log.Printf("layers (%d): %+v", len(layers), layers) | |
f, err := os.Open(*fbsFile) | |
if err != nil { | |
return err | |
} | |
defer f.Close() | |
data, err := ioutil.ReadAll(f) | |
if err != nil { | |
return err | |
} | |
r := Reader{ | |
reader: f, | |
offset: 0, | |
length: int64(len(data)), | |
} | |
if err := r.Seek(0); err != nil { | |
return err | |
} | |
log.Printf("finding tables") | |
vtables := map[int64]vtable{} | |
vtablesCounts := map[int64]int{} | |
var tables []int64 | |
for { | |
offset := r.Offset() | |
position, err := r.ReadSOffset() | |
if errors.Is(err, io.EOF) { | |
break | |
} else if err != nil { | |
return err | |
} | |
if position >= 0 && position < int64(len(data)) { | |
entry, err := r.ReadVTable(position) | |
if err != nil { | |
return err | |
} | |
tables = append(tables, offset) | |
vtables[position] = *entry | |
vtablesCounts[position] += 1 | |
log.Printf("candidate %d: %#v", offset, entry) | |
} | |
} | |
log.Printf("found %d distinct vtables", len(vtables)) | |
log.Printf("found %d distinct tables", len(tables)) | |
log.Printf("found table counts %+v", vtablesCounts) | |
log.Printf("loading root table") | |
if err := r.Seek(0); err != nil { | |
return err | |
} | |
rootTableOffset, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
if err := r.Seek(rootTableOffset); err != nil { | |
return err | |
} | |
if err := r.ReadTable(); err != nil { | |
return err | |
} | |
return nil | |
if err := r.Seek(0); err != nil { | |
return err | |
} | |
// find layer names in flatbuffer | |
layerLocs := map[string]int{} | |
for _, layer := range layers { | |
search := 0 | |
for { | |
match := bytes.Index(data[search:], []byte(layer)) | |
if match < 0 { | |
return errors.Errorf("failed to find %q", layer) | |
} | |
match += search | |
search = match + 1 | |
length := match - 4 | |
if err := r.Seek(int64(length)); err != nil { | |
return err | |
} | |
elements, err := r.ReadUint32() | |
if err != nil { | |
return err | |
} | |
if elements != len(layer) { | |
log.Printf("# elements doesn't match str: %d != len(%q), %s", elements, layer, data[match:match+elements]) | |
continue | |
} | |
log.Printf("found %q = %d", layer, match) | |
layerLocs[layer] = length | |
break | |
} | |
} | |
log.Printf("found %d layerLocs", len(layerLocs)) | |
// Find references to layer strings | |
layerRefs := map[string]int64{} | |
for layer, target := range layerLocs { | |
if err := r.Seek(0); err != nil { | |
return err | |
} | |
for { | |
off := r.Offset() | |
match, err := r.ReadOffset() | |
if err != nil { | |
return err | |
} | |
if match == int64(target) { | |
log.Printf("found %q ref at %d", layer, off) | |
layerRefs[layer] = match | |
break | |
} | |
} | |
} | |
// Find VTables for table | |
for layer, ref := range layerRefs { | |
found := false | |
for i := 0; i < 100; i += 4 { | |
attempt := ref - int64(i) | |
if err := r.Seek(attempt); err != nil { | |
return err | |
} | |
soffset, err := r.ReadSOffset() | |
if err != nil { | |
return err | |
} | |
if soffset < int64(len(data)) { | |
log.Printf("%q (layerref %d): %d soffset %d", layer, ref, attempt, soffset) | |
if _, err := r.ReadVTable(soffset); err != nil { | |
return err | |
} | |
r.PrintDebug(64) | |
found = true | |
break | |
} | |
} | |
if !found { | |
return errors.Errorf("failed to find %q", layer) | |
} | |
} | |
if err := r.Seek(rootTableOffset); err != nil { | |
return err | |
} | |
log.Printf("root table offset %d", rootTableOffset) | |
if err := r.Seek(rootTableOffset); err != nil { | |
return err | |
} | |
tableOffset := r.Offset() | |
vTableOffset, err := r.ReadSOffset() | |
if err != nil { | |
return err | |
} | |
log.Printf("vtable offset %d", vTableOffset) | |
if err := r.Seek(vTableOffset); err != nil { | |
return err | |
} | |
vtableLength, err := r.ReadUint16() | |
if err != nil { | |
return err | |
} | |
log.Printf("vtable length %d", vtableLength) | |
tableLength, err := r.ReadUint16() | |
if err != nil { | |
return err | |
} | |
log.Printf("table length %d", tableLength) | |
var entries []int | |
for i := 0; i < (vtableLength/2 - 1); i++ { | |
offset, err := r.ReadUint16() | |
if err != nil { | |
return err | |
} | |
entries = append(entries, offset) | |
} | |
log.Printf("entries %+v", entries) | |
if err := r.Seek(tableOffset + int64(entries[0])); err != nil { | |
return err | |
} | |
val1, err := r.ReadUint16() | |
if err != nil { | |
return err | |
} | |
val2, err := r.ReadUint16() | |
if err != nil { | |
return err | |
} | |
log.Printf("table entries %d %d", val1, val2) | |
r.PrintDebug(300) | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When I do
go run decode.go
I receive the following errors: