Created
April 10, 2018 22:49
-
-
Save Kukanani/619a27d8a8cc1b245ef2d30f671a4a37 to your computer and use it in GitHub Desktop.
Save and load a Scikit-learn GMM to file using np.save
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn import mixture | |
# make a GMM | |
gmm = mixture.GaussianMixture(n_components=n_components, covariance_type=cv_type) | |
# fit the GMM however you'd like | |
# save to file | |
gmm_name = 'new_gmm' | |
np.save(gmm_name + '_weights', gmm.weights_, allow_pickle=False) | |
np.save(gmm_name + '_means', gmm.means_, allow_pickle=False) | |
np.save(gmm_name + '_covariances', gmm.covariances_, allow_pickle=False) | |
# reload | |
means = np.load(gmm_name + '_means.npy') | |
covar = np.load(gmm_name + '_covariances.npy') | |
loaded_gmm = mixture.GaussianMixture(n_components = len(means), covariance_type='full') | |
loaded_gmm.precisions_cholesky_ = np.linalg.cholesky(np.linalg.inv(covar)) | |
loaded_gmm.weights_ = np.load(gmm_name + '_weights.npy') | |
loaded_gmm.means_ = means | |
loaded_gmm.covariances_ = covar | |
# compare performance on your data between the two GMMs (loaded and saved) | |
cats = gmm.predict(y) | |
cats2 = loaded_gmm.predict(y) | |
print(all(cats == cats2)) # or assert |
Thanks for this! : )
Great finding for me
Thx for your code
Thanks Man👍🏼
In fact, you don't even need to re-calculate precisions_cholesky_,
You can load and save that too, e.g.
# save
np.save(gmm_name + '_precisions_cholesky', gmm.precisions_cholesky_, allow_pickle=False)
# load
loaded_gmm.precisions_cholesky_ = np.load(gmm_name + 'precisions_cholesky.npy')
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Fantastic!
I was looking everywhere and you are the only one who knows how to save GMM for using them somewhere else...
Congrats!