Last active
November 13, 2022 02:11
-
-
Save luisquintanilla/164176ec414e465246d6323aa62b38df to your computer and use it in GitHub Desktop.
Bidirectional Attention Flow (BiDAF) ONNX ML.NET Sample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Bidirectional Attention Flow (BiDAF) ONNX ML.NET\n", | |
"\n", | |
"https://github.com/onnx/models/tree/main/text/machine_comprehension/bidirectional_attention_flow" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Install packages" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div><div><strong>Restore sources</strong><ul><li><span>https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json</span></li></ul></div><div></div><div><strong>Installed Packages</strong><ul><li><span>Microsoft.ML, 2.0.0-preview.22514.1</span></li><li><span>Microsoft.ML.OnnxTransformer, 2.0.0-preview.22514.1</span></li></ul></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"#i \"nuget:https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json\"\n", | |
"\n", | |
"#r \"nuget:Microsoft.ML,2.0.0-preview.22514.1\"\n", | |
"#r \"nuget:Microsoft.ML.OnnxTransformer,2.0.0-preview.22514.1\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Add using statements" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"using System.Collections.Generic;\n", | |
"using System.Linq;\n", | |
"using Microsoft.ML;\n", | |
"using Microsoft.ML.Data;" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Initialize MLContext" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var ctx = new MLContext();" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define ONNX model path\n", | |
"\n", | |
"[Download bidaf ONNX model](https://github.com/onnx/models/blob/main/text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.onnx) and update path." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var onnxModelFilePath = @\"C:\\Dev\\ONNXModels\\bidaf-9.onnx\";" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define ONNX model settings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"struct BiDAFSettings\n", | |
"{\n", | |
" public const int SeqLength = 50; // ML.NET expects a vector of known size. Change to a larger length if needed.\n", | |
" public const int ListLength = 16;\n", | |
"} " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var inputColumns = new [] {\"context_word\",\"context_char\",\"query_word\",\"query_char\"};\n", | |
"var outputColumns = new[] {\"start_pos\",\"end_pos\"};" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var shape = new Dictionary<string,int[]>()\n", | |
"{\n", | |
" {inputColumns[0],new int[] {BiDAFSettings.SeqLength,1}},\n", | |
" {inputColumns[1], new int[] {BiDAFSettings.SeqLength,1,1,BiDAFSettings.ListLength}},\n", | |
" {inputColumns[2], new int[] {BiDAFSettings.SeqLength,1}},\n", | |
" {inputColumns[3], new int[] {BiDAFSettings.SeqLength,1,1,BiDAFSettings.ListLength}}\n", | |
"};" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define ONNX model pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var onnxPipeline = \n", | |
" ctx.Transforms.ApplyOnnxModel(\n", | |
" modelFile: onnxModelFilePath,\n", | |
" inputColumnNames:inputColumns,\n", | |
" outputColumnNames:outputColumns,\n", | |
" shapeDictionary:shape,gpuDeviceId:null,fallbackToCpu:true);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load data into IDataView" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var context = \"A quick brown fox jumps over the lazy dog.\";\n", | |
"var query = \"What color is the fox?\";" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var dv = ctx.Data.LoadFromEnumerable(new [] {new {Context=context,Query=query}}); " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define tokenization pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var preprocessingPipeline = \n", | |
" ctx.Transforms.Text.NormalizeText(outputColumnName:\"NormC\",inputColumnName:\"Context\")\n", | |
" .Append(ctx.Transforms.Text.NormalizeText(outputColumnName:\"NormQ\",inputColumnName:\"Query\"))\n", | |
" .Append(ctx.Transforms.Text.TokenizeIntoWords(outputColumnName:\"WordCTokens\",inputColumnName:\"NormC\"))\n", | |
" .Append(ctx.Transforms.Text.TokenizeIntoWords(outputColumnName:\"WordQTokens\",inputColumnName:\"NormQ\"))\n", | |
" .Append(ctx.Transforms.Text.TokenizeIntoCharactersAsKeys(outputColumnName:\"CharCTokens\",inputColumnName:\"NormC\"))\n", | |
" .Append(ctx.Transforms.Text.TokenizeIntoCharactersAsKeys(outputColumnName:\"CharQTokens\",inputColumnName:\"NormQ\"))\n", | |
" .Append(ctx.Transforms.Conversion.MapKeyToValue(\"CharCTokens\"))\n", | |
" .Append(ctx.Transforms.Conversion.MapKeyToValue(\"CharQTokens\"));" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Tokenize data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var tokenizedDv = preprocessingPipeline.Fit(dv).Transform(dv);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>Name</th><th>Index</th><th>IsHidden</th><th>Type</th><th>Annotations</th></tr></thead><tbody><tr><td>0</td><td>Context</td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>1</td><td>Query</td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>2</td><td>NormC</td><td><div class=\"dni-plaintext\">2</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>3</td><td>NormQ</td><td><div class=\"dni-plaintext\">3</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>4</td><td>WordCTokens</td><td><div class=\"dni-plaintext\">4</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 0 ]</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>5</td><td>WordQTokens</td><td><div class=\"dni-plaintext\">5</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 0 ]</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>6</td><td>CharCTokens</td><td><div class=\"dni-plaintext\">6</div></td><td><div class=\"dni-plaintext\">True</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 0 ]</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">{ Key<UInt16, 0-65534>: Count: 65535, RawType: System.UInt16 }</div></td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.UInt16></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { KeyValues: Vector<String, 65535>: Name: KeyValues, Index: 0, IsHidden: False, Type: { Vector<String, 65535>: Dimensions: [ 65535 ], IsKnownSize: True, ItemType: { String: RawType: System.ReadOnlyMemory<System.Char> }, Size: 65535, RawType: Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>> }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>7</td><td>CharCTokens</td><td><div class=\"dni-plaintext\">7</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 0 ]</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>8</td><td>CharQTokens</td><td><div class=\"dni-plaintext\">8</div></td><td><div class=\"dni-plaintext\">True</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 0 ]</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">{ Key<UInt16, 0-65534>: Count: 65535, RawType: System.UInt16 }</div></td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.UInt16></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ { KeyValues: Vector<String, 65535>: Name: KeyValues, Index: 0, IsHidden: False, Type: { Vector<String, 65535>: Dimensions: [ 65535 ], IsKnownSize: True, ItemType: { String: RawType: System.ReadOnlyMemory<System.Char> }, Size: 65535, RawType: Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>> }, Annotations: { : Schema: [ ] } } ]</div></td></tr></tbody></table></td></tr><tr><td>9</td><td>CharQTokens</td><td><div class=\"dni-plaintext\">9</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 0 ]</div></td><td><div class=\"dni-plaintext\">False</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"tokenizedDv.Schema" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"public class TokenizedData \n", | |
"{\n", | |
" public string Context {get;set;}\n", | |
" public string Query {get;set;}\n", | |
" public string[] WordCTokens {get;set;}\n", | |
" public string[] WordQTokens {get;set;}\n", | |
" public string[] CharCTokens {get;set;}\n", | |
" public string[] CharQTokens {get;set;}\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Resize input columns to defined sizes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var tokenizedData = ctx.Data.CreateEnumerable<TokenizedData>(tokenizedDv,reuseRowObject:false);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>Context</th><th>Query</th><th>WordCTokens</th><th>WordQTokens</th><th>CharCTokens</th><th>CharQTokens</th></tr></thead><tbody><tr><td>0</td><td>A quick brown fox jumps over the lazy dog.</td><td>What color is the fox?</td><td><div class=\"dni-plaintext\">[ a, quick, brown, fox, jumps, over, the, lazy, dog. ]</div></td><td><div class=\"dni-plaintext\">[ what, color, is, the, fox? ]</div></td><td><div class=\"dni-plaintext\">[ <␂>, a, <␠>, q, u, i, c, k, <␠>, b, r, o, w, n, <␠>, f, o, x, <␠>, j ... (24 more) ]</div></td><td><div class=\"dni-plaintext\">[ <␂>, w, h, a, t, <␠>, c, o, l, o, r, <␠>, i, s, <␠>, t, h, e, <␠>, f ... (4 more) ]</div></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"tokenizedData" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"public class PreprocessedData \n", | |
"{\n", | |
" public string Context {get;set;}\n", | |
" public string Query {get;set;}\n", | |
" \n", | |
" [VectorType(BiDAFSettings.SeqLength,1)]\n", | |
" [ColumnName(\"context_word\")]\n", | |
" public string[] WordCTokens {get;set;}\n", | |
" \n", | |
" [VectorType(BiDAFSettings.SeqLength,1)]\n", | |
" [ColumnName(\"query_word\")]\n", | |
" public string[] WordQTokens {get;set;}\n", | |
" \n", | |
" [VectorType(BiDAFSettings.SeqLength,1,1,BiDAFSettings.ListLength)]\n", | |
" [ColumnName(\"context_char\")]\n", | |
" public string[] CharCTokens {get;set;}\n", | |
" \n", | |
" [VectorType(BiDAFSettings.SeqLength,1,1,BiDAFSettings.ListLength)]\n", | |
" [ColumnName(\"query_char\")]\n", | |
" public string[] CharQTokens {get;set;}\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var preprocessedData = \n", | |
" tokenizedData\n", | |
" .Select(x => {\n", | |
"\n", | |
" var wordCLen = Math.Max(0,BiDAFSettings.SeqLength-x.WordCTokens.Count());\n", | |
" var wordQLen = Math.Max(0,BiDAFSettings.SeqLength-x.WordQTokens.Count());\n", | |
" var charCLen = Math.Max(0,(BiDAFSettings.SeqLength*BiDAFSettings.ListLength)-x.CharCTokens.Count());\n", | |
" var charQLen = Math.Max(0,(BiDAFSettings.SeqLength*BiDAFSettings.ListLength)-x.CharQTokens.Count());\n", | |
"\n", | |
" var data = new PreprocessedData\n", | |
" {\n", | |
" Context = x.Context,\n", | |
" Query = x.Query,\n", | |
" WordCTokens = \n", | |
" x.WordCTokens\n", | |
" .Concat(Enumerable.Repeat(String.Empty,wordCLen)) // Add empty values if needed\n", | |
" .Take(BiDAFSettings.SeqLength) // Clip to max length\n", | |
" .ToArray(),\n", | |
" WordQTokens = \n", | |
" x.WordQTokens\n", | |
" .Concat(Enumerable.Repeat(String.Empty,wordQLen))\n", | |
" .Take(BiDAFSettings.SeqLength)\n", | |
" .ToArray(),\n", | |
" CharCTokens = \n", | |
" x.CharCTokens\n", | |
" .Concat(Enumerable.Repeat(String.Empty,charCLen))\n", | |
" .Take(BiDAFSettings.SeqLength*BiDAFSettings.ListLength)\n", | |
" .ToArray(),\n", | |
" CharQTokens = \n", | |
" x.CharQTokens.Concat(Enumerable.Repeat(String.Empty,charQLen))\n", | |
" .Take(BiDAFSettings.SeqLength*BiDAFSettings.ListLength)\n", | |
" .ToArray()\n", | |
" };\n", | |
"\n", | |
"\n", | |
" return data;\n", | |
" });" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load preprocessed data to IDataView" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var preprocessedDv = ctx.Data.LoadFromEnumerable(preprocessedData);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Apply ONNX pipeline transforms to preprocessed data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var predictionsDv = onnxPipeline.Fit(preprocessedDv).Transform(preprocessedDv);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>Name</th><th>Index</th><th>IsHidden</th><th>Type</th><th>Annotations</th></tr></thead><tbody><tr><td>0</td><td>Context</td><td><div class=\"dni-plaintext\">0</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>1</td><td>Query</td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">System.ReadOnlyMemory<System.Char></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>2</td><td>context_word</td><td><div class=\"dni-plaintext\">2</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 50, 1 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">50</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>3</td><td>query_word</td><td><div class=\"dni-plaintext\">3</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 50, 1 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">50</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>4</td><td>context_char</td><td><div class=\"dni-plaintext\">4</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 50, 1, 1, 16 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">800</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>5</td><td>query_char</td><td><div class=\"dni-plaintext\">5</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 50, 1, 1, 16 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ String: RawType: System.ReadOnlyMemory<System.Char> }</div></td><td><div class=\"dni-plaintext\">800</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.ReadOnlyMemory<System.Char>></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>6</td><td>start_pos</td><td><div class=\"dni-plaintext\">6</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 1 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ Int32: RawType: System.Int32 }</div></td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.Int32></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr><tr><td>7</td><td>end_pos</td><td><div class=\"dni-plaintext\">7</div></td><td><div class=\"dni-plaintext\">False</div></td><td><table><thead><tr><th>Dimensions</th><th>IsKnownSize</th><th>ItemType</th><th>Size</th><th>RawType</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ 1 ]</div></td><td><div class=\"dni-plaintext\">True</div></td><td><div class=\"dni-plaintext\">{ Int32: RawType: System.Int32 }</div></td><td><div class=\"dni-plaintext\">1</div></td><td><div class=\"dni-plaintext\">Microsoft.ML.Data.VBuffer<System.Int32></div></td></tr></tbody></table></td><td><table><thead><tr><th>Schema</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">[ ]</div></td></tr></tbody></table></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"predictionsDv.Schema" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## View model output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"public class Output \n", | |
"{\n", | |
" public string Context {get;set;}\n", | |
" public string Query {get;set;}\n", | |
" \n", | |
" [ColumnName(\"start_pos\")]\n", | |
" [VectorType(1)]\n", | |
" public int[] Start {get;set;}\n", | |
" \n", | |
" [ColumnName(\"end_pos\")]\n", | |
" [VectorType(1)]\n", | |
" public int[] End {get;set;}\n", | |
"\n", | |
" public string GetAnswer () => this.Context.Substring(this.Start[0],this.End[0]-1);\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"var modelOutput = ctx.Data.CreateEnumerable<Output>(predictionsDv,reuseRowObject:false);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>Context</th><th>Query</th><th>Start</th><th>End</th></tr></thead><tbody><tr><td>0</td><td>A quick brown fox jumps over the lazy dog.</td><td>What color is the fox?</td><td><div class=\"dni-plaintext\">[ 7 ]</div></td><td><div class=\"dni-plaintext\">[ 8 ]</div></td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"modelOutput" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"dotnet_interactive": { | |
"language": "csharp" | |
}, | |
"vscode": { | |
"languageId": "dotnet-interactive.csharp" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table><thead><tr><th><i>index</i></th><th>value</th></tr></thead><tbody><tr><td>0</td><td> brown </td></tr></tbody></table>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"modelOutput.Select(x => x.GetAnswer())" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".NET (C#)", | |
"language": "C#", | |
"name": ".net-csharp" | |
}, | |
"language_info": { | |
"name": "C#" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment