from transformers import EncodecModel, AutoProcessor
from datetime import datetime
from tqdm import tqdm
import soundfile as sf
import librosa
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EncodecModel.from_pretrained("facebook/encodec_24khz").to(device)
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
# Load audio file
audio_file = "4708.mp3"
audio, sr = librosa.load(audio_file, sr=model.config.sampling_rate, mono=True)
# Cut audio
#audio = audio[:30*sr]
# First one is a "warmup" to make sure the recorded runtime is accurate for the subsequent runs
configs = [(None, None), (None, None), (1., 0.01), (1., 0.4)]
results = []
for chunk_length_s, overlap in tqdm(configs):
model.config.chunk_length_s = chunk_length_s
processor.chunk_length_s = chunk_length_s
model.config.overlap = overlap
processor.overlap = overlap
# prepare for model
inputs = processor(raw_audio=audio, sampling_rate=sr, return_tensors="pt").to(device)
start =
with torch.no_grad():
# encode
encoder_outputs = model.encode(**inputs, bandwidth=6.0)
# decode
decoder_outputs= model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])
output_audio = decoder_outputs[0][0, 0]
total_time = - start
# Compare with original audio
total_error = torch.nn.functional.l1_loss(inputs.input_values[0,0], output_audio, reduction="sum")
avg_error = torch.nn.functional.l1_loss(inputs.input_values[0,0], output_audio, reduction="mean")
results.append((total_error, avg_error, len(output_audio), total_time))
sf.write(f"4708_{chunk_length_s}_{overlap}.wav", output_audio.cpu().numpy(), sr)
for config, result in list(zip(configs, results))[1:]:
chunk_length_s, overlap = config
total_error, avg_error, wav_length, total_time = result
print(f"chunk_length_s: {chunk_length_s}; overlap: {overlap}")
print(f"Total error: {total_error}")
print(f"Average error: {avg_error}")
print(f"Waveform length with padding (samples): {wav_length}")
print(f"Waveform length with padding (seconds): {(wav_length / sr):.2f}")
print(f"Processing time (seconds): {total_time.total_seconds():.2f}")
chunk_length_s: None; overlap: None
Total error: 322234.03125
Average error: 0.007459419313818216
Waveform length with padding (samples): 43198272
Waveform length with padding (seconds): 1799.93
Processing time (seconds): 5.74

chunk_length_s: 1.0; overlap: 0.01
Total error: 322975.4375
Average error: 0.007472878787666559
Waveform length with padding (samples): 43219680
Waveform length with padding (seconds): 1800.82
Processing time (seconds): 18.00

chunk_length_s: 1.0; overlap: 0.4
Total error: 311178.75
Average error: 0.007201611530035734
Waveform length with padding (samples): 43209600
Waveform length with padding (seconds): 1800.40
Processing time (seconds): 29.51

