Skip to content

Instantly share code, notes, and snippets.

@r9y9
Created October 8, 2016 16:25
Show Gist options
  • Select an option

  • Save r9y9/9d2ebe95aa7cd47ea57c64b3be4eb4c1 to your computer and use it in GitHub Desktop.

Select an option

Save r9y9/9d2ebe95aa7cd47ea57c64b3be4eb4c1 to your computer and use it in GitHub Desktop.
kotlin入門二日目
/**
* Created by ryuyamamoto on 2016/10/07.
*/
import java.util.Random
fun computeDistance(x1: DoubleArray, x2: DoubleArray): Double {
var distance: Double = 0.0
for (d in x1.indices) {
distance += (x1[d] - x2[d]) * (x1[d] - x2[d])
}
return Math.sqrt(distance)
}
class KMeans(val K: Int = 2, val D: Int = 2) {
var clusters: Array<DoubleArray> = Array(K) { DoubleArray(D) }
private val random = Random()
fun fit(X: Array<DoubleArray>, nIter: Int = 20): Array<DoubleArray> {
val N = X.size
// クラスタを適当に割り当てる
var which = IntArray(N)
for (i in which.indices) {
which[i] = random.nextInt(K)
}
// 繰り返し
for (i in 0..nIter) {
which = updateAssignment(X)
updateClusters(X, which)
}
return clusters
}
fun updateAssignment(X: Array<DoubleArray>): IntArray {
val N = X.size
var which = IntArray(N)
for (n in 0..N - 1) {
var distances = DoubleArray(K)
for (k in 0..K - 1) {
distances[k] = computeDistance(X[n], clusters[k])
}
var minIndex: Int = 0
var minValue: Double = distances[minIndex]
for (k in 1..K - 1) {
if (distances[k] < minValue) {
minIndex = k
minValue = distances[k]
}
}
which[n] = minIndex
}
return which
}
fun updateClusters(X: Array<DoubleArray>, which: IntArray) {
val N = X.size
for (k in 0..K - 1) {
for (d in clusters[k].indices) {
var sum: Double = 0.0
var count: Int = 0
for (n in 0..N - 1) {
if (which[n] == k) {
sum += X[n][d]
count += 1
}
}
clusters[k][d] = if (count > 0) {
sum / count
} else {
0.0
}
}
}
}
}
fun testData(): Array<DoubleArray> {
/* DataはJuliaで生成した
```julia
using Distributions
rand(MvNormal([0.1,0.1], 0.01*diagm(ones(2))), 20)'
```
*/
var data = Array(40) { DoubleArray(2) }
data[0] = doubleArrayOf(0.270645451904348, 0.09905485660153546)
data[1] = doubleArrayOf(0.2104697439241483, 0.16746008309261867)
data[2] = doubleArrayOf(0.053233805374506664, 0.019565142995862733)
data[3] = doubleArrayOf(0.23914042222182508, 0.15154760542951273)
data[4] = doubleArrayOf(0.15086637491060237, 0.10696803039238002)
data[5] = doubleArrayOf(0.14667883579929916, 0.32023563331239036)
data[6] = doubleArrayOf(0.061757222008084284, 0.29952357572276933)
data[7] = doubleArrayOf(0.0033625433674295824, 0.2505102692861)
data[8] = doubleArrayOf(0.1292375140527806, 0.11038326403355014)
data[9] = doubleArrayOf(0.07580825553736137, 0.17796575592847452)
data[10] = doubleArrayOf(0.08216783649857799, -0.05122766295330777)
data[11] = doubleArrayOf(0.07297415944149035, 0.07728231179802257)
data[12] = doubleArrayOf(0.03555578224194941, 0.19324403074621316)
data[13] = doubleArrayOf(0.02946701952300422, 0.25274127350355263)
data[14] = doubleArrayOf(0.20820693183274155, -0.029277424938451313)
data[15] = doubleArrayOf(0.17671736644333297, 0.06321603653433891)
data[16] = doubleArrayOf(0.1961268628983909, 0.16643321756888052)
data[17] = doubleArrayOf(0.18448503279998468, 0.01492330736414664)
data[18] = doubleArrayOf(-0.06606722384471866, 0.09628369943617718)
data[19] = doubleArrayOf(0.1546217088548203, 0.13177746616761643)
/*
```julia
rand(MvNormal([0.7,0.7], 0.01*diagm(ones(2))), 20)'
```
*/
data[20] = doubleArrayOf(0.7062974855486346, 0.7849391731810692)
data[21] = doubleArrayOf(0.6294004718034438, 0.6267958484486116)
data[22] = doubleArrayOf(0.8930307321940418, 0.7784391568542095)
data[23] = doubleArrayOf(0.8132990505641942, 0.7271174530990139)
data[24] = doubleArrayOf(0.7441097143741126, 0.6425141324100608)
data[25] = doubleArrayOf(0.7344774202499925, 0.6514385410929594)
data[26] = doubleArrayOf(0.6615820715998231, 0.7192607187447675)
data[27] = doubleArrayOf(0.7674882060540668, 0.6152064954947855)
data[28] = doubleArrayOf(0.7508106501879178, 0.5936021314001545)
data[29] = doubleArrayOf(0.7227883314353798, 0.7203576123836959)
data[30] = doubleArrayOf(0.7313356536618718, 0.6980369810204339)
data[31] = doubleArrayOf(0.5583637898393485, 0.7857518116490675)
data[32] = doubleArrayOf(0.7086118413038008, 0.7537309570691416)
data[33] = doubleArrayOf(0.7911614658292188, 0.6860840993177828)
data[34] = doubleArrayOf(0.6989695736881008, 0.6338877715486645)
data[35] = doubleArrayOf(0.7209966578931128, 0.790136317251175)
data[36] = doubleArrayOf(0.735507692852464, 0.7080924211231643)
data[37] = doubleArrayOf(0.6877643598885169, 0.6831662385191429)
data[38] = doubleArrayOf(0.6672221952385339, 0.831361584200396)
data[39] = doubleArrayOf(0.7293546576410929, 0.793395336562019)
return data
}
fun main(args: Array<String>) {
val K = 2
val D = 2
val data = testData()
val km = KMeans(K, D)
// ドン
val clusters = km.fit(data, nIter = 20)
for (i in clusters.indices) {
println("Cluster #${i}: ${clusters[i].filter { it == it }}")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment