Last active
October 5, 2023 03:49
-
-
Save itolosa/6008ac6cd1257c964bb8a5d078b4b542 to your computer and use it in GitHub Desktop.
CS50AI - Lecture 2 - Bayesian Network updated source code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from model import model | |
X = torch.tensor( | |
[ | |
[ | |
-1, | |
-1, | |
1, # delayed | |
-1, | |
] | |
] | |
) | |
X_masked = torch.masked.MaskedTensor(X, mask=(X != -1)) | |
states = ( | |
("rain", ["none", "light", "heavy"]), | |
("maintenance", ["yes", "no"]), | |
("train", ["on time", "delayed"]), | |
("appointment", ["attend", "miss"]), | |
) | |
# Calculate predictions | |
predictions = model.predict_proba(X_masked) | |
# Print predictions for each node | |
for (node_name, values), prediction in zip(states, predictions): | |
if isinstance(prediction, str): | |
print(f"{node_name}: {prediction}") | |
else: | |
print(f"{node_name}") | |
for value, probability in zip(values, prediction[0]): | |
print(f" {value}: {probability:.4f}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import torch | |
from model import model | |
rain_values = ["none", "light", "heavy"] | |
maintenance_values = ["yes", "no"] | |
train_values = ["on time", "delayed"] | |
appoinment_values = ["attend", "miss"] | |
probability = model.probability( | |
torch.as_tensor( | |
[ | |
[ | |
rain_values.index("none"), | |
maintenance_values.index("no"), | |
train_values.index("on time"), | |
appoinment_values.index("attend"), | |
] | |
] | |
) | |
) | |
print(probability) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pomegranate import * | |
import numpy as np | |
from pomegranate.distributions import * | |
from pomegranate.bayesian_network import BayesianNetwork | |
rain = Categorical( | |
[ | |
[0.7, 0.2, 0.1], | |
] | |
) | |
maintenance = ConditionalCategorical( | |
[ | |
[ | |
[0.4, 0.6], | |
[0.2, 0.8], | |
[0.1, 0.9], | |
], | |
] | |
) | |
train = ConditionalCategorical( | |
[ | |
[ | |
[ | |
[0.8, 0.2], | |
[0.9, 0.1], | |
], | |
[ | |
[0.6, 0.4], | |
[0.7, 0.3], | |
], | |
[ | |
[0.4, 0.6], | |
[0.5, 0.5], | |
], | |
] | |
] | |
) | |
appointment = ConditionalCategorical( | |
[ | |
[ | |
[0.9, 0.1], | |
[0.6, 0.4], | |
], | |
] | |
) | |
# Create a Bayesian Network and add states | |
model = BayesianNetwork() | |
model.add_distributions([rain, maintenance, train, appointment]) | |
# Add edges connecting nodes | |
model.add_edge(rain, maintenance) | |
model.add_edge(rain, train) | |
model.add_edge(maintenance, train) | |
model.add_edge(train, appointment) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pomegranate.distributions import ConditionalCategorical | |
from collections import Counter | |
from model import model | |
# Rejection sampling | |
# Compute distribution of Appointment given that train is delayed | |
N = 10000 | |
data = [] | |
for i in range(N): | |
sample = model.sample(1)[0] | |
# sample == "delayed" | |
if sample[2] == 1.0: | |
data.append("attend" if sample[3] == 0 else "miss") | |
print(Counter(data)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment