Created
February 3, 2022 03:10
-
-
Save mattjj/52914908ac22d9ad57b76b685d19acb8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from contextlib import contextmanager | |
from typing import NamedTuple, Callable, Optional, Any | |
import numpy as np | |
Array = Any | |
class Node(NamedTuple): | |
vjp: Optional[Callable] | |
parents: List[Node] | |
parentless_node = lambda: Node(None, []) | |
class Tracer(NamedTuple): | |
level: int | |
val: Array | |
node: Node | |
def primitive(f): | |
def wrapped(*args): | |
level = find_top_level(args) | |
if not level: return f(*args) | |
tracers = [lift(level, x) for x in args] | |
return process(level, f, tracers) | |
return wrapped | |
sin = primitive(np.sin) | |
cos = primitive(np.cos) | |
add = Tracer.__add__ = Tracer.__radd__ = primitive(np.add) | |
mul = Tracer.__mul__ = Tracer.__rmul__ = primitive(np.multiply) | |
neg = Tracer.__neg__ = primitive(np.negative) | |
def find_top_level(args): | |
return max((x.level for x in args if isinstance(x, Tracer)), default=0) | |
def lift(level, x): | |
if isinstance(x, Tracer) and x.level == level: | |
return x | |
return Tracer(level=level, val=x, node=parentless_node()) | |
def process(level, prim, tracers): | |
in_vals, in_nodes = zip(*[(t.val, t.node) for t in tracers]) | |
out_val, prim_vjp = vjp_rules[prim](*in_vals) | |
out_node = Node(vjp=prim_vjp, parents=in_nodes) | |
return Tracer(level=level, val=out_val, node=out_node) | |
def vjp(f, *args): | |
with new_trace_level() as level: | |
in_tracers = [Tracer(level=level, val=x, node=parentless_node()) | |
for x in args] | |
out = f(*in_tracers) | |
_, out_val, out_node = lift(level, out) | |
in_nodes = [t.node for t in in_tracers] | |
f_vjp = lambda g: backward_pass(in_nodes, out_node, g) | |
return out_val, f_vjp | |
trace_level = 0 | |
@contextmanager | |
def new_trace_level(): | |
global trace_level | |
trace_level += 1 | |
try: | |
yield trace_level | |
finally: | |
trace_level -= 1 | |
def backward_pass(in_nodes, out_node, g): | |
env = {id(out_node): g} | |
for node in toposort(out_node): | |
out_bar = env.pop(id(node)) | |
inputs_bar = node.vjp(out_bar) | |
for input_bar, parent in zip(inputs_bar, node.parents): | |
env[id(parent)] = add_grads(env.get(id(parent)), input_bar) | |
return [env.get(id(node)) for node in in_nodes] | |
def add_grads(g1, g2): | |
return g2 if g1 is None else g1 + g2 | |
def toposort(end_node): | |
return reversed([n for n in _toposort(set(), end_node) if n.parents]) | |
def _toposort(seen, node): | |
if id(node) not in seen: | |
seen.add(id(node)) | |
for p in node.parents: | |
yield from _toposort(seen, p) | |
yield node | |
vjp_rules = {} | |
vjp_rules[np.sin] = lambda x: (sin(x), lambda g: [ cos(x) * g]) | |
vjp_rules[np.cos] = lambda x: (cos(x), lambda g: [-sin(x) * g]) | |
vjp_rules[np.add] = lambda x, y: (x + y, lambda g: [g, g]) | |
vjp_rules[np.multiply] = lambda x, y: (x * y, lambda g: [g * y, x * g]) | |
vjp_rules[np.negative] = lambda x: (-x, lambda g: [-g]) | |
def grad(f): | |
def f_grad(*args): | |
_, f_vjp = vjp(f, *args) | |
return f_vjp(1.)[0] | |
return f_grad | |
### | |
def f(x): | |
return sin(sin(x)) + x | |
print(f(3.)) | |
print(grad(f)(3.)) | |
print(grad(grad(f))(3.)) | |
print(grad(lambda x: grad(lambda y: x * y)(1.))(1.)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment