Last active
August 29, 2015 13:58
-
-
Save LFY/10175240 to your computer and use it in GitHub Desktop.
Contrastive divergence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import * | |
from random import * | |
# Basic MCMC-------------------------------------------------------------------- | |
def make_kernel(propscore): | |
def call(state, score): | |
return mh_step(state, score, propscore) | |
return call | |
def make_propscore(prop, scorer): | |
def call(state): | |
next_state = prop(state) | |
next_score = scorer(state) | |
return next_state, next_score | |
return call | |
def log_flip(lp): | |
return uniform(0, 1) < exp(lp) | |
def acc_step(curr, prop, lp_new, lp_old): | |
return prop if log_flip(lp_new - lp_old) else curr | |
def mh_step(state, score, propscore): | |
next_state, next_score = propscore(state) | |
return acc_step(state, next_state, next_score, score), next_score | |
# Contrastive divergence-------------------------------------------------------- | |
# Assumes multiplicative weights; i.e., | |
# f(x; theta) = theta * g(x), therefore | |
# grad f_theta f(x; theta) = g(x) | |
# So we just need a vector of factor functions | |
# [f_i ... f_n], | |
# each f_i is a function of the state | |
# P(state) = \prod_i f_i = \sum_i ln f_i | |
# prop: a proposal function: state -> some random state | |
def make_cd_step(step, initial_state, prop, factors): | |
# Assume exp weights. | |
grad = lambda state: map(lambda f: f(state), factors) | |
# c.f. page 3 eq (16) of http://www.robots.ox.ac.uk/~ojw/files/NotesOnCD.pdf | |
update_step = lambda theta, grad_initial, grad_next: theta + step * (grad_initial - grad_next) | |
def call(thetas): | |
# Scoring, gradient | |
initial_scorer = lambda state: sum(map(lambda t, f: t * f(state), thetas, factors)) | |
initial_propscore = make_propscore(prop, initial_scorer) | |
# Actual next state | |
next_state, next_score = mh_step(initial_state, initial_scorer(initial_state), initial_propscore) | |
# Update in direction of estimated gradient | |
# c.f. page 3 eq (16) of http://www.robots.ox.ac.uk/~ojw/files/NotesOnCD.pdf | |
next_thetas = map(update_step, thetas, grad(initial_state), grad(next_state)) | |
return next_thetas | |
return call | |
# Example: learning properties of lists of numbers------------------------------ | |
# Our state: a list of numbers | |
data_point = [1, 2, 4, 8, 16] | |
# Some factors: | |
to_01 = lambda x: 1 if x else -1 | |
# Has at least 1 even number. | |
has_even = lambda state: to_01(any(map(lambda i: i % 2 == 0, state))) | |
# Has at least 1 odd number. | |
has_odd = lambda state: to_01(any(map(lambda i: i % 2 == 1, state))) | |
has_n = lambda n: lambda state: to_01(any(map(lambda i: i == n, state))) | |
# Has the number 3. | |
has_3 = has_n(3) | |
# Has the number 2. | |
has_2 = has_n(2) | |
# A proposal function | |
def prop(state): | |
shift_idx = randint(0, len(state) - 1) | |
next_state = map(lambda (i, x): x if i != shift_idx else (x + randint(-5, 5)), enumerate(state)) | |
return next_state | |
# Factors and initial weights | |
factors = [has_even, has_odd, has_3, has_2] | |
thetas = [1.0, 1.0, 1.0, 1.0] | |
# Timestep | |
step = 0.2 | |
# Our CD step function | |
cd_step = make_cd_step(step, data_point, prop, factors) | |
# Some iterations of CD | |
for i in range(100): | |
thetas = cd_step(thetas) | |
# See that the weight for has_3 is low and the weight for has_2 is # high. | |
print "learned weights:",thetas | |
# More data points-------------------------------------------------------------- | |
data_points = [data_point, [42], [0, 1, 2, 3], [2, 4, 6 ,8], [9999999, 2, 33]] | |
# Make CD steps for each of them | |
cd_steps = map(lambda d: make_cd_step(step, d, prop, factors), data_points) | |
# Initialize thetas again | |
thetas = [1.0, 1.0, 1.0, 1.0] | |
# Iterate again, but use a random data point per step | |
for i in range(100): | |
thetas = cd_steps[randint(0, len(cd_steps) - 1)](thetas) | |
print "learned weights (more data points):",thetas |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment