Skip to content

Instantly share code, notes, and snippets.

@mataney
Last active June 24, 2018 16:04
Show Gist options
  • Save mataney/d5c0bb444fb0d3862ea3affff9ef40b8 to your computer and use it in GitHub Desktop.
Save mataney/d5c0bb444fb0d3862ea3affff9ef40b8 to your computer and use it in GitHub Desktop.
device = torch.device("cuda")
class SegDataset(Dataset):
def __init__(self, csv_loc, data_dir):
self.data_dir = data_dir
self.images_data = read_csv(csv_loc)
self.images = self.prepare_images()
def transform(self, raw, seg):
t = transforms.CenterCrop(128)
raw = t(raw)
seg = t(seg)
raw = trans_f.to_tensor(raw).mul(255).float().to(device)
seg = trans_f.to_tensor(seg).mul(255).long().to(device)
return {'raw': raw, 'seg': seg}
def prepare_images(self):
images = []
def read_image_by_id(idx, raw_image=True):
img_name = os.path.join(self.data_dir,
self.images_data[idx][int(not raw_image)])
return Image.open(img_name)
for idx in range(len(self.images_data)):
raw = read_image_by_id(idx)
seg = read_image_by_id(idx, False)
images.append(self.transform(raw, seg))
random.shuffle(images)
return images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx]
criterion = nn.CrossEntropyLoss()
def train_model(model):
train_data = SegDataset(csv_loc='Data/train.csv', data_dir='Data')
train_iter = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
val_data = SegDataset(csv_loc='Data/val.csv', data_dir='Data')
val_iter = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE)
for epoch in range(EPOCHS):
train_stats = run_proc_on_data(train_batch, model, train_iter)
val_stats = run_proc_on_data(validate_batch, model, val_iter)
def run_proc_on_data(func, model, data_iter):
print(func.__name__)
loss, jac = 0, 0
for i, batch in enumerate(data_iter):
curr_loss, curr_jac = func(i, model, batch)
loss += curr_loss
jac += curr_jac
loss /= len(data_iter)
jac /= len(data_iter)
print("loss: " + str(loss) + " jaccard: "+ str(jac))
def train_batch(batch_id, model, batch):
model.zero_grad()
pred = model(batch['raw'])
loss = criterion(pred.view(-1, 3), batch['seg'].view(-1))
loss.backward()
model.optim.step()
return loss.item(), jac.item()
def validate_batch(batch_id, model, batch):
pred = model(batch['raw'])
loss = criterion(pred.view(-1, 3), batch['seg'].view(-1))
return loss.item(), jac.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment