Created
February 21, 2023 17:31
-
-
Save enjalot/2228952d82e54f68e45e307258d82dc6 to your computer and use it in GitHub Desktop.
get SD friendly CLIP embeddings for 1.x and 2.x
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print("importing") | |
from flask import Flask, request | |
from flask_cors import CORS | |
import json | |
import torch | |
import torch.nn as nn | |
import pandas as pd | |
import numpy as np | |
import transformers | |
transformers.logging.set_verbosity_error() | |
import open_clip | |
from transformers import CLIPTokenizerFast, CLIPTokenizer, CLIPTextModel | |
device = torch.device("mps") | |
print("defining") | |
class AbstractEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def encode(self, *args, **kwargs): | |
raise NotImplementedError | |
class FrozenCLIPEmbedder(AbstractEncoder): | |
"""Uses the CLIP transformer encoder for text (from huggingface)""" | |
LAYERS = [ | |
"last", | |
"pooled", | |
"hidden" | |
] | |
def __init__(self, version="openai/clip-vit-large-patch14", device="mps", max_length=77, | |
freeze=True, layer="pooled", layer_idx=None): # clip-vit-base-patch32 | |
super().__init__() | |
assert layer in self.LAYERS | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.transformer = CLIPTextModel.from_pretrained(version) | |
self.transformer.to(device) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
self.layer_idx = layer_idx | |
if layer == "hidden": | |
assert layer_idx is not None | |
assert 0 <= abs(layer_idx) <= 12 | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
#self.train = disabled_train | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer(text, | |
truncation=True, | |
max_length=self.max_length, | |
return_length=True, | |
return_overflowing_tokens=False, | |
padding="max_length", | |
return_tensors="pt") | |
print("tokens", batch_encoding) | |
tokens = batch_encoding["input_ids"].to(self.device) | |
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") | |
print("after outputs") | |
z = outputs.last_hidden_state.to("cpu") | |
embed = outputs.pooler_output[:, None, :].to("cpu") | |
return [embed, z, tokens] | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPEmbedder(AbstractEncoder): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
LAYERS = [ | |
#"pooled", | |
"last", | |
"penultimate" | |
] | |
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cpu", max_length=77, | |
freeze=True, layer="last"): | |
super().__init__() | |
assert layer in self.LAYERS | |
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) | |
del model.visual | |
self.model = model | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
if self.layer == "last": | |
self.layer_idx = 0 | |
elif self.layer == "penultimate": | |
self.layer_idx = 1 | |
else: | |
raise NotImplementedError() | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
tokens = open_clip.tokenize(text) | |
z = self.encode_with_transformer(tokens.to(self.device)) | |
return [z, tokens] | |
def full_encode(self, text): | |
tokens = open_clip.tokenize(text) | |
z = self.encode_with_transformer(tokens.to(self.device)) | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
embed = z[torch.arange(z.shape[0]), tokens.argmax(dim=-1)] @ self.model.text_projection | |
return [embed, z, tokens] | |
def encode_with_transformer(self, text): | |
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.model.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.model.ln_final(x) | |
return x | |
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): | |
for i, r in enumerate(self.model.transformer.resblocks): | |
if i == len(self.model.transformer.resblocks) - self.layer_idx: | |
break | |
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): | |
x = checkpoint(r, x, attn_mask) | |
else: | |
x = r(x, attn_mask=attn_mask) | |
return x | |
def encode(self, text): | |
return self(text) | |
# =============================================== | |
# =============================================== | |
# Server code | |
# =============================================== | |
# =============================================== | |
print("loading old clip embedder") | |
fce = FrozenCLIPEmbedder() | |
print("loading open clip embedder") | |
foce = FrozenOpenCLIPEmbedder() | |
app = Flask(__name__) | |
cors = CORS(app) | |
@app.route('/api/openclip', methods=['GET']) | |
def get_openclip(): | |
# Get the text prompt from the query string | |
prompt = request.args.get('prompt') | |
[embed, z, tokens] = foce.full_encode(prompt) | |
end = tokens[0].argmax(dim=-1) + 1 | |
# Return the numbers as a JSON response | |
return json.dumps({ | |
"embed": embed[0].tolist(), | |
"z": z[0][0:end].tolist(), | |
"tokens": tokens[0][0:end].tolist() | |
}) | |
@app.route('/api/oldclip', methods=['GET']) | |
def get_oldclip(): | |
# Get the text prompt from the query string | |
prompt = request.args.get('prompt') | |
[embed, z, tokens] = fce.forward(prompt) | |
end = tokens[0].argmax(dim=-1) + 1 | |
# Return the numbers as a JSON response | |
return json.dumps({ | |
"embed": embed[0][0].tolist(), | |
"z": z[0][0:end].tolist(), | |
"tokens": tokens[0][0:end].tolist() | |
}) | |
if __name__ == '__main__': | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment