Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active November 1, 2022 06:11
Show Gist options
  • Save crowsonkb/043684ffc3370c321d2e2ee64cf500a4 to your computer and use it in GitHub Desktop.
Save crowsonkb/043684ffc3370c321d2e2ee64cf500a4 to your computer and use it in GitHub Desktop.
BigGAN + CLIP, Langevin dynamics method
import copy
from functools import wraps
from hashlib import sha256
from io import open
import json
import math
import logging
import os
from pathlib import Path
import shutil
import sys
import tempfile
from urllib.parse import urlparse
import boto3
from botocore.exceptions import ClientError
import requests
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
Path.home() / '.cache/pytorch_pretrained_biggan'))
PRETRAINED_MODEL_ARCHIVE_MAP = {
'biggan-deep-128':
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin',
'biggan-deep-256':
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin',
'biggan-deep-512':
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin',
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'biggan-deep-128':
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json',
'biggan-deep-256':
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json',
'biggan-deep-512':
'https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json',
}
WEIGHTS_NAME = 'pytorch_model.bin'
CONFIG_NAME = 'config.json'
logger = logging.getLogger(__name__)
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError('file {} not found'.format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError('file {} not found'.format(meta_path))
with open(meta_path, encoding='utf-8') as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError('file {} not found'.format(url_or_filename))
else:
# Something unknown
raise ValueError('unable to parse {} as a URL or as a local path'.format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError('bad s3 path {}'.format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response['Error']['Code']) == 404:
raise EnvironmentError('file {} not found'.format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource('s3')
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource('s3')
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith('s3://'):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError('HEAD request failed for url {} with status code {}'
.format(url, response.status_code))
etag = response.headers.get('ETag')
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info('%s not found in cache, downloading to %s', url, temp_file.name)
# GET file object
if url.startswith('s3://'):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info('copying %s to cache at %s', temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info('creating metadata file for %s', cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding='utf-8') as meta_file:
json.dump(meta, meta_file)
logger.info('removing temp file %s', temp_file.name)
return cache_path
def read_set_from_file(filename):
"""
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
"""
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
class BigGANConfig(object):
"""Configuration class to store the configuration of a `BigGAN`. Defaults are for the
128x128 model. The Layers tuple is (up-sample in the layer?, input channels, output
channels)
"""
def __init__(self,
output_dim=128,
z_dim=128,
class_embed_dim=128,
channel_width=128,
num_classes=1000,
layers=[(False, 16, 16),
(True, 16, 16),
(False, 16, 16),
(True, 16, 8),
(False, 8, 8),
(True, 8, 4),
(False, 4, 4),
(True, 4, 2),
(False, 2, 2),
(True, 2, 1)],
attention_layer_position=8,
eps=1e-4,
n_stats=51):
"""Constructs BigGANConfig."""
self.output_dim = output_dim
self.z_dim = z_dim
self.class_embed_dim = class_embed_dim
self.channel_width = channel_width
self.num_classes = num_classes
self.layers = layers
self.attention_layer_position = attention_layer_position
self.eps = eps
self.n_stats = n_stats
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BigGANConfig` from a Python dictionary of parameters."""
config = BigGANConfig()
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BigGANConfig` from a json file of parameters."""
with open(json_file, 'r', encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=4, sort_keys=True) + '\n'
def snconv2d(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
def snlinear(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
def sn_embedding(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
class SelfAttn(nn.Module):
"""Self-attention layer."""
def __init__(self, in_channels, eps=1e-12):
super().__init__()
self.in_channels = in_channels
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels // 8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels // 8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels // 2,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
kernel_size=1, bias=False, eps=eps)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch // 8, h * w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch // 8, h * w // 4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch // 2, h * w // 4)
# Attn_g - o_conv
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch // 2, h, w)
attn_g = self.snconv1x1_o_conv(attn_g)
# Out
out = x + self.gamma * attn_g
return out
class BigGANBatchNorm(nn.Module):
"""This is a batch norm module that can handle conditional input and can be provided with
pre-computed activation means and variances for various truncation parameters.
We cannot just rely on torch.batch_norm since it cannot handle batched weights (pytorch
1.0.1). We compute batch_norm ourself without updating running means and variances.
If you want to train this model you should add running means and variance computation
logic.
"""
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4,
conditional=True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.conditional = conditional
# We use pre-computed statistics for n_stats values of truncation between 0 and 1
self.register_buffer('running_means', torch.zeros(n_stats, num_features))
self.register_buffer('running_vars', torch.ones(n_stats, num_features))
self.step_size = 1. / (n_stats - 1)
if conditional:
assert condition_vector_dim is not None
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features,
bias=False, eps=eps)
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features,
bias=False, eps=eps)
else:
self.weight = nn.Parameter(torch.Tensor(num_features))
self.bias = nn.Parameter(torch.Tensor(num_features))
def forward(self, x, truncation, condition_vector=None):
# Retreive pre-computed statistics associated to this truncation
coef, start_idx = math.modf(truncation / self.step_size)
start_idx = int(start_idx)
if coef: # Interpolate
running_mean = self.running_means[start_idx] * coef + \
self.running_means[start_idx + 1] * (1 - coef)
running_var = self.running_vars[start_idx] * coef + \
self.running_vars[start_idx + 1] * (1 - coef)
else:
running_mean = self.running_means[start_idx]
running_var = self.running_vars[start_idx]
if self.conditional:
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
else:
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
training=False, momentum=0., eps=self.eps)
return out
class GenBlock(nn.Module):
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4,
up_sample=False, n_stats=51, eps=1e-12):
super().__init__()
self.up_sample = up_sample
self.drop_channels = (in_size != out_size)
middle_size = in_size // reduction_factor
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats,
eps=eps, conditional=True)
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1,
eps=eps)
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats,
eps=eps, conditional=True)
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3,
padding=1, eps=eps)
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats,
eps=eps, conditional=True)
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3,
padding=1, eps=eps)
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats,
eps=eps, conditional=True)
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1,
eps=eps)
self.relu = nn.ReLU()
def forward(self, x, cond_vector, truncation):
x0 = x
x = self.bn_0(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_0(x)
x = self.bn_1(x, truncation, cond_vector)
x = self.relu(x)
if self.up_sample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv_1(x)
x = self.bn_2(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_2(x)
x = self.bn_3(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_3(x)
if self.drop_channels:
new_channels = x0.shape[1] // 2
x0 = x0[:, :new_channels, ...]
if self.up_sample:
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
out = x + x0
return out
class Generator(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
ch = config.channel_width
condition_vector_dim = config.z_dim * 2
self.gen_z = snlinear(in_features=condition_vector_dim,
out_features=4 * 4 * 16 * ch, eps=config.eps)
layers = []
for i, layer in enumerate(config.layers):
if i == config.attention_layer_position:
layers.append(SelfAttn(ch * layer[1], eps=config.eps))
layers.append(GenBlock(ch * layer[1],
ch * layer[2],
condition_vector_dim,
up_sample=layer[0],
n_stats=config.n_stats,
eps=config.eps))
self.layers = nn.ModuleList(layers)
self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
self.relu = nn.ReLU()
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1,
eps=config.eps)
self.tanh = nn.Tanh()
def forward(self, cond_vector, truncation):
z = self.gen_z(cond_vector[:, 0])
# We use this conversion step to be able to use TF weights:
# TF convention on shape is [batch, height, width, channels]
# PT convention on shape is [batch, channels, height, width]
z = z.view(-1, 4, 4, 16 * self.config.channel_width)
z = z.permute(0, 3, 1, 2).contiguous()
for i, layer in enumerate(self.layers):
if isinstance(layer, GenBlock):
z = layer(z, cond_vector[:, i+1], truncation)
else:
z = layer(z)
z = self.bn(z, truncation)
z = self.relu(z)
z = self.conv_to_rgb(z)
z = z[:, :3, ...]
z = self.tanh(z)
return (z + 1) / 2
class BigGAN(nn.Module):
"""BigGAN Generator."""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
try:
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
msg = 'Wrong model name, should be a valid path to a folder containing a {} file ' \
'and a {} file or a model name in {}'
logger.error(msg.format(WEIGHTS_NAME, CONFIG_NAME,
PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
raise
logger.info('loading model {} from cache at {}'.format(
pretrained_model_name_or_path, resolved_model_file))
# Load config
config = BigGANConfig.from_json_file(resolved_config_file)
logger.info('Model config {}'.format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
state_dict = torch.load(resolved_model_file, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
return model
def __init__(self, config):
super().__init__()
self.config = config
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
self.generator = Generator(config)
def forward(self, z, class_label, truncation):
assert 0 < truncation <= 1
n, s, c = class_label.shape
embed = self.embeddings(class_label.view([-1, c])).view([n, s, -1])
cond_vector = torch.cat((z, embed), dim=2)
z = self.generator(cond_vector, truncation)
return z
#!/usr/bin/env python3
"""BigGAN + CLIP, Langevin dynamics method."""
import argparse
from copy import deepcopy
import clip
import torch
from torch import distributions as dists, nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from rich import print
from rich.align import Align
from rich.panel import Panel
from rich.traceback import install
from biggan import BigGAN
@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
version of a model. It should be called after each optimizer step."""
model_params = dict(model.named_parameters())
averaged_params = dict(averaged_model.named_parameters())
assert model_params.keys() == averaged_params.keys()
for name, param in model_params.items():
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
model_buffers = dict(model.named_buffers())
averaged_buffers = dict(averaged_model.named_buffers())
assert model_buffers.keys() == averaged_buffers.keys()
for name, buf in model_buffers.items():
averaged_buffers[name].copy_(buf)
class Latent(nn.Module):
def __init__(self):
super().__init__()
self.z = nn.Parameter(torch.randn([1, 16, 128]))
self.cls = nn.Parameter(torch.randn([1, 16, 1000]))
def z_prior(self):
return dists.Normal(0, 1).log_prob(self.z).sum()
def cls_prior(self):
return dists.Normal(0, 1).log_prob(self.cls).sum()
def forward(self):
return self.z, self.cls / 1000 ** 0.5
def endless_range(start=0, step=1):
i = start
while True:
yield i
i = i + step
def main():
install(max_frames=10)
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('prompt', type=str,
help='the prompt to maximize')
p.add_argument('--clip', type=str, default='ViT-B/16', choices=clip.available_models(),
help='the CLIP model')
p.add_argument('--eps', type=float, default=0.05,
help='the initial step size')
p.add_argument('--gamma', type=float, default=1.,
help='the decay rate for the step size')
p.add_argument('--kappa', type=float, default=5000.,
help='the CLIP guidance scale')
p.add_argument('--ema-decay', type=float, default=0.95,
help='the decay rate for iterate averaging')
p.add_argument('--cutn', type=int, default=64,
help='the number of random crops shown to CLIP per iteration')
p.add_argument('--display-freq', type=int, default=25,
help='display every this many steps')
p.add_argument('--seed', type=int, default=0,
help='the random seed')
p.add_argument('--verbose', '-v', action='store_true',
help='print out more information')
args = p.parse_args()
print(Panel(Align('BigGAN + CLIP, Langevin dynamics method', 'center')))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: "{device}"')
model = BigGAN.from_pretrained('biggan-deep-512').to(device).eval().requires_grad_(False)
side_x = side_y = 512
perceptor = clip.load(args.clip)[0].to(device).eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
clip_size = perceptor.visual.input_resolution
print(f'Using seed: {args.seed}')
torch.manual_seed(args.seed)
latent = Latent().to(device)
latent_ema = deepcopy(latent)
beta, lmbda = 0.99, 1e-5
t = torch.tensor(0., device=device)
v = torch.tensor(0., device=device)
toks = clip.tokenize(args.prompt, truncate=True)
text_embed = perceptor.encode_text(toks.to(device)).float()
@torch.no_grad()
def checkin(i, energies):
print(f'step: {i}, total energy: {sum(energies).item():g}, z: {energies[0].item():g}, cls: {energies[1].item():g}, prompts:', ' '.join(f'{e.item():g}' for e in energies[2:]))
image = model(*latent_ema(), 1)
TF.to_pil_image(image[0].cpu()).save(f'out_{i:05}.png')
def eval_energies():
energies = [latent.z_prior(), latent.cls_prior()]
out = model(*latent(), 1)
crops = []
for ch in range(args.cutn):
size = max(int(side_x * dists.Normal(.8, .3).sample([]).clamp(.5, .95)), clip_size)
offsetx = torch.randint(0, side_x - size + 1, ())
offsety = torch.randint(0, side_y - size + 1, ())
crop = out[:, :, offsety:offsety + size, offsetx:offsetx + size]
crop = F.interpolate(crop, (clip_size, clip_size), mode='bilinear', align_corners=False, antialias=True)
crops.append(crop)
crops = torch.cat(crops)
image_embeds = perceptor.encode_image(normalize(crops)).float()
energies.append(args.kappa * torch.cosine_similarity(image_embeds, text_embed, dim=-1).mean())
return energies
def step(i):
nonlocal t, v
energies = eval_energies()
if i % args.display_freq == 0:
checkin(i, energies)
loss = sum(energies)
latent.zero_grad()
loss.backward()
eps = args.eps * (1 + t * args.gamma) ** -1
with torch.no_grad():
grad_mean_sq = sum(p.grad.pow(2).sum() for p in latent.parameters()) / sum(p.grad.numel() for p in latent.parameters())
v.mul_(beta).add_(grad_mean_sq, alpha=1 - beta)
v_hat = v / (1 - beta ** (i + 1))
g = 1 / (v_hat.sqrt() + lmbda)
if args.verbose and i % 25 == 0:
print(f'step: {i}, t: {t.item():g}, g: {g.item():g}, eps: {eps.item():g}')
t += g * eps
for p in latent.parameters():
p += (g * eps / 2) * p.grad
p += (g * eps) ** 0.5 * torch.randn_like(p)
ema_update(latent, latent_ema, args.ema_decay)
try:
for i in endless_range():
step(i)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment