Last active
January 8, 2020 18:06
-
-
Save jwhendy/085f74b346a508178dbb78dab7e5a3e6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
def data_weighted_kmeans(df, x_val, y_val, wt_val=None, k=1, centers=None, max_iters=100): | |
"""Revised version of the code found here, allowing the use of a data frame instead of a list of dicts. | |
- https://github.com/leapingllamas/medium_posts/observation_weighted_kmeans | |
Re-citing original code referenced by leapingllamas github above: | |
- http://people.sc.fsu.edu/~jburkardt/m_src/kmeans/kmeans.html | |
- http://people.sc.fsu.edu/~jburkardt/m_src/kmeans/kmeans_w_03.m | |
Args: | |
df (pd.DataFrame): coordinates and corresponding weights, one coordinate per row | |
x_val (str): the column name holding x locations | |
y_val (str): the column name holding y locations | |
wt_val (str): the column holding the weights | |
k: the number of clusters to use | |
centers ([[float, float]]: a list of [x, y] initial cluster centers where len(centers)==k. | |
If None, k random coordinates from df will be chosen. | |
max_iters (int): the maximum center update iterations to run | |
Returns: | |
clusters (list): cluster ID assigned to each point in the order of the original data frame. | |
summary (df): a pd.DataFrame containing columns clust, x_val, y_val, wt_total, and n_total corresponding to the | |
cluster ID, x coord, y coord, weight sum of member points, and total count of member points, respectively. | |
""" | |
# subset dataframe to x_val, y_val, and wt_val, add column to store cluster assignment | |
data = df[[x_val, y_val]].copy().reset_index(drop=True) | |
if wt_val: | |
data[wt_val] = df[wt_val] | |
else: | |
data[wt_val] = [1] * len(data) | |
data['clust'] = [None] * len(data) | |
clusters = [None] * len(data) | |
# initialize centers | |
if not centers: | |
centers = [[p[x_val], p[y_val]] for _, p in data.sample(n=k, random_state=1).iterrows()] | |
summary = pd.DataFrame({ | |
'clust': list(range(k)), | |
'lng': [c[0] for c in centers], | |
'lat': [c[1] for c in centers], | |
'wt_total': [None] * k, | |
'n_total': [None] * k | |
}) | |
# perform iteration | |
for iter in range(max_iters): | |
# assign clusters | |
for i, row in data.iterrows(): | |
dists = [((row[x_val]-c[0])**2 + (row[y_val]-c[1])**2) for c in centers] | |
clusters[i] = dists.index(min(dists)) | |
# if assignments match, break | |
if data['clust'].to_list() == clusters: | |
print(iter) | |
break | |
# adjust cluster centers | |
data['clust'] = clusters | |
for i in range(k): | |
pts = data.loc[data.clust == i] | |
wt_total = pts[wt_val].sum() | |
centers[i][0] = (pts[x_val]*pts[wt_val]).sum()/wt_total | |
centers[i][1] = (pts[y_val]*pts[wt_val]).sum()/wt_total | |
summary.loc[summary.clust == i, 'lng'] = centers[i][0] | |
summary.loc[summary.clust == i, 'lat'] = centers[i][1] | |
summary.loc[summary.clust == i, 'wt_total'] = wt_total | |
summary.loc[summary.clust == i, 'n_total'] = len(pts) | |
return clusters, summary | |
### example call | |
# clusters, summary = data_weighted_kmeans(df, x_val='lng', y_val='lat', wt_val='foo', k=5, centers=[[0, 1], [2, 3]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment