Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Last active June 5, 2020 15:41
Show Gist options
  • Save sshleifer/8860c44260ebd94dbce271d68ccebab3 to your computer and use it in GitHub Desktop.
Save sshleifer/8860c44260ebd94dbce271d68ccebab3 to your computer and use it in GitHub Desktop.
Copy a subset of bart layers into a smaller model
from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer
from torch import nn
from typing import List
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
6: [0, 2, 4, 7, 9, 11],
1: [11],
3: [0, 6, 11],
2: [0, 11],
4: [0, 4, 8, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
}
def init_student(student, teacher):
"""Copy everything"""
teacher_state_dict = teacher.state_dict()
info = student.load_state_dict(teacher_state_dict, strict=False)
assert info.missing_keys == [], info.missing_keys
return student, info
def copy_layers(teacher_layers, student_layers, l2copy: List):
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in l2copy])
assert len(student_layers) == len(l2copy), f"{len(student_layers)} != {len(l2copy)}"
student_layers.load_state_dict(layers_to_copy.state_dict())
def make_student(teacher, student_updates):
d_layers_to_copy = layers_to_copy[student_updates["decoder_layers"]]
e_layers_to_copy = layers_to_copy[student_updates["encoder_layers"]]
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
return student
teacher = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
student_updates = {
"decoder_layers": 6,
"encoder_layers": 6,
}
student = make_student(teacher, student_updates)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment