-
-
Save Tachyon5/a0b89df5f9293e09564c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Bayesian Generative Classifier | |
------------------------------ | |
""" | |
# Author: Jake Vanderplas <[email protected]> | |
import numpy as np | |
from sklearn.neighbors.kde import KernelDensity | |
from sklearn.mixture import GMM | |
from sklearn.base import BaseEstimator, clone | |
from sklearn.utils import array2d, check_random_state | |
from sklearn.naive_bayes import BaseNB | |
class NormalApproximation(BaseEstimator): | |
"""Normal Approximation Density Estimator""" | |
def __init__(self): | |
pass | |
def fit(self, X): | |
"""Fit the Normal Approximation to data | |
Parameters | |
---------- | |
X: array_like, shape (n_samples, n_features) | |
List of n_features-dimensional data points. Each row | |
corresponds to a single data point. | |
""" | |
X = array2d(X) | |
epsilon = 1e-9 | |
self.mean = X.mean(0) | |
self.var = X.var(0) + epsilon | |
return self | |
def eval(self, X): | |
"""Evaluate the model on the data | |
Parameters | |
---------- | |
X : array_like | |
An array of points to query. Last dimension should match dimension | |
of training data (n_features) | |
Returns | |
------- | |
density : ndarray | |
The array of density evaluations. This has shape X.shape[:-1] | |
""" | |
X = array2d(X) | |
if X.shape[-1] != self.mean.shape[0]: | |
raise ValueError("dimension of X must match that of training data") | |
norm = 1. / np.sqrt(2 ** X.shape[-1] * np.sum(self.var)) | |
res = np.log(norm * np.exp(-0.5 * ((X - self.mean) ** 2 | |
/ self.var).sum(1))) | |
return res | |
def score(self, X): | |
"""Compute the log probability under the model. | |
Parameters | |
---------- | |
X : array_like, shape (n_samples, n_features) | |
List of n_features-dimensional data points. Each row | |
corresponds to a single data point. | |
Returns | |
------- | |
logprob : array_like, shape (n_samples,) | |
Log probabilities of each data point in X | |
""" | |
return np.sum(np.log(self.eval(X))) | |
def sample(self, n_samples=1, random_state=None): | |
"""Generate random samples from the model. | |
Parameters | |
---------- | |
n_samples : int, optional | |
Number of samples to generate. Defaults to 1. | |
random_state: RandomState or an int seed (0 by default) | |
A random number generator instance | |
Returns | |
------- | |
X : array_like, shape (n_samples, n_features) | |
List of samples | |
""" | |
rng = check_random_state(random_state) | |
try: | |
n_samples = n_samples + (1,) | |
except TypeError: | |
n_samples = (n_samples, 1) | |
return rng.normal(self.mean, self.std, size=n_samples) | |
DENSITY_ESTIMATORS = {'norm_approx':NormalApproximation, | |
'gmm':GMM, | |
'kde':KernelDensity} | |
class GenerativeBayes(BaseNB): | |
"""Generative Bayes Classifier""" | |
# note: interface is essentially the same as that of GaussianNB, | |
# and if density_estimator is `NormalApproximation`, it should | |
# give the same results. | |
def __init__(self, density_estimator, **kwargs): | |
if isinstance(density_estimator, str): | |
dclass = DENSITY_ESTIMATORS.get(density_estimator) | |
self.density_estimator = dclass(**kwargs) | |
elif isinstance(density_estimator, type): | |
self.density_estimator = density_estimator(**kwargs) | |
else: | |
self.density_estimator = density_estimator | |
def fit(self, X, y): | |
X = array2d(X) | |
y = np.asarray(y) | |
self.classes_ = np.sort(np.unique(y)) | |
n_classes = len(self.classes_) | |
n_samples, n_features = X.shape | |
self.class_prior_ = np.array([np.float(np.sum(y == y_i)) / n_samples | |
for y_i in self.classes_]) | |
self.estimators_ = [clone(self.density_estimator).fit(X[y == c]) | |
for c in self.classes_] | |
return self | |
def _joint_log_likelihood(self, X): | |
X = array2d(X) | |
jll = np.array([np.log(prior) + dens.eval(X) | |
for (prior, dens) | |
in zip(self.class_prior_, | |
self.estimators_)]).T | |
return jll |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment