Skip to content

Instantly share code, notes, and snippets.

@liprais
Created July 21, 2020 10:06
Show Gist options
  • Save liprais/18f8f5c4b9ea12a4943527de057eecca to your computer and use it in GitHub Desktop.
Save liprais/18f8f5c4b9ea12a4943527de057eecca to your computer and use it in GitHub Desktop.
1.
把已有的模型导出成pb格式
保存的时候的INFO:tensorflow:Unsupported signature for serialization貌似不用管
python code:
import os
os.environ['TF_KERAS'] = '1'
import numpy as np
from bert4keras.backend import K as K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
from keras.models import load_model
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
model = '/Users/liuxiao/Downloads/gpt2-ml-finetuned-lyrics/h5/short-articales-finetuned-gpt2-ml-150k.h5'
base = '/Volumes/Untitled/pb'
keras_model = load_model(model,compile=False)
keras_model.save(base + '/150k/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model
2.
用docker启动server
docker run -p 8501:8501 --mount type=bind,source=/Volumes/Untitled/pb150k/,target=/models/my_model -e MODEL_NAME=my_model -t tensorflow/serving
在这个页面查看模型的元数据:
http://localhost:8501/v1/models/my_model/metadata
inputs后面的就是api要求的参数
"inputs": { "Input-Token": { "dtype": "DT_FLOAT","tensor_shape": {"dim": [{"size": "-1","name": ""},{"size": "-1","name": ""}],"unknown_rank": false},"name": "serving_default_Input-Token:0"}}
3.用requests调用
python code:
import requests
import json
payload = [[1921,7471,5682,5023,4170,7433]] # <=== 这是tokenizer编码过的中文:天青色等烟雨
d = {"signature_name": "serving_default",
"inputs": {"Input-Token":[[1921,7471,5682,5023,4170,7433]]}} <=== payload
r = requests.post('http://127.0.0.1:8501/v1/models/my_model:predict',json=d)
print(r.json())
@liprais
Copy link
Author

liprais commented Jul 21, 2020

def infer(url='http://127.0.0.1:8501/v1/models/my_model:predict'): payload = [[1921, 7471, 5682, 5023, 4170, 7433]] d = {} d['signature_name'] = 'serving_default' h = {} h["Input-Token"] = payload d['inputs'] = h #d = {"signature_name": "serving_default","inputs": {"Input-Token": [[1921, 7471, 5682, 5023, 4170, 7433]]}} r = requests.post(url, json=d) return r

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment