Last active
January 4, 2018 14:28
-
-
Save JoshCheek/a11fbb71e5d81ee79d4e to your computer and use it in GitHub Desktop.
Overfit neural network to find prime numbers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def is_prime?(n) | |
bits = 8.times.map { |i| n[i] } | |
[[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1], | |
[-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0], | |
[-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0], | |
[ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3], | |
[-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0], | |
[ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5], | |
[-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4], | |
[-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4], | |
[ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2] | |
], | |
[[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]], | |
] | |
.inject(bits) { |nodes_lhs, synapseses_rhs_major| | |
synapseses_rhs_major.map { |synapses| | |
total_input = [*nodes_lhs, 1.0].zip(synapses).map { |n, s| n*s }.inject(0, :+) | |
1 / (1 + Math::E ** -total_input) # https://en.wikipedia.org/wiki/Sigmoid_function | |
} | |
}.first > 0.5 | |
end | |
require 'prime' | |
(256).times.all? { |n| is_prime?(n) == Prime.prime?(n) } # => true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'ai4r/neural_network/backpropagation' | |
require 'prime' | |
# srand 1 | |
def bits_for(n) | |
[n[0],n[1],n[2],n[3],n[4],n[5],n[6],n[7]] | |
end | |
def train(net, n) | |
net.train bits_for(n), [n.prime? ? 1 : 0] | |
end | |
loop do | |
puts "Training the network, please wait."; | |
max = (2**8)-1 | |
net = Ai4r::NeuralNetwork::Backpropagation.new([8, 9, 1]); | |
2001.times do |i| | |
errors = 0.upto(max).to_a.shuffle.map { |n| [n, train(net, n)] } | |
next unless i % 200 == 0 | |
average = errors.map(&:last).inject(:+) / errors.length | |
puts "Error for run #{i}: #{average.inspect}" | |
next if i < 600 | |
# retrain = errors.select { |n, err| err > 0.2 } | |
retrain = 0.upto(max).reject { |n| (0.5 < net.eval(bits_for n).first) == n.prime? } | |
puts " retraining on #{retrain.inspect}" | |
500.times { retrain.each { |n| train net, n } } | |
end | |
puts "Test data" | |
num_correct = 0 | |
num_attempted = 0 | |
0.upto max do |n| | |
num_attempted += 1 | |
result = net.eval(bits_for n).first | |
actual = (0.5 <= result) | |
expected = n.prime? | |
(expected == actual) ? (num_correct += 1) : printf("%3d \e[31mreal:%-5s ai:%-5s\e[0m (%f)\n", n, expected, actual, result) | |
end | |
puts | |
puts "SUMMARY: #{num_correct}/#{num_attempted} (#{100.0*num_correct/num_attempted}%)" | |
if num_attempted == num_correct | |
require "pry" | |
binding.pry | |
end | |
puts "-----------------------------" | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# require 'pp' | |
# puts '[' << synapseses_rhs_major.map { |s| s.map { |s| '%4.1f' % s }.join(', ') }.join("],\n") | |
# 8 bits | |
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] | |
# [[[-3.2, 3.6, -0.7, 2.8, 0.9, 3.2, -0.1, 0.4, -2.3], | |
# [-0.2, -3.0, 0.9, -3.8, 1.3, 3.1, 1.0, -2.4, 0.2], | |
# [-1.0, 1.4, -0.0, -3.5, 1.6, 2.4, -2.9, 4.0, -1.0], | |
# [-0.6, 3.0, 1.9, 2.4, 0.4, 0.4, 2.7, 0.6, -0.9], | |
# [-4.1, -1.6, 2.8, -1.6, 0.6, -1.3, 3.3, 1.2, -0.5], | |
# [ 0.0, 3.8, -5.0, 4.1, -5.2, -1.1, -5.9, 4.6, 1.3], | |
# [-0.2, 2.0, -1.9, -1.0, -2.5, 6.2, -2.6, -0.3, -1.3], | |
# [-0.3, -0.7, 0.6, -0.8, 1.3, -0.6, 1.9, -2.2, -0.2], | |
# [-2.3, -0.3, 3.4, -2.6, 2.6, -1.2, 2.5, -2.1, -3.2], | |
# [ 0.8, -0.9, -4.4, 3.4, -1.6, 2.5, -0.6, -2.1, -0.9], | |
# [ 3.2, 3.7, 1.6, -3.8, -1.7, 4.5, -4.6, 3.2, -0.3], | |
# [-0.3, 0.7, 1.1, -0.3, -0.8, -0.5, -0.3, -0.8, -0.4], | |
# [-1.5, 1.7, -2.5, 0.3, -1.3, 2.0, -2.5, 0.4, 0.3], | |
# [ 2.9, -2.6, -1.7, 0.0, 1.4, -2.1, -1.7, 1.7, 0.9], | |
# [ 0.1, -2.5, 3.2, -1.7, 1.4, -2.5, 3.1, -1.6, -0.1], | |
# [-3.6, -2.1, -0.5, 4.1, -1.6, 0.6, -1.3, 4.4, -0.8], | |
# ], | |
# [[ 0.7, 1.1, -0.6, -0.6, 0.3, -0.7, 0.2, 1.0, -0.2, -0.7, -1.0, -0.9, 0.4, -0.9, -0.8, -0.5, -1.1], | |
# [-0.6, -0.9, -0.3, -0.9, 0.2, -0.1, 0.0, 0.3, 0.8, -0.7, 0.5, -0.4, -0.3, -0.8, -0.3, -0.5, -0.4], | |
# [-1.6, -0.5, -1.1, 0.2, 0.2, -1.5, 0.5, -0.7, 0.2, -0.9, 0.6, -0.3, 0.3, 0.0, 0.4, -0.1, -1.4], | |
# [ 2.0, 3.3, -0.1, -1.3, 3.2, 7.0, -6.3, 0.5, 4.6, -3.8, -5.1, 1.3, -1.0, 1.5, 0.3, -0.7, -0.8], | |
# [-1.5, 0.5, -0.5, -0.1, -0.9, -1.5, -0.6, 0.8, -0.6, -1.1, 1.3, -0.3, -0.2, 0.2, 0.8, -1.4, -0.5], | |
# [-2.1, 1.2, 1.2, -3.7, 0.3, 2.0, -1.9, 1.1, 1.6, 0.8, -0.8, 0.1, -0.0, -0.3, 1.0, -0.4, -0.2], | |
# [ 1.1, 2.4, 4.2, -1.1, 2.2, -7.7, 2.3, 1.1, 2.0, 3.5, -3.4, 0.5, 3.1, -3.7, 0.8, 5.4, 0.2], | |
# [-5.1, 1.4, -1.5, 1.4, -0.6, -5.4, -2.1, 2.1, 1.1, -1.5, 1.0, 0.1, -2.6, 0.3, 5.1, -2.7, -1.4], | |
# ], | |
# [[-0.7, 0.2, 0.9, -12.0, 1.4, -4.3, -11.1, 9.8, 5.4]], | |
# ] | |
# 7 bits | |
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6]] | |
# [[[-5.8, 3.6, 1.9, 1.1, 3.5, 2.1, 3.3, -4.9], | |
# [-3.0, 5.4, 6.8, 2.7, -5.6, 4.6, -6.4, -1.6], | |
# [-5.1, 0.2, 4.5, 2.3, -1.9, -2.3, 0.0, -1.0], | |
# [ 0.2, 2.2, -2.7, 4.8, 0.1, -3.0, -1.7, -0.7], | |
# [-2.6, 4.3, -4.8, 3.6, -3.2, 3.3, -6.0, 2.6], | |
# [-3.0, 2.5, 1.1, -4.1, 1.5, 6.5, -3.3, -4.1], | |
# [-1.7, -5.8, 2.2, 6.2, -0.5, -6.8, 3.7, 2.4]], | |
# [[-8.0, 9.6, -7.6, 8.6, -12.3, -9.4, -9.6, 5.4]], | |
# ] | |
# 8 bits | |
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] # ~> NameError: undefined local variable or method `n' for main:Object | |
# [[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1], | |
# [-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0], | |
# [-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0], | |
# [ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3], | |
# [-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0], | |
# [ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5], | |
# [-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4], | |
# [-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4], | |
# [ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2] | |
# ], | |
# [[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]], | |
# ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def is_prime?(n) | |
[[[-9, 3, 9, 5, -5, 6, -6, 6, -6], | |
[-2, 2, -3, -6, 0, 6, -5, -6, -3], | |
[-4, -4, 7, 0, -3, -3, 2, -2, 1], | |
[ 2, -9, -3, 8, -4, -9, -3, 6, 6], | |
[-2, 4, 11, 3, -6, 1, -7, 2, -3], | |
[ 1, -7, 3, 4, -4, -2, 3, 0, -3.5], | |
[-9, 6, 5, -4, 6, 5, 5, -4, -6], | |
[-4, -7, -4, 5, 2, -7, 8, -4, 0], | |
[ 1, -3, -7, -3, 2, -5, 6, -5, -2]], | |
[[-15,-14,-9,-11,14,11,-11,-11,11,5]]] | |
.inject(8.times.map{|i|n[i]}){|l,r|r.map{|s|1/(1+Math::E**[*l,1].zip(s).inject(0){|x,(n,s)|x-n*s})}}[0]>0.5 | |
end | |
require 'prime' | |
(256).times.all? { |n| is_prime?(n) == n.prime? } # => true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How large numbers can the algorithm take?