Last active
September 27, 2024 15:09
-
-
Save sadimanna/3f79ad35ec3ec2ef6a5e58c5d3b4af1f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
nr = 0 | |
current_epoch = 0 | |
epochs = 100 | |
tr_loss = [] | |
val_loss = [] | |
for epoch in range(100): | |
print(f"Epoch [{epoch}/{epochs}]\t") | |
stime = time.time() | |
model.train() | |
tr_loss_epoch = 0 | |
for step, (x_i, x_j) in enumerate(train_loader): | |
optimizer.zero_grad() | |
x_i = x_i.squeeze().to('cuda:0').float() | |
x_j = x_j.squeeze().to('cuda:0').float() | |
# positive pair, with encoding | |
z_i = model(x_i) | |
z_j = model(x_j) | |
loss = criterion(z_i, z_j) | |
loss.backward() | |
optimizer.step() | |
if nr == 0 and step % 50 == 0: | |
print(f"Step [{step}/{len(train_loader)}]\t Loss: {round(loss.item(), 5)}") | |
tr_loss_epoch += loss.item() | |
if nr == 0 and epoch < 10: | |
warmupscheduler.step() | |
if nr == 0 and epoch >= 10: | |
mainscheduler.step() | |
lr = optimizer.param_groups[0]["lr"] | |
if nr == 0 and (epoch+1) % 50 == 0: | |
save_model(model, optimizer, mainscheduler, current_epoch,"SimCLR_CIFAR10_RN50_P128_LR0P2_LWup10_Cos500_T0p5_B128_checkpoint_{}_260621.pt") | |
model.eval() | |
with torch.no_grad(): | |
val_loss_epoch = 0 | |
for step, (x_i, x_j) in enumerate(valid_loader): | |
x_i = x_i.squeeze().to('cuda:0').float() | |
x_j = x_j.squeeze().to('cuda:0').float() | |
# positive pair, with encoding | |
z_i = model(x_i) | |
z_j = model(x_j) | |
loss = criterion(z_i, z_j) | |
if nr == 0 and step % 50 == 0: | |
print(f"Step [{step}/{len(valid_loader)}]\t Loss: {round(loss.item(),5)}") | |
val_loss_epoch += loss.item() | |
if nr == 0: | |
tr_loss.append(tr_loss_epoch / len(dl)) | |
val_loss.append(val_loss_epoch / len(vdl)) | |
print(f"Epoch [{epoch}/{epochs}]\t Training Loss: {tr_loss_epoch / len(dl)}\t lr: {round(lr, 5)}") | |
print(f"Epoch [{epoch}/{epochs}]\t Validation Loss: {val_loss_epoch / len(vdl)}\t lr: {round(lr, 5)}") | |
current_epoch += 1 | |
dg.on_epoch_end() | |
time_taken = (time.time()-stime)/60 | |
print(f"Epoch [{epoch}/{epochs}]\t Time Taken: {time_taken} minutes") | |
if (epoch+1)%10==0: | |
plot_features(model.pretrained, 10, 2048, 128) | |
save_model(model, optimizer, mainscheduler, current_epoch, "SimCLR_CIFAR10_RN50_P128_LR0P2_LWup10_Cos500_T0p5_B128_checkpoint_{}_260621.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@sayantan-kuila check the dimensions of your input to the resnet. It should be of the shape N x C x H x W