Created
July 25, 2020 05:40
-
-
Save mohcinemadkour/4b52f8d05b7cc05672347737f3cf8850 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MLPClassifier: | |
def __init__(self, hidden_layers=[800, 800], droprates=[0, 0], batch_size=128, max_epoch=10, \ | |
lr=0.1, momentum=0): | |
# Wrap MLP model | |
self.hidden_layers = hidden_layers | |
self.droprates = droprates | |
self.batch_size = batch_size | |
self.max_epoch = max_epoch | |
self.model = MLP(hidden_layers=hidden_layers, droprates=droprates) | |
self.model.cuda() | |
self.criterion = nn.CrossEntropyLoss().cuda() | |
self.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=momentum) | |
self.loss_ = [] | |
self.test_accuracy = [] | |
self.test_error = [] | |
def fit(self, trainset, testset, verbose=True): | |
# Training, make sure it's on GPU, otherwise, very slow... | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False) | |
X_test, y_test = iter(testloader).next() | |
X_test = X_test.cuda() | |
for epoch in range(self.max_epoch): | |
running_loss = 0 | |
for i, data in enumerate(trainloader, 0): | |
inputs, labels = data | |
inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda() | |
self.optimizer.zero_grad() | |
outputs = self.model(inputs) | |
loss = self.criterion(outputs, labels) | |
loss.backward() | |
self.optimizer.step() | |
running_loss += loss.data[0] | |
self.loss_.append(running_loss / len(trainloader)) | |
if verbose: | |
print('Epoch {} loss: {}'.format(epoch+1, self.loss_[-1])) | |
y_test_pred = self.predict(X_test).cpu() | |
self.test_accuracy.append(np.mean(y_test == y_test_pred)) | |
self.test_error.append(int(len(testset)*(1-self.test_accuracy[-1]))) | |
if verbose: | |
print('Test error: {}; test accuracy: {}'.format(self.test_error[-1], self.test_accuracy[-1])) | |
return self | |
def predict(self, x): | |
# Used to keep all test errors after each epoch | |
model = self.model.eval() | |
outputs = model(Variable(x)) | |
_, pred = torch.max(outputs.data, 1) | |
model = self.model.train() | |
return pred | |
def __str__(self): | |
return 'Hidden layers: {}; dropout rates: {}'.format(self.hidden_layers, self.droprates) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment