Last active
November 4, 2016 07:46
-
-
Save MichaelSnowden/b56fb04401f2705a0b38098aca372e41 to your computer and use it in GitHub Desktop.
Naive Bayes. Make sure to have the datasets from http://yann.lecun.com/exdb/mnist/ if you're using this
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from operator import add, mul | |
| from functools import reduce | |
| from math import log | |
| NUM_PIXELS = 784 | |
| NUM_ROWS = 28 | |
| NUM_COLS = 28 | |
| def display(x_X): | |
| for i in range(NUM_ROWS): | |
| for j in range(NUM_COLS): | |
| print(u'\u2588' if x_X[i * NUM_ROWS + j] == 1 else ' ', end='') | |
| print() | |
| def test_naive_bayes(predict, test_data, num_test_samples, label_domain_size): | |
| error = 0 | |
| for x_X, y in test_data: | |
| y_hat = predict(x_X) | |
| if y != y_hat: | |
| display(x_X) | |
| print("actual =", y, "guess =", y_hat) | |
| error += 1 | |
| print("error rate =", error / num_test_samples) | |
| if __name__ == '__main__': | |
| from struct import unpack | |
| def read_32_int(file): | |
| return unpack('>I', file.read(4))[0] | |
| def read_8_int(file): | |
| return int(unpack('>B', file.read(1))[0]) | |
| def read(file_name, t): | |
| with open(file_name + '-images.idx3-ubyte', 'rb') as image_file: | |
| image_file.read(16) | |
| with open(file_name + '-labels.idx1-ubyte', 'rb') as label_file: | |
| label_file.read(8) | |
| for _ in range(t): | |
| x_X = [None] * NUM_PIXELS | |
| for X in range(NUM_PIXELS): | |
| pixel = read_8_int(image_file) | |
| if pixel < 128: | |
| x = 0 | |
| else: | |
| x = 1 | |
| x_X[X] = x | |
| y = read_8_int(label_file) | |
| yield x_X, y | |
| import sys | |
| num_training_samples = 60000 # int(sys.argv[1]) | |
| num_test_samples = 200 # int(sys.argv[2]) | |
| training_data = read('train', num_training_samples) | |
| test_data = read('t10k', num_test_samples) | |
| from naive_bayes import train | |
| predict = train(training_data, NUM_PIXELS, 2, 10) | |
| test_naive_bayes(predict, test_data, num_test_samples, 10) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(training_data, X_size, x_size, y_size): | |
| from math import log | |
| log_P_y = [-log(y_size) for y in range(y_size)] | |
| N_Xxy = [[[0 for y in range(y_size)] for x in range(x_size)] for X in range(X_size)] | |
| N_y = [0 for y in range(y_size)] | |
| for x_X, y in training_data: | |
| N_y[y] += 1 | |
| for X, x in enumerate(x_X): | |
| N_Xxy[X][x][y] += 1 | |
| log_P_Xxy = [[[log(N_Xxy[X][x][y] + 1) - log(N_y[y] + 1) for y in range(y_size)] for x in range(x_size)] for X in range(X_size)] | |
| def predict(x_X): | |
| return max(range(y_size), key=lambda y: log_P_y[y] + sum(log_P_Xxy[X][x][y] for X, x in enumerate(x_X))) | |
| return predict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment