Last active
January 8, 2024 05:19
-
-
Save mathias-brandewinder/6443302 to your computer and use it in GitHub Desktop.
Experimenting with Accord SVM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#r @"..\packages\Accord.2.8.1.0\lib\Accord.dll" | |
#r @"..\packages\Accord.Math.2.8.1.0\lib\Accord.Math.dll" | |
#r @"..\packages\Accord.Statistics.2.8.1.0\lib\Accord.Statistics.dll" | |
#r @"..\packages\Accord.MachineLearning.2.8.1.0\lib\Accord.MachineLearning.dll" | |
open System | |
open System.IO | |
open Accord.MachineLearning | |
open Accord.MachineLearning.VectorMachines | |
open Accord.MachineLearning.VectorMachines.Learning | |
open Accord.Statistics.Kernels | |
(* | |
The dataset I am using here is a subset of the Kaggle digit recognizer; | |
download it first on your machine, and correct path accordingly. | |
Training set of 5,000 examples: | |
http://brandewinder.blob.core.windows.net/public/trainingsample.csv | |
Validation set of 500 examples, to test your model: | |
http://brandewinder.blob.core.windows.net/public/validationsample.csv | |
*) | |
let training = @"C:/users/mathias/desktop/dojosample/trainingsample.csv" | |
let validation = @"C:/users/mathias/desktop/dojosample/validationsample.csv" | |
let readData filePath = | |
File.ReadAllLines filePath | |
|> fun lines -> lines.[1..] | |
|> Array.map (fun line -> line.Split(',')) | |
|> Array.map (fun line -> | |
(line.[0] |> Convert.ToInt32), (line.[1..] |> Array.map Convert.ToDouble)) | |
|> Array.unzip | |
let labels, observations = readData training | |
let features = 28 * 28 | |
let classes = 10 | |
let algorithm = | |
fun (svm: KernelSupportVectorMachine) | |
(classInputs: float[][]) | |
(classOutputs: int[]) (i: int) (j: int) -> | |
let strategy = SequentialMinimalOptimization(svm, classInputs, classOutputs) | |
strategy :> ISupportVectorMachineLearning | |
let kernel = Linear() | |
let svm = new MulticlassSupportVectorMachine(features, kernel, classes) | |
let learner = MulticlassSupportVectorLearning(svm, observations, labels) | |
let config = SupportVectorMachineLearningConfigurationFunction(algorithm) | |
learner.Algorithm <- config | |
let error = learner.Run() | |
printfn "Error: %f" error | |
let validationLabels, validationObservations = readData validation | |
let correct = | |
Array.zip validationLabels validationObservations | |
|> Array.map (fun (l, o) -> if l = svm.Compute(o) then 1. else 0.) | |
|> Array.average | |
let view = | |
Array.zip validationLabels validationObservations | |
|> fun x -> x.[..20] | |
|> Array.iter (fun (l, o) -> printfn "Real: %i, predicted: %i" l (svm.Compute(o))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment