Forked from liprais/use_tf_serving_with_bert4keras.txt
Created
August 23, 2020 02:54
-
-
Save windbg/616a78954306435ee7e4482c48918795 to your computer and use it in GitHub Desktop.
让bert4keras使用Tensorflow serving调用模型
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0. | |
基于tf2.0 | |
1. | |
把已有的模型导出成pb格式 | |
保存的时候的INFO:tensorflow:Unsupported signature for serialization貌似不用管 | |
python code: | |
import os | |
os.environ['TF_KERAS'] = '1' | |
import numpy as np | |
from bert4keras.backend import K as K | |
from bert4keras.models import build_transformer_model | |
from bert4keras.tokenizers import Tokenizer | |
from bert4keras.snippets import AutoRegressiveDecoder | |
from keras.models import load_model | |
import tensorflow as tf | |
from tensorflow.python.framework.ops import disable_eager_execution | |
disable_eager_execution() | |
model = 'gptml.h5' | |
base = '/Volumes/Untitled/pb' | |
keras_model = load_model(model,compile=False) | |
keras_model.save(base + '/150k/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model | |
2. | |
用docker启动server | |
docker run -p 8501:8501 --mount type=bind,source=/Volumes/Untitled/pb/150k/,target=/models/my_model -e MODEL_NAME=my_model -t tensorflow/serving | |
在这个页面查看模型的元数据: | |
http://localhost:8501/v1/models/my_model/metadata | |
inputs后面的就是api要求的参数 | |
"inputs": { "Input-Token": { "dtype": "DT_FLOAT","tensor_shape": {"dim": [{"size": "-1","name": ""},{"size": "-1","name": ""}],"unknown_rank": false},"name": "serving_default_Input-Token:0"}} | |
或者不用docker的话,在ubuntu上可以这样 | |
需要先安装tensorflow_model_server | |
然后执行命令: | |
tensorflow_model_server --model_base_path="/Volumes/Untitled/pb/150k" --rest_api_port=8501 --model_name="my_model" | |
3.用requests调用 | |
python code: | |
import requests | |
import json | |
payload = [[1921,7471,5682,5023,4170,7433]] # <=== 这是tokenizer编码过的中文:天青色等烟雨 | |
d = {"signature_name": "serving_default", | |
"inputs": {"Input-Token":[[1921,7471,5682,5023,4170,7433]]}} <=== payload | |
r = requests.post('http://127.0.0.1:8501/v1/models/my_model:predict',json=d) | |
print(r.json()) | |
4.以https://github.com/bojone/bert4keras/blob/master/examples/basic_language_model_gpt2_ml.py 为例子 | |
把这个类修改一下,加上远程调用的方法 | |
先依赖requests 和 numpy | |
import requests | |
import numpy | |
class ArticleCompletion(AutoRegressiveDecoder): | |
"""基于随机采样的文章续写 | |
""" | |
@AutoRegressiveDecoder.wraps(default_rtype='probas') | |
def predict(self, inputs, output_ids, states): | |
token_ids = np.concatenate([inputs[0], output_ids], 1) | |
return model.predict(token_ids)[:, -1] | |
def generate(self, text, n=1, topk=5): | |
token_ids, _ = tokenizer.encode(text) | |
results = self.random_sample([token_ids], n, topk) # 基于随机采样 | |
return [text + tokenizer.decode(ids) for ids in results] | |
def remote_call(self,token_ids): | |
payload = token_ids.tolist() | |
d = {"signature_name": "serving_default","inputs": {"Input-Token":payload}} | |
r = requests.post('http://127.0.0.1:8501/v1/models/my_model:predict',json=d) | |
return numpy.array(r.json()['outputs']) | |
然后就可以跑了 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment