Created
May 9, 2020 12:32
-
-
Save avinashselvam/b098cc2b02ca6664c260ea5c880bbbf9 to your computer and use it in GitHub Desktop.
example of forward and backward automatic differentiation on a computational graph
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Constant(): | |
def __init__(self, value): | |
self.value = value | |
self.gradient = None | |
def evaluate(self): | |
return self.value | |
def derivative(self, wrt_variable): | |
# derivative of a constant w.r.t anything is 0 | |
return 0 | |
def backprop(self, prev_gradient): | |
# can't differentiate with respect to a constant | |
pass | |
class Variable(): | |
count = 0 | |
def __init__(self, value): | |
Variable.count += 1 | |
self.name = "var"+str(Variable.count) | |
self.value = value | |
self.gradient = 0 | |
def evaluate(self): | |
return self.value | |
def derivative(self, wrt_variable): | |
# derivative w.r.t itself is 1 otherwise 0 | |
return 1 if wrt_variable == self else 0 | |
def backprop(self, prev_gradient): | |
# the variable maybe present in many nodes in the graph | |
# we add all the contributions | |
self.gradient += prev_gradient | |
class BinaryOperator(): | |
def __init__(self, a, b): | |
self.a = a | |
self.b = b | |
self.cache = None | |
class Add(BinaryOperator): | |
def evaluate(self): | |
if not self.cache: self.cache = self.a.evaluate() + self.b.evaluate() | |
return self.cache | |
def derivative(self, wrt_variable): | |
# (f+g)' = f' + g' | |
return self.a.derivative(wrt_variable) + self.b.derivative(wrt_variable) | |
def backprop(self, prev_gradient): | |
self.a.backprop(prev_gradient) | |
self.b.backprop(prev_gradient) | |
class Multiply(BinaryOperator): | |
def evaluate(self): | |
if not self.cache: self.cache = self.a.evaluate()*self.b.evaluate() | |
return self.cache | |
def derivative(self, wrt_variable): | |
# (uv)' = uv' + u'v | |
return self.a.derivative(wrt_variable)*self.b.evaluate() + self.a.evaluate()*self.b.derivative(wrt_variable) | |
def backprop(self, prev_gradient): | |
self.a.backprop(self.b.evaluate()*prev_gradient) | |
self.b.backprop(self.a.evaluate()*prev_gradient) | |
""" | |
z(x) = v(u(x)) | |
forward diff --> z'(x) = u'(x)*v'(u(x)) | |
backward diff --> z'(x) = v'(u(x))*u'(x) | |
Let z = x**2 + xy + 2 | |
""" | |
x = Variable(3.0) | |
y = Variable(5.0) | |
graph = Add(Add(Multiply(x, x), Multiply(x, y)), Constant(2)) | |
graph.backprop(1.0) | |
print(f"for x = 3, y = 5 we get z = {graph.evaluate()}") | |
print(f"by forward diff dzdx = {graph.derivative(x)}, dzdy = {graph.derivative(y)}") | |
print(f"by backward diff dzdx = {x.gradient}, dzdy = {y.gradient}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment