Created
January 6, 2024 02:58
-
-
Save ichabodcole/1c0d19ef4c33b7b5705b0860c7c27f7b to your computer and use it in GitHub Desktop.
Combining XTTS speaker embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor | |
from TTS.api import TTS | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
import utils | |
from utils import CombineMethod | |
from pydub import AudioSegment | |
from typing import List | |
config = XttsConfig() | |
config.load_json('./tts/tts_models--multilingual--multi-dataset--xtts_v2/config.json') | |
model = Xtts.init_from_config(config) | |
checkpoint_dir = './tts/tts_models--multilingual--multi-dataset--xtts_v2' | |
model.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=True) | |
model.cuda() if torch.cuda.is_available() else model.cpu() | |
speaker_data = torch.load('./tts/tts_models--multilingual--multi-dataset--xtts_v2/speakers_xtts.pth') | |
daisy = list(speaker_data['Daisy Studious'].values()) | |
henriette = list(speaker_data['Henriette Usha'].values()) | |
baldur = list(speaker_data['Baldur Sanjin'].values()) | |
speakers = [daisy, henriette, baldur] | |
def generate_speech(speakers, combine_method: CombineMethod, speaker_weights: List | None = None): | |
avg_gpt_cond_latents, avg_speaker_embedding = utils.average_latents_and_embeddings(speakers, combine_method, speaker_weights) | |
text = "This ascent represents your connection to the universe, your consciousness expanding to embrace the infinite." | |
out = model.inference( | |
text=text, | |
language="en", | |
gpt_cond_latent=avg_gpt_cond_latents, | |
speaker_embedding=avg_speaker_embedding, | |
temperature=0.7, | |
speed=1.0 | |
) | |
twav = Tensor(out['wav']) | |
numpy_wav = utils.tensor_to_numpy_array(twav) | |
audio_segment = AudioSegment( | |
data=numpy_wav.tobytes(), # Convert the array to bytes | |
sample_width=2, # 2 bytes (16 bits) per sample | |
frame_rate=24000, # Sample rate | |
channels=1 # Mono audio | |
) | |
weights = "-".join(map(str, speaker_weights)) | |
audio_segment.export(f'./output/speaker-mix_method_{combine_method.value}_weight_{weights}.mp3', format="mp3") | |
generate_speech(speakers, CombineMethod.SUM, speaker_weights=[1, 2, 2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor | |
import numpy as np | |
from enum import Enum | |
from typing import List | |
class CombineMethod(Enum): | |
MEAN = 'mean' | |
SUM = 'sum' | |
MEDIAN = 'median' | |
MAX = 'max' | |
MIN = 'min' | |
NORMALIZED_SUM = 'normalized_sum' | |
def tensor_to_numpy_array(tensor: Tensor) -> np.ndarray: | |
tensor = tensor.cpu().detach() | |
numpy_array = tensor.numpy() | |
return (numpy_array * np.iinfo(np.int16).max).astype(np.int16) | |
def average_latents_and_embeddings(latent_embedding_pairs, combine_method: CombineMethod = CombineMethod.MEAN, speaker_weights: List | None = None): | |
""" | |
Averages a list of (gpt_cond_latents, speaker_embedding) pairs. | |
Args: | |
latent_embedding_pairs (list of tuples): A list where each element is a tuple containing gpt_cond_latents and speaker_embedding. | |
Returns: | |
tuple: A tuple containing the averaged gpt_cond_latents and speaker_embedding. | |
""" | |
# Separate gpt_cond_latents and speaker_embeddings | |
gpt_cond_latents_list = [pair[0] for pair in latent_embedding_pairs] | |
speaker_embeddings_list = [pair[1] for pair in latent_embedding_pairs] | |
# Average gpt_cond_latents | |
avg_gpt_cond_latents = combine_embeddings(gpt_cond_latents_list, combine_method, speaker_weights) | |
# Average speaker_embeddings | |
avg_speaker_embedding = combine_embeddings(speaker_embeddings_list, combine_method, speaker_weights) | |
return avg_gpt_cond_latents, avg_speaker_embedding | |
def combine_embeddings(embeddings, method, weights: List | None = None): | |
if weights == None: | |
weights = [1 for _ in embeddings] | |
if len(weights) != len(embeddings): | |
raise ValueError("Weights match the number of embeddings for weighted average.") | |
weighted_embeddings = [embedding * weight for embedding, weight in zip(embeddings, weights)] | |
if method == CombineMethod.MEAN: | |
return torch.mean(torch.stack(weighted_embeddings), dim=0) | |
elif method == CombineMethod.SUM: | |
return torch.sum(torch.stack(weighted_embeddings), dim=0) | |
elif method == CombineMethod.MEDIAN: | |
return torch.median(torch.stack(weighted_embeddings), dim=0).values | |
elif method == CombineMethod.MAX: | |
return torch.max(torch.stack(weighted_embeddings), dim=0).values | |
elif method == CombineMethod.MIN: | |
return torch.min(torch.stack(weighted_embeddings), dim=0).values | |
elif method == CombineMethod.NORMALIZED_SUM: | |
normalized = [embedding / torch.norm(embedding) for embedding in weighted_embeddings] | |
return torch.sum(torch.stack(normalized), dim=0) | |
else: | |
raise ValueError("Invalid combine method specified.") | |
def normalize_weights(weights): | |
total = sum(weights) | |
if total == 0: | |
raise ValueError("Sum of weights cannot be zero.") | |
return [w / total for w in weights] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment