Skip to content

Instantly share code, notes, and snippets.

@Swarchal
Last active January 30, 2022 04:34
Show Gist options
  • Save Swarchal/e29a3a1113403710b6850590641f046c to your computer and use it in GitHub Desktop.
Save Swarchal/e29a3a1113403710b6850590641f046c to your computer and use it in GitHub Desktop.
remove redundant columns in pandas dataframe
import pandas as pd
import numpy as np
def find_correlation(data, threshold=0.9, remove_negative=False):
"""
Given a numeric pd.DataFrame, this will find highly correlated features,
and return a list of features to remove.
Parameters
-----------
data : pandas DataFrame
DataFrame
threshold : float
correlation threshold, will remove one of pairs of features with a
correlation greater than this value.
remove_negative: Boolean
If true then features which are highly negatively correlated will
also be returned for removal.
Returns
--------
select_flat : list
listof column names to be removed
"""
corr_mat = data.corr()
if remove_negative:
corr_mat = np.abs(corr_mat)
corr_mat.loc[:, :] = np.tril(corr_mat, k=-1)
already_in = set()
result = []
for col in corr_mat:
perfect_corr = corr_mat[col][corr_mat[col] > threshold].index.tolist()
if perfect_corr and col not in already_in:
already_in.update(set(perfect_corr))
perfect_corr.append(col)
result.append(perfect_corr)
select_nested = [f[1:] for f in result]
select_flat = [i for j in select_nested for i in j]
return select_flat
@ziqueiros
Copy link

I think this code has a severe bug, just take the sample from elvinaqa on this same list of comments is working better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment