Created
November 2, 2022 01:58
-
-
Save luisquintanilla/77664adc3f0e920648954244893a6bf8 to your computer and use it in GitHub Desktop.
Binary Classification Cross Validation with Platt Callibrator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Install packages" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div><div></div><div></div><div><strong>Installed Packages</strong><ul><li><span>Microsoft.Data.Analysis, 0.20.0-preview.22551.1</span></li><li><span>Microsoft.ML.AutoML, 0.20.0-preview.22551.1</span></li></ul></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/markdown": [ | |
"Loading extensions from `SkiaSharp.DotNet.Interactive.dll`" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/markdown": [ | |
"Loading extensions from `Microsoft.Data.Analysis.Interactive.dll`" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/markdown": [ | |
"Loading extensions from `Microsoft.ML.AutoML.Interactive.dll`" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"#r \"nuget: Microsoft.ML.AutoML, 0.20.0-preview.22551.1\"\n", | |
"#r \"nuget: Microsoft.Data.Analysis, 0.20.0-preview.22551.1\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Add using statements" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"using System;\n", | |
"using Microsoft.ML;\n", | |
"using Microsoft.Data.Analysis;\n", | |
"using Microsoft.ML.Data;\n", | |
"using Microsoft.ML.Transforms;" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define data path" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var dataPath = @\"C:\\Datasets\\yelp_labelled.txt\";" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Initialize MLContext" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var ctx = new MLContext();" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load data into DataFrame" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var data = DataFrame.LoadCsv(dataPath,separator:'\\t', header:false, columnNames:new [] {\"Review\",\"Label\"});" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Preview first 5 rows" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table id=\"table_638029365145485375\"><thead><tr><th><i>index</i></th><th>Review</th><th>Label</th></tr></thead><tbody><tr><td><i><div class=\"dni-plaintext\">0</div></i></td><td>Wow... Loved this place.</td><td><div class=\"dni-plaintext\">1</div></td></tr><tr><td><i><div class=\"dni-plaintext\">1</div></i></td><td>Crust is not good.</td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td><i><div class=\"dni-plaintext\">2</div></i></td><td>Not tasty and the texture was just nasty.</td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td><i><div class=\"dni-plaintext\">3</div></i></td><td>Stopped by during the late May bank holiday off Rick Steve recommendation and loved it.</td><td><div class=\"dni-plaintext\">1</div></td></tr><tr><td><i><div class=\"dni-plaintext\">4</div></i></td><td>The selection on the menu was great and so were the prices.</td><td><div class=\"dni-plaintext\">1</div></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"data.Head(5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define training pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var pipeline = \n", | |
" ctx.Transforms.Text.FeaturizeText(outputColumnName: \"Features\", inputColumnName: \"Review\")\n", | |
" .Append(ctx.Transforms.Conversion.ConvertType(outputColumnName:\"Label\",inputColumnName:\"Label\",outputKind:DataKind.Boolean))\n", | |
" .Append(ctx.BinaryClassification.Trainers.FastTree())\n", | |
" .Append(ctx.BinaryClassification.Calibrators.Platt());" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Train with cross validation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var cvResults = ctx.BinaryClassification.CrossValidate(data, pipeline, numberOfFolds: 10);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Get model for fold with top accuracy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var topModel = \n", | |
" cvResults\n", | |
" .OrderByDescending(fold => fold.Metrics.Accuracy)\n", | |
" .Select(fold => fold.Model)\n", | |
" .First();" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Use model to make predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var predictions = topModel.Transform(data);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Display prediction IDataView schema" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>Name</th><th>Index</th><th>IsHidden</th><th>Type</th><th>Annotations</th></tr></thead><tbody><tr><td>0</td><td>Review</td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>1</td><td>Label</td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">True</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>2</td><td>Label</td><td><div class=\"dni-plaintext\">2</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Boolean</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>3</td><td>Features</td><td><div class=\"dni-plaintext\">3</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 13307 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ Single: RawType: System.Single }</div></td><td><div class=\"dni-plaintext\">13307</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.Single></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { IsNormalized: Boolean: Name: IsNormalized, Index: 0, IsHidden: False, Type: { Boolean: RawType: System.Boolean }, Annotations: { : Schema: [ ] } }, { SlotNames: Vector<String, 13307>: Name: SlotNames, Index: 1, IsHidden: False, Type: { Vector<String, 13307>: Dimensions: [ 13307 ], IsKnownSize: True, ItemType: { String: RawType: System.ReadOnlyMemory<System.Char> }, Size: 13307, RawType: Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>> }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>4</td><td>PredictedLabel</td><td><div class=\"dni-plaintext\">4</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Boolean</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 0, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 1, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } }, { ScoreColumnSetId: Key<UInt32, 0-2147483646>: Name: ScoreColumnSetId, Index: 2, IsHidden: False, Type: { Key<UInt32, 0-2147483646>: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>5</td><td>Score</td><td><div class=\"dni-plaintext\">5</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 0, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } }, { ScoreColumnSetId: Key<UInt32, 0-2147483646>: Name: ScoreColumnSetId, Index: 1, IsHidden: False, Type: { Key<UInt32, 0-2147483646>: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 2, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>6</td><td>Probability</td><td><div class=\"dni-plaintext\">6</div></td><td><div class=\"dni-plaintext\">True</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 0, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } }, { ScoreColumnSetId: Key<UInt32, 0-2147483646>: Name: ScoreColumnSetId, Index: 1, IsHidden: False, Type: { Key<UInt32, 0-2147483646>: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } }, { IsNormalized: Boolean: Name: IsNormalized, Index: 2, IsHidden: False, Type: { Boolean: RawType: System.Boolean }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 3, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>7</td><td>Probability</td><td><div class=\"dni-plaintext\">7</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnSetId: Key<UInt32, 0-2147483646>: Name: ScoreColumnSetId, Index: 0, IsHidden: False, Type: { Key<UInt32, 0-2147483646>: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } }, { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 1, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 2, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory<System.Char> }, Annotations: { : Schema: [ ] } }, { IsNormalized: Boolean: Name: IsNormalized, Index: 3, IsHidden: False, Type: { Boolean: RawType: System.Boolean }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"predictions.Schema" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Preview predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var columnsToOutput = \n", | |
" predictions.Schema\n", | |
" .Where(x => x.Name != \"Features\")\n", | |
" .Select(x => x.Name)\n", | |
" .ToArray();\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var predictionDf = predictions.ToDataFrame(selectColumns:columnsToOutput);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table id=\"table_638029367486754124\"><thead><tr><th><i>index</i></th><th>Review</th><th>Label</th><th>PredictedLabel</th><th>Score</th><th>Probability</th></tr></thead><tbody><tr><td><i><div class=\"dni-plaintext\">0</div></i></td><td>Wow... Loved this place.</td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">12.9196615</div></td><td><div class=\"dni-plaintext\">0.9976944</div></td></tr><tr><td><i><div class=\"dni-plaintext\">1</div></i></td><td>Crust is not good.</td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">-17.585255</div></td><td><div class=\"dni-plaintext\">0.00026196413</div></td></tr><tr><td><i><div class=\"dni-plaintext\">2</div></i></td><td>Not tasty and the texture was just nasty.</td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">-12.287056</div></td><td><div class=\"dni-plaintext\">0.0031399904</div></td></tr><tr><td><i><div class=\"dni-plaintext\">3</div></i></td><td>Stopped by during the late May bank holiday off Rick Steve recommendation and loved it.</td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">15.230707</div></td><td><div class=\"dni-plaintext\">0.9992195</div></td></tr><tr><td><i><div class=\"dni-plaintext\">4</div></i></td><td>The selection on the menu was great and so were the prices.</td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">24.48196</div></td><td><div class=\"dni-plaintext\">0.9999898</div></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"predictionDf.Head(5)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".NET (C#)", | |
"language": "C#", | |
"name": ".net-csharp" | |
}, | |
"language_info": { | |
"name": "C#" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment