Last active
October 23, 2023 21:31
-
-
Save Nikolaj-K/6720fa0428a078e22c80dfdb9f89ff23 to your computer and use it in GitHub Desktop.
A simple implementation of reverse accumulation (back-propagation)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Impelmentation of reverse accumulation (backpropagation), discussed in | |
https://youtu.be/BcCk8I6YAqw | |
Most of this script is exposition code. The reverse accumulation code still runs if you delete everything except the following 5 classes | |
* ValAndDVal, ExprBaseRevAcc, VarExprRevAcc, PlusExprRevAcc, MultExprRevAcc | |
validated using the following 2 functions: | |
* p, run_VarExprRevAcc_gradient_decent_demo | |
For motivation, see Hotz's "tinygrad" autograd/tensor library: | |
* https://github.com/tinygrad/tinygrad#quick-example-comparing-to-pytorch | |
- Itself motivated by micrograd by Karpathy, linked therein | |
- For pytorch, see also https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html | |
- https://www.google.com/search?q=autograd&rlz=1C5CHFA_enAT878AT878&oq=autograd&gs_lcrp=EgZjaHJvbWUyBggAEEUYOTIHCAEQABiABDIHCAIQABiABDIHCAMQABiABDINCAQQLhjHARjRAxiABDIHCAUQABiABDIHCAYQABiABDIGCAcQRRg90gEIMTI4OWowajeoAgCwAgA&sourceid=chrome&ie=UTF-8 | |
See also | |
* https://en.wikipedia.org/wiki/Backpropagation | |
* https://en.wikipedia.org/wiki/Automatic_differentiation | |
- That article ends with a concise implementation using dual numbers (not covered in this script) | |
This script: | |
* Implementation of | |
- ExprBase, for forward accumulation for computing of polynomial derivatives | |
- ExprBaseRevAcc, for backwards accumulation for computing of polynomial derivatives (see Wikipedia for a similar example) | |
* Use of ExprBaseRevAcc for gradient decent exmaple | |
* Start out with motivational LayerStack class, explaining forwards and backwards motion of data. | |
- Linear computational model, as opposed to tree structure found in the polynomial expression | |
""" | |
def p(d, x, y): | |
""" | |
This is some polynomial we're gonna compute the derivative of, at various places | |
See | |
https://www.wolframalpha.com/input?i=z+%3D+x+*+%28%28x+%2B+y%29+%2B+2%29+%2B+y+*+y | |
Mathematica code: | |
d = 2; | |
p[x_, y_] := x (x + y + d) + y^2 | |
pos = {2, 3} | |
p[2, 3] == 2 (2 + 3 + 2) + 3 3 == 23 | |
(D[p[x, y], x] /.{x->2, y->3}) == 9 | |
(D[p[x, y], y] /.{x->2, y->3}) == 8 | |
D[p[x[t], y[t]], t] == (2 + 2 * x[t] + y[t]) x'[t] + (x[t] + 2 y[t]) y'[t] //Simplify | |
""" | |
return x * (x + y + d) + y * y | |
class ExprTorch: | |
""" | |
Tensor wrapper enabling notations 'f * g' and 'f + g'. | |
(Technical note: Is recursively defined, in the sense that it uses PlusExpr(ExprBase)) | |
""" | |
def __init__(self, tensor): | |
self.tensor = tensor | |
def __add__(self, other): | |
return ExprTorch(self.tensor + other.tensor) | |
def __mul__(self, other): | |
return ExprTorch(self.tensor.matmul(other.tensor)) | |
def torch_example(): | |
def torch_scalar(scalar): | |
matrix_1x1 = [float(scalar)] | |
import torch # Note: If you don't have pytorch installed, just don't run torch_example | |
return torch.tensor([matrix_1x1], requires_grad=True) | |
def torch_p(d, x, y): | |
return p(ExprTorch(d), ExprTorch(x), ExprTorch(y)).tensor | |
x = torch_scalar(2) | |
y = torch_scalar(3) | |
d = torch_scalar(2) | |
z = torch_p(d, x, y) | |
z.backward() | |
print(f"[torch exmaple] z = {float(z)}") | |
print(f"[torch exmaple] ∂z/∂x = {float(x.grad)} at pos=(2,3)") | |
print(f"[torch exmaple] ∂z/∂y = {float(y.grad)} at pos=(2,3)") | |
print(f"[torch exmaple] ∂z/∂d = {float(d.grad)}") | |
class Layer: | |
def __init__(self, layer_name, func, prev_layer): | |
self.layer_name = layer_name | |
self.data = None | |
self.func = func | |
self.prev_layer = prev_layer | |
print(f"[{self.layer_name}] Initialized.") | |
def set_data(self, data): | |
self.data = data | |
print(f"[{self.layer_name}] Set data {data}.") | |
def eval_f(self, x): | |
self.data = self.func(x) | |
print(f"[{self.layer_name}] Got input x={x},\tSetting data and returning layer.data = layer.func(x) = {self.data}") | |
return self.data | |
def push_down_signal(self, signal): | |
new_signal = signal + "|" + str(self.data) | |
# Note: If .forward() is not called before .push_down_signal (i.e. when .data are not set), then the logs will show a lot of 'None' | |
if self.prev_layer is None: | |
print(f"[{self.layer_name}] Got signal '{signal}'.\tBut prev_layer is None and so ending on signal '{new_signal}'.") | |
else: | |
print(f"[{self.layer_name}] Got signal '{signal}'.\tNow pushing down '{new_signal}' to previous layer '{self.prev_layer.layer_name}'.") | |
self.prev_layer.push_down_signal(new_signal) | |
print(f"[{self.layer_name}] Done pushing down.") | |
class LayerStack: | |
def __init__(self): | |
print(f"[LayerStack START] Initlaizing stack.") | |
self.layer_0 = Layer("layer 0", None, None) # .f of layer 0 will not be used | |
self.layer_1 = Layer("layer 1", lambda x: x + 4 * 10 ** 1, self.layer_0) | |
self.layer_2 = Layer("layer 2", lambda x: x + 5 * 10 ** 2, self.layer_1) | |
self.layer_3 = Layer("layer 3", lambda x: x + 6 * 10 ** 3, self.layer_2) | |
self.layer_4 = Layer("layer 4", lambda x: x + 7 * 10 ** 4, self.layer_3) | |
print(f"[LayerStack END] Initialized stack.\n") | |
def set_layer_0_data(self, data): | |
print(f"[LayerStack START] Setting stack input_value (layer_0_data) to {data}.") | |
self.layer_0.set_data(data) | |
print(f"[LayerStack END] Set stack input_value (layer_0_data) to {data}.\n") | |
def forward(self): | |
r0 = self.layer_0.data | |
print(f"[LayerStack START] Forwarding {r0} through all layers.") | |
r1 = self.layer_1.eval_f(r0) | |
r2 = self.layer_2.eval_f(r1) | |
r3 = self.layer_3.eval_f(r2) | |
r4 = self.layer_4.eval_f(r3) | |
print(f"[LayerStack END] Forwarded {r0} through all layers and arrived with {r4}.\n") | |
return r4 | |
def push_down_signal(self, signal): | |
print(f"[LayerStack START] Pushing down signal={signal} using last layer {self.layer_4.layer_name}") | |
self.layer_4.push_down_signal(signal) | |
print(f"[LayerStack END] Pushed down signal={signal} using last layer {self.layer_4.layer_name}") | |
def run_stack_demo(): | |
INPUT = 3 | |
ls = LayerStack() | |
ls.set_layer_0_data(INPUT) | |
_r = ls.forward() | |
ls.push_down_signal("foo") | |
class ValAndDVal: | |
def __init__(self, val): | |
self.__val = val | |
def get_val(self): | |
return self.__val | |
def get_dval(self): | |
return self.__dval | |
def set_dval(self, dval): # Only in ExprBaseRevAcc is this not just called right after __init__ | |
self.__dval = dval | |
class ExprBase: | |
""" | |
Enable notations 'f * g' and 'f + g'. | |
(Technical note: Is recursively defined, in the sense that it uses PlusExpr(ExprBase)) | |
""" | |
def __add__(self, other): | |
return PlusExpr(self, other) | |
def __mul__(self, other): | |
return MultExpr(self, other) | |
def make_val_and_dval(self, var): | |
# Note: Both (!!) val and dval will be (recursively) forwarded in this step | |
assert False # Combined eval and getter function must be implemented | |
def get_val(self): # Auxiliary getter for printing | |
VAR = None # Not interested in derivative expression here, so passing auxiliary value | |
return self.make_val_and_dval(VAR).get_val() | |
class VarExpr(ExprBase): | |
def __init__(self, val): | |
self.__val = val | |
def make_val_and_dval(self, var): # Evaluate val and first-order derivative (Note: there are no higher ones for primitive variables) | |
# Note: var == None (or passing any other var that's not a VarExpr) is allowed to access val (when not caring about dval) | |
res = ValAndDVal(self.__val) | |
res.set_dval(int(self == var)) # kronecker_delta(self, var) | |
return res | |
class PlusExpr(ExprBase): # Note: Does not have any numerical value members | |
def __init__(self, expr_l, expr_r): | |
self.expr_l = expr_l | |
self.expr_r = expr_r | |
def make_val_and_dval(self, var): # Evaluate and get val and first-order derivative val | |
l_dl = self.expr_l.make_val_and_dval(var) | |
r_dr = self.expr_r.make_val_and_dval(var) | |
res = ValAndDVal(l_dl.get_val() + r_dr.get_val()) | |
res.set_dval(l_dl.get_dval() + r_dr.get_dval()) # Derivative distributes over addition | |
return res | |
class MultExpr(ExprBase): # Note: Does not have any numerical value members | |
def __init__(self, expr_l, expr_r): | |
self.expr_l = expr_l | |
self.expr_r = expr_r | |
def make_val_and_dval(self, var): # Evaluate and get val and first-order derivative val | |
# Compute instances used several times in expression (Same 3 lines as in PlusExpr.make_val_and_dval) | |
# See also https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py#L28 | |
l_dl = self.expr_l.make_val_and_dval(var) | |
r_dr = self.expr_r.make_val_and_dval(var) | |
l_at_v = l_dl.get_val() | |
r_at_v = r_dr.get_val() | |
res = ValAndDVal(l_at_v * r_at_v) | |
res.set_dval(r_at_v * l_dl.get_dval() + l_at_v * r_dr.get_dval()) # Product rule for multiplication, (a*b)' = (a')*b+a*(b') | |
return res | |
class ExprBaseRevAcc: | |
# Similar to ExprBase above, but primitive subexpressions will accumulate dval (partial derivative) values | |
def __add__(self, other): | |
return PlusExprRevAcc(self, other) | |
def __mul__(self, other): | |
return MultExprRevAcc(self, other) | |
def eval(self): | |
self.forward() | |
DT_OVER_DT = 1 # ∂t/∂t = kronecker_delta(foo, foo) = 1 | |
self.accumulate_derivative(DT_OVER_DT) | |
class VarExprRevAcc(ExprBaseRevAcc): | |
# Note: This var will also accumulate a dval. Compare to previous ExprBase which didn't have any dval member field! | |
def __init__(self, val): | |
self.val_and_dval = ValAndDVal(val) | |
self.val_and_dval.set_dval(0) | |
def forward(self): | |
pass # Need no forward, since value was already set in __init__(val) | |
def get_val(self): | |
return self.val_and_dval.get_val() | |
def get_dval(self): | |
return self.val_and_dval.get_dval() | |
def accumulate_derivative(self, dval): # Accumulation of derivative values coming in | |
s = self.val_and_dval.get_dval() | |
self.val_and_dval.set_dval(s + dval) | |
def step(self, step_size): | |
correction = -step_size * self.get_dval() # -grad | |
corrected_val = self.get_val() + correction | |
self.val_and_dval = ValAndDVal(corrected_val) | |
self.val_and_dval.set_dval(0) | |
class PlusExprRevAcc(ExprBaseRevAcc): | |
def __init__(self, expr_l, expr_r): | |
self.expr_l = expr_l | |
self.expr_r = expr_r | |
self.val = None # This expr class has a record for its own val, but no dval | |
def get_val(self): | |
return self.val | |
def forward(self): # Also a setter function w.r.t. whatever vals are set in the expression | |
self.expr_l.forward() | |
self.expr_r.forward() | |
self.val = self.expr_l.get_val() + self.expr_r.get_val() | |
def accumulate_derivative(self, dval): | |
# Note: Addition '+' not implemented here, as both .expr's push back into the same accumulating variables. | |
# | |
# Linearity of derivative: | |
# e := e1 + e2 | |
# de/dt * dT = de1/dt * (1.0 * dT) + de2/dt * (1.0 * dT) | |
self.expr_l.accumulate_derivative(1.0 * dval) | |
self.expr_r.accumulate_derivative(1.0 * dval) | |
class MultExprRevAcc(ExprBaseRevAcc): | |
# See comments made about PlusExprRevAcc above. MultExprRevAcc is similar, if slightly more complicated | |
def __init__(self, expr_l, expr_r): | |
self.expr_l = expr_l | |
self.expr_r = expr_r | |
self.val = None | |
def get_val(self): | |
return self.val | |
def forward(self): | |
self.expr_l.forward() | |
self.expr_r.forward() | |
self.val = self.expr_l.get_val() * self.expr_r.get_val() | |
def accumulate_derivative(self, dval): | |
# Note: Makes use for val. So need to evaluate before pushing back! | |
# | |
# Product rule (multiplication and expression switcheroo): | |
# e := e1 * e2 | |
# de/dt * dT = de1/dt * (e2 * dT) + de2/dt * (e1 * dT) | |
self.expr_l.accumulate_derivative(self.expr_r.get_val() * dval) | |
self.expr_r.accumulate_derivative(self.expr_l.get_val() * dval) | |
def run_comparison_demo(): | |
# Example using VarExpr: Finding the dvals of x * ((x + y) + d) + y * y, at (x, y) = (2, 3) | |
x = VarExpr(2) | |
y = VarExpr(3) | |
d = VarExpr(2) # Const. | |
z = p(d, x, y) | |
dz_dx = z.make_val_and_dval(x) | |
dz_dy = z.make_val_and_dval(y) | |
dz_dd = z.make_val_and_dval(d) | |
z_VarExpr_val = z.get_val() | |
print("[VarExpr exmaple] z =", z_VarExpr_val) | |
print("[VarExpr exmaple] ∂z/∂x =", dz_dx.get_dval()) | |
print("[VarExpr exmaple] ∂z/∂y =", dz_dy.get_dval()) | |
print("[VarExpr exmaple] ∂z/∂d =", dz_dd.get_dval()) | |
print() | |
# Example using VarExprRevAcc | |
x = VarExprRevAcc(2) | |
y = VarExprRevAcc(3) | |
d = VarExprRevAcc(2) | |
z = p(d, x, y) | |
z.eval() | |
z_VarExprRevAcc_val = z.get_val() | |
print("[VarExprRevAcc exmaple] z =", z_VarExprRevAcc_val) | |
print("[VarExprRevAcc exmaple] ∂z/∂x =", x.get_dval()) | |
print("[VarExprRevAcc exmaple] ∂z/∂y =", y.get_dval()) | |
print("[VarExprRevAcc exmaple] ∂z/∂d =", d.get_dval()) | |
print() | |
assert z_VarExprRevAcc_val == z_VarExpr_val | |
def run_VarExprRevAcc_gradient_decent_demo(): | |
class Config: | |
ITERATIONS = 5000 | |
EPS = 1e-3 | |
MU = EPS | |
x = VarExprRevAcc(-3) | |
y = VarExprRevAcc(-5) | |
d = VarExprRevAcc(2) # Const. | |
for idx in range(Config.ITERATIONS): | |
z = p(d, x, y) | |
z.eval() | |
x.step(Config.MU) | |
y.step(Config.MU) | |
if (idx < 1000 and idx % 100 == 0) or idx % 1000 == 0: | |
print(f"[gradient_decent_demo] idx #{idx}, z = {z.get_val()}") | |
# Validation | |
MIN_Z_GT = -4/3 # Minimum of p | |
assert abs(z.get_val() - MIN_Z_GT) < Config.EPS | |
if __name__=='__main__': | |
torch_example() | |
print() | |
run_stack_demo() | |
print() | |
run_comparison_demo() | |
print() | |
run_VarExprRevAcc_gradient_decent_demo() | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment