Skip to content

Instantly share code, notes, and snippets.

@JihoChoi
Forked from xvjiarui/cluster_odise.py
Created January 23, 2024 02:54
Show Gist options
  • Save JihoChoi/2fabda76c90d7084b4f2235baeb3cebd to your computer and use it in GitHub Desktop.
Save JihoChoi/2fabda76c90d7084b4f2235baeb3cebd to your computer and use it in GitHub Desktop.
import argparse
import copy
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import tqdm
from einops import rearrange
from PIL import Image
from sklearn.cluster import KMeans as _KMeans
from detectron2.data import detection_utils as utils, MetadataCatalog
from detectron2.data import get_detection_dataset_dicts
from detectron2.utils.env import seed_all_rng
from detectron2.utils.visualizer import Visualizer
from odise.modeling.diffusion import create_gaussian_diffusion
from odise.modeling.meta_arch.ldm import LatentDiffusion, LdmExtractor
class SlideInference(nn.Module):
def __init__(self, model):
super().__init__()
assert isinstance(model, LdmExtractor)
self.model = model
def single_forward(self, img):
features = self.model(dict(img=img))
return features
def slide_forward(self, img):
batch_size, _, h_img, w_img = img.shape
num_features = len(self.model.feature_strides)
output_features = [None] * num_features
for idx in range(num_features):
stride = self.model.feature_strides[idx]
channel = self.model.feature_dims[idx]
output_features[idx] = torch.zeros(
(batch_size, channel, h_img // stride, w_img // stride),
dtype=img.dtype,
device=img.device,
)
count_mats = [torch.zeros_like(v) for v in output_features]
# if not slide training then use the shorter side to crop
short_side = min(img.shape[-2:])
h_crop = w_crop = short_side
h_stride = w_stride = short_side
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
assert crop_img.shape[-2:] == (h_crop, w_crop), f"{crop_img.shape} from {img.shape}"
crop_features = self.single_forward(crop_img)
for idx in range(num_features):
feature_stride = self.model.feature_strides[idx]
k_x1 = x1 // feature_stride
k_x2 = x2 // feature_stride
k_y1 = y1 // feature_stride
k_y2 = y2 // feature_stride
output_features[idx][:, :, k_y1:k_y2, k_x1:k_x2] += crop_features[idx]
count_mats[idx][..., k_y1:k_y2, k_x1:k_x2] += 1
assert all((count_mats[idx] == 0).sum() == 0 for k in range(num_features))
for idx in range(num_features):
output_features[idx] /= count_mats[idx]
return output_features
def forward(self, img):
return self.slide_forward(img)
class KMeans(nn.Module):
def __init__(self, num_clusters=5, normalized=False):
super().__init__()
self.kmeans = _KMeans(
n_clusters=num_clusters,
tol=1e-5,
n_init=100,
max_iter=10000,
random_state=42,
algorithm="elkan",
)
self.normalized = normalized
self.num_clusters = num_clusters
def forward(self, image_feature):
batch_size, num_channels, height, width = image_feature.shape
image_groups = []
for i in range(batch_size):
cur_feature = image_feature[i]
cur_feature = rearrange(
cur_feature, "c h w -> (h w) c", h=height, w=width, c=num_channels
)
if self.normalized:
cur_feature = F.normalize(cur_feature, dim=-1)
# [HxW, N]
similarity = -self.kmeans.fit_transform(cur_feature.cpu().numpy())
similarity = torch.from_numpy(similarity).to(image_feature.device)
# [N, H, W]
similarity = rearrange(
similarity, "(h w) n -> n h w", n=self.num_clusters, h=height, w=width
)
image_groups.append(similarity)
# [B, N, H, W]
image_groups = torch.stack(image_groups, dim=0)
return image_groups
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script that visualizes groupings on PASCAL VOC dataset."
)
parser.add_argument("--output-dir", default="viz_dataset", help="path to output directory")
parser.add_argument("--tag", type=str, help="tag for output directory", default="")
parser.add_argument("--vis-num", default=1, type=int, help="number of images to visualize")
args = parser.parse_args()
output_dir = args.output_dir
if args.tag:
output_dir = os.path.join(output_dir, args.tag)
# dataset_name = "coco_2017_train_panoptic_with_sem_seg"
dataset_name = "coco_2017_val_panoptic_with_sem_seg"
metadata = MetadataCatalog.get(dataset_name)
pseudo_metadata = copy.deepcopy(metadata)
# remove classes to show without labels
thing_classes = pseudo_metadata.thing_classes
stuff_classes = pseudo_metadata.stuff_classes
delattr(pseudo_metadata, "thing_classes")
delattr(metadata, "thing_colors")
delattr(pseudo_metadata, "stuff_classes")
pseudo_metadata.thing_classes = ["" for _ in thing_classes]
pseudo_metadata.stuff_classes = ["" for _ in stuff_classes]
delattr(pseudo_metadata, "stuff_colors")
dicts = get_detection_dataset_dicts(
names="coco_2017_train_panoptic_with_sem_seg",
filter_empty=False,
)
seed_all_rng(42)
random.shuffle(dicts)
dicts = dicts[: args.vis_num]
os.makedirs(output_dir, exist_ok=True)
seed_all_rng(42)
feature_extractor = LdmExtractor(
encoder_block_indices=(),
unet_block_indices=(7, 8, 9),
decoder_block_indices=(),
steps=(0,),
ldm=LatentDiffusion(
diffusion=create_gaussian_diffusion(
steps=1000,
learn_sigma=False,
noise_schedule="ldm_linear",
),
),
)
feature_extractor = SlideInference(feature_extractor)
feature_extractor.cuda()
transform = T.Compose([T.Resize(512), T.ToTensor()])
crop_resize_transform = T.Compose([T.Resize(512)])
for dic in tqdm.tqdm(dicts):
img = utils.read_image(dic["file_name"], "RGB")
visualizer = Visualizer(img, metadata=metadata)
dic.pop("annotations", None)
dic.pop("sem_seg_file_name", None)
vis = visualizer.draw_dataset_dict(dic)
img_tensor = transform(Image.fromarray(img)).unsqueeze(0).cuda()
features = feature_extractor(img_tensor)
for idx, feature in enumerate(features):
visualizer = Visualizer(
np.asarray(crop_resize_transform(Image.fromarray(img))), metadata=pseudo_metadata
)
groups = KMeans(num_clusters=32, normalized=False)(feature.cpu())
groups = F.interpolate(
groups, size=img_tensor.shape[2:], mode="bilinear", align_corners=False
)
# [H, W]
seg_map = (groups.squeeze(0).cpu()).argmax(0)
vis = visualizer.draw_sem_seg(seg_map, area_threshold=0, alpha=0.5)
vis.save(
os.path.join(
output_dir,
f'{os.path.splitext(os.path.basename(dic["file_name"]))[0]}_{idx:02d}.png',
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment