Skip to content

Instantly share code, notes, and snippets.

@yangchenyun
Last active March 13, 2023 15:48
Show Gist options
  • Save yangchenyun/6c801e151441258b83d2b17fa45d2b3d to your computer and use it in GitHub Desktop.
Save yangchenyun/6c801e151441258b83d2b17fa45d2b3d to your computer and use it in GitHub Desktop.
Symbolic differentiator
"""
The following program implements a numeric system with primitive data
structures and algorithms that can calculate derivatives for arbitrary
arithmetic expressions.
The core data structure is called `Dual`, which represents a number with its
numeric value and the operation performed in the expression AST. This means that
`Dual` remembers how the value is computed.
All primitive operations (including binary and unary) have the signature
`[]Duals -> Dual`. Additionally, primitive operations calculate the derivative
with respect to their parameters using differential calculus.
Composed operations use the chain rule to accumulate the result while traversing
the chain.
TODO:
- Merge the data type with python's number typing system (int, float)
- Implement other number protocols (__iadd__, __radd__, etc.)
- Support vector calculus
"""
from collections import defaultdict
import math
def treemap(fn, tree):
"""Treemap works similar of `map' but on a nested list instead of list.
It modifes the leaf nodes on a tree.
treemap(add1, [1, [2, [3, 4]]) => [2, [3, [4, 5]]
"""
if tree is None:
return tree
elif isinstance(tree, list):
return [treemap(fn, elem) for elem in tree]
else:
return fn(tree)
assert treemap(lambda x: x + 1, None) == None
assert treemap(lambda x: x + 1, []) == []
assert treemap(lambda x: x + 1, 1) == 2
assert treemap(lambda x: x + 1, [1]) == [2]
assert treemap(lambda x: x + 1, [1, [2, [3, 4]]]) == [2, [3, [4, 5]]]
def linkOperator(dual, accumulated, store):
"""LinkOperator takes on a dual, accumulated gradient value, and a store.
It returns a LinkOperator to construct new Duals.
This is just a function notation, the real implementation could be found below.
"""
pass
def endOperator(dual, accumulated, store):
"""Special LinkOperator to save the accumulated gradient value in a given store."""
store[dual] += accumulated
return store
def prim2(opName, valueFn, gradientFn):
"""Returns derivitive link function for binary operator: +,-,*,/,exp.
valueFn returns the numeric value
gradientFn returns the corresponding derivitive value according to formula
"""
def binaryOp(self, other):
if isinstance(other, int) or isinstance(other, float):
other = Dual(other)
def dop(dual, accumulated, store):
a = self
b = other
ga, gb = gradientFn(a.value, b.value, accumulated)
newStore = a.link(a, ga, store)
newStore = b.link(b, gb, newStore)
return newStore
name = f"{self.value} {opName} {other.value}"
return self.__class__(valueFn(self.value, other.value), dop, name=name)
return binaryOp
def prim1(opName, valueFn, gradientFn):
"""Returns derivitive link function for binary operator: +,-,*,/,exp.
valueFn returns the numeric value
gradientFn returns the corresponding derivitive value according to formula
"""
def unaryOp(self):
def dop(dual, accumulated, store):
a = self
ga = gradientFn(a.value, accumulated)
return a.link(a, (accumulated * ga), store)
name = f"{opName}({self.value})"
return self.__class__(valueFn(self.value), dop, name=name)
return unaryOp
class Dual:
def __init__(self, value, link=endOperator, name=None):
self.value = value
# link is a function which captures the operator's gradient calculation
# on the chain
self.link = link
if name is None:
self.name = str(value)
else:
self.name = name
def truncate(self):
"""Truncate reset the link information for the dual."""
self.link = endOperator
return self
def grad(self, store):
return self.link(self, 1.0, store)
# Binary operator
__add__ = prim2("+", lambda a, b: a + b, lambda a, b, z: (z, z))
__radd__ = prim2("+", lambda b, a: a + b, lambda b, a, z: (z, z))
__sub__ = prim2("-", lambda a, b: a - b, lambda a, b, z: (z, -z))
__rsub__ = prim2("-", lambda b, a: b - a, lambda b, a, z: (-z, z))
__mul__ = prim2("*", lambda a, b: a * b, lambda a, b, z: ((b * z), (a * z)))
__rmul__ = prim2("*", lambda b, a: a * b, lambda b, a, z: ((a * z), (b * z)))
__truediv__ = prim2(
"/", lambda a, b: a / b, lambda a, b, z: ((z * (1.0 / b)), (z * (-a / (b * b))))
)
__rtruediv__ = prim2(
"/", lambda b, a: a / b, lambda b, a, z: ((z * (-a / (b * b))), (z * (1.0 / b)))
)
# dx^y/dx = y * x^(y-1)
# dx^y/dy = x^y * ln(x)
__pow__ = prim2(
"^",
lambda a, b: a**b,
lambda a, b, z: ((z * (b * (a ** (b - 1)))), (z * (a**b * math.log(a)))),
)
# Unary operator
log = prim1("log", lambda a: math.log(a), lambda a, z: z * (1.0 / a))
exp = prim1("exp", lambda a: math.exp(a), lambda a, z: z * math.exp(a))
sqrt = prim1("sqrt", lambda a: math.sqrt(a), lambda a, z: z / (2 * math.sqrt(a)))
def __str__(self):
return f"[Dual(name={self.name}, val:{self.value})]"
def Del(fn, theta):
"""
Del is an operator, which takes in a function and a list of parameters
and return corresponding gradients.
"""
theta = treemap(lambda v: Dual(v), theta)
store = defaultdict(lambda: 0.0)
fn(*theta).grad(store)
# Return in respect to theta's shape
return treemap(lambda t: store[t], theta)
if __name__ == "__main__":
print(
[
# Basic binary operators
Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0]
Del(lambda x, y: x * y, [2.0, 3.0]), # => [2.0, 3.0]
Del(lambda x, y: x - y, [2.0, 3.0]), # => [1.0, -1.0]
Del(lambda x, y: x / y, [2.0, 3.0]), # => [0.3333333, -0.2222]
Del(lambda x, y: x**y, [2.0, 3.0]), # => [12.0, 5.54517]
# Unary operator
Del(lambda x: x.exp(), [2.0]), # => e^2
Del(lambda x: x.log(), [2.0]), # => 0.5
Del(lambda x: x.sqrt(), [2.0]), # => 0.3535
# Composed function
Del(lambda x, y: (x * x) + x + y, [3.0, 2.0]), # => [7.0, 1.0]
# fn = (e^x + log(y) + sqrt(x))/y
Del(lambda x, y: (x.exp() * y.log() + x.sqrt()) / y, [3.0, 2.0]),
# support python int and float
Del(lambda x, b: (x * x * 2.0) + x * 1 + b, [3.0, 2.0]),
# right hand operator
Del(lambda x: 2.0 + x, [3.0]),
Del(lambda x: 2.0 - x, [3.0]),
Del(lambda x: 2.0 * x, [3.0]),
Del(lambda x: 2.0 / x, [3.0]),
Del(lambda x, b: (2.0 * x * x) + 1 * x + b, [3.0, 2.0]),
]
)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 41,
"id": "fba27f8c",
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "dd5a5a4a",
"metadata": {},
"outputs": [],
"source": [
"def treemap(fn, tree):\n",
" \"\"\"Treemap works similar of `map' but on a nested list instead of list.\n",
" It modifes the leaf nodes on a tree.\n",
" treemap(add1, [1, [2, [3, 4]]) => [2, [3, [4, 5]]\n",
" \"\"\"\n",
" if tree is None:\n",
" return tree\n",
" elif isinstance(tree, list):\n",
" return [treemap(fn, elem) for elem in tree]\n",
" else:\n",
" return fn(tree)\n",
"\n",
"\n",
"assert treemap(lambda x: x + 1, None) == None\n",
"assert treemap(lambda x: x + 1, []) == []\n",
"assert treemap(lambda x: x + 1, 1) == 2\n",
"assert treemap(lambda x: x + 1, [1]) == [2]\n",
"assert treemap(lambda x: x + 1, [1, [2, [3, 4]]]) == [2, [3, [4, 5]]]"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "eecbe813",
"metadata": {},
"outputs": [],
"source": [
"def linkOperator(dual, accumulated, store):\n",
" \"\"\"LinkOperator takes on a dual, accumulated gradient value, and a store.\n",
" It returns a new store representing the new gradient collected.\n",
"\n",
" This is just a function notation, the real implementation could be found below.\n",
" \"\"\"\n",
" pass\n",
"\n",
"\n",
"def endOperator(dual, accumulated, store):\n",
" \"\"\"Special LinkOperator to save the accumulated gradient value in a given store.\"\"\"\n",
" store[dual] += accumulated\n",
" return store\n",
"\n",
"\n",
"class Dual:\n",
" def __init__(self, value, link=endOperator, name=None):\n",
" self.value = value # numeric value, int / float\n",
"\n",
" # link is a function which captures the operator's gradient calculation\n",
" # on the chain\n",
" self.link = link \n",
"\n",
" if name is None:\n",
" self.name = str(value)\n",
" else:\n",
" self.name = name\n",
"\n",
" def truncate(self):\n",
" \"\"\"Truncate reset the link information for the dual.\"\"\"\n",
" self.link = endOperator\n",
" return self\n",
"\n",
" def grad(self, store):\n",
" return self.link(self, 1.0, store)\n",
"\n",
"\n",
" def __add__(self, other): # dual, dual -> dual\n",
" def dadd(dual, accumulated, store):\n",
" a = self\n",
" b = other\n",
" newStore = a.link(a, (accumulated * 1.0), store)\n",
" newStore = b.link(b, (accumulated * 1.0), newStore)\n",
" return newStore\n",
" \n",
" name = f\"{self.value} + {other.value}\"\n",
" return self.__class__(self.value + other.value, dadd, name=name)\n",
" \n",
" def __str__(self):\n",
" return f\"[Dual(name={self.name}, val:{self.value})]\"\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "b8d6ccdd",
"metadata": {},
"outputs": [],
"source": [
"def Del(fn, theta):\n",
" \"\"\"\n",
" Del is an operator, which takes in a function and a list of parameters\n",
" and return corresponding gradients.\n",
" \"\"\"\n",
" theta = treemap(lambda v: Dual(v), theta)\n",
" store = defaultdict(lambda: 0.0)\n",
" \n",
" result = fn(*theta)\n",
" # print(result)\n",
" \n",
" result.grad(store)\n",
"\n",
" # Return in respect to theta's shape\n",
" return treemap(lambda t: store[t], theta)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "8fd0be3a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[1.0, 1.0], [2.0, 1.0]]"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[\n",
" Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0]\n",
" Del(lambda x, y: (x + x) + y, [3.0, 2.0]), # => [2.0, 1.0]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "454fcc0f",
"metadata": {},
"outputs": [],
"source": [
"def prim2(opName, valueFn, gradientFn):\n",
" \"\"\"Returns derivitive link function for binary operator: +,-,*,/,exp.\n",
"\n",
" valueFn returns the numeric value\n",
" gradientFn returns the corresponding derivitive value according to formula\n",
"\n",
" \"\"\"\n",
" def binaryOp(self, other):\n",
"\n",
" if isinstance(other, int) or isinstance(other, float):\n",
" other = Dual(other)\n",
"\n",
" assert (type(self) is Dual)\n",
" assert (type(other) is Dual)\n",
" \n",
" def dop(dual, accumulated, store):\n",
" a = self\n",
" b = other\n",
" \n",
" ga, gb = gradientFn(a.value, b.value, accumulated)\n",
"\n",
" newStore = a.link(a, ga, store)\n",
" newStore = b.link(b, gb, newStore)\n",
" return newStore\n",
"\n",
" name = f\"{self.value} {opName} {other.value}\"\n",
"\n",
" return self.__class__(valueFn(self.value, other.value), dop, name=name)\n",
"\n",
" return binaryOp\n",
"\n",
"def prim1(opName, valueFn, gradientFn):\n",
" \"\"\"Returns derivitive link function for binary operator: +,-,*,/,exp.\n",
"\n",
" valueFn returns the numeric value\n",
" gradientFn returns the corresponding derivitive value according to formula\n",
"\n",
" \"\"\"\n",
"\n",
" def unaryOp(self):\n",
" def dop(dual, accumulated, store):\n",
" a = self\n",
" ga = gradientFn(a.value, accumulated)\n",
" return a.link(a, (accumulated * ga), store)\n",
"\n",
" name = f\"{opName}({self.value})\"\n",
"\n",
" return self.__class__(valueFn(self.value), dop, name=name)\n",
"\n",
" return unaryOp"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "fefb265e",
"metadata": {},
"outputs": [],
"source": [
"class Dual:\n",
" def __init__(self, value, link=endOperator, name=None):\n",
" self.value = value\n",
"\n",
" # link is a function which captures the operator's gradient calculation\n",
" # on the chain\n",
" self.link = link\n",
"\n",
" if name is None:\n",
" self.name = str(value)\n",
" else:\n",
" self.name = name\n",
"\n",
" def truncate(self):\n",
" \"\"\"Truncate reset the link information for the dual.\"\"\"\n",
" self.link = endOperator\n",
" return self\n",
"\n",
" def grad(self, store):\n",
" return self.link(self, 1.0, store)\n",
"\n",
" # Binary operator\n",
" __add__ = prim2(\"+\", lambda a, b: a + b, lambda a, b, z: (z, z))\n",
" __radd__ = prim2(\"+\", lambda b, a: a + b, lambda b, a, z: (z, z))\n",
"\n",
" __sub__ = prim2(\"-\", lambda a, b: a - b, lambda a, b, z: (z, -z))\n",
" __rsub__ = prim2(\"-\", lambda b, a: b - a, lambda b, a, z: (-z, z))\n",
"\n",
" __mul__ = prim2(\"*\", lambda a, b: a * b, lambda a, b, z: ((b * z), (a * z)))\n",
" __rmul__ = prim2(\"*\", lambda b, a: a * b, lambda b, a, z: ((a * z), (b * z)))\n",
"\n",
" __truediv__ = prim2(\n",
" \"/\", lambda a, b: a / b, lambda a, b, z: ((z * (1.0 / b)), (z * (-a / (b * b))))\n",
" )\n",
" __rtruediv__ = prim2(\n",
" \"/\", lambda b, a: a / b, lambda b, a, z: ((z * (-a / (b * b))), (z * (1.0 / b)))\n",
" )\n",
"\n",
" # dx^y/dx = y * x^(y-1)\n",
" # dx^y/dy = x^y * ln(x)\n",
" __pow__ = prim2(\n",
" \"^\",\n",
" lambda a, b: a**b,\n",
" lambda a, b, z: ((z * (b * (a ** (b - 1)))), (z * (a**b * math.log(a)))),\n",
" )\n",
" \n",
" # Unary operator\n",
" log = prim1(\"log\", lambda a: math.log(a), lambda a, z: z * (1.0 / a))\n",
" exp = prim1(\"exp\", lambda a: math.exp(a), lambda a, z: z * math.exp(a))\n",
" sqrt = prim1(\"sqrt\", lambda a: math.sqrt(a), lambda a, z: z / (2 * math.sqrt(a)))\n",
"\n",
"\n",
" def __str__(self):\n",
" return f\"[Dual(name={self.name}, val:{self.value})]\""
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "6e525ae9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[1.0, 1.0],\n",
" [2.0, 1.0],\n",
" [3.0, 2.0],\n",
" [1.0, -1.0],\n",
" [0.3333333333333333, -0.2222222222222222],\n",
" [12.0, 5.545177444479562],\n",
" [7.38905609893065],\n",
" [0.5],\n",
" [0.35355339059327373],\n",
" [7.0, 1.0],\n",
" [2.4847079713764115, 46.51502816261462]]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[\n",
" # Basic binary operators\n",
" Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0]\n",
" Del(lambda x, y: (x + x) + y, [3.0, 2.0]), # => [2.0, 1.0]\n",
" Del(lambda x, y: x * y, [2.0, 3.0]), # => [3.0, 2.0]\n",
" Del(lambda x, y: x - y, [2.0, 3.0]), # => [1.0, -1.0]\n",
" Del(lambda x, y: x / y, [2.0, 3.0]), # => [0.3333333, -0.2222]\n",
" Del(lambda x, y: x ** y, [2.0, 3.0]), # => [12.0, 5.54517]\n",
" \n",
" # Unary operator\n",
" Del(lambda x: x.exp(), [2.0]), # => e^2\n",
" Del(lambda x: x.log(), [2.0]), # => 0.5\n",
" Del(lambda x: x.sqrt(), [2.0]), # => 0.3535\n",
" \n",
" # Composed function\n",
" Del(lambda x, y: (x * x) + x + y, [3.0, 2.0]), # => [7.0, 1.0]\n",
" \n",
" # fn = (e^x + log(y) + sqrt(x))/y\n",
" Del(lambda x, y: (x.exp() * y.log() + x.sqrt())/y, [3.0, 2.0]),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "ccb9a9e3",
"metadata": {},
"outputs": [],
"source": [
"# What's going from there?\n",
"# - [x] support basic scalar; merge type of number and dual (medium)\n",
"# Del(lambda x, b: (2.0 * x) + 1.0 * x + b, [3.0, 2.0]), it would blow up now\n",
"# - [x] right hand operator\n",
"# - other triangle operator (easy), sin, csin\n",
"# - support matrix calculus (hard, maybe)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "a6aaab01",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[13.0, 1.0], [1.0], [-1.0], [2.0], [-0.2222222222222222], [13.0, 1.0]]"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# y = 2x^2 + x + b\n",
"[\n",
" # support python int and float\n",
" Del(lambda x, b: (x * x * 2.0) + x * 1 + b, [3.0, 2.0]),\n",
"\n",
" # right hand operator\n",
" Del(lambda x: 2.0 + x, [3.0]),\n",
" Del(lambda x: 2.0 - x, [3.0]),\n",
" Del(lambda x: 2.0 * x, [3.0]),\n",
" Del(lambda x: 2.0 / x, [3.0]),\n",
" Del(lambda x, b: (2.0 * x * x) + 1 * x + b, [3.0, 2.0]),\n",
"]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment