Created
June 22, 2024 11:55
-
-
Save lewoudar/4e72d106d67ebc131597347618598eb7 to your computer and use it in GitHub Desktop.
Playing with the ctranslate2 backend
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import Literal | |
from faster_whisper import WhisperModel | |
model_size = "large-v3" | |
# workaround if we have already installed openai whisper stuff via transformers | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
# Run on GPU with FP16 | |
# model = WhisperModel(model_size, device="cuda", compute_type="float16") | |
# or run on GPU with INT8 | |
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16") | |
# or run on CPU with INT8 | |
model = WhisperModel(model_size, device="cpu", compute_type="int8") | |
path = Path(__file__).parent.parent / 'sample.wav' | |
@dataclass | |
class Segment: | |
start: float | |
end: float | |
text: str | |
@dataclass | |
class Transcriber: | |
model: WhisperModel | |
audio: Path | |
_text: str = field(init=False, default='') | |
_segments: list[Segment] = field(default_factory=list, init=False) | |
def transcribe(self) -> None: | |
segments, info = self.model.transcribe(self.audio.as_posix(), beam_size=5) | |
print("Detected language '%s' with probability %f" % (info.language, info.language_probability)) | |
for segment in segments: | |
self._text += segment.text | |
self._segments.append(Segment(segment.start, segment.end, segment.text)) | |
print(self._segments[0]) | |
@staticmethod | |
def _format_timestamp(seconds: float) -> str: | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = seconds % 60 | |
return f"{hours:02}:{minutes:02}:{secs:06.3f}".replace('.', ',') | |
@property | |
def text(self) -> str: | |
return self._text | |
def _get_writable_path(self, filename: str | None, suffix: Literal['vtt', 'srt']) -> Path: | |
if filename is not None and not os.access(Path(filename), os.W_OK): | |
raise PermissionError(f'{filename} is not writable') | |
return Path(filename) if filename else self.audio.resolve().with_suffix(f'.{suffix}') | |
def create_vtt_file(self, filename: str | None = None) -> None: | |
vtt_file = self._get_writable_path(filename, 'vtt') | |
with vtt_file.open('w') as f: | |
f.write('WEBVTT\n\n') | |
for segment in self._segments: | |
start_time = self._format_timestamp(segment.start) | |
end_time = self._format_timestamp(segment.end) | |
f.write(f'{start_time} --> {end_time}\n') | |
f.write(f'{segment.text}\n\n') | |
def create_srt_file(self, filename: str | None = None) -> None: | |
srt_file = self._get_writable_path(filename, 'srt') | |
with srt_file.open('w') as f: | |
for index, segment in enumerate(self._segments, start=1): | |
start_time = self._format_timestamp(segment.start) | |
end_time = self._format_timestamp(segment.end) | |
f.write(f'{index}\n') | |
f.write(f'{start_time} --> {end_time}\n') | |
f.write(f'{segment.text}\n\n') | |
transcriber = Transcriber(model=model, audio=path) | |
start = time.perf_counter() | |
transcriber.transcribe() | |
print(f'duration: {time.perf_counter() - start:.2f}s') | |
print(transcriber.text) | |
transcriber.create_vtt_file() | |
transcriber.create_srt_file() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import Literal | |
import whisperx | |
import torch | |
import psutil | |
from whisperx.asr import FasterWhisperPipeline | |
device = "gpu" if torch.cuda.is_available() else "cpu" | |
audio_file = Path(__file__).parent.parent / 'sample.wav' | |
compute_type = "float16" if device == "gpu" else "int8" | |
model = whisperx.load_model( | |
"large-v3", device, compute_type=compute_type, threads=psutil.cpu_count(logical=False), asr_options={'hotwords': None} | |
) | |
@dataclass | |
class Segment: | |
start: float | |
end: float | |
text: str | |
@dataclass | |
class Transcriber: | |
model: FasterWhisperPipeline | |
audio: Path | |
batch_size: int | |
device: Literal['cpu', 'gpu'] = 'cpu' | |
_text: str = field(init=False, default='') | |
_segments: list[Segment] = field(default_factory=list, init=False) | |
def transcribe(self): | |
audio = whisperx.load_audio(audio_file.as_posix()) | |
result = self.model.transcribe(audio, batch_size=self.batch_size) | |
self._text = ''.join(segment['text'] for segment in result['segments']) | |
# align Whisper model | |
model_a, metadata = whisperx.load_align_model(language_code=result['language'], device=self.device) | |
result = whisperx.align(result['segments'], model_a, metadata, audio, device, return_char_alignments=False) | |
for segment in result['segments']: | |
for single_segment in segment['words']: | |
self._segments.append( | |
Segment(start=single_segment['start'], end=single_segment['end'], text=single_segment['word']) | |
) | |
@property | |
def text(self) -> str: | |
return self._text | |
@staticmethod | |
def _format_timestamp(seconds: float) -> str: | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = seconds % 60 | |
return f"{hours:02}:{minutes:02}:{secs:06.3f}".replace('.', ',') | |
def _get_writable_path(self, filename: str | None, suffix: Literal['vtt', 'srt']) -> Path: | |
if filename is not None and not os.access(Path(filename), os.W_OK): | |
raise PermissionError(f'{filename} is not writable') | |
return Path(filename) if filename else self.audio.resolve().with_suffix(f'.{suffix}') | |
def create_vtt_file(self, filename: str | None = None) -> None: | |
vtt_file = self._get_writable_path(filename, 'vtt') | |
with vtt_file.open('w') as f: | |
f.write('WEBVTT\n\n') | |
for segment in self._segments: | |
start_time = self._format_timestamp(segment.start) | |
end_time = self._format_timestamp(segment.end) | |
f.write(f'{start_time} --> {end_time}\n') | |
f.write(f'{segment.text}\n\n') | |
def create_srt_file(self, filename: str | None = None) -> None: | |
srt_file = self._get_writable_path(filename, 'srt') | |
with srt_file.open('w') as f: | |
for index, segment in enumerate(self._segments, start=1): | |
start_time = self._format_timestamp(segment.start) | |
end_time = self._format_timestamp(segment.end) | |
f.write(f'{index}\n') | |
f.write(f'{start_time} --> {end_time}\n') | |
f.write(f'{segment.text}\n\n') | |
transcriber = Transcriber( | |
model=model, | |
audio=audio_file, | |
batch_size=16, # reduce if low on GPU memory | |
device=device # type: ignore | |
) | |
start = time.perf_counter() | |
transcriber.transcribe() | |
print(f'duration: {time.perf_counter() - start:.2f}s') | |
print(transcriber.text) | |
transcriber.create_vtt_file() | |
transcriber.create_srt_file() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import os | |
from pathlib import Path | |
import ctranslate2 | |
import transformers | |
# workaround when there are conflicts with transformers installations | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
model_path = Path.home() / '.cache' / 'ctranslate2' / 'm2m100_1.2B' | |
translator = ctranslate2.Translator(model_path.as_posix()) | |
def translate(text, source_lang, target_lang): | |
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_1.2B") | |
tokenizer.src_lang = source_lang | |
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(text)) | |
target_prefix = [tokenizer.lang_code_to_token[target_lang]] | |
results = translator.translate_batch([source], target_prefix=[target_prefix]) | |
target = results[0].hypotheses[0][1:] | |
return tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) | |
start_time = time.perf_counter() | |
# hindi to french | |
print(translate("जीवन एक चॉकलेट बॉक्स की तरह है।", "hi", "fr")) | |
print(f'took {time.perf_counter() - start_time:.2f} seconds') | |
start_time = time.perf_counter() | |
# chinese to english | |
print(translate("生活就像一盒巧克力。", "zh", "en")) | |
print(f'took {time.perf_counter() - start_time:.2f} seconds') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment