Created
August 30, 2024 15:34
-
-
Save lucasnewman/3bd2a448078bb4d5ab809d2757f00a22 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from pathlib import Path | |
import torch | |
import torchaudio | |
from torchaudio.transforms import MelSpectrogram | |
from einops import rearrange | |
from vocos import Vocos | |
from e2_tts_pytorch.e2_tts import E2TTS | |
import matplotlib.pyplot as plt | |
from IPython.display import Audio | |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
checkpoint_path = "e2tts_124000.pt" | |
audio_path = "~/data/LibriTTS_R/dev-clean/174/50561/174_50561_000001_000002.flac" | |
text_path = "~/data/LibriTTS_R/dev-clean/174/50561/174_50561_000001_000002.txt" | |
# load the model | |
e2tts = E2TTS( | |
tokenizer = 'phoneme_en', | |
cond_drop_prob = 0.2, | |
transformer = dict( | |
dim = 384, | |
depth = 12, | |
heads = 8, | |
max_seq_len = 1024, | |
skip_connect_type = 'concat' | |
), | |
mel_spec_kwargs = dict( | |
filter_length = 1024, | |
hop_length = 256, | |
win_length = 1024, | |
n_mel_channels = 100, | |
sampling_rate = 24000, | |
), | |
frac_lengths_mask = (0.7, 0.9) | |
) | |
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only = True) | |
e2tts.load_state_dict(checkpoint['model_state_dict']) | |
# load a sample audio file | |
audio, sr = torchaudio.load(Path(audio_path).expanduser()) | |
original_mel_spec = e2tts.mel_spec(audio).squeeze(0) | |
# visualize the mel spectrogram | |
plt.figure(figsize=(6, 4)) | |
plt.imshow(original_mel_spec.numpy(), origin='lower', aspect='auto') | |
plt.colorbar() | |
plt.show() | |
# mask off the second half | |
mask_length = original_mel_spec.shape[1] // 2 | |
mel_spec = rearrange(original_mel_spec, 'd n -> 1 n d')[:, :mask_length, :] | |
lens = torch.LongTensor([mel_spec.shape[1]]) | |
text = Path(text_path).expanduser().read_text().strip() | |
print(f"Text: {text}") | |
# if you want to use an accelerator, e.g. cuda or mps | |
device = torch.device('mps') | |
e2tts = e2tts.to(device) | |
mel_spec = mel_spec.to(device) | |
lens = lens.to(device) | |
start_date = datetime.datetime.now() | |
with torch.inference_mode(): | |
generated = e2tts.sample( | |
cond = mel_spec, | |
text = [text], | |
duration = original_mel_spec.shape[1], | |
steps = 32, | |
cfg_strength = 1 # if trained with cond drop | |
) | |
print(f"Generated: {generated.shape} in {datetime.datetime.now() - start_date}") | |
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') | |
# visualize the generated mel spectrogram | |
plt.figure(figsize=(6, 4)) | |
plt.imshow(generated_mel_spec[0].cpu().numpy(), origin='lower', aspect='auto') | |
plt.colorbar() | |
plt.show() | |
# visualize the original mel spectrogram | |
plt.figure(figsize=(6, 4)) | |
plt.imshow(original_mel_spec.numpy(), origin='lower', aspect='auto') | |
plt.colorbar() | |
plt.show() | |
# vocode into audio | |
wave = vocos.decode(rearrange(mel_spec.cpu(), '1 n d -> 1 d n')) | |
print(f"wave: {wave.shape}") | |
wave2 = vocos.decode(generated_mel_spec.cpu()) | |
print(f"wave2: {wave2.shape}") | |
# show previews of the original and generated audio | |
print("Original:") | |
torchaudio.save("original.wav", wave, 24_000) | |
display(Audio("original.wav")) | |
print("Generated:") | |
torchaudio.save("generated.wav", wave2, 24_000) | |
display(Audio("generated.wav")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment