Skip to content

Instantly share code, notes, and snippets.

@Vindaar
Created December 17, 2020 10:20
Show Gist options
  • Save Vindaar/226054207d7a0a65ea1e56f857a1e44f to your computer and use it in GitHub Desktop.
Save Vindaar/226054207d7a0a65ea1e56f857a1e44f to your computer and use it in GitHub Desktop.
Direct Nim port of @karpathy's micrograd
import hashes, sets, math, strformat, algorithm
type
BackFn = proc(r: Value)
Value = ref object
data: float
grad: float
backFn: BackFn
prev: HashSet[Value]
op: string
proc hash(v: Value): Hash =
result = hash(v.data)
result = result !& hash(v.grad)
result = result !& hash(v.backFn)
result = result !& hash(v.prev)
proc `$`(v: Value): string =
result = &"Value(data={v.data}, grad={v.grad})"
proc initValue[T: SomeNumber](x: T, children = initHashSet[Value](), op = ""): Value =
result = Value(data: x.float, grad: 0.0, prev: children, op: op,
backFn: (proc(r: Value) = discard))
proc toSet(args: varargs[Value]): HashSet[Value] =
result = initHashSet[Value]()
for arg in args:
result.incl arg
proc `+`[T](v: Value, w: T): Value =
when T is SomeNumber:
var val = initValue(w)
elif T is Value:
template val: untyped = w
result = initValue(v.data + val.data, toSet(v, val), "+")
result.backFn = (
proc(r: Value) =
v.grad += r.grad
val.grad += r.grad
)
proc `+`[T: not Value](v: T, w: Value): Value = w + v
proc `*`[T](v: Value, w: T): Value =
when T is SomeNumber:
let val = initValue(w)
else:
template val: untyped = w
result = initValue(v.data * val.data, toSet(v, val), "*")
result.backFn = (
proc(r: Value) =
v.grad += val.data * r.grad
val.grad += v.data * r.grad
)
proc `*`[T: not Value](v: T, w: Value): Value = w * v
proc `**`[T: SomeNumber; U: SomeNumber](v: T, val: U): float = result = pow(v.float, val.float)
proc `**`[T: SomeNumber](v: Value, val: T): Value =
result = initValue(pow(v.data, val.float), toSet(v), "**" & $val)
result.backFn = (
proc(r: Value) =
v.grad += (val.float * pow(v.data, (val - 1).float)) * r.grad
)
proc relu(v: Value): Value =
result = initValue(if v.data < 0: 0.0 else: v.data, toSet(v), "relu")
result.backFn = (
proc(r: Value) =
if r.data > 0:
v.grad += r.grad
)
proc tanh(v: Value): Value =
result = initValue(tanh(v.data), toSet(v), "tanh")
result.backFn = (
proc(r: Value) =
v.grad += (1.0 / pow(cosh(v.data), 2.0)) * r.grad
)
proc backward(v: Value) =
var topo = newSeq[Value]()
var visited = initHashSet[Value]()
proc build_topo(w: Value) =
if w notin visited:
visited.incl w
for ch in w.prev:
build_topo(ch)
topo.add w
build_topo(v)
v.grad = 1
for w in topo.reversed():
w.backFn(w) # this is ridiculous (having to give `w` as an argument)
proc `-`(v: Value): Value = result = v * (-1)
proc `-`[T](v: Value, w: T): Value = result = v + (-w)
proc `-`[T: not Value](v: T, w: Value): Value = result = (-w) + v
proc `/`[T](v: Value, w: T): Value = result = v * (w ** (-1.0))
proc `/`[T](v: T, w: Value): Value = result = (w ** (-1.0)) * v
proc `+=`[T](v: var Value, w: T) = v = v + w
when isMainModule:
let a = initValue(-4.0)
let b = initValue(2.0)
var c = a + b
var d = a * b + b**3
c += c + 1
c += 1 + c + (-a)
d += d * 2 + (b + a).relu()
d += 3 * d + (b - a).relu()
let e = c - d
let f = e**2
var g = f / 2.0
g += 10.0 / f
echo &"{g.data:.4f}" # prints 24.7041, the outcome of this forward pass
g.backward()
echo &"{a.grad:.4f}" # prints 138.8338, i.e. the numerical value of dg/da
echo &"{b.grad:.4f}" # prints 645.5773, i.e. the numerical value of dg/db
let arg = initValue(0.5)
let x = tanh(arg)
x.backward()
echo "AD: ∂(tanh(x))/∂x|_{x=0.5} = ", &"{arg.grad:.4f}"
echo "∂(tanh(x))/∂x|_{x=0.5} = [1 / cosh²(x)]_{x=0.5} = ", &"{1.0 / (cosh(0.5)**2):.4f}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment