Created
October 14, 2024 17:18
-
-
Save albertbuchard/84ecb3e77d499aaf31a984aaac8bfb67 to your computer and use it in GitHub Desktop.
Partial Correlation Algorithms Evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import networkx as nx | |
from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd | |
from npeet_plus import mi, mi_pvalue | |
from tqdm import tqdm | |
from src.causal_discovery.static_causal_discovery import ( | |
run_causal_discovery, | |
visualize_causal_graph, | |
) | |
def compute_correlations(df): | |
""" | |
Compute the Pearson correlation matrix from a pandas DataFrame. | |
Parameters: | |
- df (pd.DataFrame): The input data. | |
Returns: | |
- corr_matrix (np.ndarray): The Pearson correlation matrix. | |
""" | |
return df.corr().values | |
import numpy as np | |
import pandas as pd | |
from sklearn.covariance import GraphicalLassoCV | |
from itertools import combinations | |
from scipy import stats | |
from scipy.stats import spearmanr | |
def compute_partial_correlations_corrected( | |
df, | |
method="glasso", | |
alpha=0.05, | |
correction_method="pearson", | |
verbose=False, | |
epsilon=1e-6, | |
**kwargs, | |
): | |
""" | |
Compute the partial correlation matrix from a pandas DataFrame, | |
detecting and correcting for edges due to collider biases. | |
Parameters: | |
- df (pd.DataFrame): The input data. | |
- method (str): The method to compute the precision matrix ('glasso' supported). | |
- alpha (float): Significance level for statistical tests. | |
- correction_method (str): The correlation method to use to correct for collider bias ('pearson', 'spearman', 'conditional_mi_kci'). | |
- verbose (bool): Whether to print progress messages. | |
- epsilon (float): Smallest value considered a non-zero partial correlation. | |
- **kwargs: Additional keyword arguments for the estimator. | |
Returns: | |
- corrected_partial_corr_matrix (np.ndarray): The partial correlation matrix corrected for collider bias. | |
""" | |
# Step 1: Compute the partial correlation matrix | |
print("Computing partial correlations...") | |
partial_corr_matrix = compute_partial_correlations( | |
df, method, verbose=verbose, **kwargs | |
) | |
# Initialize corrected partial correlation matrix | |
print("Correcting for collider bias...") | |
corrected_partial_corr_matrix = partial_corr_matrix.copy() | |
# Step 2: Detect and correct for colliders | |
n_vars = df.shape[1] | |
variables = df.columns.tolist() | |
# Iterate over all combinations of triplets | |
tqdm_bar = tqdm(combinations(range(n_vars), 3), desc="Detecting colliders") | |
for i, j, k in tqdm_bar: | |
if corrected_partial_corr_matrix[i, j] < epsilon: | |
corrected_partial_corr_matrix[i, j] = 0 | |
corrected_partial_corr_matrix[j, i] = 0 | |
continue | |
tqdm_bar.set_postfix( | |
{"Edge": f"{variables[i]} - {variables[j]} | {variables[k]}"} | |
) | |
X = df.iloc[:, i] | |
Y = df.iloc[:, j] | |
Z = df.iloc[:, k] | |
# Compute zero-order correlations | |
if correction_method == "pearson": | |
r_xy, p_xy = stats.pearsonr(X, Y) | |
elif correction_method == "spearman": | |
r_xy, p_xy = spearmanr(X, Y) | |
elif correction_method == "conditional_mi_kci": | |
X = X.values.reshape(-1, 1) | |
Y = Y.values.reshape(-1, 1) | |
Z = Z.values.reshape(-1, 1) | |
kci = KCI_UInd(null_ss=1000, approx=True) | |
p_xy, r_xy = kci.compute_pvalue(data_x=X, data_y=Y) | |
elif correction_method == "conditional_mi_nonparametric": | |
X = X.values | |
Y = Y.values | |
Z = Z.values | |
r_xy, p_xy, _ = mi_pvalue(X, Y, alpha=0.25) | |
else: | |
raise ValueError( | |
"Invalid correlation method. Choose 'pearson', 'spearman', or 'conditional_mi_kci'." | |
) | |
# Compute partial correlations controlling for Z | |
r_xy_z, p_xy_z = partial_corr_test(X, Y, Z, correction_method) | |
# Check for collider bias: If X and Y are independent (p > alpha) | |
# but become dependent when conditioning on Z (p < alpha) | |
if p_xy > alpha and p_xy_z < alpha: | |
if verbose: | |
print( | |
f"Collider detected: {variables[i]}, {variables[j]} | {variables[k]}" | |
) | |
# Spurious association due to collider at Z | |
# Remove edge between X and Y in the partial correlation matrix | |
corrected_partial_corr_matrix[i, j] = 0 | |
corrected_partial_corr_matrix[j, i] = 0 | |
return corrected_partial_corr_matrix | |
def partial_corr_test(X, Y, Z, correlation_method="pearson"): | |
""" | |
Compute the partial correlation between X and Y, controlling for Z, | |
and perform a statistical test. | |
Parameters: | |
- X, Y, Z: pd.Series or np.ndarray. The variables involved. | |
- correlation_method (str): The correlation method to use ('pearson', 'spearman', 'conditional_mi_kci', "conditional_mi_nonparametric"). | |
Returns: | |
- r (float): Partial correlation coefficient. | |
- p_value (float): Two-tailed p-value. | |
""" | |
if isinstance(X, pd.Series): | |
X = X.values | |
if isinstance(Y, pd.Series): | |
Y = Y.values | |
if isinstance(Z, pd.Series): | |
Z = Z.values | |
if correlation_method == "pearson" or correlation_method == "spearman": | |
# Regress X on Z | |
beta_xz = np.linalg.lstsq(Z.reshape(-1, 1), X, rcond=None)[0] | |
residuals_x = X - Z * beta_xz | |
# Regress Y on Z | |
beta_yz = np.linalg.lstsq(Z.reshape(-1, 1), Y, rcond=None)[0] | |
residuals_y = Y - Z * beta_yz | |
# Compute correlation between residuals | |
if correlation_method == "pearson": | |
r, p_value = stats.pearsonr(residuals_x, residuals_y) | |
else: | |
r, p_value = spearmanr(residuals_x, residuals_y) | |
elif correlation_method == "conditional_mi_kci": | |
kci_test = KCI_CInd(nullss=5000, approx=True) | |
# Perform the KCI test | |
p_value, r = kci_test.compute_pvalue(data_x=X, data_y=Y, data_z=Z) | |
elif correlation_method == "conditional_mi_nonparametric": | |
r, p_value, _ = mi_pvalue(X, Y, Z, alpha=0.25) | |
else: | |
raise ValueError( | |
"Invalid correlation method. Choose 'pearson', 'spearman', or 'conditional_mi_kci'." | |
) | |
return r, p_value | |
def compute_partial_correlation_kci(df, null_ss=5000, alpha=0.05, verbose=False): | |
""" | |
Compute the non-linear partial correlation matrix using the KCI_CInd test. | |
Parameters: | |
- df: pandas DataFrame with continuous variables. | |
- null_ss: Number of samples in simulating the null distribution (default=5000). | |
- alpha: Significance level for the hypothesis test (default=0.05). | |
- verbose: Whether to display progress messages (default=False). | |
Returns: | |
- stat_matrix: numpy array representing the test statistics matrix. | |
- pvalue_matrix: numpy array representing the p-values matrix. | |
- significance_matrix: numpy array indicating significant partial correlations (1 if significant, 0 otherwise). | |
""" | |
n_vars = df.shape[1] | |
columns = df.columns | |
stat_matrix = np.zeros((n_vars, n_vars)) | |
pvalue_matrix = np.ones((n_vars, n_vars)) | |
significance_matrix = np.zeros((n_vars, n_vars)) | |
np.fill_diagonal(stat_matrix, 0.0) | |
np.fill_diagonal(pvalue_matrix, 0.0) | |
np.fill_diagonal(significance_matrix, 0.0) | |
if verbose: | |
iterator = tqdm(range(n_vars), desc="Variables") | |
else: | |
iterator = range(n_vars) | |
for i in iterator: | |
for j in range(i + 1, n_vars): | |
X = df.iloc[:, i].values.reshape(-1, 1) | |
Y = df.iloc[:, j].values.reshape(-1, 1) | |
Z_columns = [col for k, col in enumerate(columns) if k != i and k != j] | |
if Z_columns: | |
Z = df[Z_columns].values | |
else: | |
Z = None | |
kci_test = KCI_CInd(nullss=null_ss, approx=True) | |
# Perform the KCI test | |
pvalue, test_stat = kci_test.compute_pvalue(data_x=X, data_y=Y, data_z=Z) | |
# Store the test statistic and p-value | |
stat_matrix[i, j] = test_stat | |
stat_matrix[j, i] = test_stat | |
pvalue_matrix[i, j] = pvalue | |
pvalue_matrix[j, i] = pvalue | |
# Determine significance based on alpha | |
if pvalue < alpha: | |
significance_matrix[i, j] = 1 | |
significance_matrix[j, i] = 1 | |
return stat_matrix, pvalue_matrix, significance_matrix | |
def compute_partial_correlations( | |
df, method="glasso", graphical_lasso_args=None, verbose=False, n_permutations=10 | |
): | |
""" | |
Compute the partial correlation matrix from a pandas DataFrame. | |
Parameters: | |
- df (pd.DataFrame): The input data. | |
- method (str): The method to compute the partial correlation matrix ('glasso', 'pearson', 'spearman', 'conditional_mi_kci', "conditional_mi_nonparametric"). | |
- graphical_lasso_args (dict): Additional arguments for GraphicalLassoCV. | |
- verbose (bool): Whether to print progress messages. | |
- n_permutations (int): Number of permutations for the permutation test. | |
Returns: | |
- partial_corr_matrix (np.ndarray): The partial correlation matrix. | |
""" | |
if method == "glasso": | |
graphical_lasso_args = graphical_lasso_args or {} | |
model = GraphicalLassoCV(**graphical_lasso_args) | |
model.fit(df) | |
precision_matrix = model.precision_ | |
# Convert precision matrix to partial correlations | |
d = np.sqrt(np.diag(precision_matrix)) | |
partial_corr_matrix = -precision_matrix / np.outer(d, d) | |
np.fill_diagonal(partial_corr_matrix, 1) | |
elif method in ["pearson", "spearman"]: | |
# Compute residuals using linear regression for pearson or spearman correlation | |
residuals = pd.DataFrame(index=df.index, columns=df.columns) | |
for target in df.columns: | |
predictors = df.drop(columns=[target]) | |
beta = np.linalg.lstsq(predictors, df[target], rcond=None)[0] | |
residuals[target] = df[target] - predictors.dot(beta) | |
# Compute correlations | |
if method == "pearson": | |
partial_corr_matrix = residuals.corr(method="pearson").values | |
else: | |
partial_corr_matrix = residuals.corr(method="spearman").values | |
elif method == "conditional_mi_nonparametric": | |
n_vars = df.shape[1] | |
columns = df.columns | |
partial_corr_matrix = np.zeros((n_vars, n_vars)) | |
np.fill_diagonal(partial_corr_matrix, 1.0) | |
for i in range(n_vars): | |
for j in range(i + 1, n_vars): | |
x = df.iloc[:, i].values | |
y = df.iloc[:, j].values | |
z_columns = [col for k, col in enumerate(columns) if k != i and k != j] | |
z = df[z_columns].values | |
cmi = mi(x, y, z, k=3, alpha=0.25) | |
if cmi < 0: | |
cmi = 0 | |
partial_corr_matrix[i, j] = cmi | |
partial_corr_matrix[j, i] = cmi | |
elif method == "conditional_mi_kci": | |
stat_matrix, pvalue_matrix, significance_matrix = ( | |
compute_partial_correlation_kci(df, verbose=verbose) | |
) | |
partial_corr_matrix = significance_matrix * stat_matrix | |
elif method == "nonlinear": | |
raise NotImplementedError | |
# n_vars = df.shape[1] | |
# columns = df.columns | |
# | |
# # Step 1: Compute observed residuals and mutual information | |
# if verbose: | |
# print("Computing observed residuals and mutual information...") | |
# mi_observed = {} | |
# for i in tqdm(range(n_vars), desc="Variables"): | |
# for j in range(i + 1, n_vars): | |
# # Define control variables (all others except X_i and X_j) | |
# control_vars = [ | |
# col for k, col in enumerate(columns) if k != i and k != j | |
# ] | |
# | |
# # Regress X_i on control_vars (excluding X_j) | |
# x_i = df.iloc[:, i] | |
# x_control_i = df.loc[:, control_vars] | |
# model_i = XGBRegressor() | |
# model_i.fit(x_control_i, x_i) | |
# residuals_i = x_i - model_i.predict(x_control_i) | |
# | |
# # Regress X_j on control_vars (excluding X_i) | |
# x_j = df.iloc[:, j] | |
# x_control_j = df.loc[:, control_vars] | |
# model_j = XGBRegressor() | |
# model_j.fit(x_control_j, x_j) | |
# residuals_j = x_j - model_j.predict(x_control_j) | |
# | |
# # Compute mutual information between residuals | |
# mi = mutual_info_regression( | |
# residuals_i.values.reshape(-1, 1), | |
# residuals_j, | |
# discrete_features=False, | |
# )[0] | |
# mi_observed[(i, j)] = mi | |
# | |
# # Step 2: Compute permutation distribution by shuffling raw data | |
# if n_permutations is None: | |
# print("Threshold set to 0.001") | |
# threshold = 0.001 | |
# else: | |
# if verbose: | |
# print("Computing permutation distribution...") | |
# mi_permuted = [] | |
# for n in tqdm(range(n_permutations), desc="Permutations"): | |
# # Shuffle the data for each variable independently | |
# df_shuffled = df.apply(np.random.permutation) | |
# j = np.random.randint(n_vars) | |
# for i in range(n_vars): | |
# if i == j: | |
# continue | |
# # for j in range(i + 1, n_vars): | |
# control_vars = [ | |
# col for k, col in enumerate(columns) if k != i and k != j | |
# ] | |
# | |
# # Regress shuffled X_i on shuffled control_vars (excluding X_j) | |
# x_i_shuffled = df_shuffled.iloc[:, i] | |
# x_control_i_shuffled = df_shuffled.loc[:, control_vars] | |
# model_i = XGBRegressor() | |
# model_i.fit(x_control_i_shuffled, x_i_shuffled) | |
# residuals_i_shuffled = x_i_shuffled - model_i.predict( | |
# x_control_i_shuffled | |
# ) | |
# | |
# # Regress shuffled X_j on shuffled control_vars (excluding X_i) | |
# x_j_shuffled = df_shuffled.iloc[:, j] | |
# x_control_j_shuffled = df_shuffled.loc[:, control_vars] | |
# model_j = XGBRegressor() | |
# model_j.fit(x_control_j_shuffled, x_j_shuffled) | |
# residuals_j_shuffled = x_j_shuffled - model_j.predict( | |
# x_control_j_shuffled | |
# ) | |
# | |
# # Compute mutual information between residuals | |
# mi_perm = mutual_info_regression( | |
# residuals_i_shuffled.values.reshape(-1, 1), | |
# residuals_j_shuffled, | |
# discrete_features=False, | |
# )[0] | |
# mi_permuted.append(mi_perm) | |
# | |
# mi_permuted = np.array(mi_permuted) | |
# | |
# # Step 3: Compute a global threshold from the permutation distribution | |
# threshold = np.percentile(mi_permuted, 95) | |
# | |
# # Step 4: Build the partial correlation matrix | |
# partial_corr_matrix = np.zeros((n_vars, n_vars)) | |
# np.fill_diagonal(partial_corr_matrix, 1.0) | |
# | |
# for (i, j), mi in mi_observed.items(): | |
# if mi > threshold: | |
# partial_corr_matrix[i, j] = mi | |
# partial_corr_matrix[j, i] = mi | |
# else: | |
# partial_corr_matrix[i, j] = 0 | |
# partial_corr_matrix[j, i] = 0 | |
# for (i, j), mi in mi_observed.items(): | |
# p_value = np.mean(mi_permuted >= mi) | |
# if p_value < 0.05: | |
# partial_corr_matrix[i, j] = mi | |
# partial_corr_matrix[j, i] = mi | |
# else: | |
# partial_corr_matrix[i, j] = 0 | |
# partial_corr_matrix[j, i] = 0 | |
else: | |
raise ValueError( | |
"Invalid method: choose 'glasso'', 'pearson', 'spearman', or 'conditional_mi_kci'." | |
) | |
return partial_corr_matrix | |
def plot_correlation_graph( | |
corr_matrix, | |
labels=None, | |
threshold=0.1, | |
layout="spring", | |
node_size=3000, | |
node_color="skyblue", | |
font_size=12, | |
edge_cmap="coolwarm", | |
edge_vmin=-1, | |
edge_vmax=1, | |
with_edge_labels=True, | |
figsize=(10, 10), | |
min_edge_width=1, | |
max_edge_width=5, | |
title=None, | |
node_order=None, | |
auto_order=False, | |
edge_width=None, | |
node_kwargs=None, | |
edge_kwargs=None, | |
): | |
""" | |
Plot a correlation graph using NetworkX. | |
Parameters: | |
- corr_matrix (np.ndarray): The correlation or partial correlation matrix. | |
- labels (list): Variable names corresponding to the matrix. | |
- threshold (float): Minimum absolute value for edges to be included. | |
- layout (str): Layout for the graph ('spring', 'circular', 'hierarchical', 'grid'). | |
- node_size (int): Size of the nodes. | |
- node_color (str or list): Color of the nodes. | |
- font_size (int): Font size for labels. | |
- edge_cmap (str): Colormap for the edges. | |
- edge_vmin (float): Minimum value for edge colormap. | |
- edge_vmax (float): Maximum value for edge colormap. | |
- with_edge_labels (bool): Whether to display edge labels. | |
- figsize (tuple): Figure size. | |
- min_edge_width (float): Minimum edge width. | |
- max_edge_width (float): Maximum edge width. | |
- title (str): Title of the plot. | |
- node_order (list): Predefined ordering of nodes (used in 'hierarchical' and 'grid' layouts). | |
- auto_order (bool): Automatically order nodes based on number of edges over threshold. | |
- edge_width (float or list): Custom edge width or list of widths. | |
- node_kwargs (dict): Additional keyword arguments for drawing nodes. | |
- edge_kwargs (dict): Additional keyword arguments for drawing edges. | |
""" | |
if node_kwargs is None: | |
node_kwargs = {} | |
if edge_kwargs is None: | |
edge_kwargs = {} | |
if isinstance(corr_matrix, pd.DataFrame): | |
if labels is None: | |
labels = corr_matrix.columns | |
corr_matrix = corr_matrix.values | |
if corr_matrix.shape[0] != corr_matrix.shape[1]: | |
raise ValueError("Correlation matrix must be square.") | |
G = nx.Graph() | |
if labels is None: | |
labels = [f"Var{i}" for i in range(corr_matrix.shape[0])] | |
# Add nodes | |
G.add_nodes_from(labels) | |
# Add edges with weights above the threshold | |
if threshold is None: | |
threshold = -np.inf | |
for i in range(len(labels)): | |
for j in range(i + 1, len(labels)): | |
weight = corr_matrix[i, j] | |
if abs(weight) >= threshold: | |
G.add_edge(labels[i], labels[j], weight=weight) | |
# Automatically define node_order based on degree if auto_order is True | |
if auto_order: | |
# Calculate degrees (number of edges over threshold) | |
degrees = dict(G.degree()) | |
# Sort nodes by degree (descending order) | |
node_order = sorted(degrees, key=lambda x: degrees[x], reverse=True) | |
# Choose layout | |
if layout == "spring": | |
pos = nx.spring_layout(G, seed=42) | |
elif layout == "circular": | |
pos = nx.circular_layout(G) | |
elif layout == "hierarchical": | |
# Use the undirected graph G for hierarchical layout | |
if node_order is None or len(node_order) == 0: | |
# Use node with highest degree as root | |
degrees = dict(G.degree()) | |
root_node = max(degrees, key=degrees.get) | |
else: | |
root_node = node_order[0] | |
# Use graphviz_layout with 'dot' for hierarchical layout | |
try: | |
pos = nx.nx_agraph.graphviz_layout(G, prog="dot", root=root_node) | |
except (ImportError, nx.NetworkXException): | |
# Fallback to spring layout if graphviz is not available | |
print("Graphviz layout not available, using spring layout") | |
pos = nx.spring_layout(G, seed=42) | |
elif layout == "grid": | |
if node_order is None: | |
node_order = labels | |
# Arrange nodes in a grid | |
sqrt_n = int(np.ceil(np.sqrt(len(node_order)))) | |
grid_positions = {} | |
for idx, node in enumerate(node_order): | |
row = idx // sqrt_n | |
col = idx % sqrt_n | |
grid_positions[node] = (col, -row) | |
pos = grid_positions | |
else: | |
raise ValueError( | |
"Invalid layout. Choose 'spring', 'circular', 'hierarchical', or 'grid'." | |
) | |
# Get edge weights for coloring and widths | |
edges = G.edges(data=True) | |
edge_weights = [] | |
for u, v, data in edges: | |
if "weight" in data: | |
edge_weights.append(data["weight"]) | |
else: | |
edge_weights.append(0) # Default weight if missing | |
if edge_width is None: | |
# Normalize edge widths based on correlation magnitude | |
abs_weights = [abs(w) for w in edge_weights] | |
min_weight = min(abs_weights) if abs_weights else 0 | |
max_weight = max(abs_weights) if abs_weights else 1 # Avoid division by zero | |
if max_weight == min_weight: | |
widths = [max_edge_width] * len(abs_weights) | |
else: | |
widths = [ | |
min_edge_width | |
+ (abs(w) - min_weight) | |
/ (max_weight - min_weight) | |
* (max_edge_width - min_edge_width) | |
for w in abs_weights | |
] | |
else: | |
widths = edge_width | |
# Draw the graph | |
plt.figure(figsize=figsize) | |
nx.draw_networkx_nodes( | |
G, pos, node_size=node_size, node_color=node_color, **node_kwargs | |
) | |
nx.draw_networkx_labels(G, pos, font_size=font_size, font_weight="bold") | |
# Draw edges (undirected) | |
nx.draw_networkx_edges( | |
G, | |
pos, | |
edge_color=edge_weights, | |
edge_cmap=plt.get_cmap(edge_cmap), | |
edge_vmin=edge_vmin, | |
edge_vmax=edge_vmax, | |
width=widths, | |
**edge_kwargs, | |
) | |
# Draw edge labels if required | |
if with_edge_labels: | |
edge_labels = {} | |
for u, v, data in edges: | |
weight = data.get("weight", 0) | |
edge_labels[(u, v)] = f"{weight:.2f}" | |
nx.draw_networkx_edge_labels( | |
G, pos, edge_labels=edge_labels, font_size=font_size - 2 | |
) | |
plt.axis("off") | |
if title: | |
plt.title(title, fontsize=font_size + 2) | |
plt.show() | |
def plot_correlation_graph_bak( | |
corr_matrix, | |
labels=None, | |
threshold=0.1, | |
layout="spring", | |
node_size=3000, | |
node_color="skyblue", | |
font_size=12, | |
edge_cmap="coolwarm", | |
edge_vmin=-1, | |
edge_vmax=1, | |
with_edge_labels=True, | |
figsize=(10, 10), | |
min_edge_width=1, | |
max_edge_width=5, | |
title=None, | |
node_order=None, | |
auto_order=False, | |
edge_width=None, | |
**kwargs, | |
): | |
""" | |
Plot a correlation graph using NetworkX. | |
Parameters: | |
- corr_matrix (np.ndarray): The correlation or partial correlation matrix. | |
- labels (list): Variable names corresponding to the matrix. | |
- threshold (float): Minimum absolute value for edges to be included. | |
- layout (str): Layout for the graph ('spring', 'circular', 'tree', 'grid'). | |
- node_size (int): Size of the nodes. | |
- node_color (str or list): Color of the nodes. | |
- font_size (int): Font size for labels. | |
- edge_cmap (str): Colormap for the edges. | |
- edge_vmin (float): Minimum value for edge colormap. | |
- edge_vmax (float): Maximum value for edge colormap. | |
- with_edge_labels (bool): Whether to display edge labels. | |
- figsize (tuple): Figure size. | |
- min_edge_width (float): Minimum edge width. | |
- max_edge_width (float): Maximum edge width. | |
- title (str): Title of the plot. | |
- node_order (list): Predefined ordering of nodes (used in 'tree' and 'grid' layouts). | |
- auto_order (bool): Automatically order nodes based on number of edges over threshold. | |
- **kwargs: Additional keyword arguments for NetworkX drawing functions. | |
""" | |
G = nx.Graph() | |
if labels is None: | |
labels = [f"Var{i}" for i in range(corr_matrix.shape[0])] | |
# Add nodes | |
G.add_nodes_from(labels) | |
# Add edges with weights above the threshold | |
for i in range(len(labels)): | |
for j in range(i + 1, len(labels)): | |
weight = corr_matrix[i, j] | |
if abs(weight) >= threshold: | |
G.add_edge(labels[i], labels[j], weight=weight) | |
# Automatically define node_order based on degree if auto_order is True | |
if auto_order: | |
# Calculate degrees (number of edges over threshold) | |
degrees = dict(G.degree()) | |
# Sort nodes by degree (descending order) | |
node_order = sorted(degrees, key=lambda x: degrees[x], reverse=True) | |
# Choose layout | |
if layout == "spring": | |
pos = nx.spring_layout(G, seed=42) | |
elif layout == "circular": | |
pos = nx.circular_layout(G) | |
elif layout == "tree": | |
if node_order is None: | |
node_order = labels | |
# Assign layers based on node_order | |
layers = {node: idx for idx, node in enumerate(node_order)} | |
nx.set_node_attributes(G, layers, "layer") | |
pos = nx.multipartite_layout(G, subset_key="layer") | |
elif layout == "grid": | |
if node_order is None: | |
node_order = labels | |
sqrt_n = int(np.ceil(np.sqrt(len(node_order)))) | |
grid_positions = {} | |
for idx, node in enumerate(node_order): | |
row = idx // sqrt_n | |
col = idx % sqrt_n | |
grid_positions[node] = (col, -row) | |
pos = grid_positions | |
else: | |
raise ValueError( | |
"Invalid layout. Choose 'spring', 'circular', 'tree', or 'grid'." | |
) | |
# Get edge weights for coloring and widths | |
edges = G.edges(data=True) | |
edge_weights = [data["weight"] for _, _, data in edges] | |
if edge_width is None: | |
# Normalize edge widths based on correlation magnitude | |
abs_weights = [abs(w) for w in edge_weights] | |
min_weight = min(abs_weights) if abs_weights else 0 | |
max_weight = max(abs_weights) if abs_weights else 1 # Avoid division by zero | |
# Avoid division by zero | |
if max_weight == min_weight: | |
widths = [max_edge_width] * len(abs_weights) | |
else: | |
widths = [ | |
min_edge_width | |
+ (abs(w) - min_weight) | |
/ (max_weight - min_weight) | |
* (max_edge_width - min_edge_width) | |
for w in abs_weights | |
] | |
else: | |
widths = edge_width | |
# Draw the graph | |
plt.figure(figsize=figsize) | |
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color=node_color, **kwargs) | |
nx.draw_networkx_edges( | |
G, | |
pos, | |
edge_color=edge_weights, | |
edge_cmap=plt.get_cmap(edge_cmap), | |
edge_vmin=edge_vmin, | |
edge_vmax=edge_vmax, | |
width=widths, | |
) | |
nx.draw_networkx_labels(G, pos, font_size=font_size, font_weight="bold") | |
# Draw edge labels if required | |
if with_edge_labels: | |
edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in edges} | |
nx.draw_networkx_edge_labels( | |
G, pos, edge_labels=edge_labels, font_size=font_size - 2 | |
) | |
plt.axis("off") | |
if title: | |
plt.title(title, fontsize=font_size + 2) | |
plt.show() | |
if __name__ == "__main__": | |
# Test with pygraphviz | |
try: | |
from networkx.drawing.nx_agraph import graphviz_layout | |
method = "pygraphviz" | |
except ImportError: | |
# Test with pydot | |
try: | |
from networkx.drawing.nx_pydot import graphviz_layout | |
method = "pydot" | |
except ImportError: | |
graphviz_layout = None | |
if graphviz_layout is None: | |
print("graphviz_layout is not available.") | |
else: | |
print(f"graphviz_layout is available using {method}.") | |
# Create a simple graph and test layout | |
G = nx.complete_graph(5) | |
pos = graphviz_layout(G, prog="dot") | |
print("graphviz_layout is working.") | |
# Generate some random data | |
np.random.seed(42) | |
n_samples, n_features = 100, 5 | |
data = np.random.randn(n_samples, n_features) | |
df = pd.DataFrame(data, columns=[f"Var{i}" for i in range(n_features)]) | |
# Compute partial correlation matrix | |
partial_corr_matrix = compute_partial_correlations(df, method="glasso") | |
# Plot the partial correlation graph | |
plot_correlation_graph( | |
partial_corr_matrix, | |
labels=df.columns, | |
threshold=0.1, | |
layout="spring", | |
node_size=3000, | |
node_color="skyblue", | |
font_size=12, | |
edge_cmap="coolwarm", | |
edge_vmin=-1, | |
edge_vmax=1, | |
with_edge_labels=True, | |
figsize=(10, 10), | |
edge_width=2, | |
) | |
n_samples = 200 | |
# Independent variables | |
Var1 = np.random.normal(0, 1, n_samples) | |
Var2 = np.random.normal(0, 1, n_samples) | |
# Dependent variables | |
Var3 = 2 * Var1 + np.random.normal(0, 1, n_samples) # Var3 depends on Var1 | |
Var4 = 0.5 * Var2 + np.random.normal(0, 1, n_samples) # Var4 depends on Var2 | |
Var5 = ( | |
Var3 + Var4 + np.random.normal(0, 1, n_samples) | |
) # Var5 depends on Var3 and Var4 | |
# Create DataFrame | |
df = pd.DataFrame( | |
{"Var1": Var1, "Var2": Var2, "Var3": Var3, "Var4": Var4, "Var5": Var5} | |
) | |
# Compute Pearson correlation matrix | |
corr_matrix = compute_correlations(df) | |
# Compute partial correlations | |
partial_corr_matrix = compute_partial_correlations(df) | |
# Plot settings | |
plot_settings = { | |
"labels": df.columns, | |
"threshold": 0.2, | |
"layout": "spring", | |
"node_size": 2500, | |
# "node_color": "lightgreen", | |
"font_size": 12, | |
"edge_cmap": "bwr", | |
"edge_vmin": -1, | |
"edge_vmax": 1, | |
"with_edge_labels": True, | |
"figsize": (8, 8), | |
"edge_width": 2.5, | |
} | |
# Plot the Pearson correlation graph | |
plot_correlation_graph( | |
corr_matrix, title="Pearson Correlation Graph", **plot_settings | |
) | |
plot_correlation_graph( | |
corr_matrix, | |
labels=df.columns, | |
threshold=0.2, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Pearson Correlation Graph with Automatic Node Ordering", | |
) | |
# Plot the Partial Correlation graph | |
plot_correlation_graph( | |
partial_corr_matrix, | |
title="Partial Correlation Graph", | |
auto_order=True, | |
**plot_settings, | |
) | |
plot_correlation_graph( | |
partial_corr_matrix, | |
labels=df.columns, | |
threshold=0.2, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Partial Correlation Graph with Automatic Node Ordering", | |
) | |
partial_corr_matrix_corrected = compute_partial_correlations_corrected( | |
df, method="glasso" | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_corrected, | |
labels=df.columns, | |
threshold=0.2, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Partial Correlation Graph Corrected for Collider Bias", | |
) | |
partial_corr_matrix_pearson = compute_partial_correlations( | |
df, | |
method="pearson", | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_pearson, | |
labels=df.columns, | |
threshold=0.2, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Pearson Partial Correlation Graph", | |
) | |
partial_corr_matrix_spearman = compute_partial_correlations( | |
df, | |
method="spearman", | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_spearman, | |
labels=df.columns, | |
threshold=0.2, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Spearman Partial Correlation Graph", | |
) | |
partial_corr_matrix_nonlinear = compute_partial_correlations( | |
df, method="conditional_mi_kci" | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_nonlinear, | |
labels=df.columns, | |
threshold=0.01, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Non-linear Partial Correlation Graph", | |
) | |
# Run causal discovery algorithm | |
method = "fci" | |
result = run_causal_discovery(df, method=method, verbose=True) | |
# Visualize the causal graph | |
labels = df.columns.tolist() | |
visualize_causal_graph( | |
result, title=f"{method.upper()} Result, no cross-validation" # , labels=labels | |
) | |
partial_corr_matrix_nonlinear_corrected = compute_partial_correlations_corrected( | |
df, | |
method="conditional_mi_kci", | |
correction_method="conditional_mi_kci", | |
verbose=True, | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_nonlinear_corrected, | |
labels=df.columns, | |
threshold=0.2, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="Non-linear Partial Correlation Graph Corrected for Collider Bias Non-linearly", | |
) | |
partial_corr_matrix_nonparametric = compute_partial_correlations( | |
df, method="conditional_mi_nonparametric" | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_nonparametric, | |
labels=df.columns, | |
threshold=0.01, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="MI Non-parametric Partial Correlation Graph", | |
) | |
partial_corr_matrix_nonparametric_corrected = ( | |
compute_partial_correlations_corrected( | |
df, | |
method="conditional_mi_nonparametric", | |
correction_method="conditional_mi_nonparametric", | |
verbose=True, | |
) | |
) | |
plot_correlation_graph( | |
partial_corr_matrix_nonparametric_corrected, | |
labels=df.columns, | |
threshold=0.01, | |
layout="hierarchical", | |
auto_order=True, | |
node_size=2500, | |
node_color="lightblue", | |
font_size=12, | |
edge_cmap="bwr", | |
edge_vmin=-1, | |
edge_vmax=1, | |
min_edge_width=1, | |
max_edge_width=5, | |
title="MI Non-parametric Partial Correlation Graph Corrected for Collider Bias Non-linearly", | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment