Skip to content

Instantly share code, notes, and snippets.

@yashsavani
Last active July 17, 2022 22:45
Show Gist options
  • Save yashsavani/bf243f3d7b80d4951c92b805f34e1040 to your computer and use it in GitHub Desktop.
Save yashsavani/bf243f3d7b80d4951c92b805f34e1040 to your computer and use it in GitHub Desktop.
Train ResNet on CIFAR10 in a single file using PyTorch
"""Train ResNet on CIFAR10 in a single file using PyTorch."""
import argparse
import json
import os
import pandas as pd
import time
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision.transforms import Compose, ToTensor, Normalize
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device = }")
parser = argparse.ArgumentParser()
parser.add_argument("--num_epochs", type=int, default=11)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--loc", type=str)
args = parser.parse_args()
# Load Data
train_split = 0.8
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
fulltrainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=Compose([ToTensor(), normalize])
)
train_split_len = int(train_split * len(fulltrainset))
val_split_len = len(fulltrainset) - train_split_len
trainset, valset = random_split(fulltrainset, [train_split_len, val_split_len])
testset = torchvision.datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=Compose([ToTensor(), normalize]),
)
trainloader = DataLoader(
trainset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True
)
valloader = DataLoader(
valset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True
)
testloader = DataLoader(
testset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True
)
# Load Model
model = torchvision.models.resnet34()
in_ftrs = model.fc.in_features
model.fc = nn.Linear(in_ftrs, 10)
model = model.to(device)
# Train Model
save_loc = f"{args.loc}_wts.pt"
stats_loc = f"{args.loc}_stats.csv"
if os.path.exists(save_loc):
model.load_state_dict(torch.load(save_loc))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
stats = {
"epoch_num": [],
"epoch_phase": [],
"epoch_loss": [],
"epoch_acc": [],
"epoch_time": [],
}
for epoch in range(args.num_epochs):
for phase in ["train", "val"]:
if phase == "train":
model.train()
else:
model.eval()
start = time.time()
losses = 0.0
corrects = 0
data_loader = trainloader if phase == "train" else valloader
for inps, labels in data_loader:
inps = inps.to(device)
labels = labels.to(device)
optimizer.zero_grad()
if phase == "train":
logits = model(inps)
else:
with torch.no_grad():
logits = model(inps)
_, preds = torch.max(logits, 1)
loss = criterion(logits, labels)
if phase == "train":
loss.backward()
optimizer.step()
losses += loss.item() * inps.size(0)
corrects += torch.sum(preds == labels.data)
data_len = len(trainset) if phase == "train" else len(valset)
epoch_loss = losses / data_len
epoch_acc = corrects.double() / data_len
epoch_time = time.time() - start
print(
f"{epoch}/{args.num_epochs} {phase}: {epoch_loss=:.4f}, {epoch_acc=:.4f}, {epoch_time=:.4f}"
)
stats["epoch_num"].append(epoch)
stats["epoch_phase"].append(phase)
stats["epoch_loss"].append(epoch_loss)
stats["epoch_acc"].append(epoch_acc.item())
stats["epoch_time"].append(epoch_time)
if not epoch % 5:
torch.save(model.state_dict(), save_loc)
print(f"Saved model parameters to {save_loc}")
pd.DataFrame(stats).to_csv(stats_loc, index=False)
print(f"Saved training statistics to {stats_loc}")
# Evaluate Model
model.eval()
losses = 0.0
corrects = 0
start = time.time()
test_stats_loc = f"{args.loc}_test_stats.json"
for inps, labels in testloader:
inps = inps.to(device)
labels = labels.to(device)
with torch.no_grad():
logits = model(inps)
_, preds = torch.max(logits, 1)
loss = criterion(logits, labels)
losses += loss.item() * inps.size(0)
corrects += torch.sum(preds == labels.data)
test_loss = losses / len(testset)
test_acc = corrects.double() / len(testset)
test_time = time.time() - start
print(f"Test Stats: {test_loss=:.4f}, {test_acc=:.4f}, {test_time=:.4f}")
with open(test_stats_loc, "w") as fh:
json.dump(
{"test_loss": test_loss, "test_acc": test_acc.item(), "test_time": test_time},
fh,
)
print(f"Saved test statistics to {test_stats_loc}")
"""Script to run parallel slurm jobs"""
import argparse
import submitit
from pathlib import Path
parser = argparse.ArgumentParser(description="ResNet CIFAR-10")
parser.add_argument("--num_epochs", type=int, default=11)
parser.add_argument("--job_name", type=str, default="job_%j")
parser.add_argument("--pathname", type=str, default="results")
args = parser.parse_args()
Path(args.pathname).mkdir(parents=True, exist_ok=True)
executor = submitit.SlurmExecutor(folder=args.job_name)
executor.update_parameters(
time=4000,
cpus_per_task=2,
gres="gpu:1",
job_name="resnet_cifar10",
ntasks_per_node=1,
exclude="locus-1-[21,25,29]",
mem="12G",
array_parallelism=4,
)
with executor.batch():
lrs = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]
fnames = ["tmp2", "tmp3", "tmp4", "tmp5", "tmp6"]
for lr, loc in zip(lrs, fnames):
function = submitit.helpers.CommandFunction(
[
"python",
"-u",
"resnet_cifar10.py",
"--num_epochs",
f"{args.num_epochs}",
"--lr",
f"{lr}",
"--loc",
f"{args.pathname}/{loc}",
]
)
print(f"Submitting job: {' '.join(function.command)}")
job = executor.submit(function)
epoch_num epoch_phase epoch_loss epoch_acc epoch_time
0 train 1.6960874133110047 0.38380000000000003 18.53317618370056
0 val 1.4451372964859008 0.47600000000000003 1.47853422164917
1 train 1.3361847997188567 0.519475 17.766621828079224
1 val 1.2464430358886718 0.5504 1.4541492462158203
2 train 1.143345989561081 0.59435 17.87923002243042
2 val 1.1588681091308595 0.5906 1.3571674823760986
3 train 0.9828106349229813 0.6540250000000001 17.691423177719116
3 val 1.0863204523086547 0.6179 1.4985730648040771
4 train 0.8367472008705139 0.707275 18.27773690223694
4 val 1.083212512397766 0.6246 1.3585000038146973
5 train 0.6975694972991944 0.7554500000000001 18.061813354492188
5 val 1.084663610458374 0.6403 1.4291815757751465
6 train 0.5639139029979706 0.805725 17.99892497062683
6 val 1.1505574979782105 0.6363 1.3442208766937256
7 train 0.44514585880041124 0.8430500000000001 17.89025568962097
7 val 1.1511975234985352 0.6521 1.3732786178588867
8 train 0.3563057611465454 0.87505 17.896596670150757
8 val 1.2635577613830566 0.6413 1.364490270614624
9 train 0.2955851906552911 0.8973000000000001 18.068241834640503
9 val 1.3302849162101746 0.6376000000000001 1.352651596069336
10 train 0.24327969312220812 0.9142750000000001 18.12368679046631
10 val 1.3919539432525634 0.6415000000000001 1.468806266784668
11 train 0.2088443217486143 0.929025 17.909857034683228
11 val 1.3832221105575562 0.641 1.4864187240600586
12 train 0.18264484179019927 0.936675 17.711808443069458
12 val 1.4087063361167909 0.6495000000000001 1.3579418659210205
13 train 0.17399553629234432 0.9413 18.009772062301636
13 val 1.4485021294593812 0.6552 1.4399867057800293
14 train 0.14986198884062468 0.94745 17.799652099609375
14 val 1.5679083934783935 0.6445000000000001 1.3729536533355713
15 train 0.14712372142113744 0.9488500000000001 17.783864498138428
15 val 1.5645678364753723 0.6424000000000001 1.4078621864318848
16 train 0.12653812631219624 0.9565250000000001 17.860257148742676
16 val 1.595202984046936 0.647 1.4989902973175049
17 train 0.12441516583785414 0.9570000000000001 17.95743441581726
17 val 1.5260233039855957 0.6572 1.3984713554382324
18 train 0.11859542702510953 0.959025 18.011947870254517
18 val 1.61346548538208 0.6492 1.3637778759002686
19 train 0.10850345106683672 0.9620500000000001 18.012967109680176
19 val 1.558018336057663 0.6579 1.3531913757324219
20 train 0.09907145978957414 0.965825 17.901024103164673
20 val 1.6670068212032318 0.653 1.36391282081604
21 train 0.10149896029718221 0.9647 17.875645637512207
21 val 1.6091691967010497 0.6574 1.3599228858947754
22 train 0.101146349629201 0.96555 17.86266827583313
22 val 1.6013954856872559 0.6645 1.3542509078979492
23 train 0.08616407853569835 0.970825 17.84755301475525
23 val 1.6136671280384063 0.6622 1.3616793155670166
24 train 0.08264649133831263 0.97145 18.134830474853516
24 val 1.6646786427497864 0.6656000000000001 1.3589999675750732
25 train 0.09348004345707596 0.9683750000000001 17.80931305885315
25 val 1.5970482931137084 0.6564 1.3304085731506348
26 train 0.08165445998329669 0.973025 17.87434983253479
26 val 1.6293686190605163 0.6691 1.3875038623809814
27 train 0.07648255949728192 0.97385 17.848639488220215
27 val 1.6804926796913147 0.6632 1.3649330139160156
28 train 0.07254848490022123 0.9748 17.872262954711914
28 val 1.7689826406478881 0.663 1.356893539428711
29 train 0.07388110288595781 0.9742500000000001 17.95309567451477
29 val 1.7121417123794556 0.6645 1.4937670230865479
30 train 0.07134598800907843 0.9755750000000001 17.860339164733887
30 val 1.7459123119354247 0.6582 1.362112045288086
31 train 0.06858938560741953 0.97645 17.98039150238037
31 val 1.7511834965705873 0.6667000000000001 1.366523027420044
32 train 0.06496722006741912 0.977175 18.98650360107422
32 val 1.726856561088562 0.662 1.3619060516357422
33 train 0.06336505509754643 0.9789 17.844940900802612
33 val 1.8165611315727235 0.6567000000000001 1.4037866592407227
34 train 0.06330432953285053 0.97875 18.013700246810913
34 val 1.701840454006195 0.6666000000000001 1.3784518241882324
35 train 0.05898183861435391 0.9796 17.929861783981323
35 val 1.7428623735427857 0.6607000000000001 1.3617007732391357
36 train 0.05972674607567024 0.9786 17.958882093429565
36 val 1.734776219367981 0.6628000000000001 1.3758878707885742
37 train 0.05497746549076401 0.9815250000000001 18.005401611328125
37 val 1.7338352456092834 0.6675 1.3590779304504395
38 train 0.057928202350577336 0.9800000000000001 17.92861580848694
38 val 1.7371590656280518 0.6723 1.4081010818481445
39 train 0.055695211252570154 0.98035 17.94972324371338
39 val 1.7491377750396728 0.6721 1.395190954208374
40 train 0.05344161576689221 0.9825 18.079487323760986
40 val 1.7496719292223453 0.6734 1.371354341506958
41 train 0.05053124641787726 0.983025 17.683712482452393
41 val 1.738305499649048 0.6738000000000001 1.4246728420257568
42 train 0.04965392700391821 0.9833000000000001 18.31919765472412
42 val 1.8942454456329345 0.6601 1.3650131225585938
43 train 0.05288469399437308 0.9816750000000001 17.857980728149414
43 val 1.785586702156067 0.6691 1.3628778457641602
44 train 0.04939625953314826 0.9830500000000001 18.075785636901855
44 val 1.74666238155365 0.6852 1.3637464046478271
45 train 0.048314340169890786 0.9831500000000001 18.01528525352478
45 val 1.767422658443451 0.6733 1.4817347526550293
46 train 0.04522587890475988 0.984925 18.123121976852417
46 val 1.8413782642364502 0.667 1.371633529663086
47 train 0.0468662763066357 0.9840500000000001 18.42077875137329
47 val 1.8086338933944701 0.6736000000000001 1.4757544994354248
48 train 0.04183294942476787 0.9854750000000001 18.07581353187561
48 val 1.8136471565246581 0.6738000000000001 1.3467607498168945
49 train 0.04488923963187262 0.9846 17.93657636642456
49 val 1.7563840320587158 0.6757000000000001 1.3584692478179932
50 train 0.04078557941949693 0.98665 18.037089586257935
50 val 1.8805511571884155 0.6698000000000001 1.370103359222412
51 train 0.0449939735007938 0.9844750000000001 17.983230590820312
51 val 1.7697006483078004 0.6837000000000001 1.5036170482635498
52 train 0.03993258131076582 0.9868750000000001 17.734033584594727
52 val 1.722931281900406 0.6847000000000001 1.401460886001587
53 train 0.04125696243355051 0.9857 17.950467824935913
53 val 1.8739290427207946 0.6805 1.350996971130371
54 train 0.044291711033822505 0.985125 17.83350110054016
54 val 1.884322718524933 0.6692 1.3733716011047363
55 train 0.0368670628109132 0.9877750000000001 17.655507564544678
55 val 1.8255374110221863 0.6716000000000001 1.4257173538208008
56 train 0.03812791590410052 0.9875 17.964482307434082
56 val 1.8464583566665649 0.6758000000000001 1.3690247535705566
57 train 0.03927929636144545 0.986975 17.887848615646362
57 val 1.7753132797241211 0.6763 1.3629875183105469
58 train 0.039536747192405165 0.987275 17.97967004776001
58 val 1.8691625988960265 0.6744 1.360184669494629
59 train 0.0362498815687024 0.987425 17.78330373764038
59 val 1.8468463715553283 0.6687000000000001 1.3876428604125977
60 train 0.036601675694051664 0.98755 17.874656438827515
60 val 1.9155296007156373 0.6727000000000001 1.3647675514221191
61 train 0.0382174508746597 0.9866750000000001 17.840538501739502
61 val 1.8789127073287963 0.6766 1.4696545600891113
62 train 0.03441867789547541 0.988275 18.26482391357422
62 val 1.8357850158691407 0.6809000000000001 1.4053137302398682
63 train 0.03626724732222501 0.9877 17.874390363693237
63 val 1.7776766492843628 0.6856 1.373171091079712
64 train 0.03346482290996937 0.98895 17.747379541397095
64 val 1.8953405448913574 0.6755 1.364102840423584
65 train 0.035797213427070526 0.9878 17.966055870056152
65 val 1.833572540283203 0.672 1.3602638244628906
66 train 0.03064623904817272 0.98945 18.041853666305542
66 val 1.8723105999946594 0.6819000000000001 1.474238634109497
67 train 0.036478065293820694 0.9873000000000001 17.91158652305603
67 val 1.774871664428711 0.6803 1.4380147457122803
68 train 0.030640873152774292 0.9900500000000001 17.92766284942627
68 val 1.9281403480529786 0.671 1.4862265586853027
69 train 0.030735508631228002 0.9897750000000001 18.08848547935486
69 val 1.9348541726112365 0.6696000000000001 1.358482837677002
70 train 0.0359612195932772 0.98865 17.99178695678711
70 val 1.8559286540985107 0.6821 1.3594703674316406
71 train 0.032673440153233244 0.9887 17.912135124206543
71 val 1.9335292773246766 0.6739 1.4792242050170898
72 train 0.03238569093749102 0.9895250000000001 17.859195470809937
72 val 1.8416625398635864 0.6813 1.4943647384643555
73 train 0.026340200062800433 0.9911000000000001 18.02529287338257
73 val 1.9810066781997682 0.6746 1.3622326850891113
74 train 0.032096733767678964 0.9889 18.489730834960938
74 val 1.9609797005653382 0.6809000000000001 1.4511082172393799
75 train 0.027843941722798627 0.99095 17.916367292404175
75 val 1.9956219348907471 0.6745 1.3590576648712158
76 train 0.035282389413286 0.988325 18.173656940460205
76 val 1.8228674980163575 0.6805 1.365363597869873
77 train 0.0272451836361608 0.9904000000000001 17.84825849533081
77 val 1.8854433339118957 0.6789000000000001 1.3784475326538086
78 train 0.02900064716959605 0.98985 17.863919734954834
78 val 1.9110873949050904 0.6863 1.3668856620788574
79 train 0.02595402797064744 0.991175 17.846888542175293
79 val 1.8648885075569153 0.6835 1.3857927322387695
80 train 0.02733516783758532 0.9907 17.959612131118774
80 val 1.9358659605026245 0.675 1.5023152828216553
81 train 0.029890539337060183 0.990125 17.813404083251953
81 val 1.8661690349578857 0.6795 1.394080638885498
82 train 0.026497844307951164 0.991075 18.057520627975464
82 val 1.998090232849121 0.6817000000000001 1.3574254512786865
83 train 0.030876639142358907 0.9898 17.807169914245605
83 val 1.836445219230652 0.6845 1.361199140548706
84 train 0.022354049466736615 0.9928250000000001 17.959235429763794
84 val 1.919796456718445 0.6845 1.4114432334899902
85 train 0.030048751522362 0.9904000000000001 17.936451196670532
85 val 1.93991791973114 0.6812 1.438070297241211
86 train 0.022835028435260755 0.99195 18.048775672912598
86 val 2.0077165143013 0.6727000000000001 1.36210036277771
87 train 0.027225513198939733 0.991025 18.1739342212677
87 val 1.8998400712966919 0.682 1.3633615970611572
88 train 0.027063793912390246 0.9906250000000001 18.00027298927307
88 val 1.8794661586761474 0.6912 1.3572006225585938
89 train 0.025147190348108417 0.9916250000000001 17.79403519630432
89 val 1.9119867372512818 0.6799000000000001 1.3673436641693115
90 train 0.025222583868273068 0.9915 18.32296848297119
90 val 1.8864133625030517 0.6919000000000001 1.3622722625732422
91 train 0.025201019231416284 0.9916250000000001 17.968748331069946
91 val 1.9845755340576172 0.6804 1.3651256561279297
92 train 0.024359502122670528 0.9915250000000001 17.84384036064148
92 val 1.995139317703247 0.6834 1.3745269775390625
93 train 0.02280836629884725 0.99275 17.975011110305786
93 val 2.0460741375923157 0.6803 1.3680858612060547
94 train 0.027009409671777392 0.99065 18.093051195144653
94 val 1.9071997953414916 0.6868000000000001 1.3569996356964111
95 train 0.02445775417679106 0.9919250000000001 17.94615912437439
95 val 1.9767365124702454 0.676 1.3571388721466064
96 train 0.02153194406411494 0.993125 17.908305406570435
96 val 1.9767985507965087 0.6836 1.4256927967071533
97 train 0.024142693555586448 0.9917 17.901453495025635
97 val 2.030321930885315 0.6802 1.3615853786468506
98 train 0.020131318463997742 0.993075 18.418582439422607
98 val 1.8558624132156372 0.6895 1.37446928024292
99 train 0.024965851304757234 0.9918 17.981637477874756
99 val 1.8886236921787263 0.6794 1.3605659008026123
100 train 0.019026363392782512 0.994325 18.326786518096924
100 val 1.984264079284668 0.6893 1.377504587173462
{"test_loss": 2.019672994709015, "test_acc": 0.6923, "test_time": 1.195981502532959}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment