Skip to content

Instantly share code, notes, and snippets.

@enjalot
Created February 21, 2023 17:31
Show Gist options
  • Save enjalot/2228952d82e54f68e45e307258d82dc6 to your computer and use it in GitHub Desktop.
Save enjalot/2228952d82e54f68e45e307258d82dc6 to your computer and use it in GitHub Desktop.
get SD friendly CLIP embeddings for 1.x and 2.x
print("importing")
from flask import Flask, request
from flask_cors import CORS
import json
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import transformers
transformers.logging.set_verbosity_error()
import open_clip
from transformers import CLIPTokenizerFast, CLIPTokenizer, CLIPTextModel
device = torch.device("mps")
print("defining")
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="mps", max_length=77,
freeze=True, layer="pooled", layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.transformer.to(device)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
print("tokens", batch_encoding)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
print("after outputs")
z = outputs.last_hidden_state.to("cpu")
embed = outputs.pooler_output[:, None, :].to("cpu")
return [embed, z, tokens]
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cpu", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return [z, tokens]
def full_encode(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
# take features from the eot embedding (eot_token is the highest number in each sequence)
embed = z[torch.arange(z.shape[0]), tokens.argmax(dim=-1)] @ self.model.text_projection
return [embed, z, tokens]
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
# ===============================================
# ===============================================
# Server code
# ===============================================
# ===============================================
print("loading old clip embedder")
fce = FrozenCLIPEmbedder()
print("loading open clip embedder")
foce = FrozenOpenCLIPEmbedder()
app = Flask(__name__)
cors = CORS(app)
@app.route('/api/openclip', methods=['GET'])
def get_openclip():
# Get the text prompt from the query string
prompt = request.args.get('prompt')
[embed, z, tokens] = foce.full_encode(prompt)
end = tokens[0].argmax(dim=-1) + 1
# Return the numbers as a JSON response
return json.dumps({
"embed": embed[0].tolist(),
"z": z[0][0:end].tolist(),
"tokens": tokens[0][0:end].tolist()
})
@app.route('/api/oldclip', methods=['GET'])
def get_oldclip():
# Get the text prompt from the query string
prompt = request.args.get('prompt')
[embed, z, tokens] = fce.forward(prompt)
end = tokens[0].argmax(dim=-1) + 1
# Return the numbers as a JSON response
return json.dumps({
"embed": embed[0][0].tolist(),
"z": z[0][0:end].tolist(),
"tokens": tokens[0][0:end].tolist()
})
if __name__ == '__main__':
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment