Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save peeush-agarwal/0708916b2cd40f44c75cafc28063f2f8 to your computer and use it in GitHub Desktop.
Save peeush-agarwal/0708916b2cd40f44c75cafc28063f2f8 to your computer and use it in GitHub Desktop.
Handle multicollinearity using VIF and dropping correlated columns with higher VIF value
def multicollinearity_by_vif(X, vif=5):
"""Remove columns from X whose VIF is greater than supplied 'vif'
Parameters:
X:array or dataframe containing data excluding target variable
vif: int or float of limiting value of VIF
Note:
This function changes X inplace
"""
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
# Go through each column
for i in range(len(X.columns)):
# View which columns are left
print(f"Columns remaining at iteration {i}: {X.columns}")
# Calculate VIF
l = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
s = pd.Series(index=X.columns, data=l).sort_values(ascending=False)
# If VIF is above our threshold, eliminate the column with the highest VIF
if s.iloc[0] > vif:
X.drop(s.index[0], axis=1, inplace=True)
print('Removed: ', s.index[0],', VIF: ', s.iloc[0])
else:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment