Created
February 14, 2017 09:08
-
-
Save david90/cd4e3288a535424fcb926a5ac91ee7ea to your computer and use it in GitHub Desktop.
Code for the training the SVM classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sklearn | |
from sklearn import cross_validation, grid_search | |
from sklearn.metrics import confusion_matrix, classification_report | |
from sklearn.svm import SVC | |
from sklearn.externals import joblib | |
def train_svm_classifer(features, labels, model_output_path): | |
""" | |
train_svm_classifer will train a SVM, saved the trained and SVM model and | |
report the classification performance | |
features: array of input features | |
labels: array of labels associated with the input features | |
model_output_path: path for storing the trained svm model | |
""" | |
# save 20% of data for performance evaluation | |
X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.2) | |
param = [ | |
{ | |
"kernel": ["linear"], | |
"C": [1, 10, 100, 1000] | |
}, | |
{ | |
"kernel": ["rbf"], | |
"C": [1, 10, 100, 1000], | |
"gamma": [1e-2, 1e-3, 1e-4, 1e-5] | |
} | |
] | |
# request probability estimation | |
svm = SVC(probability=True) | |
# 10-fold cross validation, use 4 thread as each fold and each parameter set can be train in parallel | |
clf = grid_search.GridSearchCV(svm, param, | |
cv=10, n_jobs=4, verbose=3) | |
clf.fit(X_train, y_train) | |
if os.path.exists(model_output_path): | |
joblib.dump(clf.best_estimator_, model_output_path) | |
else: | |
print("Cannot save trained svm model to {0}.".format(model_output_path)) | |
print("\nBest parameters set:") | |
print(clf.best_params_) | |
y_predict=clf.predict(X_test) | |
labels=sorted(list(set(labels))) | |
print("\nConfusion matrix:") | |
print("Labels: {0}\n".format(",".join(labels))) | |
print(confusion_matrix(y_test, y_predict, labels=labels)) | |
print("\nClassification report:") | |
print(classification_report(y_test, y_predict)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment