Created
July 12, 2009 00:53
-
-
Save astro/145463 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# http://www.cs.toronto.edu/~roweis/data/mnist_{train,test}{0..9}.jpg | |
require 'RMagick' | |
include Magick | |
require 'ai4r' | |
W, H = 14, 14 | |
IW, IH = 28, 28 | |
class Picture | |
attr_reader :digit | |
def initialize(digit, pixels) | |
@digit = digit | |
dx = (IW / W).floor | |
dy = (IH / H).floor | |
@pixels = (1..H).to_a.map { |y| | |
(1..W).to_a.map { |x| | |
sum = 0 | |
dy.times { |y1| | |
dx.times { |x1| | |
sum += pixels[(y * dy + y1) * IW + x * dx + x1].to_f | |
} | |
} | |
sum / (dx * dy) | |
} | |
}.flatten | |
end | |
def to_input | |
@pixels | |
end | |
def desired_output | |
[0.0] * @digit + [1.0] + [0.0] * (9 - @digit) | |
end | |
def print | |
(0..H).each { |y| | |
(0..W).each { |x| | |
STDOUT.print case to_input[y * H + x] | |
when (0.001..0.1) | |
'.' | |
when (0.1..0.5) | |
'x' | |
when (0.5..1.0) | |
'X' | |
else | |
' ' | |
end | |
} | |
STDOUT.puts | |
} | |
end | |
end | |
def load_nist_pictures(digit, filename, n=3000) | |
GC.disable | |
res = [] | |
print "Loading #{n} NIST samples from #{filename}"; STDOUT.flush | |
img = Image.read(filename).first | |
x, i = 0, 0 | |
while i < n && x < img.columns - IW | |
y = 0 | |
while i < n && y < img.rows - IH | |
pixels = img.get_pixels(x, y, IW, IH).map { |pixel| pixel.intensity / 65535.0 } | |
if pixels.inject(0.0, &:+) >= 5.0 # At least 5 white pixels | |
res << Picture.new(digit, pixels) | |
i += 1 | |
print '.'; STDOUT.flush | |
end | |
y += IH | |
end | |
x += IW | |
end | |
puts "#{i}" | |
GC.enable | |
res | |
end | |
NIST_PATH = ARGV.shift or raise "Usage: #{$0} <nist-path> [_last_network.rb]" | |
SAVEGAME_FILE = ARGV.shift | |
training_examples_by_digit = [[]] * 10 | |
(0..9).each do |digit| | |
file = "#{NIST_PATH}/mnist_train#{digit}.jpg" | |
imgs = load_nist_pictures(digit, file) | |
training_examples_by_digit[digit] = imgs | |
end | |
training_examples_interleaved = [] | |
done = false | |
while !done | |
done = true | |
training_examples_by_digit.each do |training_examples| | |
if (img = training_examples.shift) | |
done = false | |
training_examples_interleaved << img | |
end | |
end | |
end | |
net = Ai4r::NeuralNetwork::Backpropagation.new([W * H, 32, 10]) | |
if SAVEGAME_FILE | |
net.init_network | |
net.weights = eval(IO::readlines(SAVEGAME_FILE).to_s) | |
end | |
#net.learning_rate = 0.4 | |
#net.momentum = 0.3 | |
ITERATIONS = 2 | |
t_start = Time.now | |
ITERATIONS.times do |iteration| | |
training_examples_interleaved.each_with_index do |img,img_i| | |
net.learning_rate = 0.9 - ((iteration * training_examples_interleaved.length + img_i) * 0.8 / (training_examples_interleaved.length * ITERATIONS)) | |
#p img.to_input | |
#img.print | |
e = net.train(img.to_input, img.desired_output) | |
if e > 0.8 | |
img.print | |
end | |
e_bar = "*" * (Math.log(e + 1) * 100).to_i | |
step = iteration * training_examples_interleaved.length + img_i + 1 | |
steps = training_examples_interleaved.length * ITERATIONS | |
t_eta = ((steps - step) * (Time.now - t_start) / step.to_f).to_i | |
puts "[ETA #{t_eta.to_s.rjust(4)}s] Trained #{step.to_s.ljust(3)}/#{steps} " + | |
"(#{img.digit}): " + | |
"rate=#{format '%.6f', net.learning_rate} " + | |
"error=#{format '%.2f', e} " + e_bar | |
end | |
end | |
training_examples_by_digit, training_examples_interleaved = nil, nil | |
correct, wrong = 0, 0 | |
(0..9).each do |digit| | |
file = "#{NIST_PATH}/mnist_test#{digit}.jpg" | |
imgs = load_nist_pictures(digit, file, 2000) | |
imgs.each do |img| | |
GC.disable | |
output = net.eval(img.to_input) | |
GC.enable | |
max_val = output.max | |
classified_digit = output.index(max_val) | |
if classified_digit == digit | |
correct += 1 | |
else | |
wrong += 1 | |
puts "#{classified_digit} != #{digit}: " + (output.map { |o| format '%.3f', o }.join(' ')) | |
img.print | |
end | |
end | |
GC.start | |
end | |
puts "Correct: #{correct}" | |
puts "Wrong: #{wrong}" | |
puts "Ratio: #{correct * 100 / (correct + wrong)}%" | |
savegame = File.new('_last_network.rb', 'w') | |
savegame.puts net.weights.inspect | |
savegame.close |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment