Skip to content

Instantly share code, notes, and snippets.

@efirdc
Created June 8, 2020 01:37
Show Gist options
  • Save efirdc/5d8bd66859e574c683a504a4690ae8bc to your computer and use it in GitHub Desktop.
Save efirdc/5d8bd66859e574c683a504a4690ae8bc to your computer and use it in GitHub Desktop.
Connected components in pytorch
# Demonstration: https://www.youtube.com/watch?v=5AvHrIK-Kjc&feature=youtu.be
# rand_cmap from https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from visualizations import rand_cmap
W = H = 64
img = torch.randn(W, H).to(device)
for _ in range(6):
img = F.avg_pool2d(img.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1).squeeze()
threshold = 0.1
img[img>threshold] = 1.0
img[img<=threshold] = 0.0
components = torch.arange(W * H).reshape((W, H)).to(device).float()
components[img!=1] = 0
new_cmap = rand_cmap(W * H, type='bright', first_color_black=True, last_color_black=False, verbose=False)
plt.figure(figsize=(12, 12))
plt.tick_params(which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False)
plt.imshow(img.cpu().numpy(), cmap="gray")
@interact(iterations=(0, 50))
def connected_components(iterations):
global components
comp = components.clone()
for _ in range(iterations):
comp[img==1] = F.max_pool2d(comp.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1).squeeze()[img==1]
plt.figure(figsize=(12, 12))
plt.tick_params(which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False)
plt.imshow(comp.cpu().numpy(), cmap=new_cmap)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment