Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created August 30, 2024 15:34
Show Gist options
  • Save lucasnewman/3bd2a448078bb4d5ab809d2757f00a22 to your computer and use it in GitHub Desktop.
Save lucasnewman/3bd2a448078bb4d5ab809d2757f00a22 to your computer and use it in GitHub Desktop.
import datetime
from pathlib import Path
import torch
import torchaudio
from torchaudio.transforms import MelSpectrogram
from einops import rearrange
from vocos import Vocos
from e2_tts_pytorch.e2_tts import E2TTS
import matplotlib.pyplot as plt
from IPython.display import Audio
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
checkpoint_path = "e2tts_124000.pt"
audio_path = "~/data/LibriTTS_R/dev-clean/174/50561/174_50561_000001_000002.flac"
text_path = "~/data/LibriTTS_R/dev-clean/174/50561/174_50561_000001_000002.txt"
# load the model
e2tts = E2TTS(
tokenizer = 'phoneme_en',
cond_drop_prob = 0.2,
transformer = dict(
dim = 384,
depth = 12,
heads = 8,
max_seq_len = 1024,
skip_connect_type = 'concat'
),
mel_spec_kwargs = dict(
filter_length = 1024,
hop_length = 256,
win_length = 1024,
n_mel_channels = 100,
sampling_rate = 24000,
),
frac_lengths_mask = (0.7, 0.9)
)
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only = True)
e2tts.load_state_dict(checkpoint['model_state_dict'])
# load a sample audio file
audio, sr = torchaudio.load(Path(audio_path).expanduser())
original_mel_spec = e2tts.mel_spec(audio).squeeze(0)
# visualize the mel spectrogram
plt.figure(figsize=(6, 4))
plt.imshow(original_mel_spec.numpy(), origin='lower', aspect='auto')
plt.colorbar()
plt.show()
# mask off the second half
mask_length = original_mel_spec.shape[1] // 2
mel_spec = rearrange(original_mel_spec, 'd n -> 1 n d')[:, :mask_length, :]
lens = torch.LongTensor([mel_spec.shape[1]])
text = Path(text_path).expanduser().read_text().strip()
print(f"Text: {text}")
# if you want to use an accelerator, e.g. cuda or mps
device = torch.device('mps')
e2tts = e2tts.to(device)
mel_spec = mel_spec.to(device)
lens = lens.to(device)
start_date = datetime.datetime.now()
with torch.inference_mode():
generated = e2tts.sample(
cond = mel_spec,
text = [text],
duration = original_mel_spec.shape[1],
steps = 32,
cfg_strength = 1 # if trained with cond drop
)
print(f"Generated: {generated.shape} in {datetime.datetime.now() - start_date}")
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
# visualize the generated mel spectrogram
plt.figure(figsize=(6, 4))
plt.imshow(generated_mel_spec[0].cpu().numpy(), origin='lower', aspect='auto')
plt.colorbar()
plt.show()
# visualize the original mel spectrogram
plt.figure(figsize=(6, 4))
plt.imshow(original_mel_spec.numpy(), origin='lower', aspect='auto')
plt.colorbar()
plt.show()
# vocode into audio
wave = vocos.decode(rearrange(mel_spec.cpu(), '1 n d -> 1 d n'))
print(f"wave: {wave.shape}")
wave2 = vocos.decode(generated_mel_spec.cpu())
print(f"wave2: {wave2.shape}")
# show previews of the original and generated audio
print("Original:")
torchaudio.save("original.wav", wave, 24_000)
display(Audio("original.wav"))
print("Generated:")
torchaudio.save("generated.wav", wave2, 24_000)
display(Audio("generated.wav"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment