Skip to content

Instantly share code, notes, and snippets.

@tomrockdsouza
Last active June 25, 2023 18:04
Show Gist options
  • Save tomrockdsouza/3b4b3f744ade13d7c6cd9ad97a6b25f6 to your computer and use it in GitHub Desktop.
Save tomrockdsouza/3b4b3f744ade13d7c6cd9ad97a6b25f6 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "f0b98428-986b-453a-8bd6-80fdd825a5c5",
"metadata": {},
"source": [
"## Question:\r\n",
"\r\n",
"Given a set of two dimensional points P (e.g. [(1.1, 2.5), (3.4,1.9)...]; the size of set can be 100s), write a function that calculates simple K-means. The expected returned value from the function is\r\n",
"\r\n",
"1. a set of cluster id that each point belongs to, and\r\n",
"2. coordinates of centroids at the end of the iterations\r\n",
"\r\n",
"Please use Python with standard libraries like numpy or pandas, but do not use Scikit-learn's k-means or any other k-means library, the idea is for you to implement k-means from scratch. Feel free to research and look up any information you need, but please note plagiarism will not be tolerated. You may spend as much time as needed, but as a frame of reference, an hour would be the maximum time frame. If more time is required, please send over the intermediate code at the one hour mark.\r\n",
"\r\n",
"You may start the assignment whenever you are ready. Once you have completed this task, get back to us along with the code and how long it took you. Please feel free to get in touch with me if you encounter any questions or problems.\r\n",
"\r\n",
"**Requirements**:\r\n",
"\r\n",
"Minimum: implementation of the k-means function/class\r\n",
"\r\n",
"Expected:\r\n",
"\r\n",
"- Implement an interface similar to Sklearn (subset is fine)\r\n",
"- Test code\r\n",
"- Visualisation\r\n",
"\r\n",
"**Deliverable**: Notebook with explanation and HTML output. "
]
},
{
"cell_type": "markdown",
"id": "3765cf32-53d7-45d5-8cd4-d69cc539bbdf",
"metadata": {},
"source": [
"### Installing Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "65203708-6021-4a47-b31a-3235fc1cbc0f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pandas in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (2.0.2)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: numpy>=1.21.0 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from pandas) (1.24.3)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: matplotlib in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (3.7.1)\n",
"Requirement already satisfied: contourpy>=1.0.1 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (1.0.7)\n",
"Requirement already satisfied: cycler>=0.10 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (4.39.4)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (1.4.4)\n",
"Requirement already satisfied: numpy>=1.20 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (1.24.3)\n",
"Requirement already satisfied: packaging>=20.0 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (23.1)\n",
"Requirement already satisfied: pillow>=6.2.0 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (9.5.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (3.0.9)\n",
"Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from matplotlib) (2.8.2)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
"Requirement already satisfied: pydantic in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (1.10.8)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in c:\\users\\tomrock\\desktop\\test\\test\\lib\\site-packages (from pydantic) (4.6.3)\n"
]
}
],
"source": [
"!pip install pandas\n",
"!pip install matplotlib\n",
"!pip install pydantic"
]
},
{
"cell_type": "markdown",
"id": "ccd980d7-371d-4e99-9e40-78184e8100fd",
"metadata": {},
"source": [
"### Importing Libraries"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c1b4ef05-5d42-4748-8eed-0459fd2f1e06",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm as cm\n",
"from typing import List, TypeVar, Tuple\n",
"from pydantic import validate_arguments, ValidationError\n",
"\n",
"PandasDataFrame = TypeVar('pandas.core.frame.DataFrame')"
]
},
{
"cell_type": "markdown",
"id": "5aa54ca8-2385-4305-8a71-39ae1e173de2",
"metadata": {},
"source": [
"### Loading Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "14322f2b-29eb-41bb-8f86-8d2ee46a22e8",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"@validate_arguments\n",
"def generate_random_array(low: int, high: int, size: int) -> List[float]:\n",
" return [random.uniform(low, high) for _ in range(size)]\n",
"\n",
"\n",
"@validate_arguments\n",
"def generate_dataframe(low: int, high: int, size: int):\n",
" return pd.DataFrame(\n",
" {\n",
" 'x': generate_random_array(low, high, size),\n",
" 'y': generate_random_array(low, high, size),\n",
" }\n",
" )\n",
"\n",
"\n",
"@validate_arguments\n",
"def validate_dataframe(df: PandasDataFrame) -> bool:\n",
" try:\n",
" df = df[['x', 'y']]\n",
" except:\n",
" raise Exception('The DataFrame should have 2 columns only x and y')\n",
" try:\n",
" df['x'] = df['x'].astype(float)\n",
" df['y'] = df['y'].astype(float)\n",
" except:\n",
" raise Exception('The DataFrame should have real float numbers as input')\n",
" if len(df[~(df.applymap(np.isreal))]) == 0:\n",
" raise Exception('The DataFrame should have real float numbers as input')\n",
" return True\n",
"\n",
"\n",
"@validate_arguments\n",
"def generate_random_centroids(df: PandasDataFrame, num_centroids: int) -> List[List[float]]:\n",
" x_min = df['x'].min()\n",
" x_max = df['x'].max()\n",
" y_min = df['y'].min()\n",
" y_max = df['y'].max()\n",
" centroids_x = generate_random_array(x_min, x_max, num_centroids)\n",
" centroids_y = generate_random_array(y_min, y_max, num_centroids)\n",
" return [list(row) for row in zip(*[centroids_x, centroids_y])]\n",
"\n",
"\n",
"@validate_arguments\n",
"def get_new_cluster_index(centroids: List[List[float]], df: PandasDataFrame) -> List[int]:\n",
" df = df[['x', 'y']]\n",
" centroid_arrays = []\n",
" for ed_2_centroid in centroids:\n",
" ed_square = np.square(ed_2_centroid[0] - df['x']) + np.square(ed_2_centroid[1] - df['y'])\n",
" centroid_arrays.append(np.sqrt(ed_square))\n",
" np_centroid_arrays = np.array(centroid_arrays)\n",
" return [np.argmin(sub_arr) for sub_arr in np_centroid_arrays.transpose()]\n",
"\n",
"\n",
"@validate_arguments\n",
"def visualize_kmeans(centroids: List[List[float]], df: PandasDataFrame) -> None:\n",
" centroids_x, centroids_y = [list(row) for row in zip(*centroids)]\n",
" # generating color for each cluster\n",
" colors = cm.rainbow(np.linspace(0, 1, len(centroids)))\n",
" # looping over each generated cluster\n",
" for idx in range(len(centroids)):\n",
" plt.scatter(\n",
" df[df['centroid_index'] == idx]['x'].values, df[df['centroid_index'] == idx]['y'].values,\n",
" s=50, c=colors[idx],\n",
" marker='o', edgecolor='black',\n",
" label=f'cluster {idx}'\n",
" )\n",
" # plot the centroids\n",
" plt.scatter(\n",
" centroids_x, centroids_y,\n",
" s=250, marker='*',\n",
" c='yellow', edgecolor='red',\n",
" label='centroids'\n",
" )\n",
" plt.legend(scatterpoints=1, loc='upper left', bbox_to_anchor=(1.01, 1))\n",
" plt.grid()\n",
" plt.show()\n",
"\n",
"\n",
"@validate_arguments\n",
"def fit_predict(\n",
" df: PandasDataFrame = generate_dataframe(0, 500, 150),\n",
" num_centroids: int = 3,\n",
" max_iter: int = 300\n",
") -> Tuple[List[List[float]], PandasDataFrame]:\n",
" validate_dataframe(df)\n",
" centroids = generate_random_centroids(df, num_centroids)\n",
" old_centroid_relation: int = [-1] * num_centroids\n",
" for iter_len in range(max_iter):\n",
" new_centroid_relation = get_new_cluster_index(centroids, df)\n",
" df['centroid_index'] = new_centroid_relation\n",
" # comparing old cluster relations with new\n",
" if new_centroid_relation == old_centroid_relation:\n",
" print(f'Total Iterations ({iter_len})')\n",
" visualize_kmeans(centroids, df)\n",
" return centroids\n",
" else:\n",
" # setting new centroid values\n",
" for idx, centroid in enumerate(centroids):\n",
" new_centroid = [\n",
" np.mean(df[df['centroid_index'] == idx]['x']),\n",
" np.mean(df[df['centroid_index'] == idx]['y']),\n",
" ]\n",
" if not new_centroid == (0, 0):\n",
" centroids[idx] = new_centroid\n",
" old_centroid_relation = new_centroid_relation\n",
" print(f'Max Iteration Reached ({max_iter})!')\n",
" visualize_kmeans(centroids, df)\n",
" return centroids\n",
"\n",
"\n",
"@validate_arguments\n",
"def predict_cluster(\n",
" centroids: List[List[float]],\n",
" x: List[float] = generate_random_array(low=0, high=500, size=2),\n",
" y: List[float] = generate_random_array(low=0, high=500, size=2)\n",
") -> List[int]:\n",
" if len(centroids) == 0:\n",
" raise Exception('Zero Values Present in Centroids')\n",
" return get_new_cluster_index(centroids, pd.DataFrame({'x': x, 'y': y}))\n"
]
},
{
"cell_type": "markdown",
"id": "71175194-94a1-4c7b-b666-ae1566656343",
"metadata": {},
"source": [
"### Testing and Visualizations"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "771c1fc0-12f4-4b8b-85d6-078fb28d50fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Iterations (11)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\TOMROCK\\AppData\\Local\\Temp\\ipykernel_6120\\3167390455.py:61: UserWarning: *c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*. Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.\n",
" plt.scatter(\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample Number:1, Cluster ID: 6\n",
"Sample Number:2, Cluster ID: 1\n",
"Sample Number:3, Cluster ID: 0\n",
"Sample Number:4, Cluster ID: 4\n",
"Sample Number:5, Cluster ID: 6\n"
]
}
],
"source": [
"try:\n",
" # Test Data Seed Values\n",
" lower_bound_number=0\n",
" higher_bound_number=500\n",
" train_size =150\n",
" test_size = 5 \n",
"\n",
" # train and generate centroids/weights\n",
" centroids = fit_predict(\n",
" df=generate_dataframe(low=lower_bound_number, high=higher_bound_number, size=train_size),\n",
" num_centroids=8,\n",
" max_iter=500\n",
" )\n",
"\n",
" # predict cluster index using centroids/weights\n",
" for idx, cluster_index in enumerate(\n",
" predict_cluster(\n",
" centroids=centroids,\n",
" x=generate_random_array(low=lower_bound_number, high=higher_bound_number, size=test_size),\n",
" y=generate_random_array(low=lower_bound_number, high=higher_bound_number, size=test_size)\n",
" )\n",
" ):\n",
" print(f'Sample Number:{idx + 1}, Cluster ID: {cluster_index}')\n",
" \n",
"except ValidationError as exc:\n",
" print(exc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1eada685-13ce-433a-ac54-ecd38d010851",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment