Last active
          April 30, 2022 08:45 
        
      - 
      
 - 
        
Save dvgodoy/dd2696cf490c9e3d2eb0342590ab428f to your computer and use it in GitHub Desktop.  
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import TensorDataset, DataLoader | |
| def draw_circle(radius, center_x=0.5, center_y=0.5, size=28): | |
| # draw a circle using coordinates for the center, and the radius | |
| circle = plt.Circle((center_x, center_y), radius, color='k', fill=False) | |
| fig, ax = plt.subplots(figsize=(1, 1)) | |
| ax.add_patch(circle) | |
| ax.axis('off') | |
| buf = fig.canvas.print_to_buffer() | |
| plt.close() | |
| # converts matplotlib figure into PIL image, make it grayscale, and resize it | |
| return np.array(Image.frombuffer('RGBA', buf[1], buf[0]).convert('L').resize((int(size), int(size)))) | |
| def gen_circles(n, size=28): | |
| # generates random coordinates around (0.5, 0.5) as center points | |
| center_x = np.random.uniform(0.0, 0.03, size=n).reshape(-1, 1)+.5 | |
| center_y = np.random.uniform(0.0, 0.03, size=n).reshape(-1, 1)+.5 | |
| # generates random radius sizes between 0.03 and 0.47 | |
| radius = np.random.uniform(0.03, 0.47, size=n).reshape(-1, 1) | |
| sizes = np.ones((n, 1))*size | |
| coords = np.concatenate([radius, center_x, center_y, sizes], axis=1) | |
| # generates circles using draw_circle function | |
| circles = np.apply_along_axis(func1d=lambda v: draw_circle(*v), axis=1, arr=coords) | |
| return circles, radius | |
| np.random.seed(42) | |
| # generates 1,000 circles | |
| circles, radius = gen_circles(1000) | |
| circles_ds = TensorDataset(torch.as_tensor(circles).unsqueeze(1).float()/255, torch.as_tensor(radius)) | |
| circles_dl = DataLoader(circles_ds, batch_size=32, shuffle=True, drop_last=True) | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment