-
-
Save transitive-bullshit/cc9140ff832fc7e815a48f0a45e1fc27 to your computer and use it in GitHub Desktop.
SPLADE on Modal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi.responses import JSONResponse | |
from modal import Image, Mount, NetworkFileSystem, Secret, Stub, method, web_endpoint | |
from pydantic import BaseModel | |
# This is copied from: https://github.com/pinecone-io/examples/blob/2f51ddfd12a08f2963cc2849661fab51afdeedc6/learn/search/semantic-search/sparse/splade/splade-vector-generation.ipynb#L10 | |
# Which is recommended here: https://docs.pinecone.io/docs/hybrid-search | |
stub = Stub("splade") | |
image = Image.debian_slim().pip_install("torch", "transformers") | |
volume = NetworkFileSystem.persisted("splade-model-cache-vol-gcp", cloud="gcp") | |
CACHE_DIR = "/cache" | |
class Body(BaseModel): | |
text: str | |
@stub.cls( | |
image=image, | |
cloud="gcp", | |
cpu=2, | |
memory=2048, | |
keep_warm=40, | |
container_idle_timeout=120, | |
network_file_systems={CACHE_DIR: volume}, | |
secret=Secret.from_dict({"TORCH_HOME": CACHE_DIR, "TRANSFORMERS_CACHE": CACHE_DIR}), | |
) | |
class SPLADE: | |
def __enter__(self): | |
import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
model = "naver/splade-cocondenser-ensembledistil" | |
# check device | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tokenizer = AutoTokenizer.from_pretrained(model) | |
self.model = AutoModelForMaskedLM.from_pretrained(model) | |
# move to gpu if available | |
self.model.to(self.device) | |
@web_endpoint(method="POST") | |
def vector(self, body: Body): | |
import torch | |
from transformers.tokenization_utils_base import TruncationStrategy | |
text = body.text | |
max_length = self.tokenizer.model_max_length | |
inputs = self.tokenizer( | |
text, | |
truncation=TruncationStrategy.LONGEST_FIRST, | |
max_length=max_length, | |
return_tensors="pt", | |
).to(self.device) | |
with torch.no_grad(): | |
logits = self.model(**inputs).logits | |
inter = torch.log1p(torch.relu(logits[0])) | |
token_max = torch.max(inter, dim=0) # sum over input tokens | |
nz_tokens = torch.where(token_max.values > 0)[0] | |
nz_weights = token_max.values[nz_tokens] | |
order = torch.sort(nz_weights, descending=True) | |
nz_weights = nz_weights[order[1]] | |
nz_tokens = nz_tokens[order[1]] | |
response = { | |
"indices": nz_tokens.cpu().numpy().tolist(), | |
"values": nz_weights.cpu().numpy().tolist(), | |
} | |
return JSONResponse(content=response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment