Skip to content

Instantly share code, notes, and snippets.

@thekaranacharya
Last active July 3, 2024 22:54
Show Gist options
  • Save thekaranacharya/f6bcb32ac983870999b7ca75056fe4af to your computer and use it in GitHub Desktop.
Save thekaranacharya/f6bcb32ac983870999b7ca75056fe4af to your computer and use it in GitHub Desktop.
Implementation: Using simple fine-tuning (freezing all layers except the last few Linear layers) for LLMs
# Imports
from transformers import AutoModelForSequenceClassification
###################
model_uri = "distilbert/distilbert-base-uncased"
num_classes = 2
# Initialise the model
model = AutoModelForSequenceClassification.from_pretrained(
model_uri, num_labels=num_classes
)
# Freeze all the layers
for param in model.parameters():
param.requires_grad = False
# Unfreeze pre-classifier (penultimate layer)
for param in model.pre_classifier.parameters():
param.requires_grad = True
# Unfreeze classifier (final layer)
for param in model.classifier.parameters():
param.requires_grad = True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment