i have (x, y) pairs generated from two smooth curves, but i don't know which curve each point comes from
— alex hayes (@alexpghayes) June 28, 2022
is there a way to recover the original curves?https://t.co/h9kL2JEeEV pic.twitter.com/tzy4O1VnrT
library(tidyverse)
x <- seq(0, 10, length.out = 100)
unobserved <- tibble(
x = x,
true_curve1 = sin(x),
true_curve2 = tanh(x) - 0.5,
coin = as.logical(rbinom(length(x), size = 1, prob = 0.5)),
y1 = ifelse(coin, true_curve1, true_curve2),
y2 = ifelse(coin, true_curve2, true_curve1)
)
observed <- unobserved |>
pivot_longer(
c("y1", "y2"),
values_to = "y"
) |>
select(x, y)
observed |>
ggplot(aes(x, y)) +
geom_point() +
labs(
title = "How to identify which of the two curves each point belongs to?",
subtitle = "100 data points have been generated from two smooth curves, but we only see (x, y) pairs"
) +
theme_minimal()
Here, roughness means sum of squared second derivatives.
path_roughness <- function(x, y) {
drv1 <- if (x[1] != x[2]) (y[, 2] - y[, 1]) / (x[2] - x[1]) else 0
drv2 <- if (x[2] != x[3]) (y[, 3] - y[, 2]) / (x[3] - x[2]) else 0
(drv2 - drv1)^2 / (x[3] - x[1])^2
}
assign_curves <- function(x, y) {
x <- unname(unlist(x))
y <- as.matrix(y)
d <- ncol(y) # number of curves
# trick: first design point is replicated to allow for computing 2nd derivative
# (which really is just 1st deriv there)
curves <- matrix(NA, length(x) + 1, d + 1)
curves[, 1] <- c(x[1], x)
curves[1, -1] <- curves[2, -1] <- y[1, ]
for (i in seq(2, nrow(y))) {
# construct possible paths
y_poss <- expand.grid(
curves[i, -1],
y[i, ]
)
y_poss <- cbind(curves[i - 1, -1], as.matrix(y_poss))
# roughness of each path serves as weight for the assignment problem
roughness <- path_roughness(curves[i + seq(-1, 1), 1], y_poss)
sol <- RcppHungarian::HungarianSolver(matrix(roughness, d, d))
curves[i + 1, 1 + sol$pairs[, 1]] <- matrix(y_poss[, 3], d, d)[sol$pairs]
}
colnames(curves) <- c("x", paste0("y", 1:d))
curves[-1, ] # remove duplicated first design point
}
# convert to wide format
observed <- observed |>
group_by(x) |>
mutate(id = paste0("V", seq_along(y))) |>
pivot_wider(values_from = y, names_from = id)
# run algorithm and plot
res <- assign_curves(observed[, 1], observed[, -1])
res |>
as_tibble() |>
pivot_longer(-x) |>
ggplot(aes(x, value, color = name)) +
geom_point() +
theme_minimal()
Results: