Skip to content

Instantly share code, notes, and snippets.

@luciotorre
Created May 25, 2022 03:36
Show Gist options
  • Save luciotorre/bbb14c3d03337b5fb62efd503aec0c98 to your computer and use it in GitHub Desktop.
Save luciotorre/bbb14c3d03337b5fb62efd503aec0c98 to your computer and use it in GitHub Desktop.
import ast
import copy
import inspect
import math
from dataclasses import dataclass
import numpy as np
_d = {math.sin: math.cos}
@dataclass
class Dual:
value: float
diff: float = 0
@classmethod
def force(cls, value):
if isinstance(value, cls):
return value
return cls(value)
def __pow__(self, other):
other = Dual.force(other)
if other.diff != 0:
raise "Not yet"
return Dual(
self.value ** other.value,
other.value * (self.value ** (other.value - 1)) * self.diff
)
def __add__(self, other):
other = Dual.force(other)
return Dual(self.value + other.value, self.diff + other.diff)
def __gt__(self, other):
other = Dual.force(other)
return self.value > other.value
def chain_rule(f, expr):
df = _d[f]
expr = Dual.force(expr)
return Dual(
value=f(expr.value),
diff=df(expr.value) * expr.diff)
class DifferentiateAST(ast.NodeTransformer):
def visit_Call(self, node):
return ast.Call(
func=ast.Name(id="chain_rule"),
args=[node.func] + [self.visit(n) for n in node.args],
keywords=node.keywords
)
def preprocess(f):
fast = ast.parse(inspect.getsource(f))
dast = DifferentiateAST().visit(copy.deepcopy(fast))
new_code = ast.unparse(dast)
co = compile(new_code, "<dual>", 'exec')
scope = dict(f.__globals__, **inspect.getclosurevars(f).nonlocals)
exec(co, scope, scope)
return scope[f.__name__]
def differentiate(f):
df = preprocess(f)
def __inner__(x):
x_dual = Dual(x, 1)
fx_dual = Dual.force(df(x_dual))
return fx_dual.diff
return __inner__
def ReLU(x):
if x > 0:
return x
else:
return 0
dReLU = differentiate(ReLU)
assert dReLU(-1) == 0
assert dReLU(0) == 0
assert dReLU(1) == 1
assert dReLU(2) == 1
def chirp(x):
return math.sin(x ** 2)
dchirp = differentiate(chirp)
for i in range(1, 100):
zero = math.sqrt(math.pi / 2) * math.sqrt(2 * i - 1)
assert np.isclose(dchirp(zero), 0)
def f(x):
return x ** 2 + 1
assert differentiate(f)(2) == 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment