Created
December 13, 2019 19:50
-
-
Save aiwithshekhar/a2ac9587af1a2400c1c3322ef774b0c6 to your computer and use it in GitHub Desktop.
training starts here
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Trainer(object): | |
def __init__(self,model): | |
self.num_workers=4 | |
self.batch_size={'train':1, 'val':1} | |
self.accumulation_steps=4//self.batch_size['train'] | |
self.lr=5e-4 | |
self.num_epochs=10 | |
self.phases=['train','val'] | |
self.best_loss=float('inf') | |
self.device=torch.device("cuda:0") | |
torch.set_default_tensor_type("torch.cuda.FloatTensor") | |
self.net=model.to(self.device) | |
cudnn.benchmark= True | |
self.criterion=torch.nn.BCEWithLogitsLoss() | |
self.optimizer=optim.Adam(self.net.parameters(),lr=self.lr) | |
self.scheduler=ReduceLROnPlateau(self.optimizer,mode='min',patience=3, verbose=True) | |
self.dataloaders={phase: CarDataloader(df, img_fol, | |
mask_fol, mean, std, | |
phase=phase,batch_size=self.batch_size[phase], | |
num_workers=self.num_workers) for phase in self.phases} | |
self.losses={phase:[] for phase in self.phases} | |
self.dice_score={phase:[] for phase in self.phases} | |
def forward(self, inp_images, tar_mask): | |
inp_images=inp_images.to(self.device) | |
tar_mask=tar_mask.to(self.device) | |
pred_mask=self.net(inp_images) | |
loss=self.criterion(pred_mask,tar_mask) | |
return loss, pred_mask | |
def iterate(self, epoch, phase): | |
measure=Scores(phase, epoch) | |
start=time.strftime("%H:%M:%S") | |
print (f"Starting epoch: {epoch} | phase:{phase} | 🙊':{start}") | |
batch_size=self.batch_size[phase] | |
self.net.train(phase=="train") | |
dataloader=self.dataloaders[phase] | |
running_loss=0.0 | |
total_batches=len(dataloader) | |
self.optimizer.zero_grad() | |
for itr,batch in enumerate(dataloader): | |
images,mask_target=batch | |
loss, pred_mask=self.forward(images,mask_target) | |
loss=loss/self.accumulation_steps | |
if phase=='train': | |
loss.backward() | |
if (itr+1) % self.accumulation_steps ==0: | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
running_loss+=loss.item() | |
pred_mask=pred_mask.detach().cpu() | |
measure.update(mask_target,pred_mask) | |
epoch_loss=(running_loss*self.accumulation_steps)/total_batches | |
dice=epoch_log(phase, epoch, epoch_loss, measure, start) | |
self.losses[phase].append(epoch_loss) | |
self.dice_score[phase].append(dice) | |
torch.cuda.empty_cache() | |
return epoch_loss | |
def start(self): | |
for epoch in range (self.num_epochs): | |
self.iterate(epoch,"train") | |
state = { | |
"epoch": epoch, | |
"best_loss": self.best_loss, | |
"state_dict": self.net.state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
} | |
with torch.no_grad(): | |
val_loss=self.iterate(epoch,"val") | |
self.scheduler.step(val_loss) | |
if val_loss < self.best_loss: | |
print("******** New optimal found, saving state ********") | |
state["best_loss"] = self.best_loss = val_loss | |
torch.save(state, "./model_office.pth") | |
print () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment