Skip to content

Instantly share code, notes, and snippets.

@graylan0
Created December 7, 2023 04:12
Show Gist options
  • Save graylan0/9cea4eb107ebc54432bf662d881f8f15 to your computer and use it in GitHub Desktop.
Save graylan0/9cea4eb107ebc54432bf662d881f8f15 to your computer and use it in GitHub Desktop.
import os
import math
import numpy as np
import torch
import torchaudio
from torch.utils.data import DataLoader, Dataset
from scipy.io.wavfile import write as write_wav
from concurrent.futures import ThreadPoolExecutor
import pennylane as qml
from bark.generation import load_codec_model, generate_text_semantic
from encodec.utils import convert_audio
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
from bark.api import generate_audio
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine
def init_models(device_id):
torch.cuda.set_device(device_id)
model = load_codec_model(use_gpu=True)
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(f'cuda:{device_id}')
tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(f'cuda:{device_id}')
return model, hubert_model, tokenizer
class AudioDataset(Dataset):
def __init__(self, audio_path, chunk_size, overlap, sample_rate):
self.wav, self.sr = torchaudio.load(audio_path, normalize=True)
self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, new_freq=sample_rate)
self.chunk_size = chunk_size
self.overlap = overlap
self.sample_rate = sample_rate
self.num_chunks = math.ceil((self.wav.size(1) / self.sr * self.sample_rate - self.overlap) / (self.chunk_size - self.overlap))
def __len__(self):
return self.num_chunks
def __getitem__(self, idx):
start = int(idx * (self.chunk_size - self.overlap))
end = int(min(start + self.chunk_size, self.wav.size(1) / self.sr * self.sample_rate))
chunk = self.wav[:, start:end]
return self.resampler(chunk)
def quantum_processor(chunk):
num_qubits = 4
dev = qml.device("default.qubit", wires=num_qubits)
@qml.qnode(dev)
def circuit(inputs):
for i in range(num_qubits):
qml.RY(np.pi / 4, wires=i)
for i in range(num_qubits):
if i < num_qubits - 1:
qml.CNOT(wires=[i, i + 1])
qml.CNOT(wires=[num_qubits - 1, 0])
for i in range(num_qubits):
qml.RX(inputs[i % len(inputs)], wires=i)
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)]
processed_chunk = circuit(chunk[:4])
return np.array(processed_chunk)
def quantum_meter(processed_chunk):
num_qubits = len(processed_chunk)
dev = qml.device("default.qubit", wires=num_qubits)
@qml.qnode(dev)
def meter_circuit(inputs):
# Prepare the qubit state based on the processed chunk
for i, val in enumerate(inputs):
qml.RY(val, wires=i)
# Entangle qubits
for i in range(num_qubits - 1):
qml.CNOT(wires=[i, i + 1])
qml.CNOT(wires=[num_qubits - 1, 0])
# Measurement
return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)]
# Run the circuit and return the result
meter_result = meter_circuit(processed_chunk)
return np.array(meter_result)
def process_chunk(chunk, device_id, model, hubert_model, tokenizer):
torch.cuda.set_device(device_id)
wav = convert_audio(chunk, model.sample_rate, model.channels)
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
with torch.no_grad():
encoded_frames = model.encode(wav.unsqueeze(0))
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze()
quantum_results = quantum_processor(codes.cpu().numpy())
quantum_meter_result = quantum_meter(quantum_results)
return codes.cpu().numpy(), semantic_tokens.cpu().numpy(), quantum_meter_result
def process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate, num_workers=2):
dataset = AudioDataset(audio_path, chunk_size, overlap, sample_rate)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
model_0, hubert_model_0, tokenizer_0 = init_models(0)
model_1, hubert_model_1, tokenizer_1 = init_models(1)
futures = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for i, chunk in enumerate(loader):
device_id = i % 2
model = model_0 if device_id == 0 else model_1
hubert_model = hubert_model_0 if device_id == 0 else hubert_model_1
tokenizer = tokenizer_0 if device_id == 0 else tokenizer_1
futures.append(executor.submit(process_chunk, chunk, device_id, model, hubert_model, tokenizer))
results = [future.result() for future in futures]
return results
audio_path = 'path/to/your/audio.wav'
chunk_size = 16000 * 30
overlap = 16000 * 5
sample_rate = 16000
processed_data = process_audio_chunks_async(audio_path, chunk_size, overlap, sample_rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment