Created
May 6, 2023 18:44
-
-
Save SametSahin10/22052c07855c51930b55855b1de043dd to your computer and use it in GitHub Desktop.
composer_instrument_classification.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMT2oIzww37DW4ujWdLf8HH", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/SametSahin10/22052c07855c51930b55855b1de043dd/composer_instrument_classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive/')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "t6OBJv6rfD6U", | |
"outputId": "d4378f86-8ab3-4d50-f698-f5b53a5c6bd1" | |
}, | |
"execution_count": 76, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount(\"/content/drive/\", force_remount=True).\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install pretty_midi" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "AFSSTwuJgJ2M", | |
"outputId": "ec54cae1-f533-45e4-f5b5-984db8db6c63" | |
}, | |
"execution_count": 77, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: pretty_midi in /usr/local/lib/python3.10/dist-packages (0.2.10)\n", | |
"Requirement already satisfied: mido>=1.1.16 in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.2.10)\n", | |
"Requirement already satisfied: numpy>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.22.4)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.16.0)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dNWUTrTme1lG", | |
"outputId": "b931e5a0-771d-46ac-e590-0f626708bb87" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Got an error while processing: /content/drive/MyDrive/collab/musicnet_midis/Bach/2310_prelude15.mid\n", | |
"Got an error while processing: /content/drive/MyDrive/collab/musicnet_midis/Bach/2230_prelude20.mid\n" | |
] | |
} | |
], | |
"source": [ | |
"import pretty_midi\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import tensorflow as tf\n", | |
"import pandas as pd\n", | |
"\n", | |
"def extractNotesAndVelocities(path):\n", | |
" # Load MIDI file\n", | |
" midi_data = pretty_midi.PrettyMIDI(path)\n", | |
" # Extract notes and velocities\n", | |
" notes = []\n", | |
" velocities = []\n", | |
" for instrument in midi_data.instruments:\n", | |
" for note in instrument.notes:\n", | |
" notes.append(note.pitch)\n", | |
" velocities.append(note.velocity)\n", | |
"\n", | |
" # Convert notes and velocities to Numpy arrays\n", | |
" notes = np.array(notes)\n", | |
" velocities = np.array(velocities)\n", | |
" return notes, velocities\n", | |
"\n", | |
"def createComposersDictionary(root_path):\n", | |
" composers = {}\n", | |
"\n", | |
" for folder in os.listdir(root_path):\n", | |
" for file in os.listdir(os.path.join(root_path, folder)):\n", | |
" path = f'{root_path}/{folder}/{file}'\n", | |
" try:\n", | |
" notes, velocities = extractNotesAndVelocities(path)\n", | |
" except:\n", | |
" print(f\"Got an error while processing: {path}\")\n", | |
" continue\n", | |
"\n", | |
" data = [notes, velocities]\n", | |
" dataAsNumpyArray = np.array(data)\n", | |
" \n", | |
" if folder in composers: \n", | |
" composers[folder].append(dataAsNumpyArray)\n", | |
" else:\n", | |
" composers[folder] = [dataAsNumpyArray]\n", | |
"\n", | |
" return composers\n", | |
"\n", | |
"composers = createComposersDictionary('/content/drive/MyDrive/collab/musicnet_midis')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = tf.keras.Sequential([\n", | |
" tf.keras.layers.Input(shape=(2, 400)),\n", | |
" tf.keras.layers.Flatten(),\n", | |
" tf.keras.layers.Dense(128, activation='relu'),\n", | |
" tf.keras.layers.Dense(128, activation='relu'),\n", | |
" tf.keras.layers.Dense(128, activation='relu'),\n", | |
" tf.keras.layers.Dense(10, activation=\"softmax\")\n", | |
"])\n", | |
"\n", | |
"model.compile(\n", | |
" loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n", | |
" optimizer=tf.keras.optimizers.Adam(),\n", | |
" metrics=[\"accuracy\"]\n", | |
")" | |
], | |
"metadata": { | |
"id": "pxHkGhK6s76u" | |
}, | |
"execution_count": 79, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"label_to_int = {\n", | |
" composer:index for index, composer in enumerate(os.listdir('/content/drive/MyDrive/collab/musicnet_midis'))\n", | |
"}" | |
], | |
"metadata": { | |
"id": "5x24_F1CjJwi" | |
}, | |
"execution_count": 80, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"label_to_int" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "VMeNPV8Sjt3g", | |
"outputId": "3f53625d-2da1-44d7-8f46-6d54eb3af7bb" | |
}, | |
"execution_count": 81, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'Mozart': 0,\n", | |
" 'Brahms': 1,\n", | |
" 'Dvorak': 2,\n", | |
" 'Haydn': 3,\n", | |
" 'Schubert': 4,\n", | |
" 'Bach': 5,\n", | |
" 'Beethoven': 6,\n", | |
" 'Faure': 7,\n", | |
" 'Ravel': 8,\n", | |
" 'Cambini': 9}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 81 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"labels = []\n", | |
"features = []\n", | |
"for k, v in composers.items():\n", | |
" for i in range(len(v)):\n", | |
" if v[i][0].shape[0] < 400: continue\n", | |
" n = v[i][0].shape[0] // 400\n", | |
" increment = 0\n", | |
" for index in range(n):\n", | |
" notes = v[i][0][increment:increment+400]\n", | |
" velocities = v[i][1][increment:increment+400]\n", | |
" labels.append(label_to_int[k])\n", | |
"\n", | |
" data = [notes, velocities]\n", | |
" dataAsNumpyArray = np.array(data)\n", | |
"\n", | |
" features.append(dataAsNumpyArray)\n", | |
" increment+=400" | |
], | |
"metadata": { | |
"id": "I6Ec_RHHh15W" | |
}, | |
"execution_count": 82, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"len(features), len(labels)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PGcqNUJDjBeP", | |
"outputId": "b87983ef-66af-45dc-dbb6-983b46363636" | |
}, | |
"execution_count": 83, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(2589, 2589)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 83 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"features = np.array(features)\n", | |
"labels = np.array(labels)" | |
], | |
"metadata": { | |
"id": "xHSo0wVsmLWO" | |
}, | |
"execution_count": 84, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)" | |
], | |
"metadata": { | |
"id": "SgomezGyk9J6" | |
}, | |
"execution_count": 85, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x_train.shape, x_test.shape, y_train.shape, y_test.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "uHEsr-IctFg1", | |
"outputId": "f8472338-900f-465b-8a1d-84871ace0ebc" | |
}, | |
"execution_count": 86, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((2071, 2, 400), (518, 2, 400), (2071,), (518,))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 86 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x_train[0]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "puBn-Oz_wCGB", | |
"outputId": "7c5ac3d5-aac2-4fda-f88c-8b938dff2d9a" | |
}, | |
"execution_count": 87, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[ 48, 53, 53, 48, 53, 50, 53, 58, 53, 55, 51, 48, 53,\n", | |
" 53, 57, 53, 58, 53, 53, 60, 53, 58, 53, 53, 58, 54,\n", | |
" 58, 54, 54, 55, 49, 53, 48, 53, 48, 57, 53, 53, 55,\n", | |
" 57, 58, 57, 58, 62, 60, 57, 53, 57, 58, 62, 60, 57,\n", | |
" 53, 65, 63, 65, 63, 62, 61, 63, 61, 60, 58, 57, 60,\n", | |
" 57, 58, 60, 61, 65, 69, 70, 72, 70, 69, 70, 72, 68,\n", | |
" 60, 60, 61, 63, 65, 68, 61, 66, 68, 66, 65, 66, 68,\n", | |
" 69, 70, 66, 63, 61, 60, 65, 65, 65, 75, 74, 70, 68,\n", | |
" 77, 67, 75, 68, 67, 68, 67, 63, 60, 58, 60, 62, 53,\n", | |
" 57, 58, 48, 60, 48, 48, 60, 48, 60, 48, 49, 61, 50,\n", | |
" 62, 50, 53, 65, 48, 60, 60, 60, 60, 57, 59, 61, 52,\n", | |
" 57, 59, 61, 52, 61, 62, 64, 65, 57, 60, 58, 57, 57,\n", | |
" 59, 61, 52, 57, 59, 61, 52, 61, 62, 64, 57, 66, 61,\n", | |
" 62, 64, 57, 66, 57, 59, 61, 60, 53, 54, 55, 60, 69,\n", | |
" 81, 80, 79, 76, 75, 72, 69, 65, 63, 60, 62, 74, 72,\n", | |
" 70, 67, 72, 70, 69, 60, 69, 69, 60, 69, 72, 72, 70,\n", | |
" 67, 72, 60, 69, 77, 77, 76, 67, 79, 69, 77, 81, 81,\n", | |
" 79, 81, 81, 79, 74, 77, 65, 65, 67, 67, 64, 64, 65,\n", | |
" 65, 60, 60, 48, 48, 60, 60, 48, 48, 60, 60, 48, 69,\n", | |
" 67, 60, 65, 65, 60, 60, 67, 67, 60, 60, 69, 69, 60,\n", | |
" 60, 70, 70, 60, 60, 64, 70, 64, 67, 62, 61, 62, 65,\n", | |
" 63, 60, 57, 63, 60, 57, 58, 57, 55, 63, 60, 57, 65,\n", | |
" 63, 62, 60, 59, 52, 53, 62, 65, 62, 59, 64, 52, 53,\n", | |
" 62, 65, 62, 59, 64, 65, 62, 64, 64, 60, 57, 64, 60,\n", | |
" 57, 65, 62, 59, 65, 62, 59, 67, 64, 60, 67, 64, 58,\n", | |
" 67, 64, 61, 67, 64, 61, 69, 65, 62, 69, 65, 62, 72,\n", | |
" 69, 66, 72, 67, 72, 72, 69, 63, 63, 72, 60, 63, 63,\n", | |
" 72, 54, 54, 54, 55, 55, 57, 57, 57, 58, 58, 60, 60,\n", | |
" 60, 61, 61, 62, 50, 61, 61, 62, 65, 64, 65, 68, 67,\n", | |
" 63, 67, 65, 62, 61, 62, 65, 63, 60, 63, 62, 56, 55,\n", | |
" 56, 62, 60, 51, 60, 59, 50, 49, 50, 62],\n", | |
" [ 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n", | |
" 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n", | |
" 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 80, 80,\n", | |
" 80, 80, 60, 80, 100, 80, 80, 60, 60, 80, 100, 80, 60,\n", | |
" 60, 100, 100, 100, 80, 80, 80, 87, 88, 89, 94, 100, 100,\n", | |
" 80, 80, 80, 60, 60, 60, 60, 61, 61, 61, 62, 63, 64,\n", | |
" 65, 66, 67, 68, 69, 70, 71, 72, 73, 73, 73, 74, 76,\n", | |
" 82, 90, 92, 94, 96, 98, 100, 80, 80, 100, 100, 100, 100,\n", | |
" 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 80,\n", | |
" 60, 40, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n", | |
" 68, 76, 84, 92, 100, 100, 80, 60, 40, 40, 40, 40, 40,\n", | |
" 40, 40, 40, 40, 40, 60, 66, 80, 80, 74, 72, 64, 40,\n", | |
" 40, 40, 40, 40, 40, 40, 40, 40, 46, 48, 52, 52, 64,\n", | |
" 70, 72, 76, 76, 88, 94, 96, 100, 100, 100, 80, 60, 60,\n", | |
" 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 74,\n", | |
" 76, 78, 80, 60, 40, 100, 100, 100, 90, 87, 84, 81, 71,\n", | |
" 68, 60, 60, 61, 62, 63, 66, 67, 69, 72, 73, 74, 75,\n", | |
" 78, 79, 81, 84, 85, 87, 90, 91, 92, 93, 96, 97, 98,\n", | |
" 100, 60, 60, 60, 60, 60, 60, 60, 60, 60, 66, 73, 80,\n", | |
" 80, 60, 60, 60, 60, 60, 60, 63, 66, 69, 78, 81, 84,\n", | |
" 88, 100, 80, 80, 60, 40, 40, 40, 40, 60, 60, 80, 80,\n", | |
" 60, 60, 40, 80, 60, 40, 80, 60, 40, 80, 60, 40, 100,\n", | |
" 100, 87, 74, 60, 60, 68, 76, 84, 92, 100, 100, 60, 60,\n", | |
" 70, 80, 90, 100, 100, 60, 60, 60, 60, 60, 60, 60, 60,\n", | |
" 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n", | |
" 60, 60, 60, 60, 61, 63, 65, 67, 69, 70, 72, 74, 76,\n", | |
" 78, 80, 82, 84, 86, 88, 90, 92, 94, 94, 96, 96, 98,\n", | |
" 98, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,\n", | |
" 100, 100, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80,\n", | |
" 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80,\n", | |
" 80, 80, 80, 81, 82, 83, 84, 86, 87, 88]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 87 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "A-P_oIp_l1d-", | |
"outputId": "aa8be3eb-82f0-4dc3-9ca8-15a2d6518a4e" | |
}, | |
"execution_count": 88, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1/10\n", | |
"65/65 [==============================] - 2s 12ms/step - loss: 14.2796 - accuracy: 0.3419 - val_loss: 4.4266 - val_accuracy: 0.3977\n", | |
"Epoch 2/10\n", | |
"65/65 [==============================] - 0s 5ms/step - loss: 3.9253 - accuracy: 0.3670 - val_loss: 2.4873 - val_accuracy: 0.4517\n", | |
"Epoch 3/10\n", | |
"65/65 [==============================] - 0s 5ms/step - loss: 2.7020 - accuracy: 0.3974 - val_loss: 2.2042 - val_accuracy: 0.4961\n", | |
"Epoch 4/10\n", | |
"65/65 [==============================] - 0s 5ms/step - loss: 2.2060 - accuracy: 0.4384 - val_loss: 2.8763 - val_accuracy: 0.1467\n", | |
"Epoch 5/10\n", | |
"65/65 [==============================] - 0s 7ms/step - loss: 1.8630 - accuracy: 0.4437 - val_loss: 1.6927 - val_accuracy: 0.4691\n", | |
"Epoch 6/10\n", | |
"65/65 [==============================] - 0s 6ms/step - loss: 1.6717 - accuracy: 0.4780 - val_loss: 1.6900 - val_accuracy: 0.4846\n", | |
"Epoch 7/10\n", | |
"65/65 [==============================] - 0s 6ms/step - loss: 1.5670 - accuracy: 0.5089 - val_loss: 1.8285 - val_accuracy: 0.4961\n", | |
"Epoch 8/10\n", | |
"65/65 [==============================] - 0s 5ms/step - loss: 1.6075 - accuracy: 0.5036 - val_loss: 1.6497 - val_accuracy: 0.5077\n", | |
"Epoch 9/10\n", | |
"65/65 [==============================] - 0s 5ms/step - loss: 1.4927 - accuracy: 0.5340 - val_loss: 1.4372 - val_accuracy: 0.5714\n", | |
"Epoch 10/10\n", | |
"65/65 [==============================] - 0s 5ms/step - loss: 1.4137 - accuracy: 0.5350 - val_loss: 1.5872 - val_accuracy: 0.4846\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment