Created
May 12, 2024 07:48
-
-
Save peeush-agarwal/0708916b2cd40f44c75cafc28063f2f8 to your computer and use it in GitHub Desktop.
Handle multicollinearity using VIF and dropping correlated columns with higher VIF value
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def multicollinearity_by_vif(X, vif=5): | |
"""Remove columns from X whose VIF is greater than supplied 'vif' | |
Parameters: | |
X:array or dataframe containing data excluding target variable | |
vif: int or float of limiting value of VIF | |
Note: | |
This function changes X inplace | |
""" | |
import statsmodels.api as sm | |
from statsmodels.stats.outliers_influence import variance_inflation_factor | |
# Go through each column | |
for i in range(len(X.columns)): | |
# View which columns are left | |
print(f"Columns remaining at iteration {i}: {X.columns}") | |
# Calculate VIF | |
l = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])] | |
s = pd.Series(index=X.columns, data=l).sort_values(ascending=False) | |
# If VIF is above our threshold, eliminate the column with the highest VIF | |
if s.iloc[0] > vif: | |
X.drop(s.index[0], axis=1, inplace=True) | |
print('Removed: ', s.index[0],', VIF: ', s.iloc[0]) | |
else: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment