Skip to content

Instantly share code, notes, and snippets.

@xvjiarui
Created June 30, 2023 00:42
Show Gist options
  • Save xvjiarui/b711027f37bb4237d3ec437c1bdda271 to your computer and use it in GitHub Desktop.
Save xvjiarui/b711027f37bb4237d3ec437c1bdda271 to your computer and use it in GitHub Desktop.
import argparse
import copy
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import tqdm
from einops import rearrange
from PIL import Image
from sklearn.cluster import KMeans as _KMeans
from detectron2.data import detection_utils as utils, MetadataCatalog
from detectron2.data import get_detection_dataset_dicts
from detectron2.utils.env import seed_all_rng
from detectron2.utils.visualizer import Visualizer
from odise.modeling.diffusion import create_gaussian_diffusion
from odise.modeling.meta_arch.ldm import LatentDiffusion, LdmExtractor
class SlideInference(nn.Module):
def __init__(self, model):
super().__init__()
assert isinstance(model, LdmExtractor)
self.model = model
def single_forward(self, img):
features = self.model(dict(img=img))
return features
def slide_forward(self, img):
batch_size, _, h_img, w_img = img.shape
num_features = len(self.model.feature_strides)
output_features = [None] * num_features
for idx in range(num_features):
stride = self.model.feature_strides[idx]
channel = self.model.feature_dims[idx]
output_features[idx] = torch.zeros(
(batch_size, channel, h_img // stride, w_img // stride),
dtype=img.dtype,
device=img.device,
)
count_mats = [torch.zeros_like(v) for v in output_features]
# if not slide training then use the shorter side to crop
short_side = min(img.shape[-2:])
h_crop = w_crop = short_side
h_stride = w_stride = short_side
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
assert crop_img.shape[-2:] == (h_crop, w_crop), f"{crop_img.shape} from {img.shape}"
crop_features = self.single_forward(crop_img)
for idx in range(num_features):
feature_stride = self.model.feature_strides[idx]
k_x1 = x1 // feature_stride
k_x2 = x2 // feature_stride
k_y1 = y1 // feature_stride
k_y2 = y2 // feature_stride
output_features[idx][:, :, k_y1:k_y2, k_x1:k_x2] += crop_features[idx]
count_mats[idx][..., k_y1:k_y2, k_x1:k_x2] += 1
assert all((count_mats[idx] == 0).sum() == 0 for k in range(num_features))
for idx in range(num_features):
output_features[idx] /= count_mats[idx]
return output_features
def forward(self, img):
return self.slide_forward(img)
class KMeans(nn.Module):
def __init__(self, num_clusters=5, normalized=False):
super().__init__()
self.kmeans = _KMeans(
n_clusters=num_clusters,
tol=1e-5,
n_init=100,
max_iter=10000,
random_state=42,
algorithm="elkan",
)
self.normalized = normalized
self.num_clusters = num_clusters
def forward(self, image_feature):
batch_size, num_channels, height, width = image_feature.shape
image_groups = []
for i in range(batch_size):
cur_feature = image_feature[i]
cur_feature = rearrange(
cur_feature, "c h w -> (h w) c", h=height, w=width, c=num_channels
)
if self.normalized:
cur_feature = F.normalize(cur_feature, dim=-1)
# [HxW, N]
similarity = -self.kmeans.fit_transform(cur_feature.cpu().numpy())
similarity = torch.from_numpy(similarity).to(image_feature.device)
# [N, H, W]
similarity = rearrange(
similarity, "(h w) n -> n h w", n=self.num_clusters, h=height, w=width
)
image_groups.append(similarity)
# [B, N, H, W]
image_groups = torch.stack(image_groups, dim=0)
return image_groups
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script that visualizes groupings on PASCAL VOC dataset."
)
parser.add_argument("--output-dir", default="viz_dataset", help="path to output directory")
parser.add_argument("--tag", type=str, help="tag for output directory", default="")
parser.add_argument("--vis-num", default=1, type=int, help="number of images to visualize")
args = parser.parse_args()
output_dir = args.output_dir
if args.tag:
output_dir = os.path.join(output_dir, args.tag)
# dataset_name = "coco_2017_train_panoptic_with_sem_seg"
dataset_name = "coco_2017_val_panoptic_with_sem_seg"
metadata = MetadataCatalog.get(dataset_name)
pseudo_metadata = copy.deepcopy(metadata)
# remove classes to show without labels
thing_classes = pseudo_metadata.thing_classes
stuff_classes = pseudo_metadata.stuff_classes
delattr(pseudo_metadata, "thing_classes")
delattr(metadata, "thing_colors")
delattr(pseudo_metadata, "stuff_classes")
pseudo_metadata.thing_classes = ["" for _ in thing_classes]
pseudo_metadata.stuff_classes = ["" for _ in stuff_classes]
delattr(pseudo_metadata, "stuff_colors")
dicts = get_detection_dataset_dicts(
names="coco_2017_train_panoptic_with_sem_seg",
filter_empty=False,
)
seed_all_rng(42)
random.shuffle(dicts)
dicts = dicts[: args.vis_num]
os.makedirs(output_dir, exist_ok=True)
seed_all_rng(42)
feature_extractor = LdmExtractor(
encoder_block_indices=(),
unet_block_indices=(7, 8, 9),
decoder_block_indices=(),
steps=(0,),
ldm=LatentDiffusion(
diffusion=create_gaussian_diffusion(
steps=1000,
learn_sigma=False,
noise_schedule="ldm_linear",
),
),
)
feature_extractor = SlideInference(feature_extractor)
feature_extractor.cuda()
transform = T.Compose([T.Resize(512), T.ToTensor()])
crop_resize_transform = T.Compose([T.Resize(512)])
for dic in tqdm.tqdm(dicts):
img = utils.read_image(dic["file_name"], "RGB")
visualizer = Visualizer(img, metadata=metadata)
dic.pop("annotations", None)
dic.pop("sem_seg_file_name", None)
vis = visualizer.draw_dataset_dict(dic)
img_tensor = transform(Image.fromarray(img)).unsqueeze(0).cuda()
features = feature_extractor(img_tensor)
for idx, feature in enumerate(features):
visualizer = Visualizer(
np.asarray(crop_resize_transform(Image.fromarray(img))), metadata=pseudo_metadata
)
groups = KMeans(num_clusters=32, normalized=False)(feature.cpu())
groups = F.interpolate(
groups, size=img_tensor.shape[2:], mode="bilinear", align_corners=False
)
# [H, W]
seg_map = (groups.squeeze(0).cpu()).argmax(0)
vis = visualizer.draw_sem_seg(seg_map, area_threshold=0, alpha=0.5)
vis.save(
os.path.join(
output_dir,
f'{os.path.splitext(os.path.basename(dic["file_name"]))[0]}_{idx:02d}.png',
)
)
@JihoChoi
Copy link

Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment