Last active
April 10, 2021 17:27
-
-
Save l4rz/b42292ef07efdba538e41a56af989a86 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import os | |
import click | |
import numpy as np | |
import PIL.Image | |
import PIL.ImageOps | |
import torch | |
import dnnlib | |
import legacy | |
from tqdm import tqdm | |
from kmeans_pytorch import kmeans | |
# | |
# copy clusterize.py to stylegan2-ada-pytorch dir and run it from there | |
# to create the grid use montage from imagemagick, e.g. | |
# montage *.png -tile 8x8 -geometry 256x256+1+1 montage.jpg | |
# | |
# -l4rz | |
def clusterize( | |
G, | |
*, | |
samples = 25000, # 10k is ok. 30k is better. | |
clusters = 64, | |
device: torch.device | |
): | |
kmeans_clusters = clusters | |
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) | |
from kmeans_pytorch import kmeans | |
print(f'G.mapping {samples} samples') | |
z_samples = np.random.randn(samples, G.z_dim) | |
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) | |
w_samples = w_samples.cpu().numpy().astype(np.float32) | |
w_samples_1d = w_samples[:, :1, :].astype(np.float32) | |
data_size, dims, num_clusters = samples, G.z_dim, kmeans_clusters | |
xx = w_samples_1d | |
xx = torch.from_numpy(xx) | |
print(f'Performing kmeans clustering {samples} latents into {kmeans_clusters} clusters') | |
# thx @pbaylies for kmeans idea https://gist.github.com/pbaylies/671ef8434fd11f056bab4330e0e7c365 | |
cluster_ids_x, cluster_centers = kmeans( | |
X=xx, num_clusters=num_clusters, distance='euclidean', device=device | |
) | |
pod = torch.tensor(cluster_centers, dtype=torch.float32, device=device, requires_grad=True) # pod is array | |
print('kmeans pod shape', pod.shape) | |
# the whole point | |
torch.save(pod, 'kmeans-pod.pt') | |
print(f'Generating images for {kmeans_clusters} centroid w\'s') | |
for i in tqdm(range(kmeans_clusters)): | |
wt = pod[i].repeat([1, G.mapping.num_ws, 1]) | |
snap = G.synthesis(wt, noise_mode='const') | |
img = (snap + 1) * (255/2) | |
img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
PIL.Image.fromarray(img, 'RGB').save(f'kmean-{i:04}.png') | |
return None | |
#---------------------------------------------------------------------------- | |
@click.command() | |
@click.option('--network', 'network_pkl', help='Network pickle filename, e.g. stylegan2-ffhq-config-f.pkl', required=True) | |
@click.option('--samples', help='Number of random samples ', type=int, default=1000, show_default=True) | |
@click.option('--clusters', help='Number of clusters ', type=int, default=64, show_default=True) | |
def run_clusterize( | |
network_pkl: str, | |
samples: int, | |
clusters: int | |
): | |
"""Derive a number of 𝑤 latents from random 𝑧 values using the mapping network of StyleGAN2, cluster them and generate sample images corresponding to cluster centroids. | |
Samples will be saved to the current directory. | |
\b | |
""" | |
# take off | |
print('Loading networks from "%s"...' % network_pkl) | |
device = torch.device('cuda') | |
with dnnlib.util.open_url(network_pkl) as fp: | |
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore | |
# do it | |
clusterize( | |
G, | |
samples=samples, | |
clusters=clusters, | |
device=device | |
) | |
if __name__ == "__main__": | |
run_clusterize() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment