Skip to content

Instantly share code, notes, and snippets.

@monk1337
Created August 22, 2021 12:21
Show Gist options
  • Save monk1337/df510c601f61ff48383d7834d3391a4d to your computer and use it in GitHub Desktop.
Save monk1337/df510c601f61ff48383d7834d3391a4d to your computer and use it in GitHub Desktop.
def get_adj_matrix(df, ml_format = True):
"""where df should be one hot of labels"""
all_categoris = list(df.columns)
label_freq = {}
for i in all_categoris:
label_freq[i] = df[i].value_counts()[1]
u = np.diag(np.ones(df.shape[1], dtype=bool))
adj_m = df.T.dot(df) * (~u)
adj_m = adj_m.to_numpy()
if ml_format:
data = {'adj': adj_m, 'nums': np.array(list(label_freq.values()))}
return data
else:
return adj_m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment