Created
May 13, 2015 15:57
-
-
Save beomjunshin-ben/733f03a6bcf7151dbb67 to your computer and use it in GitHub Desktop.
RBM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import fetch_mldata | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class RBM: | |
def __init__(self, num_visible, num_hidden, learning_rate=0.1): | |
self.num_hidden = num_hidden | |
self.num_visible = num_visible | |
self.learning_rate = learning_rate | |
self.w = 0.1 * np.random.randn(self.num_visible, self.num_hidden) | |
self.b = 0.1 * np.random.randn(self.num_hidden) | |
self.c = 0.1 * np.random.randn(self.num_visible) | |
def train(self, data_all, batch_size=100, max_epochs=1000): | |
num_data = data_all.shape[0] | |
num_batch = num_data/batch_size | |
for epoch in range(max_epochs): | |
for batch_idx in range(batch_size): | |
data = data_all[batch_idx*num_batch:(batch_idx+1)*num_batch] | |
# add bias term b row-wise | |
# pos_hidden_activations is row-wise stack of hidden units of each training sample | |
pos_hidden_activations = self.sigmoid(self.b + np.dot(data, self.w)) # (70000 x 576) | |
pos_phase_w = np.dot(data.T, pos_hidden_activations) # (784 x 70000) x (70000 x 576) = (dimension of self.w) | |
neg_visible_activations = self.sigmoid(self.c + np.dot(pos_hidden_activations, self.w.T)) | |
neg_visible_states = neg_visible_activations > np.random.rand(num_batch, self.num_visible) | |
neg_hidden_activations = self.sigmoid(self.b + np.dot(neg_visible_states, self.w)) | |
neg_phase_w = np.dot(neg_visible_states.T, neg_hidden_activations) | |
self.w += self.learning_rate*(pos_phase_w - neg_phase_w) / num_batch | |
# image tile for self.w[:, 0~num_hidden-1] | |
error = np.sum((data - neg_visible_states) ** 2) | |
visualize(self.w[:, 0], width=28, height=28) | |
print("Epoch %s: error is %s" % (epoch, error)) | |
print self.w | |
def sigmoid(self, x): | |
return 1/(np.exp(-x) + 1) | |
def visualize(x, width=-1, height=-1, pad=5, title='Filter.png'): | |
# TODO save tiled images.. tired | |
x = x.reshape((width, height)) | |
f = plt.figure() | |
plt.imshow(x, cmap='Greys') | |
f.savefig(title) | |
if __name__ == '__main__': | |
mnist = fetch_mldata('MNIST original', data_home='./') | |
X = mnist.data | |
r = RBM(num_visible=28*28, num_hidden=10) | |
r.train(X, batch_size=100, max_epochs=5000) | |
np.savetxt('weights.out', r.w) | |
np.savetxt('bias_b.out', r.b) | |
np.savetxt('bias_c.out', r.c) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment