Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created December 28, 2018 18:59
Show Gist options
  • Save enijkamp/d297d8f75b47c1a257f39a4bbea95121 to your computer and use it in GitHub Desktop.
Save enijkamp/d297d8f75b47c1a257f39a4bbea95121 to your computer and use it in GitHub Desktop.
minimal dataset
import torch.utils
import torch.nn.utils
from torch.utils import data
from torchvision import transforms
from PIL import Image, ImageDraw
import numpy as np
def plot_page():
w, h = 28, 28
image = Image.new('RGB', (w, h), (0, 0, 0))
randi = lambda min, max: int(np.random.randint(min, max))
draw = ImageDraw.Draw(image)
x = randi(4, 24)
y = randi(0, 10)
draw.rectangle((x, y, x + 6, y + 6), fill='white')
draw = ImageDraw.Draw(image)
x = randi(4, 24)
y = randi(12, 24)
draw.polygon([(x, y), (x-4, y+4), (x+4, y+4)], fill='white')
return image
def create_pages_images(size=1):
return [plot_page() for _ in range(size)]
class ShapesDataset(data.Dataset):
def __init__(self, size=10000, transform=None):
self.transform = transform
self.images = create_pages_images(size)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
data = self.images[index]
if self.transform is not None:
data = self.transform(data)
return data
train_loader = torch.utils.data.DataLoader(ShapesDataset(transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])), batch_size=batch_size, shuffle=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment