Created
February 27, 2017 11:24
-
-
Save amorgun/6df4b04638430911b50b0c684bed7924 to your computer and use it in GitHub Desktop.
SVM hack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.sparse import csr_matrix | |
import numpy as np | |
def make_sparse(clf): | |
""" | |
Make sklearn.svm.SVC trained on dense data work on sparse features without fitting if again. | |
""" | |
clf._sparse = True | |
clf.support_vectors_ = csr_matrix(clf.support_vectors_) | |
# From https://github.com/scikit-learn/scikit-learn/blob/14031f65d144e3966113d3daec836e443c6d7a5b/sklearn/svm/base.py#L288-L293 | |
n_class = len(clf._label) - 1 | |
n_SV = clf.support_vectors_.shape[0] | |
dual_coef_indices = np.tile(np.arange(n_SV), n_class) | |
dual_coef_indptr = np.arange(0, dual_coef_indices.size + 1, | |
dual_coef_indices.size / n_class) | |
clf.dual_coef_ = csr_matrix( | |
(clf.dual_coef_.ravel(), dual_coef_indices, dual_coef_indptr), | |
(n_class, n_SV)) | |
return clf |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment