Last active
October 22, 2023 11:49
-
-
Save albertbuchard/98b5739b40cee32ad3a33deec0709527 to your computer and use it in GitHub Desktop.
MINE: Mutual Information Neural Estimation | Minimal Working Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch.optim as optim | |
import torch | |
from torch import nn | |
class MineWrapper(nn.Module): | |
def __init__(self, stat_model, moving_average_rate=0.1, unbiased=False): | |
super(MineWrapper, self).__init__() | |
self.stat_model = stat_model | |
self.unbiased = unbiased | |
LogWithMovingAverageGrad.alpha = moving_average_rate | |
def get_t_exp_t(self, x, y): | |
# resample y for marginal estimation | |
y_resampled = y[torch.randperm(y.shape[0])] | |
t = self.stat_model(x, mine_y=y).mean() | |
exp_t = torch.exp(self.stat_model(x, mine_y=y_resampled)).mean() | |
return t, exp_t | |
def get_loss(self, x, y): | |
t, exp_t = self.get_t_exp_t(x, y) | |
if self.unbiased: | |
lower_bound = (t - LogWithMovingAverage(exp_t)) | |
else: | |
lower_bound = (t - torch.log(exp_t)) | |
return -1.0 * lower_bound | |
def get_mutual_information(self, x, y): | |
t, exp_t = self.get_t_exp_t(x, y) | |
mi = (t - torch.log(exp_t)).item() / math.log(2) | |
return mi | |
class LogWithMovingAverageGrad(torch.autograd.Function): | |
# Static variable to store the moving average of the input | |
moving_avg_input = None | |
alpha = 0.01 | |
@staticmethod | |
def forward(ctx, input): | |
# Compute the log and save the input for backward pass | |
output = input.log() | |
ctx.save_for_backward(input) | |
return output | |
@staticmethod | |
def backward(ctx, grad_output): | |
input, = ctx.saved_tensors | |
# Update the moving average of the input | |
if LogWithMovingAverageGrad.moving_avg_input is None: | |
LogWithMovingAverageGrad.moving_avg_input = input | |
else: | |
LogWithMovingAverageGrad.moving_avg_input = ( | |
LogWithMovingAverageGrad.alpha * input + | |
(1 - LogWithMovingAverageGrad.alpha) * LogWithMovingAverageGrad.moving_avg_input | |
) | |
# Normalize the grad_output by dividing it with the moving average of the input | |
grad_input = grad_output / LogWithMovingAverageGrad.moving_avg_input | |
return grad_input | |
LogWithMovingAverage = LogWithMovingAverageGrad.apply | |
class StatModel(nn.Module): | |
def __init__(self, dim): | |
super(StatModel, self).__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(dim, 100), | |
nn.ReLU(), | |
nn.Linear(100, 1) | |
) | |
def forward(self, x, mine_y): | |
# Concatenate x and y | |
x_y = torch.cat([x, mine_y], dim=1) | |
out = self.layers(x_y) | |
return out | |
def train(x, y, num_epochs=100): | |
dim = x.shape[1] + y.shape[1] | |
stat_model = StatModel(dim) | |
# Create an instance of MineWrapper | |
mine = MineWrapper(stat_model=stat_model) | |
# Set up the optimizer | |
optimizer = optim.AdamW(mine.parameters(), lr=0.001) | |
# Training loop | |
mi = None | |
for epoch in range(num_epochs): | |
optimizer.zero_grad() | |
loss = mine.get_loss(x, y) | |
loss.backward() | |
optimizer.step() | |
mi = mine.get_mutual_information(x, y) | |
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Mutual Information: {mi}") | |
return mi | |
if __name__ == "__main__": | |
n = 10000 | |
dim = 10 | |
# Independent variables | |
x = torch.randn(n, 10) | |
y = torch.randint(0, 2, size=(n, 10)).float() | |
independent_mi = train(x, y) | |
# Dependent variables | |
x = torch.randn(n, 10) | |
y = x + torch.normal(0, 2, size=(n, 10)) > 0 | |
y = y.float() | |
dependent_mi = train(x, y) | |
print(f"Independent MI: {independent_mi}, Dependent MI: {dependent_mi}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment