Last active
September 6, 2023 02:01
-
-
Save automata/a95828a38b4fee78b77e853e7d5dc2ea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import random_float64 | |
from math import tanh | |
@value | |
@register_passable("trivial") | |
struct Value: | |
var r: Pointer[Int] | |
var l: Pointer[Int] | |
var data: Float64 | |
var grad: Float64 | |
var op: StringLiteral | |
var _id: Float64 | |
fn __init__(data: Float64) -> Value: | |
return Value(Pointer[Int].get_null(), Pointer[Int].get_null(), data, 0.0, "", random_float64()) | |
fn __eq__(self, other : Value) -> Bool: | |
# For now using a random_float64 value :-) | |
if self._id == other._id: | |
return True | |
return False | |
# Add | |
fn __add__(self, other: Value) -> Value: | |
return self.new(self.data + other.data, other, "+") | |
fn __radd__(self, other:Value) -> Value: | |
return self + other | |
fn __add__(self, other: Float64) -> Value: | |
return self + Value(other) | |
fn __radd__(self, other: Float64) -> Value: | |
return self + Value(other) | |
@staticmethod | |
fn backward_add(inout node: Value): | |
var l = node.l.bitcast[Value]().load(0) | |
var r = node.r.bitcast[Value]().load(0) | |
l.grad += node.grad | |
r.grad += node.grad | |
node.l.bitcast[Value]().store(0, l) | |
node.l.bitcast[Value]().store(0, r) | |
Value._backward(l) | |
Value._backward(r) | |
# Mul | |
fn __mul__(self, other: Value) -> Value: | |
return self.new(self.data * other.data, other, "*") | |
fn __rmul__(self, other: Value) -> Value: | |
return self * other | |
fn __mul__(self, other: Float64) -> Value: | |
return self * Value(other) | |
fn __rmul__(self, other: Float64) -> Value: | |
return self * Value(other) | |
@staticmethod | |
fn backward_mul(inout node: Value): | |
var left = node.l.bitcast[Value]().load(0) | |
var right = node.r.bitcast[Value]().load(0) | |
left.grad += right.data * node.grad | |
right.grad += left.data * node.grad | |
node.l.bitcast[Value]().store(0, left) | |
node.r.bitcast[Value]().store(0, right) | |
Value._backward(left) | |
Value._backward(right) | |
# Neg | |
fn __neg__(self) -> Value: | |
return self * -1 | |
# Sub | |
fn __sub__(self, other: Value) -> Value: | |
return self + (-other) | |
fn __sub__(self, other: Float64) -> Value: | |
return self + (-Value(other)) | |
# Tanh | |
fn tanh(self) -> Value: | |
return self.new(tanh(self.data), "tanh") | |
fn backward_tanh(inout node: Value): | |
var left = node.l.bitcast[Value]().load(0) | |
left.grad += (1 - tanh(left.data)**2) * node.grad | |
node.l.bitcast[Value]().store(0, left) | |
Value._backward(left) | |
# Value alloc | |
fn new(self, data: Float64, op: StringLiteral) -> Value: | |
let l = Pointer[Value].alloc(1) | |
l.store(self) | |
return Value(l.bitcast[Int](), Pointer[Int].get_null(), data, 0.0, op, random_float64()) | |
fn new(self, data: Float64, right: Value, op: StringLiteral) -> Value: | |
let l = Pointer[Value].alloc(1) | |
l.store(self) | |
let r = Pointer[Value].alloc(1) | |
r.store(right) | |
return Value(l.bitcast[Int](), r.bitcast[Int](), data, 0.0, op, random_float64()) | |
# Autograd | |
@staticmethod | |
fn _backward(inout node: Value): | |
if node.op == "": | |
return | |
if node.op == "+": | |
Value.backward_add(node) | |
if node.op == "*": | |
Value.backward_mul(node) | |
if node.op == "tanh": | |
Value.backward_tanh(node) | |
fn backward(inout self): | |
# Topological sort | |
var topo : DynamicVector[Value] = DynamicVector[Value]() | |
var visited : DynamicVector[Value] = DynamicVector[Value]() | |
self.build_topo(self, visited, topo) | |
self.grad = 1.0 | |
var reversed = Value.reverse(topo) | |
for i in range(len(reversed)): | |
self._backward(reversed[i]) | |
fn build_topo(inout self, v : Value, inout visited : DynamicVector[Value], inout topo : DynamicVector[Value]): | |
var is_in_visited = False | |
let size = len(visited) | |
for i in range(size): | |
if v == visited[i]: | |
is_in_visited = True | |
if not is_in_visited: | |
visited.push_back(v) | |
# It's pushing back, so visit in reverse, first right then left | |
if v.r.bitcast[Int]() != Pointer[Int].get_null(): | |
self.build_topo(v.r.bitcast[Value]().load(0), visited, topo) | |
if v.l.bitcast[Int]() != Pointer[Int].get_null(): | |
self.build_topo(v.l.bitcast[Value]().load(0), visited, topo) | |
topo.push_back(v) | |
@staticmethod | |
fn reverse(vec : DynamicVector[Value]) -> DynamicVector[Value]: | |
var reversed : DynamicVector[Value] = DynamicVector[Value](len(vec)) | |
for i in range(len(vec)-1, -1, -1): | |
reversed.push_back(vec[i]) | |
return reversed | |
fn show(self, label : StringLiteral): | |
print("<Value", label, "::", "data:", self.data, "grad:", self.grad, "op:", self.op, ">") | |
@staticmethod | |
fn print_backward(node: Value): | |
if node.l and node.r: | |
let left = node.l.bitcast[Value]().load(0) | |
let right = node.r.bitcast[Value]().load(0) | |
print(left.data, "(", left.grad, ")", node.op, right.data, "(", right.grad, ")", "=", node.data) | |
elif node.l: | |
let left = node.l.bitcast[Value]().load(0) | |
print(left.data, "(", left.grad, ")", node.op, "=", node.data) | |
if node.l: | |
let left = node.l.bitcast[Value]().load(0) | |
Value.print_backward(left) | |
if node.r: | |
let right = node.r.bitcast[Value]().load(0) | |
Value.print_backward(right) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
var a = Value(1) | |
var b = Value(2) | |
var c = Value(7) | |
var s1 = a + b | |
var s2 = s1 * c | |
s2.backward() | |
Value.print_backward(s2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment