Skip to content

Instantly share code, notes, and snippets.

@p3nGu1nZz
Created December 2, 2024 14:14
Show Gist options
  • Save p3nGu1nZz/22b2e4d85d017e39f5cff61b51d1ca4e to your computer and use it in GitHub Desktop.
Save p3nGu1nZz/22b2e4d85d017e39f5cff61b51d1ca4e to your computer and use it in GitHub Desktop.
EWC for fine-tuning models.
import torch
import torch.nn as nn
import torch.optim as optim
# Example model
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Compute Fisher Information Matrix
def compute_fisher(model, dataloader, criterion):
fisher_matrix = {}
model.eval()
for data, target in dataloader:
model.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
fisher_matrix[name] = param.grad.data ** 2
return fisher_matrix
# EWC Loss
def ewc_loss(model, fisher_matrix, prev_params, lambda_=0.4):
loss = 0
for name, param in model.named_parameters():
if name in fisher_matrix:
fisher = fisher_matrix[name]
prev_param = prev_params[name]
loss += (fisher * (param - prev_param) ** 2).sum()
return lambda_ * loss
# Main training loop
def train_model(model, dataloader, criterion, optimizer, epochs=10, fisher_matrix=None, prev_params=None):
model.train()
for epoch in range(epochs):
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
if fisher_matrix is not None and prev_params is not None:
loss += ewc_loss(model, fisher_matrix, prev_params)
loss.backward()
optimizer.step()
# Initialize and train
model = SimpleNN(input_size=784, hidden_size=256, output_size=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
train_dataloader = ... # Define your DataLoader
# Train on first task
train_model(model, train_dataloader, criterion, optimizer)
# Save parameters and compute Fisher Matrix
prev_params = {name: param.clone() for name, param in model.named_parameters()}
fisher_matrix = compute_fisher(model, train_dataloader, criterion)
# Train on new task with EWC
new_train_dataloader = ... # Define your new task DataLoader
train_model(model, new_train_dataloader, criterion, optimizer, fisher_matrix=fisher_matrix, prev_params=prev_params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment