Last active
November 1, 2022 06:11
-
-
Save crowsonkb/043684ffc3370c321d2e2ee64cf500a4 to your computer and use it in GitHub Desktop.
BigGAN + CLIP, Langevin dynamics method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from functools import wraps | |
from hashlib import sha256 | |
from io import open | |
import json | |
import math | |
import logging | |
import os | |
from pathlib import Path | |
import shutil | |
import sys | |
import tempfile | |
from urllib.parse import urlparse | |
import boto3 | |
from botocore.exceptions import ClientError | |
import requests | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', | |
Path.home() / '.cache/pytorch_pretrained_biggan')) | |
PRETRAINED_MODEL_ARCHIVE_MAP = { | |
'biggan-deep-128': | |
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin', | |
'biggan-deep-256': | |
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin', | |
'biggan-deep-512': | |
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin', | |
} | |
PRETRAINED_CONFIG_ARCHIVE_MAP = { | |
'biggan-deep-128': | |
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json', | |
'biggan-deep-256': | |
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json', | |
'biggan-deep-512': | |
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json', | |
} | |
WEIGHTS_NAME = 'pytorch_model.bin' | |
CONFIG_NAME = 'config.json' | |
logger = logging.getLogger(__name__) | |
def url_to_filename(url, etag=None): | |
""" | |
Convert `url` into a hashed filename in a repeatable way. | |
If `etag` is specified, append its hash to the url's, delimited | |
by a period. | |
""" | |
url_bytes = url.encode('utf-8') | |
url_hash = sha256(url_bytes) | |
filename = url_hash.hexdigest() | |
if etag: | |
etag_bytes = etag.encode('utf-8') | |
etag_hash = sha256(etag_bytes) | |
filename += '.' + etag_hash.hexdigest() | |
return filename | |
def filename_to_url(filename, cache_dir=None): | |
""" | |
Return the url and etag (which may be ``None``) stored for `filename`. | |
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | |
""" | |
if cache_dir is None: | |
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE | |
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
cache_path = os.path.join(cache_dir, filename) | |
if not os.path.exists(cache_path): | |
raise EnvironmentError('file {} not found'.format(cache_path)) | |
meta_path = cache_path + '.json' | |
if not os.path.exists(meta_path): | |
raise EnvironmentError('file {} not found'.format(meta_path)) | |
with open(meta_path, encoding='utf-8') as meta_file: | |
metadata = json.load(meta_file) | |
url = metadata['url'] | |
etag = metadata['etag'] | |
return url, etag | |
def cached_path(url_or_filename, cache_dir=None): | |
""" | |
Given something that might be a URL (or might be a local path), | |
determine which. If it's a URL, download the file and cache it, and | |
return the path to the cached file. If it's already a local path, | |
make sure the file exists and then return the path. | |
""" | |
if cache_dir is None: | |
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE | |
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): | |
url_or_filename = str(url_or_filename) | |
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
parsed = urlparse(url_or_filename) | |
if parsed.scheme in ('http', 'https', 's3'): | |
# URL, so get it from the cache (downloading if necessary) | |
return get_from_cache(url_or_filename, cache_dir) | |
elif os.path.exists(url_or_filename): | |
# File, and it exists. | |
return url_or_filename | |
elif parsed.scheme == '': | |
# File, but it doesn't exist. | |
raise EnvironmentError('file {} not found'.format(url_or_filename)) | |
else: | |
# Something unknown | |
raise ValueError('unable to parse {} as a URL or as a local path'.format(url_or_filename)) | |
def split_s3_path(url): | |
"""Split a full s3 path into the bucket name and path.""" | |
parsed = urlparse(url) | |
if not parsed.netloc or not parsed.path: | |
raise ValueError('bad s3 path {}'.format(url)) | |
bucket_name = parsed.netloc | |
s3_path = parsed.path | |
# Remove '/' at beginning of path. | |
if s3_path.startswith("/"): | |
s3_path = s3_path[1:] | |
return bucket_name, s3_path | |
def s3_request(func): | |
""" | |
Wrapper function for s3 requests in order to create more helpful error | |
messages. | |
""" | |
@wraps(func) | |
def wrapper(url, *args, **kwargs): | |
try: | |
return func(url, *args, **kwargs) | |
except ClientError as exc: | |
if int(exc.response['Error']['Code']) == 404: | |
raise EnvironmentError('file {} not found'.format(url)) | |
else: | |
raise | |
return wrapper | |
@s3_request | |
def s3_etag(url): | |
"""Check ETag on S3 object.""" | |
s3_resource = boto3.resource('s3') | |
bucket_name, s3_path = split_s3_path(url) | |
s3_object = s3_resource.Object(bucket_name, s3_path) | |
return s3_object.e_tag | |
@s3_request | |
def s3_get(url, temp_file): | |
"""Pull a file directly from S3.""" | |
s3_resource = boto3.resource('s3') | |
bucket_name, s3_path = split_s3_path(url) | |
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) | |
def http_get(url, temp_file): | |
req = requests.get(url, stream=True) | |
content_length = req.headers.get('Content-Length') | |
total = int(content_length) if content_length is not None else None | |
progress = tqdm(unit="B", total=total) | |
for chunk in req.iter_content(chunk_size=1024): | |
if chunk: # filter out keep-alive new chunks | |
progress.update(len(chunk)) | |
temp_file.write(chunk) | |
progress.close() | |
def get_from_cache(url, cache_dir=None): | |
""" | |
Given a URL, look for the corresponding dataset in the local cache. | |
If it's not there, download it. Then return the path to the cached file. | |
""" | |
if cache_dir is None: | |
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE | |
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir) | |
# Get eTag to add to filename, if it exists. | |
if url.startswith('s3://'): | |
etag = s3_etag(url) | |
else: | |
response = requests.head(url, allow_redirects=True) | |
if response.status_code != 200: | |
raise IOError('HEAD request failed for url {} with status code {}' | |
.format(url, response.status_code)) | |
etag = response.headers.get('ETag') | |
filename = url_to_filename(url, etag) | |
# get cache path to put the file | |
cache_path = os.path.join(cache_dir, filename) | |
if not os.path.exists(cache_path): | |
# Download to temporary file, then copy to cache dir once finished. | |
# Otherwise you get corrupt cache entries if the download gets interrupted. | |
with tempfile.NamedTemporaryFile() as temp_file: | |
logger.info('%s not found in cache, downloading to %s', url, temp_file.name) | |
# GET file object | |
if url.startswith('s3://'): | |
s3_get(url, temp_file) | |
else: | |
http_get(url, temp_file) | |
# we are copying the file before closing it, so flush to avoid truncation | |
temp_file.flush() | |
# shutil.copyfileobj() starts at the current position, so go to the start | |
temp_file.seek(0) | |
logger.info('copying %s to cache at %s', temp_file.name, cache_path) | |
with open(cache_path, 'wb') as cache_file: | |
shutil.copyfileobj(temp_file, cache_file) | |
logger.info('creating metadata file for %s', cache_path) | |
meta = {'url': url, 'etag': etag} | |
meta_path = cache_path + '.json' | |
with open(meta_path, 'w', encoding='utf-8') as meta_file: | |
json.dump(meta, meta_file) | |
logger.info('removing temp file %s', temp_file.name) | |
return cache_path | |
def read_set_from_file(filename): | |
""" | |
Extract a de-duped collection (set) of text from a file. | |
Expected file format is one item per line. | |
""" | |
collection = set() | |
with open(filename, 'r', encoding='utf-8') as file_: | |
for line in file_: | |
collection.add(line.rstrip()) | |
return collection | |
def get_file_extension(path, dot=True, lower=True): | |
ext = os.path.splitext(path)[1] | |
ext = ext if dot else ext[1:] | |
return ext.lower() if lower else ext | |
class BigGANConfig(object): | |
"""Configuration class to store the configuration of a `BigGAN`. Defaults are for the | |
128x128 model. The Layers tuple is (up-sample in the layer?, input channels, output | |
channels) | |
""" | |
def __init__(self, | |
output_dim=128, | |
z_dim=128, | |
class_embed_dim=128, | |
channel_width=128, | |
num_classes=1000, | |
layers=[(False, 16, 16), | |
(True, 16, 16), | |
(False, 16, 16), | |
(True, 16, 8), | |
(False, 8, 8), | |
(True, 8, 4), | |
(False, 4, 4), | |
(True, 4, 2), | |
(False, 2, 2), | |
(True, 2, 1)], | |
attention_layer_position=8, | |
eps=1e-4, | |
n_stats=51): | |
"""Constructs BigGANConfig.""" | |
self.output_dim = output_dim | |
self.z_dim = z_dim | |
self.class_embed_dim = class_embed_dim | |
self.channel_width = channel_width | |
self.num_classes = num_classes | |
self.layers = layers | |
self.attention_layer_position = attention_layer_position | |
self.eps = eps | |
self.n_stats = n_stats | |
@classmethod | |
def from_dict(cls, json_object): | |
"""Constructs a `BigGANConfig` from a Python dictionary of parameters.""" | |
config = BigGANConfig() | |
for key, value in json_object.items(): | |
config.__dict__[key] = value | |
return config | |
@classmethod | |
def from_json_file(cls, json_file): | |
"""Constructs a `BigGANConfig` from a json file of parameters.""" | |
with open(json_file, 'r', encoding='utf-8') as reader: | |
text = reader.read() | |
return cls.from_dict(json.loads(text)) | |
def __repr__(self): | |
return str(self.to_json_string()) | |
def to_dict(self): | |
"""Serializes this instance to a Python dictionary.""" | |
output = copy.deepcopy(self.__dict__) | |
return output | |
def to_json_string(self): | |
"""Serializes this instance to a JSON string.""" | |
return json.dumps(self.to_dict(), indent=4, sort_keys=True) + '\n' | |
def snconv2d(eps=1e-12, **kwargs): | |
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) | |
def snlinear(eps=1e-12, **kwargs): | |
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) | |
def sn_embedding(eps=1e-12, **kwargs): | |
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) | |
class SelfAttn(nn.Module): | |
"""Self-attention layer.""" | |
def __init__(self, in_channels, eps=1e-12): | |
super().__init__() | |
self.in_channels = in_channels | |
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, | |
kernel_size=1, bias=False, eps=eps) | |
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, | |
kernel_size=1, bias=False, eps=eps) | |
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels // 2, | |
kernel_size=1, bias=False, eps=eps) | |
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels, | |
kernel_size=1, bias=False, eps=eps) | |
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) | |
self.softmax = nn.Softmax(dim=-1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
def forward(self, x): | |
_, ch, h, w = x.size() | |
# Theta path | |
theta = self.snconv1x1_theta(x) | |
theta = theta.view(-1, ch // 8, h * w) | |
# Phi path | |
phi = self.snconv1x1_phi(x) | |
phi = self.maxpool(phi) | |
phi = phi.view(-1, ch // 8, h * w // 4) | |
# Attn map | |
attn = torch.bmm(theta.permute(0, 2, 1), phi) | |
attn = self.softmax(attn) | |
# g path | |
g = self.snconv1x1_g(x) | |
g = self.maxpool(g) | |
g = g.view(-1, ch // 2, h * w // 4) | |
# Attn_g - o_conv | |
attn_g = torch.bmm(g, attn.permute(0, 2, 1)) | |
attn_g = attn_g.view(-1, ch // 2, h, w) | |
attn_g = self.snconv1x1_o_conv(attn_g) | |
# Out | |
out = x + self.gamma * attn_g | |
return out | |
class BigGANBatchNorm(nn.Module): | |
"""This is a batch norm module that can handle conditional input and can be provided with | |
pre-computed activation means and variances for various truncation parameters. | |
We cannot just rely on torch.batch_norm since it cannot handle batched weights (pytorch | |
1.0.1). We compute batch_norm ourself without updating running means and variances. | |
If you want to train this model you should add running means and variance computation | |
logic. | |
""" | |
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, | |
conditional=True): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.conditional = conditional | |
# We use pre-computed statistics for n_stats values of truncation between 0 and 1 | |
self.register_buffer('running_means', torch.zeros(n_stats, num_features)) | |
self.register_buffer('running_vars', torch.ones(n_stats, num_features)) | |
self.step_size = 1. / (n_stats - 1) | |
if conditional: | |
assert condition_vector_dim is not None | |
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, | |
bias=False, eps=eps) | |
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, | |
bias=False, eps=eps) | |
else: | |
self.weight = nn.Parameter(torch.Tensor(num_features)) | |
self.bias = nn.Parameter(torch.Tensor(num_features)) | |
def forward(self, x, truncation, condition_vector=None): | |
# Retreive pre-computed statistics associated to this truncation | |
coef, start_idx = math.modf(truncation / self.step_size) | |
start_idx = int(start_idx) | |
if coef: # Interpolate | |
running_mean = self.running_means[start_idx] * coef + \ | |
self.running_means[start_idx + 1] * (1 - coef) | |
running_var = self.running_vars[start_idx] * coef + \ | |
self.running_vars[start_idx + 1] * (1 - coef) | |
else: | |
running_mean = self.running_means[start_idx] | |
running_var = self.running_vars[start_idx] | |
if self.conditional: | |
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | |
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | |
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1) | |
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1) | |
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias | |
else: | |
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, | |
training=False, momentum=0., eps=self.eps) | |
return out | |
class GenBlock(nn.Module): | |
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, | |
up_sample=False, n_stats=51, eps=1e-12): | |
super().__init__() | |
self.up_sample = up_sample | |
self.drop_channels = (in_size != out_size) | |
middle_size = in_size // reduction_factor | |
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, | |
eps=eps, conditional=True) | |
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, | |
eps=eps) | |
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, | |
eps=eps, conditional=True) | |
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, | |
padding=1, eps=eps) | |
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, | |
eps=eps, conditional=True) | |
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, | |
padding=1, eps=eps) | |
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, | |
eps=eps, conditional=True) | |
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, | |
eps=eps) | |
self.relu = nn.ReLU() | |
def forward(self, x, cond_vector, truncation): | |
x0 = x | |
x = self.bn_0(x, truncation, cond_vector) | |
x = self.relu(x) | |
x = self.conv_0(x) | |
x = self.bn_1(x, truncation, cond_vector) | |
x = self.relu(x) | |
if self.up_sample: | |
x = F.interpolate(x, scale_factor=2, mode='nearest') | |
x = self.conv_1(x) | |
x = self.bn_2(x, truncation, cond_vector) | |
x = self.relu(x) | |
x = self.conv_2(x) | |
x = self.bn_3(x, truncation, cond_vector) | |
x = self.relu(x) | |
x = self.conv_3(x) | |
if self.drop_channels: | |
new_channels = x0.shape[1] // 2 | |
x0 = x0[:, :new_channels, ...] | |
if self.up_sample: | |
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') | |
out = x + x0 | |
return out | |
class Generator(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
ch = config.channel_width | |
condition_vector_dim = config.z_dim * 2 | |
self.gen_z = snlinear(in_features=condition_vector_dim, | |
out_features=4 * 4 * 16 * ch, eps=config.eps) | |
layers = [] | |
for i, layer in enumerate(config.layers): | |
if i == config.attention_layer_position: | |
layers.append(SelfAttn(ch * layer[1], eps=config.eps)) | |
layers.append(GenBlock(ch * layer[1], | |
ch * layer[2], | |
condition_vector_dim, | |
up_sample=layer[0], | |
n_stats=config.n_stats, | |
eps=config.eps)) | |
self.layers = nn.ModuleList(layers) | |
self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False) | |
self.relu = nn.ReLU() | |
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, | |
eps=config.eps) | |
self.tanh = nn.Tanh() | |
def forward(self, cond_vector, truncation): | |
z = self.gen_z(cond_vector[:, 0]) | |
# We use this conversion step to be able to use TF weights: | |
# TF convention on shape is [batch, height, width, channels] | |
# PT convention on shape is [batch, channels, height, width] | |
z = z.view(-1, 4, 4, 16 * self.config.channel_width) | |
z = z.permute(0, 3, 1, 2).contiguous() | |
for i, layer in enumerate(self.layers): | |
if isinstance(layer, GenBlock): | |
z = layer(z, cond_vector[:, i+1], truncation) | |
else: | |
z = layer(z) | |
z = self.bn(z, truncation) | |
z = self.relu(z) | |
z = self.conv_to_rgb(z) | |
z = z[:, :3, ...] | |
z = self.tanh(z) | |
return (z + 1) / 2 | |
class BigGAN(nn.Module): | |
"""BigGAN Generator.""" | |
@classmethod | |
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): | |
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: | |
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] | |
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] | |
else: | |
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | |
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) | |
try: | |
resolved_model_file = cached_path(model_file, cache_dir=cache_dir) | |
resolved_config_file = cached_path(config_file, cache_dir=cache_dir) | |
except EnvironmentError: | |
msg = 'Wrong model name, should be a valid path to a folder containing a {} file ' \ | |
'and a {} file or a model name in {}' | |
logger.error(msg.format(WEIGHTS_NAME, CONFIG_NAME, | |
PRETRAINED_MODEL_ARCHIVE_MAP.keys())) | |
raise | |
logger.info('loading model {} from cache at {}'.format( | |
pretrained_model_name_or_path, resolved_model_file)) | |
# Load config | |
config = BigGANConfig.from_json_file(resolved_config_file) | |
logger.info('Model config {}'.format(config)) | |
# Instantiate model. | |
model = cls(config, *inputs, **kwargs) | |
state_dict = torch.load(resolved_model_file, map_location='cpu') | |
model.load_state_dict(state_dict, strict=False) | |
return model | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False) | |
self.generator = Generator(config) | |
def forward(self, z, class_label, truncation): | |
assert 0 < truncation <= 1 | |
n, s, c = class_label.shape | |
embed = self.embeddings(class_label.view([-1, c])).view([n, s, -1]) | |
cond_vector = torch.cat((z, embed), dim=2) | |
z = self.generator(cond_vector, truncation) | |
return z |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""BigGAN + CLIP, Langevin dynamics method.""" | |
import argparse | |
from copy import deepcopy | |
import clip | |
import torch | |
from torch import distributions as dists, nn | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from rich import print | |
from rich.align import Align | |
from rich.panel import Panel | |
from rich.traceback import install | |
from biggan import BigGAN | |
@torch.no_grad() | |
def ema_update(model, averaged_model, decay): | |
"""Incorporates updated model parameters into an exponential moving averaged | |
version of a model. It should be called after each optimizer step.""" | |
model_params = dict(model.named_parameters()) | |
averaged_params = dict(averaged_model.named_parameters()) | |
assert model_params.keys() == averaged_params.keys() | |
for name, param in model_params.items(): | |
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) | |
model_buffers = dict(model.named_buffers()) | |
averaged_buffers = dict(averaged_model.named_buffers()) | |
assert model_buffers.keys() == averaged_buffers.keys() | |
for name, buf in model_buffers.items(): | |
averaged_buffers[name].copy_(buf) | |
class Latent(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.z = nn.Parameter(torch.randn([1, 16, 128])) | |
self.cls = nn.Parameter(torch.randn([1, 16, 1000])) | |
def z_prior(self): | |
return dists.Normal(0, 1).log_prob(self.z).sum() | |
def cls_prior(self): | |
return dists.Normal(0, 1).log_prob(self.cls).sum() | |
def forward(self): | |
return self.z, self.cls / 1000 ** 0.5 | |
def endless_range(start=0, step=1): | |
i = start | |
while True: | |
yield i | |
i = i + step | |
def main(): | |
install(max_frames=10) | |
p = argparse.ArgumentParser(description=__doc__, | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
p.add_argument('prompt', type=str, | |
help='the prompt to maximize') | |
p.add_argument('--clip', type=str, default='ViT-B/16', choices=clip.available_models(), | |
help='the CLIP model') | |
p.add_argument('--eps', type=float, default=0.05, | |
help='the initial step size') | |
p.add_argument('--gamma', type=float, default=1., | |
help='the decay rate for the step size') | |
p.add_argument('--kappa', type=float, default=5000., | |
help='the CLIP guidance scale') | |
p.add_argument('--ema-decay', type=float, default=0.95, | |
help='the decay rate for iterate averaging') | |
p.add_argument('--cutn', type=int, default=64, | |
help='the number of random crops shown to CLIP per iteration') | |
p.add_argument('--display-freq', type=int, default=25, | |
help='display every this many steps') | |
p.add_argument('--seed', type=int, default=0, | |
help='the random seed') | |
p.add_argument('--verbose', '-v', action='store_true', | |
help='print out more information') | |
args = p.parse_args() | |
print(Panel(Align('BigGAN + CLIP, Langevin dynamics method', 'center'))) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f'Using device: "{device}"') | |
model = BigGAN.from_pretrained('biggan-deep-512').to(device).eval().requires_grad_(False) | |
side_x = side_y = 512 | |
perceptor = clip.load(args.clip)[0].to(device).eval().requires_grad_(False) | |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
clip_size = perceptor.visual.input_resolution | |
print(f'Using seed: {args.seed}') | |
torch.manual_seed(args.seed) | |
latent = Latent().to(device) | |
latent_ema = deepcopy(latent) | |
beta, lmbda = 0.99, 1e-5 | |
t = torch.tensor(0., device=device) | |
v = torch.tensor(0., device=device) | |
toks = clip.tokenize(args.prompt, truncate=True) | |
text_embed = perceptor.encode_text(toks.to(device)).float() | |
@torch.no_grad() | |
def checkin(i, energies): | |
print(f'step: {i}, total energy: {sum(energies).item():g}, z: {energies[0].item():g}, cls: {energies[1].item():g}, prompts:', ' '.join(f'{e.item():g}' for e in energies[2:])) | |
image = model(*latent_ema(), 1) | |
TF.to_pil_image(image[0].cpu()).save(f'out_{i:05}.png') | |
def eval_energies(): | |
energies = [latent.z_prior(), latent.cls_prior()] | |
out = model(*latent(), 1) | |
crops = [] | |
for ch in range(args.cutn): | |
size = max(int(side_x * dists.Normal(.8, .3).sample([]).clamp(.5, .95)), clip_size) | |
offsetx = torch.randint(0, side_x - size + 1, ()) | |
offsety = torch.randint(0, side_y - size + 1, ()) | |
crop = out[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
crop = F.interpolate(crop, (clip_size, clip_size), mode='bilinear', align_corners=False, antialias=True) | |
crops.append(crop) | |
crops = torch.cat(crops) | |
image_embeds = perceptor.encode_image(normalize(crops)).float() | |
energies.append(args.kappa * torch.cosine_similarity(image_embeds, text_embed, dim=-1).mean()) | |
return energies | |
def step(i): | |
nonlocal t, v | |
energies = eval_energies() | |
if i % args.display_freq == 0: | |
checkin(i, energies) | |
loss = sum(energies) | |
latent.zero_grad() | |
loss.backward() | |
eps = args.eps * (1 + t * args.gamma) ** -1 | |
with torch.no_grad(): | |
grad_mean_sq = sum(p.grad.pow(2).sum() for p in latent.parameters()) / sum(p.grad.numel() for p in latent.parameters()) | |
v.mul_(beta).add_(grad_mean_sq, alpha=1 - beta) | |
v_hat = v / (1 - beta ** (i + 1)) | |
g = 1 / (v_hat.sqrt() + lmbda) | |
if args.verbose and i % 25 == 0: | |
print(f'step: {i}, t: {t.item():g}, g: {g.item():g}, eps: {eps.item():g}') | |
t += g * eps | |
for p in latent.parameters(): | |
p += (g * eps / 2) * p.grad | |
p += (g * eps) ** 0.5 * torch.randn_like(p) | |
ema_update(latent, latent_ema, args.ema_decay) | |
try: | |
for i in endless_range(): | |
step(i) | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment