Skip to content

Instantly share code, notes, and snippets.

@flcong
Last active March 26, 2024 02:24
Show Gist options
  • Save flcong/cabff3be5f7d96820d62b7f5e264f779 to your computer and use it in GitHub Desktop.
Save flcong/cabff3be5f7d96820d62b7f5e264f779 to your computer and use it in GitHub Desktop.
Fast Split-Apply-Combine using Numba
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "N8gBDmMgQUK7"
},
"source": [
"# Fast Split-Apply-Combine in Pandas using Numba\n",
"\n",
"Split-Apply-Combine is often used in data analysis. Pandas provides the method `groupby()` followed by `apply()` to achieve this. Users are allowed to write customized functions to be applied to each group. However, as the data size increases, the performance of this approach is unsatisfactory.\n",
"\n",
"In this short notebook, I explore step by step how to improve the performance of this operation. We start from the naive approach and then try to speed up using different methods and finally arrive at using numba.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LGkerJJw-rHR"
},
"source": [
"\n",
"## Generate test data set\n",
"\n",
"First, let's generate a test data set for our subsequent analysis:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "-WLZrqMKRPti",
"outputId": "f146866a-2169-4b6e-f134-483647931aec"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>firm</th>\n",
" <th>bond</th>\n",
" <th>date</th>\n",
" <th>e</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.196340</td>\n",
" <td>0.393604</td>\n",
" <td>9.132385</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>-0.170759</td>\n",
" <td>0.135544</td>\n",
" <td>6.184681</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0.361699</td>\n",
" <td>0.974308</td>\n",
" <td>15.104778</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0.005268</td>\n",
" <td>0.307555</td>\n",
" <td>8.080814</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>-1.307258</td>\n",
" <td>0.422590</td>\n",
" <td>7.918642</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149995</th>\n",
" <td>49</td>\n",
" <td>29</td>\n",
" <td>95</td>\n",
" <td>-0.477026</td>\n",
" <td>0.471655</td>\n",
" <td>9.239529</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149996</th>\n",
" <td>49</td>\n",
" <td>29</td>\n",
" <td>96</td>\n",
" <td>-0.527001</td>\n",
" <td>0.431524</td>\n",
" <td>8.788244</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149997</th>\n",
" <td>49</td>\n",
" <td>29</td>\n",
" <td>97</td>\n",
" <td>-1.330590</td>\n",
" <td>0.857821</td>\n",
" <td>12.247621</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149998</th>\n",
" <td>49</td>\n",
" <td>29</td>\n",
" <td>98</td>\n",
" <td>-0.531367</td>\n",
" <td>0.178954</td>\n",
" <td>6.258172</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149999</th>\n",
" <td>49</td>\n",
" <td>29</td>\n",
" <td>99</td>\n",
" <td>-0.591134</td>\n",
" <td>0.046371</td>\n",
" <td>4.872573</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>149000 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" firm bond date e x y\n",
"0 0 0 0 0.196340 0.393604 9.132385\n",
"1 0 0 1 -0.170759 0.135544 6.184681\n",
"2 0 0 2 0.361699 0.974308 15.104778\n",
"3 0 0 3 0.005268 0.307555 8.080814\n",
"4 0 0 4 -1.307258 0.422590 7.918642\n",
"... ... ... ... ... ... ...\n",
"149995 49 29 95 -0.477026 0.471655 9.239529\n",
"149996 49 29 96 -0.527001 0.431524 8.788244\n",
"149997 49 29 97 -1.330590 0.857821 12.247621\n",
"149998 49 29 98 -0.531367 0.178954 6.258172\n",
"149999 49 29 99 -0.591134 0.046371 4.872573\n",
"\n",
"[149000 rows x 6 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from numba import njit\n",
"import statsmodels.api as sm\n",
"\n",
"# Create the test data set\n",
"nfirms = 50\n",
"nbonds = 30\n",
"ndates = 100\n",
"df = pd.DataFrame(\n",
" {'e': np.random.randn(nfirms*nbonds*ndates),\n",
" 'x': np.random.rand(nfirms*nbonds*ndates)}, \n",
" pd.MultiIndex.from_product(\n",
" [np.arange(nfirms), np.arange(nbonds), np.arange(ndates)],\n",
" names=['firm', 'bond', 'date']\n",
" )\n",
").reset_index(drop=False)\n",
"df['y'] = 5 + df['x'] * 10 + df['e']\n",
"np.random.seed(0)\n",
"# Randomly drop observations\n",
"dropmask = np.random.choice(df.index, size=1000, replace=False)\n",
"df = df.drop(index=dropmask)\n",
"# Randomly set x as NA\n",
"xnamask = np.random.choice(df.index, size=1000, replace=False)\n",
"df.loc[xnamask, 'x'] = np.nan\n",
"# Randomly set y as NA\n",
"ynamask = np.random.choice(df.index, size=1000, replace=False)\n",
"df.loc[ynamask, 'y'] = np.nan\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aYbJGSna9zv_"
},
"source": [
"The columns `firm`, `bond`, and `date` uniquely identify each row (primary keys). You can consider that the data set contains corporate bond market data, such as bond prices. To be as realistic as possible, I randomly drop some rows to break the balance of the data set. For example, some firms may have more bonds than other firms. In addition, I also randomly assign `np.nan` to columns `x` and `y`, since bond prices may be missing in practice.\n",
"\n",
"Our goal is to estimate an OLS regression (with constant) of column `y` on column `x` for each firm on each date and obtain the intercept and coefficient on `x`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MXV1wepP-zku"
},
"source": [
"## Naive approach\n",
"\n",
"We first consider the naive approach using the `groupby()` and `apply()` function from pandas. To estimate OLS regression, we first use statsmodel package.\n",
"\n",
"We define the following function that accept the partitioned DataFrame as input. Before running regressions, we drop any rows if either `y` or `x` is missing. We require at least 10 observations to estimate the regression."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "TCJSKx4Bz_gp"
},
"outputs": [],
"source": [
"# Use statsmodel (version 1)\n",
"def ols_sm_v1(data):\n",
" # Remove NA\n",
" datatmp = data[['y', 'x']].dropna(how='any').assign(const=1)\n",
" N = datatmp.shape[0]\n",
" if N >= 10:\n",
" beta = sm.OLS(datatmp['y'], datatmp[['const', 'x']]).fit().params\n",
" return beta\n",
" else:\n",
" return pd.Series([np.nan, np.nan], index=['const', 'x'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's see the performance of this approach:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 10.8 s\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>const</th>\n",
" <th>x</th>\n",
" </tr>\n",
" <tr>\n",
" <th>firm</th>\n",
" <th>date</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">0</th>\n",
" <th>0</th>\n",
" <td>5.421995</td>\n",
" <td>9.379899</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5.157152</td>\n",
" <td>9.828529</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.691264</td>\n",
" <td>10.590534</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.470283</td>\n",
" <td>10.442855</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4.453084</td>\n",
" <td>10.911452</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">49</th>\n",
" <th>95</th>\n",
" <td>5.410526</td>\n",
" <td>8.602204</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>5.159938</td>\n",
" <td>10.251619</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>4.859723</td>\n",
" <td>9.689384</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>4.643058</td>\n",
" <td>10.934716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>5.197897</td>\n",
" <td>9.376050</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5000 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" const x\n",
"firm date \n",
"0 0 5.421995 9.379899\n",
" 1 5.157152 9.828529\n",
" 2 4.691264 10.590534\n",
" 3 4.470283 10.442855\n",
" 4 4.453084 10.911452\n",
"... ... ...\n",
"49 95 5.410526 8.602204\n",
" 96 5.159938 10.251619\n",
" 97 4.859723 9.689384\n",
" 98 4.643058 10.934716\n",
" 99 5.197897 9.376050\n",
"\n",
"[5000 rows x 2 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time res_sm_v1 = df.groupby(['firm', 'date']).apply(ols_sm_v1)\n",
"res_sm_v1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**About 10 sec!** This seems unacceptably slow! "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## First attempt to improve: Use numpy array\n",
"\n",
"Let's think about how to improve the method without substantial changes. In `ols_sm_v1`, the input is a pandas DataFrame and returns a pandas Series containing regression coefficients. **Maybe we can rewrite the function to accept a numpy array as input and also returns a numpy array, so that the function does not use pandas at all!**\n",
"\n",
"Following this idea, we rewrite the function as follows. The input is a 2d numpy array with two columns: the first one is `y` and the second is `x`. Then, we drop missing values and using statsmodel to estimate the regression. Finally, we return the coefficients as a 1d numpy array with size 2."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Use statsmodel (version 1)\n",
"def ols_sm_v2(data):\n",
" \"\"\"Single-variable OLS regression with constant using statsmodel. \n",
" The first column is y and the second column is x.\"\"\"\n",
" # Remove NA\n",
" nnanmask = ~np.isnan(data).any(axis=1)\n",
" ytmp = data[nnanmask, 0]\n",
" xtmp = data[nnanmask, 1]\n",
" N = ytmp.shape[0]\n",
" if N >= 10:\n",
" X = np.column_stack((np.ones(N), xtmp))\n",
" beta = sm.OLS(ytmp, X).fit().params\n",
" return beta\n",
" else:\n",
" return np.nan * np.zeros(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also need to change the groupby-apply code as well: only the numpy array containing `y` and `x` is sent to the function. In addition, since we return a numpy array from the function, its element will not be automatically populated as separate columns. We have to do this by ourselves."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 2.86 s\n"
]
}
],
"source": [
"%%time \n",
"res_sm_v2 = df.groupby(['firm', 'date']).apply(lambda x: ols_sm_v2(x[['y', 'x']].to_numpy()))\n",
"res_sm_v2 = pd.DataFrame(\n",
" [x for x in res_sm_v2],\n",
" index=res_sm_v2.index,\n",
" columns=['const', 'x']\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**About 2.9 sec! A substantial improvement!** "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we check if the results from the above two methods are indeed the same (very important to check):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pd.testing.assert_frame_equal(res_sm_v1, res_sm_v2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Second attempt: Write OLS function\n",
"\n",
"Now let's see if we can improve more. The `OLS` function from statsmodel gives us not only regression coefficients, but also other estimates, such as standard errors and R-squared. In the current situation, **since we only want the OLS coefficients, using statsmodel may have extra overhead**. Hence, instead of using the package, we can just use linear algebra to calculate coefficients on our own. Recall the OLS estimate:\n",
"$$\n",
" \\beta=(X'X)^{-1}X'y\n",
"$$\n",
"\n",
"Hence, we rewrite the function as follows:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def ols_la(data):\n",
" \"\"\"Single-variable OLS regression with constant using numpy. \n",
" The first column is y and the second column is x.\"\"\"\n",
" # Remove NA\n",
" nnanmask = ~np.isnan(data).any(axis=1)\n",
" ytmp = data[nnanmask, 0]\n",
" xtmp = data[nnanmask, 1]\n",
" N = ytmp.shape[0]\n",
" if N >= 10:\n",
" X = np.column_stack((np.ones(N), xtmp))\n",
" XX = X.T @ X\n",
" beta = np.linalg.inv(XX) @ X.T @ ytmp\n",
" return beta\n",
" else:\n",
" return np.nan * np.zeros(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and run the following code (also check if the result is the same as before):"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 2.02 s\n"
]
}
],
"source": [
"%%time \n",
"res_la = df.groupby(['firm', 'date']).apply(lambda x: ols_la(x[['y', 'x']].to_numpy()))\n",
"res_la = pd.DataFrame(\n",
" [x for x in res_la],\n",
" index=res_la.index,\n",
" columns=['const', 'x']\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**2 sec!** Not a substantial improvement, but still faster. We also check if the result is correct:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"pd.testing.assert_frame_equal(res_sm_v1, res_la)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Third attempt: Use numba to speed up\n",
"\n",
"Maybe you have heard from someone that python is slow, but you can use numba to speed up. I will not cover the mechanism of numba (I am not an expert on this anyway), but you can find more information on their [website](https://numba.readthedocs.io/en/stable/index.html). From my perspective, it is enough that you know the following information and refer to their documentation if necessary:\n",
"\n",
"* Use `from numba import njit` to import the `njit` decorator.\n",
"* Put `@njit` before the function definition.\n",
"* Since the function will be compiled in the first run, the first run will spend more time, but additional usage of the function is much faster.\n",
"* Numba supports a subset of functionalities in python and numpy (**Yes. Pandas is not supported.**) For the list of supported functionalities, refer to their [website](https://numba.readthedocs.io/en/stable/index.html). This is the main drawback of numba.\n",
"* Still, numba supports core functions of numpy, so you don't have to rebuild the wheels (advantage compared with cython).\n",
"* When the function raises error related to unknown/mismatched types, mismatched function signature, etc, it is highly likely that you used functionalities not supported by numba.\n",
"* To write numba functions, forget about vectorization or any performance improvement tricks in python or numpy. Just write \"low-level\" functions like C/C++: Don't hesitate to use loops.\n",
"\n",
"Now, we rewrite the function as follows. Since numba does not support numpy `any()` function with optional arguments, we have to rewrite that code."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NNK4RGFq4Mgp",
"outputId": "33d774a5-567e-4c1a-e850-e61bce205110"
},
"outputs": [],
"source": [
"@njit\n",
"def ols_njit(data):\n",
" \"\"\"Single-variable OLS regression with constant. \n",
" The first column is y and the second column is x.\"\"\"\n",
" # Remove NA\n",
" nnanmask = (~np.isnan(data[:,0])) & (~np.isnan(data[:,1]))\n",
" ytmp = data[nnanmask, 0]\n",
" xtmp = data[nnanmask, 1]\n",
" N = ytmp.shape[0]\n",
" if N >= 10:\n",
" X = np.column_stack((np.ones(N), xtmp))\n",
" XX = X.T @ X\n",
" beta = np.linalg.inv(XX) @ X.T @ ytmp\n",
" return beta\n",
" else:\n",
" return np.nan * np.zeros(2)\n",
"# First run, to compile\n",
"res_njit = df.groupby(['firm', 'date']).apply(lambda x: ols_njit(x[['y', 'x']].to_numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 1.8 s\n"
]
}
],
"source": [
"%%time \n",
"res_njit = df.groupby(['firm', 'date']).apply(lambda x: ols_njit(x[['y', 'x']].to_numpy()))\n",
"res_njit = pd.DataFrame(\n",
" [x for x in res_njit],\n",
" index=res_njit.index,\n",
" columns=['const', 'x']\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**About 1.8 sec!** A slight improvement. We also check if the result is correct:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"pd.testing.assert_frame_equal(res_sm_v1, res_njit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Final attempt: Use numba to groupby and apply\n",
"\n",
"Although we have improved a lot from the naive approach, I still have a hunch that there is more room to improve. Since we have found evidence that pandas is much slower than numpy, the performance could be improved further if we rewrite groupby and apply using numpy, powered by numba.\n",
"\n",
"To do that, we first sort the data set based on `firm` and `date`. Then, we use the `ngroup()` method to generate integers (starting from 0) that identify each group:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>firm</th>\n",
" <th>bond</th>\n",
" <th>date</th>\n",
" <th>e</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>grpid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.196340</td>\n",
" <td>0.393604</td>\n",
" <td>9.132385</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>-0.265015</td>\n",
" <td>0.221148</td>\n",
" <td>6.946461</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>200</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0.206135</td>\n",
" <td>0.699376</td>\n",
" <td>12.199894</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>300</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1.239965</td>\n",
" <td>0.933079</td>\n",
" <td>15.570760</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>400</th>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>-0.781574</td>\n",
" <td>0.645731</td>\n",
" <td>10.675739</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149599</th>\n",
" <td>49</td>\n",
" <td>25</td>\n",
" <td>99</td>\n",
" <td>1.848271</td>\n",
" <td>0.805096</td>\n",
" <td>14.899233</td>\n",
" <td>4999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149699</th>\n",
" <td>49</td>\n",
" <td>26</td>\n",
" <td>99</td>\n",
" <td>-0.292031</td>\n",
" <td>0.830390</td>\n",
" <td>13.011868</td>\n",
" <td>4999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149799</th>\n",
" <td>49</td>\n",
" <td>27</td>\n",
" <td>99</td>\n",
" <td>-0.619243</td>\n",
" <td>0.410581</td>\n",
" <td>8.486571</td>\n",
" <td>4999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149899</th>\n",
" <td>49</td>\n",
" <td>28</td>\n",
" <td>99</td>\n",
" <td>0.386612</td>\n",
" <td>0.068111</td>\n",
" <td>6.067720</td>\n",
" <td>4999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149999</th>\n",
" <td>49</td>\n",
" <td>29</td>\n",
" <td>99</td>\n",
" <td>-0.591134</td>\n",
" <td>0.046371</td>\n",
" <td>4.872573</td>\n",
" <td>4999</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>149000 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" firm bond date e x y grpid\n",
"0 0 0 0 0.196340 0.393604 9.132385 0\n",
"100 0 1 0 -0.265015 0.221148 6.946461 0\n",
"200 0 2 0 0.206135 0.699376 12.199894 0\n",
"300 0 3 0 1.239965 0.933079 15.570760 0\n",
"400 0 4 0 -0.781574 0.645731 10.675739 0\n",
"... ... ... ... ... ... ... ...\n",
"149599 49 25 99 1.848271 0.805096 14.899233 4999\n",
"149699 49 26 99 -0.292031 0.830390 13.011868 4999\n",
"149799 49 27 99 -0.619243 0.410581 8.486571 4999\n",
"149899 49 28 99 0.386612 0.068111 6.067720 4999\n",
"149999 49 29 99 -0.591134 0.046371 4.872573 4999\n",
"\n",
"[149000 rows x 7 columns]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = df.sort_values(['firm', 'date'])\n",
"# Integer to identify each group\n",
"df['grpid'] = df.groupby(['firm', 'date']).ngroup()\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we define the following numba function to apply a function `func` on a 2d numpy array by group."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"@njit\n",
"def groupby_apply_njit(data, func):\n",
" \"\"\"The first column of data is group id. The second is y and the third is x.\"\"\"\n",
" ngroups = int(data[-1,0])+1 # Number of groups\n",
" nrows = data.shape[0] # Number of rows\n",
" reslist = []\n",
" istart = 0\n",
" for k in range(ngroups):\n",
" # Find start and end rows of the group\n",
" # (istart point to the start and iend-1 point to the end\n",
" iend = istart + 1\n",
" while iend < nrows and data[iend-1,0] == data[iend,0]:\n",
" iend += 1\n",
" # Apply the function to the numpy array in the group\n",
" res = func(data[istart:iend,1:])\n",
" reslist.append(np.hstack((np.array([k]), res)))\n",
" # Move to the next group\n",
" istart = iend\n",
" return reslist\n",
"# First run to compile\n",
"resarr = groupby_apply_njit(df[['grpid', 'y', 'x']].to_numpy(), ols_njit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first argument is a 2d numpy array. Its first column contains group numbers. Its second and third columns contain `y` and `x`, respectively. By definition of the `ngroup()` method, we can obtain the number of groups from the last element of the first column of `data` (after adding 1). Then, we loop over each group. In the loop, we find the start row and end row of each group by checking if the group id changes. `istart` points to the start row and `iend` points to the next row below the end row. Then, we extract rows of that group, excluding the first column and send to the function `func`.\n",
"\n",
"In this case, the output from the `groupby_apply_njit` function is a 2d numpy array. Its first column is group id. Its second and third columns are regression coefficients. We need more code to convert this numpy array into a nice-looking pandas DataFrame as before."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 35.5 ms\n"
]
}
],
"source": [
"%%time \n",
"resarr = groupby_apply_njit(df[['grpid', 'y', 'x']].to_numpy(), ols_njit)\n",
"res_ga_njit = df[['firm', 'date', 'grpid']].drop_duplicates().merge(\n",
" pd.DataFrame(resarr, columns=['grpid', 'const', 'x']), \n",
" on='grpid', \n",
" how='left'\n",
").set_index(['firm', 'date']).drop(columns=['grpid'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Only 0.038 sec!** This is a big improvement. We also check if the result is correct:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"pd.testing.assert_frame_equal(res_sm_v1, res_ga_njit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Takeaways\n",
"\n",
"The takeaways from the above exercise are as follows:\n",
"\n",
"* For simple calculations, using third-party packages may introduce extra overhead.\n",
"* Apply a function to GroupBy object is faster if the function does not use pandas at all (only numpy).\n",
"* Numba is faster, especially suitable for numerical calculations, but its coding style is different from python: more like C/C++ and it has limitations.\n",
"* Writing our own numba function to implement groupby-apply has much much better performance.\n",
"\n",
"## Extension\n",
"\n",
"Some extensions for further exploration:\n",
"* In the above example, the applied function takes a numpy array of `float` as input and returns a numpy array of `float` as well. How can we change the code if the function is more general, e.g. take strings as input? Does it affect the performance?\n",
"* You may want to wrap the above code into a function that can be re-used."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Untitled0.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@Moofeng
Copy link

Moofeng commented Mar 7, 2023

Thank you very much. It's just what I need! Looking forward to the extensions

@flcong
Copy link
Author

flcong commented Mar 11, 2023

Thank you very much. It's just what I need! Looking forward to the extensions

I've made public some functions that I frequently use in the pyempfin package, where I implement a function for fast groupby-apply (https://github.com/flcong/pyempfin/blob/72482bb912bca8017749d2cc47b2647f8cf27e8d/pyempfin/datautils.py#L483). Yet, the package is not fully tested nor well documented yet, so more for your reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment