Skip to content

Instantly share code, notes, and snippets.

@philschmid
Created October 8, 2021 13:52
Show Gist options
  • Select an option

  • Save philschmid/928b5d1c6b01431c69b09be31eb60172 to your computer and use it in GitHub Desktop.

Select an option

Save philschmid/928b5d1c6b01431c69b09be31eb60172 to your computer and use it in GitHub Desktop.
import torch
import os
from transformers import Wav2Vec2Processor, HubertForCTC
from transformers.pipelines.automatic_speech_recognition import ffmpeg_read
def get_file_size(file_path):
size = os.path.getsize(file_path)
return f"{round(size / 1000 / 1000,2)} MB"
quantize=True
# load model
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft",torchscript=True)
# prepare input/ read audio file
with open("audio.wav", "rb") as f:
inputs = ffmpeg_read(f.read(), processor.feature_extractor.sampling_rate)
input_values = processor(inputs, return_tensors="pt",sampling_rate= processor.feature_extractor.sampling_rate)
sample_input=tuple(input_values.values())
# trace model
traced_model = torch.jit.trace(model, sample_input)
torch.jit.save(traced_model, "traced_model.pt")
print(f'model size is: {get_file_size("traced_model.pt")}')
# test model
logits = traced_model(**input_values)[0]
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
print(transcription)
if quantize == True:
# quantize
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear, torch.nn.Conv1d}, dtype=torch.qint8, inplace=True)
quantized_traced_model = torch.jit.trace(quantized_model, sample_input)
torch.jit.save(quantized_traced_model, "quantized_traced_model.pt")
# test model
logits = quantized_traced_model(**input_values)[0]
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
print(f'model size is: {get_file_size("quantized_traced_model.pt")}')
print(transcription)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment