Last active
July 4, 2024 06:54
-
-
Save thekaranacharya/4e6b7d2322eb9736fb9f0596a909f681 to your computer and use it in GitHub Desktop.
Implementation: Using Low-rank Adaptation(LoRA) fine-tuning for LLMs as described in https://arxiv.org/pdf/2106.09685
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports | |
import torch | |
from functools import partial | |
from transformers import AutoModelForSequenceClassification | |
# Classes | |
# Define the LoRA layer | |
class LoRA(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, rank, alpha) -> None: | |
""" | |
Args: | |
in_dim: int | |
Input dimension of the LoRA layer | |
out_dim: int | |
Output dimension of the LoRA layer | |
rank: int | |
Rank of the LoRA layer | |
alpha: int | |
Hyperparameter that refers to the degree to which to use "new" knowledge | |
""" | |
super().__init__() | |
self.A = torch.nn.Parameter( | |
torch.randn(in_dim, rank) | |
) # A.shape => (in_dim, rank) | |
self.B = torch.nn.Parameter( | |
torch.zeros(rank, out_dim) | |
) # B.shape => (rank, out_dim) | |
self.alpha = alpha | |
self.rank = rank | |
def forward(self, x): | |
""" | |
Forward propogation of the LoRA layer | |
""" | |
return (self.alpha / self.rank) * (x @ self.A @ self.B) | |
# Define the LinearLoRA layer | |
class LinearLoRA(torch.nn.Module): | |
def __init__(self, linear, rank, alpha) -> None: | |
""" | |
Args: | |
linear: torch.nn.Linear | |
Linear layer to which the LoRA layer is to be added | |
rank: int | |
Rank of the LoRA layer | |
alpha: int | |
Hyperparameter that refers to the degree to which to use "new" knowledge | |
""" | |
super().__init__() | |
self.linear = linear | |
self.lora = LoRA(linear.in_features, linear.out_features, rank, alpha) | |
def forward(self, x): | |
""" | |
Forward propogation of the LinearLoRA layer | |
""" | |
return self.linear(x) + self.lora(x) | |
################### | |
model_uri = "distilbert/distilbert-base-uncased" | |
num_classes = 2 | |
lora_rank = 16 | |
lora_alpha = 16 | |
# Initialise the model | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_uri, num_labels=num_classes | |
) | |
# Freeze all the layers | |
for param in model.parameters(): | |
param.requires_grad = False | |
""" | |
Applies the LinearLoRA layer to certain Linear layers | |
within the network | |
Only adapt the linear layers in the Attention block as per the paper, | |
keep layers in the MLP frozen | |
""" | |
linear_lora = partial(LinearLoRA, rank=lora_rank, alpha=lora_alpha) | |
# Replace only the Attention layers within the TransformerBlock with LinearLoRA | |
# As specified in the paper | |
for block in model.distilbert.transformer.layer: | |
## Transformer Block: Multi-head Self-Attention block | |
block.attention.q_lin = linear_lora(block.attention.q_lin) | |
block.attention.k_lin = linear_lora(block.attention.k_lin) | |
block.attention.v_lin = linear_lora(block.attention.v_lin) | |
block.attention.out_lin = linear_lora(block.attention.out_lin) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment