Skip to content

Instantly share code, notes, and snippets.

@sobamchan
Last active April 22, 2025 14:03
Show Gist options
  • Save sobamchan/745a701ca997b4bcdb1968cc3f857856 to your computer and use it in GitHub Desktop.
Save sobamchan/745a701ca997b4bcdb1968cc3f857856 to your computer and use it in GitHub Desktop.
asyncio api encoder
import asyncio
from dataclasses import dataclass
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
@dataclass
class AsyncEncoder:
client: AsyncOpenAI
model_name: str
semaphore: asyncio.Semaphore | None
cache_path: str | None = None
db: Cache | None = None
concurrency: int = 30
tokenizer: AutoTokenizer | None = None
@classmethod
def init(
cls,
base_url: str,
api_key: str,
model_name: str,
concurrency: int = 30,
cache_path: str | None = None,
) -> "AsyncEncoder":
db = Cache(cache_path) if cache_path is not None else None
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
return cls(
client,
model_name,
None,
cache_path,
db,
concurrency=concurrency,
tokenizer=AutoTokenizer.from_pretrained(model_name),
)
def _check_cache(self, text: str) -> bool | list[float]:
if self.db:
if text in self.db:
return json.loads(self.db[text])
else:
return False
else:
return False
def _save_cache(self, text: str, emb: list[float]) -> None:
if self.db is not None:
self.db[text] = json.dumps(emb)
async def task(self, text: str) -> list[float]:
assert self.semaphore
async with self.semaphore:
false_or_cache = self._check_cache(text)
if not false_or_cache:
# No cache. Compute embedding
emb = await self.call_api(text)
self._save_cache(text, emb)
return emb
else:
# Reuse data from cache
assert isinstance(false_or_cache, list)
return false_or_cache
async def call_api(self, text: str) -> list[float]:
text = self.tokenizer.decode(
self.tokenizer.encode(text, truncation=True), skip_special_tokens=True
)
res = await self.client.embeddings.create(
model=self.model_name,
input=text,
encoding_format="float",
)
return res.data[0].embedding
async def _encode(self, texts: list[str]) -> list[list[float]]:
tasks = [self.task(text) for text in texts]
results = await tqdm_asyncio.gather(*tasks)
return list(results)
def encode(self, texts: list[str]) -> list[list[float]]:
self.semaphore = asyncio.Semaphore(self.concurrency)
out = asyncio.run(self._encode(texts))
if self.db:
self.db.close()
self.semaphore.release()
return out
if __name__ == "__main__":
encoder = AsyncEncoder.init(
"<URL>",
"<API-KEY>",
"<MODEL-NAME>",
)
texts = ["hello world" for _ in range(1000)]
out = encoder.encode(texts) # -> list of embeddings (list[list[float]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment