Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save SametSahin10/22052c07855c51930b55855b1de043dd to your computer and use it in GitHub Desktop.
Save SametSahin10/22052c07855c51930b55855b1de043dd to your computer and use it in GitHub Desktop.
composer_instrument_classification.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMT2oIzww37DW4ujWdLf8HH",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/SametSahin10/22052c07855c51930b55855b1de043dd/composer_instrument_classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive/')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "t6OBJv6rfD6U",
"outputId": "d4378f86-8ab3-4d50-f698-f5b53a5c6bd1"
},
"execution_count": 76,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount(\"/content/drive/\", force_remount=True).\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install pretty_midi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AFSSTwuJgJ2M",
"outputId": "ec54cae1-f533-45e4-f5b5-984db8db6c63"
},
"execution_count": 77,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: pretty_midi in /usr/local/lib/python3.10/dist-packages (0.2.10)\n",
"Requirement already satisfied: mido>=1.1.16 in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.2.10)\n",
"Requirement already satisfied: numpy>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.22.4)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.16.0)\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dNWUTrTme1lG",
"outputId": "b931e5a0-771d-46ac-e590-0f626708bb87"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Got an error while processing: /content/drive/MyDrive/collab/musicnet_midis/Bach/2310_prelude15.mid\n",
"Got an error while processing: /content/drive/MyDrive/collab/musicnet_midis/Bach/2230_prelude20.mid\n"
]
}
],
"source": [
"import pretty_midi\n",
"import numpy as np\n",
"import os\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"\n",
"def extractNotesAndVelocities(path):\n",
" # Load MIDI file\n",
" midi_data = pretty_midi.PrettyMIDI(path)\n",
" # Extract notes and velocities\n",
" notes = []\n",
" velocities = []\n",
" for instrument in midi_data.instruments:\n",
" for note in instrument.notes:\n",
" notes.append(note.pitch)\n",
" velocities.append(note.velocity)\n",
"\n",
" # Convert notes and velocities to Numpy arrays\n",
" notes = np.array(notes)\n",
" velocities = np.array(velocities)\n",
" return notes, velocities\n",
"\n",
"def createComposersDictionary(root_path):\n",
" composers = {}\n",
"\n",
" for folder in os.listdir(root_path):\n",
" for file in os.listdir(os.path.join(root_path, folder)):\n",
" path = f'{root_path}/{folder}/{file}'\n",
" try:\n",
" notes, velocities = extractNotesAndVelocities(path)\n",
" except:\n",
" print(f\"Got an error while processing: {path}\")\n",
" continue\n",
"\n",
" data = [notes, velocities]\n",
" dataAsNumpyArray = np.array(data)\n",
" \n",
" if folder in composers: \n",
" composers[folder].append(dataAsNumpyArray)\n",
" else:\n",
" composers[folder] = [dataAsNumpyArray]\n",
"\n",
" return composers\n",
"\n",
"composers = createComposersDictionary('/content/drive/MyDrive/collab/musicnet_midis')"
]
},
{
"cell_type": "code",
"source": [
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Input(shape=(2, 400)),\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(128, activation='relu'),\n",
" tf.keras.layers.Dense(128, activation='relu'),\n",
" tf.keras.layers.Dense(128, activation='relu'),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])\n",
"\n",
"model.compile(\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
" optimizer=tf.keras.optimizers.Adam(),\n",
" metrics=[\"accuracy\"]\n",
")"
],
"metadata": {
"id": "pxHkGhK6s76u"
},
"execution_count": 79,
"outputs": []
},
{
"cell_type": "code",
"source": [
"label_to_int = {\n",
" composer:index for index, composer in enumerate(os.listdir('/content/drive/MyDrive/collab/musicnet_midis'))\n",
"}"
],
"metadata": {
"id": "5x24_F1CjJwi"
},
"execution_count": 80,
"outputs": []
},
{
"cell_type": "code",
"source": [
"label_to_int"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VMeNPV8Sjt3g",
"outputId": "3f53625d-2da1-44d7-8f46-6d54eb3af7bb"
},
"execution_count": 81,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'Mozart': 0,\n",
" 'Brahms': 1,\n",
" 'Dvorak': 2,\n",
" 'Haydn': 3,\n",
" 'Schubert': 4,\n",
" 'Bach': 5,\n",
" 'Beethoven': 6,\n",
" 'Faure': 7,\n",
" 'Ravel': 8,\n",
" 'Cambini': 9}"
]
},
"metadata": {},
"execution_count": 81
}
]
},
{
"cell_type": "code",
"source": [
"labels = []\n",
"features = []\n",
"for k, v in composers.items():\n",
" for i in range(len(v)):\n",
" if v[i][0].shape[0] < 400: continue\n",
" n = v[i][0].shape[0] // 400\n",
" increment = 0\n",
" for index in range(n):\n",
" notes = v[i][0][increment:increment+400]\n",
" velocities = v[i][1][increment:increment+400]\n",
" labels.append(label_to_int[k])\n",
"\n",
" data = [notes, velocities]\n",
" dataAsNumpyArray = np.array(data)\n",
"\n",
" features.append(dataAsNumpyArray)\n",
" increment+=400"
],
"metadata": {
"id": "I6Ec_RHHh15W"
},
"execution_count": 82,
"outputs": []
},
{
"cell_type": "code",
"source": [
"len(features), len(labels)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PGcqNUJDjBeP",
"outputId": "b87983ef-66af-45dc-dbb6-983b46363636"
},
"execution_count": 83,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2589, 2589)"
]
},
"metadata": {},
"execution_count": 83
}
]
},
{
"cell_type": "code",
"source": [
"features = np.array(features)\n",
"labels = np.array(labels)"
],
"metadata": {
"id": "xHSo0wVsmLWO"
},
"execution_count": 84,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)"
],
"metadata": {
"id": "SgomezGyk9J6"
},
"execution_count": 85,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x_train.shape, x_test.shape, y_train.shape, y_test.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uHEsr-IctFg1",
"outputId": "f8472338-900f-465b-8a1d-84871ace0ebc"
},
"execution_count": 86,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((2071, 2, 400), (518, 2, 400), (2071,), (518,))"
]
},
"metadata": {},
"execution_count": 86
}
]
},
{
"cell_type": "code",
"source": [
"x_train[0]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "puBn-Oz_wCGB",
"outputId": "7c5ac3d5-aac2-4fda-f88c-8b938dff2d9a"
},
"execution_count": 87,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 48, 53, 53, 48, 53, 50, 53, 58, 53, 55, 51, 48, 53,\n",
" 53, 57, 53, 58, 53, 53, 60, 53, 58, 53, 53, 58, 54,\n",
" 58, 54, 54, 55, 49, 53, 48, 53, 48, 57, 53, 53, 55,\n",
" 57, 58, 57, 58, 62, 60, 57, 53, 57, 58, 62, 60, 57,\n",
" 53, 65, 63, 65, 63, 62, 61, 63, 61, 60, 58, 57, 60,\n",
" 57, 58, 60, 61, 65, 69, 70, 72, 70, 69, 70, 72, 68,\n",
" 60, 60, 61, 63, 65, 68, 61, 66, 68, 66, 65, 66, 68,\n",
" 69, 70, 66, 63, 61, 60, 65, 65, 65, 75, 74, 70, 68,\n",
" 77, 67, 75, 68, 67, 68, 67, 63, 60, 58, 60, 62, 53,\n",
" 57, 58, 48, 60, 48, 48, 60, 48, 60, 48, 49, 61, 50,\n",
" 62, 50, 53, 65, 48, 60, 60, 60, 60, 57, 59, 61, 52,\n",
" 57, 59, 61, 52, 61, 62, 64, 65, 57, 60, 58, 57, 57,\n",
" 59, 61, 52, 57, 59, 61, 52, 61, 62, 64, 57, 66, 61,\n",
" 62, 64, 57, 66, 57, 59, 61, 60, 53, 54, 55, 60, 69,\n",
" 81, 80, 79, 76, 75, 72, 69, 65, 63, 60, 62, 74, 72,\n",
" 70, 67, 72, 70, 69, 60, 69, 69, 60, 69, 72, 72, 70,\n",
" 67, 72, 60, 69, 77, 77, 76, 67, 79, 69, 77, 81, 81,\n",
" 79, 81, 81, 79, 74, 77, 65, 65, 67, 67, 64, 64, 65,\n",
" 65, 60, 60, 48, 48, 60, 60, 48, 48, 60, 60, 48, 69,\n",
" 67, 60, 65, 65, 60, 60, 67, 67, 60, 60, 69, 69, 60,\n",
" 60, 70, 70, 60, 60, 64, 70, 64, 67, 62, 61, 62, 65,\n",
" 63, 60, 57, 63, 60, 57, 58, 57, 55, 63, 60, 57, 65,\n",
" 63, 62, 60, 59, 52, 53, 62, 65, 62, 59, 64, 52, 53,\n",
" 62, 65, 62, 59, 64, 65, 62, 64, 64, 60, 57, 64, 60,\n",
" 57, 65, 62, 59, 65, 62, 59, 67, 64, 60, 67, 64, 58,\n",
" 67, 64, 61, 67, 64, 61, 69, 65, 62, 69, 65, 62, 72,\n",
" 69, 66, 72, 67, 72, 72, 69, 63, 63, 72, 60, 63, 63,\n",
" 72, 54, 54, 54, 55, 55, 57, 57, 57, 58, 58, 60, 60,\n",
" 60, 61, 61, 62, 50, 61, 61, 62, 65, 64, 65, 68, 67,\n",
" 63, 67, 65, 62, 61, 62, 65, 63, 60, 63, 62, 56, 55,\n",
" 56, 62, 60, 51, 60, 59, 50, 49, 50, 62],\n",
" [ 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n",
" 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n",
" 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 80, 80,\n",
" 80, 80, 60, 80, 100, 80, 80, 60, 60, 80, 100, 80, 60,\n",
" 60, 100, 100, 100, 80, 80, 80, 87, 88, 89, 94, 100, 100,\n",
" 80, 80, 80, 60, 60, 60, 60, 61, 61, 61, 62, 63, 64,\n",
" 65, 66, 67, 68, 69, 70, 71, 72, 73, 73, 73, 74, 76,\n",
" 82, 90, 92, 94, 96, 98, 100, 80, 80, 100, 100, 100, 100,\n",
" 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 80,\n",
" 60, 40, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n",
" 68, 76, 84, 92, 100, 100, 80, 60, 40, 40, 40, 40, 40,\n",
" 40, 40, 40, 40, 40, 60, 66, 80, 80, 74, 72, 64, 40,\n",
" 40, 40, 40, 40, 40, 40, 40, 40, 46, 48, 52, 52, 64,\n",
" 70, 72, 76, 76, 88, 94, 96, 100, 100, 100, 80, 60, 60,\n",
" 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 74,\n",
" 76, 78, 80, 60, 40, 100, 100, 100, 90, 87, 84, 81, 71,\n",
" 68, 60, 60, 61, 62, 63, 66, 67, 69, 72, 73, 74, 75,\n",
" 78, 79, 81, 84, 85, 87, 90, 91, 92, 93, 96, 97, 98,\n",
" 100, 60, 60, 60, 60, 60, 60, 60, 60, 60, 66, 73, 80,\n",
" 80, 60, 60, 60, 60, 60, 60, 63, 66, 69, 78, 81, 84,\n",
" 88, 100, 80, 80, 60, 40, 40, 40, 40, 60, 60, 80, 80,\n",
" 60, 60, 40, 80, 60, 40, 80, 60, 40, 80, 60, 40, 100,\n",
" 100, 87, 74, 60, 60, 68, 76, 84, 92, 100, 100, 60, 60,\n",
" 70, 80, 90, 100, 100, 60, 60, 60, 60, 60, 60, 60, 60,\n",
" 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,\n",
" 60, 60, 60, 60, 61, 63, 65, 67, 69, 70, 72, 74, 76,\n",
" 78, 80, 82, 84, 86, 88, 90, 92, 94, 94, 96, 96, 98,\n",
" 98, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,\n",
" 100, 100, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80,\n",
" 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80,\n",
" 80, 80, 80, 81, 82, 83, 84, 86, 87, 88]])"
]
},
"metadata": {},
"execution_count": 87
}
]
},
{
"cell_type": "code",
"source": [
"history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A-P_oIp_l1d-",
"outputId": "aa8be3eb-82f0-4dc3-9ca8-15a2d6518a4e"
},
"execution_count": 88,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/10\n",
"65/65 [==============================] - 2s 12ms/step - loss: 14.2796 - accuracy: 0.3419 - val_loss: 4.4266 - val_accuracy: 0.3977\n",
"Epoch 2/10\n",
"65/65 [==============================] - 0s 5ms/step - loss: 3.9253 - accuracy: 0.3670 - val_loss: 2.4873 - val_accuracy: 0.4517\n",
"Epoch 3/10\n",
"65/65 [==============================] - 0s 5ms/step - loss: 2.7020 - accuracy: 0.3974 - val_loss: 2.2042 - val_accuracy: 0.4961\n",
"Epoch 4/10\n",
"65/65 [==============================] - 0s 5ms/step - loss: 2.2060 - accuracy: 0.4384 - val_loss: 2.8763 - val_accuracy: 0.1467\n",
"Epoch 5/10\n",
"65/65 [==============================] - 0s 7ms/step - loss: 1.8630 - accuracy: 0.4437 - val_loss: 1.6927 - val_accuracy: 0.4691\n",
"Epoch 6/10\n",
"65/65 [==============================] - 0s 6ms/step - loss: 1.6717 - accuracy: 0.4780 - val_loss: 1.6900 - val_accuracy: 0.4846\n",
"Epoch 7/10\n",
"65/65 [==============================] - 0s 6ms/step - loss: 1.5670 - accuracy: 0.5089 - val_loss: 1.8285 - val_accuracy: 0.4961\n",
"Epoch 8/10\n",
"65/65 [==============================] - 0s 5ms/step - loss: 1.6075 - accuracy: 0.5036 - val_loss: 1.6497 - val_accuracy: 0.5077\n",
"Epoch 9/10\n",
"65/65 [==============================] - 0s 5ms/step - loss: 1.4927 - accuracy: 0.5340 - val_loss: 1.4372 - val_accuracy: 0.5714\n",
"Epoch 10/10\n",
"65/65 [==============================] - 0s 5ms/step - loss: 1.4137 - accuracy: 0.5350 - val_loss: 1.5872 - val_accuracy: 0.4846\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment