Skip to content

Instantly share code, notes, and snippets.

@abap34
Created August 1, 2020 04:19
Show Gist options
  • Save abap34/cf1dcb90fd072dad187d4c0f360bd7df to your computer and use it in GitHub Desktop.
Save abap34/cf1dcb90fd072dad187d4c0f360bd7df to your computer and use it in GitHub Desktop.
import Base
using DeepShiba
using BenchmarkTools
rosenbrock(a, b) = 100 * (b - a^2)^2 + (a - 1)^2
function Base.String(result::BenchmarkTools.Trial)
io = IOBuffer()
show(io, "text/plain", result)
return String(take!(io))
end
function optim(iter; init=(0., 2.,), lr=1e-4, log_interval=iter ÷ 20)
a = variable(init[1], name="a")
b = variable(init[2], name="b")
for i in 1:iter
if (log_interval > 0) && (i - 1) % log_interval == 0
println("$(i) a: $(a.data) b: $(b.data)")
end
y = rosenbrock(a, b)
cleargrad!(a)
cleargrad!(b)
backward!(y)
a.data -= lr * a.grad.data
b.data -= lr * b.grad.data
end
end
const iter = 10000
result = @benchmark optim(iter, log_interval = -1)
println(String(result))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment