Created
September 21, 2024 09:01
-
-
Save twobob/54c6e370c735576635f77031cd7c617f to your computer and use it in GitHub Desktop.
spilts audio into stems - extracts the midi from the stems , saves them, merges them into one midi per song and saves that as well
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import re | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import pretty_midi | |
from basic_pitch.inference import predict_and_save, predict | |
from basic_pitch import ICASSP_2022_MODEL_PATH | |
from mido import MidiFile, MetaMessage, bpm2tempo | |
import demucs.separate | |
import shlex | |
import librosa | |
import subprocess | |
# Set these flags as needed | |
REPROCESS_ALL = False # Reprocess all files regardless of existing outputs | |
REPROCESS_ALL_MIDI = True # Reprocess MIDI files even if REPROCESS_ALL is False | |
def extract_info_from_filename(filename): | |
bpm_match = re.search(r'(\d+)Bpm', filename) | |
key_match = re.search(r'=(\w+[#b]?m?)', filename) | |
bpm = int(bpm_match.group(1)) if bpm_match else None | |
key_name = key_match.group(1) if key_match else None | |
return bpm, key_name | |
def create_piano_roll(midi_file, output_image): | |
midi_data = pretty_midi.PrettyMIDI(midi_file) | |
piano_roll = midi_data.get_piano_roll(fs=100) | |
plt.figure(figsize=(12, 6)) | |
plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues') | |
plt.title('Piano Roll Representation') | |
plt.ylabel('MIDI Note Number') | |
plt.xlabel('Time (centiseconds)') | |
plt.colorbar(label='Velocity') | |
plt.savefig(output_image, dpi=300, bbox_inches='tight') | |
plt.close() | |
def add_tempo_to_midi(midi_file, bpm): | |
mid = MidiFile(midi_file) | |
tempo = bpm2tempo(bpm) | |
meta = MetaMessage('set_tempo', tempo=tempo, time=0) | |
mid.tracks[0].insert(0, meta) | |
mid.save(midi_file) | |
def separate_audio_with_demucs(audio_file, output_dir): | |
# Determine the expected output directory for stems | |
model_name = "htdemucs" # Change if you use a different model | |
base_name = os.path.splitext(os.path.basename(audio_file))[0] | |
separated_dir = os.path.join(output_dir, model_name, base_name) | |
# Check if the separated directory already exists and contains stems | |
if os.path.exists(separated_dir) and any(f.endswith('.wav') for f in os.listdir(separated_dir)) and not REPROCESS_ALL: | |
print(f"Stems already exist in {separated_dir}. Skipping separation.") | |
return separated_dir | |
# If stems don't exist or REPROCESS_ALL is True, proceed with separation | |
command = f"demucs --out \"{output_dir}\" \"{audio_file}\"" | |
subprocess.run(command, shell=True) | |
return separated_dir # Path to the directory with separated audio files | |
def process_stem_with_basic_pitch(stem_audio_file, bpm=None): | |
# Use Basic Pitch to predict MIDI from a stem audio file | |
if bpm: | |
model_output, midi_data, note_events = predict( | |
audio_path=stem_audio_file, | |
model_or_model_path=ICASSP_2022_MODEL_PATH, | |
midi_tempo=bpm | |
) | |
return midi_data | |
else: | |
model_output, midi_data, note_events = predict( | |
audio_path=stem_audio_file, | |
model_or_model_path=ICASSP_2022_MODEL_PATH | |
) | |
return midi_data | |
def merge_midi_files(midi_data_list): | |
# Merge multiple PrettyMIDI objects into one | |
merged_midi = pretty_midi.PrettyMIDI() | |
for midi_data in midi_data_list: | |
for instrument in midi_data.instruments: | |
# Append instruments to merged_midi | |
merged_midi.instruments.append(instrument) | |
return merged_midi | |
def should_skip(file_path, is_midi=False): | |
""" | |
Determines whether to skip processing based on the REPROCESS flags. | |
Args: | |
file_path (str): Path to the file. | |
is_midi (bool): Whether the file is a MIDI file. | |
Returns: | |
bool: True if processing should be skipped, False otherwise. | |
""" | |
if os.path.exists(file_path): | |
if is_midi: | |
return not (REPROCESS_ALL or REPROCESS_ALL_MIDI) | |
else: | |
return not REPROCESS_ALL | |
return False | |
def audio_to_midi_with_demucs(audio_file, bpm=None): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi_file = f"{base_name}.mid" | |
if should_skip(output_midi_file, is_midi=True): | |
print(f"MIDI file {output_midi_file} already exists. Skipping processing.") | |
return output_midi_file | |
print(f"Processing with Demucs and Basic Pitch: {audio_file}") | |
# Separate the audio into stems using Demucs | |
separated_dir = separate_audio_with_demucs(audio_file, os.path.dirname(audio_file)) | |
print(f"Separated audio stored in {separated_dir}") | |
# Process each stem with Basic Pitch | |
stem_files = [os.path.join(separated_dir, f) for f in os.listdir(separated_dir) if f.endswith('.wav') and f != 'drums.wav'] | |
midi_data_list = [] | |
for stem_file in stem_files: | |
stem_midi_file = os.path.splitext(stem_file)[0] + '.mid' | |
if should_skip(stem_midi_file, is_midi=True): | |
print(f"Stem MIDI file {stem_midi_file} already exists. Skipping processing.") | |
midi_data = pretty_midi.PrettyMIDI(stem_midi_file) | |
else: | |
print(f"Processing stem: {stem_file}") | |
midi_data = process_stem_with_basic_pitch(stem_file, bpm) | |
midi_data.write(stem_midi_file) | |
print(f"Stem MIDI file saved as: {stem_midi_file}") | |
midi_data_list.append(midi_data) | |
# Merge the MIDI data from all stems | |
merged_midi = merge_midi_files(midi_data_list) | |
# Save the merged MIDI file using PrettyMIDI | |
merged_midi.write(output_midi_file) | |
print(f"Merged MIDI file saved as: {output_midi_file}") | |
# If BPM is provided, add tempo information using Mido | |
if bpm: | |
add_tempo_to_midi(output_midi_file, bpm) | |
print(f"Added BPM {bpm} to MIDI file.") | |
return output_midi_file | |
def overwrite_rename(src, dst): | |
if os.path.exists(dst): | |
os.remove(dst) | |
os.rename(src, dst) | |
def audio_to_midi_basic_pitch(audio_file, bpm=None): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi_file = f"{base_name}.mid" | |
if should_skip(output_midi_file, is_midi=True): | |
print(f"MIDI file {output_midi_file} already exists. Skipping processing.") | |
return output_midi_file | |
print(f"Processing with Basic Pitch: {audio_file}") | |
# Use Basic Pitch to predict and save the MIDI file | |
predict_and_save( | |
audio_path_list=[audio_file], | |
output_directory=os.path.dirname(audio_file), | |
save_midi=True, | |
sonify_midi=False, | |
save_model_outputs=False, | |
save_notes=False, | |
multiple_pitch_bends=False, | |
model_or_model_path=ICASSP_2022_MODEL_PATH | |
) | |
# The MIDI file is saved with the name including "_basic_pitch" | |
midi_file_generated = f"{base_name}_basic_pitch.mid" | |
# If BPM is provided, add tempo information | |
if bpm: | |
add_tempo_to_midi(midi_file_generated, bpm) | |
print(f"Added BPM {bpm} to MIDI file.") | |
# Rename the file to remove "_basic_pitch" | |
overwrite_rename(midi_file_generated, output_midi_file) | |
print(f"MIDI file renamed and saved as: {output_midi_file}") | |
return output_midi_file | |
def audio_to_midi(audio_file): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi_file = f"{base_name}.mid" | |
output_image = f"{base_name}.png" | |
if should_skip(output_midi_file, is_midi=True): | |
print(f"MIDI file {output_midi_file} already exists. Skipping processing.") | |
return output_midi_file | |
print(f"Processing: {audio_file}") | |
# Extract BPM and key from filename | |
bpm, key_name = extract_info_from_filename(os.path.basename(audio_file)) | |
# Load audio file and get duration | |
y, sr = librosa.load(audio_file, sr=None, mono=True) | |
duration = librosa.get_duration(y=y, sr=sr) | |
print(f"Audio duration: {duration:.2f} seconds") | |
if duration < 600: | |
# Use Demucs and Basic Pitch | |
output_midi_file = audio_to_midi_with_demucs(audio_file, bpm) | |
else: | |
# Use Basic Pitch directly | |
output_midi_file = audio_to_midi_basic_pitch(audio_file, bpm) | |
# Create piano roll visualization | |
if not should_skip(output_midi_file, is_midi=True): | |
create_piano_roll(output_midi_file, output_image) | |
print(f"Piano roll visualization saved as: {output_image}") | |
else: | |
# If skipping processing, check if the image exists or needs to be recreated | |
if REPROCESS_ALL or REPROCESS_ALL_MIDI: | |
create_piano_roll(output_midi_file, output_image) | |
print(f"Piano roll visualization saved as: {output_image}") | |
else: | |
if not os.path.exists(output_image): | |
create_piano_roll(output_midi_file, output_image) | |
print(f"Piano roll visualization saved as: {output_image}") | |
else: | |
print(f"Piano roll image {output_image} already exists. Skipping visualization.") | |
print("-" * 50) | |
return output_midi_file | |
def process_folder(folder_path): | |
audio_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))] | |
for audio_file in tqdm(audio_files, desc="Processing audio files", unit="file"): | |
full_path = os.path.join(folder_path, audio_file) | |
audio_to_midi(full_path) | |
# Usage | |
input_path = r"E:\Dubstep_diffusion\tracks" | |
if os.path.isdir(input_path): | |
process_folder(input_path) | |
else: | |
output_midi = audio_to_midi(input_path) | |
print(f"MIDI file saved as: {output_midi}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment