Last active
December 4, 2024 11:55
-
-
Save mschauer/715a3555873fc2e054f519f2d40ba7b8 to your computer and use it in GitHub Desktop.
Comparing CausalInference.pcalg with Associations.infer_graph
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CausalInference | |
using Associations: CorrTest, PC, Associations | |
using Test | |
using Graphs: SimpleDiGraph, Graphs, complete_graph | |
using StableRNGs | |
using LinearAlgebra, Random, Distributions | |
using CausalInference: pcalg, gausscitest, CausalInference | |
using Combinatorics | |
using Tables: table, istable | |
rng = StableRNG(123) | |
α = 0.01 | |
alg = PC(CorrTest(), CorrTest(); α) | |
d = 8 | |
n = 20_000 | |
beta = 0.3 | |
# Make a random $d-dimnesional problem where I know the true DAG and generate some data | |
Random.seed!(8) | |
E = LowerTriangular([i > j ? 1 * (rand() < beta) : 0 for i in 1:d, j in 1:d]) | |
L = E .* (0.8 .+ 0.2rand(d, d)) | |
println("\nVertices: $d, Edges: ", sum(E)) | |
X = (I - L) \ randn(d, n) | |
dag_true = SimpleDiGraph(E') | |
dg_true = CausalInference.cpdag(dag_true) | |
# Some statistics | |
qu(x) = x*x' | |
Σtrue = inv(qu((I - L)')) | |
di = sqrt.(diag(Σtrue)) | |
Ctrue = (Σtrue) ./ (di * di') | |
C = cor(X, dims=2) | |
# I can test for independence using the data or by an oracle which looks up independences in the true DAG | |
itest = CausalInference.IClosure(gausscitest, ((C, n), quantile(Normal(), 1 - α / 2))) | |
ioracle = CausalInference.IClosure(dseporacle, (dag_true,)) | |
# Let's first compare the skeleton implementations, using the oracle as ground truth | |
g_ci_exact, S_ci_exact = skeleton(complete_graph(d), ioracle) | |
println("CausalInference.skeleton") | |
@time g_ci, S_ci = skeleton(complete_graph(d), itest) | |
println("Associations.skeleton") | |
@time g_ct, S_ct = Associations.skeleton(alg, eachrow(X); verbose = false) | |
# some sanity tests: independence set make vertices independent | |
@testset "sep" begin | |
for (edge, set) in S_ci | |
@test dsep(dag_true, edge.src, edge.dst, set) | |
end | |
end | |
@testset "sep" begin | |
for (edge, set) in S_ct | |
@test dsep(dag_true, edge.src, edge.dst, set) | |
end | |
end | |
@test g_ci == Graphs.SimpleGraph(g_ct) | |
# testing CausalInference | |
# now run the oracle pcalg as a check | |
dg_ci_oracle = CausalInference.pcalg(d, dseporacle, dag_true) | |
@test dg_true == dg_ci_oracle | |
# now test with data | |
println("CausalInference.pcalg") | |
@time dg_ci = CausalInference.pcalg(d, gausscitest, (C, n), quantile(Normal(), 1 - 0.01 / 2)) | |
@test dg_true == dg_ci | |
@show symdiff(vpairs(dg_ci), vpairs(dg_true)) | |
@test dg_true == dg_ci | |
# testing Associations | |
# test with data | |
println("Associations.infer_graph") | |
@time dg_ct = Associations.infer_graph(alg, eachrow(X); verbose = false) | |
@show symdiff(vpairs(dg_ct), vpairs(dg_true)) | |
@test dg_true == dg_ct |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[deps] | |
Associations = "614afb3a-e278-4863-8805-9959372b9ec2" | |
CausalInference = "8e462317-f959-576b-b3c1-403f26cec956" | |
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | |
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | |
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" | |
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | |
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" |
Author
mschauer
commented
Dec 3, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment