Created
September 17, 2024 13:47
-
-
Save tyuvraj/3e85a323bd3a498b7738c02416317896 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import warnings | |
import PIL | |
from PIL import Image | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple | |
import torch | |
import argparse | |
import numpy as np | |
from modeling.BaseModel import BaseModel | |
from modeling import build_model | |
from utils.distributed import init_distributed | |
from utils.arguments import load_opt_from_config_files | |
from utils.constants import COCO_PANOPTIC_CLASSES | |
from utils.visualizer import Visualizer | |
from torchvision import transforms | |
from detectron2.data import MetadataCatalog | |
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES | |
from modeling.language.loss import vl_similarity | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import time | |
import cv2 | |
def parse_option(): | |
parser = argparse.ArgumentParser('SEEM Demo', add_help=False) | |
parser.add_argument('--conf_files', default="configs/seem/focall_unicl_lang_demo.yaml", metavar="FILE", help='path to config file', ) | |
cfg = parser.parse_args() | |
return cfg | |
if __name__ == "__main__": | |
cfg = parse_option() | |
opt = load_opt_from_config_files([cfg.conf_files]) | |
opt = init_distributed(opt) | |
# META DATA | |
pretrained_pth = os.path.join("seem_focall_v0.pt") | |
if not os.path.exists(pretrained_pth): | |
os.system("wget {}".format("https://huggingface.co/xdecoder/SEEM/resolve/main/seem_focall_v0.pt")) | |
cur_model = 'Focal-L' | |
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() | |
with torch.no_grad(): | |
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True) | |
t = [] | |
t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) | |
transform = transforms.Compose(t) | |
metadata = MetadataCatalog.get('coco_2017_train_panoptic') | |
all_classes = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES] + ["others"] | |
colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]] | |
dir_path = "src_path" | |
""" | |
For each image in the directory, the model will predict the mask of the person and save it in the mask | |
directory. The mask directory will have the same structure as the source directory. | |
|---src_path | |
| |---child_dir1 | |
| | |---img1.jpg | |
| | |---img2.jpg | |
| |---child_dir2 | |
| | |---img1.jpg | |
| | |---img2.jpg | |
""" | |
mask_dir = "dst_path" | |
reftxt = "person" | |
model.model.task_switch['spatial'] = False | |
model.model.task_switch['visual'] = False | |
model.model.task_switch['grounding'] = True | |
model.model.task_switch['audio'] = False | |
tasks = ["Text"] | |
for child_dir in os.listdir(dir_path): | |
print(child_dir) | |
write_dir = os.path.join(mask_dir, child_dir) | |
os.makedirs(write_dir, exist_ok=True) | |
for img_name in tqdm(os.listdir(os.path.join(dir_path, child_dir))): | |
img_path = os.path.join(dir_path, child_dir, img_name) | |
name, ext = os.path.splitext(img_name) | |
write_path = os.path.join(write_dir, name + ".png") | |
with torch.no_grad(): | |
img = Image.open(img_path) | |
image_ori = transform(img).convert('RGB') | |
# visual = Visualizer(image_ori, metadata=metadata) | |
width = image_ori.size[0] | |
height = image_ori.size[1] | |
image_ori = np.asarray(image_ori) | |
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() | |
data = {"image": images, "height": height, "width": width, "text": [reftxt]} | |
batch_inputs = [data] | |
results,image_size,extra = model.model.evaluate_demo(batch_inputs) | |
pred_masks = results['pred_masks'][0] | |
v_emb = results['pred_captions'][0] | |
t_emb = extra['grounding_class'] | |
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale | |
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) | |
matched_id = out_prob.max(0)[1] | |
pred_masks_pos = pred_masks[matched_id,:,:] | |
pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1] | |
# interpolate mask to ori size | |
pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy() | |
texts = [all_classes[pred_class[0]]] | |
for idx, mask in enumerate(pred_masks_pos): | |
# color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() | |
out_txt = texts[idx] if 'Text' not in tasks else reftxt | |
if out_txt == reftxt: | |
cv2.imwrite(write_path, mask*255) | |
break | |
else: | |
print("No person detected ", img_path) | |
torch.cuda.empty_cache() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment