Created
July 25, 2024 04:41
-
-
Save ucalyptus2/1647bae342e7214ccd0babef6db8d69d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from typing import List, Tuple, Dict | |
class Spectrum: | |
def __init__(self, model: torch.nn.Module): | |
self.model = model | |
self.layer_snrs = {} | |
def compute_snr(self, weight_matrix: torch.Tensor) -> float: | |
# Convert to numpy for easier manipulation | |
W = weight_matrix.detach().cpu().numpy() | |
# Compute SVD | |
U, s, Vt = np.linalg.svd(W, full_matrices=False) | |
# Compute matrix dimensions | |
m, n = W.shape | |
r = min(m, n) | |
# Compute bounds using Marchenko-Pastur distribution | |
q = m / n | |
sigma = np.median(s) # Using median instead of std for robustness | |
lambda_plus = sigma**2 * (1 + np.sqrt(q))**2 | |
lambda_minus = sigma**2 * (1 - np.sqrt(q))**2 | |
# Compute epsilon (noise threshold) | |
epsilon = np.sqrt(lambda_minus) | |
# Compute SNR | |
signal = np.sum(s[s > epsilon]) | |
noise = np.sum(s[s <= epsilon]) | |
snr = signal / noise if noise != 0 else float('inf') | |
# Normalize SNR by largest singular value | |
normalized_snr = snr / s[0] | |
return normalized_snr | |
def compute_layer_snrs(self): | |
for name, module in self.model.named_modules(): | |
if isinstance(module, torch.nn.Linear): | |
snr = self.compute_snr(module.weight) | |
self.layer_snrs[name] = snr | |
def select_layers_for_training(self, percentage: float = 0.25) -> List[str]: | |
if not self.layer_snrs: | |
self.compute_layer_snrs() | |
# Sort layers by SNR in descending order | |
sorted_layers = sorted(self.layer_snrs.items(), key=lambda x: x[1], reverse=True) | |
# Select top percentage of layers | |
num_layers_to_train = int(len(sorted_layers) * percentage) | |
selected_layers = [layer[0] for layer in sorted_layers[:num_layers_to_train]] | |
return selected_layers | |
def prepare_model_for_training(self, selected_layers: List[str]): | |
for name, param in self.model.named_parameters(): | |
if any(layer in name for layer in selected_layers): | |
param.requires_grad = True | |
else: | |
param.requires_grad = False | |
def apply_spectrum(model: torch.nn.Module, percentage: float = 0.25) -> Tuple[torch.nn.Module, List[str]]: | |
spectrum = Spectrum(model) | |
selected_layers = spectrum.select_layers_for_training(percentage) | |
spectrum.prepare_model_for_training(selected_layers) | |
return model, selected_layers | |
# Example usage | |
if __name__ == "__main__": | |
# Assume we have a pre-trained model | |
model = torch.nn.Sequential( | |
torch.nn.Linear(100, 50), | |
torch.nn.ReLU(), | |
torch.nn.Linear(50, 25), | |
torch.nn.ReLU(), | |
torch.nn.Linear(25, 10) | |
) | |
# Apply Spectrum method | |
model, selected_layers = apply_spectrum(model, percentage=0.5) | |
print("Selected layers for training:", selected_layers) | |
print("Trainable parameters:") | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
print(name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment