Skip to content

Instantly share code, notes, and snippets.

@abap34
Created April 30, 2020 02:10
Show Gist options
  • Save abap34/dd98ba582212f54ba1f898119e0c0504 to your computer and use it in GitHub Desktop.
Save abap34/dd98ba582212f54ba1f898119e0c0504 to your computer and use it in GitHub Desktop.
mutable struct Variable
data
grad
creator
end
mutable struct Func
input :: Variable
output :: Variable
forward :: Function
backward :: Function
name
end
function Base.println(x::Variable,debug=false)
print("Variable(")
print("data:",x.data)
if debug
println()
(x.grad != nothing) && println("grad:",x.grad)
(x.creator != nothing) && println("creator:",x.creator.name)
end
println(")")
end
function Base.println(x::Func)
println("Func[")
print("input: ")
println(x.input)
print("output: ")
println(x.output)
if x.name != nothing
println("function:",x.name)
end
println("]")
end
function set_func(functions::Array{Function},name=nothing)
none_variable = Variable(nothing,nothing,nothing)
return Func(none_variable,none_variable,functions[1],functions[2],name)
end
function set_var(data,grad=nothing,cretor=nothing)
return Variable(data,grad,cretor)
end
square_forward(x) = x ^ 2
square_backward(x) = 2x
exp_forward(x) = exp(x)
exp_backward(x) = exp(x)
Square = [square_forward, square_backward]
Exp = [exp_forward, exp_backward]
function Base.run(f::Func,input::Variable)
x = input.data
f.input = input
y = f.forward(x)
output = Variable(y,nothing,f)
f.output = output
return output
end
function backward!(var::Variable)
f = var.creator
if f != nothing
x = f.input
x.grad = f.backward(f.input.data) * var.grad
backward!(x)
end
end
a = set_var(0.5)
A = set_func(Square,"A")
B = set_func(Exp,"B")
C = set_func(Square,"C")
b = run(A,a)
c = run(B,b)
y = run(C,c)
y.grad = 1
backward!(y)
println(a.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment