Created
May 25, 2022 03:36
-
-
Save luciotorre/bbb14c3d03337b5fb62efd503aec0c98 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import copy | |
import inspect | |
import math | |
from dataclasses import dataclass | |
import numpy as np | |
_d = {math.sin: math.cos} | |
@dataclass | |
class Dual: | |
value: float | |
diff: float = 0 | |
@classmethod | |
def force(cls, value): | |
if isinstance(value, cls): | |
return value | |
return cls(value) | |
def __pow__(self, other): | |
other = Dual.force(other) | |
if other.diff != 0: | |
raise "Not yet" | |
return Dual( | |
self.value ** other.value, | |
other.value * (self.value ** (other.value - 1)) * self.diff | |
) | |
def __add__(self, other): | |
other = Dual.force(other) | |
return Dual(self.value + other.value, self.diff + other.diff) | |
def __gt__(self, other): | |
other = Dual.force(other) | |
return self.value > other.value | |
def chain_rule(f, expr): | |
df = _d[f] | |
expr = Dual.force(expr) | |
return Dual( | |
value=f(expr.value), | |
diff=df(expr.value) * expr.diff) | |
class DifferentiateAST(ast.NodeTransformer): | |
def visit_Call(self, node): | |
return ast.Call( | |
func=ast.Name(id="chain_rule"), | |
args=[node.func] + [self.visit(n) for n in node.args], | |
keywords=node.keywords | |
) | |
def preprocess(f): | |
fast = ast.parse(inspect.getsource(f)) | |
dast = DifferentiateAST().visit(copy.deepcopy(fast)) | |
new_code = ast.unparse(dast) | |
co = compile(new_code, "<dual>", 'exec') | |
scope = dict(f.__globals__, **inspect.getclosurevars(f).nonlocals) | |
exec(co, scope, scope) | |
return scope[f.__name__] | |
def differentiate(f): | |
df = preprocess(f) | |
def __inner__(x): | |
x_dual = Dual(x, 1) | |
fx_dual = Dual.force(df(x_dual)) | |
return fx_dual.diff | |
return __inner__ | |
def ReLU(x): | |
if x > 0: | |
return x | |
else: | |
return 0 | |
dReLU = differentiate(ReLU) | |
assert dReLU(-1) == 0 | |
assert dReLU(0) == 0 | |
assert dReLU(1) == 1 | |
assert dReLU(2) == 1 | |
def chirp(x): | |
return math.sin(x ** 2) | |
dchirp = differentiate(chirp) | |
for i in range(1, 100): | |
zero = math.sqrt(math.pi / 2) * math.sqrt(2 * i - 1) | |
assert np.isclose(dchirp(zero), 0) | |
def f(x): | |
return x ** 2 + 1 | |
assert differentiate(f)(2) == 4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment