Skip to content

Instantly share code, notes, and snippets.

@Yotsuyubi
Last active March 10, 2020 04:53
Show Gist options
  • Save Yotsuyubi/fcbe0d5c1808f8a6226a7b8babf97277 to your computer and use it in GitHub Desktop.
Save Yotsuyubi/fcbe0d5c1808f8a6226a7b8babf97277 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import requests
from PIL import Image
import io
import os
from time import sleep
class Danbooru(Dataset):
def __init__(self, tags, *, num_images_per_tag=100, cache='./danbooru_cache', transform=None):
self.tags = tags
self.num_images_per_tag = num_images_per_tag
self.cache = cache
self.transform = transform
self.cache_list = [None]*(len(self.tags)*self.num_images_per_tag)
if self.cache is not None:
os.makedirs(self.cache, exist_ok=True)
def __len__(self):
return len(self.tags)*self.num_images_per_tag
def __getitem__(self, idx):
tag_idx = idx%len(self.tags)
tag = self.tags[tag_idx]
md5 = self.cache_list[idx]
img = self._get(idx, tag, md5=md5)
x = self.transform(img) if self.transform is not None else img
return x, tag_idx
def _get(self, idx, tag, *, md5=None):
if md5 is None:
image, md5 = self._fetch(tag)
if self.cache is not None:
self.cache_list[idx] = md5
image.save('{}/{}.png'.format(self.cache, md5))
return image
else:
image = Image.open('{}/{}.png'.format(self.cache, md5))
return image
def _fetch(self, tag):
url = "https://danbooru.donmai.us/posts.json"
params = {'tags': tag, 'random': 'true'}
res = requests.get(url, params=params).json()[0]
sleep(0.5)
if 'file_url' in res:
try:
file = res['file_url']
md5 = res['md5']
res = requests.get(file)
sleep(0.5)
image = Image.open(io.BytesIO(res.content))
return image, md5
except OSError:
return self._fetch(tag)
else:
return self._fetch(tag)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment