Last active
July 17, 2022 22:45
-
-
Save yashsavani/bf243f3d7b80d4951c92b805f34e1040 to your computer and use it in GitHub Desktop.
Train ResNet on CIFAR10 in a single file using PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Train ResNet on CIFAR10 in a single file using PyTorch.""" | |
import argparse | |
import json | |
import os | |
import pandas as pd | |
import time | |
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
from torch.utils.data import DataLoader, random_split | |
import torchvision | |
from torchvision.transforms import Compose, ToTensor, Normalize | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"{device = }") | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--num_epochs", type=int, default=11) | |
parser.add_argument("--lr", type=float, default=1e-4) | |
parser.add_argument("--loc", type=str) | |
args = parser.parse_args() | |
# Load Data | |
train_split = 0.8 | |
classes = ( | |
"plane", | |
"car", | |
"bird", | |
"cat", | |
"deer", | |
"dog", | |
"frog", | |
"horse", | |
"ship", | |
"truck", | |
) | |
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
fulltrainset = torchvision.datasets.CIFAR10( | |
root="./data", train=True, download=True, transform=Compose([ToTensor(), normalize]) | |
) | |
train_split_len = int(train_split * len(fulltrainset)) | |
val_split_len = len(fulltrainset) - train_split_len | |
trainset, valset = random_split(fulltrainset, [train_split_len, val_split_len]) | |
testset = torchvision.datasets.CIFAR10( | |
root="./data", | |
train=False, | |
download=True, | |
transform=Compose([ToTensor(), normalize]), | |
) | |
trainloader = DataLoader( | |
trainset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True | |
) | |
valloader = DataLoader( | |
valset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True | |
) | |
testloader = DataLoader( | |
testset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True | |
) | |
# Load Model | |
model = torchvision.models.resnet34() | |
in_ftrs = model.fc.in_features | |
model.fc = nn.Linear(in_ftrs, 10) | |
model = model.to(device) | |
# Train Model | |
save_loc = f"{args.loc}_wts.pt" | |
stats_loc = f"{args.loc}_stats.csv" | |
if os.path.exists(save_loc): | |
model.load_state_dict(torch.load(save_loc)) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
stats = { | |
"epoch_num": [], | |
"epoch_phase": [], | |
"epoch_loss": [], | |
"epoch_acc": [], | |
"epoch_time": [], | |
} | |
for epoch in range(args.num_epochs): | |
for phase in ["train", "val"]: | |
if phase == "train": | |
model.train() | |
else: | |
model.eval() | |
start = time.time() | |
losses = 0.0 | |
corrects = 0 | |
data_loader = trainloader if phase == "train" else valloader | |
for inps, labels in data_loader: | |
inps = inps.to(device) | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
if phase == "train": | |
logits = model(inps) | |
else: | |
with torch.no_grad(): | |
logits = model(inps) | |
_, preds = torch.max(logits, 1) | |
loss = criterion(logits, labels) | |
if phase == "train": | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() * inps.size(0) | |
corrects += torch.sum(preds == labels.data) | |
data_len = len(trainset) if phase == "train" else len(valset) | |
epoch_loss = losses / data_len | |
epoch_acc = corrects.double() / data_len | |
epoch_time = time.time() - start | |
print( | |
f"{epoch}/{args.num_epochs} {phase}: {epoch_loss=:.4f}, {epoch_acc=:.4f}, {epoch_time=:.4f}" | |
) | |
stats["epoch_num"].append(epoch) | |
stats["epoch_phase"].append(phase) | |
stats["epoch_loss"].append(epoch_loss) | |
stats["epoch_acc"].append(epoch_acc.item()) | |
stats["epoch_time"].append(epoch_time) | |
if not epoch % 5: | |
torch.save(model.state_dict(), save_loc) | |
print(f"Saved model parameters to {save_loc}") | |
pd.DataFrame(stats).to_csv(stats_loc, index=False) | |
print(f"Saved training statistics to {stats_loc}") | |
# Evaluate Model | |
model.eval() | |
losses = 0.0 | |
corrects = 0 | |
start = time.time() | |
test_stats_loc = f"{args.loc}_test_stats.json" | |
for inps, labels in testloader: | |
inps = inps.to(device) | |
labels = labels.to(device) | |
with torch.no_grad(): | |
logits = model(inps) | |
_, preds = torch.max(logits, 1) | |
loss = criterion(logits, labels) | |
losses += loss.item() * inps.size(0) | |
corrects += torch.sum(preds == labels.data) | |
test_loss = losses / len(testset) | |
test_acc = corrects.double() / len(testset) | |
test_time = time.time() - start | |
print(f"Test Stats: {test_loss=:.4f}, {test_acc=:.4f}, {test_time=:.4f}") | |
with open(test_stats_loc, "w") as fh: | |
json.dump( | |
{"test_loss": test_loss, "test_acc": test_acc.item(), "test_time": test_time}, | |
fh, | |
) | |
print(f"Saved test statistics to {test_stats_loc}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script to run parallel slurm jobs""" | |
import argparse | |
import submitit | |
from pathlib import Path | |
parser = argparse.ArgumentParser(description="ResNet CIFAR-10") | |
parser.add_argument("--num_epochs", type=int, default=11) | |
parser.add_argument("--job_name", type=str, default="job_%j") | |
parser.add_argument("--pathname", type=str, default="results") | |
args = parser.parse_args() | |
Path(args.pathname).mkdir(parents=True, exist_ok=True) | |
executor = submitit.SlurmExecutor(folder=args.job_name) | |
executor.update_parameters( | |
time=4000, | |
cpus_per_task=2, | |
gres="gpu:1", | |
job_name="resnet_cifar10", | |
ntasks_per_node=1, | |
exclude="locus-1-[21,25,29]", | |
mem="12G", | |
array_parallelism=4, | |
) | |
with executor.batch(): | |
lrs = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6] | |
fnames = ["tmp2", "tmp3", "tmp4", "tmp5", "tmp6"] | |
for lr, loc in zip(lrs, fnames): | |
function = submitit.helpers.CommandFunction( | |
[ | |
"python", | |
"-u", | |
"resnet_cifar10.py", | |
"--num_epochs", | |
f"{args.num_epochs}", | |
"--lr", | |
f"{lr}", | |
"--loc", | |
f"{args.pathname}/{loc}", | |
] | |
) | |
print(f"Submitting job: {' '.join(function.command)}") | |
job = executor.submit(function) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
epoch_num | epoch_phase | epoch_loss | epoch_acc | epoch_time | |
---|---|---|---|---|---|
0 | train | 1.6960874133110047 | 0.38380000000000003 | 18.53317618370056 | |
0 | val | 1.4451372964859008 | 0.47600000000000003 | 1.47853422164917 | |
1 | train | 1.3361847997188567 | 0.519475 | 17.766621828079224 | |
1 | val | 1.2464430358886718 | 0.5504 | 1.4541492462158203 | |
2 | train | 1.143345989561081 | 0.59435 | 17.87923002243042 | |
2 | val | 1.1588681091308595 | 0.5906 | 1.3571674823760986 | |
3 | train | 0.9828106349229813 | 0.6540250000000001 | 17.691423177719116 | |
3 | val | 1.0863204523086547 | 0.6179 | 1.4985730648040771 | |
4 | train | 0.8367472008705139 | 0.707275 | 18.27773690223694 | |
4 | val | 1.083212512397766 | 0.6246 | 1.3585000038146973 | |
5 | train | 0.6975694972991944 | 0.7554500000000001 | 18.061813354492188 | |
5 | val | 1.084663610458374 | 0.6403 | 1.4291815757751465 | |
6 | train | 0.5639139029979706 | 0.805725 | 17.99892497062683 | |
6 | val | 1.1505574979782105 | 0.6363 | 1.3442208766937256 | |
7 | train | 0.44514585880041124 | 0.8430500000000001 | 17.89025568962097 | |
7 | val | 1.1511975234985352 | 0.6521 | 1.3732786178588867 | |
8 | train | 0.3563057611465454 | 0.87505 | 17.896596670150757 | |
8 | val | 1.2635577613830566 | 0.6413 | 1.364490270614624 | |
9 | train | 0.2955851906552911 | 0.8973000000000001 | 18.068241834640503 | |
9 | val | 1.3302849162101746 | 0.6376000000000001 | 1.352651596069336 | |
10 | train | 0.24327969312220812 | 0.9142750000000001 | 18.12368679046631 | |
10 | val | 1.3919539432525634 | 0.6415000000000001 | 1.468806266784668 | |
11 | train | 0.2088443217486143 | 0.929025 | 17.909857034683228 | |
11 | val | 1.3832221105575562 | 0.641 | 1.4864187240600586 | |
12 | train | 0.18264484179019927 | 0.936675 | 17.711808443069458 | |
12 | val | 1.4087063361167909 | 0.6495000000000001 | 1.3579418659210205 | |
13 | train | 0.17399553629234432 | 0.9413 | 18.009772062301636 | |
13 | val | 1.4485021294593812 | 0.6552 | 1.4399867057800293 | |
14 | train | 0.14986198884062468 | 0.94745 | 17.799652099609375 | |
14 | val | 1.5679083934783935 | 0.6445000000000001 | 1.3729536533355713 | |
15 | train | 0.14712372142113744 | 0.9488500000000001 | 17.783864498138428 | |
15 | val | 1.5645678364753723 | 0.6424000000000001 | 1.4078621864318848 | |
16 | train | 0.12653812631219624 | 0.9565250000000001 | 17.860257148742676 | |
16 | val | 1.595202984046936 | 0.647 | 1.4989902973175049 | |
17 | train | 0.12441516583785414 | 0.9570000000000001 | 17.95743441581726 | |
17 | val | 1.5260233039855957 | 0.6572 | 1.3984713554382324 | |
18 | train | 0.11859542702510953 | 0.959025 | 18.011947870254517 | |
18 | val | 1.61346548538208 | 0.6492 | 1.3637778759002686 | |
19 | train | 0.10850345106683672 | 0.9620500000000001 | 18.012967109680176 | |
19 | val | 1.558018336057663 | 0.6579 | 1.3531913757324219 | |
20 | train | 0.09907145978957414 | 0.965825 | 17.901024103164673 | |
20 | val | 1.6670068212032318 | 0.653 | 1.36391282081604 | |
21 | train | 0.10149896029718221 | 0.9647 | 17.875645637512207 | |
21 | val | 1.6091691967010497 | 0.6574 | 1.3599228858947754 | |
22 | train | 0.101146349629201 | 0.96555 | 17.86266827583313 | |
22 | val | 1.6013954856872559 | 0.6645 | 1.3542509078979492 | |
23 | train | 0.08616407853569835 | 0.970825 | 17.84755301475525 | |
23 | val | 1.6136671280384063 | 0.6622 | 1.3616793155670166 | |
24 | train | 0.08264649133831263 | 0.97145 | 18.134830474853516 | |
24 | val | 1.6646786427497864 | 0.6656000000000001 | 1.3589999675750732 | |
25 | train | 0.09348004345707596 | 0.9683750000000001 | 17.80931305885315 | |
25 | val | 1.5970482931137084 | 0.6564 | 1.3304085731506348 | |
26 | train | 0.08165445998329669 | 0.973025 | 17.87434983253479 | |
26 | val | 1.6293686190605163 | 0.6691 | 1.3875038623809814 | |
27 | train | 0.07648255949728192 | 0.97385 | 17.848639488220215 | |
27 | val | 1.6804926796913147 | 0.6632 | 1.3649330139160156 | |
28 | train | 0.07254848490022123 | 0.9748 | 17.872262954711914 | |
28 | val | 1.7689826406478881 | 0.663 | 1.356893539428711 | |
29 | train | 0.07388110288595781 | 0.9742500000000001 | 17.95309567451477 | |
29 | val | 1.7121417123794556 | 0.6645 | 1.4937670230865479 | |
30 | train | 0.07134598800907843 | 0.9755750000000001 | 17.860339164733887 | |
30 | val | 1.7459123119354247 | 0.6582 | 1.362112045288086 | |
31 | train | 0.06858938560741953 | 0.97645 | 17.98039150238037 | |
31 | val | 1.7511834965705873 | 0.6667000000000001 | 1.366523027420044 | |
32 | train | 0.06496722006741912 | 0.977175 | 18.98650360107422 | |
32 | val | 1.726856561088562 | 0.662 | 1.3619060516357422 | |
33 | train | 0.06336505509754643 | 0.9789 | 17.844940900802612 | |
33 | val | 1.8165611315727235 | 0.6567000000000001 | 1.4037866592407227 | |
34 | train | 0.06330432953285053 | 0.97875 | 18.013700246810913 | |
34 | val | 1.701840454006195 | 0.6666000000000001 | 1.3784518241882324 | |
35 | train | 0.05898183861435391 | 0.9796 | 17.929861783981323 | |
35 | val | 1.7428623735427857 | 0.6607000000000001 | 1.3617007732391357 | |
36 | train | 0.05972674607567024 | 0.9786 | 17.958882093429565 | |
36 | val | 1.734776219367981 | 0.6628000000000001 | 1.3758878707885742 | |
37 | train | 0.05497746549076401 | 0.9815250000000001 | 18.005401611328125 | |
37 | val | 1.7338352456092834 | 0.6675 | 1.3590779304504395 | |
38 | train | 0.057928202350577336 | 0.9800000000000001 | 17.92861580848694 | |
38 | val | 1.7371590656280518 | 0.6723 | 1.4081010818481445 | |
39 | train | 0.055695211252570154 | 0.98035 | 17.94972324371338 | |
39 | val | 1.7491377750396728 | 0.6721 | 1.395190954208374 | |
40 | train | 0.05344161576689221 | 0.9825 | 18.079487323760986 | |
40 | val | 1.7496719292223453 | 0.6734 | 1.371354341506958 | |
41 | train | 0.05053124641787726 | 0.983025 | 17.683712482452393 | |
41 | val | 1.738305499649048 | 0.6738000000000001 | 1.4246728420257568 | |
42 | train | 0.04965392700391821 | 0.9833000000000001 | 18.31919765472412 | |
42 | val | 1.8942454456329345 | 0.6601 | 1.3650131225585938 | |
43 | train | 0.05288469399437308 | 0.9816750000000001 | 17.857980728149414 | |
43 | val | 1.785586702156067 | 0.6691 | 1.3628778457641602 | |
44 | train | 0.04939625953314826 | 0.9830500000000001 | 18.075785636901855 | |
44 | val | 1.74666238155365 | 0.6852 | 1.3637464046478271 | |
45 | train | 0.048314340169890786 | 0.9831500000000001 | 18.01528525352478 | |
45 | val | 1.767422658443451 | 0.6733 | 1.4817347526550293 | |
46 | train | 0.04522587890475988 | 0.984925 | 18.123121976852417 | |
46 | val | 1.8413782642364502 | 0.667 | 1.371633529663086 | |
47 | train | 0.0468662763066357 | 0.9840500000000001 | 18.42077875137329 | |
47 | val | 1.8086338933944701 | 0.6736000000000001 | 1.4757544994354248 | |
48 | train | 0.04183294942476787 | 0.9854750000000001 | 18.07581353187561 | |
48 | val | 1.8136471565246581 | 0.6738000000000001 | 1.3467607498168945 | |
49 | train | 0.04488923963187262 | 0.9846 | 17.93657636642456 | |
49 | val | 1.7563840320587158 | 0.6757000000000001 | 1.3584692478179932 | |
50 | train | 0.04078557941949693 | 0.98665 | 18.037089586257935 | |
50 | val | 1.8805511571884155 | 0.6698000000000001 | 1.370103359222412 | |
51 | train | 0.0449939735007938 | 0.9844750000000001 | 17.983230590820312 | |
51 | val | 1.7697006483078004 | 0.6837000000000001 | 1.5036170482635498 | |
52 | train | 0.03993258131076582 | 0.9868750000000001 | 17.734033584594727 | |
52 | val | 1.722931281900406 | 0.6847000000000001 | 1.401460886001587 | |
53 | train | 0.04125696243355051 | 0.9857 | 17.950467824935913 | |
53 | val | 1.8739290427207946 | 0.6805 | 1.350996971130371 | |
54 | train | 0.044291711033822505 | 0.985125 | 17.83350110054016 | |
54 | val | 1.884322718524933 | 0.6692 | 1.3733716011047363 | |
55 | train | 0.0368670628109132 | 0.9877750000000001 | 17.655507564544678 | |
55 | val | 1.8255374110221863 | 0.6716000000000001 | 1.4257173538208008 | |
56 | train | 0.03812791590410052 | 0.9875 | 17.964482307434082 | |
56 | val | 1.8464583566665649 | 0.6758000000000001 | 1.3690247535705566 | |
57 | train | 0.03927929636144545 | 0.986975 | 17.887848615646362 | |
57 | val | 1.7753132797241211 | 0.6763 | 1.3629875183105469 | |
58 | train | 0.039536747192405165 | 0.987275 | 17.97967004776001 | |
58 | val | 1.8691625988960265 | 0.6744 | 1.360184669494629 | |
59 | train | 0.0362498815687024 | 0.987425 | 17.78330373764038 | |
59 | val | 1.8468463715553283 | 0.6687000000000001 | 1.3876428604125977 | |
60 | train | 0.036601675694051664 | 0.98755 | 17.874656438827515 | |
60 | val | 1.9155296007156373 | 0.6727000000000001 | 1.3647675514221191 | |
61 | train | 0.0382174508746597 | 0.9866750000000001 | 17.840538501739502 | |
61 | val | 1.8789127073287963 | 0.6766 | 1.4696545600891113 | |
62 | train | 0.03441867789547541 | 0.988275 | 18.26482391357422 | |
62 | val | 1.8357850158691407 | 0.6809000000000001 | 1.4053137302398682 | |
63 | train | 0.03626724732222501 | 0.9877 | 17.874390363693237 | |
63 | val | 1.7776766492843628 | 0.6856 | 1.373171091079712 | |
64 | train | 0.03346482290996937 | 0.98895 | 17.747379541397095 | |
64 | val | 1.8953405448913574 | 0.6755 | 1.364102840423584 | |
65 | train | 0.035797213427070526 | 0.9878 | 17.966055870056152 | |
65 | val | 1.833572540283203 | 0.672 | 1.3602638244628906 | |
66 | train | 0.03064623904817272 | 0.98945 | 18.041853666305542 | |
66 | val | 1.8723105999946594 | 0.6819000000000001 | 1.474238634109497 | |
67 | train | 0.036478065293820694 | 0.9873000000000001 | 17.91158652305603 | |
67 | val | 1.774871664428711 | 0.6803 | 1.4380147457122803 | |
68 | train | 0.030640873152774292 | 0.9900500000000001 | 17.92766284942627 | |
68 | val | 1.9281403480529786 | 0.671 | 1.4862265586853027 | |
69 | train | 0.030735508631228002 | 0.9897750000000001 | 18.08848547935486 | |
69 | val | 1.9348541726112365 | 0.6696000000000001 | 1.358482837677002 | |
70 | train | 0.0359612195932772 | 0.98865 | 17.99178695678711 | |
70 | val | 1.8559286540985107 | 0.6821 | 1.3594703674316406 | |
71 | train | 0.032673440153233244 | 0.9887 | 17.912135124206543 | |
71 | val | 1.9335292773246766 | 0.6739 | 1.4792242050170898 | |
72 | train | 0.03238569093749102 | 0.9895250000000001 | 17.859195470809937 | |
72 | val | 1.8416625398635864 | 0.6813 | 1.4943647384643555 | |
73 | train | 0.026340200062800433 | 0.9911000000000001 | 18.02529287338257 | |
73 | val | 1.9810066781997682 | 0.6746 | 1.3622326850891113 | |
74 | train | 0.032096733767678964 | 0.9889 | 18.489730834960938 | |
74 | val | 1.9609797005653382 | 0.6809000000000001 | 1.4511082172393799 | |
75 | train | 0.027843941722798627 | 0.99095 | 17.916367292404175 | |
75 | val | 1.9956219348907471 | 0.6745 | 1.3590576648712158 | |
76 | train | 0.035282389413286 | 0.988325 | 18.173656940460205 | |
76 | val | 1.8228674980163575 | 0.6805 | 1.365363597869873 | |
77 | train | 0.0272451836361608 | 0.9904000000000001 | 17.84825849533081 | |
77 | val | 1.8854433339118957 | 0.6789000000000001 | 1.3784475326538086 | |
78 | train | 0.02900064716959605 | 0.98985 | 17.863919734954834 | |
78 | val | 1.9110873949050904 | 0.6863 | 1.3668856620788574 | |
79 | train | 0.02595402797064744 | 0.991175 | 17.846888542175293 | |
79 | val | 1.8648885075569153 | 0.6835 | 1.3857927322387695 | |
80 | train | 0.02733516783758532 | 0.9907 | 17.959612131118774 | |
80 | val | 1.9358659605026245 | 0.675 | 1.5023152828216553 | |
81 | train | 0.029890539337060183 | 0.990125 | 17.813404083251953 | |
81 | val | 1.8661690349578857 | 0.6795 | 1.394080638885498 | |
82 | train | 0.026497844307951164 | 0.991075 | 18.057520627975464 | |
82 | val | 1.998090232849121 | 0.6817000000000001 | 1.3574254512786865 | |
83 | train | 0.030876639142358907 | 0.9898 | 17.807169914245605 | |
83 | val | 1.836445219230652 | 0.6845 | 1.361199140548706 | |
84 | train | 0.022354049466736615 | 0.9928250000000001 | 17.959235429763794 | |
84 | val | 1.919796456718445 | 0.6845 | 1.4114432334899902 | |
85 | train | 0.030048751522362 | 0.9904000000000001 | 17.936451196670532 | |
85 | val | 1.93991791973114 | 0.6812 | 1.438070297241211 | |
86 | train | 0.022835028435260755 | 0.99195 | 18.048775672912598 | |
86 | val | 2.0077165143013 | 0.6727000000000001 | 1.36210036277771 | |
87 | train | 0.027225513198939733 | 0.991025 | 18.1739342212677 | |
87 | val | 1.8998400712966919 | 0.682 | 1.3633615970611572 | |
88 | train | 0.027063793912390246 | 0.9906250000000001 | 18.00027298927307 | |
88 | val | 1.8794661586761474 | 0.6912 | 1.3572006225585938 | |
89 | train | 0.025147190348108417 | 0.9916250000000001 | 17.79403519630432 | |
89 | val | 1.9119867372512818 | 0.6799000000000001 | 1.3673436641693115 | |
90 | train | 0.025222583868273068 | 0.9915 | 18.32296848297119 | |
90 | val | 1.8864133625030517 | 0.6919000000000001 | 1.3622722625732422 | |
91 | train | 0.025201019231416284 | 0.9916250000000001 | 17.968748331069946 | |
91 | val | 1.9845755340576172 | 0.6804 | 1.3651256561279297 | |
92 | train | 0.024359502122670528 | 0.9915250000000001 | 17.84384036064148 | |
92 | val | 1.995139317703247 | 0.6834 | 1.3745269775390625 | |
93 | train | 0.02280836629884725 | 0.99275 | 17.975011110305786 | |
93 | val | 2.0460741375923157 | 0.6803 | 1.3680858612060547 | |
94 | train | 0.027009409671777392 | 0.99065 | 18.093051195144653 | |
94 | val | 1.9071997953414916 | 0.6868000000000001 | 1.3569996356964111 | |
95 | train | 0.02445775417679106 | 0.9919250000000001 | 17.94615912437439 | |
95 | val | 1.9767365124702454 | 0.676 | 1.3571388721466064 | |
96 | train | 0.02153194406411494 | 0.993125 | 17.908305406570435 | |
96 | val | 1.9767985507965087 | 0.6836 | 1.4256927967071533 | |
97 | train | 0.024142693555586448 | 0.9917 | 17.901453495025635 | |
97 | val | 2.030321930885315 | 0.6802 | 1.3615853786468506 | |
98 | train | 0.020131318463997742 | 0.993075 | 18.418582439422607 | |
98 | val | 1.8558624132156372 | 0.6895 | 1.37446928024292 | |
99 | train | 0.024965851304757234 | 0.9918 | 17.981637477874756 | |
99 | val | 1.8886236921787263 | 0.6794 | 1.3605659008026123 | |
100 | train | 0.019026363392782512 | 0.994325 | 18.326786518096924 | |
100 | val | 1.984264079284668 | 0.6893 | 1.377504587173462 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"test_loss": 2.019672994709015, "test_acc": 0.6923, "test_time": 1.195981502532959} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment