Skip to content

Instantly share code, notes, and snippets.

@graylan0
Created December 7, 2023 06:13
Show Gist options
  • Save graylan0/710fa07073e17e2dad7c5478b058d408 to your computer and use it in GitHub Desktop.
Save graylan0/710fa07073e17e2dad7c5478b058d408 to your computer and use it in GitHub Desktop.
from fastapi import FastAPI, UploadFile, HTTPException, Form
from fastapi.responses import FileResponse
from transformers import CLIPProcessor, CLIPModel
import torchaudio
import torch
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pennylane as qml
from sklearn.svm import SVC
import joblib
from sklearn.preprocessing import StandardScaler
from bark.generation import load_codec_model, generate_text_semantic
from encodec.utils import convert_audio
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
from bark.api import generate_audio
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine
import math
app = FastAPI()
# Initialize models and processors
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
svm_model = joblib.load('path/to/pretrained_svm_model.pkl')
svm_scaler = joblib.load('path/to/pretrained_svm_scaler.pkl')
# Quantum Neural Network setup
def quantum_neural_network(num_qubits=4, num_layers=3):
dev = qml.device("default.qubit", wires=num_qubits)
@qml.qnode(dev)
def qnn_circuit(inputs, weights):
qml.templates.AngleEmbedding(inputs, wires=range(num_qubits))
qml.templates.StronglyEntanglingLayers(weights, wires=range(num_qubits))
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)]
weight_shapes = {"weights": (num_layers, num_qubits, 3)}
qlayer = qml.qnn.TorchLayer(qnn_circuit, weight_shapes)
return qlayer
qnn = quantum_neural_network()
# Bark and HuBERT models setup
model = load_codec_model(use_gpu=True)
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
hubert_model_0 = CustomHubert(checkpoint_path='path/to/hubert_0.pt').to('cuda')
tokenizer_0 = CustomTokenizer.load_from_checkpoint('path/to/tokenizer_0.pth').to('cuda')
hubert_model_1 = CustomHubert(checkpoint_path='path/to/hubert_1.pt').to('cuda')
tokenizer_1 = CustomTokenizer.load_from_checkpoint('path/to/tokenizer_1.pth').to('cuda')
# Function to clone voice
def voice_cloning(audio_filepath, clip_model, hubert_model, tokenizer):
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to('cuda')
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
with torch.no_grad():
encoded_frames = model.encode(wav.unsqueeze(0))
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze()
codes = codes.cpu().numpy()
semantic_tokens = semantic_tokens.cpu().numpy()
voice_name = 'output'
output_path = f'bark/assets/prompts/{voice_name}.npz'
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
# Audio dataset class
class AudioDataset(Dataset):
def __init__(self, audio_path, chunk_size, overlap, sample_rate):
self.wav, self.sr = torchaudio.load(audio_path, normalize=True)
self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, new_freq=sample_rate)
self.chunk_size = chunk_size
self.overlap = overlap
self.sample_rate = sample_rate
self.num_chunks = math.ceil((self.wav.size(1) / self.sr * sample_rate - overlap) / (chunk_size - overlap))
def __len__(self):
return self.num_chunks
def __getitem__(self, idx):
start = int(idx * (self.chunk_size - self.overlap))
end = int(min(start + self.chunk_size, self.wav.size(1) / self.sr * self.sample_rate))
chunk = self.wav[:, start:end]
return self.resampler(chunk)
# Silence detection
def is_silent(audio_chunk, silence_threshold=1e-4, min_silence_duration=0.5):
silent_duration = 0.0
for sample in audio_chunk:
if torch.abs(sample).mean() < silence_threshold:
silent_duration += 1.0 / audio_chunk.sample_rate
if silent_duration >= min_silence_duration:
return True
else:
silent_duration = 0
return False
# Dynamic prompt generation using CLIP
def generate_dynamic_prompts(spectrogram):
categories = ["music", "speech", "environmental sound", "silence"]
prompts = [f"This is a {category} spectrogram" for category in categories]
inputs = clip_processor(text=prompts, images=spectrogram.unsqueeze(0), return_tensors="pt", padding=True)
outputs = clip_model(**inputs)
similarities = outputs.logits_per_text.cpu().detach().numpy()
best_prompt_index = np.argmax(similarities, axis=1)
return prompts[best_prompt_index[0]]
# Processing each audio chunk
def process_chunk(chunk, qnn, clip_model, clip_processor, svm_model, svm_scaler, model, hubert_model, tokenizer):
if is_silent(chunk):
return 'Silence'
spectrogram = torchaudio.transforms.Spectrogram()(chunk)
dynamic_prompt = generate_dynamic_prompts(spectrogram)
# Additional processing steps can be added here
return dynamic_prompt
# Asynchronous audio chunk processing
def process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate, num_workers=2):
dataset = AudioDataset(audio_path, chunk_size, overlap, sample_rate)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
def init_models(device_id):
return (
model if device_id == 0 else model_1,
hubert_model if device_id == 0 else hubert_model_1,
tokenizer if device_id == 0 else tokenizer_1,
)
futures = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for i, chunk in enumerate(loader):
device_id = i % 2
model, hubert_model, tokenizer = init_models(device_id)
futures.append(executor.submit(process_chunk, chunk, qnn, clip_model, clip_processor, svm_model, svm_scaler, model, hubert_model, tokenizer))
results = [future.result() for future in futures]
return results
# FastAPI endpoint for serving HTML page with the processed data
@app.post("/process-and-clone/")
async def process_and_clone(
file: UploadFile = Form(...),
text_prompt: str = Form(...),
history_prompt: str = Form(...),
):
try:
with open(file.filename, "wb") as f:
f.write(file.file.read())
audio_filepath = file.filename
voice_cloning(audio_filepath, clip_model, hubert_model_0, tokenizer_0)
return {"message": "Voice cloning and audio processing successful"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing and cloning voice: {e}")
# FastAPI endpoint for serving the cloned voice file
@app.get("/get-cloned-voice/")
async def get_cloned_voice():
try:
voice_name = 'output'
output_path = f'bark/assets/prompts/{voice_name}.npz'
return FileResponse(output_path, media_type="application/octet-stream", filename=f"{voice_name}.npz")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting cloned voice: {e}")
# FastAPI endpoint for processing audio chunks asynchronously
@app.post("/process-audio-chunks/")
async def process_audio_chunks(
file: UploadFile = Form(...),
chunk_size: int = Form(...),
overlap: int = Form(...),
sample_rate: int = Form(...),
num_workers: int = Form(...),
):
try:
with open(file.filename, "wb") as f:
f.write(file.file.read())
audio_filepath = file.filename
results = process_audio_chunks_async(audio_filepath, chunk_size, overlap, sample_rate, num_workers)
return {"results": results}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio chunks: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment