Created
July 21, 2020 10:06
-
-
Save liprais/18f8f5c4b9ea12a4943527de057eecca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1. | |
把已有的模型导出成pb格式 | |
保存的时候的INFO:tensorflow:Unsupported signature for serialization貌似不用管 | |
python code: | |
import os | |
os.environ['TF_KERAS'] = '1' | |
import numpy as np | |
from bert4keras.backend import K as K | |
from bert4keras.models import build_transformer_model | |
from bert4keras.tokenizers import Tokenizer | |
from bert4keras.snippets import AutoRegressiveDecoder | |
from keras.models import load_model | |
import tensorflow as tf | |
from tensorflow.python.framework.ops import disable_eager_execution | |
disable_eager_execution() | |
model = '/Users/liuxiao/Downloads/gpt2-ml-finetuned-lyrics/h5/short-articales-finetuned-gpt2-ml-150k.h5' | |
base = '/Volumes/Untitled/pb' | |
keras_model = load_model(model,compile=False) | |
keras_model.save(base + '/150k/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model | |
2. | |
用docker启动server | |
docker run -p 8501:8501 --mount type=bind,source=/Volumes/Untitled/pb150k/,target=/models/my_model -e MODEL_NAME=my_model -t tensorflow/serving | |
在这个页面查看模型的元数据: | |
http://localhost:8501/v1/models/my_model/metadata | |
inputs后面的就是api要求的参数 | |
"inputs": { "Input-Token": { "dtype": "DT_FLOAT","tensor_shape": {"dim": [{"size": "-1","name": ""},{"size": "-1","name": ""}],"unknown_rank": false},"name": "serving_default_Input-Token:0"}} | |
3.用requests调用 | |
python code: | |
import requests | |
import json | |
payload = [[1921,7471,5682,5023,4170,7433]] # <=== 这是tokenizer编码过的中文:天青色等烟雨 | |
d = {"signature_name": "serving_default", | |
"inputs": {"Input-Token":[[1921,7471,5682,5023,4170,7433]]}} <=== payload | |
r = requests.post('http://127.0.0.1:8501/v1/models/my_model:predict',json=d) | |
print(r.json()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
def infer(url='http://127.0.0.1:8501/v1/models/my_model:predict'): payload = [[1921, 7471, 5682, 5023, 4170, 7433]] d = {} d['signature_name'] = 'serving_default' h = {} h["Input-Token"] = payload d['inputs'] = h #d = {"signature_name": "serving_default","inputs": {"Input-Token": [[1921, 7471, 5682, 5023, 4170, 7433]]}} r = requests.post(url, json=d) return r