Skip to content

Instantly share code, notes, and snippets.

@thekaranacharya
Last active July 4, 2024 06:54
Show Gist options
  • Save thekaranacharya/4e6b7d2322eb9736fb9f0596a909f681 to your computer and use it in GitHub Desktop.
Save thekaranacharya/4e6b7d2322eb9736fb9f0596a909f681 to your computer and use it in GitHub Desktop.
Implementation: Using Low-rank Adaptation(LoRA) fine-tuning for LLMs as described in https://arxiv.org/pdf/2106.09685
# Imports
import torch
from functools import partial
from transformers import AutoModelForSequenceClassification
# Classes
# Define the LoRA layer
class LoRA(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha) -> None:
"""
Args:
in_dim: int
Input dimension of the LoRA layer
out_dim: int
Output dimension of the LoRA layer
rank: int
Rank of the LoRA layer
alpha: int
Hyperparameter that refers to the degree to which to use "new" knowledge
"""
super().__init__()
self.A = torch.nn.Parameter(
torch.randn(in_dim, rank)
) # A.shape => (in_dim, rank)
self.B = torch.nn.Parameter(
torch.zeros(rank, out_dim)
) # B.shape => (rank, out_dim)
self.alpha = alpha
self.rank = rank
def forward(self, x):
"""
Forward propogation of the LoRA layer
"""
return (self.alpha / self.rank) * (x @ self.A @ self.B)
# Define the LinearLoRA layer
class LinearLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha) -> None:
"""
Args:
linear: torch.nn.Linear
Linear layer to which the LoRA layer is to be added
rank: int
Rank of the LoRA layer
alpha: int
Hyperparameter that refers to the degree to which to use "new" knowledge
"""
super().__init__()
self.linear = linear
self.lora = LoRA(linear.in_features, linear.out_features, rank, alpha)
def forward(self, x):
"""
Forward propogation of the LinearLoRA layer
"""
return self.linear(x) + self.lora(x)
###################
model_uri = "distilbert/distilbert-base-uncased"
num_classes = 2
lora_rank = 16
lora_alpha = 16
# Initialise the model
model = AutoModelForSequenceClassification.from_pretrained(
model_uri, num_labels=num_classes
)
# Freeze all the layers
for param in model.parameters():
param.requires_grad = False
"""
Applies the LinearLoRA layer to certain Linear layers
within the network
Only adapt the linear layers in the Attention block as per the paper,
keep layers in the MLP frozen
"""
linear_lora = partial(LinearLoRA, rank=lora_rank, alpha=lora_alpha)
# Replace only the Attention layers within the TransformerBlock with LinearLoRA
# As specified in the paper
for block in model.distilbert.transformer.layer:
## Transformer Block: Multi-head Self-Attention block
block.attention.q_lin = linear_lora(block.attention.q_lin)
block.attention.k_lin = linear_lora(block.attention.k_lin)
block.attention.v_lin = linear_lora(block.attention.v_lin)
block.attention.out_lin = linear_lora(block.attention.out_lin)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment