Skip to content

Instantly share code, notes, and snippets.

@hashlash
Created March 18, 2020 09:46
Show Gist options
  • Save hashlash/c71ae298508427199d8527bc6023a516 to your computer and use it in GitHub Desktop.
Save hashlash/c71ae298508427199d8527bc6023a516 to your computer and use it in GitHub Desktop.
from sklearn.naive_bayes import _BaseNB
class MultiNB(_BaseNB):
def __init__(self, models_dict):
self.models_dict = models_dict
def fit(self, X, y, **kwargs):
for column, model in self.models_dict.items():
model.fit(X[:, column], y, **kwargs)
self.fitted_ = True
def _joint_log_likelihood(self, X):
sample_model = next(iter(self.models_dict.values()))
jil = np.zeros((X.shape[0], sample_model.class_count_.shape[0]))
jil += sample_model.class_log_prior_
for column, model in self.models_dict.items():
jil += model._joint_log_likelihood(X[:, column])
jil -= model.class_log_prior_
return jil
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment