Created
September 20, 2024 19:48
-
-
Save twobob/206893126d4ca32bd4faba91499f07f6 to your computer and use it in GitHub Desktop.
basic_pitch mido and pretty midi attempt to figure out multiple layered midi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import re | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import pretty_midi | |
from basic_pitch.inference import predict_and_save | |
from basic_pitch import ICASSP_2022_MODEL_PATH | |
from mido import MidiFile, MetaMessage, bpm2tempo | |
# Set this flag to True if you want to reprocess all files, even if MIDI files already exist | |
REPROCESS_ALL = True | |
def extract_info_from_filename(filename): | |
bpm_match = re.search(r'(\d+)Bpm', filename) | |
key_match = re.search(r'=(\w+[#b]?m?)', filename) | |
bpm = int(bpm_match.group(1)) if bpm_match else None | |
key_name = key_match.group(1) if key_match else None | |
return bpm, key_name | |
def create_piano_roll(midi_file, output_image): | |
midi_data = pretty_midi.PrettyMIDI(midi_file) | |
piano_roll = midi_data.get_piano_roll(fs=100) | |
plt.figure(figsize=(12, 6)) | |
plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues') | |
plt.title('Piano Roll Representation') | |
plt.ylabel('MIDI Note Number') | |
plt.xlabel('Time (centiseconds)') | |
plt.colorbar(label='Velocity') | |
plt.savefig(output_image, dpi=300, bbox_inches='tight') | |
plt.close() | |
def add_tempo_to_midi(midi_file, bpm): | |
mid = MidiFile(midi_file) | |
tempo = bpm2tempo(bpm) | |
meta = MetaMessage('set_tempo', tempo=tempo, time=0) | |
mid.tracks[0].insert(0, meta) | |
mid.save(midi_file) | |
def audio_to_midi_basic_pitch(audio_file): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi = f"{base_name}_basic_pitch.mid" | |
output_image = f"{base_name}.png" | |
if os.path.exists(output_midi) and not REPROCESS_ALL: | |
print(f"MIDI file {output_midi} already exists. Skipping processing.") | |
return output_midi | |
print(f"Processing: {audio_file}") | |
# Extract BPM and key from filename | |
bpm, key_name = extract_info_from_filename(os.path.basename(audio_file)) | |
# Use Basic Pitch to predict and save the MIDI file | |
predict_and_save( | |
audio_path_list=[audio_file], | |
output_directory=os.path.dirname(audio_file), | |
save_midi=True, | |
sonify_midi=False, | |
save_model_outputs=False, | |
save_notes=False, | |
model_or_model_path=ICASSP_2022_MODEL_PATH | |
) | |
# Add tempo information to the MIDI file if BPM is provided | |
if bpm: | |
add_tempo_to_midi(output_midi, bpm) | |
print(f"Added BPM {bpm} to MIDI file.") | |
# Create piano roll visualization | |
create_piano_roll(output_midi, output_image) | |
#print(f"MIDI file saved as: {output_midi}") | |
print(f"Piano roll visualization saved as: {output_image}") | |
print("-" * 50) | |
return output_midi | |
def process_folder(folder_path): | |
audio_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))] | |
for audio_file in tqdm(audio_files, desc="Processing audio files", unit="file"): | |
full_path = os.path.join(folder_path, audio_file) | |
audio_to_midi_basic_pitch(full_path) | |
# Usage | |
input_path = r"E:\Dubstep_diffusion\tracks" | |
if os.path.isdir(input_path): | |
process_folder(input_path) | |
else: | |
output_midi = audio_to_midi_basic_pitch(input_path) | |
print(f"MIDI file saved as: {output_midi}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment