Skip to content

Instantly share code, notes, and snippets.

@CookieLau
Last active April 7, 2023 11:33
Show Gist options
  • Save CookieLau/263ed496065fee8b0a54f34be5d6ee5b to your computer and use it in GitHub Desktop.
Save CookieLau/263ed496065fee8b0a54f34be5d6ee5b to your computer and use it in GitHub Desktop.
Transfer Learning by tuning the head of ImageNet-1K pre-trained model
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import gc
import time
import random
from filelock import FileLock
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import datasets
import torchvision.transforms as transforms
import optuna
from optuna.trial import TrialState
from timm.models import create_model
import wandb
wandb.init(project="optuna_resnet50")
BATCH_SIZE = 2048 # or 64, 128
def get_dataset():
# transform the data into 'tensors' using the 'transforms' module
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224 (height x width)
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
with FileLock(os.path.expanduser("~/.data.lock")):
# download training dataset
train_dataset = datasets.CIFAR10(root='./data', train=True,
transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False,
transform=transform, download=True)
return train_dataset, test_dataset
def get_data_loaders():
train_dataset, test_dataset = get_dataset()
random.seed(0)
torch.manual_seed(0)
# Feed data in batches into deep-learning models
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
num_workers=0, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
num_workers=0, shuffle=False)
return train_loader, test_loader
def get_model(model_name="poolformer"):
if model_name == "poolformer":
# poolformer_s12 model
model = create_model("poolformer_s12", pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_classes = 10
model.head = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
model.head.requires_grad = True
else:
# ResNet50 model
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
for param in model.parameters():
param.requires_grad = False
num_classes = 10
model.fc = nn.Sequential(
nn.Linear(2048, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
model.fc.requires_grad = True
return model
# Custom Accuracy Function
def validate(model, data_loader, device):
model.eval()
num_correct, total_loss, num_examples = 0, 0., 0
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits = model(features)
_, predicted_labels = torch.max(logits, 1)
loss = F.cross_entropy(logits, targets, reduction='sum')
num_examples += targets.size(0)
num_correct += (predicted_labels == targets).sum()
total_loss += loss
return num_correct.float() / num_examples * 100, total_loss / num_examples
def train(model, optimizer, train_loader, device):
start_time = time.time()
num_epochs = 5
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(device)
targets = targets.to(device)
logits = model(features)
loss = criterion(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(train_loader), loss))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
def objective(trial):
gc.collect()
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_data_loaders()
model = get_model("resnet50").to(device)
# Generate the optimizers.
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
train(model, optimizer, train_loader, device=device)
acc, _ = validate(model, test_loader, device)
wandb.log({"optimizer": optimizer_name, "lr": lr, "accuracy": acc})
return acc
if __name__ == "__main__":
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment