Skip to content

Instantly share code, notes, and snippets.

@kmdupr33
Last active September 5, 2019 12:36
Show Gist options
  • Save kmdupr33/8f34c8e1b8c256cde508b19ecbed0816 to your computer and use it in GitHub Desktop.
Save kmdupr33/8f34c8e1b8c256cde508b19ecbed0816 to your computer and use it in GitHub Desktop.
gradient descent in kotlin
fun findLearningParameters(videoGameData: Array<Pair<Double, Double>>): Pair<Double, Double> {
//...
return Pair(m, b)
}
import com.squareup.moshi.JsonAdapter
import com.squareup.moshi.Moshi
import com.squareup.moshi.Types
import org.junit.Test
import kotlin.math.pow
class Tests {
fun updateGuess(
guess: Pair<Double, Double>,
learningRate: Double,
videoGameData: Array<Pair<Double, Double>>
): Pair<Double, Double> {
val (m, b) = guess
val mGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
val (x, y) = pair
acc + ((((m * x) + b) - y) * x)
} / videoGameData.size
val bGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
val (x, y) = pair
acc + ((((m * x) + b) - y))
} / videoGameData.size
val newM = m - (learningRate * mGradient)
val newB = b - (learningRate * bGradient)
return Pair(newM, newB)
}
fun cost(videoGameData: Array<Pair<Double, Double>>, learningParameters: Pair<Double, Double>): Double =
videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
val (m, b) = learningParameters
val (x, y) = pair
acc + ((((m * x) + b) - y).pow(2)/2)
} / videoGameData.size
fun findLearningParameters(videoGameData: Array<Pair<Double, Double>>): Pair<Double, Double> {
val learningRate = .0003
var guess = Pair(0.0, 0.0)
for (i in 1..10000000) {
guess = updateGuess(guess, learningRate, videoGameData)
if (i % 1000 == 0) println("Guess: $guess")
if (i % 1000 == 0) println("Cost: ${cost(videoGameData, guess)}")
}
return guess
}
class Data(val x: Double, val y: Double)
@Test
fun name() {
val fileContent = Tests::class.java.getResource("/data.json").readText()
val moshi = Moshi.Builder().build()
val type = Types.newParameterizedType(List::class.java, Data::class.java)
val adapter: JsonAdapter<List<Data>> = moshi.adapter(type)
val data = adapter.fromJson(fileContent)!!.map { Pair(it.x, it.y) }.toTypedArray()
val learningParameters = findLearningParameters(data)
println(learningParameters)
}
}
fun updateGuess(
guess: Pair<Double, Double>,
learningRate: Double,
videoGameData: Array<Pair<Double, Double>>
): Pair<Double, Double> {
val (m, b) = guess
val mGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
val (x, y) = pair
acc + ((((m * x) + b) - y) * x)
} / videoGameData.size
val bGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
val (x, y) = pair
acc + ((((m * x) + b) - y))
} / videoGameData.size
val newM = m - (learningRate * mGradient)
val newB = b - (learningRate * bGradient)
return Pair(newM, newB)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment