Created
November 13, 2022 10:10
-
-
Save shivammehta25/c2bc3a5a875c268e538edb774733f7e8 to your computer and use it in GitHub Desktop.
Hosting the models on the server using FastAPI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import sys | |
sys.path.append('src/model') | |
sys.path.insert(0, './hifigan') | |
import logging | |
import os | |
from pathlib import Path | |
from uuid import uuid4 | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import uvicorn | |
from fastapi import BackgroundTasks, FastAPI | |
from fastapi.responses import FileResponse | |
from nltk import word_tokenize | |
from hifigan.env import AttrDict | |
from hifigan.models import Generator | |
from hifigandenoiser import Denoiser | |
from src.hparams import create_hparams | |
from src.training_module import TrainingModule | |
from src.utilities.text import phonetise_text, text_to_sequence | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
app = FastAPI() | |
logging.basicConfig(filename="log.txt", level=logging.DEBUG, | |
format="%(asctime)s - %(levelname)s: %(message)s'", filemode="a") | |
print("[+] Loading Models..") | |
hparams = create_hparams() | |
def load_model(checkpoint_path, speaker): | |
model = TrainingModule.load_from_checkpoint(checkpoint_path) | |
_ = model.to(device).eval().half() | |
print(f"[+] Model Loaded: {speaker}") | |
return model | |
checkpoint_mandarin= "checkpoints/MandrinRun_Male/checkpoint_105000.ckpt" | |
checkpoint_arabic = "checkpoints/ArabRun_Male/checkpoint_110000.ckpt" | |
checkpoint_british = "checkpoints/BritishRun_male/checkpoint_105500.ckpt" | |
checkpoint_african = "checkpoints/Nigerian2Run_male/checkpoint_111500.ckpt" | |
model_mandarin = load_model(checkpoint_mandarin, speaker="mandarin") | |
model_arabic = load_model(checkpoint_arabic, speaker="arabic") | |
model_british = load_model(checkpoint_british, speaker="british") | |
model_african = load_model(checkpoint_african, speaker="african") | |
print("[+] Models Loaded..") | |
print("[+] Loading HiFi-GAN..") | |
# load the hifi-gan model | |
hifigan_loc = 'hifigan/' | |
config_file = hifigan_loc + 'config_v1.json' | |
hifi_checkpoint_file = 'g_02500000' | |
with open(config_file) as f: | |
data = f.read() | |
json_config = json.loads(data) | |
def load_checkpoint(filepath, device): | |
print(filepath) | |
assert os.path.isfile(filepath) | |
print("Loading '{}'".format(filepath)) | |
checkpoint_dict = torch.load(filepath, map_location=device) | |
print("Complete.") | |
return checkpoint_dict | |
h = AttrDict(json_config) | |
torch.manual_seed(h.seed) | |
generator = Generator(h).to(device) | |
state_dict_g = load_checkpoint(hifi_checkpoint_file, device) | |
generator.load_state_dict(state_dict_g['generator']) | |
generator.eval().half() | |
generator.remove_weight_norm() | |
denoiser = Denoiser(generator, mode='zeros') | |
print("[+] HiFi-GAN Loaded..") | |
def text_to_seq(text): | |
text = phonetise_text(hparams.cmu_phonetiser, text, word_tokenize) | |
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :] | |
sequence = torch.from_numpy(sequence).to(device).long() | |
return sequence | |
savepath = Path('temp') | |
savepath.mkdir(exist_ok=True, parents=True) | |
def del_file(file): | |
os.remove(file) | |
def log(text, speaker): | |
logging.info(f"\tSent successfully {speaker}: {text}") | |
@app.get("/speak/") | |
async def get_audio_from_text(text: str, speaker: str, bg_tasks: BackgroundTasks, speed: float = 0.55): | |
if speaker == "mandarin": | |
model = model_mandarin | |
elif speaker == "arabic": | |
model = model_arabic | |
elif speaker == "british": | |
model = model_british | |
elif speaker == "african": | |
model = model_african | |
model.model.hmm.hparams.max_sampling_time = 10000 | |
model.model.hmm.hparams.duration_quantile_threshold=speed | |
model.model.hmm.hparams.deterministic_transition=True | |
model.model.hmm.hparams.predict_means=False | |
model.model.hmm.hparams.prenet_dropout_while_eval=True | |
model.model.hmm.prenet.prenet_dropout=0.5 | |
text += "." | |
sequence = text_to_seq(text) | |
with torch.no_grad() and torch.inference_mode(): | |
mel_output, hidden_state_travelled, _, _ = model.sample(sequence.squeeze(0), sampling_temp=0.334) | |
mel_output = mel_output.transpose(1, 2) | |
audio = generator(mel_output) | |
audio = denoiser(audio[:, 0], strength=0.004)[:, 0] | |
filename = savepath / f"{uuid4()}.wav" | |
sf.write(filename, audio.data.squeeze().cpu().numpy(), | |
22500, 'PCM_24') | |
bg_tasks.add_task(del_file, filename) | |
bg_tasks.add_task(log, text, speaker) | |
model.model.hmm.hparams.duration_quantile_threshold=0.55 | |
return FileResponse(filename, media_type="audio/wav") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8020) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment