Last active
April 28, 2017 13:30
-
-
Save vermorel/799f6d9450887ac4edbcbc0cd7b0ba17 to your computer and use it in GitHub Desktop.
Classification with naive Bayes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
namespace Lokad | |
{ | |
/// <summary> | |
/// Naive Bayesian classifier. | |
/// </summary> | |
/// <remarks> | |
/// [vermorel] May 2016. As a rule of thumb, a random forest classifier is superior | |
/// in every way to the Naive Bayes classifier except for execution speed. There is | |
/// no good statistical reason to ever use this classifier. This code is mostly kept | |
/// for archival purposes, and maybe for future micro-benchmarks too. | |
/// | |
/// See the good discussion at: | |
/// https://stackoverflow.com/questions/10059594/a-simple-explanation-of-naive-bayes-classification | |
/// </remarks> | |
public static class NaiveBayes | |
{ | |
/// <summary> | |
/// Returns the label classes from the most probable to the least probable. | |
/// </summary> | |
/// <param name="inputs">Each input is a feature vector.</param> | |
/// <param name="labels">Each input has its own label.</param> | |
/// <param name="unlabeled">A new input vector to be classified.</param> | |
/// <returns></returns> | |
public static int[][] Classify(int[][] inputs, int[] labels, int[][] unlabeled) | |
{ | |
var N = inputs.Length; | |
var n = inputs[0].Length; // we assume that all input have the same dimension | |
var m = labels.Max(l => l) + 1; | |
var o = new int[n]; | |
var a = new int[m]; | |
var b = new int[n][]; | |
var bc = new int[n, m][]; | |
for (int j = 0; j < n; j++) | |
{ | |
o[j] = inputs.Max(f => f[j]) + 1; | |
b[j] = new int[o[j]]; | |
for (var k = 0; k < m; k++) | |
{ | |
bc[j, k] = new int[o[j]]; | |
} | |
} | |
for (int i = 0; i < N; i++) | |
{ | |
var labeli = labels[i]; | |
a[labeli] += 1; | |
var fi = inputs[i]; | |
for (int j = 0; j < n; j++) | |
{ | |
b[j][fi[j]] += 1; | |
bc[j, labeli][fi[j]] += 1; | |
} | |
} | |
var pa = new double[m]; | |
var tpa = 0.0; | |
for (var k = 0; k < m; k++) | |
{ | |
pa[k] = ConfidenceScore(a[k], N - a[k]); | |
tpa += pa[k]; | |
} | |
for (var k = 0; k < m; k++) | |
{ | |
pa[k] /= tpa; | |
} | |
var pb = new double[n]; // P(B = j) | |
var pba = new double[n,m]; // P(B = j | A = k) | |
// prediction per-se, for each unlabeled input | |
var final = new int[unlabeled.Length][]; | |
for (int i = 0; i < unlabeled.Length; i++) | |
{ | |
var unlabeledi = unlabeled[i]; | |
var tpb = 0.0; | |
for (int j = 0; j < n; j++) | |
{ | |
// special case: inputs never observed before, skip the feature | |
if (unlabeledi[j] >= b[j].Length) continue; | |
pb[j] = ConfidenceScore(b[j][unlabeledi[j]], N - b[j][unlabeledi[j]]); | |
tpb += pb[j]; | |
var tpba = 0.0; | |
for (var k = 0; k < m; k++) | |
{ | |
pba[j, k] = ConfidenceScore(bc[j, k][unlabeledi[j]], a[k] - bc[j, k][unlabeledi[j]]); | |
tpba += pba[j, k]; | |
} | |
for (var k = 0; k < m; k++) | |
{ | |
pba[j, k] /= tpba; | |
} | |
} | |
for (int j = 0; j < n; j++) | |
{ | |
pb[j] /= tpb; | |
} | |
// Generalized bayes theorem, assuming that B1 .. Bn are independent | |
// P(B1 | A) .. P(Bn | A) | |
// P(A | B1 .. Bn) = ---------------------- P(A) | |
// P(B1) .. P(Bn) | |
var p = new List<Tuple<int, double>>(m); | |
for (var k = 0; k < m; k++) | |
{ | |
var s = Math.Log(pa[k]); | |
for (int j = 0; j < n; j++) | |
{ | |
// special case: inputs never observed before, skip the feature | |
if (unlabeledi[j] >= b[j].Length) continue; | |
s += Math.Log(pba[j, k]) - Math.Log(pb[j]); | |
} | |
p.Add(new Tuple<int, double>(k, Math.Exp(s))); | |
} | |
final[i] = p.OrderByDescending(tu => tu.Item2).Select(tu => tu.Item1).ToArray(); | |
} | |
return final; | |
} | |
static double ConfidenceScore(int positive, int negative) | |
{ | |
// When limited observations are available, we should not put too much faith in the observations. | |
// Instead, we factor the number of observations, as it could be done for ratings on the web: | |
// http://www.evanmiller.org/how-not-to-sort-by-average-rating.html | |
// See also, the Wilson score https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval | |
if (positive + negative == 0) | |
{ | |
return 0.001; // HACK: we may have no observations in some conditions | |
} | |
// lower bound on the 95% confidence interval | |
return ((positive + 1.9208)/(positive + negative) - | |
1.96*Math.Sqrt((positive*negative)/(positive + negative) + 0.9604)/ | |
(positive + negative))/(1 + 3.8416/(positive + negative)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment