Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created October 17, 2022 00:12
Show Gist options
  • Save Birch-san/c9e0c9ebfae0bdc17d29c3a7d42b9aa4 to your computer and use it in GitHub Desktop.
Save Birch-san/c9e0c9ebfae0bdc17d29c3a7d42b9aa4 to your computer and use it in GitHub Desktop.
plotting da histograms (partial snippet from Jupyter notebook)
import matplotlib.pyplot as plt
# …
latents: FloatTensor = self.inner_model(x, sigma, cond=cond, **kwargs)
unscaled: Tensor = latents / self.scale_factor
chs = [torch.histogram(c) for c in unscaled[0].flatten(1)]
h = torch.histogram(unscaled[0].ravel())
plt.figure(figsize=(10,2))
plt.title('Per-channel latent values after denoising sigma %.3f at CFG scale %d' % (sigma.item(), cfg_scale))
for ch, col in zip(chs, ('red','green','blue','purple',)):
plt.hist(ch.bin_edges[:-1].cpu(), ch.bin_edges.cpu(), weights=ch.hist.cpu(), color = col, alpha = 0.4)
plt.xlabel('Latent value ÷ 0.18215')
plt.ylabel('Count')
plt.legend(['Ch0','Ch1','Ch2','Ch3'])
plt.show()
plt.figure(figsize=(10,2))
plt.title('Global latent values after denoising sigma %.3f at CFG scale %d' % (sigma.item(), cfg_scale))
plt.hist(h.bin_edges[:-1].cpu(), h.bin_edges.cpu(), weights=h.hist.cpu())
plt.xlabel('Latent value ÷ 0.18215')
plt.ylabel('Count')
plt.legend(['All'])
plt.show()
@Birch-san
Copy link
Author

produces output such as:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment