Skip to content

Instantly share code, notes, and snippets.

@asears
Created October 13, 2021 08:56
Show Gist options
  • Select an option

  • Save asears/b1198c05b4ad3556f276b095dedbb501 to your computer and use it in GitHub Desktop.

Select an option

Save asears/b1198c05b4ad3556f276b095dedbb501 to your computer and use it in GitHub Desktop.
kornia - Connected components
Display the source blob
Display the rendered blob
Raw
%%capture
!pip install kornia
%%capture
!wget https://github.com/kornia/data/raw/main/cells_binary.png
from typing import Dict, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import kornia as K
def create_random_labels_map(classes: int) -> Dict[int, Tuple[int, int, int]]:
labels_map: Dict[int, Tuple[int, int, int]] = {}
for i in classes:
labels_map[i] = torch.randint(0, 255, (3, ))
labels_map[0] = torch.zeros(3)
return labels_map
def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor:
"""Function that given an image with labels ids and their pixels intrensity mapping, creates a RGB
representation for visualisation purposes."""
assert len(img_labels.shape) == 2, img_labels.shape
H, W = img_labels.shape
out = torch.empty(3, H, W, dtype=torch.uint8)
for label_id, label_val in labels_map.items():
mask = (img_labels == label_id)
for i in range(3):
out[i].masked_fill_(mask, label_val[i])
return out
def show_components(img, labels):
color_ids = torch.unique(labels)
labels_map = create_random_labels_map(color_ids)
labels_img = labels_to_image(labels, labels_map)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,12))
# Showing Original Image
ax1.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
ax1.axis("off")
ax1.set_title("Orginal Image")
#Showing Image after Component Labeling
ax2.imshow(labels_img.permute(1,2,0).squeeze().numpy())
ax2.axis('off')
ax2.set_title("Component Labeling")
img: np.ndarray = cv2.imread("cells_binary.png", cv2.IMREAD_GRAYSCALE)
img_t: torch.Tensor = K.utils.image_to_tensor(img) # CxHxW
img_t = img_t[None,...].float() / 255.
print(img_t.shape)
torch.Size([1, 1, 602, 602])
labels_out = K.contrib.connected_components(img_t, num_iterations=150)
print(labels_out.shape)
torch.Size([1, 1, 602, 602])
show_components(img_t.numpy().squeeze(), labels_out.squeeze())
print(torch.unique(labels_out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment