Created
December 18, 2024 17:55
-
-
Save Flecart/dc04ca800cf2a9b152e28c93493a6bfe to your computer and use it in GitHub Desktop.
BALD data to play with Generated by ChatGPT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.distributions import Categorical | |
import numpy as np | |
# Define a simple Bayesian Neural Network (BNN) with MC Dropout | |
class BayesianNN(nn.Module): | |
def __init__(self, input_dim, output_dim, hidden_dim=50): | |
super(BayesianNN, self).__init__() | |
self.fc1 = nn.Linear(input_dim, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, output_dim) | |
self.dropout = nn.Dropout(0.2) # Dropout simulates Bayesian uncertainty | |
def forward(self, x): | |
x = torch.relu(self.fc1(x)) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
return torch.log_softmax(x, dim=-1) | |
def mc_forward(self, x, num_samples=10): | |
"""Perform MC Dropout by sampling multiple forward passes""" | |
outputs = [self.forward(x) for _ in range(num_samples)] | |
return torch.stack(outputs) # Shape: [num_samples, batch_size, num_classes] | |
# BALD acquisition function | |
def bald_acquisition(bnn, data, num_samples=10): | |
""" | |
Compute the BALD acquisition score. | |
- `bnn`: The Bayesian neural network | |
- `data`: Unlabeled data (tensor) | |
- `num_samples`: Number of stochastic forward passes (MC samples) | |
""" | |
mc_logits = bnn.mc_forward(data, num_samples) # Shape: [num_samples, batch_size, num_classes] | |
mc_probs = mc_logits.exp() # Convert logits to probabilities | |
# Mean probability over MC samples (Expected predictive distribution) | |
mean_probs = mc_probs.mean(dim=0) # Shape: [batch_size, num_classes] | |
# Predictive entropy | |
entropy = -(mean_probs * mean_probs.log()).sum(dim=-1) # Shape: [batch_size] | |
# Mutual information (BALD score) | |
expected_entropy = -(mc_probs * mc_probs.log()).sum(dim=-1).mean(dim=0) # Shape: [batch_size] | |
bald_score = entropy - expected_entropy # Shape: [batch_size] | |
return bald_score | |
# Generate toy dataset | |
def generate_toy_data(num_samples=100): | |
x = torch.linspace(-5, 5, num_samples).unsqueeze(1) | |
y = torch.sin(x) + 0.1 * torch.randn_like(x) | |
return x, y | |
import matplotlib.pyplot as plt | |
# Modified main function with visualizations | |
if __name__ == "__main__": | |
# Parameters | |
input_dim = 1 | |
output_dim = 2 # Assume a binary classification task for simplicity | |
num_samples = 100 | |
num_acquisition_steps = 10 | |
# Generate toy data | |
x_pool, y_pool = generate_toy_data(num_samples) | |
x_train, y_train = x_pool[:10], (y_pool[:10] > 0).long().squeeze() | |
x_unlabeled = x_pool[10:] | |
y_unlabeled = (y_pool[10:] > 0).long().squeeze() | |
# Initialize the Bayesian Neural Network | |
bnn = BayesianNN(input_dim, output_dim) | |
optimizer = optim.Adam(bnn.parameters(), lr=0.01) | |
loss_fn = nn.NLLLoss() | |
# Set up the figure | |
plt.figure(figsize=(10, 6)) | |
for step in range(num_acquisition_steps): | |
# Train the BNN on labeled data | |
for epoch in range(100): | |
bnn.train() | |
optimizer.zero_grad() | |
logits = bnn(x_train) | |
loss = loss_fn(logits, y_train) | |
loss.backward() | |
optimizer.step() | |
# Perform BALD acquisition on the unlabeled data | |
bnn.eval() | |
bald_scores = bald_acquisition(bnn, x_unlabeled).detach().numpy() | |
# Select the top-scoring sample | |
top_idx = np.argmax(bald_scores) | |
new_x, new_y = x_unlabeled[top_idx].unsqueeze(0), y_unlabeled[top_idx].unsqueeze(0) | |
# Add the new sample to the training set | |
x_train = torch.cat([x_train, new_x], dim=0) | |
y_train = torch.cat([y_train, new_y], dim=0) | |
# Remove the selected sample from the unlabeled pool | |
x_unlabeled = torch.cat([x_unlabeled[:top_idx], x_unlabeled[top_idx + 1:]], dim=0) | |
y_unlabeled = torch.cat([y_unlabeled[:top_idx], y_unlabeled[top_idx + 1:]], dim=0) | |
# Visualization | |
plt.clf() | |
plt.scatter(x_pool.numpy(), y_pool.numpy(), label="True Function", c="gray", alpha=0.3) | |
plt.scatter(x_train.numpy(), y_train.numpy(), label="Labeled Data", c="blue") | |
plt.scatter( | |
x_unlabeled.numpy(), | |
y_unlabeled.numpy(), | |
label="Unlabeled Data", | |
c="orange", | |
alpha=0.6, | |
) | |
plt.scatter(new_x.numpy(), new_y.numpy(), label="Newly Acquired", c="red", edgecolor="black", s=100) | |
plt.title(f"Step {step + 1}: Acquired sample at x = {new_x.numpy()[0][0]:.2f}") | |
plt.xlabel("x") | |
plt.ylabel("y") | |
plt.legend() | |
plt.pause(0.5) | |
plt.show() | |
print("Active learning finished!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://chatgpt.com/share/67631008-5954-8009-b866-c6bde1c20e74