Last active
April 9, 2021 02:00
-
-
Save crowsonkb/e6198922b926e6ae96b42740eddb2a90 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Generates images from saved embeddings with CLIP.""" | |
import argparse | |
from concurrent import futures | |
import sys | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from torchvision import transforms, utils | |
from torchvision.transforms import functional as TF | |
from tqdm import tqdm, trange | |
from CLIP import clip | |
def setup_exceptions(): | |
try: | |
from IPython.core.ultratb import FormattedTB | |
sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral') | |
except ImportError: | |
pass | |
class TVLoss(nn.Module): | |
"""L2 total variation loss, as in Mahendran et al.""" | |
def forward(self, input): | |
input = F.pad(input, (0, 1, 0, 1), 'replicate') | |
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] | |
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] | |
return (x_diff**2 + y_diff**2).mean() | |
class Prompt(nn.Module): | |
def __init__(self, embed, weight=100.): | |
super().__init__() | |
self.register_buffer('embed', embed) | |
self.register_buffer('weight', torch.as_tensor(weight)) | |
def forward(self, input): | |
input_normed = F.normalize(input.unsqueeze(1), dim=2) | |
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) | |
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
return self.weight * dists.mean() | |
def main(): | |
setup_exceptions() | |
p = argparse.ArgumentParser(description=__doc__, | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
p.add_argument('prompt', type=str, | |
help='the prompt to maximize') | |
p.add_argument('--clip-model', type=str, default='ViT-B/32', choices=clip.available_models(), | |
help='the CLIP model to use') | |
p.add_argument('--tv-weight', '-tw', type=float, default=100., | |
help='the smoothing prior weight') | |
p.add_argument('--step-size', '-ss', type=float, default=0.05, | |
help='the step size') | |
p.add_argument('--betas', type=float, default=[0.9, 0.999], nargs=2, | |
metavar=('BETA_1', 'BETA_2'), help='the Adam beta parameters') | |
p.add_argument('--weight-decay', '-wd', type=float, default=0.1, | |
help='the weight decay') | |
p.add_argument('--iterations', '-i', type=int, default=600, | |
help='the number of iterations') | |
p.add_argument('--display-freq', type=int, default=25, | |
help='display every this many steps') | |
p.add_argument('--seed', type=int, default=0, | |
help='the random seed') | |
args = p.parse_args() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
perceptor = clip.load(args.clip_model, jit=False)[0].to(device).eval().requires_grad_(False) | |
pool = futures.ThreadPoolExecutor() | |
aug = transforms.RandomAffine(5, (1/20, 1/20), (0.95, 1.05), | |
interpolation=transforms.InterpolationMode.BILINEAR) | |
tv_loss = TVLoss() | |
batch_size = 16 | |
sideX = sideY = perceptor.visual.input_resolution | |
torch.manual_seed(args.seed) | |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
pMs = [] | |
embed = torch.tensor([float(x) for x in open(args.prompt, 'r').readlines()])[None] | |
pMs.append(Prompt(embed).to(device)) | |
image = torch.randn([batch_size, 3, sideY, sideX], device=device) / 5 | |
image.clamp_(-1, 1).requires_grad_() | |
opt = optim.AdamW([image], lr=args.step_size, betas=args.betas, | |
weight_decay=args.weight_decay) | |
# opt = optim.SGD([image], lr=args.step_size, momentum=args.betas[0]) | |
def save_image(t, name): | |
TF.to_pil_image(t).save(name) | |
@torch.no_grad() | |
def checkin(i, losses): | |
losses_str = ' '.join(f'{loss.item():g}' for loss in losses) | |
tqdm.write(f'{i} {sum(losses).item():g} {losses_str}') | |
grid = utils.make_grid(image.add(1).div(2), 4) | |
pool.submit(save_image, grid.cpu(), f'out_{i:05}.png') | |
def ascend_txt(i): | |
iii = perceptor.encode_image(normalize(aug(image.add(1).div(2)))) | |
result = [] | |
if args.tv_weight: | |
result.append(tv_loss(image) * args.tv_weight / 4) | |
for prompt in pMs: | |
result.append(prompt(iii)) | |
return result | |
def train(i): | |
lossAll = ascend_txt(i) | |
if i % args.display_freq == 0: | |
checkin(i, lossAll) | |
loss = sum(lossAll) | |
opt.zero_grad() | |
loss.backward() | |
opt.step() | |
with torch.no_grad(): | |
image.clamp_(-1, 1) | |
try: | |
for i in trange(args.iterations + 1): | |
train(i) | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment