Skip to content

Instantly share code, notes, and snippets.

@AbrahamSanders
Created June 17, 2024 21:51
Show Gist options
  • Save AbrahamSanders/6157f1fcf5b1b9b020483a3c87470cae to your computer and use it in GitHub Desktop.
Save AbrahamSanders/6157f1fcf5b1b9b020483a3c87470cae to your computer and use it in GitHub Desktop.
from transformers import EncodecModel, AutoProcessor
from datetime import datetime
from tqdm import tqdm
import soundfile as sf
import librosa
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EncodecModel.from_pretrained("facebook/encodec_24khz").to(device)
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
# Load audio file
audio_file = "4708.mp3"
audio, sr = librosa.load(audio_file, sr=model.config.sampling_rate, mono=True)
# Cut audio
#audio = audio[:30*sr]
# First one is a "warmup" to make sure the recorded runtime is accurate for the subsequent runs
configs = [(None, None), (None, None), (1., 0.01), (1., 0.4)]
results = []
for chunk_length_s, overlap in tqdm(configs):
model.config.chunk_length_s = chunk_length_s
processor.chunk_length_s = chunk_length_s
model.config.overlap = overlap
processor.overlap = overlap
# prepare for model
inputs = processor(raw_audio=audio, sampling_rate=sr, return_tensors="pt").to(device)
start = datetime.now()
with torch.no_grad():
# encode
encoder_outputs = model.encode(**inputs, bandwidth=6.0)
# decode
decoder_outputs= model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])
output_audio = decoder_outputs[0][0, 0]
total_time = datetime.now() - start
# Compare with original audio
total_error = torch.nn.functional.l1_loss(inputs.input_values[0,0], output_audio, reduction="sum")
avg_error = torch.nn.functional.l1_loss(inputs.input_values[0,0], output_audio, reduction="mean")
results.append((total_error, avg_error, len(output_audio), total_time))
sf.write(f"4708_{chunk_length_s}_{overlap}.wav", output_audio.cpu().numpy(), sr)
for config, result in list(zip(configs, results))[1:]:
chunk_length_s, overlap = config
total_error, avg_error, wav_length, total_time = result
print(f"chunk_length_s: {chunk_length_s}; overlap: {overlap}")
print(f"Total error: {total_error}")
print(f"Average error: {avg_error}")
print(f"Waveform length with padding (samples): {wav_length}")
print(f"Waveform length with padding (seconds): {(wav_length / sr):.2f}")
print(f"Processing time (seconds): {total_time.total_seconds():.2f}")
print()
print("Done")
chunk_length_s: None; overlap: None
Total error: 322234.03125
Average error: 0.007459419313818216
Waveform length with padding (samples): 43198272
Waveform length with padding (seconds): 1799.93
Processing time (seconds): 5.74

chunk_length_s: 1.0; overlap: 0.01
Total error: 322975.4375
Average error: 0.007472878787666559
Waveform length with padding (samples): 43219680
Waveform length with padding (seconds): 1800.82
Processing time (seconds): 18.00

chunk_length_s: 1.0; overlap: 0.4
Total error: 311178.75
Average error: 0.007201611530035734
Waveform length with padding (samples): 43209600
Waveform length with padding (seconds): 1800.40
Processing time (seconds): 29.51

Done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment