Last active
April 22, 2025 14:03
-
-
Save sobamchan/745a701ca997b4bcdb1968cc3f857856 to your computer and use it in GitHub Desktop.
asyncio api encoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from dataclasses import dataclass | |
from openai import AsyncOpenAI | |
from tqdm.asyncio import tqdm_asyncio | |
@dataclass | |
class AsyncEncoder: | |
client: AsyncOpenAI | |
model_name: str | |
semaphore: asyncio.Semaphore | None | |
cache_path: str | None = None | |
db: Cache | None = None | |
concurrency: int = 30 | |
tokenizer: AutoTokenizer | None = None | |
@classmethod | |
def init( | |
cls, | |
base_url: str, | |
api_key: str, | |
model_name: str, | |
concurrency: int = 30, | |
cache_path: str | None = None, | |
) -> "AsyncEncoder": | |
db = Cache(cache_path) if cache_path is not None else None | |
client = AsyncOpenAI(base_url=base_url, api_key=api_key) | |
return cls( | |
client, | |
model_name, | |
None, | |
cache_path, | |
db, | |
concurrency=concurrency, | |
tokenizer=AutoTokenizer.from_pretrained(model_name), | |
) | |
def _check_cache(self, text: str) -> bool | list[float]: | |
if self.db: | |
if text in self.db: | |
return json.loads(self.db[text]) | |
else: | |
return False | |
else: | |
return False | |
def _save_cache(self, text: str, emb: list[float]) -> None: | |
if self.db is not None: | |
self.db[text] = json.dumps(emb) | |
async def task(self, text: str) -> list[float]: | |
assert self.semaphore | |
async with self.semaphore: | |
false_or_cache = self._check_cache(text) | |
if not false_or_cache: | |
# No cache. Compute embedding | |
emb = await self.call_api(text) | |
self._save_cache(text, emb) | |
return emb | |
else: | |
# Reuse data from cache | |
assert isinstance(false_or_cache, list) | |
return false_or_cache | |
async def call_api(self, text: str) -> list[float]: | |
text = self.tokenizer.decode( | |
self.tokenizer.encode(text, truncation=True), skip_special_tokens=True | |
) | |
res = await self.client.embeddings.create( | |
model=self.model_name, | |
input=text, | |
encoding_format="float", | |
) | |
return res.data[0].embedding | |
async def _encode(self, texts: list[str]) -> list[list[float]]: | |
tasks = [self.task(text) for text in texts] | |
results = await tqdm_asyncio.gather(*tasks) | |
return list(results) | |
def encode(self, texts: list[str]) -> list[list[float]]: | |
self.semaphore = asyncio.Semaphore(self.concurrency) | |
out = asyncio.run(self._encode(texts)) | |
if self.db: | |
self.db.close() | |
self.semaphore.release() | |
return out | |
if __name__ == "__main__": | |
encoder = AsyncEncoder.init( | |
"<URL>", | |
"<API-KEY>", | |
"<MODEL-NAME>", | |
) | |
texts = ["hello world" for _ in range(1000)] | |
out = encoder.encode(texts) # -> list of embeddings (list[list[float]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment