Skip to content

Instantly share code, notes, and snippets.

@mil-ad
Forked from y0ast/train_cifar.py
Created January 10, 2020 19:09
Show Gist options
  • Save mil-ad/e8a82e8f6d4c096d1c873640f5ddae22 to your computer and use it in GitHub Desktop.
Save mil-ad/e8a82e8f6d4c096d1c873640f5ddae22 to your computer and use it in GitHub Desktop.
Getting a high accuracy on CIFAR-10 is not straightforward. This self-contained script gets to 94% accuracy with a minimal setup.
import argparse
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchvision import models, datasets, transforms
def get_CIFAR10(root="./"):
input_size = 32
num_classes = 10
train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
train_dataset = datasets.CIFAR10(
root + "data/CIFAR10", train=True, transform=train_transform, download=True
)
test_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
test_dataset = datasets.CIFAR10(
root + "data/CIFAR10", train=False, transform=test_transform, download=True
)
return input_size, num_classes, train_dataset, test_dataset
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.resnet = models.resnet18(pretrained=False, num_classes=10)
self.resnet.conv1 = torch.nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.resnet.maxpool = torch.nn.Identity()
def forward(self, x):
x = self.resnet(x)
x = F.log_softmax(x, dim=1)
return x
def train(model, train_loader, optimizer, epoch):
model.train()
total_loss = []
for data, target in tqdm(train_loader):
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
prediction = model(data)
loss = F.nll_loss(prediction, target)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_loss = sum(total_loss) / len(total_loss)
print(f"Epoch: {epoch}:")
print(f"Train Set: Average Loss: {avg_loss:.2f}")
def test(model, test_loader):
model.eval()
loss = 0
correct = 0
for data, target in test_loader:
with torch.no_grad():
data = data.cuda()
target = target.cuda()
prediction = model(data)
loss += F.nll_loss(prediction, target, reduction="sum")
prediction = prediction.max(1)[1]
correct += prediction.eq(target.view_as(prediction)).sum().item()
loss /= len(test_loader.dataset)
percentage_correct = 100.0 * correct / len(test_loader.dataset)
print(
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format(
loss, correct, len(test_loader.dataset), percentage_correct
)
)
return loss, percentage_correct
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--epochs", type=int, default=45, help="number of epochs to train (default: 45)"
)
parser.add_argument(
"--lr", type=float, default=0.05, help="learning rate (default: 0.05)"
)
parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
args = parser.parse_args()
print(args)
torch.manual_seed(args.seed)
input_size, num_classes, train_dataset, test_dataset = get_CIFAR10()
kwargs = {"num_workers": 2, "pin_memory": True}
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, **kwargs
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=5000, shuffle=False, **kwargs
)
model = Model()
model = model.cuda()
milestones = [15, 30]
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=milestones, gamma=0.1
)
for epoch in range(1, args.epochs + 1):
train(model, train_loader, optimizer, epoch)
test(model, test_loader)
scheduler.step()
torch.save(model.state_dict(), "cifar_model.pt")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment