Created
December 7, 2023 06:13
-
-
Save graylan0/710fa07073e17e2dad7c5478b058d408 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, UploadFile, HTTPException, Form | |
from fastapi.responses import FileResponse | |
from transformers import CLIPProcessor, CLIPModel | |
import torchaudio | |
import torch | |
from concurrent.futures import ThreadPoolExecutor | |
from torch.utils.data import DataLoader, Dataset | |
import numpy as np | |
import pennylane as qml | |
from sklearn.svm import SVC | |
import joblib | |
from sklearn.preprocessing import StandardScaler | |
from bark.generation import load_codec_model, generate_text_semantic | |
from encodec.utils import convert_audio | |
from hubert.hubert_manager import HuBERTManager | |
from hubert.pre_kmeans_hubert import CustomHubert | |
from hubert.customtokenizer import CustomTokenizer | |
from bark.api import generate_audio | |
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine | |
import math | |
app = FastAPI() | |
# Initialize models and processors | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
svm_model = joblib.load('path/to/pretrained_svm_model.pkl') | |
svm_scaler = joblib.load('path/to/pretrained_svm_scaler.pkl') | |
# Quantum Neural Network setup | |
def quantum_neural_network(num_qubits=4, num_layers=3): | |
dev = qml.device("default.qubit", wires=num_qubits) | |
@qml.qnode(dev) | |
def qnn_circuit(inputs, weights): | |
qml.templates.AngleEmbedding(inputs, wires=range(num_qubits)) | |
qml.templates.StronglyEntanglingLayers(weights, wires=range(num_qubits)) | |
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)] | |
weight_shapes = {"weights": (num_layers, num_qubits, 3)} | |
qlayer = qml.qnn.TorchLayer(qnn_circuit, weight_shapes) | |
return qlayer | |
qnn = quantum_neural_network() | |
# Bark and HuBERT models setup | |
model = load_codec_model(use_gpu=True) | |
hubert_manager = HuBERTManager() | |
hubert_manager.make_sure_hubert_installed() | |
hubert_manager.make_sure_tokenizer_installed() | |
hubert_model_0 = CustomHubert(checkpoint_path='path/to/hubert_0.pt').to('cuda') | |
tokenizer_0 = CustomTokenizer.load_from_checkpoint('path/to/tokenizer_0.pth').to('cuda') | |
hubert_model_1 = CustomHubert(checkpoint_path='path/to/hubert_1.pt').to('cuda') | |
tokenizer_1 = CustomTokenizer.load_from_checkpoint('path/to/tokenizer_1.pth').to('cuda') | |
# Function to clone voice | |
def voice_cloning(audio_filepath, clip_model, hubert_model, tokenizer): | |
wav, sr = torchaudio.load(audio_filepath) | |
wav = convert_audio(wav, sr, model.sample_rate, model.channels) | |
wav = wav.to('cuda') | |
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate) | |
semantic_tokens = tokenizer.get_token(semantic_vectors) | |
with torch.no_grad(): | |
encoded_frames = model.encode(wav.unsqueeze(0)) | |
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() | |
codes = codes.cpu().numpy() | |
semantic_tokens = semantic_tokens.cpu().numpy() | |
voice_name = 'output' | |
output_path = f'bark/assets/prompts/{voice_name}.npz' | |
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) | |
# Audio dataset class | |
class AudioDataset(Dataset): | |
def __init__(self, audio_path, chunk_size, overlap, sample_rate): | |
self.wav, self.sr = torchaudio.load(audio_path, normalize=True) | |
self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, new_freq=sample_rate) | |
self.chunk_size = chunk_size | |
self.overlap = overlap | |
self.sample_rate = sample_rate | |
self.num_chunks = math.ceil((self.wav.size(1) / self.sr * sample_rate - overlap) / (chunk_size - overlap)) | |
def __len__(self): | |
return self.num_chunks | |
def __getitem__(self, idx): | |
start = int(idx * (self.chunk_size - self.overlap)) | |
end = int(min(start + self.chunk_size, self.wav.size(1) / self.sr * self.sample_rate)) | |
chunk = self.wav[:, start:end] | |
return self.resampler(chunk) | |
# Silence detection | |
def is_silent(audio_chunk, silence_threshold=1e-4, min_silence_duration=0.5): | |
silent_duration = 0.0 | |
for sample in audio_chunk: | |
if torch.abs(sample).mean() < silence_threshold: | |
silent_duration += 1.0 / audio_chunk.sample_rate | |
if silent_duration >= min_silence_duration: | |
return True | |
else: | |
silent_duration = 0 | |
return False | |
# Dynamic prompt generation using CLIP | |
def generate_dynamic_prompts(spectrogram): | |
categories = ["music", "speech", "environmental sound", "silence"] | |
prompts = [f"This is a {category} spectrogram" for category in categories] | |
inputs = clip_processor(text=prompts, images=spectrogram.unsqueeze(0), return_tensors="pt", padding=True) | |
outputs = clip_model(**inputs) | |
similarities = outputs.logits_per_text.cpu().detach().numpy() | |
best_prompt_index = np.argmax(similarities, axis=1) | |
return prompts[best_prompt_index[0]] | |
# Processing each audio chunk | |
def process_chunk(chunk, qnn, clip_model, clip_processor, svm_model, svm_scaler, model, hubert_model, tokenizer): | |
if is_silent(chunk): | |
return 'Silence' | |
spectrogram = torchaudio.transforms.Spectrogram()(chunk) | |
dynamic_prompt = generate_dynamic_prompts(spectrogram) | |
# Additional processing steps can be added here | |
return dynamic_prompt | |
# Asynchronous audio chunk processing | |
def process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate, num_workers=2): | |
dataset = AudioDataset(audio_path, chunk_size, overlap, sample_rate) | |
loader = DataLoader(dataset, batch_size=1, shuffle=False) | |
def init_models(device_id): | |
return ( | |
model if device_id == 0 else model_1, | |
hubert_model if device_id == 0 else hubert_model_1, | |
tokenizer if device_id == 0 else tokenizer_1, | |
) | |
futures = [] | |
with ThreadPoolExecutor(max_workers=num_workers) as executor: | |
for i, chunk in enumerate(loader): | |
device_id = i % 2 | |
model, hubert_model, tokenizer = init_models(device_id) | |
futures.append(executor.submit(process_chunk, chunk, qnn, clip_model, clip_processor, svm_model, svm_scaler, model, hubert_model, tokenizer)) | |
results = [future.result() for future in futures] | |
return results | |
# FastAPI endpoint for serving HTML page with the processed data | |
@app.post("/process-and-clone/") | |
async def process_and_clone( | |
file: UploadFile = Form(...), | |
text_prompt: str = Form(...), | |
history_prompt: str = Form(...), | |
): | |
try: | |
with open(file.filename, "wb") as f: | |
f.write(file.file.read()) | |
audio_filepath = file.filename | |
voice_cloning(audio_filepath, clip_model, hubert_model_0, tokenizer_0) | |
return {"message": "Voice cloning and audio processing successful"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing and cloning voice: {e}") | |
# FastAPI endpoint for serving the cloned voice file | |
@app.get("/get-cloned-voice/") | |
async def get_cloned_voice(): | |
try: | |
voice_name = 'output' | |
output_path = f'bark/assets/prompts/{voice_name}.npz' | |
return FileResponse(output_path, media_type="application/octet-stream", filename=f"{voice_name}.npz") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error getting cloned voice: {e}") | |
# FastAPI endpoint for processing audio chunks asynchronously | |
@app.post("/process-audio-chunks/") | |
async def process_audio_chunks( | |
file: UploadFile = Form(...), | |
chunk_size: int = Form(...), | |
overlap: int = Form(...), | |
sample_rate: int = Form(...), | |
num_workers: int = Form(...), | |
): | |
try: | |
with open(file.filename, "wb") as f: | |
f.write(file.file.read()) | |
audio_filepath = file.filename | |
results = process_audio_chunks_async(audio_filepath, chunk_size, overlap, sample_rate, num_workers) | |
return {"results": results} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing audio chunks: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment