Skip to content

Instantly share code, notes, and snippets.

@luisquintanilla
Created November 2, 2022 01:58
Show Gist options
  • Save luisquintanilla/77664adc3f0e920648954244893a6bf8 to your computer and use it in GitHub Desktop.
Save luisquintanilla/77664adc3f0e920648954244893a6bf8 to your computer and use it in GitHub Desktop.
Binary Classification Cross Validation with Platt Callibrator
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div><div></div><div></div><div><strong>Installed Packages</strong><ul><li><span>Microsoft.Data.Analysis, 0.20.0-preview.22551.1</span></li><li><span>Microsoft.ML.AutoML, 0.20.0-preview.22551.1</span></li></ul></div></div>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"Loading extensions from `SkiaSharp.DotNet.Interactive.dll`"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"Loading extensions from `Microsoft.Data.Analysis.Interactive.dll`"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"Loading extensions from `Microsoft.ML.AutoML.Interactive.dll`"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#r \"nuget: Microsoft.ML.AutoML, 0.20.0-preview.22551.1\"\n",
"#r \"nuget: Microsoft.Data.Analysis, 0.20.0-preview.22551.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add using statements"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"using System;\n",
"using Microsoft.ML;\n",
"using Microsoft.Data.Analysis;\n",
"using Microsoft.ML.Data;\n",
"using Microsoft.ML.Transforms;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define data path"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var dataPath = @\"C:\\Datasets\\yelp_labelled.txt\";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize MLContext"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var ctx = new MLContext();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data into DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var data = DataFrame.LoadCsv(dataPath,separator:'\\t', header:false, columnNames:new [] {\"Review\",\"Label\"});"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preview first 5 rows"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table id=\"table_638029365145485375\"><thead><tr><th><i>index</i></th><th>Review</th><th>Label</th></tr></thead><tbody><tr><td><i><div class=\"dni-plaintext\">0</div></i></td><td>Wow... Loved this place.</td><td><div class=\"dni-plaintext\">1</div></td></tr><tr><td><i><div class=\"dni-plaintext\">1</div></i></td><td>Crust is not good.</td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td><i><div class=\"dni-plaintext\">2</div></i></td><td>Not tasty and the texture was just nasty.</td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td><i><div class=\"dni-plaintext\">3</div></i></td><td>Stopped by during the late May bank holiday off Rick Steve recommendation and loved it.</td><td><div class=\"dni-plaintext\">1</div></td></tr><tr><td><i><div class=\"dni-plaintext\">4</div></i></td><td>The selection on the menu was great and so were the prices.</td><td><div class=\"dni-plaintext\">1</div></td></tr></tbody></table>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.Head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define training pipeline"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var pipeline = \n",
" ctx.Transforms.Text.FeaturizeText(outputColumnName: \"Features\", inputColumnName: \"Review\")\n",
" .Append(ctx.Transforms.Conversion.ConvertType(outputColumnName:\"Label\",inputColumnName:\"Label\",outputKind:DataKind.Boolean))\n",
" .Append(ctx.BinaryClassification.Trainers.FastTree())\n",
" .Append(ctx.BinaryClassification.Calibrators.Platt());"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train with cross validation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var cvResults = ctx.BinaryClassification.CrossValidate(data, pipeline, numberOfFolds: 10);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get model for fold with top accuracy"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var topModel = \n",
" cvResults\n",
" .OrderByDescending(fold => fold.Metrics.Accuracy)\n",
" .Select(fold => fold.Model)\n",
" .First();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use model to make predictions"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var predictions = topModel.Transform(data);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Display prediction IDataView schema"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table><thead><tr><th><i>index</i></th><th>Name</th><th>Index</th><th>IsHidden</th><th>Type</th><th>Annotations</th></tr></thead><tbody><tr><td>0</td><td>Review</td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory&lt;System.Char&gt;</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>1</td><td>Label</td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">True</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>2</td><td>Label</td><td><div class=\"dni-plaintext\">2</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Boolean</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>3</td><td>Features</td><td><div class=\"dni-plaintext\">3</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 13307 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ Single: RawType: System.Single }</div></td><td><div class=\"dni-plaintext\">13307</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer&lt;System.Single&gt;</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { IsNormalized: Boolean: Name: IsNormalized, Index: 0, IsHidden: False, Type: { Boolean: RawType: System.Boolean }, Annotations: { : Schema: [ ] } }, { SlotNames: Vector&lt;String, 13307&gt;: Name: SlotNames, Index: 1, IsHidden: False, Type: { Vector&lt;String, 13307&gt;: Dimensions: [ 13307 ], IsKnownSize: True, ItemType: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Size: 13307, RawType: Microsoft.ML.Data.VBuffer&lt;System.ReadOnlyMemory&lt;System.Char&gt;&gt; }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>4</td><td>PredictedLabel</td><td><div class=\"dni-plaintext\">4</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Boolean</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 0, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 1, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } }, { ScoreColumnSetId: Key&lt;UInt32, 0-2147483646&gt;: Name: ScoreColumnSetId, Index: 2, IsHidden: False, Type: { Key&lt;UInt32, 0-2147483646&gt;: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>5</td><td>Score</td><td><div class=\"dni-plaintext\">5</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 0, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } }, { ScoreColumnSetId: Key&lt;UInt32, 0-2147483646&gt;: Name: ScoreColumnSetId, Index: 1, IsHidden: False, Type: { Key&lt;UInt32, 0-2147483646&gt;: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 2, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>6</td><td>Probability</td><td><div class=\"dni-plaintext\">6</div></td><td><div class=\"dni-plaintext\">True</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 0, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } }, { ScoreColumnSetId: Key&lt;UInt32, 0-2147483646&gt;: Name: ScoreColumnSetId, Index: 1, IsHidden: False, Type: { Key&lt;UInt32, 0-2147483646&gt;: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } }, { IsNormalized: Boolean: Name: IsNormalized, Index: 2, IsHidden: False, Type: { Boolean: RawType: System.Boolean }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 3, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>7</td><td>Probability</td><td><div class=\"dni-plaintext\">7</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { ScoreColumnSetId: Key&lt;UInt32, 0-2147483646&gt;: Name: ScoreColumnSetId, Index: 0, IsHidden: False, Type: { Key&lt;UInt32, 0-2147483646&gt;: Count: 2147483647, RawType: System.UInt32 }, Annotations: { : Schema: [ ] } }, { ScoreColumnKind: String: Name: ScoreColumnKind, Index: 1, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } }, { ScoreValueKind: String: Name: ScoreValueKind, Index: 2, IsHidden: False, Type: { String: RawType: System.ReadOnlyMemory&lt;System.Char&gt; }, Annotations: { : Schema: [ ] } }, { IsNormalized: Boolean: Name: IsNormalized, Index: 3, IsHidden: False, Type: { Boolean: RawType: System.Boolean }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr></tbody></table>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"predictions.Schema"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preview predictions"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var columnsToOutput = \n",
" predictions.Schema\n",
" .Where(x => x.Name != \"Features\")\n",
" .Select(x => x.Name)\n",
" .ToArray();\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var predictionDf = predictions.ToDataFrame(selectColumns:columnsToOutput);"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table id=\"table_638029367486754124\"><thead><tr><th><i>index</i></th><th>Review</th><th>Label</th><th>PredictedLabel</th><th>Score</th><th>Probability</th></tr></thead><tbody><tr><td><i><div class=\"dni-plaintext\">0</div></i></td><td>Wow... Loved this place.</td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">12.9196615</div></td><td><div class=\"dni-plaintext\">0.9976944</div></td></tr><tr><td><i><div class=\"dni-plaintext\">1</div></i></td><td>Crust is not good.</td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">-17.585255</div></td><td><div class=\"dni-plaintext\">0.00026196413</div></td></tr><tr><td><i><div class=\"dni-plaintext\">2</div></i></td><td>Not tasty and the texture was just nasty.</td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">-12.287056</div></td><td><div class=\"dni-plaintext\">0.0031399904</div></td></tr><tr><td><i><div class=\"dni-plaintext\">3</div></i></td><td>Stopped by during the late May bank holiday off Rick Steve recommendation and loved it.</td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">15.230707</div></td><td><div class=\"dni-plaintext\">0.9992195</div></td></tr><tr><td><i><div class=\"dni-plaintext\">4</div></i></td><td>The selection on the menu was great and so were the prices.</td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">24.48196</div></td><td><div class=\"dni-plaintext\">0.9999898</div></td></tr></tbody></table>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"predictionDf.Head(5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".NET (C#)",
"language": "C#",
"name": ".net-csharp"
},
"language_info": {
"name": "C#"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment