Last active
March 13, 2024 12:07
-
-
Save litagin02/0fa2b6d47d5376eae52053cf7708798a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from typing import Any, Optional | |
import tqdm | |
from torch.utils.data import Dataset | |
# HF pipelineで進捗表示をするために必要なDatasetクラス | |
class StrListDataset(Dataset[str]): | |
def __init__(self, original_list: list[str]) -> None: | |
self.original_list = original_list | |
def __len__(self) -> int: | |
return len(self.original_list) | |
def __getitem__(self, i: int) -> str: | |
return self.original_list[i] | |
def transcribe_files_with_hf_whisper( | |
audio_files: list[Path], | |
model_id: str, | |
initial_prompt: Optional[str] = None, | |
language: str = "ja", | |
batch_size: int = 16, | |
num_beams: int = 1, | |
device: str = "cuda", | |
pbar: Optional[tqdm] = None, | |
) -> list[str]: | |
import torch | |
from transformers import WhisperProcessor, pipeline | |
processor: WhisperProcessor = WhisperProcessor.from_pretrained(model_id) | |
generate_kwargs: dict[str, Any] = { | |
"language": language, | |
"do_sample": False, | |
"num_beams": num_beams, | |
} | |
if initial_prompt is not None: | |
prompt_ids: torch.Tensor = processor.get_prompt_ids( | |
initial_prompt, return_tensors="pt" | |
) | |
prompt_ids = prompt_ids.to(device) | |
generate_kwargs["prompt_ids"] = prompt_ids | |
pipe = pipeline( | |
model=model_id, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=batch_size, | |
torch_dtype=torch.float16, | |
device="cuda", | |
generate_kwargs=generate_kwargs, | |
) | |
dataset = StrListDataset([str(f) for f in audio_files]) | |
results: list[str] = [] | |
for whisper_result in pipe(dataset): | |
text: str = whisper_result["text"] | |
# なぜかテキストの最初に" {initial_prompt}"が入るので、文字の最初からこれを削除する | |
# cf. https://github.com/huggingface/transformers/issues/27594 | |
if text.startswith(f" {initial_prompt}"): | |
text = text[len(f" {initial_prompt}") :] | |
results.append(text) | |
if pbar is not None: | |
pbar.update(1) | |
if pbar is not None: | |
pbar.close() | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment