Last active
December 16, 2018 18:55
-
-
Save aslpavel/a365dacac494ad2b02a2188fe1f7c757 to your computer and use it in GitHub Desktop.
Simple implementation of automatic differentiation and SVM in python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple implementation of atomatic differentiation | |
""" | |
import math | |
import inspect | |
import numpy as np | |
def apply(fn, *args, **kwargs): | |
return fn(*args, **kwargs) | |
def debug(fn): | |
def fn_debug(*args, **kwargs): | |
try: | |
return fn(*args, **kwargs) | |
except Exception as e: | |
import sys | |
import pdb | |
sys.stderr.write(f"{e}\n") | |
pdb.post_mortem() | |
return fn_debug | |
# ------------------------------------------------------------------------------ | |
# Expression | |
# ------------------------------------------------------------------------------ | |
def lift(value): | |
"""Lift value to expression domain | |
""" | |
if isinstance(value, Expr): | |
return value | |
elif isinstance(value, (float, int, np.ndarray)): | |
return Expr("val", value) | |
else: | |
raise TypeError(f"value of type {type(value)} is not supported") | |
@apply | |
class Var: | |
def __call__(self, name): | |
return Expr("var", name) | |
def __getattr__(self, name): | |
return Expr("var", name) | |
def __repr__(self): | |
return self.name | |
def Val(val): | |
return Expr("val", val) | |
class Expr: | |
__slots__ = ("name", "args") | |
def __init__(self, name, *args): | |
self.name = name | |
self.args = args | |
@classmethod | |
def from_fn(cls, fn): | |
args = inspect.getfullargspec(fn).args | |
return fn(**{arg: Var(arg) for arg in args}) | |
def __repr__(self): | |
if self.name == "var": | |
return str(self.args[0]) | |
return "({}{})".format(self.name, "".join(f" {arg}" for arg in self.args)) | |
def __add__(self, other): | |
return Expr("+", self, lift(other)) | |
def __radd__(self, other): | |
return Expr("+", lift(other), self) | |
def __sub__(self, other): | |
return Expr("-", self, lift(other)) | |
def __rsub__(self, other): | |
return Expr("-", lift(other), self) | |
def __mul__(self, other): | |
return Expr("*", self, lift(other)) | |
def __rmul__(self, other): | |
return Expr("*", lift(other), self) | |
def __truediv__(self, other): | |
return Expr("//", self, lift(other)) | |
def __rtruediv__(self, other): | |
return Expr("//", lift(other), self) | |
def __pow__(self, other): | |
return Expr("**", self, lift(other)) | |
def __rpow__(self, other): | |
return Expr("**", lift(other), self) | |
def __matmul__(self, other): | |
return Expr("@", self, lift(other)) | |
def __rmatmul__(self, other): | |
return Expr("@", lift(other), self) | |
def __getitem__(self, selector): | |
return Expr("getitem", self, Expr("val", selector)) | |
def clip(self, min=None, max=None): | |
if min is None and max is None: | |
raise ValueError("One of max or min must be given") | |
return Expr("clip", self, Val(min), Val(max)) | |
def sum(self, axis=None): | |
return Expr("sum", self, Val(axis)) | |
sin = lambda self: Expr("sin", self) | |
cos = lambda self: Expr("cos", self) | |
max = lambda self: Expr("max", self) | |
mean = lambda self: Expr("mean", self) | |
log = lambda self: Expr("log", self) | |
dot = lambda self, other: self @ other | |
exp = lambda self: Expr("**", Val(math.e), self) | |
# ------------------------------------------------------------------------------ | |
# Backward AD mode | |
# ------------------------------------------------------------------------------ | |
def backward(expr, **values): | |
"""Find value and derivatives using backward method | |
""" | |
class Var: | |
def __init__(self, tape, value, index): | |
self.tape = tape | |
self.value = value | |
self.index = index | |
def grad(self): | |
derivs = [0] * (self.index + 1) | |
derivs[-1] = 1 | |
nodes = self.tape.nodes | |
for index in range(self.index, -1, -1): | |
deriv = derivs[index] | |
for dep, weight in nodes[index]: | |
derivs[dep] += weight(deriv) | |
return Grad(derivs) | |
def __repr__(self): | |
return repr(self.value) | |
class Grad: | |
def __init__(self, derivs): | |
self.derivs = derivs | |
def wrt(self, var): | |
return self.derivs[var.index] | |
class Tape: | |
def __init__(self): | |
self.nodes = [] | |
def push(self, *dep_weigths): | |
assert len(dep_weigths) % 2 == 0 | |
deps = dep_weigths[::2] | |
weights = dep_weigths[1::2] | |
self.nodes.append(tuple(zip(deps, weights))) | |
return len(self.nodes) - 1 | |
def var(self, value, *index_vjp): | |
return Var(self, value, self.push(*index_vjp)) | |
def op_clip(tape, val, min, max): | |
min, max = min.value, max.value | |
ans = val.value.clip(min, max) | |
return tape.var( | |
ans, val.index, lambda g: np.logical_and(ans != min, ans != max) * g | |
) | |
def op_getitem(tape, val, selector): | |
def vjp(g): | |
out = np.zeros_like(val.value) | |
out[selector.value] = g | |
return out | |
return tape.var(val.value[selector.value], val.index, vjp) | |
def op_max(tape, vs): | |
ans = np.max(vs.value) | |
loc = vs.value == ans | |
return tape.var(ans, vs.index, lambda g: loc / loc.sum() * g) | |
def op_sum(tape, val, axis): | |
ans = np.sum(val.value, axis=axis.value) | |
return tape.var( | |
ans, val.index, lambda g: broadcast_to_match(g, val.value, axis.value) | |
) | |
def op_dot(tape, left, right): | |
def adj_l(m): | |
if np.ndim(m) == 1: | |
m = m[:, np.newaxis] | |
def vjp(g): | |
if np.ndim(g) == 1: | |
g = g[:, np.newaxis] | |
return np.squeeze(np.dot(g, np.transpose(m))) | |
return vjp | |
def adj_r(m): | |
if np.ndim(m) == 1: | |
m = m[:, np.newaxis] | |
def vjp(g): | |
if np.ndim(g) == 1: | |
g = g[:, np.newaxis] | |
return np.squeeze(np.dot(np.transpose(m), g)) | |
return vjp | |
return tape.var( | |
np.dot(left.value, right.value), | |
left.index, | |
adj_l(right.value), | |
right.index, | |
adj_r(left.value), | |
) | |
def broadcast_to_match(a, b, axis): | |
"""broadcast `a` along `axis` to match `b` shape | |
""" | |
shape = np.array(b.shape) | |
shape[list(axis) if isinstance(axis, tuple) else axis] = 1 | |
return np.broadcast_to(np.reshape(a, shape), b.shape) | |
def unbroadcast_vjp(target, vjp_base): | |
def vjp(g): | |
out = vjp_base(g) | |
while np.ndim(out) > np.ndim(target): | |
out = np.sum(out, axis=0) | |
for axis, size in enumerate(np.shape(target)): | |
if size == 1: | |
out = np.sum(out, axis=axis, keepdims=True) | |
return out | |
return vjp | |
OPS = { | |
"+": lambda tape, left, right: tape.var( | |
left.value + right.value, | |
left.index, | |
unbroadcast_vjp(left.value, lambda g: g), | |
right.index, | |
unbroadcast_vjp(right.value, lambda g: g), | |
), | |
"-": lambda tape, left, right: tape.var( | |
left.value - right.value, | |
left.index, | |
unbroadcast_vjp(left.value, lambda g: g), | |
right.index, | |
unbroadcast_vjp(right.value, lambda g: -g), | |
), | |
"*": lambda tape, left, right: tape.var( | |
left.value * right.value, | |
left.index, | |
unbroadcast_vjp(left.value, lambda g: right.value * g), | |
right.index, | |
unbroadcast_vjp(right.value, lambda g: left.value * g), | |
), | |
"//": lambda tape, left, right: tape.var( | |
left.value / right.value, | |
left.index, | |
unbroadcast_vjp(left.value, lambda g: 1 / right.value * g), | |
right.index, | |
unbroadcast_vjp( | |
right.value, lambda g: -left.value / (right.value ** 2) * g | |
), | |
), | |
"**": lambda tape, base, exp: tape.var( | |
base.value ** exp.value, | |
base.index, | |
lambda g: exp.value * (base.value ** (exp.value - 1)) * g, | |
exp.index, | |
lambda g: math.log(base.value) * (base.value ** exp.value) * g, | |
), | |
"@": op_dot, | |
"sum": op_sum, | |
"sin": lambda tape, value: tape.var( | |
np.sin(value.value), value.index, lambda g: np.cos(value.value) * g | |
), | |
"log": lambda tape, val: tape.var( | |
np.log(val.value), val.index, lambda g: 1 / val.value * g | |
), | |
"clip": op_clip, | |
"getitem": op_getitem, | |
"max": op_max, | |
"mean": lambda tape, val: tape.var( | |
np.mean(val.value), | |
val.index, | |
lambda g: np.full_like(val.value, 1 / val.value.size) * g, | |
), | |
} | |
def evaluate(expr, values): | |
def evaluate_rec(expr): | |
if expr.name == "var": | |
name = expr.args[0] | |
var = scope.get(name) | |
if var is None: | |
raise ValueError(f"value for `{name}` was not specified") | |
return var | |
elif expr.name == "val": | |
return tape.var(expr.args[0]) | |
else: | |
op = OPS.get(expr.name) | |
if op is None: | |
raise ValueError(f"unsupported function `{expr.name}`") | |
r = op(tape, *(evaluate_rec(arg) for arg in expr.args)) | |
return r | |
tape = Tape() | |
scope = {k: tape.var(v) for k, v in values.items()} | |
result = evaluate_rec(expr) | |
grad = result.grad() | |
return result.value, {k: grad.wrt(v) for k, v in scope.items()} | |
return evaluate(expr, values) | |
def grad(fn): | |
spec = inspect.getfullargspec(fn).args | |
expr = fn(**{name: Var(name) for name in spec}) | |
return lambda *args: backward(expr, **dict(zip(spec, args))) | |
# ------------------------------------------------------------------------------ | |
# Examples | |
# ------------------------------------------------------------------------------ | |
def hinge_loss(weights, bias, margin, input): | |
"""[1 / n * ∑max(0, 1 - yᵢ(ω * xᵢ - b))] + λ * ||ω||²""" | |
yi = input[:, 0] | |
xi = input[:, 1:] | |
S = (yi * (xi @ weights - bias) * (-1) + 1).clip(0.0, None).mean() | |
M = weights @ weights * margin | |
return S + M | |
@debug | |
def svm(input, rate=0.001, margin=0.1, generations=10000, error=0.001): | |
"""Linear Support Vector Machine | |
First columnt of `input` contains class yᵢ in (1, -1), rest of the row | |
contains xᵢ. Returns `(weights, bias)` which defines `weight * x - bias` | |
dividing hyperplane. | |
""" | |
assert input.ndim == 2, "unexpected input format" | |
loss_fn = grad(hinge_loss) | |
# initialize | |
bias = np.random.randn(1)[0] | |
weights = np.random.randn(input.shape[1] - 1) | |
# stochastic gradient descent | |
for generation in range(generations): | |
loss, grads = loss_fn(weights, bias, margin, input) | |
if loss < error: | |
break | |
weights -= rate * grads["weights"] | |
bias -= rate * grads["bias"] | |
return (weights, bias), loss, generation | |
# ------------------------------------------------------------------------------ | |
# Unittests | |
# ------------------------------------------------------------------------------ | |
import unittest # noqa: E402 | |
class BackwardDiffTest(unittest.TestCase): | |
def test_simple(self): | |
expr = Var.x * Var.y + Var.x.sin() | |
x, y = 0.5, 4.2 | |
result, grad = backward(expr, x=x, y=y) | |
self.assertAlmostEqual(result, 2.579_425_538_604_20) | |
self.assertAlmostEqual(grad["x"], y + math.cos(x)) | |
self.assertAlmostEqual(grad["y"], x) | |
def test_matmul(self): | |
a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) | |
b = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32).T | |
c = np.array([[1, 2], [3, 4]], dtype=np.float32) | |
fn = grad(lambda a, b, c: (a @ b * c).sum()) | |
result, grads = fn(a, b, c) | |
self.assertTrue(np.allclose(result, 1220)) | |
self.assertTrue( | |
np.allclose(grads["a"], np.array([[27.0, 30.0, 33.0], [61.0, 68.0, 75.0]])) | |
) | |
self.assertTrue( | |
np.allclose( | |
grads["b"], np.array([[13.0, 18.0], [17.0, 24.0], [21.0, 30.0]]) | |
) | |
) | |
self.assertTrue( | |
np.allclose(grads["c"], np.array([[50.0, 68.0], [122.0, 167.0]])) | |
) | |
def test_hinge_loss(self): | |
weights = np.array([3, 10], dtype=float) | |
bias = 5.0 | |
margin = 0.1 | |
input = np.array( | |
[[1.0, 1.0, 2.0], [1.0, 4.0, 1.0], [-1.0, 0.0, 10.0], [-1.0, -3.0, 5.0]], | |
dtype=float, | |
) | |
out, grads = grad(hinge_loss)(weights, bias, margin, input) | |
out_expected = hinge_loss(weights, bias, margin, input) | |
self.assertTrue(np.allclose(out, out_expected), f"{out} != {out_expected}") | |
grads_expected = { | |
"weights": np.array([-0.15, 5.75]), | |
"bias": -0.5, | |
"margin": 109.0, | |
"input": np.array( | |
[ | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[-23.75, 0.75, 2.5], | |
[-9.0, 0.75, 2.5], | |
] | |
), | |
} | |
for arg, arg_grad in grads.items(): | |
arg_grad_expected = grads_expected[arg] | |
self.assertTrue( | |
np.allclose(arg_grad, arg_grad_expected), | |
f"{arg}: {arg_grad} != {arg_grad_expected}", | |
) | |
def test_svm(self): | |
input = np.array( | |
[[1.0, 1.0, 2.0], [1.0, 4.0, 1.0], [-1.0, 0.0, 10.0], [-1.0, -3.0, 5.0]] | |
) | |
(weights, bias), loss, gen = svm(input) | |
self.assertLess(loss, 0.1) | |
def test_sum_logistic(self): | |
@grad | |
def sum_logistic(x): | |
return (1 / (1 + 1 / x.exp())).sum() | |
out, grads = sum_logistic(np.arange(3.0)) | |
self.assertAlmostEqual(out, 2.111_855_656_607_887) | |
self.assertTrue( | |
np.allclose(grads["x"], np.array([0.25, 0.196_611_93, 0.104_993_59])) | |
) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment