Last active
April 7, 2023 11:33
-
-
Save CookieLau/263ed496065fee8b0a54f34be5d6ee5b to your computer and use it in GitHub Desktop.
Transfer Learning by tuning the head of ImageNet-1K pre-trained model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original Code here: | |
# https://github.com/pytorch/examples/blob/master/mnist/main.py | |
import os | |
import gc | |
import time | |
import random | |
from filelock import FileLock | |
import torch | |
from torch import optim | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torchvision import models | |
from torchvision import datasets | |
import torchvision.transforms as transforms | |
import optuna | |
from optuna.trial import TrialState | |
from timm.models import create_model | |
import wandb | |
wandb.init(project="optuna_resnet50") | |
BATCH_SIZE = 2048 # or 64, 128 | |
def get_dataset(): | |
# transform the data into 'tensors' using the 'transforms' module | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 (height x width) | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
with FileLock(os.path.expanduser("~/.data.lock")): | |
# download training dataset | |
train_dataset = datasets.CIFAR10(root='./data', train=True, | |
transform=transform, download=True) | |
test_dataset = datasets.CIFAR10(root='./data', train=False, | |
transform=transform, download=True) | |
return train_dataset, test_dataset | |
def get_data_loaders(): | |
train_dataset, test_dataset = get_dataset() | |
random.seed(0) | |
torch.manual_seed(0) | |
# Feed data in batches into deep-learning models | |
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, | |
num_workers=0, shuffle=True) | |
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, | |
num_workers=0, shuffle=False) | |
return train_loader, test_loader | |
def get_model(model_name="poolformer"): | |
if model_name == "poolformer": | |
# poolformer_s12 model | |
model = create_model("poolformer_s12", pretrained=True) | |
for param in model.parameters(): | |
param.requires_grad = False | |
num_classes = 10 | |
model.head = nn.Sequential( | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(256, num_classes) | |
) | |
model.head.requires_grad = True | |
else: | |
# ResNet50 model | |
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) | |
for param in model.parameters(): | |
param.requires_grad = False | |
num_classes = 10 | |
model.fc = nn.Sequential( | |
nn.Linear(2048, 256), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(256, num_classes) | |
) | |
model.fc.requires_grad = True | |
return model | |
# Custom Accuracy Function | |
def validate(model, data_loader, device): | |
model.eval() | |
num_correct, total_loss, num_examples = 0, 0., 0 | |
for i, (features, targets) in enumerate(data_loader): | |
features = features.to(device) | |
targets = targets.to(device) | |
logits = model(features) | |
_, predicted_labels = torch.max(logits, 1) | |
loss = F.cross_entropy(logits, targets, reduction='sum') | |
num_examples += targets.size(0) | |
num_correct += (predicted_labels == targets).sum() | |
total_loss += loss | |
return num_correct.float() / num_examples * 100, total_loss / num_examples | |
def train(model, optimizer, train_loader, device): | |
start_time = time.time() | |
num_epochs = 5 | |
criterion = nn.CrossEntropyLoss() | |
for epoch in range(num_epochs): | |
model.train() | |
for batch_idx, (features, targets) in enumerate(train_loader): | |
features = features.to(device) | |
targets = targets.to(device) | |
logits = model(features) | |
loss = criterion(logits, targets) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if not batch_idx % 50: | |
print ('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f' | |
%(epoch+1, num_epochs, batch_idx, | |
len(train_loader), loss)) | |
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) | |
def objective(trial): | |
gc.collect() | |
torch.cuda.empty_cache() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
train_loader, test_loader = get_data_loaders() | |
model = get_model("resnet50").to(device) | |
# Generate the optimizers. | |
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"]) | |
lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True) | |
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) | |
train(model, optimizer, train_loader, device=device) | |
acc, _ = validate(model, test_loader, device) | |
wandb.log({"optimizer": optimizer_name, "lr": lr, "accuracy": acc}) | |
return acc | |
if __name__ == "__main__": | |
study = optuna.create_study(direction="maximize") | |
study.optimize(objective, n_trials=100) | |
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) | |
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) | |
print("Study statistics: ") | |
print(" Number of finished trials: ", len(study.trials)) | |
print(" Number of pruned trials: ", len(pruned_trials)) | |
print(" Number of complete trials: ", len(complete_trials)) | |
print("Best trial:") | |
trial = study.best_trial | |
print(" Value: ", trial.value) | |
print(" Params: ") | |
for key, value in trial.params.items(): | |
print(" {}: {}".format(key, value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment