Created
April 7, 2023 09:59
-
-
Save AlexApps99/974d45d71f94a9187209cf4e31780a21 to your computer and use it in GitHub Desktop.
Convert a .WAV music file into MIDI using a bunch of numbers and stuff
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Converts a .WAV music file into MIDI using a bunch of numbers and stuff | |
The gist of the program is: | |
- Load a mono WAV file | |
- Move a rolling window along the audio | |
- Multiply that window by a "hanning taper window" (this step is optional, it’s basically multiplying it by a bell-curve-like shape so the middle of the window has the most influence) | |
- Run an FFT on that window | |
- Set the velocity of each piano note to the average magnitude of its closest frequencies | |
- Copy the piano notes from each window into a MIDI file, holding down or skipping notes where needed | |
If you wanna run it, you'll need numpy and matplotlib (but you probably have that already), and you'll need to install mido (a Python MIDI library) from pip. | |
Just chuck a mono WAV file at "in.wav", and a MIDI file will be saved to "out.mid". | |
''' | |
import wave | |
import numpy as np | |
from mido import MetaMessage, Message, MidiFile, MidiTrack | |
def hz2mid(hz): | |
''' | |
Converts frequency to MIDI note ID (epsilon is added to prevent log2 0) | |
''' | |
return 12.0 * np.log2(hz / 440.0 + 1e-300) + 69.0 | |
def get_audio_data(path): | |
''' | |
Loads mono 16-bit WAV from path | |
''' | |
with wave.open(path, 'rb') as f: | |
nchannels, sampwidth, framerate, nframes, comptype, compname = f.getparams() | |
assert nchannels == 1, "WAV must be mono" | |
assert sampwidth == 2, "WAV must be 16-bit" | |
frames = f.readframes(nframes) | |
# Load as little-endian int16 | |
buf = np.frombuffer(frames, '<i2') | |
# Normalize to [-1, 1) | |
normalized_frames = np.float64(buf) / 32768 | |
return normalized_frames, framerate | |
def fft_rolling_windows(frames, framerate, steps_per_sec, wins_per_sec): | |
''' | |
Returns a generator of FFT results over a rolling window | |
''' | |
# Split audio into 25 windows per second (IDK a better approach) | |
win_size = framerate // wins_per_sec | |
half_win_size = win_size // 2 | |
step_size = framerate // steps_per_sec | |
taper = np.hanning(win_size) | |
frames = np.pad(frames, max(half_win_size, win_size-half_win_size)) | |
# size of float64 in bytes | |
el_size = 8 | |
rolling_windows = np.lib.stride_tricks.as_strided(frames, [frames.size // step_size, win_size], [step_size*el_size, 1*el_size]) * taper | |
rffts = np.fft.rfft(rolling_windows) | |
freq = np.fft.rfftfreq(win_size, 1/framerate) | |
return rffts, freq | |
def note_velocities(rffts, freq): | |
''' | |
Lumps the FFT frequencies into individual note velocities | |
''' | |
global midi_output_matrix | |
mid_notes = hz2mid(freq) | |
# Each column of the matrix is dot-producted with each RFFT row, leaving a velocity for each MIDI note. | |
# A matrix column corresponds to a MIDI note, and a row corresponds to the weighting of a given RFFT frequency. | |
midi_output_matrix = np.zeros((len(freq), 128), dtype=np.float64) | |
for i, note in enumerate(mid_notes): | |
if note > -1 and note < 128 and round(note) >= 0 and round(note) <= 127: | |
midi_output_matrix[i, round(note)] = 1 | |
# Rather than having each frequency have a weight of 1, they should have a weight of 1/(number of frequencies contributing to note), so it's more like an average | |
column_sum = midi_output_matrix.sum(axis=0) | |
midi_output_matrix *= np.reciprocal(column_sum, where=column_sum != 0) | |
notes_list = np.matmul(np.abs(rffts), midi_output_matrix) | |
return notes_list | |
def make_midi(notes_list, steps_per_sec, vol_multiply=1, max_note=None): | |
''' | |
Creates a mido MidiFile object with the provided note data | |
''' | |
max_vel = max(max(notes) for notes in notes_list) | |
print("Max vel:", max_vel) | |
mid = MidiFile(type=0, ticks_per_beat=1) | |
track = mid.add_track() | |
track.append(MetaMessage('set_tempo', tempo=1000000//steps_per_sec, time=0)) | |
# TODO hold note until velocity is too different | |
note_vel = [0 for n in range(128)] | |
for notes in notes_list: | |
# outliers seem to make a lot of things too quiet | |
notes_tweaked = [min(round((v / max_vel)*127 * vol_multiply), 127) for v in notes] | |
# filter out relevant notes that are similar to note_velocities | |
notes_indexed = [(i, v) for i, v in enumerate(notes_tweaked) if (max_note is None or i < max_note) and abs(note_vel[i]-v) >= 8] | |
# for each note: if note_velocities is zero, do note_on, otherwise, do note_off note_on | |
for i, vel in notes_indexed: | |
# clip quiet notes | |
if vel <= 4: | |
vel = 0 | |
if note_vel[i] != 0: | |
track.append(Message('note_off', note=i, velocity=note_vel[i], time=0)) | |
if vel != 0: | |
track.append(Message('note_on', note=i, velocity=vel, time=0)) | |
note_vel[i] = vel | |
track.append(Message('sysex', data=[], time=1)) | |
# Remove all held notes at end of song | |
for i, v in enumerate(note_vel): | |
track.append(Message('note_off', note=i, velocity=v, time=0)) | |
note_vel[i] = 0 | |
return mid | |
if __name__ == "__main__": | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
STEPS_PER_SEC = 10 | |
WINS_PER_SEC = 5 | |
VOL_MULTIPLY = 2.0 | |
MAX_NOTE = None # was 84 | |
frames, framerate = get_audio_data("in.wav") | |
print(f"Loaded WAV at {framerate} Hz sample rate ({frames.size} samples)") | |
rffts, freq = fft_rolling_windows(frames, framerate, steps_per_sec=STEPS_PER_SEC, wins_per_sec=WINS_PER_SEC) | |
print("Calculated FFT") | |
notes_list = note_velocities(rffts, freq) | |
print("Generated note velocities") | |
make_midi(notes_list, steps_per_sec=STEPS_PER_SEC, vol_multiply=VOL_MULTIPLY, max_note=MAX_NOTE).save('out.mid') | |
print("Saved MIDI") | |
fig, (plot_a, plot_b, plot_c) = plt.subplots(3) | |
plot_a.title.set_text("FFT") | |
m1 = plot_a.matshow(np.abs(rffts.transpose()), aspect='auto') | |
plot_a.set_xlabel('steps') | |
plot_a.set_ylabel('frequencies (TODO make the key clearer)') | |
fig.colorbar(m1, ax=plot_a) | |
plot_b.title.set_text("MIDI note velocity") | |
m2 = plot_b.matshow(notes_list.transpose(), aspect='auto') | |
plot_b.set_xlabel('steps') | |
plot_b.set_ylabel('MIDI notes') | |
fig.colorbar(m2, ax=plot_b) | |
plot_c.title.set_text("Translation table") | |
m3 = plot_c.matshow(midi_output_matrix, aspect='auto') | |
fig.colorbar(m3, ax=plot_c) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment