Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Last active April 30, 2022 08:45
Show Gist options
  • Save dvgodoy/dd2696cf490c9e3d2eb0342590ab428f to your computer and use it in GitHub Desktop.
Save dvgodoy/dd2696cf490c9e3d2eb0342590ab428f to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch.utils.data import TensorDataset, DataLoader
def draw_circle(radius, center_x=0.5, center_y=0.5, size=28):
# draw a circle using coordinates for the center, and the radius
circle = plt.Circle((center_x, center_y), radius, color='k', fill=False)
fig, ax = plt.subplots(figsize=(1, 1))
ax.add_patch(circle)
ax.axis('off')
buf = fig.canvas.print_to_buffer()
plt.close()
# converts matplotlib figure into PIL image, make it grayscale, and resize it
return np.array(Image.frombuffer('RGBA', buf[1], buf[0]).convert('L').resize((int(size), int(size))))
def gen_circles(n, size=28):
# generates random coordinates around (0.5, 0.5) as center points
center_x = np.random.uniform(0.0, 0.03, size=n).reshape(-1, 1)+.5
center_y = np.random.uniform(0.0, 0.03, size=n).reshape(-1, 1)+.5
# generates random radius sizes between 0.03 and 0.47
radius = np.random.uniform(0.03, 0.47, size=n).reshape(-1, 1)
sizes = np.ones((n, 1))*size
coords = np.concatenate([radius, center_x, center_y, sizes], axis=1)
# generates circles using draw_circle function
circles = np.apply_along_axis(func1d=lambda v: draw_circle(*v), axis=1, arr=coords)
return circles, radius
np.random.seed(42)
# generates 1,000 circles
circles, radius = gen_circles(1000)
circles_ds = TensorDataset(torch.as_tensor(circles).unsqueeze(1).float()/255, torch.as_tensor(radius))
circles_dl = DataLoader(circles_ds, batch_size=32, shuffle=True, drop_last=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment