This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_acc = 0.0 | |
for samples, labels in loaders['test']: | |
with torch.no_grad(): | |
samples, labels = samples.cuda(), labels.cuda() | |
output = trained_model(samples) | |
# calculate accuracy | |
pred = torch.argmax(output, dim=1) | |
correct = pred.eq(labels) | |
test_acc += torch.mean(correct.float()) | |
print('Accuracy of the network on {} test images: {}%'.format(len(testset), round(test_acc.item()*100.0/len(loaders['test']), 2))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trained_model = train(1, 3, np.Inf, loaders, model, optimizer, criterion, use_cuda, "./checkpoint/current_checkpoint.pt", "./best_model/best_model.pt") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(start_epochs, n_epochs, valid_loss_min_input, loaders, model, optimizer, criterion, use_cuda, checkpoint_path, best_model_path): | |
""" | |
Keyword arguments: | |
start_epochs -- the real part (default 0.0) | |
n_epochs -- the imaginary part (default 0.0) | |
valid_loss_min_input | |
loaders | |
model | |
optimizer | |
criterion |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define loss function and optimizer | |
criterion = nn.NLLLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create the network, define the criterion and optimizer | |
model = FashionClassifier() | |
# move model to GPU if CUDA is available | |
if use_cuda: | |
model = model.cuda() | |
print(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define your network ( Simple Example ) | |
class FashionClassifier(nn.Module): | |
def __init__(self): | |
super().__init__() | |
input_size = 784 | |
self.fc1 = nn.Linear(input_size, 512) | |
self.fc2 = nn.Linear(512, 256) | |
self.fc3 = nn.Linear(256, 128) | |
self.fc4 = nn.Linear(128, 64) | |
self.fc5 = nn.Linear(64,10) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define a transform to normalize the data | |
transform = transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
# Download and load the training data | |
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform) | |
# Download and load the test data | |
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform) | |
loaders = { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_ckp(checkpoint_fpath, model, optimizer): | |
""" | |
checkpoint_path: path to save checkpoint | |
model: model that we want to load checkpoint parameters into | |
optimizer: optimizer we defined in previous training | |
""" | |
# load check point | |
checkpoint = torch.load(checkpoint_fpath) | |
# initialize state_dict from checkpoint to model | |
model.load_state_dict(checkpoint['state_dict']) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_ckp(state, is_best, checkpoint_path, best_model_path): | |
""" | |
state: checkpoint we want to save | |
is_best: is this the best checkpoint; min validation loss | |
checkpoint_path: path to save checkpoint | |
best_model_path: path to save best model | |
""" | |
f_path = checkpoint_path | |
# save checkpoint data to the path given, checkpoint_path | |
torch.save(state, f_path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# check if CUDA is available | |
use_cuda = torch.cuda.is_available() |