Created
October 8, 2021 13:52
-
-
Save philschmid/928b5d1c6b01431c69b09be31eb60172 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import os | |
| from transformers import Wav2Vec2Processor, HubertForCTC | |
| from transformers.pipelines.automatic_speech_recognition import ffmpeg_read | |
| def get_file_size(file_path): | |
| size = os.path.getsize(file_path) | |
| return f"{round(size / 1000 / 1000,2)} MB" | |
| quantize=True | |
| # load model | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") | |
| model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft",torchscript=True) | |
| # prepare input/ read audio file | |
| with open("audio.wav", "rb") as f: | |
| inputs = ffmpeg_read(f.read(), processor.feature_extractor.sampling_rate) | |
| input_values = processor(inputs, return_tensors="pt",sampling_rate= processor.feature_extractor.sampling_rate) | |
| sample_input=tuple(input_values.values()) | |
| # trace model | |
| traced_model = torch.jit.trace(model, sample_input) | |
| torch.jit.save(traced_model, "traced_model.pt") | |
| print(f'model size is: {get_file_size("traced_model.pt")}') | |
| # test model | |
| logits = traced_model(**input_values)[0] | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.decode(predicted_ids[0]) | |
| print(transcription) | |
| if quantize == True: | |
| # quantize | |
| quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear, torch.nn.Conv1d}, dtype=torch.qint8, inplace=True) | |
| quantized_traced_model = torch.jit.trace(quantized_model, sample_input) | |
| torch.jit.save(quantized_traced_model, "quantized_traced_model.pt") | |
| # test model | |
| logits = quantized_traced_model(**input_values)[0] | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.decode(predicted_ids[0]) | |
| print(f'model size is: {get_file_size("quantized_traced_model.pt")}') | |
| print(transcription) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment