Created
January 8, 2018 10:34
-
-
Save plutov/b88486ac41678a88b8edc2332f936a1f to your computer and use it in GitHub Desktop.
tensorflow4.go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func normalizeImage(body io.ReadCloser) (*tensorflow.Tensor, error) { | |
var buf bytes.Buffer | |
io.Copy(&buf, body) | |
tensor, err := tensorflow.NewTensor(buf.String()) | |
if err != nil { | |
return nil, err | |
} | |
graph, input, output, err := getNormalizedGraph() | |
if err != nil { | |
return nil, err | |
} | |
session, err := tensorflow.NewSession(graph, nil) | |
if err != nil { | |
return nil, err | |
} | |
normalized, err := session.Run( | |
map[tensorflow.Output]*tensorflow.Tensor{ | |
input: tensor, | |
}, | |
[]tensorflow.Output{ | |
output, | |
}, | |
nil) | |
if err != nil { | |
return nil, err | |
} | |
return normalized[0], nil | |
} | |
// Creates a graph to decode, rezise and normalize an image | |
func getNormalizedGraph() (graph *tensorflow.Graph, input, output tensorflow.Output, err error) { | |
s := op.NewScope() | |
input = op.Placeholder(s, tensorflow.String) | |
// 3 return RGB image | |
decode := op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)) | |
// Sub: returns x - y element-wise | |
output = op.Sub(s, | |
// make it 224x224: inception specific | |
op.ResizeBilinear(s, | |
// inserts a dimension of 1 into a tensor's shape. | |
op.ExpandDims(s, | |
// cast image to float type | |
op.Cast(s, decode, tensorflow.Float), | |
op.Const(s.SubScope("make_batch"), int32(0))), | |
op.Const(s.SubScope("size"), []int32{224, 224})), | |
// mean = 117: inception specific | |
op.Const(s.SubScope("mean"), float32(117))) | |
graph, err = s.Finalize() | |
return graph, input, output, err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment