Skip to content

Instantly share code, notes, and snippets.

@aslpavel
Last active December 16, 2018 18:55
Show Gist options
  • Save aslpavel/a365dacac494ad2b02a2188fe1f7c757 to your computer and use it in GitHub Desktop.
Save aslpavel/a365dacac494ad2b02a2188fe1f7c757 to your computer and use it in GitHub Desktop.
Simple implementation of automatic differentiation and SVM in python
"""Simple implementation of atomatic differentiation
"""
import math
import inspect
import numpy as np
def apply(fn, *args, **kwargs):
return fn(*args, **kwargs)
def debug(fn):
def fn_debug(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
import sys
import pdb
sys.stderr.write(f"{e}\n")
pdb.post_mortem()
return fn_debug
# ------------------------------------------------------------------------------
# Expression
# ------------------------------------------------------------------------------
def lift(value):
"""Lift value to expression domain
"""
if isinstance(value, Expr):
return value
elif isinstance(value, (float, int, np.ndarray)):
return Expr("val", value)
else:
raise TypeError(f"value of type {type(value)} is not supported")
@apply
class Var:
def __call__(self, name):
return Expr("var", name)
def __getattr__(self, name):
return Expr("var", name)
def __repr__(self):
return self.name
def Val(val):
return Expr("val", val)
class Expr:
__slots__ = ("name", "args")
def __init__(self, name, *args):
self.name = name
self.args = args
@classmethod
def from_fn(cls, fn):
args = inspect.getfullargspec(fn).args
return fn(**{arg: Var(arg) for arg in args})
def __repr__(self):
if self.name == "var":
return str(self.args[0])
return "({}{})".format(self.name, "".join(f" {arg}" for arg in self.args))
def __add__(self, other):
return Expr("+", self, lift(other))
def __radd__(self, other):
return Expr("+", lift(other), self)
def __sub__(self, other):
return Expr("-", self, lift(other))
def __rsub__(self, other):
return Expr("-", lift(other), self)
def __mul__(self, other):
return Expr("*", self, lift(other))
def __rmul__(self, other):
return Expr("*", lift(other), self)
def __truediv__(self, other):
return Expr("//", self, lift(other))
def __rtruediv__(self, other):
return Expr("//", lift(other), self)
def __pow__(self, other):
return Expr("**", self, lift(other))
def __rpow__(self, other):
return Expr("**", lift(other), self)
def __matmul__(self, other):
return Expr("@", self, lift(other))
def __rmatmul__(self, other):
return Expr("@", lift(other), self)
def __getitem__(self, selector):
return Expr("getitem", self, Expr("val", selector))
def clip(self, min=None, max=None):
if min is None and max is None:
raise ValueError("One of max or min must be given")
return Expr("clip", self, Val(min), Val(max))
def sum(self, axis=None):
return Expr("sum", self, Val(axis))
sin = lambda self: Expr("sin", self)
cos = lambda self: Expr("cos", self)
max = lambda self: Expr("max", self)
mean = lambda self: Expr("mean", self)
log = lambda self: Expr("log", self)
dot = lambda self, other: self @ other
exp = lambda self: Expr("**", Val(math.e), self)
# ------------------------------------------------------------------------------
# Backward AD mode
# ------------------------------------------------------------------------------
def backward(expr, **values):
"""Find value and derivatives using backward method
"""
class Var:
def __init__(self, tape, value, index):
self.tape = tape
self.value = value
self.index = index
def grad(self):
derivs = [0] * (self.index + 1)
derivs[-1] = 1
nodes = self.tape.nodes
for index in range(self.index, -1, -1):
deriv = derivs[index]
for dep, weight in nodes[index]:
derivs[dep] += weight(deriv)
return Grad(derivs)
def __repr__(self):
return repr(self.value)
class Grad:
def __init__(self, derivs):
self.derivs = derivs
def wrt(self, var):
return self.derivs[var.index]
class Tape:
def __init__(self):
self.nodes = []
def push(self, *dep_weigths):
assert len(dep_weigths) % 2 == 0
deps = dep_weigths[::2]
weights = dep_weigths[1::2]
self.nodes.append(tuple(zip(deps, weights)))
return len(self.nodes) - 1
def var(self, value, *index_vjp):
return Var(self, value, self.push(*index_vjp))
def op_clip(tape, val, min, max):
min, max = min.value, max.value
ans = val.value.clip(min, max)
return tape.var(
ans, val.index, lambda g: np.logical_and(ans != min, ans != max) * g
)
def op_getitem(tape, val, selector):
def vjp(g):
out = np.zeros_like(val.value)
out[selector.value] = g
return out
return tape.var(val.value[selector.value], val.index, vjp)
def op_max(tape, vs):
ans = np.max(vs.value)
loc = vs.value == ans
return tape.var(ans, vs.index, lambda g: loc / loc.sum() * g)
def op_sum(tape, val, axis):
ans = np.sum(val.value, axis=axis.value)
return tape.var(
ans, val.index, lambda g: broadcast_to_match(g, val.value, axis.value)
)
def op_dot(tape, left, right):
def adj_l(m):
if np.ndim(m) == 1:
m = m[:, np.newaxis]
def vjp(g):
if np.ndim(g) == 1:
g = g[:, np.newaxis]
return np.squeeze(np.dot(g, np.transpose(m)))
return vjp
def adj_r(m):
if np.ndim(m) == 1:
m = m[:, np.newaxis]
def vjp(g):
if np.ndim(g) == 1:
g = g[:, np.newaxis]
return np.squeeze(np.dot(np.transpose(m), g))
return vjp
return tape.var(
np.dot(left.value, right.value),
left.index,
adj_l(right.value),
right.index,
adj_r(left.value),
)
def broadcast_to_match(a, b, axis):
"""broadcast `a` along `axis` to match `b` shape
"""
shape = np.array(b.shape)
shape[list(axis) if isinstance(axis, tuple) else axis] = 1
return np.broadcast_to(np.reshape(a, shape), b.shape)
def unbroadcast_vjp(target, vjp_base):
def vjp(g):
out = vjp_base(g)
while np.ndim(out) > np.ndim(target):
out = np.sum(out, axis=0)
for axis, size in enumerate(np.shape(target)):
if size == 1:
out = np.sum(out, axis=axis, keepdims=True)
return out
return vjp
OPS = {
"+": lambda tape, left, right: tape.var(
left.value + right.value,
left.index,
unbroadcast_vjp(left.value, lambda g: g),
right.index,
unbroadcast_vjp(right.value, lambda g: g),
),
"-": lambda tape, left, right: tape.var(
left.value - right.value,
left.index,
unbroadcast_vjp(left.value, lambda g: g),
right.index,
unbroadcast_vjp(right.value, lambda g: -g),
),
"*": lambda tape, left, right: tape.var(
left.value * right.value,
left.index,
unbroadcast_vjp(left.value, lambda g: right.value * g),
right.index,
unbroadcast_vjp(right.value, lambda g: left.value * g),
),
"//": lambda tape, left, right: tape.var(
left.value / right.value,
left.index,
unbroadcast_vjp(left.value, lambda g: 1 / right.value * g),
right.index,
unbroadcast_vjp(
right.value, lambda g: -left.value / (right.value ** 2) * g
),
),
"**": lambda tape, base, exp: tape.var(
base.value ** exp.value,
base.index,
lambda g: exp.value * (base.value ** (exp.value - 1)) * g,
exp.index,
lambda g: math.log(base.value) * (base.value ** exp.value) * g,
),
"@": op_dot,
"sum": op_sum,
"sin": lambda tape, value: tape.var(
np.sin(value.value), value.index, lambda g: np.cos(value.value) * g
),
"log": lambda tape, val: tape.var(
np.log(val.value), val.index, lambda g: 1 / val.value * g
),
"clip": op_clip,
"getitem": op_getitem,
"max": op_max,
"mean": lambda tape, val: tape.var(
np.mean(val.value),
val.index,
lambda g: np.full_like(val.value, 1 / val.value.size) * g,
),
}
def evaluate(expr, values):
def evaluate_rec(expr):
if expr.name == "var":
name = expr.args[0]
var = scope.get(name)
if var is None:
raise ValueError(f"value for `{name}` was not specified")
return var
elif expr.name == "val":
return tape.var(expr.args[0])
else:
op = OPS.get(expr.name)
if op is None:
raise ValueError(f"unsupported function `{expr.name}`")
r = op(tape, *(evaluate_rec(arg) for arg in expr.args))
return r
tape = Tape()
scope = {k: tape.var(v) for k, v in values.items()}
result = evaluate_rec(expr)
grad = result.grad()
return result.value, {k: grad.wrt(v) for k, v in scope.items()}
return evaluate(expr, values)
def grad(fn):
spec = inspect.getfullargspec(fn).args
expr = fn(**{name: Var(name) for name in spec})
return lambda *args: backward(expr, **dict(zip(spec, args)))
# ------------------------------------------------------------------------------
# Examples
# ------------------------------------------------------------------------------
def hinge_loss(weights, bias, margin, input):
"""[1 / n * ∑max(0, 1 - yᵢ(ω * xᵢ - b))] + λ * ||ω||²"""
yi = input[:, 0]
xi = input[:, 1:]
S = (yi * (xi @ weights - bias) * (-1) + 1).clip(0.0, None).mean()
M = weights @ weights * margin
return S + M
@debug
def svm(input, rate=0.001, margin=0.1, generations=10000, error=0.001):
"""Linear Support Vector Machine
First columnt of `input` contains class yᵢ in (1, -1), rest of the row
contains xᵢ. Returns `(weights, bias)` which defines `weight * x - bias`
dividing hyperplane.
"""
assert input.ndim == 2, "unexpected input format"
loss_fn = grad(hinge_loss)
# initialize
bias = np.random.randn(1)[0]
weights = np.random.randn(input.shape[1] - 1)
# stochastic gradient descent
for generation in range(generations):
loss, grads = loss_fn(weights, bias, margin, input)
if loss < error:
break
weights -= rate * grads["weights"]
bias -= rate * grads["bias"]
return (weights, bias), loss, generation
# ------------------------------------------------------------------------------
# Unittests
# ------------------------------------------------------------------------------
import unittest # noqa: E402
class BackwardDiffTest(unittest.TestCase):
def test_simple(self):
expr = Var.x * Var.y + Var.x.sin()
x, y = 0.5, 4.2
result, grad = backward(expr, x=x, y=y)
self.assertAlmostEqual(result, 2.579_425_538_604_20)
self.assertAlmostEqual(grad["x"], y + math.cos(x))
self.assertAlmostEqual(grad["y"], x)
def test_matmul(self):
a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
b = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32).T
c = np.array([[1, 2], [3, 4]], dtype=np.float32)
fn = grad(lambda a, b, c: (a @ b * c).sum())
result, grads = fn(a, b, c)
self.assertTrue(np.allclose(result, 1220))
self.assertTrue(
np.allclose(grads["a"], np.array([[27.0, 30.0, 33.0], [61.0, 68.0, 75.0]]))
)
self.assertTrue(
np.allclose(
grads["b"], np.array([[13.0, 18.0], [17.0, 24.0], [21.0, 30.0]])
)
)
self.assertTrue(
np.allclose(grads["c"], np.array([[50.0, 68.0], [122.0, 167.0]]))
)
def test_hinge_loss(self):
weights = np.array([3, 10], dtype=float)
bias = 5.0
margin = 0.1
input = np.array(
[[1.0, 1.0, 2.0], [1.0, 4.0, 1.0], [-1.0, 0.0, 10.0], [-1.0, -3.0, 5.0]],
dtype=float,
)
out, grads = grad(hinge_loss)(weights, bias, margin, input)
out_expected = hinge_loss(weights, bias, margin, input)
self.assertTrue(np.allclose(out, out_expected), f"{out} != {out_expected}")
grads_expected = {
"weights": np.array([-0.15, 5.75]),
"bias": -0.5,
"margin": 109.0,
"input": np.array(
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[-23.75, 0.75, 2.5],
[-9.0, 0.75, 2.5],
]
),
}
for arg, arg_grad in grads.items():
arg_grad_expected = grads_expected[arg]
self.assertTrue(
np.allclose(arg_grad, arg_grad_expected),
f"{arg}: {arg_grad} != {arg_grad_expected}",
)
def test_svm(self):
input = np.array(
[[1.0, 1.0, 2.0], [1.0, 4.0, 1.0], [-1.0, 0.0, 10.0], [-1.0, -3.0, 5.0]]
)
(weights, bias), loss, gen = svm(input)
self.assertLess(loss, 0.1)
def test_sum_logistic(self):
@grad
def sum_logistic(x):
return (1 / (1 + 1 / x.exp())).sum()
out, grads = sum_logistic(np.arange(3.0))
self.assertAlmostEqual(out, 2.111_855_656_607_887)
self.assertTrue(
np.allclose(grads["x"], np.array([0.25, 0.196_611_93, 0.104_993_59]))
)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment