Last active
November 8, 2018 19:22
-
-
Save jaye-ross/38ffdda57a17305610c803943d382572 to your computer and use it in GitHub Desktop.
kmeans||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
dist = function(v1,v2) { | |
sqrt(sum((v1-v2)^2)) | |
} | |
# point and mu are numerical vectors of the same length | |
sum_of_squares = function(point,mu){ | |
sum((point - mu)^2) | |
} | |
# points a data.frame | |
cost = function(points,mu){ | |
sum(apply(points,1,sum_of_squares,mu)) | |
} | |
weight_centers = function(data,centers){ | |
# weight centers | |
weights = table(sapply(1:nrow(data),function(i){ | |
which.min(apply(data[centers,],1,function(row){ | |
sqrt(sum_of_squares(row,data[i,])) | |
})) | |
})) | |
return(weights) | |
} | |
oversample_init_centers = function(data,k,l){ | |
centroids = sample(1:nrow(data),1) | |
c = cost(data,centroids[1]) | |
phi = round(log(c)) | |
# create centers | |
for(i in 1:phi){ | |
# get distance from every point to its closest center | |
min_dists = sapply(1:nrow(data),function(i){ | |
min(apply(data[centroids,],1,function(row){ | |
sqrt(sum_of_squares(row,data[i,])) | |
})) | |
}) | |
cutoffs = l*min_dists/sum(min_dists) | |
probs = runif(length(cutoffs)) | |
new_centroids = which(probs < cutoffs) | |
centroids = c(centroids,new_centroids) | |
} | |
return(unique(centroids)) | |
} | |
k_means_par = function(data,k,l){ | |
centers = oversample_init_centers(data,k,l) | |
w = weight_centers(data,centers) | |
return(data[centers,] %>% mutate(weights = w)) | |
} | |
# Added optional weights | |
kpp = function(data,k) { | |
if(!("weights" %in% colnames(data))){ | |
data = data %>% mutate(weights = 1/nrow(data)) | |
} | |
centroids = numeric(k) | |
# assign first centroid at random from points | |
centroids[1] = sample(1:nrow(data),1) | |
for(i in 2:k) { | |
# For each data point x, compute D(x), the distance between | |
# x and the nearest center that has already been chosen | |
d_x = apply(data,1,function(row){ | |
# find closest center | |
d_to_centers = sapply(centroids[1:(i-1)], | |
function(center) { | |
dist(row,data[center,]) | |
}) | |
min(d_to_centers) | |
}) | |
# Choose one new data point at random as a new center, | |
# using a weighted probability distribution where a point x | |
# is chosen with probability proportional to D(x)^2. | |
probs = data$weights * d_x^2 / sum(d_x) | |
zero_prob = which(d_x == 0) | |
centroids[i] = sample(seq_along(probs)[-zero_prob],1,prob=probs[-zero_prob]) | |
} | |
return(data[centroids,]) | |
} | |
kmp = function(data,k,l){ | |
centers = k_means_par(data,2,2) | |
centers_kpp = kpp(centers,2) | |
kpp_all = kpp(data,2) | |
# plot results | |
g = ggplot(data) + geom_point(aes(x,y)) + | |
geom_point(data=centers,aes(x,y),colour="blue") + | |
geom_point(data=centers_kpp,aes(x,y),colour="red") + | |
geom_point(data=kpp_all,aes(x,y),colour="green") + | |
theme(legend.position="none") | |
print(g) | |
return(centers_kpp) | |
} | |
############ | |
data = faithful %>% rename(x = eruptions, y = waiting) | |
kmp(data,2,4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment