-
-
Save puma007/fc055dc6fd95a2d9543cef53e568c287 to your computer and use it in GitHub Desktop.
Tensorflow Serving Go client for the inception model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Tensorflow Serving Go client for the inception model | |
// First of all compile the proto files: | |
// git clone --recursive https://github.com/tensorflow/serving.git | |
// protoc -I=serving -I serving/tensorflow --go_out=plugins=grpc:$GOPATH/src serving/tensorflow_serving/apis/*.proto | |
// protoc -I=serving/tensorflow --go_out=plugins=grpc:$GOPATH/src serving/tensorflow/tensorflow/core/framework/*.proto | |
// protoc -I=serving/tensorflow --go_out=plugins=grpc:$GOPATH/src serving/tensorflow/tensorflow/core/protobuf/{saver,meta_graph}.proto | |
// protoc -I=serving/tensorflow --go_out=plugins=grpc:$GOPATH/src serving/tensorflow/tensorflow/core/example/*.proto | |
package main | |
import ( | |
"context" | |
"flag" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"os" | |
"path/filepath" | |
tf_core_framework "tensorflow/core/framework" | |
pb "tensorflow_serving/apis" | |
google_protobuf "github.com/golang/protobuf/ptypes/wrappers" | |
tf "github.com/tensorflow/tensorflow/tensorflow/go" | |
"google.golang.org/grpc" | |
) | |
func main() { | |
servingAddress := flag.String("serving-address", "localhost:9000", "The tensorflow serving address") | |
flag.Parse() | |
if flag.NArg() != 1 { | |
fmt.Println("Usage: " + os.Args[0] + " --serving-address localhost:9000 path/to/img.png") | |
os.Exit(1) | |
} | |
imgPath, err := filepath.Abs(flag.Arg(0)) | |
if err != nil { | |
log.Fatalln(err) | |
} | |
imageBytes, err := ioutil.ReadFile(imgPath) | |
if err != nil { | |
log.Fatalln(err) | |
} | |
tensor, err := tf.NewTensor(string(imageBytes)) | |
if err != nil { | |
log.Fatalln("Cannot read image file") | |
} | |
tensorString, ok := tensor.Value().(string) | |
if !ok { | |
log.Fatalln("Cannot type assert tensor value to string") | |
} | |
request := &pb.PredictRequest{ | |
ModelSpec: &pb.ModelSpec{ | |
Name: "inception", | |
SignatureName: "predict_images", | |
Version: &google_protobuf.Int64Value{ | |
Value: int64(1), | |
}, | |
}, | |
Inputs: map[string]*tf_core_framework.TensorProto{ | |
"images": &tf_core_framework.TensorProto{ | |
Dtype: tf_core_framework.DataType_DT_STRING, | |
TensorShape: &tf_core_framework.TensorShapeProto{ | |
Dim: []*tf_core_framework.TensorShapeProto_Dim{ | |
&tf_core_framework.TensorShapeProto_Dim{ | |
Size: int64(1), | |
}, | |
}, | |
}, | |
StringVal: [][]byte{[]byte(tensorString)}, | |
}, | |
}, | |
} | |
conn, err := grpc.Dial(*servingAddress, grpc.WithInsecure()) | |
if err != nil { | |
log.Fatalf("Cannot connect to the grpc server: %v\n", err) | |
} | |
defer conn.Close() | |
client := pb.NewPredictionServiceClient(conn) | |
resp, err := client.Predict(context.Background(), request) | |
if err != nil { | |
log.Fatalln(err) | |
} | |
log.Println(resp) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment