This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fig,ax = plt.subplots(1,1, figsize=(10,10)) | |
b = [] | |
for i in range(num_plyrs): | |
b.append(ax.bar(x - (i - num_plyrs/2 + 0.5)*w, | |
stats.loc[i].values[1:], | |
width=w, | |
color=colors(i), | |
align='center', | |
edgecolor = 'black', |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tdg = DSDataGen('test', testimages, testlabels, num_classes=10) | |
tdl = DataLoader(tdg, batch_size = 32, drop_last = True) | |
dsmodel.eval() | |
loss_sublist = np.array([]) | |
acc_sublist = np.array([]) | |
with torch.no_grad(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(20): | |
stime = time.time() | |
print("=============== Epoch : %3d ==============="%(epoch+1)) | |
loss_sublist = np.array([]) | |
acc_sublist = np.array([]) | |
#iter_num = 0 | |
dsmodel.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tr_ep_loss = [] | |
tr_ep_acc = [] | |
val_ep_loss = [] | |
val_ep_acc = [] | |
min_val_loss = 100.0 | |
EPOCHS = 10 | |
num_cl = 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DSDataGen(Dataset): | |
def __init__(self, phase, imgarr,labels,num_classes): | |
self.phase = phase | |
self.num_classes = num_classes | |
self.imgarr = imgarr | |
self.labels = labels | |
self.randomcrop = transforms.RandomResizedCrop(32,(0.8,1.0)) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DSModel(nn.Module): | |
def __init__(self,premodel,num_classes): | |
super().__init__() | |
self.premodel = premodel | |
self.num_classes = num_classes | |
for p in self.premodel.parameters(): | |
p.requires_grad = False | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
nr = 0 | |
current_epoch = 0 | |
epochs = 100 | |
tr_loss = [] | |
val_loss = [] | |
for epoch in range(100): | |
print(f"Epoch [{epoch}/{epochs}]\t") | |
stime = time.time() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_model(model, optimizer, scheduler, current_epoch, name): | |
out = os.path.join('/content/saved_models/',name.format(current_epoch)) | |
torch.save({'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'scheduler_state_dict':scheduler.state_dict()}, out) | |
def plot_features(model, num_classes, num_feats, batch_size): | |
preds = np.array([]).reshape((0,1)) | |
gt = np.array([]).reshape((0,1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#OPTMIZER | |
optimizer = LARS( | |
[params for params in model.parameters() if params.requires_grad], | |
lr=0.2, | |
weight_decay=1e-6, | |
exclude_from_weight_decay=["batch_normalization", "bias"], | |
) | |
# "decay the learning rate with the cosine decay schedule without restarts" | |
#SCHEDULER OR LINEAR EWARMUP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.optimizer import Optimizer, required | |
import re | |
EETA_DEFAULT = 0.001 | |
class LARS(Optimizer): | |
""" | |
Layer-wise Adaptive Rate Scaling for large batch training. | |
Introduced by "Large Batch Training of Convolutional Networks" by Y. You, |