Created
July 14, 2017 10:17
-
-
Save naotokui/1ee163d22c94740568b45390245ab539 to your computer and use it in GitHub Desktop.
GAN-based drumloop generation (work in progress)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# GAN-808 - Drumloop generator\n", | |
"\n", | |
"You can listen to generated loops here: \n", | |
"https://soundcloud.com/user-890879658\n", | |
"\n", | |
"\n", | |
"# Constants" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ['CUDA_VISIBLE_DEVICES'] = '0' # only relevant to my own environment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"len_seq = 16 # length of drumloops in training data\n", | |
" # 16 x 16th note = 1 bar \n", | |
"\n", | |
"min_drum_note = 35 # drum instruments are mapped in between 35 - 81 \n", | |
"max_drum_note = 82\n", | |
"nb_notes = 48 # number of possible MIDI notes - max_drum_note - min_drum_note\n", | |
"\n", | |
"\n", | |
"len_input = 100 # size of input array" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Preprocessing\n", | |
"\n", | |
"MIDI file -> piano roll-like matrix data. unit timestep = 16th note\n", | |
"\n", | |
"Lakh MIDI Dataset\n", | |
"http://colinraffel.com/projects/lmd/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import joblib\n", | |
"import pretty_midi\n", | |
"import glob\n", | |
"\n", | |
"# unit timestep = 16th note\n", | |
"def get_pianomatrices_of_drums(midi_file, is_drum=True):\n", | |
" # load midi file\n", | |
" try:\n", | |
" pm = pretty_midi.PrettyMIDI(midi_file)\n", | |
" except:\n", | |
" print \"Failed to load midi: \", midi_file\n", | |
" return None\n", | |
" \n", | |
" # get timiming of quater notes\n", | |
" beats = pm.get_beats()\n", | |
" if len(beats) < 256: \n", | |
" return None # too short\n", | |
" \n", | |
" # make it to 16th notes\n", | |
" resolution = 4 # 16th note\n", | |
" beats_times = []\n", | |
" beats_ticks = [] # converting to ticks\n", | |
" for i, beat in enumerate(beats[:-1]):\n", | |
" next_beat = beats[i + 1]\n", | |
" note_time = (next_beat - beat)/float(resolution)\n", | |
" for j in range(resolution):\n", | |
" t = beat + j * note_time;\n", | |
" beats_times.append(t)\n", | |
" beats_ticks.append(pm.time_to_tick(t))\n", | |
" \n", | |
" num_notes16 = len(beats_ticks)\n", | |
"\n", | |
" # limit the maximum timesteps i.e number of 16th note, to (max_notes16)\n", | |
" max_notes16 = 3000\n", | |
"\n", | |
" # convert ticks to index of 16th note\n", | |
" def find_note_index(tick, beats_ticks):\n", | |
" for i, (t1, t2) in enumerate(zip(beats_ticks[:-1],beats_ticks[1:])):\n", | |
" if tick >= t1 and tick < t2:\n", | |
" return i\n", | |
" return beats_ticks[-1]\n", | |
"\n", | |
" # create pianoroll matrix (resolution: 16th note)\n", | |
" pianorolls =[]\n", | |
" for instrument in pm.instruments:\n", | |
" if instrument.is_drum == is_drum: # use drum tracks only\n", | |
"# print \"instrument\", instrument.name\n", | |
" pianoroll = np.zeros((max_notes16, nb_notes), dtype='bool') \n", | |
" for note in instrument.notes:\n", | |
" idx_start = find_note_index(pm.time_to_tick(note.start), beats_ticks)\n", | |
" idx_end = find_note_index(pm.time_to_tick(note.end), beats_ticks)\n", | |
" if idx_end < max_notes16:\n", | |
" for i in range(idx_start, idx_end):\n", | |
" pianoroll[i, note.pitch] = 1\n", | |
" else:\n", | |
" break\n", | |
" pianorolls.append(pianoroll)\n", | |
" return np.array(pianorolls, dtype='bool')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Start creating piano-roll matrix\n", | |
"matrices_drums = joblib.Parallel(n_jobs=10, verbose=5)(\n", | |
" joblib.delayed(get_pianomatrices_of_drums)(midi_file, True) \n", | |
" for midi_file in glob.glob(os.path.join('data', 'clean_midi', '*', '*.mid')))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from tqdm import tqdm\n", | |
"\n", | |
"# min number of instruments appeared in the loop\n", | |
"min_nb_inst = 3 \n", | |
"\n", | |
"def get_sequences(_matrices, threshold=0.1):\n", | |
" sequences = []\n", | |
" print _matrices.shape[0]\n", | |
" for i in tqdm(range(_matrices.shape[0])):\n", | |
" _m = _matrices[i]\n", | |
" if _m is not None:\n", | |
" for j in range(_m.shape[0]):\n", | |
" m = _m[j]\n", | |
" for j in range(0, m.shape[0] - len_seq, len_seq):\n", | |
" seq = m[j: j + len_seq] \n", | |
" \n", | |
" appearence = np.sum(seq, axis=0)\n", | |
"\n", | |
" has_enough = False\n", | |
"# if appearence[36] > 0 or appearence[35] > 0: # must contains kick!\n", | |
" if np.count_nonzero(appearence) >= min_nb_inst:\n", | |
" has_enough = True\n", | |
"\n", | |
" if has_enough: \n", | |
" beats = np.sum(seq, axis=1)\n", | |
" nb_non_zero = np.count_nonzero(beats)\n", | |
" if nb_non_zero / float(len_seq) > threshold: # ignore too sparse sequence \n", | |
" sequences.append(m[j: j + len_seq, min_drum_note:max_drum_note+1])\n", | |
" print \"# of total sequences\", len(sequences) \n", | |
" return np.array(sequences)\n", | |
"\n", | |
"drum_seqs = get_sequences(np.array(matrices_drums), threshold=0.25)\n", | |
"np.savez(\"tmp/drum_sequences.npz\", drum_seqs=drum_seqs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Discriminator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_1 (InputLayer) (None, 16, 48, 1) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_1 (Conv2D) (None, 14, 23, 256) 2560 \n", | |
"_________________________________________________________________\n", | |
"leaky_re_lu_1 (LeakyReLU) (None, 14, 23, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_1 (Dropout) (None, 14, 23, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_2 (Conv2D) (None, 6, 11, 256) 590080 \n", | |
"_________________________________________________________________\n", | |
"leaky_re_lu_2 (LeakyReLU) (None, 6, 11, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_2 (Dropout) (None, 6, 11, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_3 (Conv2D) (None, 2, 5, 512) 1180160 \n", | |
"_________________________________________________________________\n", | |
"leaky_re_lu_3 (LeakyReLU) (None, 2, 5, 512) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_3 (Dropout) (None, 2, 5, 512) 0 \n", | |
"_________________________________________________________________\n", | |
"flatten_1 (Flatten) (None, 5120) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_1 (Dense) (None, 1024) 5243904 \n", | |
"_________________________________________________________________\n", | |
"dropout_4 (Dropout) (None, 1024) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_2 (Dense) (None, 1) 1025 \n", | |
"=================================================================\n", | |
"Total params: 7,017,729\n", | |
"Trainable params: 7,017,729\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"from keras.layers import Input, Dense, Flatten, Dropout\n", | |
"from keras.layers.convolutional import Conv2D\n", | |
"from keras.layers.normalization import BatchNormalization\n", | |
"from keras.layers.advanced_activations import LeakyReLU\n", | |
"from keras.optimizers import RMSprop, Adam\n", | |
"import keras.backend as K\n", | |
"from keras.models import Model\n", | |
"\n", | |
"dropout_rate = 0.4\n", | |
"\n", | |
"inputs = Input(shape=(len_seq,nb_notes, 1)) # tensorflow order\n", | |
"x = Conv2D(256, (3, 3), padding='valid', strides=(1,2))(inputs)\n", | |
"x = LeakyReLU(alpha=0.2)(x)\n", | |
"x = Dropout(dropout_rate)(x)\n", | |
"x = Conv2D(256, (3, 3), padding='valid', strides=(2,2))(x)\n", | |
"x = LeakyReLU(alpha=0.2)(x)\n", | |
"#x = BatchNormalization()(x)\n", | |
"x = Dropout(dropout_rate)(x)\n", | |
"x = Conv2D(512, (3, 3), padding='valid', strides=(2,2))(x)\n", | |
"x = LeakyReLU(alpha=0.2)(x)\n", | |
"#x = BatchNormalization()(x)\n", | |
"x = Dropout(dropout_rate)(x)\n", | |
"x = Flatten()(x)\n", | |
"x = Dense(1024, activation='relu')(x)\n", | |
"x = Dropout(dropout_rate)(x)\n", | |
"x = Dense(1, activation='sigmoid')(x)\n", | |
"\n", | |
"discriminator = Model(inputs, x)\n", | |
"discriminator.summary()\n", | |
"\n", | |
"optimizer = Adam(lr=0.0005) # higher leraning rate for D\n", | |
"discriminator.compile(optimizer=optimizer, loss='binary_crossentropy', \n", | |
" metrics=['binary_accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# GENERATOR " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_2 (InputLayer) (None, 100) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_3 (Dense) (None, 12288) 1241088 \n", | |
"_________________________________________________________________\n", | |
"leaky_re_lu_4 (LeakyReLU) (None, 12288) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_1 (Batch (None, 12288) 49152 \n", | |
"_________________________________________________________________\n", | |
"reshape_1 (Reshape) (None, 2, 48, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_5 (Dropout) (None, 2, 48, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_transpose_1 (Conv2DTr (None, 4, 48, 64) 41024 \n", | |
"_________________________________________________________________\n", | |
"leaky_re_lu_5 (LeakyReLU) (None, 4, 48, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_2 (Batch (None, 4, 48, 64) 256 \n", | |
"_________________________________________________________________\n", | |
"conv2d_transpose_2 (Conv2DTr (None, 8, 48, 32) 10272 \n", | |
"_________________________________________________________________\n", | |
"leaky_re_lu_6 (LeakyReLU) (None, 8, 48, 32) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_3 (Batch (None, 8, 48, 32) 128 \n", | |
"_________________________________________________________________\n", | |
"conv2d_transpose_3 (Conv2DTr (None, 16, 48, 1) 513 \n", | |
"_________________________________________________________________\n", | |
"activation_1 (Activation) (None, 16, 48, 1) 0 \n", | |
"=================================================================\n", | |
"Total params: 1,342,433\n", | |
"Trainable params: 1,317,665\n", | |
"Non-trainable params: 24,768\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.layers import Reshape, Conv2DTranspose, RepeatVector, Activation\n", | |
"from keras.layers.normalization import BatchNormalization\n", | |
"from keras.layers.convolutional import UpSampling2D\n", | |
"\n", | |
"inputs2 = Input(shape=(len_input,)) # tensorflow order\n", | |
"x = Dense(2 * 48 * 128)(inputs2)\n", | |
"x = LeakyReLU(alpha=0.2)(x)\n", | |
"x = BatchNormalization(momentum=0.9)(x)\n", | |
"x = Reshape((2, 48, 128))(x)\n", | |
"x = Dropout(dropout_rate)(x)\n", | |
"x = Conv2DTranspose(64, (5, 1), padding='same', strides=(2, 1))(x)\n", | |
"x = LeakyReLU(alpha=0.2)(x)\n", | |
"x = BatchNormalization(momentum=0.9, axis=-1)(x)\n", | |
"x = Conv2DTranspose(32, (5, 1), padding='same', strides=(2, 1))(x)\n", | |
"x = LeakyReLU(alpha=0.2)(x)\n", | |
"x = BatchNormalization(momentum=0.9, axis=-1)(x)\n", | |
"x = Conv2DTranspose(1, (16, 1), padding='same', strides=(2, 1))(x)\n", | |
"x = Activation('tanh')(x)\n", | |
"\n", | |
"generator = Model(inputs2, x)\n", | |
"generator.summary()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# ADVERSARIAL MODEL " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"model_2 (Model) (None, 16, 48, 1) 1342433 \n", | |
"_________________________________________________________________\n", | |
"model_1 (Model) (None, 1) 7017729 \n", | |
"=================================================================\n", | |
"Total params: 8,360,162\n", | |
"Trainable params: 1,317,665\n", | |
"Non-trainable params: 7,042,497\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"\n", | |
"from keras.models import Sequential\n", | |
"\n", | |
"# training is disable for discriminator in adversarial model\n", | |
"discriminator.trainable = False \n", | |
"\n", | |
"gan = Sequential()\n", | |
"gan.add(generator)\n", | |
"gan.add(discriminator)\n", | |
"\n", | |
"#optimizer = RMSprop(lr=0.0004, clipvalue=1.0, decay=3e-8)\n", | |
"optimizer = Adam(lr=0.0003)\n", | |
"gan.compile(optimizer=optimizer, loss='binary_crossentropy')\n", | |
"\n", | |
"gan.summary()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Create Z for generator\n", | |
"def get_noise(batch_size, len_input):\n", | |
"# noise = np.random.uniform(-1.0, 1.0, size=[batch_size, len_input])\n", | |
" \n", | |
" # better to use a spherical Z. according to https://github.com/soumith/ganhacks\n", | |
" noise = np.random.normal(0.0, 0.50, size=[batch_size, len_input])\n", | |
" return noise" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Training data\n", | |
"import numpy as np\n", | |
"drum_seqs = np.load('tmp/drum_sequences.npz')['drum_seqs']\n", | |
"print drum_seqs.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from IPython.display import clear_output\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"d_losses = []\n", | |
"a_losses = []\n", | |
"\n", | |
"m_a_loss = 0.0\n", | |
"m_d_loss = 0.0\n", | |
"\n", | |
"batch_size = 500\n", | |
"nb_epochs = 10000\n", | |
"nb_samples = drum_seqs.shape[0]\n", | |
"\n", | |
"K_unrolled = 4\n", | |
"\n", | |
"for epoch in range(nb_epochs):\n", | |
" for repeat in range(nb_samples/batch_size):\n", | |
" \n", | |
" for j in range(K_unrolled):\n", | |
" # training data\n", | |
" drum_train = drum_seqs[np.random.randint(0, drum_seqs.shape[0], size=batch_size), :, :]\n", | |
" drum_train = np.expand_dims(drum_train, axis=3)\n", | |
"\n", | |
" # generated samples\n", | |
" noise = get_noise(batch_size, len_input)\n", | |
" drum_generated = generator.predict(noise)\n", | |
" drum_fake = drum_generated \n", | |
"\n", | |
" # training D\n", | |
" x = np.concatenate([drum_train, drum_fake]) \n", | |
" y = np.ones([2 * batch_size, 1])\n", | |
" \n", | |
" y[:batch_size, :] = 0.9 # one-sided soft labeling\n", | |
" y[batch_size:, :] = 0 # label 0: fake 1: real\n", | |
"\n", | |
" d_loss = discriminator.train_on_batch(x, y)\n", | |
" \n", | |
" # cache for later update\n", | |
" cache_weights = discriminator.get_weights()\n", | |
" \n", | |
" m_d_loss += d_loss[0]\n", | |
" \n", | |
" # training G\n", | |
" y = np.ones([batch_size, 1]) # watch out the label! it should be one here \n", | |
" noise = get_noise(batch_size, len_input)\n", | |
" a_loss = gan.train_on_batch(noise, y)\n", | |
" m_a_loss += a_loss\n", | |
" \n", | |
" # update layer \n", | |
" discriminator.set_weights(cache_weights)\n", | |
" \n", | |
" if repeat % 100 == 0:\n", | |
" print \"epoch\", epoch, repeat\n", | |
" print \"d_loss\", m_d_loss/100., \"a_loss\", m_a_loss/100.\n", | |
" \n", | |
" # store history\n", | |
" d_losses.append(m_d_loss/100.)\n", | |
" a_losses.append(m_a_loss/100.)\n", | |
" m_a_loss = 0.0\n", | |
" m_d_loss = 0.0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"# MIDI Playback" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# midi playback\n", | |
"def note_matrix_to_sequence(mat, threshold = 0.5):\n", | |
" seq = []\n", | |
" for row in mat[:]:\n", | |
" arow = [[i, r] for i, r in enumerate(row) if r > threshold]\n", | |
" seq.append(arow)\n", | |
" return seq\n", | |
"\n", | |
"import OSC\n", | |
"max_poly = 6 # maximum number of instruments played at the same time\n", | |
"\n", | |
"client = OSC.OSCClient()\n", | |
"client.connect( ('10.0.1.14', 2014) ) \n", | |
"\n", | |
"def send_sequence_via_osc(seq):\n", | |
" ## the most basic ##\n", | |
" msg = OSC.OSCMessage()\n", | |
" msg.setAddress(\"/seq\")\n", | |
" msg.append(max_poly * 2)\n", | |
" \n", | |
" for notes in seq:\n", | |
" for i in range(max_poly):\n", | |
" if len(notes) > i:\n", | |
" msg.append([notes[i][0]+min_drum_note, notes[i][1]])\n", | |
" else:\n", | |
" msg.append([0, 0])\n", | |
" client.send(msg)\n", | |
" \n", | |
"def playback_seq_via_osc(mat):\n", | |
" send_sequence_via_osc(note_matrix_to_sequence(mat, 0.2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time\n", | |
"\n", | |
"repeat = 1\n", | |
"\n", | |
"for j in range(100):\n", | |
" noise1 = get_noise(1, len_input) \n", | |
" noise2 = get_noise(1, len_input) \n", | |
" \n", | |
" for i in range(repeat):\n", | |
" noise = noise1 * (1.0 - i/float(repeat)) + noise2 * i/float(repeat)\n", | |
" drum_generated = generator.predict(noise)\n", | |
" mat = np.squeeze(drum_generated)\n", | |
" playback_seq_via_osc(mat)\n", | |
" time.sleep(4.0)\n", | |
" clear_output(wait=True)" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [default]", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment