-
-
Save tarvos21/bd09403f12f4b51ab68d6e40b03f81b1 to your computer and use it in GitHub Desktop.
Overfit neural network to find prime numbers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def is_prime?(n) | |
bits = 8.times.map { |i| n[i] } | |
[[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1], | |
[-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0], | |
[-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0], | |
[ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3], | |
[-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0], | |
[ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5], | |
[-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4], | |
[-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4], | |
[ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2] | |
], | |
[[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]], | |
] | |
.inject(bits) { |nodes_lhs, synapseses_rhs_major| | |
synapseses_rhs_major.map { |synapses| | |
total_input = [*nodes_lhs, 1.0].zip(synapses).map { |n, s| n*s }.inject(0, :+) | |
1 / (1 + Math::E ** -total_input) # https://en.wikipedia.org/wiki/Sigmoid_function | |
} | |
}.first > 0.5 | |
end | |
require 'prime' | |
(256).times.all? { |n| is_prime?(n) == Prime.prime?(n) } # => true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'ai4r/neural_network/backpropagation' | |
require 'prime' | |
# srand 1 | |
def bits_for(n) | |
[n[0],n[1],n[2],n[3],n[4],n[5],n[6],n[7]] | |
end | |
def train(net, n) | |
net.train bits_for(n), [n.prime? ? 1 : 0] | |
end | |
loop do | |
puts "Training the network, please wait."; | |
max = (2**8)-1 | |
net = Ai4r::NeuralNetwork::Backpropagation.new([8, 9, 1]); | |
2001.times do |i| | |
errors = 0.upto(max).to_a.shuffle.map { |n| [n, train(net, n)] } | |
next unless i % 200 == 0 | |
average = errors.map(&:last).inject(:+) / errors.length | |
puts "Error for run #{i}: #{average.inspect}" | |
next if i < 600 | |
# retrain = errors.select { |n, err| err > 0.2 } | |
retrain = 0.upto(max).reject { |n| (0.5 < net.eval(bits_for n).first) == n.prime? } | |
puts " retraining on #{retrain.inspect}" | |
500.times { retrain.each { |n| train net, n } } | |
end | |
puts "Test data" | |
num_correct = 0 | |
num_attempted = 0 | |
0.upto max do |n| | |
num_attempted += 1 | |
result = net.eval(bits_for n).first | |
actual = (0.5 <= result) | |
expected = n.prime? | |
(expected == actual) ? (num_correct += 1) : printf("%3d \e[31mreal:%-5s ai:%-5s\e[0m (%f)\n", n, expected, actual, result) | |
end | |
puts | |
puts "SUMMARY: #{num_correct}/#{num_attempted} (#{100.0*num_correct/num_attempted}%)" | |
if num_attempted == num_correct | |
require "pry" | |
binding.pry | |
end | |
puts "-----------------------------" | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# require 'pp' | |
# puts '[' << synapseses_rhs_major.map { |s| s.map { |s| '%4.1f' % s }.join(', ') }.join("],\n") | |
# 8 bits | |
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] | |
# [[[-3.2, 3.6, -0.7, 2.8, 0.9, 3.2, -0.1, 0.4, -2.3], | |
# [-0.2, -3.0, 0.9, -3.8, 1.3, 3.1, 1.0, -2.4, 0.2], | |
# [-1.0, 1.4, -0.0, -3.5, 1.6, 2.4, -2.9, 4.0, -1.0], | |
# [-0.6, 3.0, 1.9, 2.4, 0.4, 0.4, 2.7, 0.6, -0.9], | |
# [-4.1, -1.6, 2.8, -1.6, 0.6, -1.3, 3.3, 1.2, -0.5], | |
# [ 0.0, 3.8, -5.0, 4.1, -5.2, -1.1, -5.9, 4.6, 1.3], | |
# [-0.2, 2.0, -1.9, -1.0, -2.5, 6.2, -2.6, -0.3, -1.3], | |
# [-0.3, -0.7, 0.6, -0.8, 1.3, -0.6, 1.9, -2.2, -0.2], | |
# [-2.3, -0.3, 3.4, -2.6, 2.6, -1.2, 2.5, -2.1, -3.2], | |
# [ 0.8, -0.9, -4.4, 3.4, -1.6, 2.5, -0.6, -2.1, -0.9], | |
# [ 3.2, 3.7, 1.6, -3.8, -1.7, 4.5, -4.6, 3.2, -0.3], | |
# [-0.3, 0.7, 1.1, -0.3, -0.8, -0.5, -0.3, -0.8, -0.4], | |
# [-1.5, 1.7, -2.5, 0.3, -1.3, 2.0, -2.5, 0.4, 0.3], | |
# [ 2.9, -2.6, -1.7, 0.0, 1.4, -2.1, -1.7, 1.7, 0.9], | |
# [ 0.1, -2.5, 3.2, -1.7, 1.4, -2.5, 3.1, -1.6, -0.1], | |
# [-3.6, -2.1, -0.5, 4.1, -1.6, 0.6, -1.3, 4.4, -0.8], | |
# ], | |
# [[ 0.7, 1.1, -0.6, -0.6, 0.3, -0.7, 0.2, 1.0, -0.2, -0.7, -1.0, -0.9, 0.4, -0.9, -0.8, -0.5, -1.1], | |
# [-0.6, -0.9, -0.3, -0.9, 0.2, -0.1, 0.0, 0.3, 0.8, -0.7, 0.5, -0.4, -0.3, -0.8, -0.3, -0.5, -0.4], | |
# [-1.6, -0.5, -1.1, 0.2, 0.2, -1.5, 0.5, -0.7, 0.2, -0.9, 0.6, -0.3, 0.3, 0.0, 0.4, -0.1, -1.4], | |
# [ 2.0, 3.3, -0.1, -1.3, 3.2, 7.0, -6.3, 0.5, 4.6, -3.8, -5.1, 1.3, -1.0, 1.5, 0.3, -0.7, -0.8], | |
# [-1.5, 0.5, -0.5, -0.1, -0.9, -1.5, -0.6, 0.8, -0.6, -1.1, 1.3, -0.3, -0.2, 0.2, 0.8, -1.4, -0.5], | |
# [-2.1, 1.2, 1.2, -3.7, 0.3, 2.0, -1.9, 1.1, 1.6, 0.8, -0.8, 0.1, -0.0, -0.3, 1.0, -0.4, -0.2], | |
# [ 1.1, 2.4, 4.2, -1.1, 2.2, -7.7, 2.3, 1.1, 2.0, 3.5, -3.4, 0.5, 3.1, -3.7, 0.8, 5.4, 0.2], | |
# [-5.1, 1.4, -1.5, 1.4, -0.6, -5.4, -2.1, 2.1, 1.1, -1.5, 1.0, 0.1, -2.6, 0.3, 5.1, -2.7, -1.4], | |
# ], | |
# [[-0.7, 0.2, 0.9, -12.0, 1.4, -4.3, -11.1, 9.8, 5.4]], | |
# ] | |
# 7 bits | |
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6]] | |
# [[[-5.8, 3.6, 1.9, 1.1, 3.5, 2.1, 3.3, -4.9], | |
# [-3.0, 5.4, 6.8, 2.7, -5.6, 4.6, -6.4, -1.6], | |
# [-5.1, 0.2, 4.5, 2.3, -1.9, -2.3, 0.0, -1.0], | |
# [ 0.2, 2.2, -2.7, 4.8, 0.1, -3.0, -1.7, -0.7], | |
# [-2.6, 4.3, -4.8, 3.6, -3.2, 3.3, -6.0, 2.6], | |
# [-3.0, 2.5, 1.1, -4.1, 1.5, 6.5, -3.3, -4.1], | |
# [-1.7, -5.8, 2.2, 6.2, -0.5, -6.8, 3.7, 2.4]], | |
# [[-8.0, 9.6, -7.6, 8.6, -12.3, -9.4, -9.6, 5.4]], | |
# ] | |
# 8 bits | |
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] # ~> NameError: undefined local variable or method `n' for main:Object | |
# [[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1], | |
# [-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0], | |
# [-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0], | |
# [ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3], | |
# [-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0], | |
# [ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5], | |
# [-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4], | |
# [-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4], | |
# [ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2] | |
# ], | |
# [[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]], | |
# ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def is_prime?(n) | |
[[[-9, 3, 9, 5, -5, 6, -6, 6, -6], | |
[-2, 2, -3, -6, 0, 6, -5, -6, -3], | |
[-4, -4, 7, 0, -3, -3, 2, -2, 1], | |
[ 2, -9, -3, 8, -4, -9, -3, 6, 6], | |
[-2, 4, 11, 3, -6, 1, -7, 2, -3], | |
[ 1, -7, 3, 4, -4, -2, 3, 0, -3.5], | |
[-9, 6, 5, -4, 6, 5, 5, -4, -6], | |
[-4, -7, -4, 5, 2, -7, 8, -4, 0], | |
[ 1, -3, -7, -3, 2, -5, 6, -5, -2]], | |
[[-15,-14,-9,-11,14,11,-11,-11,11,5]]] | |
.inject(8.times.map{|i|n[i]}){|l,r|r.map{|s|1/(1+Math::E**[*l,1].zip(s).inject(0){|x,(n,s)|x-n*s})}}[0]>0.5 | |
end | |
require 'prime' | |
(256).times.all? { |n| is_prime?(n) == n.prime? } # => true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment