Last active
March 13, 2023 15:48
-
-
Save yangchenyun/6c801e151441258b83d2b17fa45d2b3d to your computer and use it in GitHub Desktop.
Symbolic differentiator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The following program implements a numeric system with primitive data | |
structures and algorithms that can calculate derivatives for arbitrary | |
arithmetic expressions. | |
The core data structure is called `Dual`, which represents a number with its | |
numeric value and the operation performed in the expression AST. This means that | |
`Dual` remembers how the value is computed. | |
All primitive operations (including binary and unary) have the signature | |
`[]Duals -> Dual`. Additionally, primitive operations calculate the derivative | |
with respect to their parameters using differential calculus. | |
Composed operations use the chain rule to accumulate the result while traversing | |
the chain. | |
TODO: | |
- Merge the data type with python's number typing system (int, float) | |
- Implement other number protocols (__iadd__, __radd__, etc.) | |
- Support vector calculus | |
""" | |
from collections import defaultdict | |
import math | |
def treemap(fn, tree): | |
"""Treemap works similar of `map' but on a nested list instead of list. | |
It modifes the leaf nodes on a tree. | |
treemap(add1, [1, [2, [3, 4]]) => [2, [3, [4, 5]] | |
""" | |
if tree is None: | |
return tree | |
elif isinstance(tree, list): | |
return [treemap(fn, elem) for elem in tree] | |
else: | |
return fn(tree) | |
assert treemap(lambda x: x + 1, None) == None | |
assert treemap(lambda x: x + 1, []) == [] | |
assert treemap(lambda x: x + 1, 1) == 2 | |
assert treemap(lambda x: x + 1, [1]) == [2] | |
assert treemap(lambda x: x + 1, [1, [2, [3, 4]]]) == [2, [3, [4, 5]]] | |
def linkOperator(dual, accumulated, store): | |
"""LinkOperator takes on a dual, accumulated gradient value, and a store. | |
It returns a LinkOperator to construct new Duals. | |
This is just a function notation, the real implementation could be found below. | |
""" | |
pass | |
def endOperator(dual, accumulated, store): | |
"""Special LinkOperator to save the accumulated gradient value in a given store.""" | |
store[dual] += accumulated | |
return store | |
def prim2(opName, valueFn, gradientFn): | |
"""Returns derivitive link function for binary operator: +,-,*,/,exp. | |
valueFn returns the numeric value | |
gradientFn returns the corresponding derivitive value according to formula | |
""" | |
def binaryOp(self, other): | |
if isinstance(other, int) or isinstance(other, float): | |
other = Dual(other) | |
def dop(dual, accumulated, store): | |
a = self | |
b = other | |
ga, gb = gradientFn(a.value, b.value, accumulated) | |
newStore = a.link(a, ga, store) | |
newStore = b.link(b, gb, newStore) | |
return newStore | |
name = f"{self.value} {opName} {other.value}" | |
return self.__class__(valueFn(self.value, other.value), dop, name=name) | |
return binaryOp | |
def prim1(opName, valueFn, gradientFn): | |
"""Returns derivitive link function for binary operator: +,-,*,/,exp. | |
valueFn returns the numeric value | |
gradientFn returns the corresponding derivitive value according to formula | |
""" | |
def unaryOp(self): | |
def dop(dual, accumulated, store): | |
a = self | |
ga = gradientFn(a.value, accumulated) | |
return a.link(a, (accumulated * ga), store) | |
name = f"{opName}({self.value})" | |
return self.__class__(valueFn(self.value), dop, name=name) | |
return unaryOp | |
class Dual: | |
def __init__(self, value, link=endOperator, name=None): | |
self.value = value | |
# link is a function which captures the operator's gradient calculation | |
# on the chain | |
self.link = link | |
if name is None: | |
self.name = str(value) | |
else: | |
self.name = name | |
def truncate(self): | |
"""Truncate reset the link information for the dual.""" | |
self.link = endOperator | |
return self | |
def grad(self, store): | |
return self.link(self, 1.0, store) | |
# Binary operator | |
__add__ = prim2("+", lambda a, b: a + b, lambda a, b, z: (z, z)) | |
__radd__ = prim2("+", lambda b, a: a + b, lambda b, a, z: (z, z)) | |
__sub__ = prim2("-", lambda a, b: a - b, lambda a, b, z: (z, -z)) | |
__rsub__ = prim2("-", lambda b, a: b - a, lambda b, a, z: (-z, z)) | |
__mul__ = prim2("*", lambda a, b: a * b, lambda a, b, z: ((b * z), (a * z))) | |
__rmul__ = prim2("*", lambda b, a: a * b, lambda b, a, z: ((a * z), (b * z))) | |
__truediv__ = prim2( | |
"/", lambda a, b: a / b, lambda a, b, z: ((z * (1.0 / b)), (z * (-a / (b * b)))) | |
) | |
__rtruediv__ = prim2( | |
"/", lambda b, a: a / b, lambda b, a, z: ((z * (-a / (b * b))), (z * (1.0 / b))) | |
) | |
# dx^y/dx = y * x^(y-1) | |
# dx^y/dy = x^y * ln(x) | |
__pow__ = prim2( | |
"^", | |
lambda a, b: a**b, | |
lambda a, b, z: ((z * (b * (a ** (b - 1)))), (z * (a**b * math.log(a)))), | |
) | |
# Unary operator | |
log = prim1("log", lambda a: math.log(a), lambda a, z: z * (1.0 / a)) | |
exp = prim1("exp", lambda a: math.exp(a), lambda a, z: z * math.exp(a)) | |
sqrt = prim1("sqrt", lambda a: math.sqrt(a), lambda a, z: z / (2 * math.sqrt(a))) | |
def __str__(self): | |
return f"[Dual(name={self.name}, val:{self.value})]" | |
def Del(fn, theta): | |
""" | |
Del is an operator, which takes in a function and a list of parameters | |
and return corresponding gradients. | |
""" | |
theta = treemap(lambda v: Dual(v), theta) | |
store = defaultdict(lambda: 0.0) | |
fn(*theta).grad(store) | |
# Return in respect to theta's shape | |
return treemap(lambda t: store[t], theta) | |
if __name__ == "__main__": | |
print( | |
[ | |
# Basic binary operators | |
Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0] | |
Del(lambda x, y: x * y, [2.0, 3.0]), # => [2.0, 3.0] | |
Del(lambda x, y: x - y, [2.0, 3.0]), # => [1.0, -1.0] | |
Del(lambda x, y: x / y, [2.0, 3.0]), # => [0.3333333, -0.2222] | |
Del(lambda x, y: x**y, [2.0, 3.0]), # => [12.0, 5.54517] | |
# Unary operator | |
Del(lambda x: x.exp(), [2.0]), # => e^2 | |
Del(lambda x: x.log(), [2.0]), # => 0.5 | |
Del(lambda x: x.sqrt(), [2.0]), # => 0.3535 | |
# Composed function | |
Del(lambda x, y: (x * x) + x + y, [3.0, 2.0]), # => [7.0, 1.0] | |
# fn = (e^x + log(y) + sqrt(x))/y | |
Del(lambda x, y: (x.exp() * y.log() + x.sqrt()) / y, [3.0, 2.0]), | |
# support python int and float | |
Del(lambda x, b: (x * x * 2.0) + x * 1 + b, [3.0, 2.0]), | |
# right hand operator | |
Del(lambda x: 2.0 + x, [3.0]), | |
Del(lambda x: 2.0 - x, [3.0]), | |
Del(lambda x: 2.0 * x, [3.0]), | |
Del(lambda x: 2.0 / x, [3.0]), | |
Del(lambda x, b: (2.0 * x * x) + 1 * x + b, [3.0, 2.0]), | |
] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"id": "fba27f8c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import defaultdict\n", | |
"import math" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"id": "dd5a5a4a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def treemap(fn, tree):\n", | |
" \"\"\"Treemap works similar of `map' but on a nested list instead of list.\n", | |
" It modifes the leaf nodes on a tree.\n", | |
" treemap(add1, [1, [2, [3, 4]]) => [2, [3, [4, 5]]\n", | |
" \"\"\"\n", | |
" if tree is None:\n", | |
" return tree\n", | |
" elif isinstance(tree, list):\n", | |
" return [treemap(fn, elem) for elem in tree]\n", | |
" else:\n", | |
" return fn(tree)\n", | |
"\n", | |
"\n", | |
"assert treemap(lambda x: x + 1, None) == None\n", | |
"assert treemap(lambda x: x + 1, []) == []\n", | |
"assert treemap(lambda x: x + 1, 1) == 2\n", | |
"assert treemap(lambda x: x + 1, [1]) == [2]\n", | |
"assert treemap(lambda x: x + 1, [1, [2, [3, 4]]]) == [2, [3, [4, 5]]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "eecbe813", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def linkOperator(dual, accumulated, store):\n", | |
" \"\"\"LinkOperator takes on a dual, accumulated gradient value, and a store.\n", | |
" It returns a new store representing the new gradient collected.\n", | |
"\n", | |
" This is just a function notation, the real implementation could be found below.\n", | |
" \"\"\"\n", | |
" pass\n", | |
"\n", | |
"\n", | |
"def endOperator(dual, accumulated, store):\n", | |
" \"\"\"Special LinkOperator to save the accumulated gradient value in a given store.\"\"\"\n", | |
" store[dual] += accumulated\n", | |
" return store\n", | |
"\n", | |
"\n", | |
"class Dual:\n", | |
" def __init__(self, value, link=endOperator, name=None):\n", | |
" self.value = value # numeric value, int / float\n", | |
"\n", | |
" # link is a function which captures the operator's gradient calculation\n", | |
" # on the chain\n", | |
" self.link = link \n", | |
"\n", | |
" if name is None:\n", | |
" self.name = str(value)\n", | |
" else:\n", | |
" self.name = name\n", | |
"\n", | |
" def truncate(self):\n", | |
" \"\"\"Truncate reset the link information for the dual.\"\"\"\n", | |
" self.link = endOperator\n", | |
" return self\n", | |
"\n", | |
" def grad(self, store):\n", | |
" return self.link(self, 1.0, store)\n", | |
"\n", | |
"\n", | |
" def __add__(self, other): # dual, dual -> dual\n", | |
" def dadd(dual, accumulated, store):\n", | |
" a = self\n", | |
" b = other\n", | |
" newStore = a.link(a, (accumulated * 1.0), store)\n", | |
" newStore = b.link(b, (accumulated * 1.0), newStore)\n", | |
" return newStore\n", | |
" \n", | |
" name = f\"{self.value} + {other.value}\"\n", | |
" return self.__class__(self.value + other.value, dadd, name=name)\n", | |
" \n", | |
" def __str__(self):\n", | |
" return f\"[Dual(name={self.name}, val:{self.value})]\"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"id": "b8d6ccdd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def Del(fn, theta):\n", | |
" \"\"\"\n", | |
" Del is an operator, which takes in a function and a list of parameters\n", | |
" and return corresponding gradients.\n", | |
" \"\"\"\n", | |
" theta = treemap(lambda v: Dual(v), theta)\n", | |
" store = defaultdict(lambda: 0.0)\n", | |
" \n", | |
" result = fn(*theta)\n", | |
" # print(result)\n", | |
" \n", | |
" result.grad(store)\n", | |
"\n", | |
" # Return in respect to theta's shape\n", | |
" return treemap(lambda t: store[t], theta)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "8fd0be3a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[1.0, 1.0], [2.0, 1.0]]" | |
] | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"[\n", | |
" Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0]\n", | |
" Del(lambda x, y: (x + x) + y, [3.0, 2.0]), # => [2.0, 1.0]\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"id": "454fcc0f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def prim2(opName, valueFn, gradientFn):\n", | |
" \"\"\"Returns derivitive link function for binary operator: +,-,*,/,exp.\n", | |
"\n", | |
" valueFn returns the numeric value\n", | |
" gradientFn returns the corresponding derivitive value according to formula\n", | |
"\n", | |
" \"\"\"\n", | |
" def binaryOp(self, other):\n", | |
"\n", | |
" if isinstance(other, int) or isinstance(other, float):\n", | |
" other = Dual(other)\n", | |
"\n", | |
" assert (type(self) is Dual)\n", | |
" assert (type(other) is Dual)\n", | |
" \n", | |
" def dop(dual, accumulated, store):\n", | |
" a = self\n", | |
" b = other\n", | |
" \n", | |
" ga, gb = gradientFn(a.value, b.value, accumulated)\n", | |
"\n", | |
" newStore = a.link(a, ga, store)\n", | |
" newStore = b.link(b, gb, newStore)\n", | |
" return newStore\n", | |
"\n", | |
" name = f\"{self.value} {opName} {other.value}\"\n", | |
"\n", | |
" return self.__class__(valueFn(self.value, other.value), dop, name=name)\n", | |
"\n", | |
" return binaryOp\n", | |
"\n", | |
"def prim1(opName, valueFn, gradientFn):\n", | |
" \"\"\"Returns derivitive link function for binary operator: +,-,*,/,exp.\n", | |
"\n", | |
" valueFn returns the numeric value\n", | |
" gradientFn returns the corresponding derivitive value according to formula\n", | |
"\n", | |
" \"\"\"\n", | |
"\n", | |
" def unaryOp(self):\n", | |
" def dop(dual, accumulated, store):\n", | |
" a = self\n", | |
" ga = gradientFn(a.value, accumulated)\n", | |
" return a.link(a, (accumulated * ga), store)\n", | |
"\n", | |
" name = f\"{opName}({self.value})\"\n", | |
"\n", | |
" return self.__class__(valueFn(self.value), dop, name=name)\n", | |
"\n", | |
" return unaryOp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"id": "fefb265e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Dual:\n", | |
" def __init__(self, value, link=endOperator, name=None):\n", | |
" self.value = value\n", | |
"\n", | |
" # link is a function which captures the operator's gradient calculation\n", | |
" # on the chain\n", | |
" self.link = link\n", | |
"\n", | |
" if name is None:\n", | |
" self.name = str(value)\n", | |
" else:\n", | |
" self.name = name\n", | |
"\n", | |
" def truncate(self):\n", | |
" \"\"\"Truncate reset the link information for the dual.\"\"\"\n", | |
" self.link = endOperator\n", | |
" return self\n", | |
"\n", | |
" def grad(self, store):\n", | |
" return self.link(self, 1.0, store)\n", | |
"\n", | |
" # Binary operator\n", | |
" __add__ = prim2(\"+\", lambda a, b: a + b, lambda a, b, z: (z, z))\n", | |
" __radd__ = prim2(\"+\", lambda b, a: a + b, lambda b, a, z: (z, z))\n", | |
"\n", | |
" __sub__ = prim2(\"-\", lambda a, b: a - b, lambda a, b, z: (z, -z))\n", | |
" __rsub__ = prim2(\"-\", lambda b, a: b - a, lambda b, a, z: (-z, z))\n", | |
"\n", | |
" __mul__ = prim2(\"*\", lambda a, b: a * b, lambda a, b, z: ((b * z), (a * z)))\n", | |
" __rmul__ = prim2(\"*\", lambda b, a: a * b, lambda b, a, z: ((a * z), (b * z)))\n", | |
"\n", | |
" __truediv__ = prim2(\n", | |
" \"/\", lambda a, b: a / b, lambda a, b, z: ((z * (1.0 / b)), (z * (-a / (b * b))))\n", | |
" )\n", | |
" __rtruediv__ = prim2(\n", | |
" \"/\", lambda b, a: a / b, lambda b, a, z: ((z * (-a / (b * b))), (z * (1.0 / b)))\n", | |
" )\n", | |
"\n", | |
" # dx^y/dx = y * x^(y-1)\n", | |
" # dx^y/dy = x^y * ln(x)\n", | |
" __pow__ = prim2(\n", | |
" \"^\",\n", | |
" lambda a, b: a**b,\n", | |
" lambda a, b, z: ((z * (b * (a ** (b - 1)))), (z * (a**b * math.log(a)))),\n", | |
" )\n", | |
" \n", | |
" # Unary operator\n", | |
" log = prim1(\"log\", lambda a: math.log(a), lambda a, z: z * (1.0 / a))\n", | |
" exp = prim1(\"exp\", lambda a: math.exp(a), lambda a, z: z * math.exp(a))\n", | |
" sqrt = prim1(\"sqrt\", lambda a: math.sqrt(a), lambda a, z: z / (2 * math.sqrt(a)))\n", | |
"\n", | |
"\n", | |
" def __str__(self):\n", | |
" return f\"[Dual(name={self.name}, val:{self.value})]\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"id": "6e525ae9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[1.0, 1.0],\n", | |
" [2.0, 1.0],\n", | |
" [3.0, 2.0],\n", | |
" [1.0, -1.0],\n", | |
" [0.3333333333333333, -0.2222222222222222],\n", | |
" [12.0, 5.545177444479562],\n", | |
" [7.38905609893065],\n", | |
" [0.5],\n", | |
" [0.35355339059327373],\n", | |
" [7.0, 1.0],\n", | |
" [2.4847079713764115, 46.51502816261462]]" | |
] | |
}, | |
"execution_count": 48, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"[\n", | |
" # Basic binary operators\n", | |
" Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0]\n", | |
" Del(lambda x, y: (x + x) + y, [3.0, 2.0]), # => [2.0, 1.0]\n", | |
" Del(lambda x, y: x * y, [2.0, 3.0]), # => [3.0, 2.0]\n", | |
" Del(lambda x, y: x - y, [2.0, 3.0]), # => [1.0, -1.0]\n", | |
" Del(lambda x, y: x / y, [2.0, 3.0]), # => [0.3333333, -0.2222]\n", | |
" Del(lambda x, y: x ** y, [2.0, 3.0]), # => [12.0, 5.54517]\n", | |
" \n", | |
" # Unary operator\n", | |
" Del(lambda x: x.exp(), [2.0]), # => e^2\n", | |
" Del(lambda x: x.log(), [2.0]), # => 0.5\n", | |
" Del(lambda x: x.sqrt(), [2.0]), # => 0.3535\n", | |
" \n", | |
" # Composed function\n", | |
" Del(lambda x, y: (x * x) + x + y, [3.0, 2.0]), # => [7.0, 1.0]\n", | |
" \n", | |
" # fn = (e^x + log(y) + sqrt(x))/y\n", | |
" Del(lambda x, y: (x.exp() * y.log() + x.sqrt())/y, [3.0, 2.0]),\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "ccb9a9e3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# What's going from there?\n", | |
"# - [x] support basic scalar; merge type of number and dual (medium)\n", | |
"# Del(lambda x, b: (2.0 * x) + 1.0 * x + b, [3.0, 2.0]), it would blow up now\n", | |
"# - [x] right hand operator\n", | |
"# - other triangle operator (easy), sin, csin\n", | |
"# - support matrix calculus (hard, maybe)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"id": "a6aaab01", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[13.0, 1.0], [1.0], [-1.0], [2.0], [-0.2222222222222222], [13.0, 1.0]]" | |
] | |
}, | |
"execution_count": 70, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# y = 2x^2 + x + b\n", | |
"[\n", | |
" # support python int and float\n", | |
" Del(lambda x, b: (x * x * 2.0) + x * 1 + b, [3.0, 2.0]),\n", | |
"\n", | |
" # right hand operator\n", | |
" Del(lambda x: 2.0 + x, [3.0]),\n", | |
" Del(lambda x: 2.0 - x, [3.0]),\n", | |
" Del(lambda x: 2.0 * x, [3.0]),\n", | |
" Del(lambda x: 2.0 / x, [3.0]),\n", | |
" Del(lambda x, b: (2.0 * x * x) + 1 * x + b, [3.0, 2.0]),\n", | |
"]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.16" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment