Created
September 21, 2024 17:30
-
-
Save twobob/73bd328037ec77b7257fd38fc4b9ce45 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import re | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import pretty_midi | |
from basic_pitch.inference import predict_and_save, predict | |
from basic_pitch import ICASSP_2022_MODEL_PATH | |
from mido import MidiFile, MetaMessage, bpm2tempo | |
import subprocess | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
tf.get_logger().setLevel('ERROR') | |
REPROCESS_ALL = False | |
REPROCESS_ALL_MIDI = False | |
REPROCESS_ALL_TAGS = True | |
YAMNET_MODEL_HANDLE = 'https://tfhub.dev/google/yamnet/1' | |
yamnet_model = hub.load(YAMNET_MODEL_HANDLE) | |
yamnet_classes = [] | |
class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8') | |
with tf.io.gfile.GFile(class_map_path) as f: | |
for line in f.readlines()[1:]: | |
_, _, display_name = line.strip().split(',', 2) | |
yamnet_classes.append(display_name) | |
def extract_info_from_filename(filename): | |
bpm_match = re.search(r'(\d+)Bpm', filename) | |
key_match = re.search(r'=(\w+[#b]?m?)', filename) | |
bpm = int(bpm_match.group(1)) if bpm_match else None | |
key_name = key_match.group(1) if key_match else None | |
return bpm, key_name | |
def create_piano_roll(midi_file, output_image): | |
midi_data = pretty_midi.PrettyMIDI(midi_file) | |
piano_roll = midi_data.get_piano_roll(fs=100) | |
plt.figure(figsize=(12, 6)) | |
plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues') | |
plt.title('Piano Roll Representation') | |
plt.ylabel('MIDI Note Number') | |
plt.xlabel('Time (centiseconds)') | |
plt.colorbar(label='Velocity') | |
plt.savefig(output_image, dpi=300, bbox_inches='tight') | |
plt.close() | |
def add_tempo_to_midi(midi_file, bpm): | |
mid = MidiFile(midi_file) | |
tempo = bpm2tempo(bpm) | |
meta = MetaMessage('set_tempo', tempo=tempo, time=0) | |
mid.tracks[0].insert(0, meta) | |
mid.save(midi_file) | |
def separate_audio_with_demucs(audio_file, output_dir): | |
model_name = "htdemucs" | |
base_name = os.path.splitext(os.path.basename(audio_file))[0] | |
separated_dir = os.path.join(output_dir, model_name, base_name) | |
if os.path.exists(separated_dir) and any(f.endswith('.wav') for f in os.listdir(separated_dir)) and not REPROCESS_ALL: | |
#print(f"Stems already exist in {separated_dir}. Skipping separation.") | |
return separated_dir | |
command = f"demucs --out \"{output_dir}\" \"{audio_file}\"" | |
subprocess.run(command, shell=True) | |
return separated_dir | |
def process_stem_with_basic_pitch(stem_audio_file, bpm=None): | |
if bpm: | |
model_output, midi_data, note_events = predict( | |
audio_path=stem_audio_file, | |
model_or_model_path=ICASSP_2022_MODEL_PATH, | |
midi_tempo=bpm | |
) | |
else: | |
model_output, midi_data, note_events = predict( | |
audio_path=stem_audio_file, | |
model_or_model_path=ICASSP_2022_MODEL_PATH | |
) | |
return midi_data | |
def merge_midi_files(midi_data_list): | |
merged_midi = pretty_midi.PrettyMIDI() | |
for midi_data in midi_data_list: | |
for instrument in midi_data.instruments: | |
merged_midi.instruments.append(instrument) | |
return merged_midi | |
def should_skip(file_path, file_type='general'): | |
if os.path.exists(file_path): | |
if file_type == 'midi': | |
return not (REPROCESS_ALL or REPROCESS_ALL_MIDI) | |
elif file_type == 'tag': | |
return not (REPROCESS_ALL or REPROCESS_ALL_TAGS) | |
else: | |
return not REPROCESS_ALL | |
return False | |
def overwrite_rename(src, dst): | |
if os.path.exists(dst): | |
os.remove(dst) | |
os.rename(src, dst) | |
def auto_label_audio(audio_file, output_txt_file): | |
if os.path.exists(output_txt_file) and not (REPROCESS_ALL or REPROCESS_ALL_TAGS): | |
#print(f"Label file {output_txt_file} already exists and REPROCESS_ALL_TAGS is False. Skipping auto-labeling.") | |
return | |
#print(f"Auto-labeling: {audio_file}") | |
wav_data, sr = sf.read(audio_file) | |
if len(wav_data.shape) > 1: | |
wav_data = np.mean(wav_data, axis=1) | |
target_sr = 16000 | |
if sr != target_sr: | |
wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=target_sr) | |
wav_data = wav_data.astype(np.float32) | |
scores, embeddings, spectrogram = yamnet_model(wav_data) | |
mean_scores = np.mean(scores.numpy(), axis=0) | |
top_k = 10 | |
top_class_indices = np.argsort(mean_scores)[::-1][:top_k] | |
top_scores = mean_scores[top_class_indices] | |
top_labels = [yamnet_classes[i] for i in top_class_indices] | |
with open(output_txt_file, 'w') as f: | |
for label, score in zip(top_labels, top_scores): | |
f.write(f"{label}: {score:.4f}\n") | |
#print(f"Auto-labeling completed for {audio_file}. Labels saved to {output_txt_file}") | |
def audio_to_midi_with_demucs(audio_file, bpm=None): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi_file = f"{base_name}.mid" | |
output_txt_file = f"{base_name}.txt" | |
auto_label_audio(audio_file, output_txt_file) | |
if should_skip(output_midi_file, file_type='midi'): | |
#print(f"MIDI file {output_midi_file} already exists. Skipping processing.") | |
return output_midi_file | |
#print(f"Processing with Demucs and Basic Pitch: {audio_file}") | |
separated_dir = separate_audio_with_demucs(audio_file, os.path.dirname(audio_file)) | |
#print(f"Separated audio stored in {separated_dir}") | |
stem_files = [os.path.join(separated_dir, f) for f in os.listdir(separated_dir) if f.endswith('.wav')] | |
midi_data_list = [] | |
for stem_file in stem_files: | |
stem_base = os.path.splitext(stem_file)[0] | |
stem_midi_file = f"{stem_base}.mid" | |
stem_txt_file = f"{stem_base}.txt" | |
auto_label_audio(stem_file, stem_txt_file) | |
if should_skip(stem_midi_file, file_type='midi'): | |
#print(f"Stem MIDI file {stem_midi_file} already exists. Skipping processing.") | |
midi_data = pretty_midi.PrettyMIDI(stem_midi_file) | |
else: | |
#print(f"Processing stem: {stem_file}") | |
midi_data = process_stem_with_basic_pitch(stem_file, bpm) | |
midi_data.write(stem_midi_file) | |
#print(f"Stem MIDI file saved as: {stem_midi_file}") | |
midi_data_list.append(midi_data) | |
merged_midi = merge_midi_files(midi_data_list) | |
merged_midi.write(output_midi_file) | |
#print(f"Merged MIDI file saved as: {output_midi_file}") | |
if bpm: | |
add_tempo_to_midi(output_midi_file, bpm) | |
print(f"Added BPM {bpm} to MIDI file.") | |
return output_midi_file | |
def audio_to_midi_basic_pitch(audio_file, bpm=None): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi_file = f"{base_name}.mid" | |
output_txt_file = f"{base_name}.txt" | |
auto_label_audio(audio_file, output_txt_file) | |
if should_skip(output_midi_file, file_type='midi'): | |
print(f"MIDI file {output_midi_file} already exists. Skipping processing.") | |
return output_midi_file | |
print(f"Processing with Basic Pitch: {audio_file}") | |
predict_and_save( | |
audio_path_list=[audio_file], | |
output_directory=os.path.dirname(audio_file), | |
save_midi=True, | |
sonify_midi=False, | |
save_model_outputs=False, | |
save_notes=False, | |
multiple_pitch_bends=False, | |
model_or_model_path=ICASSP_2022_MODEL_PATH | |
) | |
midi_file_generated = f"{base_name}_basic_pitch.mid" | |
if bpm: | |
add_tempo_to_midi(midi_file_generated, bpm) | |
print(f"Added BPM {bpm} to MIDI file.") | |
overwrite_rename(midi_file_generated, output_midi_file) | |
print(f"MIDI file renamed and saved as: {output_midi_file}") | |
return output_midi_file | |
def audio_to_midi(audio_file): | |
base_name = os.path.splitext(audio_file)[0] | |
output_midi_file = f"{base_name}.mid" | |
output_image = f"{base_name}.png" | |
output_txt_file = f"{base_name}.txt" | |
print(f"Processing: {audio_file}") | |
auto_label_audio(audio_file, output_txt_file) | |
if should_skip(output_midi_file, file_type='midi'): | |
#print(f"MIDI file {output_midi_file} already exists. Skipping processing.") | |
return output_midi_file | |
bpm, key_name = extract_info_from_filename(os.path.basename(audio_file)) | |
y, sr = librosa.load(audio_file, sr=None, mono=True) | |
duration = librosa.get_duration(y=y, sr=sr) | |
#print(f"Audio duration: {duration:.2f} seconds") | |
if duration < 600: | |
output_midi_file = audio_to_midi_with_demucs(audio_file, bpm) | |
else: | |
output_midi_file = audio_to_midi_basic_pitch(audio_file, bpm) | |
if not should_skip(output_midi_file, file_type='midi'): | |
create_piano_roll(output_midi_file, output_image) | |
#print(f"Piano roll visualization saved as: {output_image}") | |
else: | |
if REPROCESS_ALL or REPROCESS_ALL_MIDI: | |
create_piano_roll(output_midi_file, output_image) | |
print(f"Piano roll visualization saved as: {output_image}") | |
elif not os.path.exists(output_image): | |
create_piano_roll(output_midi_file, output_image) | |
#print(f"Piano roll visualization saved as: {output_image}") | |
else: | |
pass | |
#print(f"Piano roll image {output_image} already exists. Skipping visualization.") | |
#print("-" * 50) | |
return output_midi_file | |
def process_folder(folder_path): | |
audio_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))] | |
for audio_file in tqdm(audio_files, desc="Processing audio files", unit="file"): | |
full_path = os.path.join(folder_path, audio_file) | |
audio_to_midi(full_path) | |
input_path = r"E:\Dubstep_diffusion\tracks" | |
if os.path.isdir(input_path): | |
process_folder(input_path) | |
else: | |
output_midi = audio_to_midi(input_path) | |
print(f"MIDI file saved as: {output_midi}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment