Skip to content

Instantly share code, notes, and snippets.

@zilunpeng
zilunpeng / calc_kd_loss.py
Created March 23, 2021 21:18
Calculate the knowledge distillation loss. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
torch.nn.functional.kl_div(student_log_prob, teacher_prob, reduction='batchmean') * (self.temperature**2)
@zilunpeng
zilunpeng / calc_feat_pen.py
Created March 23, 2021 21:20
Calculate the feature penalty. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
features_pen = features.float().pow(2).mean()
@zilunpeng
zilunpeng / set_kd_opt_scheduler.py
Created March 23, 2021 21:22
Set optimizer and learning rate scheduler. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
def lr_lambda(current_epoch):
if current_epoch < self.num_lr_warm_up_epoch:
return float(current_epoch+1) / float(max(1, self.num_lr_warm_up_epoch))
else:
return max( 0.0, float(self.max_epoch - current_epoch) / float(max(1, self.max_epoch - self.num_lr_warm_up_epoch)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
@zilunpeng
zilunpeng / init_student_wav2vec2.py
Created March 23, 2021 21:25
Initialize the student model by taking alternating layers. Code below is part of student_wav2vec2.py (https://git.io/JYeXX)
step = num_trans_layer_student_init_model // num_trans_layer_student_model student_init_model_selected_transformer_layers = [i for i in range(0, num_trans_layer_student_init_model, step)]
student_model_trans_layer_prefix = "encoder.layers."
student_model_transformer_layers = [i for i in range(num_trans_layer_student_model)]
for student_layer_i, init_layer_i in zip(student_model_transformer_layers, student_init_model_selected_transformer_layers):
for transformer_part in transformer_parts:
layer_name = student_model_trans_layer_prefix + str(student_layer_i) + transformer_part
param = student_init_model_state[student_init_model_trans_layer_prefix + str(init_layer_i) + transformer_part]
student_model_state[layer_name].copy_(param)
@zilunpeng
zilunpeng / prepare_quantized_wav2vec2_for_inf.py
Created March 23, 2021 21:28
Prepare wav2vec 2.0 for inference after quantization. Code is part of wav2vec2.py (https://git.io/JYe1Y).
def prepare_for_inference_after_quantization(self):
dequantizer = torch.nn.quantized.DeQuantize()
for trans_layer in self.encoder.layers:
trans_layer.self_attn.q_proj_bias = trans_layer.self_attn.q_proj.bias()
trans_layer.self_attn.k_proj_bias = trans_layer.self_attn.k_proj.bias()
trans_layer.self_attn.v_proj_bias = trans_layer.self_attn.v_proj.bias()
trans_layer.self_attn.in_proj_bias = torch.cat((trans_layer.self_attn.q_proj_bias, trans_layer.self_attn.k_proj_bias, trans_layer.self_attn.v_proj_bias))
trans_layer.self_attn.out_proj_bias = trans_layer.self_attn.out_proj.bias()
trans_layer.self_attn.out_proj_weight = dequantizer(trans_layer.self_attn.out_proj.weight())
trans_layer.self_attn.q_proj_weight = dequantizer(trans_layer.self_attn.q_proj.weight())
@zilunpeng
zilunpeng / quantize_wav2vec2.py
Created March 23, 2021 21:31
Quantize wav2vec 2.0. Code below is part of quantized wav2vec 2.0 demo notebook (https://git.io/JYe1o).
quantized_model = torch.quantization.quantize_dynamic(pt_wav2vec2, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.prepare_for_inference_after_quantization()
@zilunpeng
zilunpeng / call_wav2vec2_decoder.py
Created March 23, 2021 22:03
Get the decoder output. Code below is part of the wav2vec 2.0 inference notebook (https://git.io/JYeKX).
decoder_out = decoder.decode(emissions)
@zilunpeng
zilunpeng / import_wav2letter.py
Created March 23, 2021 22:08
Import from wav2letter. Code below is part of utils.py (https://git.io/JYeHy).
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
@zilunpeng
zilunpeng / call_viterbi_decode.py
Created March 23, 2021 22:13
Make calls to the C++ method for Viterbi decoding. Code below is part of utils.py (https://git.io/JYeHy).
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
@zilunpeng
zilunpeng / import_init_ray.py
Created March 23, 2021 22:16
Import and initialize Ray. Code below is part of the distributed inference notebook (https://git.io/JYeQQ).
import ray
ray.init()