Skip to content

Instantly share code, notes, and snippets.

@ashhadulislam
Created July 26, 2022 17:55
Show Gist options
  • Save ashhadulislam/f1d17cecaf0b3cb9bfbafd6cd26977f6 to your computer and use it in GitHub Desktop.
Save ashhadulislam/f1d17cecaf0b3cb9bfbafd6cd26977f6 to your computer and use it in GitHub Desktop.
PATH = './squeezenet1_0.pth'
# setup
model_ft = models.squeezenet1_0(pretrained=True,)
model_ft.classifier[1] = nn.Conv2d(512, len(classes), kernel_size=(1,1), stride=(1,1))
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# train
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=num_epochs)
torch.save(model_ft.state_dict(), PATH)
# load
model_ft3 = models.squeezenet1_0(pretrained=True,)
model_ft3.classifier[1] = nn.Conv2d(512, len(classes), kernel_size=(1,1), stride=(1,1))
model_ft3.to(device)
model_ft3.load_state_dict(torch.load(PATH,map_location=device))
model_ft3.eval()
# test
print(accuracy(model_ft3, test_loader))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment