Skip to content

Instantly share code, notes, and snippets.

@vermorel
Last active April 28, 2017 13:30
Show Gist options
  • Save vermorel/799f6d9450887ac4edbcbc0cd7b0ba17 to your computer and use it in GitHub Desktop.
Save vermorel/799f6d9450887ac4edbcbc0cd7b0ba17 to your computer and use it in GitHub Desktop.
Classification with naive Bayes
using System;
using System.Collections.Generic;
using System.Linq;
namespace Lokad
{
/// <summary>
/// Naive Bayesian classifier.
/// </summary>
/// <remarks>
/// [vermorel] May 2016. As a rule of thumb, a random forest classifier is superior
/// in every way to the Naive Bayes classifier except for execution speed. There is
/// no good statistical reason to ever use this classifier. This code is mostly kept
/// for archival purposes, and maybe for future micro-benchmarks too.
///
/// See the good discussion at:
/// https://stackoverflow.com/questions/10059594/a-simple-explanation-of-naive-bayes-classification
/// </remarks>
public static class NaiveBayes
{
/// <summary>
/// Returns the label classes from the most probable to the least probable.
/// </summary>
/// <param name="inputs">Each input is a feature vector.</param>
/// <param name="labels">Each input has its own label.</param>
/// <param name="unlabeled">A new input vector to be classified.</param>
/// <returns></returns>
public static int[][] Classify(int[][] inputs, int[] labels, int[][] unlabeled)
{
var N = inputs.Length;
var n = inputs[0].Length; // we assume that all input have the same dimension
var m = labels.Max(l => l) + 1;
var o = new int[n];
var a = new int[m];
var b = new int[n][];
var bc = new int[n, m][];
for (int j = 0; j < n; j++)
{
o[j] = inputs.Max(f => f[j]) + 1;
b[j] = new int[o[j]];
for (var k = 0; k < m; k++)
{
bc[j, k] = new int[o[j]];
}
}
for (int i = 0; i < N; i++)
{
var labeli = labels[i];
a[labeli] += 1;
var fi = inputs[i];
for (int j = 0; j < n; j++)
{
b[j][fi[j]] += 1;
bc[j, labeli][fi[j]] += 1;
}
}
var pa = new double[m];
var tpa = 0.0;
for (var k = 0; k < m; k++)
{
pa[k] = ConfidenceScore(a[k], N - a[k]);
tpa += pa[k];
}
for (var k = 0; k < m; k++)
{
pa[k] /= tpa;
}
var pb = new double[n]; // P(B = j)
var pba = new double[n,m]; // P(B = j | A = k)
// prediction per-se, for each unlabeled input
var final = new int[unlabeled.Length][];
for (int i = 0; i < unlabeled.Length; i++)
{
var unlabeledi = unlabeled[i];
var tpb = 0.0;
for (int j = 0; j < n; j++)
{
// special case: inputs never observed before, skip the feature
if (unlabeledi[j] >= b[j].Length) continue;
pb[j] = ConfidenceScore(b[j][unlabeledi[j]], N - b[j][unlabeledi[j]]);
tpb += pb[j];
var tpba = 0.0;
for (var k = 0; k < m; k++)
{
pba[j, k] = ConfidenceScore(bc[j, k][unlabeledi[j]], a[k] - bc[j, k][unlabeledi[j]]);
tpba += pba[j, k];
}
for (var k = 0; k < m; k++)
{
pba[j, k] /= tpba;
}
}
for (int j = 0; j < n; j++)
{
pb[j] /= tpb;
}
// Generalized bayes theorem, assuming that B1 .. Bn are independent
// P(B1 | A) .. P(Bn | A)
// P(A | B1 .. Bn) = ---------------------- P(A)
// P(B1) .. P(Bn)
var p = new List<Tuple<int, double>>(m);
for (var k = 0; k < m; k++)
{
var s = Math.Log(pa[k]);
for (int j = 0; j < n; j++)
{
// special case: inputs never observed before, skip the feature
if (unlabeledi[j] >= b[j].Length) continue;
s += Math.Log(pba[j, k]) - Math.Log(pb[j]);
}
p.Add(new Tuple<int, double>(k, Math.Exp(s)));
}
final[i] = p.OrderByDescending(tu => tu.Item2).Select(tu => tu.Item1).ToArray();
}
return final;
}
static double ConfidenceScore(int positive, int negative)
{
// When limited observations are available, we should not put too much faith in the observations.
// Instead, we factor the number of observations, as it could be done for ratings on the web:
// http://www.evanmiller.org/how-not-to-sort-by-average-rating.html
// See also, the Wilson score https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval
if (positive + negative == 0)
{
return 0.001; // HACK: we may have no observations in some conditions
}
// lower bound on the 95% confidence interval
return ((positive + 1.9208)/(positive + negative) -
1.96*Math.Sqrt((positive*negative)/(positive + negative) + 0.9604)/
(positive + negative))/(1 + 3.8416/(positive + negative));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment