Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created February 10, 2021 18:46
Show Gist options
  • Save crowsonkb/2373d604eed09e87cd3b4c427830d1b9 to your computer and use it in GitHub Desktop.
Save crowsonkb/2373d604eed09e87cd3b4c427830d1b9 to your computer and use it in GitHub Desktop.
import argparse
import csv
from pathlib import Path
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
from tqdm import tqdm
from ema_debiased import EMA
BATCH_SIZE = 100
BIG_BATCH_SIZE = 500
LATENT_SIZE = 256
PREFIX = Path(__file__).stem
class ResidualBlock(nn.Sequential):
"""A residual block with an identity shortcut connection."""
def forward(self, input):
output = input
for module in self:
output = module(output)
return input + output
class ConvBlock(nn.Sequential):
def __init__(self, c_in, c_out):
super().__init__(
nn.Conv2d(c_in, c_out, 3, padding=1),
nn.ReLU(inplace=True),
)
class ResConvBlock(ResidualBlock):
def __init__(self, c):
super().__init__(
nn.Conv2d(c, c, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(c, c, 3, padding=1),
nn.ReLU(inplace=True),
)
class Discriminator(nn.Sequential):
def __init__(self, c):
super().__init__(
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616]),
ConvBlock(3, c),
ResConvBlock(c),
ResConvBlock(c),
nn.AvgPool2d(2),
ResConvBlock(c),
ResConvBlock(c),
nn.AvgPool2d(2),
ResConvBlock(c),
ResConvBlock(c),
nn.AvgPool2d(2),
nn.Flatten(),
nn.Linear(c * 4 * 4, LATENT_SIZE),
nn.ReLU(inplace=True),
nn.Linear(LATENT_SIZE, 1),
)
class Generator(nn.Sequential):
def __init__(self, c):
super().__init__(
nn.Linear(LATENT_SIZE, c * 4 * 4),
nn.Unflatten(-1, (c, 4, 4)),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
ResConvBlock(c),
ResConvBlock(c),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
ResConvBlock(c),
ResConvBlock(c),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
ResConvBlock(c),
ResConvBlock(c),
nn.Conv2d(c, 3, 3, padding=1),
nn.Sigmoid(),
)
class Normalize(nn.Module):
def forward(self, input):
mean = input.mean(1, keepdims=True)
std = input.std(1, unbiased=False, keepdims=True)
return (input - mean) / std
class Encoder(nn.Sequential):
def __init__(self, c):
super().__init__(
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616]),
ConvBlock(3, c),
ResConvBlock(c),
ResConvBlock(c),
nn.AvgPool2d(2),
ResConvBlock(c),
ResConvBlock(c),
nn.AvgPool2d(2),
ResConvBlock(c),
ResConvBlock(c),
nn.AvgPool2d(2),
nn.Flatten(),
nn.Linear(c * 4 * 4, LATENT_SIZE),
Normalize(),
)
class DiscriminatorLoss(nn.Module):
def forward(self, input, target):
return -F.logsigmoid(target - input.T).mean()
class GeneratorLoss(nn.Module):
def forward(self, input, target):
return -F.logsigmoid(input - target.T).mean()
class IMLELoss(nn.Module):
def forward(self, input, target):
input = input.flatten(1)
target = target.flatten(1)
out = input.unsqueeze(0) - target.unsqueeze(1)
out = out.pow(2).mean(2)
return out.min(1).values.mean()
class MetricsLogger:
def __init__(self, csv_path, *names):
self.names = names
self.csv_file = open(csv_path, 'w')
self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow(self.names)
self.csv_file.flush()
self.clear()
def __str__(self):
return ' '.join(f'{name}={metric:g}' for name, metric in zip(self.names, self.get()))
def clear(self):
self.metrics = [list() for _ in self.names]
def get(self):
return [sum(metrics) / len(metrics) for metrics in self.metrics]
def put(self, *metrics):
for lst, metric in zip(self.metrics, metrics):
lst.append(metric.item() if hasattr(metric, 'item') else metric)
def write(self):
self.csv_writer.writerow(self.get())
self.csv_file.flush()
def nparams(model):
return sum(p.numel() for p in model.parameters())
def main():
p = argparse.ArgumentParser()
p.add_argument('--seed', type=int, default=0,
help='the random seed')
p.add_argument('--checkpoint', type=str, default=None,
help='the checkpoint to restart from')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(args.seed)
tf = transforms.ToTensor()
train_set = datasets.CIFAR10('data/cifar10', download=True, transform=tf)
# train_subset = data.Subset(train_set, range(5000))
train_dl = data.DataLoader(train_set, BATCH_SIZE, shuffle=True,
num_workers=1, pin_memory=True)
val_set = datasets.CIFAR10('data/cifar10', download=True, train=False, transform=tf)
val_dl = data.DataLoader(val_set, BATCH_SIZE, pin_memory=True)
d = Discriminator(32).to(device)
g = Generator(32).to(device)
e = Encoder(32).to(device)
print('D parameters:', nparams(d))
print('G parameters:', nparams(g))
print('E parameters:', nparams(e))
g = EMA(g, 0.99)
d_crit = DiscriminatorLoss()
g_crit = GeneratorLoss()
e_crit = nn.MSELoss()
imle_crit = IMLELoss()
lr = 2e-4
wd = 1e-2
opt_d = optim.AdamW(d.parameters(), lr=lr, weight_decay=wd)
opt_g = optim.AdamW(g.parameters(), lr=lr, weight_decay=wd)
opt_e = optim.AdamW(e.parameters(), lr=lr, weight_decay=wd)
epoch = 1
if args.checkpoint:
state = torch.load(args.checkpoint, map_location=device)
d.load_state_dict(state['d'])
g.load_state_dict(state['g'])
e.load_state_dict(state['e'])
opt_d.load_state_dict(state['opt_d'])
opt_g.load_state_dict(state['opt_g'])
opt_e.load_state_dict(state['opt_e'])
epoch = state['epoch']
def train():
with tqdm(total=len(train_set), unit='examples', dynamic_ncols=True) as pbar:
d.train(), g.train(), e.train()
losses = MetricsLogger(PREFIX + '.csv', 'd', 'g', 'e')
i = 0
for x, _ in train_dl:
i += 1
x = x.to(device, non_blocking=True)
z = torch.randn([BIG_BATCH_SIZE, LATENT_SIZE], device=device)
dx = d(x)
loss_d = d_crit(d(g(z)), dx)
opt_d.zero_grad()
loss_d.backward()
opt_d.step()
z = torch.randn([BIG_BATCH_SIZE, LATENT_SIZE], device=device)
ex = e(x)
gz = g(z)
egz = e(gz)
loss_g = imle_crit(egz, ex.detach())
loss_g += e_crit(g(ex.detach()), x)
loss_g += g_crit(d(gz), dx.detach()) * 1e-2
opt_g.zero_grad()
loss_g.backward()
opt_g.step()
g.update()
loss_e = e_crit(g(ex), x)
opt_e.zero_grad()
loss_e.backward()
opt_e.step()
losses.put(loss_d, loss_g, loss_e)
pbar.update(len(x))
if i % 25 == 0:
tqdm.write(f'{i * BATCH_SIZE} {losses!s}')
losses.write()
losses.clear()
if i % 250 == 0:
demo()
g.train()
@torch.no_grad()
@torch.random.fork_rng()
def demo():
g.eval()
torch.manual_seed(0)
x = next(iter(val_dl))[0][:10].to(device)
z = torch.randn([80, LATENT_SIZE], device=device)
demo = torch.cat([x, g(e(x)), g(z)])
grid = utils.make_grid(demo, 10).cpu()
TF.to_pil_image(grid).save('demo.png')
tqdm.write('Wrote examples to demo.png.')
def save():
torch.save({'d': d.state_dict(),
'g': g.state_dict(),
'e': e.state_dict(),
'opt_d': opt_d.state_dict(),
'opt_g': opt_g.state_dict(),
'opt_e': opt_e.state_dict(),
'epoch': epoch}, PREFIX + '.pth')
print(f'Wrote checkpoint to {PREFIX}.pth.')
try:
while True:
print('Epoch', epoch)
train()
demo()
epoch += 1
save()
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment