Skip to content

Instantly share code, notes, and snippets.

@ucalyptus2
Created July 25, 2024 04:41
Show Gist options
  • Save ucalyptus2/794350462c6e7d2316ca8f4ddf74dfc3 to your computer and use it in GitHub Desktop.
Save ucalyptus2/794350462c6e7d2316ca8f4ddf74dfc3 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
class Spectrum:
def __init__(self, model, train_loader, val_loader, device='cuda'):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.snr_threshold = None
def compute_snr(self, weight_matrix):
U, S, V = torch.svd(weight_matrix)
signal = S[S > self.snr_threshold]
noise = S[S <= self.snr_threshold]
snr = signal.sum() / (noise.sum() + 1e-5)
return snr
def select_layers(self):
snr_values = {}
for name, param in self.model.named_parameters():
if 'weight' in name and param.requires_grad:
snr = self.compute_snr(param.data)
snr_values[name] = snr
sorted_snr = sorted(snr_values.items(), key=lambda item: item[1], reverse=True)
top_layers = [name for name, _ in sorted_snr[:int(0.25 * len(sorted_snr))]]
return top_layers
def freeze_layers(self, top_layers):
for name, param in self.model.named_parameters():
if name not in top_layers:
param.requires_grad = False
def train(self, num_epochs, learning_rate):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate)
for epoch in range(num_epochs):
self.model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(self.train_loader):
inputs, labels = inputs.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(self.train_loader)}')
def validate(self):
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in self.val_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
outputs = self.model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy: {100 * correct / total}%')
# Usage example
# model = ... # Define your model
# train_loader = DataLoader(...) # Define your training data loader
# val_loader = DataLoader(...) # Define your validation data loader
spectrum = Spectrum(model, train_loader, val_loader)
top_layers = spectrum.select_layers()
spectrum.freeze_layers(top_layers)
spectrum.train(num_epochs=10, learning_rate=1e-5)
spectrum.validate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment