Skip to content

Instantly share code, notes, and snippets.

@juliasilge
Created August 22, 2020 23:28
Show Gist options
  • Save juliasilge/a12bf266b209814e766e61a963996478 to your computer and use it in GitHub Desktop.
Save juliasilge/a12bf266b209814e766e61a963996478 to your computer and use it in GitHub Desktop.
Join output of tidy and augment to get centers of each cluster for each point
library(tidymodels)

centers <- tibble(
  cluster = factor(1:3), 
  num_points = c(100, 150, 50),  # number points in each cluster
  x1 = c(5, 0, -3),              # x1 coordinate of cluster center
  x2 = c(-1, 1, -2)              # x2 coordinate of cluster center
)

set.seed(27)
labelled_points <- 
  centers %>%
  mutate(
    x1 = map2(num_points, x1, rnorm),
    x2 = map2(num_points, x2, rnorm)
  ) %>% 
  select(-num_points) %>% 
  unnest(cols = c(x1, x2))

points <- 
  labelled_points %>% 
  select(-cluster)

points
#> # A tibble: 300 x 2
#>       x1     x2
#>    <dbl>  <dbl>
#>  1  6.91 -2.74 
#>  2  6.14 -2.45 
#>  3  4.24 -0.946
#>  4  3.54  0.287
#>  5  3.91  0.408
#>  6  5.30 -1.58 
#>  7  5.01 -1.77 
#>  8  6.16 -1.68 
#>  9  7.13 -2.17 
#> 10  5.24 -2.42 
#> # … with 290 more rows

kclust <- kmeans(points, centers = 3)

points_joined <- augment(kclust, points) %>%
  inner_join(tidy(kclust) %>%
               select(x1, x2, cluster), 
             by = c(".cluster" = "cluster"), 
             suffix = c("_points", "_centers"))

points_joined
#> # A tibble: 300 x 5
#>    x1_points x2_points .cluster x1_centers x2_centers
#>        <dbl>     <dbl> <fct>         <dbl>      <dbl>
#>  1      6.91    -2.74  3              5.00      -1.05
#>  2      6.14    -2.45  3              5.00      -1.05
#>  3      4.24    -0.946 3              5.00      -1.05
#>  4      3.54     0.287 3              5.00      -1.05
#>  5      3.91     0.408 3              5.00      -1.05
#>  6      5.30    -1.58  3              5.00      -1.05
#>  7      5.01    -1.77  3              5.00      -1.05
#>  8      6.16    -1.68  3              5.00      -1.05
#>  9      7.13    -2.17  3              5.00      -1.05
#> 10      5.24    -2.42  3              5.00      -1.05
#> # … with 290 more rows

points_joined %>%
  select(.cluster, x1_centers, x2_centers)
#> # A tibble: 300 x 3
#>    .cluster x1_centers x2_centers
#>    <fct>         <dbl>      <dbl>
#>  1 3              5.00      -1.05
#>  2 3              5.00      -1.05
#>  3 3              5.00      -1.05
#>  4 3              5.00      -1.05
#>  5 3              5.00      -1.05
#>  6 3              5.00      -1.05
#>  7 3              5.00      -1.05
#>  8 3              5.00      -1.05
#>  9 3              5.00      -1.05
#> 10 3              5.00      -1.05
#> # … with 290 more rows

Created on 2020-08-22 by the reprex package (v0.3.0.9001)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment