Skip to content

Instantly share code, notes, and snippets.

@mschauer
Last active December 4, 2024 11:55
Show Gist options
  • Save mschauer/715a3555873fc2e054f519f2d40ba7b8 to your computer and use it in GitHub Desktop.
Save mschauer/715a3555873fc2e054f519f2d40ba7b8 to your computer and use it in GitHub Desktop.
Comparing CausalInference.pcalg with Associations.infer_graph
using CausalInference
using Associations: CorrTest, PC, Associations
using Test
using Graphs: SimpleDiGraph, Graphs, complete_graph
using StableRNGs
using LinearAlgebra, Random, Distributions
using CausalInference: pcalg, gausscitest, CausalInference
using Combinatorics
using Tables: table, istable
rng = StableRNG(123)
α = 0.01
alg = PC(CorrTest(), CorrTest(); α)
d = 8
n = 20_000
beta = 0.3
# Make a random $d-dimnesional problem where I know the true DAG and generate some data
Random.seed!(8)
E = LowerTriangular([i > j ? 1 * (rand() < beta) : 0 for i in 1:d, j in 1:d])
L = E .* (0.8 .+ 0.2rand(d, d))
println("\nVertices: $d, Edges: ", sum(E))
X = (I - L) \ randn(d, n)
dag_true = SimpleDiGraph(E')
dg_true = CausalInference.cpdag(dag_true)
# Some statistics
qu(x) = x*x'
Σtrue = inv(qu((I - L)'))
di = sqrt.(diag(Σtrue))
Ctrue = (Σtrue) ./ (di * di')
C = cor(X, dims=2)
# I can test for independence using the data or by an oracle which looks up independences in the true DAG
itest = CausalInference.IClosure(gausscitest, ((C, n), quantile(Normal(), 1 - α / 2)))
ioracle = CausalInference.IClosure(dseporacle, (dag_true,))
# Let's first compare the skeleton implementations, using the oracle as ground truth
g_ci_exact, S_ci_exact = skeleton(complete_graph(d), ioracle)
println("CausalInference.skeleton")
@time g_ci, S_ci = skeleton(complete_graph(d), itest)
println("Associations.skeleton")
@time g_ct, S_ct = Associations.skeleton(alg, eachrow(X); verbose = false)
# some sanity tests: independence set make vertices independent
@testset "sep" begin
for (edge, set) in S_ci
@test dsep(dag_true, edge.src, edge.dst, set)
end
end
@testset "sep" begin
for (edge, set) in S_ct
@test dsep(dag_true, edge.src, edge.dst, set)
end
end
@test g_ci == Graphs.SimpleGraph(g_ct)
# testing CausalInference
# now run the oracle pcalg as a check
dg_ci_oracle = CausalInference.pcalg(d, dseporacle, dag_true)
@test dg_true == dg_ci_oracle
# now test with data
println("CausalInference.pcalg")
@time dg_ci = CausalInference.pcalg(d, gausscitest, (C, n), quantile(Normal(), 1 - 0.01 / 2))
@test dg_true == dg_ci
@show symdiff(vpairs(dg_ci), vpairs(dg_true))
@test dg_true == dg_ci
# testing Associations
# test with data
println("Associations.infer_graph")
@time dg_ct = Associations.infer_graph(alg, eachrow(X); verbose = false)
@show symdiff(vpairs(dg_ct), vpairs(dg_true))
@test dg_true == dg_ct
[deps]
Associations = "614afb3a-e278-4863-8805-9959372b9ec2"
CausalInference = "8e462317-f959-576b-b3c1-403f26cec956"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
@mschauer
Copy link
Author

mschauer commented Dec 3, 2024

Vertices: 8, Edges: 10
CausalInference.skeleton
  0.000555 seconds (2.85 k allocations: 373.438 KiB)
Associations.skeleton
 18.507618 seconds (166.23 M allocations: 5.169 GiB, 3.90% gc time)
Test Summary: | Pass  Total  Time
sep           |   18     18  0.0s
Test Summary: | Pass  Total  Time
sep           |   10     10  0.0s
CausalInference.pcalg
  0.000713 seconds (3.29 k allocations: 418.078 KiB)
symdiff(vpairs(dg_ci), vpairs(dg_true)) = Pair{Int64, Int64}[]
Associations.infer_graph
 18.501315 seconds (166.23 M allocations: 5.169 GiB, 3.72% gc time)
symdiff(vpairs(dg_ct), vpairs(dg_true)) = [5 => 1, 5 => 2, 5 => 3, 7 => 1, 7 => 2, 7 => 4, 7 => 5]
Test Failed at /Users/smoritz/.julia/dev/causalcomp/comp.jl:73
  Expression: dg_true == dg_ct
   Evaluated: SimpleDiGraph{Int64}(13, [[3, 5, 6, 7], [5, 7], [1, 5], [7, 8], [7], [1], Int64[], [4]], [[3, 6], Int64[], [1], [8], [1, 2, 3], [1], [1, 2, 4, 5], [4]]) == SimpleDiGraph{Int64}(20, [[3, 5, 6, 7], [5, 7], [1, 5], [7, 8], [1, 2, 3, 7], [1], [1, 2, 4, 5], [4]], [[3, 5, 6, 7], [5, 7], [1, 5], [7, 8], [1, 2, 3, 7], [1], [1, 2, 4, 5], [4]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment