Created
October 15, 2012 20:48
-
-
Save matchu/3895352 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autograder import * | |
import samples | |
import recognizer | |
class Homework3(ProblemSet): | |
class Problem1(ProblemSet): | |
class Problem1a(Problem): | |
""" Test for train_bayes.""" | |
def test_train_bayes(self): | |
train_bayes = self.subject.train_bayes | |
im1 = samples.dImage([['+','+','+','+'],[' ','#','+','#'],['#','+','#',' '],['+','#',' ',' ']],4,4) | |
im2 = samples.dImage([['+','+','+','+'],[' ',' ','#','+'],[' ','#','+','#'],['#','+','#',' ']],4,4) | |
im3 = samples.dImage([['+','+','+',' '],[' ','#','+','#'],['#','+','#',' '],['+','#',' ',' ']],4,4) | |
im4 = samples.dImage([['#','+','#',' '],['#','+','#',' '],['#','+','#',' '],['#','+','#',' ']],4,4) | |
im5 = samples.dImage([[' ','#','+','#'],[' ','#','+','#'],[' ','#','+','#'],[' ','#','+','#']],4,4) | |
im6 = samples.dImage([[' ','#','+','#'],[' ','#','+','#'],['#','+','#',' '],['#','+','#',' ']],4,4) | |
training_images = [recognizer.basic_feature_extractor(x,4,4) for x in [im1,im2,im3,im4,im5,im6]] | |
training_labels = [7,7,7,1,1,1] | |
legal_labels = [1,7] | |
correct_priors = {1:0.5, 7:0.5} | |
correct_cprob = {(0, 1): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.6666666666666666, 1: 0.16666666666666666, 2: 0.16666666666666666}}, | |
(1, 2): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}}, | |
(3, 2): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}}, | |
(0, 0): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.6666666666666666, 2: 0.16666666666666666}}, | |
(3, 3): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.6666666666666666, 1: 0.16666666666666666, 2: 0.16666666666666666}}, | |
(3, 0): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.3333333333333333, 1: 0.5, 2: 0.16666666666666666}}, | |
(3, 1): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}}, | |
(2, 1): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}}, | |
(1, 1): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}}, | |
(2, 0): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.6666666666666666, 2: 0.16666666666666666}}, | |
(1, 3): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}}, | |
(2, 3): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}}, | |
(2, 2): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}}, | |
(1, 0): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.6666666666666666, 2: 0.16666666666666666}}, | |
(0, 3): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}}, | |
(0, 2): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}}} | |
self.assertCall(train_bayes,(training_images,training_labels,legal_labels,1),(correct_priors,correct_cprob)) | |
class Problem1b(Problem): | |
""" Tests for calculate_log_posterior_probabilities.""" | |
def test_calculate_log_posterior_probabilities(self): | |
calculate_log_posterior_probabilities = self.subject.calculate_log_posterior_probabilities | |
priors = {1:0.5, 7:0.5} | |
cprob = {(0, 1): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.6666666666666666, 1: 0.16666666666666666, 2: 0.16666666666666666}}, | |
(1, 2): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}}, | |
(3, 2): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}}, | |
(0, 0): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.6666666666666666, 2: 0.16666666666666666}}, | |
(3, 3): {1: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}, | |
7: {0: 0.6666666666666666, 1: 0.16666666666666666, 2: 0.16666666666666666}}, | |
(3, 0): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.3333333333333333, 1: 0.5, 2: 0.16666666666666666}}, | |
(3, 1): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}}, | |
(2, 1): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}}, | |
(1, 1): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}}, | |
(2, 0): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.6666666666666666, 2: 0.16666666666666666}}, | |
(1, 3): {1: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}, | |
7: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}}, | |
(2, 3): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.5, 1: 0.16666666666666666, 2: 0.3333333333333333}}, | |
(2, 2): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}}, | |
(1, 0): {1: {0: 0.16666666666666666, 1: 0.3333333333333333, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.6666666666666666, 2: 0.16666666666666666}}, | |
(0, 3): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.16666666666666666, 1: 0.5, 2: 0.3333333333333333}}, | |
(0, 2): {1: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}, | |
7: {0: 0.3333333333333333, 1: 0.16666666666666666, 2: 0.5}}} | |
im7 = samples.dImage([['+','+','+','+'],[' ',' ','#','+'],[' ',' ','#','+'],[' ',' ','#','+']],4,4) | |
test_image = recognizer.basic_feature_extractor(im7,4,4) | |
test_labels = [7] | |
legal_labels = [1,7] | |
correct_posterior = {1: -22.1942608112966, 7: -18.153160763593316} | |
given_posterior = calculate_log_posterior_probabilities(test_image,legal_labels,priors,cprob) | |
self.assertWithin(given_posterior[1], -22.194261, -22.194260, "posterior[1]") | |
self.assertWithin(given_posterior[7], -18.153161, -18.153160, "posterior[7]") | |
def assertWithin(self, given, expected_min, expected_max, name): | |
msg = "expected {0} to be within {1} and {2}, but instead was {3}".format( | |
name, expected_min, expected_max, given) | |
self.assertTrue(expected_min < given < expected_max, msg) | |
if __name__ == '__main__': | |
main(Homework3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment