Skip to content

Instantly share code, notes, and snippets.

@naoyat
Created April 21, 2012 02:19
Show Gist options
  • Select an option

  • Save naoyat/2433302 to your computer and use it in GitHub Desktop.

Select an option

Save naoyat/2433302 to your computer and use it in GitHub Desktop.
AWKでK-means法
#
# k-means.awk
#
BEGIN {
if (K < 2) {
printf("usage: awk -f %s -v K=nnn datafile¥n", ARGV[0])
exit
}
move_threshold = 0
step_max = 100
step_output = 0
}
{
# データ読み込み
if (NF > dim) dim = NF
for (d=1; d<=NF; d++) {
x[NR,d] = $d
sum[d] += $d
}
nr++
}
END {
# 標準偏差を求める
for (d=1; d<=dim; d++) {
avr[d] = sum[d]/nr
s = 0
for (n=1; n<=nr; n++) {
diff = x[n,d] - avr[d]
s += diff*diff
}
s /= nr
var[d] = sqrt(s)
printf("%d) avg:%g var:%g¥n", d, avr[d], var[d]) > "/dev/stderr"
# 入力データを標準化
for (n=1; n<=nr; n++) {
x0[n,d] = (x[n,d] - avr[d]) / var[d]
}
}
# μxの初期値
for (k=1; k<=K; k++) {
th = 3.141592653589793*(2*(k-1)/K - 1/5)
for (d=1; d<=dim; d++) {
mu[k,d] = 1.667 * (d % 2 ? cos(th) : sin(th)) # 適当にばらけた値
}
}
# r_nk初期化
for (n=1; n<=nr; n++) {
for (k=1; k<=K; k++) r[n,k] = 0
}
# EM
for (st=1; st<=step_max; st++) {
printf("STEP %d...¥n", st) > "/dev/stderr"
for (k=1; k<=K; k++) {
printf("μ%d:", k) > "/dev/stderr"
for (d=1; d<=dim; d++) {
printf(" %g", mu[k,d]) > "/dev/stderr"
}
printf("¥n") > "/dev/stderr"
}
# E
for (k=1; k<=K; k++) kn[k] = 0;
mov = 0
for (n=1; n<=nr; n++) {
d2_min = 99999999
d2_k = 0
for (k=1; k<=K; k++) {
d2 = 0
for (d=1; d<=dim; d++) {
d_ = x0[n,d] - mu[k,d]
d2 += d_ * d_
}
if (d2 < d2_min) {
d2_min = d2
d2_k = k
}
}
if (r[n,d2_k] != 1) mov++
for (k=1; k<=K; k++) {
r[n,k] = k == d2_k ? 1 : 0
}
kn[d2_k]++
}
if (mov <= move_threshold) break
if (step_output) {
# 出力(グラフ描き用)
outfile = sprintf("step%03d.dat", st)
for (n=1; n<=nr; n++) {
for (d=1; d<=dim; d++) {
printf("%.10f ", x0[n,d]) >> outfile
}
for (k=1; k<=K; k++) {
if (r[n,k] == 1) printf("%d¥n", k) >> outfile
}
}
close(outfile)
}
# M
for (k=1; k<=K; k++) {
for (d=1; d<=dim; d++) {
mu[k,d] = 0
for (n=1; n<=nr; n++) {
mu[k,d] += r[n,k] * x0[n,d]
}
mu[k,d] /= kn[k];
}
}
}
# 出力
for (n=1; n<=nr; n++) {
for (d=1; d<=dim; d++) {
printf("%.10f ", x0[n,d])
}
for (k=1; k<=K; k++) {
if (r[n,k] == 1) printf("%d¥n", k)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment