Last active
March 9, 2022 12:58
-
-
Save afiaka87/54295b957f235d0755c626bd42b87137 to your computer and use it in GitHub Desktop.
(WIP) Point at a folder of images, get box labels with probs in a folder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
# cd Detic/ | |
# %% | |
import detectron2 | |
from detectron2.utils.logger import setup_logger | |
from pathlib import Path | |
from random import randint, choice | |
import time | |
import PIL | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms as T | |
setup_logger() | |
# import some common libraries | |
import sys | |
import numpy as np | |
import os, json, cv2, random | |
# import some common detectron2 utilities | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
from IPython.display import display, clear_output | |
# Detic libraries | |
sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/') | |
from centernet.config import add_centernet_config | |
from detic.config import add_detic_config | |
from detic.modeling.utils import reset_cls_test | |
# %% | |
# Build the detector and download our pretrained weights | |
cfg = get_cfg() | |
add_centernet_config(cfg) | |
add_detic_config(cfg) | |
cfg.merge_from_file("configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml") | |
cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth' | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3 # set threshold for this model | |
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' | |
cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For better visualization purpose. Set to False for all classes. | |
predictor = DefaultPredictor(cfg) | |
# %% | |
# Setup the model's vocabulary using build-in datasets | |
BUILDIN_CLASSIFIER = { | |
'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy', | |
'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy', | |
'openimages': 'datasets/metadata/oid_clip_a+cname.npy', | |
'coco': 'datasets/metadata/coco_clip_a+cname.npy', | |
} | |
BUILDIN_METADATA_PATH = { | |
'lvis': 'lvis_v1_val', | |
'objects365': 'objects365_v2_val', | |
'openimages': 'oid_val_expanded', | |
'coco': 'coco_2017_val', | |
} | |
vocabulary = 'lvis' # change to 'lvis', 'objects365', 'openimages', or 'coco' | |
metadata = MetadataCatalog.get(BUILDIN_METADATA_PATH[vocabulary]) | |
classifier = BUILDIN_CLASSIFIER[vocabulary] | |
num_classes = len(metadata.thing_classes) | |
reset_cls_test(predictor.model, classifier, num_classes) | |
# %% | |
class ImageDataset(Dataset): | |
def __init__(self, | |
folder, | |
shuffle=False | |
): | |
""" | |
@param folder: Folder containing images and text files matched by their paths' respective "stem" | |
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated. | |
""" | |
super().__init__() | |
self.shuffle = shuffle | |
path = Path(folder) | |
image_files = [ | |
*path.glob('**/*.png'), *path.glob('**/*.jpg'), | |
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp') | |
] | |
image_files = {image_file.stem: image_file for image_file in image_files} | |
keys = set(image_files.keys()) | |
self.keys = list(keys) | |
self.image_files = {k: v for k, v in image_files.items() if k in keys} | |
def __len__(self): | |
return len(self.keys) | |
def random_sample(self): | |
return self.__getitem__(randint(0, self.__len__() - 1)) | |
def sequential_sample(self, ind): | |
if ind >= self.__len__() - 1: | |
return self.__getitem__(0) | |
return self.__getitem__(ind + 1) | |
def skip_sample(self, ind): | |
if self.shuffle: | |
return self.random_sample() | |
return self.sequential_sample(ind=ind) | |
def __getitem__(self, ind): | |
key = self.keys[ind] | |
image_file = self.image_files[key] | |
resize_value = 512 | |
try: | |
pil_image = Image.open(image_file).resize((resize_value, resize_value)).convert('RGB') | |
except OSError: | |
return self.skip_sample(ind) | |
np_output = np.array(pil_image) | |
return np_output, key | |
def log_class_prediction(caption, ind, outdir='./caption/'): | |
os.makedirs(outdir, exist_ok=True) | |
save_path = os.path.join(outdir, f'{ind}.txt') | |
with open(save_path, "w") as text_file: | |
text_file.write(f"{caption}") | |
import torch | |
batch_size = 64 | |
dataset = ImageDataset(folder='/home/samsepiol/DatasetWorkspace/CurrentDatasets/WIKIART/') | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) | |
# %% | |
with torch.no_grad(): | |
current_time = time.time() | |
for batch, ind in dataloader: | |
current_time = time.time() | |
batch_np = batch.cpu().numpy() | |
predictions = [predictor(x_np)['instances'] for x_np in batch_np] | |
for i, prediction in enumerate(predictions): | |
print(f"{i} {prediction.pred_classes}") | |
log_class_prediction(prediction, ind[i]) | |
elapsed_time = time.time() - current_time | |
print(f"Elapsed time: {elapsed_time} for {batch_size}") | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment