Created
May 5, 2020 20:16
-
-
Save burrussmp/45d7dbe5f0c9831e20593f548f73aaf4 to your computer and use it in GitHub Desktop.
Train a model in Pytorch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# hyper-parameters | |
batch_size = 32 | |
learning_rate = 0.001 | |
scheduler_step = 30 | |
epochs = 190 | |
gamma = 0.5 | |
lr_scheduler_step_size = 12 | |
adam_betas = (0.9,0.999) | |
use_cuda = torch.cuda.is_available() | |
torch.manual_seed(123456) | |
device = torch.device("cuda" if use_cuda else "cpu") | |
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} | |
# path to model | |
restart = False | |
pathToModel = os.path.join(BASEDIR,'weights2.pt') | |
torch.manual_seed(65675) | |
print('Initializing model') | |
model = UNet2D() | |
if (use_cuda): | |
model.cuda() | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
scheduler = StepLR(optimizer, step_size=lr_scheduler_step_size, gamma=gamma) | |
if os.path.isfile(pathToModel) and not restart: | |
print('Loading model.....') | |
model.load_state_dict(torch.load(pathToModel)) | |
best_loss = torch.tensor(np.load(os.path.join(BASEDIR,'lowest.npy')).tolist()).to(device) | |
train_loss_save = np.load(os.path.join(BASEDIR,'train_loss.npy')).tolist() | |
val_loss_save = np.load(os.path.join(BASEDIR,'val_loss.npy')).tolist() | |
else: | |
best_loss = math.inf | |
train_loss_save = [] | |
val_loss_save = [] | |
for epoch in range(1, epochs + 1): | |
train_loss = train(model, device, train_loader, optimizer, epoch) | |
val_loss = validate(model, device, validation_loader) | |
train_loss_save.append(train_loss.cpu().data.numpy()) | |
val_loss_save.append(val_loss.cpu().data.numpy()) | |
if (val_loss < best_loss): | |
print('Loss improved from ', best_loss, 'to',val_loss,': Saving new model to',pathToModel) | |
best_loss = val_loss | |
torch.save(model.state_dict(), pathToModel) | |
scheduler.step() | |
np.save(os.path.join(BASEDIR,'val_loss.npy'),np.array(val_loss_save)) | |
np.save(os.path.join(BASEDIR,'train_loss.npy'),np.array(train_loss_save)) | |
np.save(os.path.join(BASEDIR,'lowest.npy'),best_loss.cpu().data.numpy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment