Skip to content

Instantly share code, notes, and snippets.

@astro
Created July 12, 2009 00:53
Show Gist options
  • Save astro/145463 to your computer and use it in GitHub Desktop.
Save astro/145463 to your computer and use it in GitHub Desktop.
# http://www.cs.toronto.edu/~roweis/data/mnist_{train,test}{0..9}.jpg
require 'RMagick'
include Magick
require 'ai4r'
W, H = 14, 14
IW, IH = 28, 28
class Picture
attr_reader :digit
def initialize(digit, pixels)
@digit = digit
dx = (IW / W).floor
dy = (IH / H).floor
@pixels = (1..H).to_a.map { |y|
(1..W).to_a.map { |x|
sum = 0
dy.times { |y1|
dx.times { |x1|
sum += pixels[(y * dy + y1) * IW + x * dx + x1].to_f
}
}
sum / (dx * dy)
}
}.flatten
end
def to_input
@pixels
end
def desired_output
[0.0] * @digit + [1.0] + [0.0] * (9 - @digit)
end
def print
(0..H).each { |y|
(0..W).each { |x|
STDOUT.print case to_input[y * H + x]
when (0.001..0.1)
'.'
when (0.1..0.5)
'x'
when (0.5..1.0)
'X'
else
' '
end
}
STDOUT.puts
}
end
end
def load_nist_pictures(digit, filename, n=3000)
GC.disable
res = []
print "Loading #{n} NIST samples from #{filename}"; STDOUT.flush
img = Image.read(filename).first
x, i = 0, 0
while i < n && x < img.columns - IW
y = 0
while i < n && y < img.rows - IH
pixels = img.get_pixels(x, y, IW, IH).map { |pixel| pixel.intensity / 65535.0 }
if pixels.inject(0.0, &:+) >= 5.0 # At least 5 white pixels
res << Picture.new(digit, pixels)
i += 1
print '.'; STDOUT.flush
end
y += IH
end
x += IW
end
puts "#{i}"
GC.enable
res
end
NIST_PATH = ARGV.shift or raise "Usage: #{$0} <nist-path> [_last_network.rb]"
SAVEGAME_FILE = ARGV.shift
training_examples_by_digit = [[]] * 10
(0..9).each do |digit|
file = "#{NIST_PATH}/mnist_train#{digit}.jpg"
imgs = load_nist_pictures(digit, file)
training_examples_by_digit[digit] = imgs
end
training_examples_interleaved = []
done = false
while !done
done = true
training_examples_by_digit.each do |training_examples|
if (img = training_examples.shift)
done = false
training_examples_interleaved << img
end
end
end
net = Ai4r::NeuralNetwork::Backpropagation.new([W * H, 32, 10])
if SAVEGAME_FILE
net.init_network
net.weights = eval(IO::readlines(SAVEGAME_FILE).to_s)
end
#net.learning_rate = 0.4
#net.momentum = 0.3
ITERATIONS = 2
t_start = Time.now
ITERATIONS.times do |iteration|
training_examples_interleaved.each_with_index do |img,img_i|
net.learning_rate = 0.9 - ((iteration * training_examples_interleaved.length + img_i) * 0.8 / (training_examples_interleaved.length * ITERATIONS))
#p img.to_input
#img.print
e = net.train(img.to_input, img.desired_output)
if e > 0.8
img.print
end
e_bar = "*" * (Math.log(e + 1) * 100).to_i
step = iteration * training_examples_interleaved.length + img_i + 1
steps = training_examples_interleaved.length * ITERATIONS
t_eta = ((steps - step) * (Time.now - t_start) / step.to_f).to_i
puts "[ETA #{t_eta.to_s.rjust(4)}s] Trained #{step.to_s.ljust(3)}/#{steps} " +
"(#{img.digit}): " +
"rate=#{format '%.6f', net.learning_rate} " +
"error=#{format '%.2f', e} " + e_bar
end
end
training_examples_by_digit, training_examples_interleaved = nil, nil
correct, wrong = 0, 0
(0..9).each do |digit|
file = "#{NIST_PATH}/mnist_test#{digit}.jpg"
imgs = load_nist_pictures(digit, file, 2000)
imgs.each do |img|
GC.disable
output = net.eval(img.to_input)
GC.enable
max_val = output.max
classified_digit = output.index(max_val)
if classified_digit == digit
correct += 1
else
wrong += 1
puts "#{classified_digit} != #{digit}: " + (output.map { |o| format '%.3f', o }.join(' '))
img.print
end
end
GC.start
end
puts "Correct: #{correct}"
puts "Wrong: #{wrong}"
puts "Ratio: #{correct * 100 / (correct + wrong)}%"
savegame = File.new('_last_network.rb', 'w')
savegame.puts net.weights.inspect
savegame.close
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment