Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save windbg/616a78954306435ee7e4482c48918795 to your computer and use it in GitHub Desktop.
Save windbg/616a78954306435ee7e4482c48918795 to your computer and use it in GitHub Desktop.
让bert4keras使用Tensorflow serving调用模型
0.
基于tf2.0
1.
把已有的模型导出成pb格式
保存的时候的INFO:tensorflow:Unsupported signature for serialization貌似不用管
python code:
import os
os.environ['TF_KERAS'] = '1'
import numpy as np
from bert4keras.backend import K as K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
from keras.models import load_model
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
model = 'gptml.h5'
base = '/Volumes/Untitled/pb'
keras_model = load_model(model,compile=False)
keras_model.save(base + '/150k/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model
2.
用docker启动server
docker run -p 8501:8501 --mount type=bind,source=/Volumes/Untitled/pb/150k/,target=/models/my_model -e MODEL_NAME=my_model -t tensorflow/serving
在这个页面查看模型的元数据:
http://localhost:8501/v1/models/my_model/metadata
inputs后面的就是api要求的参数
"inputs": { "Input-Token": { "dtype": "DT_FLOAT","tensor_shape": {"dim": [{"size": "-1","name": ""},{"size": "-1","name": ""}],"unknown_rank": false},"name": "serving_default_Input-Token:0"}}
或者不用docker的话,在ubuntu上可以这样
需要先安装tensorflow_model_server
然后执行命令:
tensorflow_model_server --model_base_path="/Volumes/Untitled/pb/150k" --rest_api_port=8501 --model_name="my_model"
3.用requests调用
python code:
import requests
import json
payload = [[1921,7471,5682,5023,4170,7433]] # <=== 这是tokenizer编码过的中文:天青色等烟雨
d = {"signature_name": "serving_default",
"inputs": {"Input-Token":[[1921,7471,5682,5023,4170,7433]]}} <=== payload
r = requests.post('http://127.0.0.1:8501/v1/models/my_model:predict',json=d)
print(r.json())
4.以https://github.com/bojone/bert4keras/blob/master/examples/basic_language_model_gpt2_ml.py 为例子
把这个类修改一下,加上远程调用的方法
先依赖requests 和 numpy
import requests
import numpy
class ArticleCompletion(AutoRegressiveDecoder):
"""基于随机采样的文章续写
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids = np.concatenate([inputs[0], output_ids], 1)
return model.predict(token_ids)[:, -1]
def generate(self, text, n=1, topk=5):
token_ids, _ = tokenizer.encode(text)
results = self.random_sample([token_ids], n, topk) # 基于随机采样
return [text + tokenizer.decode(ids) for ids in results]
def remote_call(self,token_ids):
payload = token_ids.tolist()
d = {"signature_name": "serving_default","inputs": {"Input-Token":payload}}
r = requests.post('http://127.0.0.1:8501/v1/models/my_model:predict',json=d)
return numpy.array(r.json()['outputs'])
然后就可以跑了
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment