Created
November 14, 2014 13:36
-
-
Save sinhrks/2bf59a17ac43ecb82822 to your computer and use it in GitHub Desktop.
Dynamic Time Warping
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(dplyr) | |
library(tidyr) | |
library(ggplot2) | |
library(gridExtra) | |
library(animation) | |
plot_dtw_matrix <- function(ts_a, ts_b, i, j, cost, dist) { | |
.plot_matrix <- function(m, title, low, high) { | |
d <- dplyr::tbl_df(data.frame(m)) | |
colnames(d) <- 1:ncol(m) | |
d$index <- 1:nrow(m) | |
d <- gather(d, variable, value, -index) | |
p <- ggplot(d, aes(index, variable, fill = value)) + geom_tile() + | |
scale_fill_gradient(low = low, high = high) + | |
xlab('') + ylab('') + ggtitle(title) | |
p | |
} | |
d_ts <- data.frame(index = 1:length(ts_a), | |
ts_a = ts_a, ts_b = ts_b) %>% | |
gather(variable, value, -index) | |
p_ts <- ggplot(d_ts, aes(x = index, y = value)) + | |
geom_line(aes(colour = variable)) + | |
geom_segment(x = i, y = ts_a[i], xend = j, yend = ts_b[j], linetype = 'dashed') + | |
xlab('') + ylab('') | |
p_cost <- .plot_matrix(cost, title = 'Cost Matrix', low = 'lightgreen', high = 'red') | |
p_dist <- .plot_matrix(dist, title = 'Distance Matrix', low = 'blue', high = 'red') | |
grid.arrange(p_ts, arrangeGrob(p_cost, p_dist, ncol=2), nrow = 2) | |
} | |
dtw_distance <- function(ts_a, ts_b, d = function(x, y) abs(x-y), | |
window = max(length(ts_a), length(ts_b))) { | |
ts_a_len <- length(ts_a) | |
ts_b_len <- length(ts_b) | |
# コスト行列 (ts_a と ts_b のある2点間の距離を保存) | |
cost <- matrix(NA, nrow = ts_a_len, ncol = ts_b_len) | |
# 距離行列 (ts_a と ts_b の最短距離を保存) | |
dist <- matrix(NA, nrow = ts_a_len, ncol = ts_b_len) | |
cost[1, 1] <- d(ts_a[1], ts_b[1]) | |
dist[1, 1] <- cost[1, 1] | |
for (i in 2:ts_a_len) { | |
cost[i, 1] <- d(ts_a[i], ts_b[1]) | |
dist[i, 1] <- dist[i-1, 1] + cost[i, 1] | |
plot_dtw_matrix(ts_a, ts_b, i, 1, cost, dist) | |
} | |
for (j in 2:ts_b_len) { | |
cost[1, j] <- d(ts_a[1], ts_b[j]) | |
dist[1, j] <- dist[1, j-1] + cost[1, j] | |
plot_dtw_matrix(ts_a, ts_b, 1, j, cost, dist) | |
} | |
for (i in 2:ts_a_len) { | |
# 最短距離を探索する範囲 (ウィンドウサイズ = ラグ) | |
window.start <- max(2, i - window) | |
window.end <- min(ts_b_len, i + window) | |
for (j in window.start:window.end) { | |
# dtw::symmetric1 と同じパターン | |
choices <- c(dist[i-1, j], dist[i, j-1], dist[i-1, j-1]) | |
cost[i, j] <- d(ts_a[i], ts_b[j]) | |
dist[i, j] <- min(choices) + cost[i, j] | |
plot_dtw_matrix(ts_a, ts_b, i, j, cost, dist) | |
} | |
} | |
return(dist[nrow(dist), ncol(dist)]) | |
} | |
ts_a <- AirPassengers[31:45] | |
ts_b <- AirPassengers[41:55] | |
saveGIF({ | |
(d <- dtw_distance(ts_a, ts_b)) | |
}, interval = 0.1, movie.name = "dtw.gif", | |
ani.width = 600, ani.height = 400) | |
library(dtw) | |
d <- dtw::dtw(ts_a, ts_b, step.pattern = symmetric1) | |
d$distance |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment