Skip to content

Instantly share code, notes, and snippets.

@ucalyptus2
Created July 25, 2024 04:41
Show Gist options
  • Save ucalyptus2/1647bae342e7214ccd0babef6db8d69d to your computer and use it in GitHub Desktop.
Save ucalyptus2/1647bae342e7214ccd0babef6db8d69d to your computer and use it in GitHub Desktop.
import torch
import numpy as np
from typing import List, Tuple, Dict
class Spectrum:
def __init__(self, model: torch.nn.Module):
self.model = model
self.layer_snrs = {}
def compute_snr(self, weight_matrix: torch.Tensor) -> float:
# Convert to numpy for easier manipulation
W = weight_matrix.detach().cpu().numpy()
# Compute SVD
U, s, Vt = np.linalg.svd(W, full_matrices=False)
# Compute matrix dimensions
m, n = W.shape
r = min(m, n)
# Compute bounds using Marchenko-Pastur distribution
q = m / n
sigma = np.median(s) # Using median instead of std for robustness
lambda_plus = sigma**2 * (1 + np.sqrt(q))**2
lambda_minus = sigma**2 * (1 - np.sqrt(q))**2
# Compute epsilon (noise threshold)
epsilon = np.sqrt(lambda_minus)
# Compute SNR
signal = np.sum(s[s > epsilon])
noise = np.sum(s[s <= epsilon])
snr = signal / noise if noise != 0 else float('inf')
# Normalize SNR by largest singular value
normalized_snr = snr / s[0]
return normalized_snr
def compute_layer_snrs(self):
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
snr = self.compute_snr(module.weight)
self.layer_snrs[name] = snr
def select_layers_for_training(self, percentage: float = 0.25) -> List[str]:
if not self.layer_snrs:
self.compute_layer_snrs()
# Sort layers by SNR in descending order
sorted_layers = sorted(self.layer_snrs.items(), key=lambda x: x[1], reverse=True)
# Select top percentage of layers
num_layers_to_train = int(len(sorted_layers) * percentage)
selected_layers = [layer[0] for layer in sorted_layers[:num_layers_to_train]]
return selected_layers
def prepare_model_for_training(self, selected_layers: List[str]):
for name, param in self.model.named_parameters():
if any(layer in name for layer in selected_layers):
param.requires_grad = True
else:
param.requires_grad = False
def apply_spectrum(model: torch.nn.Module, percentage: float = 0.25) -> Tuple[torch.nn.Module, List[str]]:
spectrum = Spectrum(model)
selected_layers = spectrum.select_layers_for_training(percentage)
spectrum.prepare_model_for_training(selected_layers)
return model, selected_layers
# Example usage
if __name__ == "__main__":
# Assume we have a pre-trained model
model = torch.nn.Sequential(
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 25),
torch.nn.ReLU(),
torch.nn.Linear(25, 10)
)
# Apply Spectrum method
model, selected_layers = apply_spectrum(model, percentage=0.5)
print("Selected layers for training:", selected_layers)
print("Trainable parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment