Last active
January 23, 2020 09:25
-
-
Save larsbratholm/1d83d873827e675a7eab1ee520e4b0ea to your computer and use it in GitHub Desktop.
Example of 1d KDE using scipy and sklearn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.stats | |
import matplotlib.pyplot as plt | |
import sklearn.model_selection | |
import sklearn.neighbors | |
def kde_scipy(x, x_grid): | |
""" | |
Using heuristic for bandwidth | |
""" | |
kde = scipy.stats.gaussian_kde(x) | |
return kde.evaluate(x_grid) | |
def kde_sklearn(x, x_grid): | |
""" | |
Choosing bandwidth by CV | |
""" | |
params = {'bandwidth': 10**np.linspace(-2,2,100)} | |
cv = sklearn.model_selection.KFold(n_splits=5, shuffle=True) | |
model = GridSearchCV(sklearn.neighbors.KernelDensity(), params) | |
model.fit(x[:,None]) | |
# Check if the maximum or minimum of the bandwidth values were used | |
if model.best_params_['bandwidth'] in params['bandwidth'][[0,-1]]: | |
print("WARNING: Best bandwidth (%.3f) exceeds CV range (%.3f - %.3f)" % \ | |
(model.best_params_['bandwidth'], params['bandwidth'][0], params['bandwidth'][1])) | |
log_pdf = model.score_samples(x_grid[:,None]) | |
return np.exp(log_pdf) | |
# Make fake data | |
x = np.random.random(100) | |
# Make grid to evaluate KDE | |
x_min = x.min() | |
x_max = x.max() | |
x_range = x_max - x_min | |
x_grid = np.linspace(x_min - x_range/10, x_max + x_range/10, 1000) | |
plt.plot(x_grid, kde_scipy(x, x_grid)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment