Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created June 30, 2021 10:58
Show Gist options
  • Save sadimanna/9aa11e892398a3cba82669afd1eb67ab to your computer and use it in GitHub Desktop.
Save sadimanna/9aa11e892398a3cba82669afd1eb67ab to your computer and use it in GitHub Desktop.
def save_model(model, optimizer, scheduler, current_epoch, name):
out = os.path.join('/content/saved_models/',name.format(current_epoch))
torch.save({'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict':scheduler.state_dict()}, out)
def plot_features(model, num_classes, num_feats, batch_size):
preds = np.array([]).reshape((0,1))
gt = np.array([]).reshape((0,1))
feats = np.array([]).reshape((0,num_feats))
model.eval()
with torch.no_grad():
for x1,x2 in vdl:
x1 = x1.squeeze().to(device = 'cuda:0', dtype = torch.float)
out = model(x1)
out = out.cpu().data.numpy()#.reshape((1,-1))
feats = np.append(feats,out,axis = 0)
tsne = TSNE(n_components = 2, perplexity = 50)
x_feats = tsne.fit_transform(feats)
num_samples = int(batch_size*(valimages.shape[0]//batch_size))#(len(val_df)
for i in range(num_classes):
plt.scatter(x_feats[vallabels[:num_samples]==i,1],x_feats[vallabels[:num_samples]==i,0])
plt.legend([str(i) for i in range(num_classes)])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment