Skip to content

Instantly share code, notes, and snippets.

@twobob
Created September 21, 2024 17:30
Show Gist options
  • Save twobob/73bd328037ec77b7257fd38fc4b9ce45 to your computer and use it in GitHub Desktop.
Save twobob/73bd328037ec77b7257fd38fc4b9ce45 to your computer and use it in GitHub Desktop.
import os
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
import pretty_midi
from basic_pitch.inference import predict_and_save, predict
from basic_pitch import ICASSP_2022_MODEL_PATH
from mido import MidiFile, MetaMessage, bpm2tempo
import subprocess
import librosa
import numpy as np
import soundfile as sf
import tensorflow as tf
import tensorflow_hub as hub
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.get_logger().setLevel('ERROR')
REPROCESS_ALL = False
REPROCESS_ALL_MIDI = False
REPROCESS_ALL_TAGS = True
YAMNET_MODEL_HANDLE = 'https://tfhub.dev/google/yamnet/1'
yamnet_model = hub.load(YAMNET_MODEL_HANDLE)
yamnet_classes = []
class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8')
with tf.io.gfile.GFile(class_map_path) as f:
for line in f.readlines()[1:]:
_, _, display_name = line.strip().split(',', 2)
yamnet_classes.append(display_name)
def extract_info_from_filename(filename):
bpm_match = re.search(r'(\d+)Bpm', filename)
key_match = re.search(r'=(\w+[#b]?m?)', filename)
bpm = int(bpm_match.group(1)) if bpm_match else None
key_name = key_match.group(1) if key_match else None
return bpm, key_name
def create_piano_roll(midi_file, output_image):
midi_data = pretty_midi.PrettyMIDI(midi_file)
piano_roll = midi_data.get_piano_roll(fs=100)
plt.figure(figsize=(12, 6))
plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues')
plt.title('Piano Roll Representation')
plt.ylabel('MIDI Note Number')
plt.xlabel('Time (centiseconds)')
plt.colorbar(label='Velocity')
plt.savefig(output_image, dpi=300, bbox_inches='tight')
plt.close()
def add_tempo_to_midi(midi_file, bpm):
mid = MidiFile(midi_file)
tempo = bpm2tempo(bpm)
meta = MetaMessage('set_tempo', tempo=tempo, time=0)
mid.tracks[0].insert(0, meta)
mid.save(midi_file)
def separate_audio_with_demucs(audio_file, output_dir):
model_name = "htdemucs"
base_name = os.path.splitext(os.path.basename(audio_file))[0]
separated_dir = os.path.join(output_dir, model_name, base_name)
if os.path.exists(separated_dir) and any(f.endswith('.wav') for f in os.listdir(separated_dir)) and not REPROCESS_ALL:
#print(f"Stems already exist in {separated_dir}. Skipping separation.")
return separated_dir
command = f"demucs --out \"{output_dir}\" \"{audio_file}\""
subprocess.run(command, shell=True)
return separated_dir
def process_stem_with_basic_pitch(stem_audio_file, bpm=None):
if bpm:
model_output, midi_data, note_events = predict(
audio_path=stem_audio_file,
model_or_model_path=ICASSP_2022_MODEL_PATH,
midi_tempo=bpm
)
else:
model_output, midi_data, note_events = predict(
audio_path=stem_audio_file,
model_or_model_path=ICASSP_2022_MODEL_PATH
)
return midi_data
def merge_midi_files(midi_data_list):
merged_midi = pretty_midi.PrettyMIDI()
for midi_data in midi_data_list:
for instrument in midi_data.instruments:
merged_midi.instruments.append(instrument)
return merged_midi
def should_skip(file_path, file_type='general'):
if os.path.exists(file_path):
if file_type == 'midi':
return not (REPROCESS_ALL or REPROCESS_ALL_MIDI)
elif file_type == 'tag':
return not (REPROCESS_ALL or REPROCESS_ALL_TAGS)
else:
return not REPROCESS_ALL
return False
def overwrite_rename(src, dst):
if os.path.exists(dst):
os.remove(dst)
os.rename(src, dst)
def auto_label_audio(audio_file, output_txt_file):
if os.path.exists(output_txt_file) and not (REPROCESS_ALL or REPROCESS_ALL_TAGS):
#print(f"Label file {output_txt_file} already exists and REPROCESS_ALL_TAGS is False. Skipping auto-labeling.")
return
#print(f"Auto-labeling: {audio_file}")
wav_data, sr = sf.read(audio_file)
if len(wav_data.shape) > 1:
wav_data = np.mean(wav_data, axis=1)
target_sr = 16000
if sr != target_sr:
wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=target_sr)
wav_data = wav_data.astype(np.float32)
scores, embeddings, spectrogram = yamnet_model(wav_data)
mean_scores = np.mean(scores.numpy(), axis=0)
top_k = 10
top_class_indices = np.argsort(mean_scores)[::-1][:top_k]
top_scores = mean_scores[top_class_indices]
top_labels = [yamnet_classes[i] for i in top_class_indices]
with open(output_txt_file, 'w') as f:
for label, score in zip(top_labels, top_scores):
f.write(f"{label}: {score:.4f}\n")
#print(f"Auto-labeling completed for {audio_file}. Labels saved to {output_txt_file}")
def audio_to_midi_with_demucs(audio_file, bpm=None):
base_name = os.path.splitext(audio_file)[0]
output_midi_file = f"{base_name}.mid"
output_txt_file = f"{base_name}.txt"
auto_label_audio(audio_file, output_txt_file)
if should_skip(output_midi_file, file_type='midi'):
#print(f"MIDI file {output_midi_file} already exists. Skipping processing.")
return output_midi_file
#print(f"Processing with Demucs and Basic Pitch: {audio_file}")
separated_dir = separate_audio_with_demucs(audio_file, os.path.dirname(audio_file))
#print(f"Separated audio stored in {separated_dir}")
stem_files = [os.path.join(separated_dir, f) for f in os.listdir(separated_dir) if f.endswith('.wav')]
midi_data_list = []
for stem_file in stem_files:
stem_base = os.path.splitext(stem_file)[0]
stem_midi_file = f"{stem_base}.mid"
stem_txt_file = f"{stem_base}.txt"
auto_label_audio(stem_file, stem_txt_file)
if should_skip(stem_midi_file, file_type='midi'):
#print(f"Stem MIDI file {stem_midi_file} already exists. Skipping processing.")
midi_data = pretty_midi.PrettyMIDI(stem_midi_file)
else:
#print(f"Processing stem: {stem_file}")
midi_data = process_stem_with_basic_pitch(stem_file, bpm)
midi_data.write(stem_midi_file)
#print(f"Stem MIDI file saved as: {stem_midi_file}")
midi_data_list.append(midi_data)
merged_midi = merge_midi_files(midi_data_list)
merged_midi.write(output_midi_file)
#print(f"Merged MIDI file saved as: {output_midi_file}")
if bpm:
add_tempo_to_midi(output_midi_file, bpm)
print(f"Added BPM {bpm} to MIDI file.")
return output_midi_file
def audio_to_midi_basic_pitch(audio_file, bpm=None):
base_name = os.path.splitext(audio_file)[0]
output_midi_file = f"{base_name}.mid"
output_txt_file = f"{base_name}.txt"
auto_label_audio(audio_file, output_txt_file)
if should_skip(output_midi_file, file_type='midi'):
print(f"MIDI file {output_midi_file} already exists. Skipping processing.")
return output_midi_file
print(f"Processing with Basic Pitch: {audio_file}")
predict_and_save(
audio_path_list=[audio_file],
output_directory=os.path.dirname(audio_file),
save_midi=True,
sonify_midi=False,
save_model_outputs=False,
save_notes=False,
multiple_pitch_bends=False,
model_or_model_path=ICASSP_2022_MODEL_PATH
)
midi_file_generated = f"{base_name}_basic_pitch.mid"
if bpm:
add_tempo_to_midi(midi_file_generated, bpm)
print(f"Added BPM {bpm} to MIDI file.")
overwrite_rename(midi_file_generated, output_midi_file)
print(f"MIDI file renamed and saved as: {output_midi_file}")
return output_midi_file
def audio_to_midi(audio_file):
base_name = os.path.splitext(audio_file)[0]
output_midi_file = f"{base_name}.mid"
output_image = f"{base_name}.png"
output_txt_file = f"{base_name}.txt"
print(f"Processing: {audio_file}")
auto_label_audio(audio_file, output_txt_file)
if should_skip(output_midi_file, file_type='midi'):
#print(f"MIDI file {output_midi_file} already exists. Skipping processing.")
return output_midi_file
bpm, key_name = extract_info_from_filename(os.path.basename(audio_file))
y, sr = librosa.load(audio_file, sr=None, mono=True)
duration = librosa.get_duration(y=y, sr=sr)
#print(f"Audio duration: {duration:.2f} seconds")
if duration < 600:
output_midi_file = audio_to_midi_with_demucs(audio_file, bpm)
else:
output_midi_file = audio_to_midi_basic_pitch(audio_file, bpm)
if not should_skip(output_midi_file, file_type='midi'):
create_piano_roll(output_midi_file, output_image)
#print(f"Piano roll visualization saved as: {output_image}")
else:
if REPROCESS_ALL or REPROCESS_ALL_MIDI:
create_piano_roll(output_midi_file, output_image)
print(f"Piano roll visualization saved as: {output_image}")
elif not os.path.exists(output_image):
create_piano_roll(output_midi_file, output_image)
#print(f"Piano roll visualization saved as: {output_image}")
else:
pass
#print(f"Piano roll image {output_image} already exists. Skipping visualization.")
#print("-" * 50)
return output_midi_file
def process_folder(folder_path):
audio_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))]
for audio_file in tqdm(audio_files, desc="Processing audio files", unit="file"):
full_path = os.path.join(folder_path, audio_file)
audio_to_midi(full_path)
input_path = r"E:\Dubstep_diffusion\tracks"
if os.path.isdir(input_path):
process_folder(input_path)
else:
output_midi = audio_to_midi(input_path)
print(f"MIDI file saved as: {output_midi}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment