Created
October 20, 2022 15:25
-
-
Save luisquintanilla/1a48ee82b9936995bee5a28d4e69d0b6 to your computer and use it in GitHub Desktop.
Throwaway AutoML Regression experiment sample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Install packages" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div><div><strong>Restore sources</strong><ul><li><span>https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json</span></li></ul></div><div></div><div><strong>Installed Packages</strong><ul><li><span>Microsoft.ML.AutoML, 0.20.0-preview.22514.1</span></li></ul></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/markdown": [ | |
"Loading extensions from `Microsoft.ML.AutoML.Interactive.dll`" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"#i \"nuget:https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json\"\n", | |
"\n", | |
"#r \"nuget: Microsoft.ML.AutoML, 0.20.0-preview.22514.1\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Add using statements" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"using System;\n", | |
"using System.Linq;\n", | |
"using Microsoft.ML;\n", | |
"using Microsoft.ML.Data;" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define schema classes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"public class DataPoint\n", | |
"{\n", | |
" [ColumnName(\"Label\")]\n", | |
" public float y { get; set; }\n", | |
"\n", | |
" [ColumnName(\"catFeature\")]\n", | |
" public string str { get; set; }\n", | |
"\n", | |
" [ColumnName(\"smth\")]\n", | |
" public float smth { get; set; }\n", | |
"}\n", | |
"\n", | |
"\n", | |
"public class MLOutput\n", | |
"{\n", | |
" [ColumnName(\"Label\")]\n", | |
" public float y { get; set; } \n", | |
" \n", | |
" [ColumnName(\"Score\")]\n", | |
" public float score { get; set; }\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Randomly generate data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var r = new Random();\n", | |
"\n", | |
"var categories = new [] {\"Cat\",\"Dog\",\"Bird\",\"Tree\"};\n", | |
"\n", | |
"var data = \n", | |
" Enumerable.Range(0,100)\n", | |
" .Select(x => {\n", | |
" var categoryIdx = r.Next(categories.Length);\n", | |
" var s = r.NextSingle();\n", | |
" return new DataPoint\n", | |
" {\n", | |
" y = s*1.25f,\n", | |
" str = categories[categoryIdx],\n", | |
" smth = s\n", | |
" };\n", | |
" });" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>y</th><th>str</th><th>smth</th></tr></thead><tbody><tr><td>0</td><td><div class=\"dni-plaintext\">1.0415233</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.8332187</div></td></tr><tr><td>1</td><td><div class=\"dni-plaintext\">0.50924116</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.40739292</div></td></tr><tr><td>2</td><td><div class=\"dni-plaintext\">1.1102912</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.888233</div></td></tr><tr><td>3</td><td><div class=\"dni-plaintext\">0.110690445</div></td><td>Dog</td><td><div class=\"dni-plaintext\">0.088552356</div></td></tr><tr><td>4</td><td><div class=\"dni-plaintext\">0.51553434</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.41242748</div></td></tr><tr><td>5</td><td><div class=\"dni-plaintext\">0.8606141</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.6884913</div></td></tr><tr><td>6</td><td><div class=\"dni-plaintext\">1.126381</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.90110487</div></td></tr><tr><td>7</td><td><div class=\"dni-plaintext\">1.155774</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.9246192</div></td></tr><tr><td>8</td><td><div class=\"dni-plaintext\">0.7911687</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.6329349</div></td></tr><tr><td>9</td><td><div class=\"dni-plaintext\">1.1636137</div></td><td>Dog</td><td><div class=\"dni-plaintext\">0.9308909</div></td></tr><tr><td>10</td><td><div class=\"dni-plaintext\">0.3783915</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.30271322</div></td></tr><tr><td>11</td><td><div class=\"dni-plaintext\">0.21388233</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.17110586</div></td></tr><tr><td>12</td><td><div class=\"dni-plaintext\">0.18437214</div></td><td>Dog</td><td><div class=\"dni-plaintext\">0.14749771</div></td></tr><tr><td>13</td><td><div class=\"dni-plaintext\">0.6688373</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.5350698</div></td></tr><tr><td>14</td><td><div class=\"dni-plaintext\">0.9947021</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.7957617</div></td></tr><tr><td>15</td><td><div class=\"dni-plaintext\">1.2341841</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.9873473</div></td></tr><tr><td>16</td><td><div class=\"dni-plaintext\">0.35535997</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.284288</div></td></tr><tr><td>17</td><td><div class=\"dni-plaintext\">0.10461688</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.083693504</div></td></tr><tr><td>18</td><td><div class=\"dni-plaintext\">0.82547814</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.6603825</div></td></tr><tr><td>19</td><td><div class=\"dni-plaintext\">0.9158531</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.73268247</div></td></tr><tr><td colspan=\"4\"><i>... (more)</i></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Initialize MLContext" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var ctx = new MLContext();" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load data to IDataView" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var idv = ctx.Data.LoadFromEnumerable(data);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>Name</th><th>Index</th><th>IsHidden</th><th>Type</th><th>Annotations</th></tr></thead><tbody><tr><td>0</td><td>Label</td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>1</td><td>catFeature</td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>2</td><td>smth</td><td><div class=\"dni-plaintext\">2</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"idv.Schema" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Split data (80% train/20% test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var dataSplit = ctx.Data.TrainTestSplit(idv,testFraction:0.2);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define training pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var pipeline = \n", | |
" ctx.Transforms.Categorical.OneHotEncoding(new [] {new InputOutputColumnPair(\"catFeatureEnc\",\"catFeature\")})\n", | |
" .Append(ctx.Transforms.Concatenate(\"Features\",\"catFeatureEnc\",\"smth\"))\n", | |
" .Append(ctx.Auto().Regression(labelColumnName: \"Label\"));" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define AutoML experiment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var experiment = \n", | |
"\tctx.Auto().CreateExperiment()\n", | |
"\t\t.SetPipeline(pipeline)\n", | |
" .SetTrainingTimeInSeconds(60)\n", | |
" .SetDataset(dataSplit)\n", | |
" .SetRegressionMetric(RegressionMetric.RSquared,labelColumn:\"Label\", scoreColumn:\"Score\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Run experiment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var result = await experiment.RunAsync();" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Evaluation metric for best model: 0.9704059480652534\r\n" | |
] | |
} | |
], | |
"source": [ | |
"Console.WriteLine($\"Evaluation metric for best model: {result.Metric}\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Make predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var bestModel = result.Model;\n", | |
"var predictions = bestModel.Transform(dataSplit.TestSet);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>y</th><th>score</th></tr></thead><tbody><tr><td>0</td><td><div class=\"dni-plaintext\">0.13543986</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>1</td><td><div class=\"dni-plaintext\">0.14295556</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>2</td><td><div class=\"dni-plaintext\">0.25472894</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>3</td><td><div class=\"dni-plaintext\">0.16971633</div></td><td><div class=\"dni-plaintext\">0.30899113</div></td></tr><tr><td>4</td><td><div class=\"dni-plaintext\">0.47058523</div></td><td><div class=\"dni-plaintext\">0.48151553</div></td></tr><tr><td>5</td><td><div class=\"dni-plaintext\">0.17454945</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>6</td><td><div class=\"dni-plaintext\">1.2487589</div></td><td><div class=\"dni-plaintext\">1.0730814</div></td></tr><tr><td>7</td><td><div class=\"dni-plaintext\">0.4468602</div></td><td><div class=\"dni-plaintext\">0.4299633</div></td></tr><tr><td>8</td><td><div class=\"dni-plaintext\">0.24393983</div></td><td><div class=\"dni-plaintext\">0.23157853</div></td></tr><tr><td>9</td><td><div class=\"dni-plaintext\">0.7186835</div></td><td><div class=\"dni-plaintext\">0.67317843</div></td></tr><tr><td>10</td><td><div class=\"dni-plaintext\">1.2266262</div></td><td><div class=\"dni-plaintext\">1.080613</div></td></tr><tr><td>11</td><td><div class=\"dni-plaintext\">0.60866225</div></td><td><div class=\"dni-plaintext\">0.55892813</div></td></tr><tr><td>12</td><td><div class=\"dni-plaintext\">1.0635303</div></td><td><div class=\"dni-plaintext\">1.080613</div></td></tr><tr><td>13</td><td><div class=\"dni-plaintext\">1.0706856</div></td><td><div class=\"dni-plaintext\">1.080613</div></td></tr><tr><td>14</td><td><div class=\"dni-plaintext\">0.366441</div></td><td><div class=\"dni-plaintext\">0.25971332</div></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"ctx.Data.CreateEnumerable<MLOutput>(predictions,reuseRowObject:false)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".NET (C#)", | |
"language": "C#", | |
"name": ".net-csharp" | |
}, | |
"language_info": { | |
"name": "C#" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment