Skip to content

Instantly share code, notes, and snippets.

@MasonProtter
Last active March 6, 2018 02:53
Show Gist options
  • Save MasonProtter/e9dc900a98a0aa5938e62bf8f7b58b92 to your computer and use it in GitHub Desktop.
Save MasonProtter/e9dc900a98a0aa5938e62bf8f7b58b92 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Naïve Symbolic Automatic Differentiation in Julia\n",
"\n",
"In this document I'll show a quick naïve way of implementing a very basic computer algebra system in Julia which can do symbolic derivatives using automatic differentiation. I'll assume a basic knowldege of working with Julia expressions.\n",
"\n",
"Before we discuss derivatives, lets quickly build a way to do some basic symbolic math. \n",
"\n",
"First off, we need a data type for symbols and symbolic expressions. The easiest choice is to use the built in `Symbol` and `Expr` types, but if we were to to turn this code into a package people would yell that defining new methods on functions from `Base` on types from `Base` is heresy. \n",
"\n",
"Nonetheless, I don't feel like building my own versions of `Symbol` and `Expr` right now so I'll just make aliases for those types called `Sym` and `SymExpr` and make all my methods in terms of those so that later if I want to avoid type heresy, I just have to change the definitions of `Sym` and `SymExpr` and don't have to worry about all the methods I define.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"Sym = Symbol\n",
"SymExpr = Expr\n",
"\n",
"symExpr(x) = x\n",
"mathy = Union{Sym,SymExpr,Number};"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, now we have `mathy` which is just the union of `Sym`, `SymExpr` and `Number` so we can define methods for some of the standard mathematical functions"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"function Base.:+(x::mathy, y::mathy)\n",
" if x == y\n",
" symExpr(:(2*$x))\n",
" elseif x == -y\n",
" 0\n",
" elseif x == 0\n",
" y\n",
" elseif y == 0\n",
" x\n",
" else\n",
" symExpr(:($x+$y))\n",
" end\n",
"end\n",
"\n",
"Base.:+(x::mathy) = x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The with these methods on the `+` operator, we can do things like \n",
"```julia\n",
"julia> :x + :y\n",
":(x + y)\n",
"\n",
"julia> :x - :y\n",
":(x - y)\n",
"\n",
"julia> :x + :x\n",
":(2x)\n",
"\n",
"julia> +:x\n",
":x\n",
"```\n",
"\n",
"Notice that in the final `else` statement, I use `symExpr(:($x+$y))` which at this point just returns `:($x+$y)`, ie. `symExpr` is just an identiy function on `Expr` types. Later, if we make our own `SymExpr` type to avoide type piracy, then we just need to modify `symExpr` to construct an object of type `SymExpr` instead of `Expr` and we won't need to modify any of our arithmetic functions!\n",
"\n",
"Now we can go and do the same for the subtration operator"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"function Base.:-(x::mathy, y::mathy)\n",
" if x == y\n",
" 0\n",
" elseif x == -y\n",
" symExpr(:(2$x))\n",
" elseif x == 0\n",
" -y\n",
" elseif y == 0\n",
" x\n",
" else\n",
" symExpr(:($x-$y))\n",
" end\n",
"end\n",
"\n",
"function Base.:-(x::Sym)\n",
" symExpr(:(-$x))\n",
"end\n",
"\n",
"isUnaryOperation(ex::SymExpr) = length(ex.args) == 2\n",
"car(x::SymExpr) = x.args[1]\n",
"\n",
"function Base.:-(x::SymExpr)\n",
" if (car(x) == :-) && (x |> isUnaryOperation)\n",
" symExpr(x.args[2])\n",
" else\n",
" symExpr(:(-$x))\n",
" end\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first method defined above tells us how to do basic subtration\n",
"```julia\n",
"julia> :x - :y\n",
":(x - y)\n",
"\n",
"julia> :x - :x\n",
"0\n",
"```\n",
"\n",
"and the second two tell us how to negate a single argument, ie.\n",
"```julia\n",
"julia> -:x\n",
":(-1x)\n",
"\n",
"julia> -(-(:x + 1))\n",
":(x + 1)\n",
"```\n",
"\n",
"Now we can go and do similar things for multiplication, division, exponentitation and logarithms\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"function Base.:*(x::mathy,y::mathy)\n",
" if x == y\n",
" symExpr(:($x^2))\n",
" elseif x == -y\n",
" symExpr(:(-$x^2))\n",
" elseif x == 1\n",
" y\n",
" elseif y == 1\n",
" x\n",
" elseif (x == 0) || (y == 0)\n",
" 0\n",
" else\n",
" symExpr(:($x*$y))\n",
" end\n",
"end\n",
"\n",
"function Base.:/(x::mathy, y::mathy)\n",
" if x == y\n",
" 1\n",
" elseif x == -y\n",
" -1\n",
" elseif y == 1\n",
" x\n",
" else\n",
" symExpr(:($x/$y))\n",
" end\n",
"end\n",
"\n",
"function Base.:^(x::mathy, y::mathy)\n",
" symExpr(:($x^$y))\n",
"end\n",
"\n",
"Base.:^(x::mathy, y::Int) = y == 1 ? x : symExpr(:($x^$y))\n",
"\n",
"Base.log(x::mathy) = symExpr(:(log($x)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can write down all sorts of complicated symbolic expressions\n",
"```julia\n",
"julia> log(:x^4 + x^4)/:y + :y^(5(:x))\n",
":(log(2 * x ^ 4) / y + y ^ (5x))\n",
"```\n",
"\n",
"#### Exercise:\n",
"\n",
"Try to follow the conventions used above to define symbolic versions of the trigonemtric functions\n",
"\n",
"\n",
"## Automatic Differentiation\n",
"Now we've built up an basic computer algebra system. A common use for such systems is the computation of derivatives. The conventional way to do this is by defining a set of rules for transforming an expression into its symbolic derivative. This technique is usually known simply as 'symbolic differentiation'. A more interesting technique for achieving the same end is called 'automatic differentiation'.\n",
"\n",
"Recall from first year calculus that the derivative of a function $f(x)$ is defined as\n",
"\n",
"$$\n",
"f'(x) \\equiv \\lim_{\\Delta x \\rightarrow 0} \\frac{f(x+\\Delta x) - f(x)}{\\Delta x}\n",
"$$\n",
"\n",
"We can also recall Taylor's theorem which states that for a well-behaved function $f(x)$, \n",
"\n",
"$$\n",
"f(x + \\Delta x) = f(x) + f'(x) ~ \\Delta x + \\frac{1}{2}~ f''(x) ~ (\\Delta x)^2 + \\frac{1}{6}~ f'''(x) ~ (\\Delta x)^3 + \\mathcal{O}(\\Delta x ^4)\n",
"$$\n",
"\n",
"So now, lets imagine a $\\Delta x$ which is so small that $\\Delta x^2$ is $0$ and we'll call this a 'differential' $dx$.\n",
"\n",
"So now Taylor's theorm simply states that \n",
"\n",
"$$\n",
"f(x + dx) = f(x) + f'(x) ~ dx\n",
"$$\n",
"\n",
"where we can imagine $dx$ as an infinitesimally small number. In standard numerical differentiation, one uses a small floating point number not far above the minimum precision of the computer as their value for $dx$. Alternatively, we can make use of Julia's type system and define a new type which represents quantities of the form $x + dx$.\n",
"\n",
"The idea here is actually very similar to the construction of complex numbers. With complex numbers, one simply says suppose there was such a number $i$ such that $i^2 = -1$ and then defines a type (or in mathematical language, an algebra) `Complex` ($\\mathbb{C}$) and defines what basic functions like `+, -, *, /, ^, log` do when acting on objects of type `Complex` (elements of $\\mathbb{C}$).\n",
"\n",
"Similarly, we want to define a new number $\\epsilon$ such that $\\epsilon^2 = 0$. Hence, we can make a type `Differential` and then define methods on the `+, -, *, /, ^, log` functions which will give them mathematically correct results.\n",
"\n",
"So first we make a new type:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"x + yϵ"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type Differential\n",
" finite::mathy\n",
" differential::mathy\n",
"end\n",
"\n",
"differential(x,y) = y == 0 ? x : Differential(x,y) \n",
"\n",
"function Base.show(io::IO, x::Differential)\n",
" if (x.finite==0) && (x.differential==0)\n",
" print(io,\"0\")\n",
" else\n",
" finiteStr = (x.finite == 0 ? \"\" : \"$(x.finite) + \")\n",
" diffStr = (x.differential == 0 ? \"\" : \n",
" x.differential == 1 ? \"\" : \n",
" x.differential isa Expr ? \"($(x.differential))\" : \"$(x.differential)\")\n",
" print(io,finiteStr*diffStr*\"ϵ\")\n",
" end\n",
"end\n",
"\n",
"differential(:x,:y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The last function shown above just tells Julia how we'd like it to display functions of type `Differential`.\n",
"```julia\n",
"julia> differential(:x, :y)\n",
"x + yϵ\n",
"\n",
"julia> differential(0, (:x^2 + :y))\n",
"(x^2 + y)ϵ\n",
"\n",
"julia> differential(:x, 1)\n",
"x + ϵ\n",
"```\n",
"\n",
"Now we can define a new union type that includes `Differential`s and some helper functions for extracting out the `finite` and infinitesimal parts of a quantity."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"mathyDiff = Union{mathy,Differential}\n",
"\n",
"finitePart(x::mathyDiff) = x isa Differential ? x.finite : x\n",
"diffPart(x::mathyDiff) = x isa Differential ? x.differential : 0;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To see these in action\n",
"```julia\n",
"julia> z = differential(:x, :y);\n",
"julia> finitePart(z)\n",
":x\n",
"\n",
"julia> diffPart(z)\n",
":y\n",
"\n",
"julia> finitePart(1.05)\n",
"1.05\n",
"\n",
"julia> diffPart(1.05)\n",
"0\n",
"```\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now recall that \n",
"\n",
"$$f'(x) = {(f(x+ϵ) - f(x)) \\over ϵ}~~.$$\n",
"\n",
"We can then make a Julian derivative function"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"D(f::Function) = x -> diffPart(f(x + ϵ));"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now once Julia has methods for the standard mathematical functions so that they know how to deal with `Differential`s, our function `D` will perform automatic differentiation!\n",
"\n",
"Addition and aubtration are straightforward, you simply add (subtract) the finite parts and the differential parts together separately:"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"function Base.:+(x::mathyDiff,y::mathyDiff)\n",
" differential(finitePart(x) + finitePart(y), diffPart(x) + diffPart(y))\n",
"end\n",
"\n",
"function Base.:-(x::mathyDiff,y::mathyDiff)\n",
" differential(finitePart(x) - finitePart(y), diffPart(x) - diffPart(y))\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now for multiplication. This is pretty strightforward, we simply use the rule that $\\epsilon^2 = 0$. So lets suppose we have two `Differential` numbers, $u = x + y~ \\epsilon$ and $v = z + w~ \\epsilon$.\n",
"\n",
"\\begin{align}\n",
"u * v &= (x + y~\\epsilon)(z + w~\\epsilon)\\\\\n",
"&= xz + (xw + yz)\\epsilon + yw \\epsilon^2\\\\\n",
"&= xz + (xw + yz)\\epsilon\n",
"\\end{align}\n",
"\n",
"In Julia we can express this proceedure as "
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"function Base.:*(x::mathyDiff, y::mathyDiff)\n",
" differential(finitePart(x)*finitePart(y), finitePart(x)*diffPart(y) + diffPart(x)*finitePart(y))\n",
"end\n",
"\n",
"ϵ = differential(0,1);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that definition, Julia now knows how to take derivatives of functions involving multiplication and we get the product rule for free!\n",
"\n",
"```julia\n",
"julia> f(x) = 2*x;\n",
"julia> g(x) = 4*x;\n",
"julia> h(x) = f(x)*g(x);\n",
"\n",
"julia> f(:x+ϵ)\n",
"2*x + 2ϵ\n",
"\n",
"julia> D(f)(:x)\n",
"2\n",
"\n",
"julia> g(:x + ϵ)\n",
"4*x + ϵ\n",
"\n",
"julia> D(g)(:x)\n",
"4\n",
"\n",
"julia> h(:x + ϵ)\n",
"(2x) * (4x) + ((2x) * 4 + 2 * (4x))ϵ\n",
"\n",
"julia> D(h)(:x)\n",
":((2x) * 4 + 2 * (4x))\n",
"```\n",
"\n",
"We can see here that our automatic derivative function has correctly performed the chain rule (though the simplification of the result left a little to be desired).\n",
"\n",
"We can teach Julia how to use the quotient rule as follows:"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"function Base.:/(x::mathyDiff, y::mathyDiff)\n",
" (finitePart(x)/finitePart(y) + ϵ*(finitePart(x)*diffPart(y)/finitePart(y)^2)) + ϵ*(diffPart(x)/finitePart(y))\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This simply says that\n",
"$$\n",
"\\frac{x}{y + \\epsilon} = {x\\over y}+ \\frac{x}{y^2}\\epsilon\n",
"$$\n",
"\n",
"Now we can do \n",
"```julia\n",
"julia> D(x -> 1/x)(:x)\n",
"1/:x^2\n",
"```\n",
"\n",
"Likewise, we can teach Julia how to deal with exponents"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"function Base.:^(x::Differential, y::mathy)\n",
" finitePart(x)^y + y*finitePart(x)^(y-1)*diffPart(x)*ϵ\n",
"end\n",
"\n",
"function Base.:^(x::Differential, y::Int)\n",
" finitePart(x)^y + y*finitePart(x)^(y-1)*diffPart(x)*ϵ\n",
"end\n",
" \n",
"function Base.:^(x::mathy, y::Differential)\n",
" x^finitePart(y) + log(x)*x^finitePart(y)*diffPart(y)*ϵ\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can take derivatives of exponential functions! \n",
"```julia\n",
"julia> D(x -> x^2)(:x)\n",
":(2x)\n",
"\n",
"\n",
"julia> D(x -> 4*x^5)(:x)\n",
":(4 * (5 * x ^ 4))\n",
"\n",
"julia> D(x -> (1 + x)*x^-4 )(:x)\n",
":((1 + x) * (-4 * x ^ -5) + x ^ -4)\n",
"\n",
"D(x -> 2^x)(:x)\n",
":(0.6931471805599453 * 2 ^ x)\n",
"\n",
"D(x -> x*2^x)(:x)\n",
":(x * (0.6931471805599453 * 2 ^ x) + 2 ^ x)\n",
"```\n",
"\n",
"### Exercise:\n",
"You may notice that I have left out a method allowing Julia to take derivatives of the form \n",
"```julia\n",
"julia> D(x -> x^x)(:x)\n",
"\n",
"MethodError: no method matching ^(::Differential, ::Differential)\n",
"Closest candidates are:\n",
" ^(::Differential, ::Int64) at In[91]:6\n",
" ^(::Differential, ::Union{Expr, Number, Symbol}) at In[91]:2\n",
" ^(::Union{Expr, Number, Symbol}, ::Differential) at In[91]:10\n",
" ...\n",
"\n",
"Stacktrace:\n",
" [1] (::##7#8{##23#24})(::Symbol) at ./In[73]:1\n",
" [2] include_string(::String, ::String) at ./loading.jl:522\n",
"```\n",
"\n",
"To make such a derivative work, one needs to generalize the definition of `^` to take two `Differential` arguments. This isn't difficult but it is messy so its left as an exercise to the reader. \n",
"\n",
"________"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, knowing that $\\frac{d~log(x)}{dx} = {1 \\over x}$, we can teach `log` to take `Differential aguments`:"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"function Base.log(x::Differential)\n",
" log(finitePart(x)) + diffPart(x)/finitePart(x)*ϵ\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and we can check if this gives us the correct behaviour:\n",
"```julia\n",
"D(x -> log(x))(:x)\n",
":(1 / x)\n",
"\n",
"D(x -> log(x)^2 + 4*x)(:x)\n",
":((2 * log(x)) * (1 / x) + 4)\n",
"```\n",
"\n",
"Great! With a little know-how and perserverence we've made a rudiementary symbolic math system and taught it how to take derivatives in Julian style!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exercise:\n",
"If you did the earlier exerise of definiting `mathy` methods for the standard trigonometric functions, I suggest you see if you can define `Differential` methods for them so that you can take their derivatives."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.6.1",
"language": "julia",
"name": "julia-0.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment