Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save itrare/d5f0d2aca5e0e36b943e0dafc7136f04 to your computer and use it in GitHub Desktop.
Save itrare/d5f0d2aca5e0e36b943e0dafc7136f04 to your computer and use it in GitHub Desktop.
Untitled9.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Untitled9.ipynb",
"version": "0.3.2",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/parulnith/7f8c174e6ac099e86f0495d3d9a4c01e/untitled9.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "cNnM2w-HCeb1",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Music genre classification notebook"
]
},
{
"metadata": {
"id": "2l3sppZMCydR",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Importing Libraries"
]
},
{
"metadata": {
"id": "Gt3fyg6dCNvX",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# feature extractoring and preprocessing data\n",
"import librosa\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import os\n",
"from PIL import Image\n",
"import pathlib\n",
"import csv\n",
"\n",
"# Preprocessing\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"\n",
"#Keras\n",
"import keras\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "DPe_ebYuDqr5",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Extracting music and features\n",
"\n",
"### Dataset\n",
"\n",
"We use [GTZAN genre collection](http://marsyasweb.appspot.com/download/data_sets/) dataset for classification. \n",
"<br>\n",
"<br>\n",
"The dataset consists of 10 genres i.e\n",
" * Blues\n",
" * Classical\n",
" * Country\n",
" * Disco\n",
" * Hiphop\n",
" * Jazz\n",
" * Metal\n",
" * Pop\n",
" * Reggae\n",
" * Rock\n",
" \n",
"Each genre contains 100 songs. Total dataset: 1000 songs"
]
},
{
"metadata": {
"id": "neqMS0VoDpN5",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "AfBSVfRCD3PE",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Extracting the Spectrogram for every Audio"
]
},
{
"metadata": {
"id": "BHh3pTEVDdrT",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"cmap = plt.get_cmap('inferno')\n",
"\n",
"plt.figure(figsize=(10,10))\n",
"genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n",
"for g in genres:\n",
" pathlib.Path(f'img_data/{g}').mkdir(parents=True, exist_ok=True) \n",
" for filename in os.listdir(f'./MIR/genres/{g}'):\n",
" songname = f'./MIR/genres/{g}/{filename}'\n",
" y, sr = librosa.load(songname, mono=True, duration=5)\n",
" plt.specgram(y, NFFT=2048, Fs=2, Fc=0, noverlap=128, cmap=cmap, sides='default', mode='default', scale='dB');\n",
" plt.axis('off');\n",
" plt.savefig(f'img_data/{g}/{filename[:-3].replace(\".\", \"\")}.png')\n",
" plt.clf()\n",
" "
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "SszVgjYnFNX9",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"All the audio files get converted into their respective spectrograms .WE can noe easily extract features from them."
]
},
{
"metadata": {
"id": "3Nw9HpSdFRsW",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "piwUwgP5Eef9",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Extracting features from Spectrogram\n",
"\n",
"\n",
"We will extract\n",
"\n",
"* Mel-frequency cepstral coefficients (MFCC)(20 in number)\n",
"* Spectral Centroid,\n",
"* Zero Crossing Rate\n",
"* Chroma Frequencies\n",
"* Spectral Roll-off."
]
},
{
"metadata": {
"id": "__g8tX8pDeIL",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"header = 'filename chroma_stft rmse spectral_centroid spectral_bandwidth rolloff zero_crossing_rate'\n",
"for i in range(1, 21):\n",
" header += f' mfcc{i}'\n",
"header += ' label'\n",
"header = header.split()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "TBlT448pEqR9",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Writing data to csv file\n",
"\n",
"We write the data to a csv file "
]
},
{
"metadata": {
"id": "ZsSQmB0PE3Iu",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"file = open('data.csv', 'w', newline='')\n",
"with file:\n",
" writer = csv.writer(file)\n",
" writer.writerow(header)\n",
"genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n",
"for g in genres:\n",
" for filename in os.listdir(f'./MIR/genres/{g}'):\n",
" songname = f'./MIR/genres/{g}/{filename}'\n",
" y, sr = librosa.load(songname, mono=True, duration=30)\n",
" chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)\n",
" spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr)\n",
" spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)\n",
" rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)\n",
" zcr = librosa.feature.zero_crossing_rate(y)\n",
" mfcc = librosa.feature.mfcc(y=y, sr=sr)\n",
" to_append = f'{filename} {np.mean(chroma_stft)} {np.mean(rmse)} {np.mean(spec_cent)} {np.mean(spec_bw)} {np.mean(rolloff)} {np.mean(zcr)}' \n",
" for e in mfcc:\n",
" to_append += f' {np.mean(e)}'\n",
" to_append += f' {g}'\n",
" file = open('data.csv', 'a', newline='')\n",
" with file:\n",
" writer = csv.writer(file)\n",
" writer.writerow(to_append.split())"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "0yfdo1cj6V7d",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"The data has been extracted into a [data.csv](https://github.com/parulnith/Music-Genre-Classification-with-Python/blob/master/data.csv) file."
]
},
{
"metadata": {
"id": "fgeCZSKQEp1A",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Analysing the Data in Pandas"
]
},
{
"metadata": {
"id": "Kr5_EdpD9dyh",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 253
},
"outputId": "81fd4a29-93fa-44f8-bf90-2f99981f761a"
},
"cell_type": "code",
"source": [
"data = pd.read_csv('data.csv')\n",
"data.head()"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>filename</th>\n",
" <th>chroma_stft</th>\n",
" <th>rmse</th>\n",
" <th>spectral_centroid</th>\n",
" <th>spectral_bandwidth</th>\n",
" <th>rolloff</th>\n",
" <th>zero_crossing_rate</th>\n",
" <th>mfcc1</th>\n",
" <th>mfcc2</th>\n",
" <th>mfcc3</th>\n",
" <th>...</th>\n",
" <th>mfcc12</th>\n",
" <th>mfcc13</th>\n",
" <th>mfcc14</th>\n",
" <th>mfcc15</th>\n",
" <th>mfcc16</th>\n",
" <th>mfcc17</th>\n",
" <th>mfcc18</th>\n",
" <th>mfcc19</th>\n",
" <th>mfcc20</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>blues.00081.au</td>\n",
" <td>0.380260</td>\n",
" <td>0.248262</td>\n",
" <td>2116.942959</td>\n",
" <td>1956.611056</td>\n",
" <td>4196.107960</td>\n",
" <td>0.127272</td>\n",
" <td>-26.929785</td>\n",
" <td>107.334008</td>\n",
" <td>-46.809993</td>\n",
" <td>...</td>\n",
" <td>14.336612</td>\n",
" <td>-13.821769</td>\n",
" <td>7.562789</td>\n",
" <td>-6.181372</td>\n",
" <td>0.330165</td>\n",
" <td>-6.829571</td>\n",
" <td>0.965922</td>\n",
" <td>-7.570825</td>\n",
" <td>2.918987</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>blues.00022.au</td>\n",
" <td>0.306451</td>\n",
" <td>0.113475</td>\n",
" <td>1156.070496</td>\n",
" <td>1497.668176</td>\n",
" <td>2170.053545</td>\n",
" <td>0.058613</td>\n",
" <td>-233.860772</td>\n",
" <td>136.170239</td>\n",
" <td>3.289490</td>\n",
" <td>...</td>\n",
" <td>-2.250578</td>\n",
" <td>3.959198</td>\n",
" <td>5.322555</td>\n",
" <td>0.812028</td>\n",
" <td>-1.107202</td>\n",
" <td>-4.556555</td>\n",
" <td>-2.436490</td>\n",
" <td>3.316913</td>\n",
" <td>-0.608485</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>blues.00031.au</td>\n",
" <td>0.253487</td>\n",
" <td>0.151571</td>\n",
" <td>1331.073970</td>\n",
" <td>1973.643437</td>\n",
" <td>2900.174130</td>\n",
" <td>0.042967</td>\n",
" <td>-221.802549</td>\n",
" <td>110.843071</td>\n",
" <td>18.620984</td>\n",
" <td>...</td>\n",
" <td>-13.037723</td>\n",
" <td>-12.652228</td>\n",
" <td>-1.821905</td>\n",
" <td>-7.260097</td>\n",
" <td>-6.660252</td>\n",
" <td>-14.682694</td>\n",
" <td>-11.719264</td>\n",
" <td>-11.025216</td>\n",
" <td>-13.387260</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>blues.00012.au</td>\n",
" <td>0.269320</td>\n",
" <td>0.119072</td>\n",
" <td>1361.045467</td>\n",
" <td>1567.804596</td>\n",
" <td>2739.625101</td>\n",
" <td>0.069124</td>\n",
" <td>-207.208080</td>\n",
" <td>132.799175</td>\n",
" <td>-15.438986</td>\n",
" <td>...</td>\n",
" <td>-0.613248</td>\n",
" <td>0.384877</td>\n",
" <td>2.605128</td>\n",
" <td>-5.188924</td>\n",
" <td>-9.527455</td>\n",
" <td>-9.244394</td>\n",
" <td>-2.848274</td>\n",
" <td>-1.418707</td>\n",
" <td>-5.932607</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>blues.00056.au</td>\n",
" <td>0.391059</td>\n",
" <td>0.137728</td>\n",
" <td>1811.076084</td>\n",
" <td>2052.332563</td>\n",
" <td>3927.809582</td>\n",
" <td>0.075480</td>\n",
" <td>-145.434568</td>\n",
" <td>102.829023</td>\n",
" <td>-12.517677</td>\n",
" <td>...</td>\n",
" <td>7.457218</td>\n",
" <td>-10.470444</td>\n",
" <td>-2.360483</td>\n",
" <td>-6.783624</td>\n",
" <td>2.671134</td>\n",
" <td>-4.760879</td>\n",
" <td>-0.949005</td>\n",
" <td>0.024832</td>\n",
" <td>-2.005315</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 28 columns</p>\n",
"</div>"
],
"text/plain": [
" filename chroma_stft rmse spectral_centroid \\\n",
"0 blues.00081.au 0.380260 0.248262 2116.942959 \n",
"1 blues.00022.au 0.306451 0.113475 1156.070496 \n",
"2 blues.00031.au 0.253487 0.151571 1331.073970 \n",
"3 blues.00012.au 0.269320 0.119072 1361.045467 \n",
"4 blues.00056.au 0.391059 0.137728 1811.076084 \n",
"\n",
" spectral_bandwidth rolloff zero_crossing_rate mfcc1 \\\n",
"0 1956.611056 4196.107960 0.127272 -26.929785 \n",
"1 1497.668176 2170.053545 0.058613 -233.860772 \n",
"2 1973.643437 2900.174130 0.042967 -221.802549 \n",
"3 1567.804596 2739.625101 0.069124 -207.208080 \n",
"4 2052.332563 3927.809582 0.075480 -145.434568 \n",
"\n",
" mfcc2 mfcc3 ... mfcc12 mfcc13 mfcc14 mfcc15 \\\n",
"0 107.334008 -46.809993 ... 14.336612 -13.821769 7.562789 -6.181372 \n",
"1 136.170239 3.289490 ... -2.250578 3.959198 5.322555 0.812028 \n",
"2 110.843071 18.620984 ... -13.037723 -12.652228 -1.821905 -7.260097 \n",
"3 132.799175 -15.438986 ... -0.613248 0.384877 2.605128 -5.188924 \n",
"4 102.829023 -12.517677 ... 7.457218 -10.470444 -2.360483 -6.783624 \n",
"\n",
" mfcc16 mfcc17 mfcc18 mfcc19 mfcc20 label \n",
"0 0.330165 -6.829571 0.965922 -7.570825 2.918987 blues \n",
"1 -1.107202 -4.556555 -2.436490 3.316913 -0.608485 blues \n",
"2 -6.660252 -14.682694 -11.719264 -11.025216 -13.387260 blues \n",
"3 -9.527455 -9.244394 -2.848274 -1.418707 -5.932607 blues \n",
"4 2.671134 -4.760879 -0.949005 0.024832 -2.005315 blues \n",
"\n",
"[5 rows x 28 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"metadata": {
"id": "iHrDHCaR9gKR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "7d32943a-1ad5-4a59-c13a-beebeb36e4c2"
},
"cell_type": "code",
"source": [
"data.shape"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1000, 28)"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"metadata": {
"id": "veD5BgX49hZa",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# Dropping unneccesary columns\n",
"data = data.drop(['filename'],axis=1)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Nyr0aAAsGXjZ",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Encoding the Labels"
]
},
{
"metadata": {
"id": "frI5HH4q-1HS",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"genre_list = data.iloc[:, -1]\n",
"encoder = LabelEncoder()\n",
"y = encoder.fit_transform(genre_list)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Slm8W0-iGVhI",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "_2n8a02zGfvP",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Scaling the Feature columns"
]
},
{
"metadata": {
"id": "uqcqn-nyAofk",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"scaler = StandardScaler()\n",
"X = scaler.fit_transform(np.array(data.iloc[:, :-1], dtype = float))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "e3VZvbwpGo9R",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Dividing data into training and Testing set"
]
},
{
"metadata": {
"id": "F1GW3VvQA7Rj",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "upuczQ-KBHJ5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "1431a28b-e8b6-4db2-e505-7e149e37c0d7"
},
"cell_type": "code",
"source": [
"len(y_train)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"800"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"metadata": {
"id": "LtoE_FqqBzM8",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "76555a2b-2030-48e1-b52d-d71b4ebae38e"
},
"cell_type": "code",
"source": [
"len(y_test)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"200"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"metadata": {
"id": "ir9XaWgQB0lq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
},
"outputId": "2ec90814-19d8-4f27-934a-1ce54406d4ea"
},
"cell_type": "code",
"source": [
"X_train[10]"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([-0.9149113 , 0.18294103, -1.10587131, -1.3875197 , -1.14640873,\n",
" -0.97232926, -0.29174214, 1.20078936, -0.68458101, -0.55849017,\n",
" -1.27056582, -0.88176926, -0.74844069, -0.40970382, 0.49685952,\n",
" -1.12666045, 0.59501437, -0.39783853, 0.29327275, -0.72916871,\n",
" 0.63015786, -0.91149976, 0.7743942 , -0.64790051, 0.42229852,\n",
" -1.01449461])"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"metadata": {
"id": "Vp2yc5FWG04e",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Classification with Keras\n",
"\n",
"## Building our Network"
]
},
{
"metadata": {
"id": "Qj3sc2uFEUMt",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"from keras import models\n",
"from keras import layers\n",
"\n",
"model = models.Sequential()\n",
"model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))\n",
"\n",
"model.add(layers.Dense(128, activation='relu'))\n",
"\n",
"model.add(layers.Dense(64, activation='relu'))\n",
"\n",
"model.add(layers.Dense(10, activation='softmax'))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "7yrsmpI6EjJ2",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "bP0hVm4aElS7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 697
},
"outputId": "aacf234d-d0a9-4de4-91be-5fd45a33b279"
},
"cell_type": "code",
"source": [
"history = model.fit(X_train,\n",
" y_train,\n",
" epochs=20,\n",
" batch_size=128)\n",
" "
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"800/800 [==============================] - 1s 811us/step - loss: 2.1289 - acc: 0.2400\n",
"Epoch 2/20\n",
"800/800 [==============================] - 0s 39us/step - loss: 1.7940 - acc: 0.4088\n",
"Epoch 3/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 1.5437 - acc: 0.4450\n",
"Epoch 4/20\n",
"800/800 [==============================] - 0s 38us/step - loss: 1.3584 - acc: 0.5413\n",
"Epoch 5/20\n",
"800/800 [==============================] - 0s 38us/step - loss: 1.2220 - acc: 0.5750\n",
"Epoch 6/20\n",
"800/800 [==============================] - 0s 41us/step - loss: 1.1187 - acc: 0.6288\n",
"Epoch 7/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 1.0326 - acc: 0.6550\n",
"Epoch 8/20\n",
"800/800 [==============================] - 0s 44us/step - loss: 0.9631 - acc: 0.6713\n",
"Epoch 9/20\n",
"800/800 [==============================] - 0s 47us/step - loss: 0.9143 - acc: 0.6913\n",
"Epoch 10/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 0.8630 - acc: 0.7125\n",
"Epoch 11/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.8095 - acc: 0.7263\n",
"Epoch 12/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 0.7728 - acc: 0.7700\n",
"Epoch 13/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.7433 - acc: 0.7563\n",
"Epoch 14/20\n",
"800/800 [==============================] - 0s 45us/step - loss: 0.7066 - acc: 0.7825\n",
"Epoch 15/20\n",
"800/800 [==============================] - 0s 43us/step - loss: 0.6718 - acc: 0.7787\n",
"Epoch 16/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.6601 - acc: 0.7913\n",
"Epoch 17/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.6242 - acc: 0.7963\n",
"Epoch 18/20\n",
"800/800 [==============================] - 0s 44us/step - loss: 0.5994 - acc: 0.8038\n",
"Epoch 19/20\n",
"800/800 [==============================] - 0s 42us/step - loss: 0.5715 - acc: 0.8125\n",
"Epoch 20/20\n",
"800/800 [==============================] - 0s 39us/step - loss: 0.5437 - acc: 0.8250\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "0m1J0_wUFK4C",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ffd3bf36-29ea-437a-987c-9aa600b9dae6"
},
"cell_type": "code",
"source": [
"test_loss, test_acc = model.evaluate(X_test,y_test)"
],
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"text": [
"200/200 [==============================] - 0s 244us/step\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "f6HrjXeUF0Ko",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ea282dbd-6f9e-48c7-de2d-dc9afde8949e"
},
"cell_type": "code",
"source": [
"print('test_acc: ',test_acc)"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"text": [
"test_acc: 0.68\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "3yQmP_f5Kq0w",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Tes accuracy is less than training dataa accuracy. This hints at Overfitting"
]
},
{
"metadata": {
"id": "-U2qzRJoHV9O",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Validating our approach\n",
"Let's set apart 200 samples in our training data to use as a validation set:"
]
},
{
"metadata": {
"id": "xJNbvYZoF7ZT",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"x_val = X_train[:200]\n",
"partial_x_train = X_train[200:]\n",
"\n",
"y_val = y_train[:200]\n",
"partial_y_train = y_train[200:]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "L1EkG59EHeEV",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Now let's train our network for 20 epochs:"
]
},
{
"metadata": {
"id": "Dp3G4P3aP4k2",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1071
},
"outputId": "25e1a389-1ac2-425b-bd5f-05736b6e9b96"
},
"cell_type": "code",
"source": [
"model = models.Sequential()\n",
"model.add(layers.Dense(512, activation='relu', input_shape=(X_train.shape[1],)))\n",
"model.add(layers.Dense(256, activation='relu'))\n",
"model.add(layers.Dense(128, activation='relu'))\n",
"model.add(layers.Dense(64, activation='relu'))\n",
"model.add(layers.Dense(10, activation='softmax'))\n",
"\n",
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"model.fit(partial_x_train,\n",
" partial_y_train,\n",
" epochs=30,\n",
" batch_size=512,\n",
" validation_data=(x_val, y_val))\n",
"results = model.evaluate(X_test, y_test)"
],
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"text": [
"Train on 600 samples, validate on 200 samples\n",
"Epoch 1/30\n",
"600/600 [==============================] - 1s 1ms/step - loss: 2.3074 - acc: 0.0950 - val_loss: 2.1857 - val_acc: 0.2850\n",
"Epoch 2/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 2.1126 - acc: 0.3783 - val_loss: 2.0936 - val_acc: 0.2400\n",
"Epoch 3/30\n",
"600/600 [==============================] - 0s 59us/step - loss: 1.9535 - acc: 0.3633 - val_loss: 1.9966 - val_acc: 0.2600\n",
"Epoch 4/30\n",
"600/600 [==============================] - 0s 58us/step - loss: 1.8082 - acc: 0.3833 - val_loss: 1.8713 - val_acc: 0.3250\n",
"Epoch 5/30\n",
"600/600 [==============================] - 0s 59us/step - loss: 1.6663 - acc: 0.4083 - val_loss: 1.7302 - val_acc: 0.3450\n",
"Epoch 6/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 1.5329 - acc: 0.4550 - val_loss: 1.6233 - val_acc: 0.3700\n",
"Epoch 7/30\n",
"600/600 [==============================] - 0s 62us/step - loss: 1.4236 - acc: 0.4850 - val_loss: 1.5402 - val_acc: 0.3950\n",
"Epoch 8/30\n",
"600/600 [==============================] - 0s 57us/step - loss: 1.3250 - acc: 0.5117 - val_loss: 1.4655 - val_acc: 0.3800\n",
"Epoch 9/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 1.2338 - acc: 0.5633 - val_loss: 1.3927 - val_acc: 0.4650\n",
"Epoch 10/30\n",
"600/600 [==============================] - 0s 61us/step - loss: 1.1577 - acc: 0.5983 - val_loss: 1.3338 - val_acc: 0.5500\n",
"Epoch 11/30\n",
"600/600 [==============================] - 0s 64us/step - loss: 1.0981 - acc: 0.6317 - val_loss: 1.3111 - val_acc: 0.5550\n",
"Epoch 12/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 1.0529 - acc: 0.6517 - val_loss: 1.2696 - val_acc: 0.5400\n",
"Epoch 13/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 0.9994 - acc: 0.6567 - val_loss: 1.2480 - val_acc: 0.5400\n",
"Epoch 14/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 0.9673 - acc: 0.6633 - val_loss: 1.2384 - val_acc: 0.5700\n",
"Epoch 15/30\n",
"600/600 [==============================] - 0s 58us/step - loss: 0.9286 - acc: 0.6633 - val_loss: 1.1953 - val_acc: 0.5800\n",
"Epoch 16/30\n",
"600/600 [==============================] - 0s 59us/step - loss: 0.8849 - acc: 0.6783 - val_loss: 1.2000 - val_acc: 0.5550\n",
"Epoch 17/30\n",
"600/600 [==============================] - 0s 61us/step - loss: 0.8621 - acc: 0.6850 - val_loss: 1.1743 - val_acc: 0.5850\n",
"Epoch 18/30\n",
"600/600 [==============================] - 0s 61us/step - loss: 0.8195 - acc: 0.7150 - val_loss: 1.1609 - val_acc: 0.5750\n",
"Epoch 19/30\n",
"600/600 [==============================] - 0s 62us/step - loss: 0.7976 - acc: 0.7283 - val_loss: 1.1238 - val_acc: 0.6150\n",
"Epoch 20/30\n",
"600/600 [==============================] - 0s 63us/step - loss: 0.7660 - acc: 0.7650 - val_loss: 1.1604 - val_acc: 0.5850\n",
"Epoch 21/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 0.7465 - acc: 0.7650 - val_loss: 1.1888 - val_acc: 0.5700\n",
"Epoch 22/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 0.7099 - acc: 0.7517 - val_loss: 1.1563 - val_acc: 0.6050\n",
"Epoch 23/30\n",
"600/600 [==============================] - 0s 68us/step - loss: 0.6857 - acc: 0.7683 - val_loss: 1.0900 - val_acc: 0.6200\n",
"Epoch 24/30\n",
"600/600 [==============================] - 0s 67us/step - loss: 0.6597 - acc: 0.7850 - val_loss: 1.0872 - val_acc: 0.6300\n",
"Epoch 25/30\n",
"600/600 [==============================] - 0s 67us/step - loss: 0.6377 - acc: 0.7967 - val_loss: 1.1148 - val_acc: 0.6200\n",
"Epoch 26/30\n",
"600/600 [==============================] - 0s 64us/step - loss: 0.6070 - acc: 0.8200 - val_loss: 1.1397 - val_acc: 0.6150\n",
"Epoch 27/30\n",
"600/600 [==============================] - 0s 66us/step - loss: 0.5991 - acc: 0.8167 - val_loss: 1.1255 - val_acc: 0.6300\n",
"Epoch 28/30\n",
"600/600 [==============================] - 0s 62us/step - loss: 0.5656 - acc: 0.8333 - val_loss: 1.0955 - val_acc: 0.6350\n",
"Epoch 29/30\n",
"600/600 [==============================] - 0s 66us/step - loss: 0.5513 - acc: 0.8300 - val_loss: 1.1030 - val_acc: 0.6050\n",
"Epoch 30/30\n",
"600/600 [==============================] - 0s 56us/step - loss: 0.5498 - acc: 0.8233 - val_loss: 1.0869 - val_acc: 0.6250\n",
"200/200 [==============================] - 0s 65us/step\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "dljqHfDPI6lH",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "Mvi9it1SI4aR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "98b01ef2-3935-442b-82d6-45f56e036d39"
},
"cell_type": "code",
"source": [
"results"
],
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[1.2261371064186095, 0.65]"
]
},
"metadata": {
"tags": []
},
"execution_count": 38
}
]
},
{
"metadata": {
"id": "r3hb8s1l4rBA",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Predictions on Test Data"
]
},
{
"metadata": {
"id": "gudBAhIXJIi2",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"predictions = model.predict(X_test)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Xb7bVPSwJQF0",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "aca09c75-1d21-4847-bdd9-a0521dc8d948"
},
"cell_type": "code",
"source": [
"predictions[0].shape"
],
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(10,)"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"metadata": {
"id": "llusRQV0JRy9",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a856289d-883a-47cb-c0fb-ec148330a60a"
},
"cell_type": "code",
"source": [
"np.sum(predictions[0])"
],
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"metadata": {
"id": "0eoEuSZqJTdU",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "94c17d00-dd7f-40a1-84d2-78d1ebde6103"
},
"cell_type": "code",
"source": [
"np.argmax(predictions[0])"
],
"execution_count": 28,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"8"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"metadata": {
"id": "Utgt1bXfJVRN",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment