Created
November 7, 2023 04:37
-
-
Save jaggzh/6e3d4d97fa93aa76320ea7b389140ff0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# gist-paste url: | |
# gist-paste -u https://gist.github.com/jaggzh/6e3d4d97fa93aa76320ea7b389140ff0 | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from datasets import load_dataset | |
# This is the directory created by run_speech_recognition_seq2seq.py | |
whdir="whisper-custom-en" | |
# NOTE! For loading large datasets, like common-voice, seq2seq's hf mozilla commonvoice loader (at least for v11) | |
# will try to preprocess the ENTIRE SET, even if you set your splits to small %'s. | |
# I modified mine to: | |
#if training_args.do_train: | |
# raw_datasets["train"] = load_dataset( | |
# data_args.dataset_name, | |
# data_args.dataset_config_name, | |
# #split=data_args.train_split_name, | |
# # THIS LINE HERE AND IN THE .do_eval right below this one | |
# split=f'{data_args.train_split_name}[:1%]', # Load only the first 1% | |
# cache_dir=model_args.cache_dir, | |
# token=model_args.token, | |
# #verification_mode='all_checks', | |
# ) | |
# AND LOWER, BEFORE prepare_dataset(), slice the dataset (or it'll still preproc everything): | |
# These 4 lines: | |
# if data_args.max_train_samples is not None: | |
# raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) | |
# if data_args.max_eval_samples is not None: | |
# raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) | |
# def prepare_dataset(batch): | |
# load model and processor | |
# processor = WhisperProcessor.from_pretrained("openai/whisper-large") | |
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") | |
processor = WhisperProcessor.from_pretrained(whdir) | |
model = WhisperForConditionalGeneration.from_pretrained(whdir) | |
model.config.forced_decoder_ids = None | |
# load dummy dataset and read audio files | |
#ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | |
#ds = load_dataset( | |
# "mozilla-foundation/common_voice_11_0", | |
# "en", | |
# #split=data_args.train_split_name, | |
# split=f'train[:15%]', # Load only the first % | |
# #cache_dir=model_args.cache_dir, | |
# token=True | |
# #verification_mode='all_checks', | |
#) | |
# example = ds[0]["audio"] | |
# sample = example['array'] | |
# sr = example['array'] | |
import librosa | |
import numpy as np | |
aa,sr=librosa.load("/tmp/w.wav", sr=16000) | |
sample=aa.astype(np.float64) | |
# import ipdb; ipdb.set_trace(context=16); pass | |
input_features = processor(sample, sampling_rate=sr, return_tensors="pt").input_features | |
# generate token ids | |
predicted_ids = model.generate(input_features) | |
# decode token ids to text | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False) | |
print(f"Transcription: {transcription}") | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
print(f"Transcription: {transcription}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from datasets import load_dataset | |
# This is the directory created by run_speech_recognition_seq2seq.py | |
whdir="whisper-custom-en" | |
# NOTE! For loading large datasets, like common-voice, seq2seq's hf mozilla commonvoice loader (at least for v11) | |
# will try to preprocess the ENTIRE SET, even if you set your splits to small %'s. | |
# I modified mine to: | |
#if training_args.do_train: | |
# raw_datasets["train"] = load_dataset( | |
# data_args.dataset_name, | |
# data_args.dataset_config_name, | |
# #split=data_args.train_split_name, | |
# # THIS LINE HERE AND IN THE .do_eval right below this one | |
# split=f'{data_args.train_split_name}[:1%]', # Load only the first 1% | |
# cache_dir=model_args.cache_dir, | |
# token=model_args.token, | |
# #verification_mode='all_checks', | |
# ) | |
# AND LOWER, BEFORE prepare_dataset(), slice the dataset (or it'll still preproc everything): | |
# These 4 lines: | |
# if data_args.max_train_samples is not None: | |
# raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) | |
# if data_args.max_eval_samples is not None: | |
# raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) | |
# def prepare_dataset(batch): | |
# load model and processor | |
# processor = WhisperProcessor.from_pretrained("openai/whisper-large") | |
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") | |
processor = WhisperProcessor.from_pretrained(whdir) | |
model = WhisperForConditionalGeneration.from_pretrained(whdir) | |
model.config.forced_decoder_ids = None | |
# load dummy dataset and read audio files | |
#ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | |
#ds = load_dataset( | |
# "mozilla-foundation/common_voice_11_0", | |
# "en", | |
# #split=data_args.train_split_name, | |
# split=f'train[:15%]', # Load only the first % | |
# #cache_dir=model_args.cache_dir, | |
# token=True | |
# #verification_mode='all_checks', | |
#) | |
# example = ds[0]["audio"] | |
# sample = example['array'] | |
# sr = example['array'] | |
import librosa | |
import numpy as np | |
aa,sr=librosa.load("/tmp/w.wav", sr=16000) | |
sample=aa.astype(np.float64) | |
# import ipdb; ipdb.set_trace(context=16); pass | |
input_features = processor(sample, sampling_rate=sr, return_tensors="pt").input_features | |
# generate token ids | |
predicted_ids = model.generate(input_features) | |
# decode token ids to text | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False) | |
print(f"Transcription: {transcription}") | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
print(f"Transcription: {transcription}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment