Skip to content

Instantly share code, notes, and snippets.

@danielhavir
Created November 25, 2022 04:25
Show Gist options
  • Save danielhavir/e112b1ffd8ddc354f7b36374aeb84874 to your computer and use it in GitHub Desktop.
Save danielhavir/e112b1ffd8ddc354f7b36374aeb84874 to your computer and use it in GitHub Desktop.
class LoRAConv1DWrapper(nn.Module):
"""SimpleWrapper class that implements LoRA: Low-Rank Adaptation of Large Language Models.
Arxiv link: https://arxiv.org/abs/2106.09685"""
def __init__(self, conv1dmodule: transformers.pytorch_utils.Conv1D, lora_rank: int):
super().__init__()
self.base_module = conv1dmodule
d1, d2 = self.base_module.weight.size()
self.A = nn.Parameter(
torch.empty(
d1, lora_rank, dtype=self.base_module.weight.dtype, device=DEVICE
)
)
self.B = nn.Parameter(
torch.empty(
d2, lora_rank, dtype=self.base_module.weight.dtype, device=DEVICE
)
)
nn.init.kaiming_normal_(self.A)
nn.init.zeros_(self.B)
def forward(self, x):
bs, seq_len, fs = x.size()
x = x.view(-1, fs)
W = self.base_module.weight
bias = self.base_module.bias
W_out = torch.matmul(x, W)
A_out = torch.matmul(x, self.A)
B_out = torch.matmul(A_out, self.B.T)
out = W_out + B_out
out = out + bias
return out.view(bs, seq_len, out.size(-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment