Created
December 17, 2020 10:20
-
-
Save Vindaar/226054207d7a0a65ea1e56f857a1e44f to your computer and use it in GitHub Desktop.
Direct Nim port of @karpathy's micrograd
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashes, sets, math, strformat, algorithm | |
type | |
BackFn = proc(r: Value) | |
Value = ref object | |
data: float | |
grad: float | |
backFn: BackFn | |
prev: HashSet[Value] | |
op: string | |
proc hash(v: Value): Hash = | |
result = hash(v.data) | |
result = result !& hash(v.grad) | |
result = result !& hash(v.backFn) | |
result = result !& hash(v.prev) | |
proc `$`(v: Value): string = | |
result = &"Value(data={v.data}, grad={v.grad})" | |
proc initValue[T: SomeNumber](x: T, children = initHashSet[Value](), op = ""): Value = | |
result = Value(data: x.float, grad: 0.0, prev: children, op: op, | |
backFn: (proc(r: Value) = discard)) | |
proc toSet(args: varargs[Value]): HashSet[Value] = | |
result = initHashSet[Value]() | |
for arg in args: | |
result.incl arg | |
proc `+`[T](v: Value, w: T): Value = | |
when T is SomeNumber: | |
var val = initValue(w) | |
elif T is Value: | |
template val: untyped = w | |
result = initValue(v.data + val.data, toSet(v, val), "+") | |
result.backFn = ( | |
proc(r: Value) = | |
v.grad += r.grad | |
val.grad += r.grad | |
) | |
proc `+`[T: not Value](v: T, w: Value): Value = w + v | |
proc `*`[T](v: Value, w: T): Value = | |
when T is SomeNumber: | |
let val = initValue(w) | |
else: | |
template val: untyped = w | |
result = initValue(v.data * val.data, toSet(v, val), "*") | |
result.backFn = ( | |
proc(r: Value) = | |
v.grad += val.data * r.grad | |
val.grad += v.data * r.grad | |
) | |
proc `*`[T: not Value](v: T, w: Value): Value = w * v | |
proc `**`[T: SomeNumber; U: SomeNumber](v: T, val: U): float = result = pow(v.float, val.float) | |
proc `**`[T: SomeNumber](v: Value, val: T): Value = | |
result = initValue(pow(v.data, val.float), toSet(v), "**" & $val) | |
result.backFn = ( | |
proc(r: Value) = | |
v.grad += (val.float * pow(v.data, (val - 1).float)) * r.grad | |
) | |
proc relu(v: Value): Value = | |
result = initValue(if v.data < 0: 0.0 else: v.data, toSet(v), "relu") | |
result.backFn = ( | |
proc(r: Value) = | |
if r.data > 0: | |
v.grad += r.grad | |
) | |
proc tanh(v: Value): Value = | |
result = initValue(tanh(v.data), toSet(v), "tanh") | |
result.backFn = ( | |
proc(r: Value) = | |
v.grad += (1.0 / pow(cosh(v.data), 2.0)) * r.grad | |
) | |
proc backward(v: Value) = | |
var topo = newSeq[Value]() | |
var visited = initHashSet[Value]() | |
proc build_topo(w: Value) = | |
if w notin visited: | |
visited.incl w | |
for ch in w.prev: | |
build_topo(ch) | |
topo.add w | |
build_topo(v) | |
v.grad = 1 | |
for w in topo.reversed(): | |
w.backFn(w) # this is ridiculous (having to give `w` as an argument) | |
proc `-`(v: Value): Value = result = v * (-1) | |
proc `-`[T](v: Value, w: T): Value = result = v + (-w) | |
proc `-`[T: not Value](v: T, w: Value): Value = result = (-w) + v | |
proc `/`[T](v: Value, w: T): Value = result = v * (w ** (-1.0)) | |
proc `/`[T](v: T, w: Value): Value = result = (w ** (-1.0)) * v | |
proc `+=`[T](v: var Value, w: T) = v = v + w | |
when isMainModule: | |
let a = initValue(-4.0) | |
let b = initValue(2.0) | |
var c = a + b | |
var d = a * b + b**3 | |
c += c + 1 | |
c += 1 + c + (-a) | |
d += d * 2 + (b + a).relu() | |
d += 3 * d + (b - a).relu() | |
let e = c - d | |
let f = e**2 | |
var g = f / 2.0 | |
g += 10.0 / f | |
echo &"{g.data:.4f}" # prints 24.7041, the outcome of this forward pass | |
g.backward() | |
echo &"{a.grad:.4f}" # prints 138.8338, i.e. the numerical value of dg/da | |
echo &"{b.grad:.4f}" # prints 645.5773, i.e. the numerical value of dg/db | |
let arg = initValue(0.5) | |
let x = tanh(arg) | |
x.backward() | |
echo "AD: ∂(tanh(x))/∂x|_{x=0.5} = ", &"{arg.grad:.4f}" | |
echo "∂(tanh(x))/∂x|_{x=0.5} = [1 / cosh²(x)]_{x=0.5} = ", &"{1.0 / (cosh(0.5)**2):.4f}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment