Skip to content

Instantly share code, notes, and snippets.

@remi-or
Created January 17, 2022 18:34
Show Gist options
  • Select an option

  • Save remi-or/873bfc00dfcc794064ba350d1b142d79 to your computer and use it in GitHub Desktop.

Select an option

Save remi-or/873bfc00dfcc794064ba350d1b142d79 to your computer and use it in GitHub Desktop.
from transformers.models.roberta.modeling_roberta import RobertaEncoder, RobertaModel
from torch.nn import Module
def distill_roberta_weights(
teacher : Module,
student : Module,
) -> None:
"""
Recursively copies the weights of the (teacher) to the (student).
This function is meant to be first called on a RobertaFor... model, but is then called on every children of that model recursively.
The only part that's not fully copied is the encoder, of which only half is copied.
"""
# If the part is an entire RoBERTa model or a RobertaFor..., unpack and iterate
if isinstance(teacher, RobertaModel) or type(teacher).__name__.startswith('RobertaFor'):
for teacher_part, student_part in zip(teacher.children(), student.children()):
distill_roberta_weights(teacher_part, student_part)
# Else if the part is an encoder, copy one out of every layer
elif isinstance(teacher, RobertaEncoder):
teacher_encoding_layers = [layer for layer in next(teacher.children())]
student_encoding_layers = [layer for layer in next(student.children())]
for i in range(len(student_encoding_layers)):
student_encoding_layers[i].load_state_dict(teacher_encoding_layers[2*i].state_dict())
# Else the part is a head or something else, copy the state_dict
else:
student.load_state_dict(teacher.state_dict())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment