Created
November 14, 2023 02:11
-
-
Save andiac/5e704aaefa66c88bfd4fe3d08a7a0898 to your computer and use it in GitHub Desktop.
Restrict Imagenet 128
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import os | |
from PIL import Image | |
names = [] | |
# https://www.kaggle.com/competitions/imagenet-object-localization-challenge/overview | |
file_path = './LOC_synset_mapping.txt' | |
# path to the imagenet folder | |
imagenet_path = './imagenet' | |
out_path = './imagenet128' | |
os.makedirs(out_path, exist_ok=True) | |
# According to appendix A of https://arxiv.org/pdf/1805.12152.pdf | |
class_to_index = { | |
'dog': (151, 268), | |
'cat': (281, 285), | |
'frog': (30, 32), | |
'turtle': (33, 37), | |
'bird': (80, 100), | |
'primate': (365, 382), | |
'fish': (389, 397), | |
'crab': (118, 121), | |
'insect': (300, 319) | |
} | |
class_to_label = { | |
'dog': 0, | |
'cat': 1, | |
'frog': 2, | |
'turtle': 3, | |
'bird': 4, | |
'primate': 5, | |
'fish': 6, | |
'crab': 7, | |
'insect': 8 | |
} | |
# Open the file and read its contents | |
with open(file_path, 'r') as file: | |
for line in file: | |
names.append(line.strip().split(" ")[0]) | |
name_dict = {"dog": set(), "cat": set(), "frog": set(), "turtle": set(), "bird": set(), "primate": set(), "fish": set(), "crab": set(), "insect": set()} | |
name_to_class = {} | |
for k, v in class_to_index.items(): | |
for name in names[v[0]:v[1] + 1]: | |
name_dict[k].add(name) | |
name_to_class[name] = k | |
print(name_dict) | |
print(name_to_class) | |
path_dict = {k: [] for k in name_dict.keys()} | |
# Convert train | |
imagenet_train_path = os.path.join(imagenet_path, "train") | |
out_train_path = os.path.join(out_path, "train") | |
os.makedirs(out_train_path, exist_ok=True) | |
class_num = {k: 0 for k in name_dict.keys()} | |
for class_name in name_dict: | |
for folder_name in name_dict[class_name]: | |
if not os.path.exists(os.path.join(imagenet_train_path, folder_name)): | |
raise Exception("imagenet folder not exist!") | |
# for each image in the folder, resize it to 128x128 | |
for image_name in os.listdir(os.path.join(imagenet_train_path, folder_name)): | |
if not image_name.endswith(".JPEG"): | |
continue | |
class_num[class_name] += 1 | |
path_dict[class_name].append(os.path.join(imagenet_train_path, folder_name, image_name)) | |
image_names_and_labels = [] | |
for class_name in path_dict: | |
# resize | |
for image_path in path_dict[class_name][:10000]: | |
image = Image.open(image_path) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
image = image.resize((128, 128), Image.LANCZOS) | |
image_name = image_path.split("/")[-1].split(".")[0] + ".png" | |
image.save(os.path.join(out_train_path, image_name)) | |
image_names_and_labels.append((image_name, class_to_label[class_name])) | |
with open("./imagenet128/train/imagenet128_train.csv", 'w') as file: | |
writer = csv.writer(file) | |
writer.writerow(["ImageId", "Label"]) | |
for image_name, label in image_names_and_labels: | |
writer.writerow([image_name, label]) | |
print(class_num) | |
# Convert val | |
imagenet_val_path = os.path.join(imagenet_path, "val") | |
out_val_path = os.path.join(out_path, "val") | |
os.makedirs(out_val_path, exist_ok=True) | |
image_names_and_labels = [] | |
# https://www.kaggle.com/competitions/imagenet-object-localization-challenge/overview | |
with open("LOC_val_solution.csv", 'r') as file: | |
reader = csv.reader(file) | |
for row in reader: | |
if row[0] == "ImageId": | |
continue | |
if row[1].split(" ")[0] in name_to_class: | |
class_name = name_to_class[row[1].split(" ")[0]] | |
image_path = os.path.join(imagenet_val_path, row[0] + ".JPEG") | |
image = Image.open(image_path) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
image = image.resize((128, 128), Image.LANCZOS) | |
image_name = row[0] + ".png" | |
image.save(os.path.join(out_val_path, image_name)) | |
image_names_and_labels.append((image_name, class_to_label[class_name])) | |
with open("./imagenet128/val/imagenet128_val.csv", 'w') as file: | |
writer = csv.writer(file) | |
writer.writerow(["ImageId", "Label"]) | |
for image_name, label in image_names_and_labels: | |
writer.writerow([image_name, label]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment