Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created December 28, 2018 18:59
Show Gist options
  • Save enijkamp/06c862625989fc295d2e469eea5825c9 to your computer and use it in GitHub Desktop.
Save enijkamp/06c862625989fc295d2e469eea5825c9 to your computer and use it in GitHub Desktop.
minimal
import torch.utils
import torch.nn.utils
from torch.utils import data
from torchvision import transforms
from PIL import Image, ImageDraw
import numpy as np
def plot_page():
w, h = 28, 28
image = Image.new('RGB', (w, h), (0, 0, 0))
randi = lambda min, max: int(np.random.randint(min, max))
draw = ImageDraw.Draw(image)
x = randi(4, 24)
y = randi(0, 10)
draw.rectangle((x, y, x + 6, y + 6), fill='white')
draw = ImageDraw.Draw(image)
x = randi(4, 24)
y = randi(12, 24)
draw.polygon([(x, y), (x-4, y+4), (x+4, y+4)], fill='white')
return image
def create_pages_images(size=1):
return [plot_page() for _ in range(size)]
class DotsDataset(data.Dataset):
def __init__(self, size=10000, transform=None):
self.transform = transform
self.images = create_pages_images(size)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
data = self.images[index]
if self.transform is not None:
data = self.transform(data)
return data
train_loader = torch.utils.data.DataLoader(DotsDataset(transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])), batch_size=batch_size, shuffle=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment