Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sasagawa888/a3837fa64d375e2ea30e947b3e986618 to your computer and use it in GitHub Desktop.
Save sasagawa888/a3837fa64d375e2ea30e947b3e986618 to your computer and use it in GitHub Desktop.
# for adagrad test
defnetwork init_network4(_x) do
_x |> f(10,10) |> flatten
|> w(361,300) |> b(300) |> relu
|> w(300,100) |> b(100) |> relu
|> w(100,10) |> b(10) |> softmax
end
def all(m,n) do
IO.puts("preparing data")
image = MNIST.train_image(60000) |> Ctensor.to_matrex
label = MNIST.train_label_onehot(60000)
network = init_network4(0)
test_image = MNIST.test_image(10000) |> Ctensor.to_matrex
test_label = MNIST.test_label(10000)
IO.puts("ready")
network1 = all1(image,network,label,m,n,test_image,test_label)
correct = DP.accuracy(test_image,network1,test_label)
IO.write("accuracy rate = ")
IO.puts(correct / 10000)
end
def all1(_,network,_,_,0,_,_) do network end
def all1(image,network,train,m,n,test_image,test_label) do
network1 = all2(image,network,train,m)
image1 = Enum.take(image,100)
train1 = Enum.take(train,100) |> Cmatrix.to_matrex
y = DP.forward(image1,network1)
loss = DP.loss(y,train1,:cross)
DP.print(loss)
DP.newline()
correct = DP.accuracy(test_image,network1,test_label)
IO.write("accuracy rate = ")
IO.puts(correct / 10000)
all1(image,network1,train,m,n-1,test_image,test_label)
end
def all2(image,network,train,size) do
if length(image) <= size do
train1 = train |> Cmatrix.to_matrex
network1 = DP.gradient(image,network,train1)
network2 = DP.learning(network,network1,:adagrad)
IO.puts(".")
network2
else
train1 = Enum.take(train,size) |> Cmatrix.to_matrex
network1 = DP.gradient(Enum.take(image,size),network,train1)
network2 = DP.learning(network,network1,:adagrad)
IO.write(".")
all2(Enum.drop(image,size),network2,Enum.drop(train,size),size)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment