Created
July 17, 2024 18:45
-
-
Save vsocrates/2ba5caa42ae27a92c30c9cdb020d5630 to your computer and use it in GitHub Desktop.
Mutual Information in Pgmpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @param infer: a pgmpy Inference object over a BayesianNetwork (e.g. VariableElimination) | |
# @param node: a string identifier for one of the nodes in the BN | |
def mutual_info_with_percents(infer, node): | |
nodes = list(infer.model.nodes()) | |
mis = [] | |
mi_percentages = [] | |
for other_node in nodes: | |
mis.append(mi(infer, node, other_node)) | |
mi_percentages.append(mi_percentchange(infer, node, other_node)) | |
return pd.DataFrame({"MutualInfo":mis, "MIPercentChange":mi_percentages}, index=nodes) | |
def mi(infer, node1, node2): | |
if node1 == node2: | |
return np.nan | |
proby1 = infer.query([node1]) | |
proby2 = infer.query([node2]) | |
proby = infer.query([node1, node2], joint=True) | |
probs = np.outer(proby1.values, proby2.values) | |
return (np.sum(proby.values*np.log(proby.values) - proby.values*np.log(probs))) | |
def mi_percentchange(infer, node1, node2): | |
if node1 == node2: | |
return np.nan | |
proby1 = infer.query([node1]) | |
proby2 = infer.query([node2]) | |
proby = infer.query([node1, node2], joint=True) | |
probs = np.outer(proby1.values, proby2.values) | |
cond_entropy = np.sum(proby.values*np.log(probs)) | |
entropy = np.sum(proby.values*np.log(proby.values)) | |
return((1 - cond_entropy/entropy) * -100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment