-
-
Save JihoChoi/2fabda76c90d7084b4f2235baeb3cebd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import copy | |
import numpy as np | |
import os | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
import tqdm | |
from einops import rearrange | |
from PIL import Image | |
from sklearn.cluster import KMeans as _KMeans | |
from detectron2.data import detection_utils as utils, MetadataCatalog | |
from detectron2.data import get_detection_dataset_dicts | |
from detectron2.utils.env import seed_all_rng | |
from detectron2.utils.visualizer import Visualizer | |
from odise.modeling.diffusion import create_gaussian_diffusion | |
from odise.modeling.meta_arch.ldm import LatentDiffusion, LdmExtractor | |
class SlideInference(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
assert isinstance(model, LdmExtractor) | |
self.model = model | |
def single_forward(self, img): | |
features = self.model(dict(img=img)) | |
return features | |
def slide_forward(self, img): | |
batch_size, _, h_img, w_img = img.shape | |
num_features = len(self.model.feature_strides) | |
output_features = [None] * num_features | |
for idx in range(num_features): | |
stride = self.model.feature_strides[idx] | |
channel = self.model.feature_dims[idx] | |
output_features[idx] = torch.zeros( | |
(batch_size, channel, h_img // stride, w_img // stride), | |
dtype=img.dtype, | |
device=img.device, | |
) | |
count_mats = [torch.zeros_like(v) for v in output_features] | |
# if not slide training then use the shorter side to crop | |
short_side = min(img.shape[-2:]) | |
h_crop = w_crop = short_side | |
h_stride = w_stride = short_side | |
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 | |
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 | |
for h_idx in range(h_grids): | |
for w_idx in range(w_grids): | |
y1 = h_idx * h_stride | |
x1 = w_idx * w_stride | |
y2 = min(y1 + h_crop, h_img) | |
x2 = min(x1 + w_crop, w_img) | |
y1 = max(y2 - h_crop, 0) | |
x1 = max(x2 - w_crop, 0) | |
crop_img = img[:, :, y1:y2, x1:x2] | |
assert crop_img.shape[-2:] == (h_crop, w_crop), f"{crop_img.shape} from {img.shape}" | |
crop_features = self.single_forward(crop_img) | |
for idx in range(num_features): | |
feature_stride = self.model.feature_strides[idx] | |
k_x1 = x1 // feature_stride | |
k_x2 = x2 // feature_stride | |
k_y1 = y1 // feature_stride | |
k_y2 = y2 // feature_stride | |
output_features[idx][:, :, k_y1:k_y2, k_x1:k_x2] += crop_features[idx] | |
count_mats[idx][..., k_y1:k_y2, k_x1:k_x2] += 1 | |
assert all((count_mats[idx] == 0).sum() == 0 for k in range(num_features)) | |
for idx in range(num_features): | |
output_features[idx] /= count_mats[idx] | |
return output_features | |
def forward(self, img): | |
return self.slide_forward(img) | |
class KMeans(nn.Module): | |
def __init__(self, num_clusters=5, normalized=False): | |
super().__init__() | |
self.kmeans = _KMeans( | |
n_clusters=num_clusters, | |
tol=1e-5, | |
n_init=100, | |
max_iter=10000, | |
random_state=42, | |
algorithm="elkan", | |
) | |
self.normalized = normalized | |
self.num_clusters = num_clusters | |
def forward(self, image_feature): | |
batch_size, num_channels, height, width = image_feature.shape | |
image_groups = [] | |
for i in range(batch_size): | |
cur_feature = image_feature[i] | |
cur_feature = rearrange( | |
cur_feature, "c h w -> (h w) c", h=height, w=width, c=num_channels | |
) | |
if self.normalized: | |
cur_feature = F.normalize(cur_feature, dim=-1) | |
# [HxW, N] | |
similarity = -self.kmeans.fit_transform(cur_feature.cpu().numpy()) | |
similarity = torch.from_numpy(similarity).to(image_feature.device) | |
# [N, H, W] | |
similarity = rearrange( | |
similarity, "(h w) n -> n h w", n=self.num_clusters, h=height, w=width | |
) | |
image_groups.append(similarity) | |
# [B, N, H, W] | |
image_groups = torch.stack(image_groups, dim=0) | |
return image_groups | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="A script that visualizes groupings on PASCAL VOC dataset." | |
) | |
parser.add_argument("--output-dir", default="viz_dataset", help="path to output directory") | |
parser.add_argument("--tag", type=str, help="tag for output directory", default="") | |
parser.add_argument("--vis-num", default=1, type=int, help="number of images to visualize") | |
args = parser.parse_args() | |
output_dir = args.output_dir | |
if args.tag: | |
output_dir = os.path.join(output_dir, args.tag) | |
# dataset_name = "coco_2017_train_panoptic_with_sem_seg" | |
dataset_name = "coco_2017_val_panoptic_with_sem_seg" | |
metadata = MetadataCatalog.get(dataset_name) | |
pseudo_metadata = copy.deepcopy(metadata) | |
# remove classes to show without labels | |
thing_classes = pseudo_metadata.thing_classes | |
stuff_classes = pseudo_metadata.stuff_classes | |
delattr(pseudo_metadata, "thing_classes") | |
delattr(metadata, "thing_colors") | |
delattr(pseudo_metadata, "stuff_classes") | |
pseudo_metadata.thing_classes = ["" for _ in thing_classes] | |
pseudo_metadata.stuff_classes = ["" for _ in stuff_classes] | |
delattr(pseudo_metadata, "stuff_colors") | |
dicts = get_detection_dataset_dicts( | |
names="coco_2017_train_panoptic_with_sem_seg", | |
filter_empty=False, | |
) | |
seed_all_rng(42) | |
random.shuffle(dicts) | |
dicts = dicts[: args.vis_num] | |
os.makedirs(output_dir, exist_ok=True) | |
seed_all_rng(42) | |
feature_extractor = LdmExtractor( | |
encoder_block_indices=(), | |
unet_block_indices=(7, 8, 9), | |
decoder_block_indices=(), | |
steps=(0,), | |
ldm=LatentDiffusion( | |
diffusion=create_gaussian_diffusion( | |
steps=1000, | |
learn_sigma=False, | |
noise_schedule="ldm_linear", | |
), | |
), | |
) | |
feature_extractor = SlideInference(feature_extractor) | |
feature_extractor.cuda() | |
transform = T.Compose([T.Resize(512), T.ToTensor()]) | |
crop_resize_transform = T.Compose([T.Resize(512)]) | |
for dic in tqdm.tqdm(dicts): | |
img = utils.read_image(dic["file_name"], "RGB") | |
visualizer = Visualizer(img, metadata=metadata) | |
dic.pop("annotations", None) | |
dic.pop("sem_seg_file_name", None) | |
vis = visualizer.draw_dataset_dict(dic) | |
img_tensor = transform(Image.fromarray(img)).unsqueeze(0).cuda() | |
features = feature_extractor(img_tensor) | |
for idx, feature in enumerate(features): | |
visualizer = Visualizer( | |
np.asarray(crop_resize_transform(Image.fromarray(img))), metadata=pseudo_metadata | |
) | |
groups = KMeans(num_clusters=32, normalized=False)(feature.cpu()) | |
groups = F.interpolate( | |
groups, size=img_tensor.shape[2:], mode="bilinear", align_corners=False | |
) | |
# [H, W] | |
seg_map = (groups.squeeze(0).cpu()).argmax(0) | |
vis = visualizer.draw_sem_seg(seg_map, area_threshold=0, alpha=0.5) | |
vis.save( | |
os.path.join( | |
output_dir, | |
f'{os.path.splitext(os.path.basename(dic["file_name"]))[0]}_{idx:02d}.png', | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment