Skip to content

Instantly share code, notes, and snippets.

@l4rz
Last active April 10, 2021 17:27
Show Gist options
  • Save l4rz/b42292ef07efdba538e41a56af989a86 to your computer and use it in GitHub Desktop.
Save l4rz/b42292ef07efdba538e41a56af989a86 to your computer and use it in GitHub Desktop.
import copy
import os
import click
import numpy as np
import PIL.Image
import PIL.ImageOps
import torch
import dnnlib
import legacy
from tqdm import tqdm
from kmeans_pytorch import kmeans
#
# copy clusterize.py to stylegan2-ada-pytorch dir and run it from there
# to create the grid use montage from imagemagick, e.g.
# montage *.png -tile 8x8 -geometry 256x256+1+1 montage.jpg
#
# -l4rz
def clusterize(
G,
*,
samples = 25000, # 10k is ok. 30k is better.
clusters = 64,
device: torch.device
):
kmeans_clusters = clusters
G = copy.deepcopy(G).eval().requires_grad_(False).to(device)
from kmeans_pytorch import kmeans
print(f'G.mapping {samples} samples')
z_samples = np.random.randn(samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)
w_samples = w_samples.cpu().numpy().astype(np.float32)
w_samples_1d = w_samples[:, :1, :].astype(np.float32)
data_size, dims, num_clusters = samples, G.z_dim, kmeans_clusters
xx = w_samples_1d
xx = torch.from_numpy(xx)
print(f'Performing kmeans clustering {samples} latents into {kmeans_clusters} clusters')
# thx @pbaylies for kmeans idea https://gist.github.com/pbaylies/671ef8434fd11f056bab4330e0e7c365
cluster_ids_x, cluster_centers = kmeans(
X=xx, num_clusters=num_clusters, distance='euclidean', device=device
)
pod = torch.tensor(cluster_centers, dtype=torch.float32, device=device, requires_grad=True) # pod is array
print('kmeans pod shape', pod.shape)
# the whole point
torch.save(pod, 'kmeans-pod.pt')
print(f'Generating images for {kmeans_clusters} centroid w\'s')
for i in tqdm(range(kmeans_clusters)):
wt = pod[i].repeat([1, G.mapping.num_ws, 1])
snap = G.synthesis(wt, noise_mode='const')
img = (snap + 1) * (255/2)
img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(img, 'RGB').save(f'kmean-{i:04}.png')
return None
#----------------------------------------------------------------------------
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename, e.g. stylegan2-ffhq-config-f.pkl', required=True)
@click.option('--samples', help='Number of random samples ', type=int, default=1000, show_default=True)
@click.option('--clusters', help='Number of clusters ', type=int, default=64, show_default=True)
def run_clusterize(
network_pkl: str,
samples: int,
clusters: int
):
"""Derive a number of 𝑤 latents from random 𝑧 values using the mapping network of StyleGAN2, cluster them and generate sample images corresponding to cluster centroids.
Samples will be saved to the current directory.
\b
"""
# take off
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as fp:
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
# do it
clusterize(
G,
samples=samples,
clusters=clusters,
device=device
)
if __name__ == "__main__":
run_clusterize()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment