Last active
August 27, 2024 08:14
-
-
Save dixyes/47c0e76a12845b786a962c4e6dade345 to your computer and use it in GitHub Desktop.
api.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import json | |
import datetime | |
from typing import Iterable, Optional, TypedDict, Union | |
import urllib.parse | |
import hashlib | |
import hmac | |
import secrets | |
import requests | |
class APIException(Exception): | |
def __init__(self, response: dict): | |
r = response.get("Response") | |
error = r.get("Error") | |
self.code = error.get("Code") | |
self.message = error.get("Message") | |
self.requestId = r.get("RequestId") | |
def __str__(self): | |
return f"[{self.requestId}]{self.code}: {self.message}" | |
class API: | |
def __init__(self, region: Optional[str] = None): | |
raise NotImplementedError("stub only") | |
def req( | |
self, | |
uri: str, | |
method: str = "GET", | |
params: dict[str, str] = {}, | |
headers: dict[str, str] = {}, | |
data: Optional[Union[bytes, str, dict]] = None, | |
stream: bool = False, | |
) -> Union[dict, Iterable[bytes]]: | |
paramsTuples = [] | |
for k, v in params.items(): | |
if v == None: | |
continue | |
if not isinstance(v, str): | |
v = json.dumps(v) | |
paramsTuples.append((k, v)) | |
paramsTuples.sort(key=lambda x: x[0] + x[1]) | |
# not used: qcloud TC3 always use "/" as path/uri | |
paramsStr = "&".join( | |
[ | |
f'{urllib.parse.quote(k, safe="-_.~")}={urllib.parse.quote(v, safe="-_.~")}' | |
for k, v in paramsTuples | |
] | |
).replace("+", "%20") | |
lcHeaders = {k.lower(): v for k, v in headers.items()} | |
if not lcHeaders.get("x-tc-action"): | |
raise ValueError("missing x-tc-action in headers") | |
signedHeaders = {} | |
signedHeaders["host"] = self.endpoint | |
signedHeaders["x-tc-action"] = lcHeaders.get("x-tc-action") | |
signedHeaders["x-tc-timestamp"] = str( | |
int(datetime.datetime.now(datetime.UTC).timestamp()) | |
) | |
for k, v in lcHeaders.items(): | |
if k.startswith("x-tc-"): | |
signedHeaders[k] = v | |
signedHeaders["x-nonce"] = secrets.token_urlsafe(16) | |
if isinstance(data, dict): | |
_data = json.dumps(data).encode() | |
elif isinstance(data, str): | |
_data = data.encode() | |
elif data == None: | |
_data = b"" | |
if _data: | |
signedHeaders["content-type"] = lcHeaders.get( | |
"content-type", "application/json; charset=utf-8" | |
) | |
headerStrs = [ | |
f"{k}:{v.strip()}".lower() | |
for k, v in sorted(signedHeaders.items(), key=lambda x: x[0].lower()) | |
] | |
canonicalReq = "\n".join( | |
[ | |
method.upper(), | |
uri, | |
paramsStr, | |
*headerStrs, | |
"", | |
";".join(sorted(signedHeaders.keys())), | |
hashlib.sha256(_data).hexdigest(), | |
] | |
) | |
logging.debug(canonicalReq) | |
date = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d") | |
scope = f"{date}/{self.service}/tc3_request" | |
signData = ( | |
f"TC3-HMAC-SHA256\n" | |
+ f'{signedHeaders["x-tc-timestamp"]}\n' | |
+ f"{scope}\n" | |
+ hashlib.sha256(canonicalReq.encode()).hexdigest() | |
) | |
logging.debug(signData) | |
h = hmac.new( | |
("TC3" + self.secretKey).encode(), | |
date.encode(), | |
digestmod=hashlib.sha256, | |
) | |
h = hmac.new(h.digest(), self.service.encode(), digestmod=hashlib.sha256) | |
h = hmac.new(h.digest(), b"tc3_request", digestmod=hashlib.sha256) | |
sign = hmac.new( | |
h.digest(), signData.encode(), digestmod=hashlib.sha256 | |
).hexdigest() | |
signedHeaders["authorization"] = ( | |
f'TC3-HMAC-SHA256 Credential={self.accessKey}/{scope}, SignedHeaders={";".join(sorted(signedHeaders))}, Signature={sign}' | |
) | |
signedHeaders = {**headers, **signedHeaders} | |
logging.debug( | |
"%s %s%s\n%s\n%s", | |
method, | |
f"{self.endpoint}{uri}", | |
paramsStr, | |
signedHeaders, | |
_data or None, | |
) | |
res = requests.request( | |
method=method, | |
url=f"https://{self.endpoint}{uri}", | |
params=paramsStr, | |
headers=signedHeaders, | |
data=_data or None, | |
stream=stream, | |
) | |
if stream: | |
return res.iter_content(chunk_size=1024) | |
text = res.text | |
try: | |
resJson = json.loads(text) | |
if ( | |
isinstance(resJson, dict) | |
and resJson.get("Response") | |
and resJson.get("Response").get("Error") | |
): | |
raise APIException(resJson) | |
return resJson | |
except json.decoder.JSONDecodeError: | |
raise Exception(f'cannot decode "{text}"') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
import json | |
import os | |
import logging | |
from typing import Literal, Optional, TypedDict, Union, Iterator | |
from .api import API, APIException | |
class ContentDict(TypedDict): | |
Type: Literal["text", "image_url"] | |
Text: Optional[str] | |
ImageUrl: Optional[str] | |
class ToolFunction(TypedDict): | |
Name: str | |
Parameters: str # json encoded string | |
Description: Optional[str] | |
class Tool(TypedDict): | |
Type: Literal["function"] | |
Function: ToolFunction | |
class ToolCallFunction(TypedDict): | |
Name: str | |
Arguments: str # json encoded string | |
class ToolCall(TypedDict): | |
Id: str | |
Type: Literal["function"] | |
Function: ToolCallFunction | |
class Message(TypedDict): | |
Role: str | |
Content: str | |
Contents: Optional[list[ContentDict]] | |
ToolCallId: Optional[str] | |
ToolCalls: Optional[list[ToolCall]] | |
class UserMessage(Message): | |
Role: Literal["user"] | |
Content: str | |
class ChatCompletionChoice(TypedDict): | |
Message: UserMessage | |
FinishReason: str # Literal["", "stop", "token_limit"] | |
class UsageStatistics(TypedDict): | |
PromptTokens: int | |
CompletionTokens: int | |
TotalTokens: int | |
class ChatCompletionsResponse(TypedDict): | |
RequestId: str | |
Note: Literal["以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记"] | |
Choices: list[ChatCompletionChoice] | |
Created: int # epoch seconds | |
Id: str | |
Usage: UsageStatistics | |
class ChatCompletionStreamChoice(TypedDict): | |
Delta: Message | |
FinishReason: str # Literal["", "stop", "token_limit"] | |
class ChatCompletionsStreamResponse(TypedDict): | |
Note: Literal["以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记"] | |
Choices: list[ChatCompletionStreamChoice] | |
Created: int # epoch seconds | |
Usage: UsageStatistics | |
class Hunyuan(API): | |
def __init__(self, region: Optional[str] = None): | |
self.accessKey = os.environ.get("QCLOUD_ACCESS_KEY") | |
self.secretKey = os.environ.get("QCLOUD_SECRET_KEY") | |
self.endpoint = f"hunyuan.tencentcloudapi.com" | |
self.service = "hunyuan" | |
def createChatCompletions( | |
self, | |
model: Literal[ | |
"hunyuan-pro", | |
"hunyuan-standard", | |
"hunyuan-standard-256K", | |
"hunyuan-lite", | |
"hunyuan-code", | |
"hunyuan-role", | |
"hunyuan-functioncall", | |
"hunyuan-vision", | |
], | |
messages: list[Message], | |
stream: bool = False, | |
streamModeration: bool = False, | |
topP: Optional[float] = None, | |
temperature: Optional[float] = None, | |
enableEnhancement: Optional[bool] = None, | |
tools: Optional[list[ToolCall]] = None, | |
toolChoice: Optional[Literal["none", "auto", "custom"]] = None, | |
customTool: Optional[Tool] = None, | |
searchInfo: Optional[bool] = None, | |
citation: Optional[bool] = None, | |
enableSpeedSearch: Optional[bool] = None, | |
) -> Union[ChatCompletionsResponse, Iterator[ChatCompletionsStreamResponse]]: | |
req = { | |
"Model": model, | |
"Messages": messages, | |
} | |
if topP is not None: | |
req["TopP"] = topP | |
if temperature is not None: | |
req["Temperature"] = temperature | |
if enableEnhancement is not None: | |
req["EnableEnhancement"] = enableEnhancement | |
if tools: | |
req["Tools"] = tools | |
if toolChoice: | |
req["ToolChoice"] = toolChoice | |
if customTool: | |
req["CustomTool"] = customTool | |
if searchInfo: | |
req["SearchInfo"] = searchInfo | |
if citation: | |
req["Citation"] = citation | |
if enableSpeedSearch: | |
req["EnableSpeedSearch"] = enableSpeedSearch | |
if not stream: | |
ret = self.req( | |
method="POST", | |
uri=f"/", | |
headers={ | |
"x-tc-action": "ChatCompletions", | |
"x-tc-version": "2023-09-01", | |
}, | |
data=req, | |
) | |
return ret["Response"] | |
else: | |
req["Stream"] = True | |
if streamModeration: | |
req["StreamModeration"] = True | |
def generator(): | |
buffer = b"" | |
for recvBytes in self.req( | |
method="POST", | |
uri=f"/", | |
headers={ | |
"x-tc-action": "ChatCompletions", | |
"x-tc-version": "2023-09-01", | |
}, | |
data=req, | |
stream=True, | |
): | |
buffer += recvBytes | |
if b"\n\n" not in recvBytes: | |
continue | |
while True: | |
eventBytes, buffer = buffer.split(b"\n\n", maxsplit=1) | |
eventArray = eventBytes.split(b"\n") | |
for k, v in map(lambda line: line.split(b":", maxsplit=1), eventArray): | |
if k == b"data": | |
event = json.loads(v) | |
yield event | |
else: | |
logging.debug(f"unhandled event: {k}: {v}") | |
if b"\n\n" not in buffer: | |
break | |
return generator() | |
if __name__ == "__main__": | |
# logging.basicConfig(level=logging.DEBUG) | |
hunyuan = Hunyuan() | |
# res = hunyuan.createChatCompletions( | |
# model="hunyuan-pro", | |
# messages=[{"Role": "user", "Content": "还记得我第一次丢人,就被投诉高空抛物"}], | |
# temperature=0.9, | |
# ) | |
# print(res) | |
stream = hunyuan.createChatCompletions( | |
model="hunyuan-pro", | |
messages=[{"Role": "user", "Content": "还记得我第一次丢人,就被投诉高空抛物"}], | |
temperature=0.9, | |
stream=True, | |
) | |
for event in stream: | |
print(event["Choices"][0]["Delta"]["Content"], end="") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment