Created
February 20, 2023 16:34
-
-
Save jszym/7671798dec882e5a36c80b45871015ac to your computer and use it in GitHub Desktop.
PyTorch QuickDraw dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# adapted from https://github.com/nateraw/quickdraw-pytorch/blob/main/quickdraw.ipynb | |
from typing import List, Optional | |
import urllib.request | |
from tqdm.auto import tqdm | |
from pathlib import Path | |
import requests | |
import torch | |
import math | |
import numpy as np | |
def get_quickdraw_class_names(): | |
url = "https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt" | |
r = requests.get(url) | |
classes = [x.replace(' ', '_') for x in r.text.splitlines()] | |
return classes | |
def download_quickdraw_dataset(root="./data", limit: Optional[int] = None, class_names: List[str] = None): | |
if class_names is None: | |
class_names = get_quickdraw_class_names() | |
root = Path(root) | |
root.mkdir(exist_ok=True, parents=True) | |
url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/' | |
print("Downloading Quickdraw Dataset...") | |
for class_name in tqdm(class_names[:limit]): | |
fpath = root / f"{class_name}.npy" | |
if not fpath.exists(): | |
urllib.request.urlretrieve(f"{url}{class_name.replace('_', '%20')}.npy", fpath) | |
def load_quickdraw_data(root="./data", max_items_per_class=5000): | |
all_files = Path(root).glob('*.npy') | |
x = np.empty([0, 784], dtype=np.uint8) | |
y = np.empty([0], dtype=np.long) | |
class_names = [] | |
print(f"Loading {max_items_per_class} examples for each class from the Quickdraw Dataset...") | |
for idx, file in enumerate(tqdm(sorted(all_files))): | |
data = np.load(file, mmap_mode='r') | |
data = data[0: max_items_per_class, :] | |
labels = np.full(data.shape[0], idx) | |
x = np.concatenate((x, data), axis=0) | |
y = np.append(y, labels) | |
class_names.append(file.stem) | |
return x, y, class_names | |
class QuickDrawDataset(torch.utils.data.Dataset): | |
def __init__(self, root, max_items_per_class=5000, class_limit=None): | |
super().__init__() | |
self.root = root | |
self.max_items_per_class = max_items_per_class | |
self.class_limit = class_limit | |
download_quickdraw_dataset(self.root, self.class_limit) | |
self.X, self.Y, self.classes = load_quickdraw_data(self.root, self.max_items_per_class) | |
def __getitem__(self, idx): | |
x = (self.X[idx] / 255.).astype(np.float32).reshape(1, 28, 28) | |
y = self.Y[idx] | |
return torch.from_numpy(x), y.item() | |
def __len__(self): | |
return len(self.X) | |
def collate_fn(self, batch): | |
x = torch.stack([item[0] for item in batch]) | |
y = torch.LongTensor([item[1] for item in batch]) | |
return {'pixel_values': x, 'labels': y} | |
def split(self, pct=0.1): | |
num_classes = len(self.classes) | |
indices = torch.randperm(len(self)).tolist() | |
n_val = math.floor(len(indices) * pct) | |
train_ds = torch.utils.data.Subset(self, indices[:-n_val]) | |
val_ds = torch.utils.data.Subset(self, indices[-n_val:]) | |
return train_ds, val_ds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment