Created
November 28, 2023 09:29
-
-
Save twmht/53b313f2c9aa340ef1f2223b746d62d9 to your computer and use it in GitHub Desktop.
training code for stanfordCars dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# inspired from https://www.kaggle.com/code/deepbear/pytorch-car-classifier-90-accuracy | |
# the dataset is downloaded from https://github.com/pytorch/vision/issues/7545 | |
import torchvision | |
import torch | |
import torch.utils.data | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
import time | |
import torch.optim as optim | |
import torch.nn as nn | |
train_tfms = transforms.Compose([transforms.Resize((400, 400)), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomRotation(15), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
test_tfms = transforms.Compose([transforms.Resize((400, 400)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
batch_size=32 | |
train_data = torchvision.datasets.StanfordCars(root="/home/acer", download=False, transform=train_tfms, split='train') | |
test_data = torchvision.datasets.StanfordCars(root="/home/acer", download=False, transform=train_tfms, split='test') | |
# dataloader for data | |
trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4) | |
testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4) | |
device='cuda:0' | |
model_ft = models.resnet34(pretrained=True).cuda() | |
num_ftrs = model_ft.fc.in_features | |
model_ft.fc = nn.Linear(num_ftrs, 196) | |
model_ft = model_ft.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9) | |
lrscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, threshold = 0.9) | |
def eval_model(model_ft): | |
correct = 0.0 | |
total = 0.0 | |
with torch.no_grad(): | |
for i, data in enumerate(testloader, 0): | |
images, labels = data | |
#images = images.to(device).half() # uncomment for half precision model | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model_ft(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
test_acc = 100.0 * correct / total | |
print('Accuracy of the network on the test images: %d %%' % ( | |
test_acc)) | |
return test_acc | |
def train_model(model, criterion, optimizer, scheduler, n_epochs = 5): | |
model.train() | |
for epoch in range(n_epochs): | |
since = time.time() | |
running_loss = 0.0 | |
running_correct = 0.0 | |
for inputs, labels in trainloader: | |
# get the inputs and assign them to cuda | |
#inputs = inputs.to(device).half() # uncomment for half precision model | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs.data, 1) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# calculate the loss/acc later | |
running_loss += loss.item() | |
running_correct += (labels==predicted).sum().item() | |
epoch_duration = time.time()-since | |
epoch_loss = running_loss/len(trainloader) | |
epoch_acc = 100/batch_size*running_correct/len(trainloader) | |
# switch the model to eval mode to evaluate on test data | |
model.eval() | |
test_acc = eval_model(model) | |
print("Epoch %s, duration: %d s, loss: %.4f, train_acc: %.4f, test_acc: %.4f" % (epoch+1, epoch_duration, epoch_loss, epoch_acc, test_acc)) | |
# re-set the model to train mode after validating | |
model.train() | |
scheduler.step(test_acc) | |
since = time.time() | |
print('Finished Training') | |
train_model(model_ft, criterion, optimizer, lrscheduler, n_epochs=10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment