Created
October 8, 2016 16:25
-
-
Save r9y9/9d2ebe95aa7cd47ea57c64b3be4eb4c1 to your computer and use it in GitHub Desktop.
kotlin入門二日目
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /** | |
| * Created by ryuyamamoto on 2016/10/07. | |
| */ | |
| import java.util.Random | |
| fun computeDistance(x1: DoubleArray, x2: DoubleArray): Double { | |
| var distance: Double = 0.0 | |
| for (d in x1.indices) { | |
| distance += (x1[d] - x2[d]) * (x1[d] - x2[d]) | |
| } | |
| return Math.sqrt(distance) | |
| } | |
| class KMeans(val K: Int = 2, val D: Int = 2) { | |
| var clusters: Array<DoubleArray> = Array(K) { DoubleArray(D) } | |
| private val random = Random() | |
| fun fit(X: Array<DoubleArray>, nIter: Int = 20): Array<DoubleArray> { | |
| val N = X.size | |
| // クラスタを適当に割り当てる | |
| var which = IntArray(N) | |
| for (i in which.indices) { | |
| which[i] = random.nextInt(K) | |
| } | |
| // 繰り返し | |
| for (i in 0..nIter) { | |
| which = updateAssignment(X) | |
| updateClusters(X, which) | |
| } | |
| return clusters | |
| } | |
| fun updateAssignment(X: Array<DoubleArray>): IntArray { | |
| val N = X.size | |
| var which = IntArray(N) | |
| for (n in 0..N - 1) { | |
| var distances = DoubleArray(K) | |
| for (k in 0..K - 1) { | |
| distances[k] = computeDistance(X[n], clusters[k]) | |
| } | |
| var minIndex: Int = 0 | |
| var minValue: Double = distances[minIndex] | |
| for (k in 1..K - 1) { | |
| if (distances[k] < minValue) { | |
| minIndex = k | |
| minValue = distances[k] | |
| } | |
| } | |
| which[n] = minIndex | |
| } | |
| return which | |
| } | |
| fun updateClusters(X: Array<DoubleArray>, which: IntArray) { | |
| val N = X.size | |
| for (k in 0..K - 1) { | |
| for (d in clusters[k].indices) { | |
| var sum: Double = 0.0 | |
| var count: Int = 0 | |
| for (n in 0..N - 1) { | |
| if (which[n] == k) { | |
| sum += X[n][d] | |
| count += 1 | |
| } | |
| } | |
| clusters[k][d] = if (count > 0) { | |
| sum / count | |
| } else { | |
| 0.0 | |
| } | |
| } | |
| } | |
| } | |
| } | |
| fun testData(): Array<DoubleArray> { | |
| /* DataはJuliaで生成した | |
| ```julia | |
| using Distributions | |
| rand(MvNormal([0.1,0.1], 0.01*diagm(ones(2))), 20)' | |
| ``` | |
| */ | |
| var data = Array(40) { DoubleArray(2) } | |
| data[0] = doubleArrayOf(0.270645451904348, 0.09905485660153546) | |
| data[1] = doubleArrayOf(0.2104697439241483, 0.16746008309261867) | |
| data[2] = doubleArrayOf(0.053233805374506664, 0.019565142995862733) | |
| data[3] = doubleArrayOf(0.23914042222182508, 0.15154760542951273) | |
| data[4] = doubleArrayOf(0.15086637491060237, 0.10696803039238002) | |
| data[5] = doubleArrayOf(0.14667883579929916, 0.32023563331239036) | |
| data[6] = doubleArrayOf(0.061757222008084284, 0.29952357572276933) | |
| data[7] = doubleArrayOf(0.0033625433674295824, 0.2505102692861) | |
| data[8] = doubleArrayOf(0.1292375140527806, 0.11038326403355014) | |
| data[9] = doubleArrayOf(0.07580825553736137, 0.17796575592847452) | |
| data[10] = doubleArrayOf(0.08216783649857799, -0.05122766295330777) | |
| data[11] = doubleArrayOf(0.07297415944149035, 0.07728231179802257) | |
| data[12] = doubleArrayOf(0.03555578224194941, 0.19324403074621316) | |
| data[13] = doubleArrayOf(0.02946701952300422, 0.25274127350355263) | |
| data[14] = doubleArrayOf(0.20820693183274155, -0.029277424938451313) | |
| data[15] = doubleArrayOf(0.17671736644333297, 0.06321603653433891) | |
| data[16] = doubleArrayOf(0.1961268628983909, 0.16643321756888052) | |
| data[17] = doubleArrayOf(0.18448503279998468, 0.01492330736414664) | |
| data[18] = doubleArrayOf(-0.06606722384471866, 0.09628369943617718) | |
| data[19] = doubleArrayOf(0.1546217088548203, 0.13177746616761643) | |
| /* | |
| ```julia | |
| rand(MvNormal([0.7,0.7], 0.01*diagm(ones(2))), 20)' | |
| ``` | |
| */ | |
| data[20] = doubleArrayOf(0.7062974855486346, 0.7849391731810692) | |
| data[21] = doubleArrayOf(0.6294004718034438, 0.6267958484486116) | |
| data[22] = doubleArrayOf(0.8930307321940418, 0.7784391568542095) | |
| data[23] = doubleArrayOf(0.8132990505641942, 0.7271174530990139) | |
| data[24] = doubleArrayOf(0.7441097143741126, 0.6425141324100608) | |
| data[25] = doubleArrayOf(0.7344774202499925, 0.6514385410929594) | |
| data[26] = doubleArrayOf(0.6615820715998231, 0.7192607187447675) | |
| data[27] = doubleArrayOf(0.7674882060540668, 0.6152064954947855) | |
| data[28] = doubleArrayOf(0.7508106501879178, 0.5936021314001545) | |
| data[29] = doubleArrayOf(0.7227883314353798, 0.7203576123836959) | |
| data[30] = doubleArrayOf(0.7313356536618718, 0.6980369810204339) | |
| data[31] = doubleArrayOf(0.5583637898393485, 0.7857518116490675) | |
| data[32] = doubleArrayOf(0.7086118413038008, 0.7537309570691416) | |
| data[33] = doubleArrayOf(0.7911614658292188, 0.6860840993177828) | |
| data[34] = doubleArrayOf(0.6989695736881008, 0.6338877715486645) | |
| data[35] = doubleArrayOf(0.7209966578931128, 0.790136317251175) | |
| data[36] = doubleArrayOf(0.735507692852464, 0.7080924211231643) | |
| data[37] = doubleArrayOf(0.6877643598885169, 0.6831662385191429) | |
| data[38] = doubleArrayOf(0.6672221952385339, 0.831361584200396) | |
| data[39] = doubleArrayOf(0.7293546576410929, 0.793395336562019) | |
| return data | |
| } | |
| fun main(args: Array<String>) { | |
| val K = 2 | |
| val D = 2 | |
| val data = testData() | |
| val km = KMeans(K, D) | |
| // ドン | |
| val clusters = km.fit(data, nIter = 20) | |
| for (i in clusters.indices) { | |
| println("Cluster #${i}: ${clusters[i].filter { it == it }}") | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment