Last active
September 5, 2019 12:36
-
-
Save kmdupr33/8f34c8e1b8c256cde508b19ecbed0816 to your computer and use it in GitHub Desktop.
gradient descent in kotlin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fun findLearningParameters(videoGameData: Array<Pair<Double, Double>>): Pair<Double, Double> { | |
//... | |
return Pair(m, b) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.squareup.moshi.JsonAdapter | |
import com.squareup.moshi.Moshi | |
import com.squareup.moshi.Types | |
import org.junit.Test | |
import kotlin.math.pow | |
class Tests { | |
fun updateGuess( | |
guess: Pair<Double, Double>, | |
learningRate: Double, | |
videoGameData: Array<Pair<Double, Double>> | |
): Pair<Double, Double> { | |
val (m, b) = guess | |
val mGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> -> | |
val (x, y) = pair | |
acc + ((((m * x) + b) - y) * x) | |
} / videoGameData.size | |
val bGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> -> | |
val (x, y) = pair | |
acc + ((((m * x) + b) - y)) | |
} / videoGameData.size | |
val newM = m - (learningRate * mGradient) | |
val newB = b - (learningRate * bGradient) | |
return Pair(newM, newB) | |
} | |
fun cost(videoGameData: Array<Pair<Double, Double>>, learningParameters: Pair<Double, Double>): Double = | |
videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> -> | |
val (m, b) = learningParameters | |
val (x, y) = pair | |
acc + ((((m * x) + b) - y).pow(2)/2) | |
} / videoGameData.size | |
fun findLearningParameters(videoGameData: Array<Pair<Double, Double>>): Pair<Double, Double> { | |
val learningRate = .0003 | |
var guess = Pair(0.0, 0.0) | |
for (i in 1..10000000) { | |
guess = updateGuess(guess, learningRate, videoGameData) | |
if (i % 1000 == 0) println("Guess: $guess") | |
if (i % 1000 == 0) println("Cost: ${cost(videoGameData, guess)}") | |
} | |
return guess | |
} | |
class Data(val x: Double, val y: Double) | |
@Test | |
fun name() { | |
val fileContent = Tests::class.java.getResource("/data.json").readText() | |
val moshi = Moshi.Builder().build() | |
val type = Types.newParameterizedType(List::class.java, Data::class.java) | |
val adapter: JsonAdapter<List<Data>> = moshi.adapter(type) | |
val data = adapter.fromJson(fileContent)!!.map { Pair(it.x, it.y) }.toTypedArray() | |
val learningParameters = findLearningParameters(data) | |
println(learningParameters) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fun updateGuess( | |
guess: Pair<Double, Double>, | |
learningRate: Double, | |
videoGameData: Array<Pair<Double, Double>> | |
): Pair<Double, Double> { | |
val (m, b) = guess | |
val mGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> -> | |
val (x, y) = pair | |
acc + ((((m * x) + b) - y) * x) | |
} / videoGameData.size | |
val bGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> -> | |
val (x, y) = pair | |
acc + ((((m * x) + b) - y)) | |
} / videoGameData.size | |
val newM = m - (learningRate * mGradient) | |
val newB = b - (learningRate * bGradient) | |
return Pair(newM, newB) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment