Last active
September 5, 2024 00:32
-
-
Save elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d to your computer and use it in GitHub Desktop.
Automatic Differentiation with Kotlin
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Implementation of backward-mode automatic differentiation. | |
*/ | |
/** | |
* Differentiable variable with value and derivative of differentiation ([grad]) result | |
* with respect to this variable. | |
*/ | |
data class D(var x: Double, var d: Double = 0.0) { | |
constructor(x: Int): this(x.toDouble()) | |
} | |
/** | |
* Runs differentiation and establishes [AD] context inside the block of code. | |
* | |
* Example: | |
* ``` | |
* val x = D(2) // define variable(s) and their values | |
* val y = grad { sqr(x) + 5 * x + 3 } // write formulate in grad context | |
* assertEquals(17.0, y.x) // the value of result (y) | |
* assertEquals(9.0, x.d) // dy/dx | |
* ``` | |
*/ | |
fun grad(body: AD.() -> D): D = | |
ADImpl().run { | |
val result = body() | |
result.d = 1.0 // computing derivative w.r.t result | |
runBackwardPass() | |
result | |
} | |
/** | |
* Automatic Differentiation context class. | |
*/ | |
abstract class AD { | |
/** | |
* Performs update of derivative after the rest of the formula in the back-pass. | |
* | |
* For example, implementation of `sin` function is: | |
* | |
* ``` | |
* fun AD.sin(x: D): D = derive(D(sin(x.x)) { z -> // call derive with function result | |
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function | |
* } | |
* ``` | |
*/ | |
abstract fun <R> derive(value: R, block: (R) -> Unit): R | |
// Basic math (+, -, *, /) | |
operator fun D.plus(that: D): D = derive(D(this.x + that.x)) { z -> | |
this.d += z.d | |
that.d += z.d | |
} | |
operator fun D.minus(that: D): D = derive(D(this.x - that.x)) { z -> | |
this.d += z.d | |
that.d -= z.d | |
} | |
operator fun D.times(that: D): D = derive(D(this.x * that.x)) { z -> | |
this.d += z.d * that.x | |
that.d += z.d * this.x | |
} | |
operator fun D.div(that: D): D = derive(D(this.x / that.x)) { z -> | |
this.d += z.d / that.x | |
that.d -= z.d * this.x / (that.x * that.x) | |
} | |
// Overloads for Double constants | |
operator fun Double.plus(that: D): D = derive(D(this + that.x)) { z -> | |
that.d += z.d | |
} | |
operator fun D.plus(b: Double): D = b.plus(this) | |
operator fun Double.minus(that: D): D = derive(D(this - that.x)) { z -> | |
that.d -= z.d | |
} | |
operator fun D.minus(that: Double): D = derive(D(this.x - that)) { z -> | |
this.d += z.d | |
} | |
operator fun Double.times(that: D): D = derive(D(this * that.x)) { z -> | |
that.d += z.d * this | |
} | |
operator fun D.times(b: Double): D = b.times(this) | |
operator fun Double.div(that: D): D = derive(D(this / that.x)) { z -> | |
that.d -= z.d * this / (that.x * that.x) | |
} | |
operator fun D.div(that: Double): D = derive(D(this.x / that)) { z -> | |
this.d += z.d / that | |
} | |
// Overloads for Int constants | |
operator fun Int.plus(b: D): D = toDouble().plus(b) | |
operator fun D.plus(b: Int): D = plus(b.toDouble()) | |
operator fun Int.minus(b: D): D = toDouble().minus(b) | |
operator fun D.minus(b: Int): D = minus(b.toDouble()) | |
operator fun Int.times(b: D): D = toDouble().times(b) | |
operator fun D.times(b: Int): D = times(b.toDouble()) | |
operator fun Int.div(b: D): D = toDouble().div(b) | |
operator fun D.div(b: Int): D = div(b.toDouble()) | |
} | |
// ---------------------------------------- ENGINE IMPLEMENTATION ---------------------------------------- | |
// Private implementation class | |
private class ADImpl : AD() { | |
// this stack contains pairs of blocks and values to apply them to | |
private var stack = arrayOfNulls<Any?>(8) | |
private var sp = 0 | |
@Suppress("UNCHECKED_CAST") | |
override fun <R> derive(value: R, block: (R) -> Unit): R { | |
// save block to stack for backward pass | |
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) | |
stack[sp++] = block | |
stack[sp++] = value | |
return value | |
} | |
@Suppress("UNCHECKED_CAST") | |
fun runBackwardPass() { | |
while (sp > 0) { | |
val value = stack[--sp] | |
val block = stack[--sp] as (Any?) -> Unit | |
block(value) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import kotlin.math.* | |
// Extensions for differentiation of various basic mathematical functions | |
// x ^ 2 | |
fun AD.sqr(x: D): D = derive(D(x.x * x.x)) { z -> | |
x.d += z.d * 2 * x.x | |
} | |
// x ^ 1/2 | |
fun AD.sqrt(x: D): D = derive(D(sqrt(x.x))) { z -> | |
x.d += z.d * 0.5 / z.x | |
} | |
// x ^ y (const) | |
fun AD.pow(x: D, y: Double): D = derive(D(x.x.pow(y))) { z -> | |
x.d += z.d * y * x.x.pow(y - 1) | |
} | |
fun AD.pow(x: D, y: Int): D = pow(x, y.toDouble()) | |
// exp(x) | |
fun AD.exp(x: D): D = derive(D(exp(x.x))) { z -> | |
x.d += z.d * z.x | |
} | |
// ln(x) | |
fun AD.ln(x: D): D = derive(D(ln(x.x))) { z -> | |
x.d += z.d / x.x | |
} | |
// x ^ y (any) | |
fun AD.pow(x: D, y: D): D = exp(y * ln(x)) | |
// sin(x) | |
fun AD.sin(x: D): D = derive(D(sin(x.x))) { z -> | |
x.d += z.d * cos(x.x) | |
} | |
// cos(x) | |
fun AD.cos(x: D): D = derive(D(cos(x.x))) { z -> | |
x.d -= z.d * sin(x.x) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.* | |
import kotlin.math.* | |
import kotlin.test.* | |
class ADTest { | |
@Test | |
fun testPlusX2() { | |
val x = D(3) // diff w.r.t this x at 3 | |
val y = grad { x + x } | |
assertEquals(6.0, y.x) // y = x + x = 6 | |
assertEquals(2.0, x.d) // dy/dx = 2 | |
} | |
@Test | |
fun testPlus() { | |
// two variables | |
val x = D(2) | |
val y = D(3) | |
val z = grad { x + y } | |
assertEquals(5.0, z.x) // z = x + y = 5 | |
assertEquals(1.0, x.d) // dz/dx = 1 | |
assertEquals(1.0, y.d) // dz/dy = 1 | |
} | |
@Test | |
fun testMinus() { | |
// two variables | |
val x = D(7) | |
val y = D(3) | |
val z = grad { x - y } | |
assertEquals(4.0, z.x) // z = x - y = 4 | |
assertEquals(1.0, x.d) // dz/dx = 1 | |
assertEquals(-1.0, y.d) // dz/dy = -1 | |
} | |
@Test | |
fun testMulX2() { | |
val x = D(3) // diff w.r.t this x at 3 | |
val y = grad { x * x } | |
assertEquals(9.0, y.x) // y = x * x = 9 | |
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 | |
} | |
@Test | |
fun testSqr() { | |
val x = D(3) | |
val y = grad { sqr(x) } | |
assertEquals(9.0, y.x) // y = x ^ 2 = 9 | |
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 | |
} | |
@Test | |
fun testSqrSqr() { | |
val x = D(2) | |
val y = grad { sqr(sqr(x)) } | |
assertEquals(16.0, y.x) // y = x ^ 4 = 16 | |
assertEquals(32.0, x.d) // dy/dx = 4 * x^3 = 32 | |
} | |
@Test | |
fun testX3() { | |
val x = D(2) // diff w.r.t this x at 2 | |
val y = grad { x * x * x } | |
assertEquals(8.0, y.x) // y = x * x * x = 8 | |
assertEquals(12.0, x.d) // dy/dx = 3 * x * x = 12 | |
} | |
@Test | |
fun testDiv() { | |
val x = D(5) | |
val y = D(2) | |
val z = grad { x / y } | |
assertEquals(2.5, z.x) // z = x / y = 2.5 | |
assertEquals(0.5, x.d) // dz/dx = 1 / y = 0.5 | |
assertEquals(-1.25, y.d) // dz/dy = -x / y^2 = -1.25 | |
} | |
@Test | |
fun testPow3() { | |
val x = D(2) // diff w.r.t this x at 2 | |
val y = grad { pow(x, 3) } | |
assertEquals(8.0, y.x) // y = x ^ 3 = 8 | |
assertEquals(12.0, x.d) // dy/dx = 3 * x ^ 2 = 12 | |
} | |
@Test | |
fun testPowFull() { | |
val x = D(2) | |
val y = D(3) | |
val z = grad { pow(x, y) } | |
assertApprox(8.0, z.x) // z = x ^ y = 8 | |
assertApprox(12.0, x.d) // dz/dx = y * x ^ (y - 1) = 12 | |
assertApprox(8.0 * ln(2.0), y.d) // dz/dy = x ^ y * ln(x) | |
} | |
@Test | |
fun testFromPaper() { | |
val x = D(3) | |
val y = grad { 2 * x + x * x * x } | |
assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33 | |
assertEquals(29.0, x.d) // dy/dx = 2 + 3 * x * x = 29 | |
} | |
@Test | |
fun testLongChain() { | |
val n = 10_000 | |
val x = D(1) | |
val y = grad { | |
var pow = D(1) | |
for (i in 1..n) pow *= x | |
pow | |
} | |
assertEquals(1.0, y.x) // y = x ^ n = 1 | |
assertEquals(n.toDouble(), x.d) // dy/dx = n * x ^ (n - 1) = n - 1 | |
} | |
@Test | |
fun testExample() { | |
val x = D(2) | |
val y = grad { sqr(x) + 5 * x + 3 } | |
assertEquals(17.0, y.x) // the value of result (y) | |
assertEquals(9.0, x.d) // dy/dx | |
} | |
@Test | |
fun testSqrt() { | |
val x = D(16) | |
val y = grad { sqrt(x) } | |
assertEquals(4.0, y.x) // y = x ^ 1/2 = 4 | |
assertEquals(1.0 / 8, x.d) // dy/dx = 1/2 / x ^ 1/4 = 1/8 | |
} | |
@Test | |
fun testSin() { | |
val x = D(PI / 6) | |
val y = grad { sin(x) } | |
assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5 | |
assertApprox(sqrt(3.0) / 2, x.d) // dy/dx = cos(PI/6) = sqrt(3)/2 | |
} | |
@Test | |
fun testCos() { | |
val x = D(PI / 6) | |
val y = grad { cos(x) } | |
assertApprox(sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2 | |
assertApprox(-0.5, x.d) // dy/dx = -sin(PI/6) = -0.5 | |
} | |
private fun assertApprox(a: Double, b: Double) { | |
if ((a - b) > 1e-10) assertEquals(a, b) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment