Last active
December 23, 2023 06:34
-
-
Save hzhangxyz/dce6ba0320e81867224c2fd434dc42b3 to your computer and use it in GitHub Desktop.
nonebot_zhipuai_bot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nonebot.adapters import Event, Message | |
from nonebot.params import EventPlainText, CommandArg | |
from nonebot.plugin.on import on_message, on_command | |
from nonebot.adapters.onebot.v11.message import MessageSegment | |
import os | |
import ast | |
import json | |
import shelve | |
import aiohttp | |
import zhipuai | |
db = shelve.open('db') | |
zhipuai.api_key = os.environ["ZHIPU_API_KEY"] | |
each_other = {"user": "assistant", "assistant": "user"} | |
async def should_reply_checker(event: Event) -> bool: | |
return event.message_type != "group" | |
def construct_history(chat_id: int, message_id: int): | |
result = [] | |
while message_id is not None: | |
role, content, message_id = db[repr((chat_id, message_id))] | |
result.append({"role": role, "content": content}) | |
result.reverse() | |
return result | |
async def invoke(**params): | |
url = "https://open.bigmodel.cn/api/paas/v3/model-api/%s/invoke" % params["model"] | |
token = zhipuai.utils.jwt_token.generate_token(zhipuai.api_key) | |
headers = { | |
"Accept": "application/json", | |
"Content-Type": "application/json; charset=UTF-8", | |
"Authorization": token, | |
} | |
async with aiohttp.ClientSession() as session: | |
async with session.post(url, data=json.dumps(params), headers=headers) as response: | |
return json.loads(await response.text()) | |
async def completion(chat_id: int, message_id: int): | |
model = db[repr(chat_id)] if repr(chat_id) in db else "chatglm_lite" | |
result = await invoke( | |
model=model, | |
prompt=construct_history(chat_id, message_id), | |
) | |
return f"[{model}] " + ast.literal_eval(result["data"]["choices"][0]["content"]) | |
should_reply = on_message(rule=should_reply_checker, priority=20, block=True) | |
@should_reply.handle() | |
async def _(event: Event, message: str = EventPlainText()): | |
chat_id = event.user_id | |
message_id = event.message_id | |
if event.reply is None: | |
parent = None | |
role = "user" | |
else: | |
parent = event.reply.message_id | |
last_role, _, _ = db[repr((chat_id, parent))] | |
role = each_other[last_role] | |
db[repr((chat_id, message_id))] = (role, message, parent) | |
if role == "user": | |
result = await completion(chat_id, message_id) | |
reply = await should_reply.send(MessageSegment.reply(message_id) + result) | |
db[repr((chat_id, reply["message_id"]))] = ("assistant", result, message_id) | |
choose_model = on_command("model", priority=10, block=True) | |
@choose_model.handle() | |
async def _(event: Event, model: Message = CommandArg()): | |
chat_id = event.user_id | |
model = model.extract_plain_text() | |
available_models = ["chatglm_lite", "chatglm_std", "chatglm_pro"] | |
if model not in available_models: | |
await choose_model.send("invalid model") | |
else: | |
db[repr(chat_id)] = model | |
await choose_model.send("model is set to " + model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment