Skip to content

Instantly share code, notes, and snippets.

@hasithsura
Created January 9, 2020 09:40
Show Gist options
  • Save hasithsura/5613a93aa68196cbcda6da23782f4e19 to your computer and use it in GitHub Desktop.
Save hasithsura/5613a93aa68196cbcda6da23782f4e19 to your computer and use it in GitHub Desktop.
learning_rate = 2e-4
optimizer = optim.Adam(resnet_model.parameters(), lr=learning_rate)
epochs = 50
loss_fn = nn.CrossEntropyLoss()
resnet_train_losses=[]
resnet_valid_losses=[]
def lr_decay(optimizer, epoch):
if epoch%10==0:
new_lr = learning_rate / (10**(epoch//10))
optimizer = setlr(optimizer, new_lr)
print(f'Changed learning rate to {new_lr}')
return optimizer
def train(model, loss_fn, train_loader, valid_loader, epochs, optimizer, train_losses, valid_losses, change_lr=None):
for epoch in tqdm(range(1,epochs+1)):
model.train()
batch_losses=[]
if change_lr:
optimizer = change_lr(optimizer, epoch)
for i, data in enumerate(train_loader):
x, y = data
optimizer.zero_grad()
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.long)
y_hat = model(x)
loss = loss_fn(y_hat, y)
loss.backward()
batch_losses.append(loss.item())
optimizer.step()
train_losses.append(batch_losses)
print(f'Epoch - {epoch} Train-Loss : {np.mean(train_losses[-1])}')
model.eval()
batch_losses=[]
trace_y = []
trace_yhat = []
for i, data in enumerate(valid_loader):
x, y = data
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.long)
y_hat = model(x)
loss = loss_fn(y_hat, y)
trace_y.append(y.cpu().detach().numpy())
trace_yhat.append(y_hat.cpu().detach().numpy())
batch_losses.append(loss.item())
valid_losses.append(batch_losses)
trace_y = np.concatenate(trace_y)
trace_yhat = np.concatenate(trace_yhat)
accuracy = np.mean(trace_yhat.argmax(axis=1)==trace_y)
print(f'Epoch - {epoch} Valid-Loss : {np.mean(valid_losses[-1])} Valid-Accuracy : {accuracy}')
train(resnet_model, loss_fn, train_loader, valid_loader, epochs, optimizer, resnet_train_losses, resnet_valid_losses, lr_decay)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment