Last active
May 30, 2020 15:14
-
-
Save ericmjl/43d14da2d2cde5e632e78cdcd816d5c0 to your computer and use it in GitHub Desktop.
Proposed change to d-separation tests based on pytest functions and fixtures.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pytest.fixture | |
def path_graph(): | |
"""Return a path graaph of length three.""" | |
G = nx.path_graph(3, create_using=nx.DiGraph) | |
G.graph["name"] = "path" | |
nx.freeze(G) | |
return G | |
@pytest.fixture | |
def fork_graph(): | |
"""Return a three-node fork graph.""" | |
G = nx.DiGraph(name="fork") | |
G.add_edges_from([(0, 1), (0, 2)]) | |
nx.freeze(G) | |
return G | |
@pytest.fixture | |
def collider_graph(): | |
"""Return a causal collider graph.""" | |
G = nx.DiGraph(name="collider") | |
G.add_edges_from([(0, 2), (1, 2)]) | |
nx.freeze(G) | |
return G | |
@pytest.fixture | |
def naive_bayes_graph(): | |
"""Return a simple Naive Bayes PGM graph.""" | |
G = nx.DiGraph(name="naive_bayes") | |
G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)]) | |
nx.freeze(G) | |
return G | |
@pytest.fixture | |
def asia_graph(): | |
"""Return the "asia" PGM graph.""" | |
G = nx.DiGraph(name="asia") | |
G.add_edges_from( | |
[ | |
('asia', 'tuberculosis'), | |
('smoking', 'cancer'), | |
('smoking', 'bronchitis'), | |
('tuberculosis', 'either'), | |
('cancer', 'either'), | |
('either', 'xray'), | |
('either', 'dyspnea'), | |
('bronchitis', 'dyspnea') | |
] | |
) | |
nx.freeze(G) | |
return G | |
@pytest.mark.parametrize( | |
"graph", [ | |
path_graph(), | |
fork_graph(), | |
collider_graph(), | |
naive_bayes_graph(), | |
asia_graph(), | |
] | |
) | |
def test_markov_condition(graph): | |
"""Test that the Markov condition holds for each PGM graph.""" | |
for node in graph.nodes: | |
parents = set(graph.predecessors(node)) | |
non_descendants = graph.nodes - nx.descendants( | |
graph, node) - {node} - parents | |
assert nx.d_separated(graph, {node}, non_descendants, parents) | |
def test_path_graph_dsep(path_graph): | |
"""Example-based test of d-separation for path_graph.""" | |
assert nx.d_separated(path_graph, {0}, {2}, {1}) | |
assert not nx.d_separated(path_graph, {0}, {2}, {}) | |
def test_fork_graph_dsep(fork_graph): | |
"""Example-based test of d-separation for fork_graph.""" | |
assert nx.d_separated(fork_graph, {1}, {2}, {0}) | |
assert not nx.d_separated(fork_graph, {1}, {2}, {}) | |
def test_collider_graph_dsep(collider_graph): | |
"""Example-based test of d-separation for collider_graph.""" | |
assert nx.d_separated(collider_graph, {0}, {1}, {}) | |
assert not nx.d_separated(collider_graph, {0}, {1}, {2}) | |
def test_naive_bayes_dsep(naive_bayes_graph): | |
"""Example-based test of d-separation for naive_bayes_graph.""" | |
for u, v in combinations(range(1, 5), 2): | |
assert nx.d_separated(naive_bayes_graph, {u}, {v}, {0}) | |
assert not nx.d_separated(naive_bayes_graph, {u}, {v}, {}) | |
def test_asia_graph_dsep(asia_graph): | |
"""Example-based test of d-separation for asia_graph.""" | |
assert nx.d_separated(asia_graph, {'asia', 'smoking'}, | |
{'dyspnea', 'xray'}, {'bronchitis', 'either'}) | |
assert nx.d_separated(asia_graph, {'tuberculosis', 'cancer'}, | |
{'bronchitis'}, {'smoking', 'xray'}) | |
def test_undirected_graphs_are_not_supported(path_graph): | |
""" | |
Test that undirected graphs are not supported. | |
d-separation does not apply in the case of undirected graphs. | |
""" | |
with pytest.raises(nx.NetworkXNotImplemented): | |
nx.d_separated(path_graph, {0}, {1}, {2}) | |
def test_cyclic_graphs_raise_error(self): | |
""" | |
Test that cycle graphs should cause erroring. | |
This is because PGMs assume a directed acyclic graph. | |
""" | |
with pytest.raises(nx.NetworkXError): | |
g = nx.cycle_graph(3, nx.DiGraph) | |
nx.d_separated(g, {0}, {1}, {2}) | |
def test_invalid_nodes_raise_error(self): | |
""" | |
Test that graphs that have invalid nodes passed in raise errors. | |
""" | |
with pytest.raises(nx.NodeNotFound): | |
nx.d_separated(self.asia_G, {0}, {1}, {2}) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment