Created
January 31, 2015 04:13
-
-
Save mathias-brandewinder/0685734f8f47d811356e to your computer and use it in GitHub Desktop.
Accord.NET multi class SVM on digit recognizer problem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Uses Accord.NET version 2.14.0 | |
#r @"..\packages\Accord.2.14.0\lib\net40\Accord.dll" | |
#r @"..\packages\Accord.Math.2.14.0\lib\net40\Accord.Math.dll" | |
#r @"..\packages\Accord.Statistics.2.14.0\lib\net40\Accord.Statistics.dll" | |
#r @"..\packages\Accord.MachineLearning.2.14.0\lib\net40\Accord.MachineLearning.dll" | |
open System | |
open System.IO | |
open Accord.MachineLearning | |
open Accord.MachineLearning.VectorMachines | |
open Accord.MachineLearning.VectorMachines.Learning | |
open Accord.Statistics.Kernels | |
(* | |
The dataset I am using here is a subset of the Kaggle digit recognizer; | |
download it first on your machine, and correct path accordingly. | |
Training set of 5,000 examples: | |
http://brandewinder.blob.core.windows.net/public/trainingsample.csv | |
Validation set of 500 examples, to test your model: | |
http://brandewinder.blob.core.windows.net/public/validationsample.csv | |
*) | |
let root = __SOURCE_DIRECTORY__ | |
let training = root + "/trainingsample.csv" | |
let validation = root + "/validationsample.csv" | |
let readData filePath = | |
File.ReadAllLines filePath | |
|> fun lines -> lines.[1..] | |
|> Array.map (fun line -> line.Split(',')) | |
|> Array.map (fun line -> | |
(line.[0] |> Convert.ToInt32), (line.[1..] |> Array.map Convert.ToDouble)) | |
|> Array.unzip | |
let labels, observations = readData training | |
(* | |
Note that while this is a small dataset, loading the data | |
is an expensive part of the process. | |
*) | |
let features = 28 * 28 | |
let classes = 10 | |
let algorithm = | |
fun (svm: KernelSupportVectorMachine) | |
(classInputs: float[][]) | |
(classOutputs: int[]) (i: int) (j: int) -> | |
let strategy = SequentialMinimalOptimization(svm, classInputs, classOutputs) | |
strategy.Complexity <- 0.9 | |
strategy :> ISupportVectorMachineLearning | |
let kernel = Linear() | |
let svm = new MulticlassSupportVectorMachine(features, kernel, classes) | |
let learner = MulticlassSupportVectorLearning(svm, observations, labels) | |
let config = SupportVectorMachineLearningConfigurationFunction(algorithm) | |
learner.Algorithm <- config | |
let error = learner.Run() | |
printfn "Error: %f" error | |
(* | |
Are we done yet? Not quite. | |
The proof of the model is in how it deals with data | |
it has never seen before, hence the validation set. | |
*) | |
let validationLabels, validationObservations = readData validation | |
let correct = | |
Array.zip validationLabels validationObservations | |
|> Array.map (fun (l, o) -> if l = svm.Compute(o) then 1. else 0.) | |
|> Array.average |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment