Created
December 7, 2023 05:36
-
-
Save graylan0/44c04d6b6ef21b4931cb82f5d749b4a7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torchaudio | |
import math | |
from transformers import CLIPProcessor, CLIPModel | |
import pennylane as qml | |
from sklearn.svm import SVC | |
import joblib | |
from sklearn.preprocessing import StandardScaler | |
from torch.utils.data import DataLoader, Dataset | |
from concurrent.futures import ThreadPoolExecutor | |
from bark.generation import load_codec_model, generate_text_semantic | |
from encodec.utils import convert_audio | |
from hubert.hubert_manager import HuBERTManager | |
from hubert.pre_kmeans_hubert import CustomHubert | |
from hubert.customtokenizer import CustomTokenizer | |
from bark.api import generate_audio | |
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
svm_model = joblib.load('path/to/pretrained_svm_model.pkl') | |
svm_scaler = joblib.load('path/to/pretrained_svm_scaler.pkl') | |
def quantum_neural_network(num_qubits=4, num_layers=3): | |
dev = qml.device("default.qubit", wires=num_qubits) | |
@qml.qnode(dev) | |
def qnn_circuit(inputs, weights): | |
qml.templates.AngleEmbedding(inputs, wires=range(num_qubits)) | |
qml.templates.StronglyEntanglingLayers(weights, wires=range(num_qubits)) | |
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)] | |
weight_shapes = {"weights": (num_layers, num_qubits, 3)} | |
qlayer = qml.qnn.TorchLayer(qnn_circuit, weight_shapes) | |
return qlayer | |
qnn = quantum_neural_network() | |
model = load_codec_model(use_gpu=True) | |
hubert_manager = HuBERTManager() | |
hubert_manager.make_sure_hubert_installed() | |
hubert_manager.make_sure_tokenizer_installed() | |
hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to('cuda') | |
tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to('cuda') | |
class AudioDataset(Dataset): | |
def __init__(self, audio_path, chunk_size, overlap, sample_rate): | |
self.wav, self.sr = torchaudio.load(audio_path, normalize=True) | |
self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, new_freq=sample_rate) | |
self.chunk_size = chunk_size | |
self.overlap = overlap | |
self.sample_rate = sample_rate | |
self.num_chunks = math.ceil((self.wav.size(1) / self.sr * self.sample_rate - the overlap) / (chunk_size - overlap)) | |
def __len__(self): | |
return self.num_chunks | |
def __getitem__(self, idx): | |
start = int(idx * (self.chunk_size - the overlap)) | |
end = int(min(start + self.chunk_size, self.wav.size(1) / self.sr * self.sample_rate)) | |
chunk = self.wav[:, start:end] | |
return self.resampler(chunk) | |
def is_silent(audio_chunk, silence_threshold=1e-4, min_silence_duration=0.5): | |
silent_duration = 0.0 | |
for sample in audio_chunk: | |
if torch.abs(sample).mean() < silence_threshold: | |
silent_duration += 1.0 / audio_chunk.sample_rate | |
if silent_duration >= min_silence_duration: | |
return True | |
else: | |
silent_duration = 0 | |
return False | |
def generate_dynamic_prompts(spectrogram): | |
categories = ["music", "speech", "environmental sound", "silence"] | |
prompts = [f"This is a {category} spectrogram" for category in categories] | |
inputs = clip_processor(text=prompts, images=spectrogram.unsqueeze(0), return_tensors="pt", padding=True) | |
outputs = clip_model(**inputs) | |
similarities = outputs.logits_per_text.cpu().detach().numpy() | |
best_prompt_index = np.argmax(similarities, axis=1) | |
return prompts[best_prompt_index[0]] | |
def process_chunk(chunk, qnn, clip_model, clip_processor, svm_model, svm_scaler, model, hubert_model, tokenizer): | |
if is_silent(chunk): | |
return 'Silence' | |
spectrogram = torchaudio.transforms.Spectrogram()(chunk) | |
dynamic_prompt = generate_dynamic_prompts(spectrogram) | |
# Additional processing steps can be added here | |
return dynamic_prompt | |
def process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate, num_workers=2): | |
dataset = AudioDataset(audio_path, chunk_size, overlap, sample_rate) | |
loader = DataLoader(dataset, batch_size=1, shuffle=False) | |
model_0, hubert_model_0, tokenizer_0 = init_models(0) | |
model_1, hubert_model_1, tokenizer_1 = init_models(1) | |
futures = [] | |
with ThreadPoolExecutor(max_workers=num_workers) as executor: | |
for i, chunk in enumerate(loader): | |
device_id = i % 2 | |
model = model_0 if device_id == 0 else model_1 | |
hubert_model = hubert_model_0 if device_id == 0 else hubert_model_1 | |
tokenizer = tokenizer_0 if device_id == 0 else tokenizer_1 | |
futures.append(executor.submit(process_chunk, chunk, qnn, clip_model, clip_processor, svm_model, svm_scaler, model, hubert_model, tokenizer)) | |
results = [future.result() for future in futures] | |
return results | |
audio_path = 'path/to/your/audio.wav' | |
chunk_size = 16000 * 30 | |
overlap = 16000 * 5 | |
sample_rate = 16000 | |
processed_data = process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment