Created
February 25, 2021 18:35
-
-
Save l4rz/7040835c3f8266d8b8ea3615a0b49494 to your computer and use it in GitHub Desktop.
ALEPH by @advadnoun but for local execution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# ALEPH by Advadnoun, https://colab.research.google.com/drive/1Q-TbYvASMPRMXCOQjkxxf72CXYjR_8Vp | |
# "This is a notebook that uses DALL-E's decoder and CLIP to generate images from text. I will very likely make this better & easier to use in the future." | |
# | |
# rearranged to run locally on faster GPU | |
# | |
# directions: | |
# clone https://github.com/openai/DALL-E/ and https://github.com/openai/CLIP | |
# copy relevant files into one dir with this script | |
# install torch==1.7.1 and other stuff | |
# change text | |
# run | |
# | |
# (loss -6.38, 4000 iters, lr=0.5) | |
# | |
import torch | |
import numpy as np | |
import torchvision | |
import torchvision.transforms.functional as TF | |
import torchvision.transforms as T | |
import PIL | |
import random | |
import imageio | |
import clip | |
import torch | |
import io | |
import requests | |
from dall_e import map_pixels, unmap_pixels, load_model | |
clip.available_models() | |
def save_img(step, img, pre_scaled=True): | |
img = np.array(img)[:,:,:] | |
img = np.transpose(img, (1, 2, 0)) | |
if not pre_scaled: | |
img = scale(img, 48*4, 32*4) | |
imageio.imwrite(str(step) + '.png', np.array(img)) | |
return | |
def preprocess(img): | |
s = min(img.size) | |
if s < target_image_size: | |
raise ValueError(f'min dim for image {s} < {target_image_size}') | |
r = target_image_size / s | |
s = (round(r * img.size[1]), round(r * img.size[0])) | |
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) | |
img = TF.center_crop(img, output_size=2 * [target_image_size]) | |
img = torch.unsqueeze(T.ToTensor()(img), 0) | |
return map_pixels(img) | |
class Pars(torch.nn.Module): | |
def __init__(self): | |
super(Pars, self).__init__() | |
self.normu = torch.nn.Parameter(torch.randn(1, 8192, 64, 64).cuda()) | |
def forward(self): | |
normu = torch.nn.functional.gumbel_softmax(self.normu.view(1, 8192, -1), dim=-1, tau=1.4).view(1, 8192, 64, 64) # tau is temp, default 1 | |
return normu | |
def checkin(step, loss): | |
print('Step', step, 'loss', loss) | |
with torch.no_grad(): | |
al = unmap_pixels(torch.sigmoid(model(lats())[:, :3]).cpu().float()).numpy() | |
for allls in al: | |
save_img(step, allls) | |
#display.display(display.Image(str(3)+'.png')) | |
#print('\n') | |
def ascend_txt(): | |
out = unmap_pixels(torch.sigmoid(model(lats())[:, :3].float())) | |
cutn = 128 # improves quality, was 64 | |
p_s = [] | |
for ch in range(cutn): | |
size = int(sideX*torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .98)) | |
offsetx = torch.randint(0, sideX - size, ()) | |
offsety = torch.randint(0, sideY - size, ()) # should be sideY | |
apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size] | |
apper = torch.nn.functional.interpolate(apper, (224,224), mode='bilinear') | |
p_s.append(apper) | |
into = torch.cat(p_s, 0) | |
# old | |
#into = torch.nn.functional.interpolate(out, (224,224), mode='nearest') | |
# end of old | |
into = nom(into) | |
iii = perceptor.encode_image(into) | |
#llls = lats() | |
lat_l = 0 | |
return [lat_l, 10*-torch.cosine_similarity(t, iii).view(-1, 1).T.mean(1)] | |
def train(i): | |
loss1 = ascend_txt() | |
loss = loss1[0] + loss1[1] | |
loss = loss.mean() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if itt % 100 == 0: | |
checkin(i, loss1) | |
# | |
# Begin | |
# | |
text = "Moscow never sleeps" | |
print ('Text:', text) | |
print ('Loading models and stuff') | |
lats = Pars().cuda() | |
mapper = [lats.normu] | |
optimizer = torch.optim.Adam([{'params': mapper, 'lr': .075}]) #was .1 | |
model = load_model("decoder.pkl", 'cuda') | |
#model = load_model("https://cdn.openai.com/dall-e/decoder.pkl", 'cuda') | |
print ('Generator loaded') | |
perceptor, preprocess = clip.load('ViT-B/32', jit=True) | |
perceptor = perceptor.eval() | |
#im_shape = [512, 512, 3] | |
im_shape = [512, 512, 3] | |
sideX, sideY, channels = im_shape | |
tx = clip.tokenize(text) | |
t = perceptor.encode_text(tx.cuda()).detach().clone() | |
print ('Perceptor loaded') | |
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
print("Starting") | |
itt = 0 | |
for asatreat in range(10000): | |
train(itt) | |
itt+=1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment