Last active
December 20, 2022 11:27
-
-
Save pbaylies/671ef8434fd11f056bab4330e0e7c365 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Modified StyleGAN2 Projector with CLIP, addl. losses, kmeans, etc. | |
# by Peter Baylies, 2021 -- @pbaylies on Twitter | |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
"""Project given image to the latent space of pretrained network pickle.""" | |
import copy | |
import os | |
from time import perf_counter | |
import click | |
import imageio | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn.functional as F | |
import clip | |
import dnnlib | |
import legacy | |
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() | |
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() | |
def score_images(G, model, text, latents, device, label_class = 0, batch_size = 8): | |
scores = [] | |
all_images = [] | |
for i in range(latents.shape[0]//batch_size): | |
images = G.synthesis(torch.tensor(latents[i*batch_size:(i+1)*batch_size,:,:], dtype=torch.float32, device=device), noise_mode='const') | |
with torch.no_grad(): | |
image_input = (torch.clamp(images, -1, 1) + 1) * 0.5 | |
image_input = F.interpolate(image_input, size=(256, 256), mode='area') | |
image_input = image_input[:, :, 16:240, 16:240] # 256 -> 224, center crop | |
image_input -= image_mean[None, :, None, None] | |
image_input /= image_std[None, :, None, None] | |
score = model(image_input, text)[0] | |
scores.append(score.cpu().numpy()) | |
all_images.append(images.cpu().numpy()) | |
scores = np.array(scores) | |
scores = scores.reshape(-1, *scores.shape[2:]).squeeze() | |
scores = 1 - scores / np.linalg.norm(scores) | |
all_images = np.array(all_images) | |
all_images = all_images.reshape(-1, *all_images.shape[2:]) | |
return scores, all_images | |
def project( | |
G, | |
target_image: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution | |
target_text, | |
*, | |
num_steps = 300, | |
w_avg_samples = 8192, | |
initial_learning_rate = 0.02, | |
initial_latent = None, | |
initial_noise_factor = 0.01, | |
lr_rampdown_length = 0.10, | |
lr_rampup_length = 0.5, | |
noise_ramp_length = 0.75, | |
latent_range = 2.0, | |
max_noise = 0.5, | |
min_threshold = 0.6, | |
use_vgg = True, | |
use_clip = True, | |
use_pixel = True, | |
use_penalty = True, | |
use_center = True, | |
regularize_noise_weight = 1e5, | |
kmeans = True, | |
kmeans_clusters = 64, | |
verbose = False, | |
device: torch.device | |
): | |
if target_image is not None: | |
assert target_image.shape == (G.img_channels, G.img_resolution, G.img_resolution) | |
else: | |
use_vgg = False | |
use_pixel = False | |
def logprint(*args): | |
if verbose: | |
print(*args) | |
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore | |
# Compute w stats. | |
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') | |
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) | |
labels = None | |
if (G.mapping.c_dim): | |
labels = torch.from_numpy(0.5*np.random.RandomState(123).randn(w_avg_samples, G.mapping.c_dim)).to(device) | |
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), labels) # [N, L, C] | |
w_samples = w_samples.cpu().numpy().astype(np.float32) # [N, L, C] | |
w_samples_1d = w_samples[:, :1, :].astype(np.float32) | |
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, L, C] | |
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 | |
kmeans_latents = None | |
if initial_latent is not None: | |
w_avg = initial_latent | |
else: | |
if kmeans and use_clip and target_text is not None: | |
from kmeans_pytorch import kmeans | |
# data | |
data_size, dims, num_clusters = w_avg_samples, G.z_dim, kmeans_clusters | |
x = w_samples_1d | |
x = torch.from_numpy(x) | |
# kmeans | |
logprint(f'Performing kmeans clustering using {w_avg_samples} latents into {kmeans_clusters} clusters...') | |
cluster_ids_x, cluster_centers = kmeans( | |
X=x, num_clusters=num_clusters, distance='euclidean', device=device | |
) | |
#logprint(f'\nGenerating images from kmeans latents...') | |
kmeans_latents = torch.tensor(cluster_centers, dtype=torch.float32, device=device, requires_grad=True) | |
# Setup noise inputs. | |
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } | |
# Load VGG16 feature detector. | |
if use_vgg: | |
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' | |
with dnnlib.util.open_url(url) as f: | |
vgg16 = torch.jit.load(f).eval().to(device) | |
# Load CLIP | |
if use_clip: | |
model, transform = clip.load("ViT-B/32", device=device) | |
# Features for target image. | |
if target_image is not None: | |
target_images = target_image.unsqueeze(0).to(device).to(torch.float32) | |
small_target = F.interpolate(target_images, size=(64, 64), mode='area') | |
if use_center: | |
center_target = F.interpolate(target_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336] | |
target_images = F.interpolate(target_images, size=(256, 256), mode='area') | |
target_images = target_images[:, :, 16:240, 16:240] # 256 -> 224, center crop | |
if use_vgg: | |
vgg_target_features = vgg16(target_images, resize_images=False, return_lpips=True) | |
if use_center: | |
vgg_target_center = vgg16(center_target, resize_images=False, return_lpips=True) | |
if use_clip: | |
if target_image is not None: | |
with torch.no_grad(): | |
clip_target_features = model.encode_image(((target_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float() | |
if use_center: | |
clip_target_center = model.encode_image(((center_target / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float() | |
if kmeans_latents is not None and use_clip and target_text is not None: | |
scores, kmeans_images = score_images(G, model, target_text, kmeans_latents.repeat([1, G.mapping.num_ws, 1]), device=device) | |
ind = np.argpartition(scores, 4)[:4] | |
w_avg = torch.median(kmeans_latents[ind],dim=0,keepdim=True)[0].repeat([1, G.mapping.num_ws, 1]) | |
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable | |
w_avg_tensor = w_opt.clone() | |
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) | |
optimizer = torch.optim.AdamW([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) | |
# Init noise. | |
for buf in noise_bufs.values(): | |
buf[:] = torch.randn_like(buf) | |
buf.requires_grad = True | |
for step in range(num_steps): | |
# Learning rate schedule. | |
t = step / num_steps | |
w_noise_scale = max_noise * w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 | |
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) | |
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) | |
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) | |
lr = initial_learning_rate * lr_ramp | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
# Synth images from opt_w. | |
w_noise = torch.randn_like(w_opt) * w_noise_scale | |
ws = w_opt + w_noise | |
synth_images = G.synthesis(torch.clamp(ws,-latent_range,latent_range), noise_mode='const') | |
# Downsample image to 256x256 if it's larger than that. CLIP was built for 224x224 images. | |
synth_images = (torch.clamp(synth_images, -1, 1) + 1) * (255/2) | |
small_synth = F.interpolate(synth_images, size=(64, 64), mode='area') | |
if use_center: | |
center_synth = F.interpolate(synth_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336] | |
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') | |
# Features for synth images. | |
synth_images = synth_images[:, :, 16:240, 16:240] # 256 -> 224, center crop | |
dist = 0 | |
if use_vgg: | |
vgg_synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) | |
vgg_dist = (vgg_target_features - vgg_synth_features).square().sum() | |
if use_center: | |
vgg_synth_center = vgg16(center_synth, resize_images=False, return_lpips=True) | |
vgg_dist += (vgg_target_center - vgg_synth_center).square().sum() | |
vgg_dist *= 6 | |
dist += F.relu(vgg_dist*vgg_dist - min_threshold) | |
if use_clip: | |
clip_synth_image = ((synth_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None] | |
clip_synth_features = model.encode_image(clip_synth_image).float() | |
adj_center = 2.0 | |
if use_center: | |
clip_cynth_center_image = ((center_synth / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None] | |
adj_center = 1.0 | |
clip_synth_center = model.encode_image(clip_cynth_center_image).float() | |
if target_image is not None: | |
clip_dist = (clip_target_features - clip_synth_features).square().sum() | |
if use_center: | |
clip_dist += (clip_target_center - clip_synth_center).square().sum() | |
dist += F.relu(0.5 + adj_center*clip_dist - min_threshold) | |
if target_text is not None: | |
clip_text = 1 - model(clip_synth_image, target_text)[0].sum() / 100 | |
if use_center: | |
clip_text += 1 - model(clip_cynth_center_image, target_text)[0].sum() / 100 | |
dist += 2*F.relu(adj_center*clip_text*clip_text - min_threshold / adj_center) | |
if use_pixel: | |
pixel_dist = (target_images - synth_images).abs().sum() / 2000000.0 | |
if use_center: | |
pixel_dist += (center_target - center_synth).abs().sum() / 2000000.0 | |
pixel_dist += (small_target - small_synth).square().sum() / 800000.0 | |
pixel_dist /= 4 | |
dist += F.relu(lr_ramp * pixel_dist - min_threshold) | |
if use_penalty: | |
l1_penalty = (w_opt - w_avg_tensor).abs().sum() / 5000.0 | |
dist += F.relu(lr_ramp * l1_penalty - min_threshold) | |
# Noise regularization. | |
reg_loss = 0.0 | |
for v in noise_bufs.values(): | |
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() | |
while True: | |
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 | |
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 | |
if noise.shape[2] <= 8: | |
break | |
noise = F.avg_pool2d(noise, kernel_size=2) | |
#print(vgg_dist, clip_dist, pixel_dist, l1_penalty, reg_loss * regularize_noise_weight) | |
loss = dist + reg_loss * regularize_noise_weight | |
# Step | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') | |
with torch.no_grad(): | |
torch.clamp(w_opt,-latent_range,latent_range,out=w_opt) | |
# Save projected W for each optimization step. | |
w_out[step] = w_opt.detach()[0] | |
# Normalize noise. | |
with torch.no_grad(): | |
for buf in noise_bufs.values(): | |
buf -= buf.mean() | |
buf *= buf.square().mean().rsqrt() | |
return w_out | |
#---------------------------------------------------------------------------- | |
@click.command() | |
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) | |
@click.option('--target-image', 'target_fname', help='Target image file to project to', required=False, metavar='FILE', default=None) | |
@click.option('--target-text', help='Target text to project to', required=False, default=None) | |
@click.option('--initial-latent', help='Initial latent', default=None) | |
@click.option('--lr', help='Learning rate', type=float, default=0.1, show_default=True) | |
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True) | |
@click.option('--seed', help='Random seed', type=int, default=303, show_default=True) | |
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True) | |
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR') | |
@click.option('--use-vgg', help='Use VGG16 in the loss', type=bool, default=True, show_default=True) | |
@click.option('--use-clip', help='Use CLIP in the loss', type=bool, default=True, show_default=True) | |
@click.option('--use-pixel', help='Use L1/L2 distance on pixels in the loss', type=bool, default=True, show_default=True) | |
@click.option('--use-penalty', help='Use a penalty on latent values distance from the mean in the loss', type=bool, default=True, show_default=True) | |
@click.option('--use-center', help='Optimize against an additional center image crop', type=bool, default=True, show_default=True) | |
@click.option('--kmeans', help='Perform kmeans clustering for selecting initial latents', type=bool, default=True, show_default=True) | |
def run_projection( | |
network_pkl: str, | |
target_fname: str, | |
target_text: str, | |
initial_latent: str, | |
outdir: str, | |
save_video: bool, | |
seed: int, | |
lr: float, | |
num_steps: int, | |
use_vgg: bool, | |
use_clip: bool, | |
use_pixel: bool, | |
use_penalty: bool, | |
use_center: bool, | |
kmeans: bool, | |
): | |
"""Project given image to the latent space of pretrained network pickle. | |
Examples: | |
\b | |
python projector.py --outdir=out --target=~/mytargetimg.png \\ | |
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl | |
""" | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# Load networks. | |
print('Loading networks from "%s"...' % network_pkl) | |
device = torch.device('cuda') | |
with dnnlib.util.open_url(network_pkl) as fp: | |
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore | |
# Load target image. | |
target_image = None | |
if target_fname: | |
target_pil = PIL.Image.open(target_fname).convert('RGB').filter(PIL.ImageFilter.SHARPEN) | |
w, h = target_pil.size | |
s = min(w, h) | |
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) | |
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) | |
target_uint8 = np.array(target_pil, dtype=np.uint8) | |
target_image = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device) | |
if target_text: | |
target_text = torch.cat([clip.tokenize(target_text)]).to(device) | |
if initial_latent is not None: | |
initial_latent = np.load(initial_latent) | |
initial_latent = initial_latent[initial_latent.files[0]] | |
# Optimize projection. | |
start_time = perf_counter() | |
projected_w_steps = project( | |
G, | |
target_image=target_image, | |
target_text=target_text, | |
initial_latent=initial_latent, | |
initial_learning_rate=lr, | |
num_steps=num_steps, | |
use_vgg=use_vgg, | |
use_clip=use_clip, | |
use_pixel=use_pixel, | |
use_penalty=use_penalty, | |
use_center=use_center, | |
kmeans=kmeans, | |
device=device, | |
verbose=True | |
) | |
print (f'Elapsed: {(perf_counter()-start_time):.1f} s') | |
# Render debug output: optional video and projected image and W vector. | |
os.makedirs(outdir, exist_ok=True) | |
if save_video: | |
video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') | |
print (f'Saving optimization progress video "{outdir}/proj.mp4"') | |
for projected_w in projected_w_steps: | |
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') | |
synth_image = (synth_image + 1) * (255/2) | |
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
if target_fname: | |
video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) | |
else: | |
video.append_data(synth_image) | |
video.close() | |
# Save final projected frame and W vector. | |
if target_fname: | |
target_pil.save(f'{outdir}/target.png') | |
projected_w = projected_w_steps[-1] | |
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') | |
synth_image = (synth_image + 1) * (255/2) | |
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png') | |
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy()) | |
#---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
run_projection() # pylint: disable=no-value-for-parameter | |
#---------------------------------------------------------------------------- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment