Created
June 8, 2020 01:37
-
-
Save efirdc/5d8bd66859e574c683a504a4690ae8bc to your computer and use it in GitHub Desktop.
Connected components in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Demonstration: https://www.youtube.com/watch?v=5AvHrIK-Kjc&feature=youtu.be | |
# rand_cmap from https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn.functional as F | |
from visualizations import rand_cmap | |
W = H = 64 | |
img = torch.randn(W, H).to(device) | |
for _ in range(6): | |
img = F.avg_pool2d(img.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1).squeeze() | |
threshold = 0.1 | |
img[img>threshold] = 1.0 | |
img[img<=threshold] = 0.0 | |
components = torch.arange(W * H).reshape((W, H)).to(device).float() | |
components[img!=1] = 0 | |
new_cmap = rand_cmap(W * H, type='bright', first_color_black=True, last_color_black=False, verbose=False) | |
plt.figure(figsize=(12, 12)) | |
plt.tick_params(which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False) | |
plt.imshow(img.cpu().numpy(), cmap="gray") | |
@interact(iterations=(0, 50)) | |
def connected_components(iterations): | |
global components | |
comp = components.clone() | |
for _ in range(iterations): | |
comp[img==1] = F.max_pool2d(comp.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1).squeeze()[img==1] | |
plt.figure(figsize=(12, 12)) | |
plt.tick_params(which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False) | |
plt.imshow(comp.cpu().numpy(), cmap=new_cmap) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment