-
-
Save p3nGu1nZz/22b2e4d85d017e39f5cff61b51d1ca4e to your computer and use it in GitHub Desktop.
EWC for fine-tuning models.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
# Example model | |
class SimpleNN(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size): | |
super(SimpleNN, self).__init__() | |
self.fc1 = nn.Linear(input_size, hidden_size) | |
self.fc2 = nn.Linear(hidden_size, output_size) | |
def forward(self, x): | |
x = torch.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
# Compute Fisher Information Matrix | |
def compute_fisher(model, dataloader, criterion): | |
fisher_matrix = {} | |
model.eval() | |
for data, target in dataloader: | |
model.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
loss.backward() | |
for name, param in model.named_parameters(): | |
if param.grad is not None: | |
fisher_matrix[name] = param.grad.data ** 2 | |
return fisher_matrix | |
# EWC Loss | |
def ewc_loss(model, fisher_matrix, prev_params, lambda_=0.4): | |
loss = 0 | |
for name, param in model.named_parameters(): | |
if name in fisher_matrix: | |
fisher = fisher_matrix[name] | |
prev_param = prev_params[name] | |
loss += (fisher * (param - prev_param) ** 2).sum() | |
return lambda_ * loss | |
# Main training loop | |
def train_model(model, dataloader, criterion, optimizer, epochs=10, fisher_matrix=None, prev_params=None): | |
model.train() | |
for epoch in range(epochs): | |
for data, target in dataloader: | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
if fisher_matrix is not None and prev_params is not None: | |
loss += ewc_loss(model, fisher_matrix, prev_params) | |
loss.backward() | |
optimizer.step() | |
# Initialize and train | |
model = SimpleNN(input_size=784, hidden_size=256, output_size=10) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
train_dataloader = ... # Define your DataLoader | |
# Train on first task | |
train_model(model, train_dataloader, criterion, optimizer) | |
# Save parameters and compute Fisher Matrix | |
prev_params = {name: param.clone() for name, param in model.named_parameters()} | |
fisher_matrix = compute_fisher(model, train_dataloader, criterion) | |
# Train on new task with EWC | |
new_train_dataloader = ... # Define your new task DataLoader | |
train_model(model, new_train_dataloader, criterion, optimizer, fisher_matrix=fisher_matrix, prev_params=prev_params) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment