Skip to content

Instantly share code, notes, and snippets.

@mataney
Created June 22, 2018 12:47
Show Gist options
  • Save mataney/8caef039d0c5c1211f357cbed0d51b1f to your computer and use it in GitHub Desktop.
Save mataney/8caef039d0c5c1211f357cbed0d51b1f to your computer and use it in GitHub Desktop.
class SegDataset(Dataset):
def __init__(self, csv_loc, data_dir, augments=200):
self.data_dir = data_dir
self.images_data = read_csv(csv_loc)
self.images = self.prepare_images()
def transform(self, raw, seg):
i, j, h, w = transforms.RandomCrop.get_params(
raw, output_size=(128, 128))
raw = trans_f.crop(raw, i, j, h, w)
seg = trans_f.crop(seg, i, j, h, w)
if random.random() > 0.5:
raw = trans_f.hflip(raw)
seg = trans_f.hflip(seg)
if random.random() > 0.5:
raw = trans_f.vflip(raw)
seg = trans_f.vflip(seg)
raw = trans_f.to_tensor(raw).mul(255).float().to(device)
seg = trans_f.to_tensor(seg).mul(255).long().to(device)
return {'raw': raw, 'seg': seg}
def prepare_images(self):
images = []
def read_image_by_id(idx, raw_image=True):
img_name = os.path.join(self.data_dir,
self.images_data[idx][int(not raw_image)])
return Image.open(img_name)
for idx in range(len(self.images_data)):
raw = read_image_by_id(idx)
seg = read_image_by_id(idx, False)
for _ in range(self.augments):
images.append(self.transform(raw, seg))
random.shuffle(images)
return images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment