Skip to content

Instantly share code, notes, and snippets.

@Swalloow
Created January 3, 2018 08:51
Show Gist options
  • Save Swalloow/ecb8f637d6966115d9d9870415fc5623 to your computer and use it in GitHub Desktop.
Save Swalloow/ecb8f637d6966115d9d9870415fc5623 to your computer and use it in GitHub Desktop.
Imperative auto differentiation
class array(object) :
"""Simple Array object that support autodiff."""
def __init__(self, value, name=None):
self.value = value
if name:
self.grad = lambda g : {name : g}
def __add__(self, other):
assert isinstance(other, int)
ret = array(self.value + other)
ret.grad = lambda g : self.grad(g)
return ret
def __mul__(self, other):
assert isinstance(other, array)
ret = array(self.value * other.value)
def grad(g):
x = self.grad(g * other.value)
x.update(other.grad(g * self.value))
return x
ret.grad = grad
return ret
# some examples
a = array(1, 'a')
b = array(2, 'b')
c = b * a
d = c + 1
print d.value
print d.grad(1)
# Results
# 3
# {'a': 2, 'b': 1}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment