Skip to content

Instantly share code, notes, and snippets.

@xmnlab
Created January 28, 2020 21:17
Show Gist options
  • Save xmnlab/2c1f93df1a6c6bde4e32c8579117e9cc to your computer and use it in GitHub Desktop.
Save xmnlab/2c1f93df1a6c6bde4e32c8579117e9cc to your computer and use it in GitHub Desktop.
NTile function in python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def ntile(x: pd.core.groupby.generic.SeriesGroupBy, bucket: int):\n",
" \"\"\"\n",
" NTILE divides given data set into a number of buckets.\n",
" \n",
" It divides an ordered and grouped data set into a number of buckets\n",
" and assigns the appropriate bucket number to each row. \n",
" \n",
" Return an integer ranging from 0 to `bucket - 1`, dividing the \n",
" partition as equally as possible.\n",
" \n",
" Parameters\n",
" ----------\n",
" x : pandas.core.groupby.generic.SeriesGroupBy ||\n",
" pd.core.groupby.generic.DataFrameGroupBy ||\n",
" pandas.Series || pandas.DataFrame\n",
" bucket: int\n",
" \n",
" Returns\n",
" -------\n",
" pandas.Series\n",
" \"\"\"\n",
" # internal ntile function\n",
" def _ntile(x: pd.Series, bucket: int):\n",
" n = x.shape[0]\n",
" sub_n = n // bucket\n",
" diff = n - (sub_n * bucket)\n",
"\n",
" result = []\n",
" for i in range(bucket):\n",
" sub_result = [i] * (sub_n + (1 if diff else 0))\n",
" result.extend(sub_result)\n",
" if diff > 0:\n",
" diff -= 1\n",
" return pd.Series(result, index=x.index)\n",
" \n",
" \n",
" result = []\n",
" # partition\n",
" if isinstance(\n",
" x, pd.core.groupby.generic.SeriesGroupBy\n",
" ):\n",
" for name, group in x:\n",
" result.append(_ntile(group, bucket))\n",
" elif isinstance(\n",
" x, pd.core.groupby.generic.DataFrameGroupBy\n",
" ):\n",
" for group_id, group in x:\n",
" result.append(_ntile(group.iloc[:, 0], bucket))\n",
" elif isinstance(x, pd.Series):\n",
" result.append(_ntile(x, bucket))\n",
" elif isinstance(x, pd.DataFrame):\n",
" result.append(_ntile(x.iloc[:, 0], bucket))\n",
" else:\n",
" raise TypeError(\n",
" '`x` should be `pandas.Series` or `pandas.DataFrame` or '\n",
" '`pd.core.groupby.generic.SeriesGroupBy` or '\n",
" '`pd.core.groupby.generic.DataFrameGroupBy`, '\n",
" 'not {}.'.format(\n",
" type(x)\n",
" )\n",
" )\n",
" return pd.concat(result)\n",
"\n",
"\"\"\"\n",
"LAST_NAME SALARY QUARTILE\n",
"------------------------- ---------- ----------\n",
"Greenberg 12000 0\n",
"Faviet 9000 0\n",
"Chen 8200 1\n",
"Urman 7800 1\n",
"Sciarra 7700 2\n",
"Popp 6900 3\n",
"\"\"\"\n",
"# data frame and series unit test\n",
"result_s = ntile(pd.Series([12000, 9000, 8200, 7800, 7700, 6900]), 4)\n",
"result_df = ntile(pd.DataFrame({'x': [12000, 9000, 8200, 7800, 7700, 6900]}), 4)\n",
"assert np.all(result_s == result_df)\n",
"assert np.all(result_s.values == [0, 0, 1, 1, 2, 3])\n",
"\n",
"# data frame group and series group unit test\n",
"df = pd.DataFrame({\n",
" 'id': [1000, 1001, 1002, 1003, 1004, 1005],\n",
" 'x': [12000, 9000, 8200, 7800, 7700, 6900],\n",
" 'cat': ['A', 'A', 'B', 'B', 'B', 'C']\n",
"})\n",
"\n",
"gdf = df.sort_values('x').groupby('cat')\n",
"\n",
"result_gs = ntile(gdf.id, 4)\n",
"result_gdf = ntile(gdf, 4)\n",
"\n",
"# assert np.all(result_s == result_df)\n",
"# print('=' * 40)\n",
"# print(result_gs)\n",
"# print('=' * 40)\n",
"# print(result_gdf)\n",
"assert np.all(result_gs == result_gdf)\n",
"assert np.all(result_gs.values == [0, 1, 0, 1, 2, 0])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>year</th>\n",
" <th>name</th>\n",
" <th>amount</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2018</td>\n",
" <td>Jack Daniel</td>\n",
" <td>150000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2018</td>\n",
" <td>Jane Johnson</td>\n",
" <td>110000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2018</td>\n",
" <td>John Doe</td>\n",
" <td>120000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2018</td>\n",
" <td>Stephane Heady</td>\n",
" <td>200000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2018</td>\n",
" <td>Yin Yang</td>\n",
" <td>30000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2019</td>\n",
" <td>Jack Daniel</td>\n",
" <td>180000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>2019</td>\n",
" <td>Jane Johnson</td>\n",
" <td>130000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>2019</td>\n",
" <td>John Doe</td>\n",
" <td>150000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>2019</td>\n",
" <td>Stephane Heady</td>\n",
" <td>270000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>2019</td>\n",
" <td>Yin Yang</td>\n",
" <td>25000.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" year name amount\n",
"0 2018 Jack Daniel 150000.0\n",
"1 2018 Jane Johnson 110000.0\n",
"2 2018 John Doe 120000.0\n",
"3 2018 Stephane Heady 200000.0\n",
"4 2018 Yin Yang 30000.0\n",
"5 2019 Jack Daniel 180000.0\n",
"6 2019 Jane Johnson 130000.0\n",
"7 2019 John Doe 150000.0\n",
"8 2019 Stephane Heady 270000.0\n",
"9 2019 Yin Yang 25000.0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# https://www.postgresqltutorial.com/postgresql-ntile-function/\n",
"df = pd.DataFrame([\n",
" {'year': 2018, 'name': 'Jack Daniel', 'amount': 150000.0},\n",
" {'year': 2018, 'name': 'Jane Johnson', 'amount': 110000.0},\n",
" {'year': 2018, 'name': 'John Doe', 'amount': 120000.0},\n",
" {'year': 2018, 'name': 'Stephane Heady', 'amount': 200000.0},\n",
" {'year': 2018, 'name': 'Yin Yang', 'amount': 30000.0},\n",
" {'year': 2019, 'name': 'Jack Daniel', 'amount': 180000.0},\n",
" {'year': 2019, 'name': 'Jane Johnson', 'amount': 130000.0},\n",
" {'year': 2019, 'name': 'John Doe', 'amount': 150000.0},\n",
" {'year': 2019, 'name': 'Stephane Heady', 'amount': 270000.0},\n",
" {'year': 2019, 'name': 'Yin Yang', 'amount': 25000.0},\n",
"])\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"SELECT \n",
" name,\n",
" amount,\n",
" NTILE(3) OVER(\n",
" ORDER BY amount\n",
" )\n",
"FROM\n",
" sales_stats\n",
"WHERE\n",
" year = 2019;\n",
"\"\"\"\n",
"\n",
"# starting with 0, postgresql starts with 1\n",
"df_expected = pd.DataFrame([\n",
" {'name': 'Yin Yang', 'amount': 25000.0, 'ntile': 0},\n",
" {'name': 'Jane Johnson', 'amount': 130000.0, 'ntile': 0},\n",
" {'name': 'John Doe', 'amount': 150000.0, 'ntile': 1},\n",
" {'name': 'Jack Daniel', 'amount': 180000.0, 'ntile': 1},\n",
" {'name': 'Stephane Heady', 'amount': 270000.0, 'ntile': 2},\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" year name amount ntile\n",
"4 2018 Yin Yang 30000.0 NaN\n",
"1 2018 Jane Johnson 110000.0 NaN\n",
"2 2018 John Doe 120000.0 NaN\n",
"0 2018 Jack Daniel 150000.0 NaN\n",
"3 2018 Stephane Heady 200000.0 NaN\n",
"9 2019 Yin Yang 25000.0 0.0\n",
"6 2019 Jane Johnson 130000.0 0.0\n",
"7 2019 John Doe 150000.0 1.0\n",
"5 2019 Jack Daniel 180000.0 2.0\n",
"8 2019 Stephane Heady 270000.0 2.0\n"
]
}
],
"source": [
"# pd.qcut(x, q, labels=None, retbins=False, precision=3, duplicates='raise')\n",
"result = df.copy()\n",
"x = result[result.year==2019].sort_values('amount').amount\n",
"result['ntile'] = pd.qcut(x, q=3, labels=False)\n",
"result.sort_values(['year', 'amount'], inplace=True)\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" year name amount ntile\n",
"4 2018 Yin Yang 30000.0 NaN\n",
"1 2018 Jane Johnson 110000.0 NaN\n",
"2 2018 John Doe 120000.0 NaN\n",
"0 2018 Jack Daniel 150000.0 NaN\n",
"3 2018 Stephane Heady 200000.0 NaN\n",
"9 2019 Yin Yang 25000.0 0.0\n",
"6 2019 Jane Johnson 130000.0 0.0\n",
"7 2019 John Doe 150000.0 1.0\n",
"5 2019 Jack Daniel 180000.0 1.0\n",
"8 2019 Stephane Heady 270000.0 2.0\n"
]
}
],
"source": [
"result2 = df.copy()\n",
"x = result2[result2.year == 2019].sort_values('amount').amount\n",
"result['ntile'] = ntile(x, 3)\n",
"result.sort_values(['year', 'amount'], inplace=True)\n",
"print(result)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment