Last active
May 9, 2019 08:02
-
-
Save ashnair1/735975c3405c65375011b8d4d5da1d0c to your computer and use it in GitHub Desktop.
Count occurrences of categories from COCO annotation format
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import json | |
import matplotlib.pyplot as plt | |
from collections import Counter | |
# For jupyter notebook | |
#%matplotlib inline | |
# TODO: Clean up this code | |
def cat_count(annotations=None): | |
gtrain ,gval = annotations | |
train_cats = [] | |
val_cats = [] | |
for i in gtrain['annotations']: | |
j = i['category_id'] | |
for cat in gtrain['categories']: | |
if j == cat['id']: | |
train_cats.append(cat['name']) | |
for k in gval['annotations']: | |
w = k['category_id'] | |
for cat in gval['categories']: | |
if w == cat['id']: | |
val_cats.append(cat['name']) | |
# Create dictionary of category and counts | |
train_count_dict = dict(Counter(train_cats)) | |
val_count_dict = dict(Counter(val_cats)) | |
# Create dictionary of category and ids | |
ids = [] | |
cats = [] | |
for ind in range(len(gtrain['categories'])): | |
cat_id,cat,_ = zip(gtrain['categories'][ind].values()) | |
ids.append(cat_id[0]) | |
cats.append(cat[0]) | |
cat_dict = dict(zip(ids, cats)) | |
missing_train = set(list(cat_dict.values())) - set(list(train_count_dict.keys())) | |
tadd = dict(zip(missing_train, [0]*len(missing_train))) | |
train_count_dict.update(tadd) | |
missing_val = set(list(cat_dict.values())) - set(list(val_count_dict.keys())) | |
vadd = dict(zip(missing_val, [0]*len(missing_val))) | |
val_count_dict.update(vadd) | |
return [train_count_dict,val_count_dict,[tadd,vadd]] | |
def show_class_distribution_both(annotations=None, dist="train",bar="h"): | |
gtrain,gval = annotations | |
assert dist in ["train","val"], "Has to be either 'train' or 'val' data" | |
train_cats, val_cats, _ = cat_count([gtrain,gval]) | |
train_labels, train_values = zip(*Counter(train_cats).items()) | |
val_labels, val_values = zip(*Counter(val_cats).items()) | |
dat = ["train","val"] | |
for name in dat: | |
if name == "train": | |
labels = train_labels | |
values = train_values | |
elif name == "val": | |
labels = val_labels | |
values = val_values | |
indexes = np.arange(len(labels)) | |
width = 0.5 | |
if bar == "v": | |
fig_size = (20,10) | |
elif bar == "h": | |
fig_size = (8,10) | |
plt.figure(figsize=fig_size) | |
if bar == "h": | |
plt.barh(indexes, values, width,align="edge") | |
plt.yticks(indexes+width/2,labels) | |
plt.ylabel('Classes') | |
plt.xlabel('Count') | |
elif bar == "v": | |
plt.bar(indexes, values, width,align="edge") | |
plt.xticks(indexes, labels,rotation='vertical') | |
plt.xlabel('Classes') | |
plt.ylabel('Count') | |
plt.tight_layout() | |
plt.title('Class Distribution') | |
if bar == "h": | |
for i, v in enumerate(values): | |
plt.text(v + 3, i , str(v), color='blue', fontweight='bold') | |
elif bar == "v": | |
for i, v in enumerate(values): | |
plt.text(i,v + 5,str(v), color='blue', fontsize=0.6*fig_size[0],fontweight='bold') | |
plt.savefig(name + ".jpg") | |
plt.show() | |
def main(): | |
with open('NEW_ANNOTATIONS_18CLASSES/instances_train.json') as gt: | |
gtrain = json.load(gt) | |
with open('NEW_ANNOTATIONS_18CLASSES/instances_val.json') as gv: | |
gval = json.load(gv) | |
show_class_distribution_both([gtrain,gval],dist="train",bar="h") | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment