Skip to content

Instantly share code, notes, and snippets.

@twobob
Created September 20, 2024 19:48
Show Gist options
  • Save twobob/206893126d4ca32bd4faba91499f07f6 to your computer and use it in GitHub Desktop.
Save twobob/206893126d4ca32bd4faba91499f07f6 to your computer and use it in GitHub Desktop.
basic_pitch mido and pretty midi attempt to figure out multiple layered midi
import os
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
import pretty_midi
from basic_pitch.inference import predict_and_save
from basic_pitch import ICASSP_2022_MODEL_PATH
from mido import MidiFile, MetaMessage, bpm2tempo
# Set this flag to True if you want to reprocess all files, even if MIDI files already exist
REPROCESS_ALL = True
def extract_info_from_filename(filename):
bpm_match = re.search(r'(\d+)Bpm', filename)
key_match = re.search(r'=(\w+[#b]?m?)', filename)
bpm = int(bpm_match.group(1)) if bpm_match else None
key_name = key_match.group(1) if key_match else None
return bpm, key_name
def create_piano_roll(midi_file, output_image):
midi_data = pretty_midi.PrettyMIDI(midi_file)
piano_roll = midi_data.get_piano_roll(fs=100)
plt.figure(figsize=(12, 6))
plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues')
plt.title('Piano Roll Representation')
plt.ylabel('MIDI Note Number')
plt.xlabel('Time (centiseconds)')
plt.colorbar(label='Velocity')
plt.savefig(output_image, dpi=300, bbox_inches='tight')
plt.close()
def add_tempo_to_midi(midi_file, bpm):
mid = MidiFile(midi_file)
tempo = bpm2tempo(bpm)
meta = MetaMessage('set_tempo', tempo=tempo, time=0)
mid.tracks[0].insert(0, meta)
mid.save(midi_file)
def audio_to_midi_basic_pitch(audio_file):
base_name = os.path.splitext(audio_file)[0]
output_midi = f"{base_name}_basic_pitch.mid"
output_image = f"{base_name}.png"
if os.path.exists(output_midi) and not REPROCESS_ALL:
print(f"MIDI file {output_midi} already exists. Skipping processing.")
return output_midi
print(f"Processing: {audio_file}")
# Extract BPM and key from filename
bpm, key_name = extract_info_from_filename(os.path.basename(audio_file))
# Use Basic Pitch to predict and save the MIDI file
predict_and_save(
audio_path_list=[audio_file],
output_directory=os.path.dirname(audio_file),
save_midi=True,
sonify_midi=False,
save_model_outputs=False,
save_notes=False,
model_or_model_path=ICASSP_2022_MODEL_PATH
)
# Add tempo information to the MIDI file if BPM is provided
if bpm:
add_tempo_to_midi(output_midi, bpm)
print(f"Added BPM {bpm} to MIDI file.")
# Create piano roll visualization
create_piano_roll(output_midi, output_image)
#print(f"MIDI file saved as: {output_midi}")
print(f"Piano roll visualization saved as: {output_image}")
print("-" * 50)
return output_midi
def process_folder(folder_path):
audio_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))]
for audio_file in tqdm(audio_files, desc="Processing audio files", unit="file"):
full_path = os.path.join(folder_path, audio_file)
audio_to_midi_basic_pitch(full_path)
# Usage
input_path = r"E:\Dubstep_diffusion\tracks"
if os.path.isdir(input_path):
process_folder(input_path)
else:
output_midi = audio_to_midi_basic_pitch(input_path)
print(f"MIDI file saved as: {output_midi}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment