Created
June 23, 2020 14:48
-
-
Save khalidmeister/f52151554d3b0d73173947b45f01baa5 to your computer and use it in GitHub Desktop.
train-model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the model (Download only at first time loading the model) | |
model_resnet = models.resnet18(pretrained=True) | |
model_vgg16 = models.vgg16(pretrained=True) | |
model_alexnet = models.alexnet(pretrained=True) | |
# Define the function to train the model | |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_acc = 0.0 | |
for epoch in range(num_epochs): | |
print('Epoch {}/{}'.format(epoch, num_epochs - 1)) | |
print('-' * 10) | |
for phase in ['train', 'val']: | |
if phase == 'train': | |
model.train() | |
else: | |
model.eval() | |
running_loss = 0.0 | |
running_corrects = 0 | |
for inputs, labels in dataloaders[phase]: | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(phase == 'train'): | |
outputs = model(inputs) | |
_, preds = torch.max(outputs, 1) | |
loss = criterion(outputs, labels) | |
if phase == 'train': | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() * inputs.size(0) | |
running_corrects += torch.sum(preds == labels.data) | |
if phase == 'train': | |
scheduler.step() | |
epoch_loss = running_loss / data_size[phase] | |
epoch_acc = running_corrects.double() / data_size[phase] | |
print('{} Loss: {:.4f} Acc: {:.4f}'.format( | |
phase, epoch_loss, epoch_acc)) | |
if phase == 'val' and epoch_acc > best_acc: | |
best_acc = epoch_acc | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
print() | |
time_elapsed = time.time() - since | |
print('Training complete in {:.0f}m {:.0f}s'.format( | |
time_elapsed // 60, time_elapsed % 60)) | |
print('Best val Acc: {:4f}'.format(best_acc)) | |
model.load_state_dict(best_model_wts) | |
return model | |
# Set the requires_grad on each parameter to false, | |
# so it will not calculate the gradients | |
for param in model_vgg16.parameters(): | |
param.requires_grad = False | |
for param in model_resnet.parameters(): | |
param.requires_grad = False | |
for param in model_alexnet.parameters(): | |
param.requires_grad = False | |
# Set the new final layer for our new dataset | |
num_ftrs = model_vgg16.classifier[6].in_features | |
model_vgg16.classifier[6] = nn.Linear(num_ftrs, len(class_names)) | |
num_ftrs = model_resnet.fc.in_features | |
model_resnet.fc = nn.Linear(num_ftrs, len(class_names)) | |
num_ftrs = model_alexnet.classifier[6].in_features | |
model_alexnet.classifier[6] = nn.Linear(num_ftrs, len(class_names)) | |
# Enable GPU for the model | |
model_vgg16 = model_vgg16.to(device) | |
model_resnet = model_resnet.to(device) | |
model_alexnet = model_alexnet.to(device) | |
# Set the loss function | |
criterion = nn.CrossEntropyLoss() | |
# Set the optimizer and the scheduler to update the weights, and train the model. | |
optimizer_conv = optim.SGD(model_vgg16.classifier[6].parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) | |
model_vgg16 = train_model(model_vgg16, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) | |
optimizer_conv = optim.SGD(model_resnet.fc.parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) | |
model_resnet = train_model(model_resnet, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) | |
optimizer_conv = optim.SGD(model_alexnet.classifier[6].parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) | |
model_alexnet = train_model(model_alexnet, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment