Forked from parulnith/Music_genre_classification.ipynb
Created
November 2, 2019 13:25
-
-
Save itrare/d5f0d2aca5e0e36b943e0dafc7136f04 to your computer and use it in GitHub Desktop.
Untitled9.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Untitled9.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "TPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/parulnith/7f8c174e6ac099e86f0495d3d9a4c01e/untitled9.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "cNnM2w-HCeb1", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# Music genre classification notebook" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "2l3sppZMCydR", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Importing Libraries" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Gt3fyg6dCNvX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"# feature extractoring and preprocessing data\n", | |
"import librosa\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"import os\n", | |
"from PIL import Image\n", | |
"import pathlib\n", | |
"import csv\n", | |
"\n", | |
"# Preprocessing\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n", | |
"\n", | |
"#Keras\n", | |
"import keras\n", | |
"\n", | |
"import warnings\n", | |
"warnings.filterwarnings('ignore')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "DPe_ebYuDqr5", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Extracting music and features\n", | |
"\n", | |
"### Dataset\n", | |
"\n", | |
"We use [GTZAN genre collection](http://marsyasweb.appspot.com/download/data_sets/) dataset for classification. \n", | |
"<br>\n", | |
"<br>\n", | |
"The dataset consists of 10 genres i.e\n", | |
" * Blues\n", | |
" * Classical\n", | |
" * Country\n", | |
" * Disco\n", | |
" * Hiphop\n", | |
" * Jazz\n", | |
" * Metal\n", | |
" * Pop\n", | |
" * Reggae\n", | |
" * Rock\n", | |
" \n", | |
"Each genre contains 100 songs. Total dataset: 1000 songs" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "neqMS0VoDpN5", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "AfBSVfRCD3PE", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Extracting the Spectrogram for every Audio" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "BHh3pTEVDdrT", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"cmap = plt.get_cmap('inferno')\n", | |
"\n", | |
"plt.figure(figsize=(10,10))\n", | |
"genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n", | |
"for g in genres:\n", | |
" pathlib.Path(f'img_data/{g}').mkdir(parents=True, exist_ok=True) \n", | |
" for filename in os.listdir(f'./MIR/genres/{g}'):\n", | |
" songname = f'./MIR/genres/{g}/{filename}'\n", | |
" y, sr = librosa.load(songname, mono=True, duration=5)\n", | |
" plt.specgram(y, NFFT=2048, Fs=2, Fc=0, noverlap=128, cmap=cmap, sides='default', mode='default', scale='dB');\n", | |
" plt.axis('off');\n", | |
" plt.savefig(f'img_data/{g}/{filename[:-3].replace(\".\", \"\")}.png')\n", | |
" plt.clf()\n", | |
" " | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "SszVgjYnFNX9", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"All the audio files get converted into their respective spectrograms .WE can noe easily extract features from them." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "3Nw9HpSdFRsW", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "piwUwgP5Eef9", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Extracting features from Spectrogram\n", | |
"\n", | |
"\n", | |
"We will extract\n", | |
"\n", | |
"* Mel-frequency cepstral coefficients (MFCC)(20 in number)\n", | |
"* Spectral Centroid,\n", | |
"* Zero Crossing Rate\n", | |
"* Chroma Frequencies\n", | |
"* Spectral Roll-off." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "__g8tX8pDeIL", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"header = 'filename chroma_stft rmse spectral_centroid spectral_bandwidth rolloff zero_crossing_rate'\n", | |
"for i in range(1, 21):\n", | |
" header += f' mfcc{i}'\n", | |
"header += ' label'\n", | |
"header = header.split()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "TBlT448pEqR9", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Writing data to csv file\n", | |
"\n", | |
"We write the data to a csv file " | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "ZsSQmB0PE3Iu", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"file = open('data.csv', 'w', newline='')\n", | |
"with file:\n", | |
" writer = csv.writer(file)\n", | |
" writer.writerow(header)\n", | |
"genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n", | |
"for g in genres:\n", | |
" for filename in os.listdir(f'./MIR/genres/{g}'):\n", | |
" songname = f'./MIR/genres/{g}/{filename}'\n", | |
" y, sr = librosa.load(songname, mono=True, duration=30)\n", | |
" chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)\n", | |
" spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr)\n", | |
" spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)\n", | |
" rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)\n", | |
" zcr = librosa.feature.zero_crossing_rate(y)\n", | |
" mfcc = librosa.feature.mfcc(y=y, sr=sr)\n", | |
" to_append = f'{filename} {np.mean(chroma_stft)} {np.mean(rmse)} {np.mean(spec_cent)} {np.mean(spec_bw)} {np.mean(rolloff)} {np.mean(zcr)}' \n", | |
" for e in mfcc:\n", | |
" to_append += f' {np.mean(e)}'\n", | |
" to_append += f' {g}'\n", | |
" file = open('data.csv', 'a', newline='')\n", | |
" with file:\n", | |
" writer = csv.writer(file)\n", | |
" writer.writerow(to_append.split())" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "0yfdo1cj6V7d", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"The data has been extracted into a [data.csv](https://github.com/parulnith/Music-Genre-Classification-with-Python/blob/master/data.csv) file." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "fgeCZSKQEp1A", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# Analysing the Data in Pandas" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Kr5_EdpD9dyh", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 253 | |
}, | |
"outputId": "81fd4a29-93fa-44f8-bf90-2f99981f761a" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"data = pd.read_csv('data.csv')\n", | |
"data.head()" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>filename</th>\n", | |
" <th>chroma_stft</th>\n", | |
" <th>rmse</th>\n", | |
" <th>spectral_centroid</th>\n", | |
" <th>spectral_bandwidth</th>\n", | |
" <th>rolloff</th>\n", | |
" <th>zero_crossing_rate</th>\n", | |
" <th>mfcc1</th>\n", | |
" <th>mfcc2</th>\n", | |
" <th>mfcc3</th>\n", | |
" <th>...</th>\n", | |
" <th>mfcc12</th>\n", | |
" <th>mfcc13</th>\n", | |
" <th>mfcc14</th>\n", | |
" <th>mfcc15</th>\n", | |
" <th>mfcc16</th>\n", | |
" <th>mfcc17</th>\n", | |
" <th>mfcc18</th>\n", | |
" <th>mfcc19</th>\n", | |
" <th>mfcc20</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>blues.00081.au</td>\n", | |
" <td>0.380260</td>\n", | |
" <td>0.248262</td>\n", | |
" <td>2116.942959</td>\n", | |
" <td>1956.611056</td>\n", | |
" <td>4196.107960</td>\n", | |
" <td>0.127272</td>\n", | |
" <td>-26.929785</td>\n", | |
" <td>107.334008</td>\n", | |
" <td>-46.809993</td>\n", | |
" <td>...</td>\n", | |
" <td>14.336612</td>\n", | |
" <td>-13.821769</td>\n", | |
" <td>7.562789</td>\n", | |
" <td>-6.181372</td>\n", | |
" <td>0.330165</td>\n", | |
" <td>-6.829571</td>\n", | |
" <td>0.965922</td>\n", | |
" <td>-7.570825</td>\n", | |
" <td>2.918987</td>\n", | |
" <td>blues</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>blues.00022.au</td>\n", | |
" <td>0.306451</td>\n", | |
" <td>0.113475</td>\n", | |
" <td>1156.070496</td>\n", | |
" <td>1497.668176</td>\n", | |
" <td>2170.053545</td>\n", | |
" <td>0.058613</td>\n", | |
" <td>-233.860772</td>\n", | |
" <td>136.170239</td>\n", | |
" <td>3.289490</td>\n", | |
" <td>...</td>\n", | |
" <td>-2.250578</td>\n", | |
" <td>3.959198</td>\n", | |
" <td>5.322555</td>\n", | |
" <td>0.812028</td>\n", | |
" <td>-1.107202</td>\n", | |
" <td>-4.556555</td>\n", | |
" <td>-2.436490</td>\n", | |
" <td>3.316913</td>\n", | |
" <td>-0.608485</td>\n", | |
" <td>blues</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>blues.00031.au</td>\n", | |
" <td>0.253487</td>\n", | |
" <td>0.151571</td>\n", | |
" <td>1331.073970</td>\n", | |
" <td>1973.643437</td>\n", | |
" <td>2900.174130</td>\n", | |
" <td>0.042967</td>\n", | |
" <td>-221.802549</td>\n", | |
" <td>110.843071</td>\n", | |
" <td>18.620984</td>\n", | |
" <td>...</td>\n", | |
" <td>-13.037723</td>\n", | |
" <td>-12.652228</td>\n", | |
" <td>-1.821905</td>\n", | |
" <td>-7.260097</td>\n", | |
" <td>-6.660252</td>\n", | |
" <td>-14.682694</td>\n", | |
" <td>-11.719264</td>\n", | |
" <td>-11.025216</td>\n", | |
" <td>-13.387260</td>\n", | |
" <td>blues</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>blues.00012.au</td>\n", | |
" <td>0.269320</td>\n", | |
" <td>0.119072</td>\n", | |
" <td>1361.045467</td>\n", | |
" <td>1567.804596</td>\n", | |
" <td>2739.625101</td>\n", | |
" <td>0.069124</td>\n", | |
" <td>-207.208080</td>\n", | |
" <td>132.799175</td>\n", | |
" <td>-15.438986</td>\n", | |
" <td>...</td>\n", | |
" <td>-0.613248</td>\n", | |
" <td>0.384877</td>\n", | |
" <td>2.605128</td>\n", | |
" <td>-5.188924</td>\n", | |
" <td>-9.527455</td>\n", | |
" <td>-9.244394</td>\n", | |
" <td>-2.848274</td>\n", | |
" <td>-1.418707</td>\n", | |
" <td>-5.932607</td>\n", | |
" <td>blues</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>blues.00056.au</td>\n", | |
" <td>0.391059</td>\n", | |
" <td>0.137728</td>\n", | |
" <td>1811.076084</td>\n", | |
" <td>2052.332563</td>\n", | |
" <td>3927.809582</td>\n", | |
" <td>0.075480</td>\n", | |
" <td>-145.434568</td>\n", | |
" <td>102.829023</td>\n", | |
" <td>-12.517677</td>\n", | |
" <td>...</td>\n", | |
" <td>7.457218</td>\n", | |
" <td>-10.470444</td>\n", | |
" <td>-2.360483</td>\n", | |
" <td>-6.783624</td>\n", | |
" <td>2.671134</td>\n", | |
" <td>-4.760879</td>\n", | |
" <td>-0.949005</td>\n", | |
" <td>0.024832</td>\n", | |
" <td>-2.005315</td>\n", | |
" <td>blues</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 28 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" filename chroma_stft rmse spectral_centroid \\\n", | |
"0 blues.00081.au 0.380260 0.248262 2116.942959 \n", | |
"1 blues.00022.au 0.306451 0.113475 1156.070496 \n", | |
"2 blues.00031.au 0.253487 0.151571 1331.073970 \n", | |
"3 blues.00012.au 0.269320 0.119072 1361.045467 \n", | |
"4 blues.00056.au 0.391059 0.137728 1811.076084 \n", | |
"\n", | |
" spectral_bandwidth rolloff zero_crossing_rate mfcc1 \\\n", | |
"0 1956.611056 4196.107960 0.127272 -26.929785 \n", | |
"1 1497.668176 2170.053545 0.058613 -233.860772 \n", | |
"2 1973.643437 2900.174130 0.042967 -221.802549 \n", | |
"3 1567.804596 2739.625101 0.069124 -207.208080 \n", | |
"4 2052.332563 3927.809582 0.075480 -145.434568 \n", | |
"\n", | |
" mfcc2 mfcc3 ... mfcc12 mfcc13 mfcc14 mfcc15 \\\n", | |
"0 107.334008 -46.809993 ... 14.336612 -13.821769 7.562789 -6.181372 \n", | |
"1 136.170239 3.289490 ... -2.250578 3.959198 5.322555 0.812028 \n", | |
"2 110.843071 18.620984 ... -13.037723 -12.652228 -1.821905 -7.260097 \n", | |
"3 132.799175 -15.438986 ... -0.613248 0.384877 2.605128 -5.188924 \n", | |
"4 102.829023 -12.517677 ... 7.457218 -10.470444 -2.360483 -6.783624 \n", | |
"\n", | |
" mfcc16 mfcc17 mfcc18 mfcc19 mfcc20 label \n", | |
"0 0.330165 -6.829571 0.965922 -7.570825 2.918987 blues \n", | |
"1 -1.107202 -4.556555 -2.436490 3.316913 -0.608485 blues \n", | |
"2 -6.660252 -14.682694 -11.719264 -11.025216 -13.387260 blues \n", | |
"3 -9.527455 -9.244394 -2.848274 -1.418707 -5.932607 blues \n", | |
"4 2.671134 -4.760879 -0.949005 0.024832 -2.005315 blues \n", | |
"\n", | |
"[5 rows x 28 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "iHrDHCaR9gKR", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "7d32943a-1ad5-4a59-c13a-beebeb36e4c2" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"data.shape" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(1000, 28)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "veD5BgX49hZa", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"# Dropping unneccesary columns\n", | |
"data = data.drop(['filename'],axis=1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Nyr0aAAsGXjZ", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Encoding the Labels" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "frI5HH4q-1HS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"genre_list = data.iloc[:, -1]\n", | |
"encoder = LabelEncoder()\n", | |
"y = encoder.fit_transform(genre_list)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Slm8W0-iGVhI", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "_2n8a02zGfvP", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Scaling the Feature columns" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "uqcqn-nyAofk", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"scaler = StandardScaler()\n", | |
"X = scaler.fit_transform(np.array(data.iloc[:, :-1], dtype = float))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "e3VZvbwpGo9R", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Dividing data into training and Testing set" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "F1GW3VvQA7Rj", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "upuczQ-KBHJ5", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "1431a28b-e8b6-4db2-e505-7e149e37c0d7" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"len(y_train)" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"800" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "LtoE_FqqBzM8", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "76555a2b-2030-48e1-b52d-d71b4ebae38e" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"len(y_test)" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"200" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "ir9XaWgQB0lq", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 119 | |
}, | |
"outputId": "2ec90814-19d8-4f27-934a-1ce54406d4ea" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"X_train[10]" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([-0.9149113 , 0.18294103, -1.10587131, -1.3875197 , -1.14640873,\n", | |
" -0.97232926, -0.29174214, 1.20078936, -0.68458101, -0.55849017,\n", | |
" -1.27056582, -0.88176926, -0.74844069, -0.40970382, 0.49685952,\n", | |
" -1.12666045, 0.59501437, -0.39783853, 0.29327275, -0.72916871,\n", | |
" 0.63015786, -0.91149976, 0.7743942 , -0.64790051, 0.42229852,\n", | |
" -1.01449461])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Vp2yc5FWG04e", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# Classification with Keras\n", | |
"\n", | |
"## Building our Network" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Qj3sc2uFEUMt", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"from keras import models\n", | |
"from keras import layers\n", | |
"\n", | |
"model = models.Sequential()\n", | |
"model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))\n", | |
"\n", | |
"model.add(layers.Dense(128, activation='relu'))\n", | |
"\n", | |
"model.add(layers.Dense(64, activation='relu'))\n", | |
"\n", | |
"model.add(layers.Dense(10, activation='softmax'))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "7yrsmpI6EjJ2", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"model.compile(optimizer='adam',\n", | |
" loss='sparse_categorical_crossentropy',\n", | |
" metrics=['accuracy'])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "bP0hVm4aElS7", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 697 | |
}, | |
"outputId": "aacf234d-d0a9-4de4-91be-5fd45a33b279" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"history = model.fit(X_train,\n", | |
" y_train,\n", | |
" epochs=20,\n", | |
" batch_size=128)\n", | |
" " | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/20\n", | |
"800/800 [==============================] - 1s 811us/step - loss: 2.1289 - acc: 0.2400\n", | |
"Epoch 2/20\n", | |
"800/800 [==============================] - 0s 39us/step - loss: 1.7940 - acc: 0.4088\n", | |
"Epoch 3/20\n", | |
"800/800 [==============================] - 0s 37us/step - loss: 1.5437 - acc: 0.4450\n", | |
"Epoch 4/20\n", | |
"800/800 [==============================] - 0s 38us/step - loss: 1.3584 - acc: 0.5413\n", | |
"Epoch 5/20\n", | |
"800/800 [==============================] - 0s 38us/step - loss: 1.2220 - acc: 0.5750\n", | |
"Epoch 6/20\n", | |
"800/800 [==============================] - 0s 41us/step - loss: 1.1187 - acc: 0.6288\n", | |
"Epoch 7/20\n", | |
"800/800 [==============================] - 0s 37us/step - loss: 1.0326 - acc: 0.6550\n", | |
"Epoch 8/20\n", | |
"800/800 [==============================] - 0s 44us/step - loss: 0.9631 - acc: 0.6713\n", | |
"Epoch 9/20\n", | |
"800/800 [==============================] - 0s 47us/step - loss: 0.9143 - acc: 0.6913\n", | |
"Epoch 10/20\n", | |
"800/800 [==============================] - 0s 37us/step - loss: 0.8630 - acc: 0.7125\n", | |
"Epoch 11/20\n", | |
"800/800 [==============================] - 0s 36us/step - loss: 0.8095 - acc: 0.7263\n", | |
"Epoch 12/20\n", | |
"800/800 [==============================] - 0s 37us/step - loss: 0.7728 - acc: 0.7700\n", | |
"Epoch 13/20\n", | |
"800/800 [==============================] - 0s 36us/step - loss: 0.7433 - acc: 0.7563\n", | |
"Epoch 14/20\n", | |
"800/800 [==============================] - 0s 45us/step - loss: 0.7066 - acc: 0.7825\n", | |
"Epoch 15/20\n", | |
"800/800 [==============================] - 0s 43us/step - loss: 0.6718 - acc: 0.7787\n", | |
"Epoch 16/20\n", | |
"800/800 [==============================] - 0s 36us/step - loss: 0.6601 - acc: 0.7913\n", | |
"Epoch 17/20\n", | |
"800/800 [==============================] - 0s 36us/step - loss: 0.6242 - acc: 0.7963\n", | |
"Epoch 18/20\n", | |
"800/800 [==============================] - 0s 44us/step - loss: 0.5994 - acc: 0.8038\n", | |
"Epoch 19/20\n", | |
"800/800 [==============================] - 0s 42us/step - loss: 0.5715 - acc: 0.8125\n", | |
"Epoch 20/20\n", | |
"800/800 [==============================] - 0s 39us/step - loss: 0.5437 - acc: 0.8250\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "0m1J0_wUFK4C", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "ffd3bf36-29ea-437a-987c-9aa600b9dae6" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"test_loss, test_acc = model.evaluate(X_test,y_test)" | |
], | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"200/200 [==============================] - 0s 244us/step\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "f6HrjXeUF0Ko", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "ea282dbd-6f9e-48c7-de2d-dc9afde8949e" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"print('test_acc: ',test_acc)" | |
], | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"test_acc: 0.68\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "3yQmP_f5Kq0w", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"Tes accuracy is less than training dataa accuracy. This hints at Overfitting" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "-U2qzRJoHV9O", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Validating our approach\n", | |
"Let's set apart 200 samples in our training data to use as a validation set:" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "xJNbvYZoF7ZT", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"x_val = X_train[:200]\n", | |
"partial_x_train = X_train[200:]\n", | |
"\n", | |
"y_val = y_train[:200]\n", | |
"partial_y_train = y_train[200:]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "L1EkG59EHeEV", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"Now let's train our network for 20 epochs:" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Dp3G4P3aP4k2", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1071 | |
}, | |
"outputId": "25e1a389-1ac2-425b-bd5f-05736b6e9b96" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"model = models.Sequential()\n", | |
"model.add(layers.Dense(512, activation='relu', input_shape=(X_train.shape[1],)))\n", | |
"model.add(layers.Dense(256, activation='relu'))\n", | |
"model.add(layers.Dense(128, activation='relu'))\n", | |
"model.add(layers.Dense(64, activation='relu'))\n", | |
"model.add(layers.Dense(10, activation='softmax'))\n", | |
"\n", | |
"model.compile(optimizer='adam',\n", | |
" loss='sparse_categorical_crossentropy',\n", | |
" metrics=['accuracy'])\n", | |
"\n", | |
"model.fit(partial_x_train,\n", | |
" partial_y_train,\n", | |
" epochs=30,\n", | |
" batch_size=512,\n", | |
" validation_data=(x_val, y_val))\n", | |
"results = model.evaluate(X_test, y_test)" | |
], | |
"execution_count": 37, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Train on 600 samples, validate on 200 samples\n", | |
"Epoch 1/30\n", | |
"600/600 [==============================] - 1s 1ms/step - loss: 2.3074 - acc: 0.0950 - val_loss: 2.1857 - val_acc: 0.2850\n", | |
"Epoch 2/30\n", | |
"600/600 [==============================] - 0s 65us/step - loss: 2.1126 - acc: 0.3783 - val_loss: 2.0936 - val_acc: 0.2400\n", | |
"Epoch 3/30\n", | |
"600/600 [==============================] - 0s 59us/step - loss: 1.9535 - acc: 0.3633 - val_loss: 1.9966 - val_acc: 0.2600\n", | |
"Epoch 4/30\n", | |
"600/600 [==============================] - 0s 58us/step - loss: 1.8082 - acc: 0.3833 - val_loss: 1.8713 - val_acc: 0.3250\n", | |
"Epoch 5/30\n", | |
"600/600 [==============================] - 0s 59us/step - loss: 1.6663 - acc: 0.4083 - val_loss: 1.7302 - val_acc: 0.3450\n", | |
"Epoch 6/30\n", | |
"600/600 [==============================] - 0s 52us/step - loss: 1.5329 - acc: 0.4550 - val_loss: 1.6233 - val_acc: 0.3700\n", | |
"Epoch 7/30\n", | |
"600/600 [==============================] - 0s 62us/step - loss: 1.4236 - acc: 0.4850 - val_loss: 1.5402 - val_acc: 0.3950\n", | |
"Epoch 8/30\n", | |
"600/600 [==============================] - 0s 57us/step - loss: 1.3250 - acc: 0.5117 - val_loss: 1.4655 - val_acc: 0.3800\n", | |
"Epoch 9/30\n", | |
"600/600 [==============================] - 0s 52us/step - loss: 1.2338 - acc: 0.5633 - val_loss: 1.3927 - val_acc: 0.4650\n", | |
"Epoch 10/30\n", | |
"600/600 [==============================] - 0s 61us/step - loss: 1.1577 - acc: 0.5983 - val_loss: 1.3338 - val_acc: 0.5500\n", | |
"Epoch 11/30\n", | |
"600/600 [==============================] - 0s 64us/step - loss: 1.0981 - acc: 0.6317 - val_loss: 1.3111 - val_acc: 0.5550\n", | |
"Epoch 12/30\n", | |
"600/600 [==============================] - 0s 52us/step - loss: 1.0529 - acc: 0.6517 - val_loss: 1.2696 - val_acc: 0.5400\n", | |
"Epoch 13/30\n", | |
"600/600 [==============================] - 0s 52us/step - loss: 0.9994 - acc: 0.6567 - val_loss: 1.2480 - val_acc: 0.5400\n", | |
"Epoch 14/30\n", | |
"600/600 [==============================] - 0s 65us/step - loss: 0.9673 - acc: 0.6633 - val_loss: 1.2384 - val_acc: 0.5700\n", | |
"Epoch 15/30\n", | |
"600/600 [==============================] - 0s 58us/step - loss: 0.9286 - acc: 0.6633 - val_loss: 1.1953 - val_acc: 0.5800\n", | |
"Epoch 16/30\n", | |
"600/600 [==============================] - 0s 59us/step - loss: 0.8849 - acc: 0.6783 - val_loss: 1.2000 - val_acc: 0.5550\n", | |
"Epoch 17/30\n", | |
"600/600 [==============================] - 0s 61us/step - loss: 0.8621 - acc: 0.6850 - val_loss: 1.1743 - val_acc: 0.5850\n", | |
"Epoch 18/30\n", | |
"600/600 [==============================] - 0s 61us/step - loss: 0.8195 - acc: 0.7150 - val_loss: 1.1609 - val_acc: 0.5750\n", | |
"Epoch 19/30\n", | |
"600/600 [==============================] - 0s 62us/step - loss: 0.7976 - acc: 0.7283 - val_loss: 1.1238 - val_acc: 0.6150\n", | |
"Epoch 20/30\n", | |
"600/600 [==============================] - 0s 63us/step - loss: 0.7660 - acc: 0.7650 - val_loss: 1.1604 - val_acc: 0.5850\n", | |
"Epoch 21/30\n", | |
"600/600 [==============================] - 0s 65us/step - loss: 0.7465 - acc: 0.7650 - val_loss: 1.1888 - val_acc: 0.5700\n", | |
"Epoch 22/30\n", | |
"600/600 [==============================] - 0s 65us/step - loss: 0.7099 - acc: 0.7517 - val_loss: 1.1563 - val_acc: 0.6050\n", | |
"Epoch 23/30\n", | |
"600/600 [==============================] - 0s 68us/step - loss: 0.6857 - acc: 0.7683 - val_loss: 1.0900 - val_acc: 0.6200\n", | |
"Epoch 24/30\n", | |
"600/600 [==============================] - 0s 67us/step - loss: 0.6597 - acc: 0.7850 - val_loss: 1.0872 - val_acc: 0.6300\n", | |
"Epoch 25/30\n", | |
"600/600 [==============================] - 0s 67us/step - loss: 0.6377 - acc: 0.7967 - val_loss: 1.1148 - val_acc: 0.6200\n", | |
"Epoch 26/30\n", | |
"600/600 [==============================] - 0s 64us/step - loss: 0.6070 - acc: 0.8200 - val_loss: 1.1397 - val_acc: 0.6150\n", | |
"Epoch 27/30\n", | |
"600/600 [==============================] - 0s 66us/step - loss: 0.5991 - acc: 0.8167 - val_loss: 1.1255 - val_acc: 0.6300\n", | |
"Epoch 28/30\n", | |
"600/600 [==============================] - 0s 62us/step - loss: 0.5656 - acc: 0.8333 - val_loss: 1.0955 - val_acc: 0.6350\n", | |
"Epoch 29/30\n", | |
"600/600 [==============================] - 0s 66us/step - loss: 0.5513 - acc: 0.8300 - val_loss: 1.1030 - val_acc: 0.6050\n", | |
"Epoch 30/30\n", | |
"600/600 [==============================] - 0s 56us/step - loss: 0.5498 - acc: 0.8233 - val_loss: 1.0869 - val_acc: 0.6250\n", | |
"200/200 [==============================] - 0s 65us/step\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "dljqHfDPI6lH", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Mvi9it1SI4aR", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "98b01ef2-3935-442b-82d6-45f56e036d39" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"results" | |
], | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[1.2261371064186095, 0.65]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 38 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "r3hb8s1l4rBA", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Predictions on Test Data" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "gudBAhIXJIi2", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"predictions = model.predict(X_test)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Xb7bVPSwJQF0", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "aca09c75-1d21-4847-bdd9-a0521dc8d948" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"predictions[0].shape" | |
], | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(10,)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 26 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "llusRQV0JRy9", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "a856289d-883a-47cb-c0fb-ec148330a60a" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"np.sum(predictions[0])" | |
], | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "0eoEuSZqJTdU", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "94c17d00-dd7f-40a1-84d2-78d1ebde6103" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"np.argmax(predictions[0])" | |
], | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"8" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Utgt1bXfJVRN", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment