Skip to content

Instantly share code, notes, and snippets.

@twobob
Created September 21, 2024 09:01
Show Gist options
  • Save twobob/54c6e370c735576635f77031cd7c617f to your computer and use it in GitHub Desktop.
Save twobob/54c6e370c735576635f77031cd7c617f to your computer and use it in GitHub Desktop.
spilts audio into stems - extracts the midi from the stems , saves them, merges them into one midi per song and saves that as well
import os
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
import pretty_midi
from basic_pitch.inference import predict_and_save, predict
from basic_pitch import ICASSP_2022_MODEL_PATH
from mido import MidiFile, MetaMessage, bpm2tempo
import demucs.separate
import shlex
import librosa
import subprocess
# Set these flags as needed
REPROCESS_ALL = False # Reprocess all files regardless of existing outputs
REPROCESS_ALL_MIDI = True # Reprocess MIDI files even if REPROCESS_ALL is False
def extract_info_from_filename(filename):
bpm_match = re.search(r'(\d+)Bpm', filename)
key_match = re.search(r'=(\w+[#b]?m?)', filename)
bpm = int(bpm_match.group(1)) if bpm_match else None
key_name = key_match.group(1) if key_match else None
return bpm, key_name
def create_piano_roll(midi_file, output_image):
midi_data = pretty_midi.PrettyMIDI(midi_file)
piano_roll = midi_data.get_piano_roll(fs=100)
plt.figure(figsize=(12, 6))
plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues')
plt.title('Piano Roll Representation')
plt.ylabel('MIDI Note Number')
plt.xlabel('Time (centiseconds)')
plt.colorbar(label='Velocity')
plt.savefig(output_image, dpi=300, bbox_inches='tight')
plt.close()
def add_tempo_to_midi(midi_file, bpm):
mid = MidiFile(midi_file)
tempo = bpm2tempo(bpm)
meta = MetaMessage('set_tempo', tempo=tempo, time=0)
mid.tracks[0].insert(0, meta)
mid.save(midi_file)
def separate_audio_with_demucs(audio_file, output_dir):
# Determine the expected output directory for stems
model_name = "htdemucs" # Change if you use a different model
base_name = os.path.splitext(os.path.basename(audio_file))[0]
separated_dir = os.path.join(output_dir, model_name, base_name)
# Check if the separated directory already exists and contains stems
if os.path.exists(separated_dir) and any(f.endswith('.wav') for f in os.listdir(separated_dir)) and not REPROCESS_ALL:
print(f"Stems already exist in {separated_dir}. Skipping separation.")
return separated_dir
# If stems don't exist or REPROCESS_ALL is True, proceed with separation
command = f"demucs --out \"{output_dir}\" \"{audio_file}\""
subprocess.run(command, shell=True)
return separated_dir # Path to the directory with separated audio files
def process_stem_with_basic_pitch(stem_audio_file, bpm=None):
# Use Basic Pitch to predict MIDI from a stem audio file
if bpm:
model_output, midi_data, note_events = predict(
audio_path=stem_audio_file,
model_or_model_path=ICASSP_2022_MODEL_PATH,
midi_tempo=bpm
)
return midi_data
else:
model_output, midi_data, note_events = predict(
audio_path=stem_audio_file,
model_or_model_path=ICASSP_2022_MODEL_PATH
)
return midi_data
def merge_midi_files(midi_data_list):
# Merge multiple PrettyMIDI objects into one
merged_midi = pretty_midi.PrettyMIDI()
for midi_data in midi_data_list:
for instrument in midi_data.instruments:
# Append instruments to merged_midi
merged_midi.instruments.append(instrument)
return merged_midi
def should_skip(file_path, is_midi=False):
"""
Determines whether to skip processing based on the REPROCESS flags.
Args:
file_path (str): Path to the file.
is_midi (bool): Whether the file is a MIDI file.
Returns:
bool: True if processing should be skipped, False otherwise.
"""
if os.path.exists(file_path):
if is_midi:
return not (REPROCESS_ALL or REPROCESS_ALL_MIDI)
else:
return not REPROCESS_ALL
return False
def audio_to_midi_with_demucs(audio_file, bpm=None):
base_name = os.path.splitext(audio_file)[0]
output_midi_file = f"{base_name}.mid"
if should_skip(output_midi_file, is_midi=True):
print(f"MIDI file {output_midi_file} already exists. Skipping processing.")
return output_midi_file
print(f"Processing with Demucs and Basic Pitch: {audio_file}")
# Separate the audio into stems using Demucs
separated_dir = separate_audio_with_demucs(audio_file, os.path.dirname(audio_file))
print(f"Separated audio stored in {separated_dir}")
# Process each stem with Basic Pitch
stem_files = [os.path.join(separated_dir, f) for f in os.listdir(separated_dir) if f.endswith('.wav') and f != 'drums.wav']
midi_data_list = []
for stem_file in stem_files:
stem_midi_file = os.path.splitext(stem_file)[0] + '.mid'
if should_skip(stem_midi_file, is_midi=True):
print(f"Stem MIDI file {stem_midi_file} already exists. Skipping processing.")
midi_data = pretty_midi.PrettyMIDI(stem_midi_file)
else:
print(f"Processing stem: {stem_file}")
midi_data = process_stem_with_basic_pitch(stem_file, bpm)
midi_data.write(stem_midi_file)
print(f"Stem MIDI file saved as: {stem_midi_file}")
midi_data_list.append(midi_data)
# Merge the MIDI data from all stems
merged_midi = merge_midi_files(midi_data_list)
# Save the merged MIDI file using PrettyMIDI
merged_midi.write(output_midi_file)
print(f"Merged MIDI file saved as: {output_midi_file}")
# If BPM is provided, add tempo information using Mido
if bpm:
add_tempo_to_midi(output_midi_file, bpm)
print(f"Added BPM {bpm} to MIDI file.")
return output_midi_file
def overwrite_rename(src, dst):
if os.path.exists(dst):
os.remove(dst)
os.rename(src, dst)
def audio_to_midi_basic_pitch(audio_file, bpm=None):
base_name = os.path.splitext(audio_file)[0]
output_midi_file = f"{base_name}.mid"
if should_skip(output_midi_file, is_midi=True):
print(f"MIDI file {output_midi_file} already exists. Skipping processing.")
return output_midi_file
print(f"Processing with Basic Pitch: {audio_file}")
# Use Basic Pitch to predict and save the MIDI file
predict_and_save(
audio_path_list=[audio_file],
output_directory=os.path.dirname(audio_file),
save_midi=True,
sonify_midi=False,
save_model_outputs=False,
save_notes=False,
multiple_pitch_bends=False,
model_or_model_path=ICASSP_2022_MODEL_PATH
)
# The MIDI file is saved with the name including "_basic_pitch"
midi_file_generated = f"{base_name}_basic_pitch.mid"
# If BPM is provided, add tempo information
if bpm:
add_tempo_to_midi(midi_file_generated, bpm)
print(f"Added BPM {bpm} to MIDI file.")
# Rename the file to remove "_basic_pitch"
overwrite_rename(midi_file_generated, output_midi_file)
print(f"MIDI file renamed and saved as: {output_midi_file}")
return output_midi_file
def audio_to_midi(audio_file):
base_name = os.path.splitext(audio_file)[0]
output_midi_file = f"{base_name}.mid"
output_image = f"{base_name}.png"
if should_skip(output_midi_file, is_midi=True):
print(f"MIDI file {output_midi_file} already exists. Skipping processing.")
return output_midi_file
print(f"Processing: {audio_file}")
# Extract BPM and key from filename
bpm, key_name = extract_info_from_filename(os.path.basename(audio_file))
# Load audio file and get duration
y, sr = librosa.load(audio_file, sr=None, mono=True)
duration = librosa.get_duration(y=y, sr=sr)
print(f"Audio duration: {duration:.2f} seconds")
if duration < 600:
# Use Demucs and Basic Pitch
output_midi_file = audio_to_midi_with_demucs(audio_file, bpm)
else:
# Use Basic Pitch directly
output_midi_file = audio_to_midi_basic_pitch(audio_file, bpm)
# Create piano roll visualization
if not should_skip(output_midi_file, is_midi=True):
create_piano_roll(output_midi_file, output_image)
print(f"Piano roll visualization saved as: {output_image}")
else:
# If skipping processing, check if the image exists or needs to be recreated
if REPROCESS_ALL or REPROCESS_ALL_MIDI:
create_piano_roll(output_midi_file, output_image)
print(f"Piano roll visualization saved as: {output_image}")
else:
if not os.path.exists(output_image):
create_piano_roll(output_midi_file, output_image)
print(f"Piano roll visualization saved as: {output_image}")
else:
print(f"Piano roll image {output_image} already exists. Skipping visualization.")
print("-" * 50)
return output_midi_file
def process_folder(folder_path):
audio_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))]
for audio_file in tqdm(audio_files, desc="Processing audio files", unit="file"):
full_path = os.path.join(folder_path, audio_file)
audio_to_midi(full_path)
# Usage
input_path = r"E:\Dubstep_diffusion\tracks"
if os.path.isdir(input_path):
process_folder(input_path)
else:
output_midi = audio_to_midi(input_path)
print(f"MIDI file saved as: {output_midi}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment