Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active April 9, 2021 02:00
Show Gist options
  • Save crowsonkb/e6198922b926e6ae96b42740eddb2a90 to your computer and use it in GitHub Desktop.
Save crowsonkb/e6198922b926e6ae96b42740eddb2a90 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""Generates images from saved embeddings with CLIP."""
import argparse
from concurrent import futures
import sys
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms, utils
from torchvision.transforms import functional as TF
from tqdm import tqdm, trange
from CLIP import clip
def setup_exceptions():
try:
from IPython.core.ultratb import FormattedTB
sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral')
except ImportError:
pass
class TVLoss(nn.Module):
"""L2 total variation loss, as in Mahendran et al."""
def forward(self, input):
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean()
class Prompt(nn.Module):
def __init__(self, embed, weight=100.):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
return self.weight * dists.mean()
def main():
setup_exceptions()
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('prompt', type=str,
help='the prompt to maximize')
p.add_argument('--clip-model', type=str, default='ViT-B/32', choices=clip.available_models(),
help='the CLIP model to use')
p.add_argument('--tv-weight', '-tw', type=float, default=100.,
help='the smoothing prior weight')
p.add_argument('--step-size', '-ss', type=float, default=0.05,
help='the step size')
p.add_argument('--betas', type=float, default=[0.9, 0.999], nargs=2,
metavar=('BETA_1', 'BETA_2'), help='the Adam beta parameters')
p.add_argument('--weight-decay', '-wd', type=float, default=0.1,
help='the weight decay')
p.add_argument('--iterations', '-i', type=int, default=600,
help='the number of iterations')
p.add_argument('--display-freq', type=int, default=25,
help='display every this many steps')
p.add_argument('--seed', type=int, default=0,
help='the random seed')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
perceptor = clip.load(args.clip_model, jit=False)[0].to(device).eval().requires_grad_(False)
pool = futures.ThreadPoolExecutor()
aug = transforms.RandomAffine(5, (1/20, 1/20), (0.95, 1.05),
interpolation=transforms.InterpolationMode.BILINEAR)
tv_loss = TVLoss()
batch_size = 16
sideX = sideY = perceptor.visual.input_resolution
torch.manual_seed(args.seed)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
pMs = []
embed = torch.tensor([float(x) for x in open(args.prompt, 'r').readlines()])[None]
pMs.append(Prompt(embed).to(device))
image = torch.randn([batch_size, 3, sideY, sideX], device=device) / 5
image.clamp_(-1, 1).requires_grad_()
opt = optim.AdamW([image], lr=args.step_size, betas=args.betas,
weight_decay=args.weight_decay)
# opt = optim.SGD([image], lr=args.step_size, momentum=args.betas[0])
def save_image(t, name):
TF.to_pil_image(t).save(name)
@torch.no_grad()
def checkin(i, losses):
losses_str = ' '.join(f'{loss.item():g}' for loss in losses)
tqdm.write(f'{i} {sum(losses).item():g} {losses_str}')
grid = utils.make_grid(image.add(1).div(2), 4)
pool.submit(save_image, grid.cpu(), f'out_{i:05}.png')
def ascend_txt(i):
iii = perceptor.encode_image(normalize(aug(image.add(1).div(2))))
result = []
if args.tv_weight:
result.append(tv_loss(image) * args.tv_weight / 4)
for prompt in pMs:
result.append(prompt(iii))
return result
def train(i):
lossAll = ascend_txt(i)
if i % args.display_freq == 0:
checkin(i, lossAll)
loss = sum(lossAll)
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
image.clamp_(-1, 1)
try:
for i in trange(args.iterations + 1):
train(i)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment