Created
December 7, 2023 04:12
-
-
Save graylan0/9cea4eb107ebc54432bf662d881f8f15 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import math | |
import numpy as np | |
import torch | |
import torchaudio | |
from torch.utils.data import DataLoader, Dataset | |
from scipy.io.wavfile import write as write_wav | |
from concurrent.futures import ThreadPoolExecutor | |
import pennylane as qml | |
from bark.generation import load_codec_model, generate_text_semantic | |
from encodec.utils import convert_audio | |
from hubert.hubert_manager import HuBERTManager | |
from hubert.pre_kmeans_hubert import CustomHubert | |
from hubert.customtokenizer import CustomTokenizer | |
from bark.api import generate_audio | |
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine | |
def init_models(device_id): | |
torch.cuda.set_device(device_id) | |
model = load_codec_model(use_gpu=True) | |
hubert_manager = HuBERTManager() | |
hubert_manager.make_sure_hubert_installed() | |
hubert_manager.make_sure_tokenizer_installed() | |
hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(f'cuda:{device_id}') | |
tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(f'cuda:{device_id}') | |
return model, hubert_model, tokenizer | |
class AudioDataset(Dataset): | |
def __init__(self, audio_path, chunk_size, overlap, sample_rate): | |
self.wav, self.sr = torchaudio.load(audio_path, normalize=True) | |
self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, new_freq=sample_rate) | |
self.chunk_size = chunk_size | |
self.overlap = overlap | |
self.sample_rate = sample_rate | |
self.num_chunks = math.ceil((self.wav.size(1) / self.sr * self.sample_rate - self.overlap) / (self.chunk_size - self.overlap)) | |
def __len__(self): | |
return self.num_chunks | |
def __getitem__(self, idx): | |
start = int(idx * (self.chunk_size - self.overlap)) | |
end = int(min(start + self.chunk_size, self.wav.size(1) / self.sr * self.sample_rate)) | |
chunk = self.wav[:, start:end] | |
return self.resampler(chunk) | |
def quantum_processor(chunk): | |
num_qubits = 4 | |
dev = qml.device("default.qubit", wires=num_qubits) | |
@qml.qnode(dev) | |
def circuit(inputs): | |
for i in range(num_qubits): | |
qml.RY(np.pi / 4, wires=i) | |
for i in range(num_qubits): | |
if i < num_qubits - 1: | |
qml.CNOT(wires=[i, i + 1]) | |
qml.CNOT(wires=[num_qubits - 1, 0]) | |
for i in range(num_qubits): | |
qml.RX(inputs[i % len(inputs)], wires=i) | |
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)] | |
processed_chunk = circuit(chunk[:4]) | |
return np.array(processed_chunk) | |
def quantum_meter(processed_chunk): | |
num_qubits = len(processed_chunk) | |
dev = qml.device("default.qubit", wires=num_qubits) | |
@qml.qnode(dev) | |
def meter_circuit(inputs): | |
# Prepare the qubit state based on the processed chunk | |
for i, val in enumerate(inputs): | |
qml.RY(val, wires=i) | |
# Entangle qubits | |
for i in range(num_qubits - 1): | |
qml.CNOT(wires=[i, i + 1]) | |
qml.CNOT(wires=[num_qubits - 1, 0]) | |
# Measurement | |
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)] | |
# Run the circuit and return the result | |
meter_result = meter_circuit(processed_chunk) | |
return np.array(meter_result) | |
def process_chunk(chunk, device_id, model, hubert_model, tokenizer): | |
torch.cuda.set_device(device_id) | |
wav = convert_audio(chunk, model.sample_rate, model.channels) | |
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate) | |
semantic_tokens = tokenizer.get_token(semantic_vectors) | |
with torch.no_grad(): | |
encoded_frames = model.encode(wav.unsqueeze(0)) | |
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() | |
quantum_results = quantum_processor(codes.cpu().numpy()) | |
quantum_meter_result = quantum_meter(quantum_results) | |
return codes.cpu().numpy(), semantic_tokens.cpu().numpy(), quantum_meter_result | |
def process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate, num_workers=2): | |
dataset = AudioDataset(audio_path, chunk_size, overlap, sample_rate) | |
loader = DataLoader(dataset, batch_size=1, shuffle=False) | |
model_0, hubert_model_0, tokenizer_0 = init_models(0) | |
model_1, hubert_model_1, tokenizer_1 = init_models(1) | |
futures = [] | |
with ThreadPoolExecutor(max_workers=num_workers) as executor: | |
for i, chunk in enumerate(loader): | |
device_id = i % 2 | |
model = model_0 if device_id == 0 else model_1 | |
hubert_model = hubert_model_0 if device_id == 0 else hubert_model_1 | |
tokenizer = tokenizer_0 if device_id == 0 else tokenizer_1 | |
futures.append(executor.submit(process_chunk, chunk, device_id, model, hubert_model, tokenizer)) | |
results = [future.result() for future in futures] | |
return results | |
audio_path = 'path/to/your/audio.wav' | |
chunk_size = 16000 * 30 | |
overlap = 16000 * 5 | |
sample_rate = 16000 | |
processed_data = process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment