Last active
August 29, 2015 14:21
-
-
Save hnagata/549b68a1b6e2a1060c5e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## grid / ggplot ---- | |
library(ggplot2) | |
library(grid) | |
make.grid <- function(row, col) { | |
grid.newpage() | |
l <- grid.layout(row, col) | |
v <- viewport(layout=l) | |
pushViewport(v) | |
} | |
print.at <- function(o, i, j) { | |
print(o, vp=viewport(layout.pos.row=i, layout.pos.col=j)) | |
} | |
end.grid <- function() { | |
popViewport() | |
} | |
## データ読み込み ---- | |
dat <- read.csv("user.csv", fileEncoding="utf-8") | |
elemname <- c("chartInterval", "chartStability", "chartExpressiveness", "chartVibratoLongtone", "chartRhythm") | |
p <- length(elemname) | |
lv.user <- levels(dat$user) | |
lv.reqno <- levels(dat$requestNo) | |
n.user <- length(lv.user) | |
n.reqno <- length(lv.reqno) | |
# (曲, ユーザー) ペアで総合点が最も高いものだけ使う | |
filt.idx <- tapply(1 : nrow(dat), factor(paste0(dat$requestNo, dat$user)), function(idx) { | |
sub.dat <- dat[idx, ] | |
idx[which.max(dat[idx, "totalPoint"])] | |
}) | |
dat <- dat[filt.idx, ] | |
# pitch を数値に変換 | |
replace <- function(reptb, x) { | |
for (i in 1 : nrow(reptb)) x <- gsub(reptb[i, 1], reptb[i, 2], x); x | |
} | |
dat$highPitch <- gsub("♭", "b", dat$highPitch) | |
dat$lowPitch <- gsub("♭", "b", dat$lowPitch) | |
reptb.pitch <- data.frame( | |
c("lowAb", "lowA", "lowBb", "lowB", "lowC", "lowDb", "lowD", "lowEb", "lowE", "~lowF", "lowGb", "lowG", | |
"m1Ab", "m1A", "m1Bb", "m1B", "m1C", "m1Db", "m1D", "m1Eb", "m1E", "m1F", "m1Gb", "m1G", | |
"m2Ab", "m2A", "m2Bb", "m2B", "m2C", "m2Db", "m2D", "m2Eb", "m2E", "m2F", "m2Gb", "m2G", | |
"hihiAb", "hihiA", "hihiBb", "hihiB~", "hihiC", "hihiDb", "hihiD", "hihiEb", "hihiE", "hihiF", "hihiGb", "hihiG", | |
"hiAb", "hiA", "hiBb", "hiB", "hiC", "hiDb", "hiD", "hiEb", "hiE", "hiF", "hiGb", "hiG", | |
"Ab", "A", "Bb", "B", "C", "Db", "D", "Eb", "E", "F", "Gb", "G" | |
), | |
c(32 : 43, 44 : 55, 56 : 67, 80 : 91, 68 : 79, 68 : 79) | |
) | |
dat$highPitch <- as.numeric(replace(reptb.pitch, dat$highPitch)) | |
dat$lowPitch <- as.numeric(replace(reptb.pitch, dat$lowPitch)) | |
# 曲テーブルを作成 | |
songs <- data.frame( | |
reqno=lv.reqno, | |
artist=factor(tapply(as.character(dat$artist), dat$requestNo, function(x) x[1])), | |
contents=factor(tapply(as.character(dat$contents), dat$requestNo, function(x) x[1])), | |
highPitch=tapply(dat$highPitch, dat$requestNo, function(x) x[1]), | |
lowPitch=tapply(dat$lowPitch, dat$requestNo, function(x) x[1]) | |
) | |
songs$diffPitch <- songs$highPitch - songs$lowPitch + 1 | |
## 訓練・テストデータを作る ---- | |
# 全体の 5% をテストに回す | |
set.seed(0) | |
index.test <- sample(1 : nrow(dat), nrow(dat) * 0.05) | |
dat.test <- dat[index.test, ] | |
dat.train <- dat[-index.test, ] | |
# 訓練データにない曲を使うテストデータをはじく(要検討) | |
dat.test <- dat.test[table(dat.train$requestNo)[dat.test$requestNo] > 0, ] | |
# Temporary variables | |
y <- as.matrix(dat.train[, elemname]) | |
user.test <- as.numeric(dat.test$user) | |
reqno.test <- as.numeric(dat.test$requestNo) | |
y.test <- as.matrix(dat.test[, elemname]) | |
# データ数のチェック | |
c(train=nrow(dat.train), test=nrow(dat.test)) | |
## ベースライン: ユーザー内平均を予測値とする ---- | |
mse <- function(true.y, pred.y) { | |
sum((pred.y - true.y)^2) / length(true.y) | |
} | |
pred.y.by.mean <- apply(y, 2, function(y) tapply(y, dat.train$user, mean))[user.test, ] | |
mse(y.test, pred.y.by.mean) | |
## diaglm 実装 ---- | |
library(parallel) | |
diaglm <- function(dat, threshold=4, weighted=FALSE, verbose=TRUE, cl=NULL) { | |
if (!is.null(cl)) sapply <- function(...) parSapply(cl, ...) | |
# 訓練には threshold 回以上出現する曲だけ使う | |
dat <- dat[table(dat$requestNo)[dat$requestNo] >= threshold, ] | |
reqno <- as.numeric(dat$requestNo) | |
user <- as.numeric(dat$user) | |
y <- as.matrix(dat[, elemname]) | |
# 初期値の設定 | |
a <- matrix(1, n.reqno, p) | |
x <- apply(y, 2, function(y) tapply(y, dat$user, mean)) | |
if (weighted) { | |
w <- apply(y, 2, function(y) table(y)[as.character(y)]) | |
} else { | |
w <- matrix(1, nrow(dat), 5) | |
} | |
# 交互最適化 | |
iter <- 0 | |
err <- Inf | |
while (err > 1e-07) { | |
iter <- iter + 1 | |
a0 <- a | |
x0 <- x | |
j <- which(table(dat$requestNo) > 0) | |
a[j, ] <- t(sapply(j, function(j, y, xx, w, reqno, user) { | |
sub.y <- y[reqno == j, , drop=FALSE] | |
sub.x <- xx[user[reqno == j], , drop=FALSE] | |
sub.w <- w[reqno == j, , drop=FALSE] | |
sub.wx <- sqrt(sub.w) * sub.x | |
diag((t(sub.wx) %*% sub.y) / (t(sub.wx) %*% sub.x)) | |
}, y=y, xx=x, w=w, reqno=reqno, user=user)) | |
x <- t(sapply(1 : n.user, function(i, y, a, w, reqno, user) { | |
sub.y <- y[user == i, , drop=FALSE] | |
sub.a <- a[reqno[user == i], , drop=FALSE] | |
sub.w <- w[user == i, , drop=FALSE] | |
sub.wa <- sqrt(sub.w) * sub.a | |
diag((t(sub.wa) %*% sub.y) / (t(sub.wa) %*% sub.a)) | |
}, y=y, a=a, w=w, reqno=reqno, user=user)) | |
err <- sum((a - a0)^2) / (n.reqno * p) + sum(((x - x0) * 0.01)^2) / (n.user * p) | |
if (verbose) { | |
cat(paste0("#", iter, ": ", round(err, digits=8), "\n")) | |
} | |
} | |
resid <- a[reqno, ] * x[user, ] - y | |
r2 <- sapply(1 : p, function(k) { | |
(var(y[, k]) - sum(resid[, k]^2) / nrow(dat)) / var(y[, k]) | |
}) | |
colnames(a) <- elemname | |
colnames(x) <- elemname | |
list(a=a, x=x, resid=resid, r2=r2) | |
} | |
## 推定 ---- | |
cl <- makeCluster(4) | |
# diaglm | |
diaglm.std <- diaglm(dat.train, threshold=4, weighted=FALSE, cl=cl) | |
pred.y.by.diaglm.std <- diaglm.std$a[reqno.test, ] * diaglm.std$x[user.test, ] | |
c(mse=mse(y.test, pred.y.by.diaglm.std), r2=diaglm.std$r2) | |
# weighted diaglm | |
diaglm.w <- diaglm(dat.train, threshold=0, weighted=TRUE, cl=cl) | |
pred.y.by.diaglm.w <- diaglm.w$a[reqno.test, ] * diaglm.w$x[user.test, ] | |
c(mse=mse(y.test, pred.y.by.diaglm.w), r2=diaglm.w$r2) | |
# threshold | |
summary.diaglm.t <- t(sapply(1 : 6, function(threshold) { | |
diaglm.t <- diaglm(dat.train, threshold=threshold) | |
pred <- diaglm.t$a[reqno.test, ] * diaglm.t$x[user.test, ] | |
c(mse=mse(y.test, pred), r2=diaglm.t$r2) | |
})) | |
summary.diaglm.t | |
stopCluster(cl) | |
## Interval のみで mean, diaglm を比較 | |
int.df <- data.frame( | |
true = y.test[, "chartInterval"], | |
pred.mean = pred.y.by.mean[, "chartInterval"], | |
pred.diaglm = pred.y.by.diaglm.std[, "chartInterval"] | |
) | |
g1 <- ggplot(data=int.df, aes(x=true, y=pred.diaglm)) + | |
geom_point() + | |
geom_abline(slope=1) + | |
xlim(50, 100) + ylim(50, 100) + | |
xlab("True int") + ylab("Predicted int (proposed)") | |
g2 <- ggplot(data=int.df, aes(x=true, y=pred.mean)) + | |
geom_point() + | |
geom_abline(slope=1) + | |
xlim(50, 100) + ylim(50, 100) + | |
xlab("True int") + ylab("Predicted int (baseline)") | |
svg("interval.svg", width=8, height=4) | |
make.grid(1, 2) | |
print.at(g1, 1, 1) | |
print.at(g2, 1, 2) | |
end.grid() | |
dev.off() | |
## 残差の大きいサンプルを見る ---- | |
int.df$resid.diaglm <- int.df$pred.diaglm - int.df$true | |
int.df$resid.mean <- int.df$pred.mean - int.df$true | |
int.df <- int.df[order(abs(int.df$resid.diaglm), decreasing=TRUE), ] | |
int.df[1:20, c("true", "resid.diaglm", "resid.mean")] | |
## 音程の取りづらい曲は? ---- | |
cl <- makeCluster(4) | |
diaglm.full <- diaglm(dat, threshold=4, weighted=FALSE, cl=cl) | |
stopCluster(cl) | |
df <- data.frame( | |
a=diaglm.full$a[, 1], | |
songs[, c("diffPitch", "contents", "artist")] | |
) | |
df[order(df$a), ][1 : 10, ] | |
df[order(df$a, decreasing=TRUE), ][1 : 10, ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment