Skip to content

Instantly share code, notes, and snippets.

@tyuvraj
Created September 17, 2024 13:47
Show Gist options
  • Save tyuvraj/3e85a323bd3a498b7738c02416317896 to your computer and use it in GitHub Desktop.
Save tyuvraj/3e85a323bd3a498b7738c02416317896 to your computer and use it in GitHub Desktop.
import os
import warnings
import PIL
from PIL import Image
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
import argparse
import numpy as np
from modeling.BaseModel import BaseModel
from modeling import build_model
from utils.distributed import init_distributed
from utils.arguments import load_opt_from_config_files
from utils.constants import COCO_PANOPTIC_CLASSES
from utils.visualizer import Visualizer
from torchvision import transforms
from detectron2.data import MetadataCatalog
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
from modeling.language.loss import vl_similarity
import torch.nn.functional as F
from tqdm import tqdm
import time
import cv2
def parse_option():
parser = argparse.ArgumentParser('SEEM Demo', add_help=False)
parser.add_argument('--conf_files', default="configs/seem/focall_unicl_lang_demo.yaml", metavar="FILE", help='path to config file', )
cfg = parser.parse_args()
return cfg
if __name__ == "__main__":
cfg = parse_option()
opt = load_opt_from_config_files([cfg.conf_files])
opt = init_distributed(opt)
# META DATA
pretrained_pth = os.path.join("seem_focall_v0.pt")
if not os.path.exists(pretrained_pth):
os.system("wget {}".format("https://huggingface.co/xdecoder/SEEM/resolve/main/seem_focall_v0.pt"))
cur_model = 'Focal-L'
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
with torch.no_grad():
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True)
t = []
t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
transform = transforms.Compose(t)
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
all_classes = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES] + ["others"]
colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
dir_path = "src_path"
"""
For each image in the directory, the model will predict the mask of the person and save it in the mask
directory. The mask directory will have the same structure as the source directory.
|---src_path
| |---child_dir1
| | |---img1.jpg
| | |---img2.jpg
| |---child_dir2
| | |---img1.jpg
| | |---img2.jpg
"""
mask_dir = "dst_path"
reftxt = "person"
model.model.task_switch['spatial'] = False
model.model.task_switch['visual'] = False
model.model.task_switch['grounding'] = True
model.model.task_switch['audio'] = False
tasks = ["Text"]
for child_dir in os.listdir(dir_path):
print(child_dir)
write_dir = os.path.join(mask_dir, child_dir)
os.makedirs(write_dir, exist_ok=True)
for img_name in tqdm(os.listdir(os.path.join(dir_path, child_dir))):
img_path = os.path.join(dir_path, child_dir, img_name)
name, ext = os.path.splitext(img_name)
write_path = os.path.join(write_dir, name + ".png")
with torch.no_grad():
img = Image.open(img_path)
image_ori = transform(img).convert('RGB')
# visual = Visualizer(image_ori, metadata=metadata)
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
data = {"image": images, "height": height, "width": width, "text": [reftxt]}
batch_inputs = [data]
results,image_size,extra = model.model.evaluate_demo(batch_inputs)
pred_masks = results['pred_masks'][0]
v_emb = results['pred_captions'][0]
t_emb = extra['grounding_class']
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
matched_id = out_prob.max(0)[1]
pred_masks_pos = pred_masks[matched_id,:,:]
pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
# interpolate mask to ori size
pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
texts = [all_classes[pred_class[0]]]
for idx, mask in enumerate(pred_masks_pos):
# color = random_color(rgb=True, maximum=1).astype(np.int32).tolist()
out_txt = texts[idx] if 'Text' not in tasks else reftxt
if out_txt == reftxt:
cv2.imwrite(write_path, mask*255)
break
else:
print("No person detected ", img_path)
torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment