Created
December 15, 2013 20:17
-
-
Save pckujawa/7977590 to your computer and use it in GitHub Desktop.
Simple audio classifier (speech vs music) using scikit-learn (Naive Bayes classifier). Made for Multimedia Processing course.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" usage: | |
a4.py train TRAIN_FEATURE_FILE [--new] [--validate] | |
a4.py classify MUSIC_FEATURE_FILE | |
The files should be CSV with 6 columns, the last of which is the target/label/class (or empty, if classifying), and the first of which is ignored. | |
""" | |
#------------------------------------------------------------------------------- | |
# Name: Pat Kujawa | |
# Purpose: MM audio classification asn 4 | |
#------------------------------------------------------------------------------- | |
from __future__ import division | |
import os, sys | |
import docopt | |
import numpy as np | |
import cPickle as pickle | |
from sklearn.naive_bayes import GaussianNB | |
from sklearn import cross_validation | |
from sklearn.metrics import confusion_matrix | |
from sklearn.metrics import classification_report | |
from sklearn.cross_validation import cross_val_score | |
picklePath = r"classifier.pickle" | |
target_names = ['speech', 'music'] # false, true | |
def preProcess(csvFile, classifying=False): | |
"""Returns (data, targets) where targets is bool array repr IsMusic. | |
""" | |
## csvFile = r"C:\Users\Pat\Dropbox\UM Grad School\2013 Fall\Multimedia MM processing 578\asn4-audio-classifier\energy,zc,zcr,centroid,bw,name,ismusic.csv" | |
## datatable = np.genfromtxt(csvFile, delimiter=',', names=True, dtype=None) | |
## featureTable = datatable[sorted(list(set(datatable.dtype.names) - {'zc', 'ismusic', 'name'}))] # use zero crossing rate instead of absolute count; ditch non-feature data | |
## classifications = datatable['ismusic'] # bool | |
names = np.genfromtxt(csvFile, delimiter=',', usecols=(0), dtype=str) | |
data = np.genfromtxt(csvFile, delimiter=',', usecols=(1,2,3,4)) | |
if classifying: | |
targets = None | |
else: | |
targets = np.genfromtxt(csvFile, delimiter=',', usecols=(5), dtype=bool) # bool ismusic | |
return data, names, targets | |
def train(data, names, targets, startNew=False, cv=False): | |
"""Create and serialize a classifier trained on 2/3 of the input data. | |
:param startNew: create a new classifier if true else add to the training of the previous classifier | |
:param cv: do cross-validation with a subset of items | |
""" | |
classifier = None | |
if not startNew: | |
try: | |
with open(picklePath, 'rb') as f: | |
classifier = pickle.load(f) | |
except: | |
sys.stderr.write("Couldn't deserialize classifier. Creating a new one instead \n") | |
t = targets | |
# From DZone.com refcard: Data Mining - Discovering and Visualizing Patterns with Python by Giuseppe Vettigli | |
classifier = classifier or GaussianNB() | |
if not cv: | |
classifier.fit(data, t) # training | |
print 'Trained on all files:', ','.join(names) | |
return '' | |
##from sklearn import svm | |
##classifier = svm.SVC() # classifying all as Speech | |
# t_ means target, as in expected/desired classification | |
train, test, t_train, t_test, trainFiles, testFiles = \ | |
cross_validation.train_test_split(data, t, names, test_size=0.33) | |
# show which files are used for train/test | |
print 'Training files:', ','.join(trainFiles) | |
## print sum((s.startswith("mu") for s in trainFiles)), 'music files /', len(trainFiles) | |
print 'Test files:', ','.join(testFiles) | |
classifier.fit(train, t_train) # train | |
print 'Prior probabilities (n={}):'.format(len(trainFiles)) | |
for cls, prob in zip(classifier.classes_, classifier.class_prior_): | |
print target_names[cls], prob | |
print "Accuracy for 2/3 training, 1/3 test:" | |
print classifier.score(test, t_test) # test | |
# 0.0625 :( | |
print "Confusion matrix for 2/3 training, 1/3 test:" | |
print confusion_matrix(classifier.predict(test), t_test) | |
##[[2 2] | |
## [4 8]] | |
print 'Classification report for 2/3 training, 1/3 test:' | |
print classification_report(classifier.predict(test), | |
t_test, target_names=target_names) | |
print 'leave one out cv' | |
# cross validation with leave one out | |
# http://stackoverflow.com/questions/17499068/train-scikit-svm-customize-score-assessment | |
scores = cross_val_score(classifier, data, t, | |
cv=cross_validation.LeaveOneOut(len(t))) | |
print scores, np.sum(scores), '/', len(scores), '=', np.mean(scores) | |
try: | |
with open(picklePath, 'wb') as f: | |
pickle.dump(classifier, f, protocol=pickle.HIGHEST_PROTOCOL) | |
except: | |
sys.stderr.write("Error persisting classifier to file. Are you in a protected directory\n") | |
globals().update(locals()) | |
return '' | |
def classify(data): | |
"""Predict the class of the data from a deserialized classifier. | |
""" | |
assert data.ndim == 1 | |
try: | |
with open(picklePath, 'rb') as f: | |
classifier = pickle.load(f) | |
except: | |
sys.stderr.write("Error: no classifier found. Need to train first.\n") | |
return | |
result = classifier.predict(data) | |
## print result # seems to be an array of true/false | |
globals().update(locals()) | |
return target_names[result[0]] | |
def main(): | |
## print sys.argv | |
args = docopt.docopt(__doc__, options_first=False) | |
## print args | |
if args['train']: | |
print train(*preProcess(args['TRAIN_FEATURE_FILE']), # un-tuple args | |
startNew=args['--new'], cv=args['--validate']) | |
elif args['classify']: | |
print classify(preProcess(args['MUSIC_FEATURE_FILE'])[0]) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sp4.wav,0.106326512992382,0.1419375,600.216506027543,1000, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mu1.wav,0.0135809620842338,0.245192307692308,4090.59341438305,8000,True | |
mu2.wav,0.00968006905168295,0.0965171330802044,3416.33785202295,7781.25,True | |
mu3.wav,0.00785261858254671,0.0928920764386943,1335.16906731851,2781.25,True | |
mu4.wav,0.00693985680118203,0.114189284207566,1938.03139461655,3218.75,True | |
mu5.wav,0.000804243725724518,0.100921875,2862.56607992312,7593.75,True | |
mu6.wav,0.0110080037266016,0.0568290129533274,1752.79206161524,3968.75,True | |
mu7.wav,0.000471108447527513,0.128861388459195,1398.09072062083,4593.75,True | |
mu8.wav,0.000866658810991794,0.206387362637363,3057.29499579522,5218.75,True | |
mu9.wav,0.00222460692748427,0.0922671078921079,2056.32778945909,6687.5,True | |
mu10.wav,0.00612919591367245,0.124267566680729,749.351732128472,750,True | |
mu11.wav,0.00639878120273352,0.0616570929070929,1104.58139795342,2312.5,True | |
mu12.wav,0.0103937992826104,0.235546875,2952.09057466559,6875,True | |
mu13.wav,0.0168704781681299,0.0681885654463351,1508.61129664485,6093.75,True | |
mu14.wav,0.0366614460945129,0.0399224987890436,609.916803729763,1437.5,True | |
mu15.wav,0.00200785440392792,0.111,967.358939286494,2125,True | |
mu16.wav,0.0172653328627348,0.13803125,2137.19411839309,4750,True | |
mu17.wav,0.0767187625169754,0.09871875,3026.17988619235,7812.5,True | |
mu18.wav,0.00394412688910961,0.13559375,2234.4520504372,5500,True | |
mu19.wav,0.00181142438668758,0.16890625,1762.25736668424,4000,True | |
mu20.wav,0.000672354304697365,0.12184375,2408.25157117664,5687.5,True | |
sp1.wav,0.000528156640939415,0.15546875,1030.89509309957,1656.25,False | |
sp2.wav,0.000454288354376331,0.1216875,1226.68622924213,4093.75,False | |
sp3.wav,0.000502355920616537,0.13159375,1491.06108518104,3062.5,False | |
sp4.wav,0.106326512992382,0.1419375,600.216506027543,1000,False | |
sp5.wav,0.000719425734132528,0.1834375,1203.41159032112,1906.25,False | |
sp6.wav,0.000317843368975446,0.0578484015984016,615.617688734926,1343.75,False | |
sp7.wav,0.00830729119479656,0.08121875,1311.42313711464,3281.25,False | |
sp8.wav,0.0164563357830048,0.073848026973027,433.017009278009,937.5,False | |
sp9.wav,0.150312662124634,0.103109375,534.844251956482,468.75,False | |
sp10.wav,0.139020338654518,0.143109375,612.676843957936,1062.5,False | |
sp11.wav,0.000554115045815706,0.135296875,1498.60175536014,2062.5,False | |
sp12.wav,0.00710451928898692,0.172734375,4233.96991563503,7968.75,False | |
sp13.wav,0.177054643630981,0.09575,4740.70116489936,8000,False | |
sp14.wav,0.00660737184807658,0.129464285714286,861.62212786018,4031.25,False | |
sp15.wav,0.00381537573412061,0.1440625,1136.02801116168,2250,False | |
sp16.wav,0.035685483366251,0.203984375,2065.26130383625,5375,False | |
sp17.wav,0.061883706599474,0.20421875,4407.30136680859,7656.25,False | |
sp18.wav,0.00359855033457279,0.0605176073926074,944.371410860124,3187.5,False | |
sp19.wav,0.00341002829372883,0.115993381618382,1400.15423227314,1875,False | |
sp20.wav,0.0158364344388247,0.0796391108891109,625.714886130722,1312.5,False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment