Skip to content

Instantly share code, notes, and snippets.

@zilunpeng
Created March 23, 2021 21:25
Show Gist options
  • Save zilunpeng/542bca592317bc2ff03a93c3770aa12b to your computer and use it in GitHub Desktop.
Save zilunpeng/542bca592317bc2ff03a93c3770aa12b to your computer and use it in GitHub Desktop.
Initialize the student model by taking alternating layers. Code below is part of student_wav2vec2.py (https://git.io/JYeXX)
step = num_trans_layer_student_init_model // num_trans_layer_student_model student_init_model_selected_transformer_layers = [i for i in range(0, num_trans_layer_student_init_model, step)]
student_model_trans_layer_prefix = "encoder.layers."
student_model_transformer_layers = [i for i in range(num_trans_layer_student_model)]
for student_layer_i, init_layer_i in zip(student_model_transformer_layers, student_init_model_selected_transformer_layers):
for transformer_part in transformer_parts:
layer_name = student_model_trans_layer_prefix + str(student_layer_i) + transformer_part
param = student_init_model_state[student_init_model_trans_layer_prefix + str(init_layer_i) + transformer_part]
student_model_state[layer_name].copy_(param)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment