Skip to content

Instantly share code, notes, and snippets.

@albertbuchard
Created October 14, 2024 17:18
Show Gist options
  • Save albertbuchard/84ecb3e77d499aaf31a984aaac8bfb67 to your computer and use it in GitHub Desktop.
Save albertbuchard/84ecb3e77d499aaf31a984aaac8bfb67 to your computer and use it in GitHub Desktop.
Partial Correlation Algorithms Evaluation
import matplotlib.pyplot as plt
import networkx as nx
from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd
from npeet_plus import mi, mi_pvalue
from tqdm import tqdm
from src.causal_discovery.static_causal_discovery import (
run_causal_discovery,
visualize_causal_graph,
)
def compute_correlations(df):
"""
Compute the Pearson correlation matrix from a pandas DataFrame.
Parameters:
- df (pd.DataFrame): The input data.
Returns:
- corr_matrix (np.ndarray): The Pearson correlation matrix.
"""
return df.corr().values
import numpy as np
import pandas as pd
from sklearn.covariance import GraphicalLassoCV
from itertools import combinations
from scipy import stats
from scipy.stats import spearmanr
def compute_partial_correlations_corrected(
df,
method="glasso",
alpha=0.05,
correction_method="pearson",
verbose=False,
epsilon=1e-6,
**kwargs,
):
"""
Compute the partial correlation matrix from a pandas DataFrame,
detecting and correcting for edges due to collider biases.
Parameters:
- df (pd.DataFrame): The input data.
- method (str): The method to compute the precision matrix ('glasso' supported).
- alpha (float): Significance level for statistical tests.
- correction_method (str): The correlation method to use to correct for collider bias ('pearson', 'spearman', 'conditional_mi_kci').
- verbose (bool): Whether to print progress messages.
- epsilon (float): Smallest value considered a non-zero partial correlation.
- **kwargs: Additional keyword arguments for the estimator.
Returns:
- corrected_partial_corr_matrix (np.ndarray): The partial correlation matrix corrected for collider bias.
"""
# Step 1: Compute the partial correlation matrix
print("Computing partial correlations...")
partial_corr_matrix = compute_partial_correlations(
df, method, verbose=verbose, **kwargs
)
# Initialize corrected partial correlation matrix
print("Correcting for collider bias...")
corrected_partial_corr_matrix = partial_corr_matrix.copy()
# Step 2: Detect and correct for colliders
n_vars = df.shape[1]
variables = df.columns.tolist()
# Iterate over all combinations of triplets
tqdm_bar = tqdm(combinations(range(n_vars), 3), desc="Detecting colliders")
for i, j, k in tqdm_bar:
if corrected_partial_corr_matrix[i, j] < epsilon:
corrected_partial_corr_matrix[i, j] = 0
corrected_partial_corr_matrix[j, i] = 0
continue
tqdm_bar.set_postfix(
{"Edge": f"{variables[i]} - {variables[j]} | {variables[k]}"}
)
X = df.iloc[:, i]
Y = df.iloc[:, j]
Z = df.iloc[:, k]
# Compute zero-order correlations
if correction_method == "pearson":
r_xy, p_xy = stats.pearsonr(X, Y)
elif correction_method == "spearman":
r_xy, p_xy = spearmanr(X, Y)
elif correction_method == "conditional_mi_kci":
X = X.values.reshape(-1, 1)
Y = Y.values.reshape(-1, 1)
Z = Z.values.reshape(-1, 1)
kci = KCI_UInd(null_ss=1000, approx=True)
p_xy, r_xy = kci.compute_pvalue(data_x=X, data_y=Y)
elif correction_method == "conditional_mi_nonparametric":
X = X.values
Y = Y.values
Z = Z.values
r_xy, p_xy, _ = mi_pvalue(X, Y, alpha=0.25)
else:
raise ValueError(
"Invalid correlation method. Choose 'pearson', 'spearman', or 'conditional_mi_kci'."
)
# Compute partial correlations controlling for Z
r_xy_z, p_xy_z = partial_corr_test(X, Y, Z, correction_method)
# Check for collider bias: If X and Y are independent (p > alpha)
# but become dependent when conditioning on Z (p < alpha)
if p_xy > alpha and p_xy_z < alpha:
if verbose:
print(
f"Collider detected: {variables[i]}, {variables[j]} | {variables[k]}"
)
# Spurious association due to collider at Z
# Remove edge between X and Y in the partial correlation matrix
corrected_partial_corr_matrix[i, j] = 0
corrected_partial_corr_matrix[j, i] = 0
return corrected_partial_corr_matrix
def partial_corr_test(X, Y, Z, correlation_method="pearson"):
"""
Compute the partial correlation between X and Y, controlling for Z,
and perform a statistical test.
Parameters:
- X, Y, Z: pd.Series or np.ndarray. The variables involved.
- correlation_method (str): The correlation method to use ('pearson', 'spearman', 'conditional_mi_kci', "conditional_mi_nonparametric").
Returns:
- r (float): Partial correlation coefficient.
- p_value (float): Two-tailed p-value.
"""
if isinstance(X, pd.Series):
X = X.values
if isinstance(Y, pd.Series):
Y = Y.values
if isinstance(Z, pd.Series):
Z = Z.values
if correlation_method == "pearson" or correlation_method == "spearman":
# Regress X on Z
beta_xz = np.linalg.lstsq(Z.reshape(-1, 1), X, rcond=None)[0]
residuals_x = X - Z * beta_xz
# Regress Y on Z
beta_yz = np.linalg.lstsq(Z.reshape(-1, 1), Y, rcond=None)[0]
residuals_y = Y - Z * beta_yz
# Compute correlation between residuals
if correlation_method == "pearson":
r, p_value = stats.pearsonr(residuals_x, residuals_y)
else:
r, p_value = spearmanr(residuals_x, residuals_y)
elif correlation_method == "conditional_mi_kci":
kci_test = KCI_CInd(nullss=5000, approx=True)
# Perform the KCI test
p_value, r = kci_test.compute_pvalue(data_x=X, data_y=Y, data_z=Z)
elif correlation_method == "conditional_mi_nonparametric":
r, p_value, _ = mi_pvalue(X, Y, Z, alpha=0.25)
else:
raise ValueError(
"Invalid correlation method. Choose 'pearson', 'spearman', or 'conditional_mi_kci'."
)
return r, p_value
def compute_partial_correlation_kci(df, null_ss=5000, alpha=0.05, verbose=False):
"""
Compute the non-linear partial correlation matrix using the KCI_CInd test.
Parameters:
- df: pandas DataFrame with continuous variables.
- null_ss: Number of samples in simulating the null distribution (default=5000).
- alpha: Significance level for the hypothesis test (default=0.05).
- verbose: Whether to display progress messages (default=False).
Returns:
- stat_matrix: numpy array representing the test statistics matrix.
- pvalue_matrix: numpy array representing the p-values matrix.
- significance_matrix: numpy array indicating significant partial correlations (1 if significant, 0 otherwise).
"""
n_vars = df.shape[1]
columns = df.columns
stat_matrix = np.zeros((n_vars, n_vars))
pvalue_matrix = np.ones((n_vars, n_vars))
significance_matrix = np.zeros((n_vars, n_vars))
np.fill_diagonal(stat_matrix, 0.0)
np.fill_diagonal(pvalue_matrix, 0.0)
np.fill_diagonal(significance_matrix, 0.0)
if verbose:
iterator = tqdm(range(n_vars), desc="Variables")
else:
iterator = range(n_vars)
for i in iterator:
for j in range(i + 1, n_vars):
X = df.iloc[:, i].values.reshape(-1, 1)
Y = df.iloc[:, j].values.reshape(-1, 1)
Z_columns = [col for k, col in enumerate(columns) if k != i and k != j]
if Z_columns:
Z = df[Z_columns].values
else:
Z = None
kci_test = KCI_CInd(nullss=null_ss, approx=True)
# Perform the KCI test
pvalue, test_stat = kci_test.compute_pvalue(data_x=X, data_y=Y, data_z=Z)
# Store the test statistic and p-value
stat_matrix[i, j] = test_stat
stat_matrix[j, i] = test_stat
pvalue_matrix[i, j] = pvalue
pvalue_matrix[j, i] = pvalue
# Determine significance based on alpha
if pvalue < alpha:
significance_matrix[i, j] = 1
significance_matrix[j, i] = 1
return stat_matrix, pvalue_matrix, significance_matrix
def compute_partial_correlations(
df, method="glasso", graphical_lasso_args=None, verbose=False, n_permutations=10
):
"""
Compute the partial correlation matrix from a pandas DataFrame.
Parameters:
- df (pd.DataFrame): The input data.
- method (str): The method to compute the partial correlation matrix ('glasso', 'pearson', 'spearman', 'conditional_mi_kci', "conditional_mi_nonparametric").
- graphical_lasso_args (dict): Additional arguments for GraphicalLassoCV.
- verbose (bool): Whether to print progress messages.
- n_permutations (int): Number of permutations for the permutation test.
Returns:
- partial_corr_matrix (np.ndarray): The partial correlation matrix.
"""
if method == "glasso":
graphical_lasso_args = graphical_lasso_args or {}
model = GraphicalLassoCV(**graphical_lasso_args)
model.fit(df)
precision_matrix = model.precision_
# Convert precision matrix to partial correlations
d = np.sqrt(np.diag(precision_matrix))
partial_corr_matrix = -precision_matrix / np.outer(d, d)
np.fill_diagonal(partial_corr_matrix, 1)
elif method in ["pearson", "spearman"]:
# Compute residuals using linear regression for pearson or spearman correlation
residuals = pd.DataFrame(index=df.index, columns=df.columns)
for target in df.columns:
predictors = df.drop(columns=[target])
beta = np.linalg.lstsq(predictors, df[target], rcond=None)[0]
residuals[target] = df[target] - predictors.dot(beta)
# Compute correlations
if method == "pearson":
partial_corr_matrix = residuals.corr(method="pearson").values
else:
partial_corr_matrix = residuals.corr(method="spearman").values
elif method == "conditional_mi_nonparametric":
n_vars = df.shape[1]
columns = df.columns
partial_corr_matrix = np.zeros((n_vars, n_vars))
np.fill_diagonal(partial_corr_matrix, 1.0)
for i in range(n_vars):
for j in range(i + 1, n_vars):
x = df.iloc[:, i].values
y = df.iloc[:, j].values
z_columns = [col for k, col in enumerate(columns) if k != i and k != j]
z = df[z_columns].values
cmi = mi(x, y, z, k=3, alpha=0.25)
if cmi < 0:
cmi = 0
partial_corr_matrix[i, j] = cmi
partial_corr_matrix[j, i] = cmi
elif method == "conditional_mi_kci":
stat_matrix, pvalue_matrix, significance_matrix = (
compute_partial_correlation_kci(df, verbose=verbose)
)
partial_corr_matrix = significance_matrix * stat_matrix
elif method == "nonlinear":
raise NotImplementedError
# n_vars = df.shape[1]
# columns = df.columns
#
# # Step 1: Compute observed residuals and mutual information
# if verbose:
# print("Computing observed residuals and mutual information...")
# mi_observed = {}
# for i in tqdm(range(n_vars), desc="Variables"):
# for j in range(i + 1, n_vars):
# # Define control variables (all others except X_i and X_j)
# control_vars = [
# col for k, col in enumerate(columns) if k != i and k != j
# ]
#
# # Regress X_i on control_vars (excluding X_j)
# x_i = df.iloc[:, i]
# x_control_i = df.loc[:, control_vars]
# model_i = XGBRegressor()
# model_i.fit(x_control_i, x_i)
# residuals_i = x_i - model_i.predict(x_control_i)
#
# # Regress X_j on control_vars (excluding X_i)
# x_j = df.iloc[:, j]
# x_control_j = df.loc[:, control_vars]
# model_j = XGBRegressor()
# model_j.fit(x_control_j, x_j)
# residuals_j = x_j - model_j.predict(x_control_j)
#
# # Compute mutual information between residuals
# mi = mutual_info_regression(
# residuals_i.values.reshape(-1, 1),
# residuals_j,
# discrete_features=False,
# )[0]
# mi_observed[(i, j)] = mi
#
# # Step 2: Compute permutation distribution by shuffling raw data
# if n_permutations is None:
# print("Threshold set to 0.001")
# threshold = 0.001
# else:
# if verbose:
# print("Computing permutation distribution...")
# mi_permuted = []
# for n in tqdm(range(n_permutations), desc="Permutations"):
# # Shuffle the data for each variable independently
# df_shuffled = df.apply(np.random.permutation)
# j = np.random.randint(n_vars)
# for i in range(n_vars):
# if i == j:
# continue
# # for j in range(i + 1, n_vars):
# control_vars = [
# col for k, col in enumerate(columns) if k != i and k != j
# ]
#
# # Regress shuffled X_i on shuffled control_vars (excluding X_j)
# x_i_shuffled = df_shuffled.iloc[:, i]
# x_control_i_shuffled = df_shuffled.loc[:, control_vars]
# model_i = XGBRegressor()
# model_i.fit(x_control_i_shuffled, x_i_shuffled)
# residuals_i_shuffled = x_i_shuffled - model_i.predict(
# x_control_i_shuffled
# )
#
# # Regress shuffled X_j on shuffled control_vars (excluding X_i)
# x_j_shuffled = df_shuffled.iloc[:, j]
# x_control_j_shuffled = df_shuffled.loc[:, control_vars]
# model_j = XGBRegressor()
# model_j.fit(x_control_j_shuffled, x_j_shuffled)
# residuals_j_shuffled = x_j_shuffled - model_j.predict(
# x_control_j_shuffled
# )
#
# # Compute mutual information between residuals
# mi_perm = mutual_info_regression(
# residuals_i_shuffled.values.reshape(-1, 1),
# residuals_j_shuffled,
# discrete_features=False,
# )[0]
# mi_permuted.append(mi_perm)
#
# mi_permuted = np.array(mi_permuted)
#
# # Step 3: Compute a global threshold from the permutation distribution
# threshold = np.percentile(mi_permuted, 95)
#
# # Step 4: Build the partial correlation matrix
# partial_corr_matrix = np.zeros((n_vars, n_vars))
# np.fill_diagonal(partial_corr_matrix, 1.0)
#
# for (i, j), mi in mi_observed.items():
# if mi > threshold:
# partial_corr_matrix[i, j] = mi
# partial_corr_matrix[j, i] = mi
# else:
# partial_corr_matrix[i, j] = 0
# partial_corr_matrix[j, i] = 0
# for (i, j), mi in mi_observed.items():
# p_value = np.mean(mi_permuted >= mi)
# if p_value < 0.05:
# partial_corr_matrix[i, j] = mi
# partial_corr_matrix[j, i] = mi
# else:
# partial_corr_matrix[i, j] = 0
# partial_corr_matrix[j, i] = 0
else:
raise ValueError(
"Invalid method: choose 'glasso'', 'pearson', 'spearman', or 'conditional_mi_kci'."
)
return partial_corr_matrix
def plot_correlation_graph(
corr_matrix,
labels=None,
threshold=0.1,
layout="spring",
node_size=3000,
node_color="skyblue",
font_size=12,
edge_cmap="coolwarm",
edge_vmin=-1,
edge_vmax=1,
with_edge_labels=True,
figsize=(10, 10),
min_edge_width=1,
max_edge_width=5,
title=None,
node_order=None,
auto_order=False,
edge_width=None,
node_kwargs=None,
edge_kwargs=None,
):
"""
Plot a correlation graph using NetworkX.
Parameters:
- corr_matrix (np.ndarray): The correlation or partial correlation matrix.
- labels (list): Variable names corresponding to the matrix.
- threshold (float): Minimum absolute value for edges to be included.
- layout (str): Layout for the graph ('spring', 'circular', 'hierarchical', 'grid').
- node_size (int): Size of the nodes.
- node_color (str or list): Color of the nodes.
- font_size (int): Font size for labels.
- edge_cmap (str): Colormap for the edges.
- edge_vmin (float): Minimum value for edge colormap.
- edge_vmax (float): Maximum value for edge colormap.
- with_edge_labels (bool): Whether to display edge labels.
- figsize (tuple): Figure size.
- min_edge_width (float): Minimum edge width.
- max_edge_width (float): Maximum edge width.
- title (str): Title of the plot.
- node_order (list): Predefined ordering of nodes (used in 'hierarchical' and 'grid' layouts).
- auto_order (bool): Automatically order nodes based on number of edges over threshold.
- edge_width (float or list): Custom edge width or list of widths.
- node_kwargs (dict): Additional keyword arguments for drawing nodes.
- edge_kwargs (dict): Additional keyword arguments for drawing edges.
"""
if node_kwargs is None:
node_kwargs = {}
if edge_kwargs is None:
edge_kwargs = {}
if isinstance(corr_matrix, pd.DataFrame):
if labels is None:
labels = corr_matrix.columns
corr_matrix = corr_matrix.values
if corr_matrix.shape[0] != corr_matrix.shape[1]:
raise ValueError("Correlation matrix must be square.")
G = nx.Graph()
if labels is None:
labels = [f"Var{i}" for i in range(corr_matrix.shape[0])]
# Add nodes
G.add_nodes_from(labels)
# Add edges with weights above the threshold
if threshold is None:
threshold = -np.inf
for i in range(len(labels)):
for j in range(i + 1, len(labels)):
weight = corr_matrix[i, j]
if abs(weight) >= threshold:
G.add_edge(labels[i], labels[j], weight=weight)
# Automatically define node_order based on degree if auto_order is True
if auto_order:
# Calculate degrees (number of edges over threshold)
degrees = dict(G.degree())
# Sort nodes by degree (descending order)
node_order = sorted(degrees, key=lambda x: degrees[x], reverse=True)
# Choose layout
if layout == "spring":
pos = nx.spring_layout(G, seed=42)
elif layout == "circular":
pos = nx.circular_layout(G)
elif layout == "hierarchical":
# Use the undirected graph G for hierarchical layout
if node_order is None or len(node_order) == 0:
# Use node with highest degree as root
degrees = dict(G.degree())
root_node = max(degrees, key=degrees.get)
else:
root_node = node_order[0]
# Use graphviz_layout with 'dot' for hierarchical layout
try:
pos = nx.nx_agraph.graphviz_layout(G, prog="dot", root=root_node)
except (ImportError, nx.NetworkXException):
# Fallback to spring layout if graphviz is not available
print("Graphviz layout not available, using spring layout")
pos = nx.spring_layout(G, seed=42)
elif layout == "grid":
if node_order is None:
node_order = labels
# Arrange nodes in a grid
sqrt_n = int(np.ceil(np.sqrt(len(node_order))))
grid_positions = {}
for idx, node in enumerate(node_order):
row = idx // sqrt_n
col = idx % sqrt_n
grid_positions[node] = (col, -row)
pos = grid_positions
else:
raise ValueError(
"Invalid layout. Choose 'spring', 'circular', 'hierarchical', or 'grid'."
)
# Get edge weights for coloring and widths
edges = G.edges(data=True)
edge_weights = []
for u, v, data in edges:
if "weight" in data:
edge_weights.append(data["weight"])
else:
edge_weights.append(0) # Default weight if missing
if edge_width is None:
# Normalize edge widths based on correlation magnitude
abs_weights = [abs(w) for w in edge_weights]
min_weight = min(abs_weights) if abs_weights else 0
max_weight = max(abs_weights) if abs_weights else 1 # Avoid division by zero
if max_weight == min_weight:
widths = [max_edge_width] * len(abs_weights)
else:
widths = [
min_edge_width
+ (abs(w) - min_weight)
/ (max_weight - min_weight)
* (max_edge_width - min_edge_width)
for w in abs_weights
]
else:
widths = edge_width
# Draw the graph
plt.figure(figsize=figsize)
nx.draw_networkx_nodes(
G, pos, node_size=node_size, node_color=node_color, **node_kwargs
)
nx.draw_networkx_labels(G, pos, font_size=font_size, font_weight="bold")
# Draw edges (undirected)
nx.draw_networkx_edges(
G,
pos,
edge_color=edge_weights,
edge_cmap=plt.get_cmap(edge_cmap),
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
width=widths,
**edge_kwargs,
)
# Draw edge labels if required
if with_edge_labels:
edge_labels = {}
for u, v, data in edges:
weight = data.get("weight", 0)
edge_labels[(u, v)] = f"{weight:.2f}"
nx.draw_networkx_edge_labels(
G, pos, edge_labels=edge_labels, font_size=font_size - 2
)
plt.axis("off")
if title:
plt.title(title, fontsize=font_size + 2)
plt.show()
def plot_correlation_graph_bak(
corr_matrix,
labels=None,
threshold=0.1,
layout="spring",
node_size=3000,
node_color="skyblue",
font_size=12,
edge_cmap="coolwarm",
edge_vmin=-1,
edge_vmax=1,
with_edge_labels=True,
figsize=(10, 10),
min_edge_width=1,
max_edge_width=5,
title=None,
node_order=None,
auto_order=False,
edge_width=None,
**kwargs,
):
"""
Plot a correlation graph using NetworkX.
Parameters:
- corr_matrix (np.ndarray): The correlation or partial correlation matrix.
- labels (list): Variable names corresponding to the matrix.
- threshold (float): Minimum absolute value for edges to be included.
- layout (str): Layout for the graph ('spring', 'circular', 'tree', 'grid').
- node_size (int): Size of the nodes.
- node_color (str or list): Color of the nodes.
- font_size (int): Font size for labels.
- edge_cmap (str): Colormap for the edges.
- edge_vmin (float): Minimum value for edge colormap.
- edge_vmax (float): Maximum value for edge colormap.
- with_edge_labels (bool): Whether to display edge labels.
- figsize (tuple): Figure size.
- min_edge_width (float): Minimum edge width.
- max_edge_width (float): Maximum edge width.
- title (str): Title of the plot.
- node_order (list): Predefined ordering of nodes (used in 'tree' and 'grid' layouts).
- auto_order (bool): Automatically order nodes based on number of edges over threshold.
- **kwargs: Additional keyword arguments for NetworkX drawing functions.
"""
G = nx.Graph()
if labels is None:
labels = [f"Var{i}" for i in range(corr_matrix.shape[0])]
# Add nodes
G.add_nodes_from(labels)
# Add edges with weights above the threshold
for i in range(len(labels)):
for j in range(i + 1, len(labels)):
weight = corr_matrix[i, j]
if abs(weight) >= threshold:
G.add_edge(labels[i], labels[j], weight=weight)
# Automatically define node_order based on degree if auto_order is True
if auto_order:
# Calculate degrees (number of edges over threshold)
degrees = dict(G.degree())
# Sort nodes by degree (descending order)
node_order = sorted(degrees, key=lambda x: degrees[x], reverse=True)
# Choose layout
if layout == "spring":
pos = nx.spring_layout(G, seed=42)
elif layout == "circular":
pos = nx.circular_layout(G)
elif layout == "tree":
if node_order is None:
node_order = labels
# Assign layers based on node_order
layers = {node: idx for idx, node in enumerate(node_order)}
nx.set_node_attributes(G, layers, "layer")
pos = nx.multipartite_layout(G, subset_key="layer")
elif layout == "grid":
if node_order is None:
node_order = labels
sqrt_n = int(np.ceil(np.sqrt(len(node_order))))
grid_positions = {}
for idx, node in enumerate(node_order):
row = idx // sqrt_n
col = idx % sqrt_n
grid_positions[node] = (col, -row)
pos = grid_positions
else:
raise ValueError(
"Invalid layout. Choose 'spring', 'circular', 'tree', or 'grid'."
)
# Get edge weights for coloring and widths
edges = G.edges(data=True)
edge_weights = [data["weight"] for _, _, data in edges]
if edge_width is None:
# Normalize edge widths based on correlation magnitude
abs_weights = [abs(w) for w in edge_weights]
min_weight = min(abs_weights) if abs_weights else 0
max_weight = max(abs_weights) if abs_weights else 1 # Avoid division by zero
# Avoid division by zero
if max_weight == min_weight:
widths = [max_edge_width] * len(abs_weights)
else:
widths = [
min_edge_width
+ (abs(w) - min_weight)
/ (max_weight - min_weight)
* (max_edge_width - min_edge_width)
for w in abs_weights
]
else:
widths = edge_width
# Draw the graph
plt.figure(figsize=figsize)
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color=node_color, **kwargs)
nx.draw_networkx_edges(
G,
pos,
edge_color=edge_weights,
edge_cmap=plt.get_cmap(edge_cmap),
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
width=widths,
)
nx.draw_networkx_labels(G, pos, font_size=font_size, font_weight="bold")
# Draw edge labels if required
if with_edge_labels:
edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in edges}
nx.draw_networkx_edge_labels(
G, pos, edge_labels=edge_labels, font_size=font_size - 2
)
plt.axis("off")
if title:
plt.title(title, fontsize=font_size + 2)
plt.show()
if __name__ == "__main__":
# Test with pygraphviz
try:
from networkx.drawing.nx_agraph import graphviz_layout
method = "pygraphviz"
except ImportError:
# Test with pydot
try:
from networkx.drawing.nx_pydot import graphviz_layout
method = "pydot"
except ImportError:
graphviz_layout = None
if graphviz_layout is None:
print("graphviz_layout is not available.")
else:
print(f"graphviz_layout is available using {method}.")
# Create a simple graph and test layout
G = nx.complete_graph(5)
pos = graphviz_layout(G, prog="dot")
print("graphviz_layout is working.")
# Generate some random data
np.random.seed(42)
n_samples, n_features = 100, 5
data = np.random.randn(n_samples, n_features)
df = pd.DataFrame(data, columns=[f"Var{i}" for i in range(n_features)])
# Compute partial correlation matrix
partial_corr_matrix = compute_partial_correlations(df, method="glasso")
# Plot the partial correlation graph
plot_correlation_graph(
partial_corr_matrix,
labels=df.columns,
threshold=0.1,
layout="spring",
node_size=3000,
node_color="skyblue",
font_size=12,
edge_cmap="coolwarm",
edge_vmin=-1,
edge_vmax=1,
with_edge_labels=True,
figsize=(10, 10),
edge_width=2,
)
n_samples = 200
# Independent variables
Var1 = np.random.normal(0, 1, n_samples)
Var2 = np.random.normal(0, 1, n_samples)
# Dependent variables
Var3 = 2 * Var1 + np.random.normal(0, 1, n_samples) # Var3 depends on Var1
Var4 = 0.5 * Var2 + np.random.normal(0, 1, n_samples) # Var4 depends on Var2
Var5 = (
Var3 + Var4 + np.random.normal(0, 1, n_samples)
) # Var5 depends on Var3 and Var4
# Create DataFrame
df = pd.DataFrame(
{"Var1": Var1, "Var2": Var2, "Var3": Var3, "Var4": Var4, "Var5": Var5}
)
# Compute Pearson correlation matrix
corr_matrix = compute_correlations(df)
# Compute partial correlations
partial_corr_matrix = compute_partial_correlations(df)
# Plot settings
plot_settings = {
"labels": df.columns,
"threshold": 0.2,
"layout": "spring",
"node_size": 2500,
# "node_color": "lightgreen",
"font_size": 12,
"edge_cmap": "bwr",
"edge_vmin": -1,
"edge_vmax": 1,
"with_edge_labels": True,
"figsize": (8, 8),
"edge_width": 2.5,
}
# Plot the Pearson correlation graph
plot_correlation_graph(
corr_matrix, title="Pearson Correlation Graph", **plot_settings
)
plot_correlation_graph(
corr_matrix,
labels=df.columns,
threshold=0.2,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Pearson Correlation Graph with Automatic Node Ordering",
)
# Plot the Partial Correlation graph
plot_correlation_graph(
partial_corr_matrix,
title="Partial Correlation Graph",
auto_order=True,
**plot_settings,
)
plot_correlation_graph(
partial_corr_matrix,
labels=df.columns,
threshold=0.2,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Partial Correlation Graph with Automatic Node Ordering",
)
partial_corr_matrix_corrected = compute_partial_correlations_corrected(
df, method="glasso"
)
plot_correlation_graph(
partial_corr_matrix_corrected,
labels=df.columns,
threshold=0.2,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Partial Correlation Graph Corrected for Collider Bias",
)
partial_corr_matrix_pearson = compute_partial_correlations(
df,
method="pearson",
)
plot_correlation_graph(
partial_corr_matrix_pearson,
labels=df.columns,
threshold=0.2,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Pearson Partial Correlation Graph",
)
partial_corr_matrix_spearman = compute_partial_correlations(
df,
method="spearman",
)
plot_correlation_graph(
partial_corr_matrix_spearman,
labels=df.columns,
threshold=0.2,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Spearman Partial Correlation Graph",
)
partial_corr_matrix_nonlinear = compute_partial_correlations(
df, method="conditional_mi_kci"
)
plot_correlation_graph(
partial_corr_matrix_nonlinear,
labels=df.columns,
threshold=0.01,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Non-linear Partial Correlation Graph",
)
# Run causal discovery algorithm
method = "fci"
result = run_causal_discovery(df, method=method, verbose=True)
# Visualize the causal graph
labels = df.columns.tolist()
visualize_causal_graph(
result, title=f"{method.upper()} Result, no cross-validation" # , labels=labels
)
partial_corr_matrix_nonlinear_corrected = compute_partial_correlations_corrected(
df,
method="conditional_mi_kci",
correction_method="conditional_mi_kci",
verbose=True,
)
plot_correlation_graph(
partial_corr_matrix_nonlinear_corrected,
labels=df.columns,
threshold=0.2,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="Non-linear Partial Correlation Graph Corrected for Collider Bias Non-linearly",
)
partial_corr_matrix_nonparametric = compute_partial_correlations(
df, method="conditional_mi_nonparametric"
)
plot_correlation_graph(
partial_corr_matrix_nonparametric,
labels=df.columns,
threshold=0.01,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="MI Non-parametric Partial Correlation Graph",
)
partial_corr_matrix_nonparametric_corrected = (
compute_partial_correlations_corrected(
df,
method="conditional_mi_nonparametric",
correction_method="conditional_mi_nonparametric",
verbose=True,
)
)
plot_correlation_graph(
partial_corr_matrix_nonparametric_corrected,
labels=df.columns,
threshold=0.01,
layout="hierarchical",
auto_order=True,
node_size=2500,
node_color="lightblue",
font_size=12,
edge_cmap="bwr",
edge_vmin=-1,
edge_vmax=1,
min_edge_width=1,
max_edge_width=5,
title="MI Non-parametric Partial Correlation Graph Corrected for Collider Bias Non-linearly",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment