Created
January 17, 2022 18:34
-
-
Save remi-or/873bfc00dfcc794064ba350d1b142d79 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers.models.roberta.modeling_roberta import RobertaEncoder, RobertaModel | |
| from torch.nn import Module | |
| def distill_roberta_weights( | |
| teacher : Module, | |
| student : Module, | |
| ) -> None: | |
| """ | |
| Recursively copies the weights of the (teacher) to the (student). | |
| This function is meant to be first called on a RobertaFor... model, but is then called on every children of that model recursively. | |
| The only part that's not fully copied is the encoder, of which only half is copied. | |
| """ | |
| # If the part is an entire RoBERTa model or a RobertaFor..., unpack and iterate | |
| if isinstance(teacher, RobertaModel) or type(teacher).__name__.startswith('RobertaFor'): | |
| for teacher_part, student_part in zip(teacher.children(), student.children()): | |
| distill_roberta_weights(teacher_part, student_part) | |
| # Else if the part is an encoder, copy one out of every layer | |
| elif isinstance(teacher, RobertaEncoder): | |
| teacher_encoding_layers = [layer for layer in next(teacher.children())] | |
| student_encoding_layers = [layer for layer in next(student.children())] | |
| for i in range(len(student_encoding_layers)): | |
| student_encoding_layers[i].load_state_dict(teacher_encoding_layers[2*i].state_dict()) | |
| # Else the part is a head or something else, copy the state_dict | |
| else: | |
| student.load_state_dict(teacher.state_dict()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment