Skip to content

Instantly share code, notes, and snippets.

@Flecart
Created December 18, 2024 17:55
Show Gist options
  • Save Flecart/dc04ca800cf2a9b152e28c93493a6bfe to your computer and use it in GitHub Desktop.
Save Flecart/dc04ca800cf2a9b152e28c93493a6bfe to your computer and use it in GitHub Desktop.
BALD data to play with Generated by ChatGPT
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
# Define a simple Bayesian Neural Network (BNN) with MC Dropout
class BayesianNN(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=50):
super(BayesianNN, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.2) # Dropout simulates Bayesian uncertainty
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return torch.log_softmax(x, dim=-1)
def mc_forward(self, x, num_samples=10):
"""Perform MC Dropout by sampling multiple forward passes"""
outputs = [self.forward(x) for _ in range(num_samples)]
return torch.stack(outputs) # Shape: [num_samples, batch_size, num_classes]
# BALD acquisition function
def bald_acquisition(bnn, data, num_samples=10):
"""
Compute the BALD acquisition score.
- `bnn`: The Bayesian neural network
- `data`: Unlabeled data (tensor)
- `num_samples`: Number of stochastic forward passes (MC samples)
"""
mc_logits = bnn.mc_forward(data, num_samples) # Shape: [num_samples, batch_size, num_classes]
mc_probs = mc_logits.exp() # Convert logits to probabilities
# Mean probability over MC samples (Expected predictive distribution)
mean_probs = mc_probs.mean(dim=0) # Shape: [batch_size, num_classes]
# Predictive entropy
entropy = -(mean_probs * mean_probs.log()).sum(dim=-1) # Shape: [batch_size]
# Mutual information (BALD score)
expected_entropy = -(mc_probs * mc_probs.log()).sum(dim=-1).mean(dim=0) # Shape: [batch_size]
bald_score = entropy - expected_entropy # Shape: [batch_size]
return bald_score
# Generate toy dataset
def generate_toy_data(num_samples=100):
x = torch.linspace(-5, 5, num_samples).unsqueeze(1)
y = torch.sin(x) + 0.1 * torch.randn_like(x)
return x, y
import matplotlib.pyplot as plt
# Modified main function with visualizations
if __name__ == "__main__":
# Parameters
input_dim = 1
output_dim = 2 # Assume a binary classification task for simplicity
num_samples = 100
num_acquisition_steps = 10
# Generate toy data
x_pool, y_pool = generate_toy_data(num_samples)
x_train, y_train = x_pool[:10], (y_pool[:10] > 0).long().squeeze()
x_unlabeled = x_pool[10:]
y_unlabeled = (y_pool[10:] > 0).long().squeeze()
# Initialize the Bayesian Neural Network
bnn = BayesianNN(input_dim, output_dim)
optimizer = optim.Adam(bnn.parameters(), lr=0.01)
loss_fn = nn.NLLLoss()
# Set up the figure
plt.figure(figsize=(10, 6))
for step in range(num_acquisition_steps):
# Train the BNN on labeled data
for epoch in range(100):
bnn.train()
optimizer.zero_grad()
logits = bnn(x_train)
loss = loss_fn(logits, y_train)
loss.backward()
optimizer.step()
# Perform BALD acquisition on the unlabeled data
bnn.eval()
bald_scores = bald_acquisition(bnn, x_unlabeled).detach().numpy()
# Select the top-scoring sample
top_idx = np.argmax(bald_scores)
new_x, new_y = x_unlabeled[top_idx].unsqueeze(0), y_unlabeled[top_idx].unsqueeze(0)
# Add the new sample to the training set
x_train = torch.cat([x_train, new_x], dim=0)
y_train = torch.cat([y_train, new_y], dim=0)
# Remove the selected sample from the unlabeled pool
x_unlabeled = torch.cat([x_unlabeled[:top_idx], x_unlabeled[top_idx + 1:]], dim=0)
y_unlabeled = torch.cat([y_unlabeled[:top_idx], y_unlabeled[top_idx + 1:]], dim=0)
# Visualization
plt.clf()
plt.scatter(x_pool.numpy(), y_pool.numpy(), label="True Function", c="gray", alpha=0.3)
plt.scatter(x_train.numpy(), y_train.numpy(), label="Labeled Data", c="blue")
plt.scatter(
x_unlabeled.numpy(),
y_unlabeled.numpy(),
label="Unlabeled Data",
c="orange",
alpha=0.6,
)
plt.scatter(new_x.numpy(), new_y.numpy(), label="Newly Acquired", c="red", edgecolor="black", s=100)
plt.title(f"Step {step + 1}: Acquired sample at x = {new_x.numpy()[0][0]:.2f}")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.pause(0.5)
plt.show()
print("Active learning finished!")
@Flecart
Copy link
Author

Flecart commented Dec 18, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment