Skip to content

Instantly share code, notes, and snippets.

@luisquintanilla
Created October 20, 2022 15:25
Show Gist options
  • Save luisquintanilla/1a48ee82b9936995bee5a28d4e69d0b6 to your computer and use it in GitHub Desktop.
Save luisquintanilla/1a48ee82b9936995bee5a28d4e69d0b6 to your computer and use it in GitHub Desktop.
Throwaway AutoML Regression experiment sample
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div><div><strong>Restore sources</strong><ul><li><span>https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json</span></li></ul></div><div></div><div><strong>Installed Packages</strong><ul><li><span>Microsoft.ML.AutoML, 0.20.0-preview.22514.1</span></li></ul></div></div>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"Loading extensions from `Microsoft.ML.AutoML.Interactive.dll`"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#i \"nuget:https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json\"\n",
"\n",
"#r \"nuget: Microsoft.ML.AutoML, 0.20.0-preview.22514.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add using statements"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"using System;\n",
"using System.Linq;\n",
"using Microsoft.ML;\n",
"using Microsoft.ML.Data;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define schema classes"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"public class DataPoint\n",
"{\n",
" [ColumnName(\"Label\")]\n",
" public float y { get; set; }\n",
"\n",
" [ColumnName(\"catFeature\")]\n",
" public string str { get; set; }\n",
"\n",
" [ColumnName(\"smth\")]\n",
" public float smth { get; set; }\n",
"}\n",
"\n",
"\n",
"public class MLOutput\n",
"{\n",
" [ColumnName(\"Label\")]\n",
" public float y { get; set; } \n",
" \n",
" [ColumnName(\"Score\")]\n",
" public float score { get; set; }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Randomly generate data"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var r = new Random();\n",
"\n",
"var categories = new [] {\"Cat\",\"Dog\",\"Bird\",\"Tree\"};\n",
"\n",
"var data = \n",
" Enumerable.Range(0,100)\n",
" .Select(x => {\n",
" var categoryIdx = r.Next(categories.Length);\n",
" var s = r.NextSingle();\n",
" return new DataPoint\n",
" {\n",
" y = s*1.25f,\n",
" str = categories[categoryIdx],\n",
" smth = s\n",
" };\n",
" });"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table><thead><tr><th><i>index</i></th><th>y</th><th>str</th><th>smth</th></tr></thead><tbody><tr><td>0</td><td><div class=\"dni-plaintext\">1.0415233</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.8332187</div></td></tr><tr><td>1</td><td><div class=\"dni-plaintext\">0.50924116</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.40739292</div></td></tr><tr><td>2</td><td><div class=\"dni-plaintext\">1.1102912</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.888233</div></td></tr><tr><td>3</td><td><div class=\"dni-plaintext\">0.110690445</div></td><td>Dog</td><td><div class=\"dni-plaintext\">0.088552356</div></td></tr><tr><td>4</td><td><div class=\"dni-plaintext\">0.51553434</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.41242748</div></td></tr><tr><td>5</td><td><div class=\"dni-plaintext\">0.8606141</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.6884913</div></td></tr><tr><td>6</td><td><div class=\"dni-plaintext\">1.126381</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.90110487</div></td></tr><tr><td>7</td><td><div class=\"dni-plaintext\">1.155774</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.9246192</div></td></tr><tr><td>8</td><td><div class=\"dni-plaintext\">0.7911687</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.6329349</div></td></tr><tr><td>9</td><td><div class=\"dni-plaintext\">1.1636137</div></td><td>Dog</td><td><div class=\"dni-plaintext\">0.9308909</div></td></tr><tr><td>10</td><td><div class=\"dni-plaintext\">0.3783915</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.30271322</div></td></tr><tr><td>11</td><td><div class=\"dni-plaintext\">0.21388233</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.17110586</div></td></tr><tr><td>12</td><td><div class=\"dni-plaintext\">0.18437214</div></td><td>Dog</td><td><div class=\"dni-plaintext\">0.14749771</div></td></tr><tr><td>13</td><td><div class=\"dni-plaintext\">0.6688373</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.5350698</div></td></tr><tr><td>14</td><td><div class=\"dni-plaintext\">0.9947021</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.7957617</div></td></tr><tr><td>15</td><td><div class=\"dni-plaintext\">1.2341841</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.9873473</div></td></tr><tr><td>16</td><td><div class=\"dni-plaintext\">0.35535997</div></td><td>Cat</td><td><div class=\"dni-plaintext\">0.284288</div></td></tr><tr><td>17</td><td><div class=\"dni-plaintext\">0.10461688</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.083693504</div></td></tr><tr><td>18</td><td><div class=\"dni-plaintext\">0.82547814</div></td><td>Tree</td><td><div class=\"dni-plaintext\">0.6603825</div></td></tr><tr><td>19</td><td><div class=\"dni-plaintext\">0.9158531</div></td><td>Bird</td><td><div class=\"dni-plaintext\">0.73268247</div></td></tr><tr><td colspan=\"4\"><i>... (more)</i></td></tr></tbody></table>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize MLContext"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var ctx = new MLContext();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data to IDataView"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var idv = ctx.Data.LoadFromEnumerable(data);"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table><thead><tr><th><i>index</i></th><th>Name</th><th>Index</th><th>IsHidden</th><th>Type</th><th>Annotations</th></tr></thead><tbody><tr><td>0</td><td>Label</td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>1</td><td>catFeature</td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory&lt;System.Char&gt;</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>2</td><td>smth</td><td><div class=\"dni-plaintext\">2</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.Single</div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr></tbody></table>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"idv.Schema"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split data (80% train/20% test)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var dataSplit = ctx.Data.TrainTestSplit(idv,testFraction:0.2);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define training pipeline"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var pipeline = \n",
" ctx.Transforms.Categorical.OneHotEncoding(new [] {new InputOutputColumnPair(\"catFeatureEnc\",\"catFeature\")})\n",
" .Append(ctx.Transforms.Concatenate(\"Features\",\"catFeatureEnc\",\"smth\"))\n",
" .Append(ctx.Auto().Regression(labelColumnName: \"Label\"));"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define AutoML experiment"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var experiment = \n",
"\tctx.Auto().CreateExperiment()\n",
"\t\t.SetPipeline(pipeline)\n",
" .SetTrainingTimeInSeconds(60)\n",
" .SetDataset(dataSplit)\n",
" .SetRegressionMetric(RegressionMetric.RSquared,labelColumn:\"Label\", scoreColumn:\"Score\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run experiment"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var result = await experiment.RunAsync();"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluation metric for best model: 0.9704059480652534\r\n"
]
}
],
"source": [
"Console.WriteLine($\"Evaluation metric for best model: {result.Metric}\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make predictions"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [],
"source": [
"var bestModel = result.Model;\n",
"var predictions = bestModel.Transform(dataSplit.TestSet);"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"vscode": {
"languageId": "dotnet-interactive.csharp"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table><thead><tr><th><i>index</i></th><th>y</th><th>score</th></tr></thead><tbody><tr><td>0</td><td><div class=\"dni-plaintext\">0.13543986</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>1</td><td><div class=\"dni-plaintext\">0.14295556</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>2</td><td><div class=\"dni-plaintext\">0.25472894</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>3</td><td><div class=\"dni-plaintext\">0.16971633</div></td><td><div class=\"dni-plaintext\">0.30899113</div></td></tr><tr><td>4</td><td><div class=\"dni-plaintext\">0.47058523</div></td><td><div class=\"dni-plaintext\">0.48151553</div></td></tr><tr><td>5</td><td><div class=\"dni-plaintext\">0.17454945</div></td><td><div class=\"dni-plaintext\">0.1800263</div></td></tr><tr><td>6</td><td><div class=\"dni-plaintext\">1.2487589</div></td><td><div class=\"dni-plaintext\">1.0730814</div></td></tr><tr><td>7</td><td><div class=\"dni-plaintext\">0.4468602</div></td><td><div class=\"dni-plaintext\">0.4299633</div></td></tr><tr><td>8</td><td><div class=\"dni-plaintext\">0.24393983</div></td><td><div class=\"dni-plaintext\">0.23157853</div></td></tr><tr><td>9</td><td><div class=\"dni-plaintext\">0.7186835</div></td><td><div class=\"dni-plaintext\">0.67317843</div></td></tr><tr><td>10</td><td><div class=\"dni-plaintext\">1.2266262</div></td><td><div class=\"dni-plaintext\">1.080613</div></td></tr><tr><td>11</td><td><div class=\"dni-plaintext\">0.60866225</div></td><td><div class=\"dni-plaintext\">0.55892813</div></td></tr><tr><td>12</td><td><div class=\"dni-plaintext\">1.0635303</div></td><td><div class=\"dni-plaintext\">1.080613</div></td></tr><tr><td>13</td><td><div class=\"dni-plaintext\">1.0706856</div></td><td><div class=\"dni-plaintext\">1.080613</div></td></tr><tr><td>14</td><td><div class=\"dni-plaintext\">0.366441</div></td><td><div class=\"dni-plaintext\">0.25971332</div></td></tr></tbody></table>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ctx.Data.CreateEnumerable<MLOutput>(predictions,reuseRowObject:false)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".NET (C#)",
"language": "C#",
"name": ".net-csharp"
},
"language_info": {
"name": "C#"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment