Created
July 19, 2018 17:48
-
-
Save praeclarum/7b5029656962864936d7667ae2f4a624 to your computer and use it in GitHub Desktop.
Predicts then next C# tokens given a history of previous tokens using CoreML on iOS with F#
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Given previous tokens, predict the next token (and runners up) | |
let predictNextToken (previousKinds : SyntaxKind[]) : Prediction[] = | |
if ios11 then | |
let model : MLModel = model.Value // Load the cached model | |
let mutable predictions : Prediction[] = [| |] | |
// RNNs require external memory | |
let mutable lstm_1_h : MLMultiArray = null | |
let mutable lstm_1_c : MLMultiArray = null | |
// Run the model for each previous token | |
let inputKeys1 = [| s_prevVectorizedToken |] | |
let inputKeys3 = [| s_prevVectorizedToken; s_lstm_1_h_in; s_lstm_1_c_in |] | |
let mutable error : NSError = null | |
for kindIndex, prevKind in previousKinds |> Array.indexed do | |
// Convert the token to a vector for the model | |
let vectorizedToken = CSharpPredictor.kindToVector prevKind | |
// The first run doesn't include the memory | |
let inputKeys, inputValues = if lstm_1_h <> null then inputKeys3, [| vectorizedToken :> NSObject; lstm_1_h :> NSObject; lstm_1_c :> NSObject |] | |
else inputKeys1, [| vectorizedToken :> NSObject |] | |
let inputDict = NSDictionary<NSString, NSObject>.FromObjectsAndKeys (inputValues, inputKeys, System.nint inputKeys.Length) | |
let inputFeatures = new MLDictionaryFeatureProvider (inputDict, &error) | |
// Run the prediction | |
match model.GetPrediction (inputFeatures) with | |
| _, error when error <> null -> | |
Debug.WriteLine (error) | |
failwith "Prediction failed" | |
| output, _ -> | |
lstm_1_h <- output.GetFeatureValue("lstm_1_h_out").MultiArrayValue | |
lstm_1_c <- output.GetFeatureValue("lstm_1_c_out").MultiArrayValue | |
// If this is the last prediction, store the results | |
if kindIndex = previousKinds.Length - 1 then | |
predictions <- | |
output.GetFeatureValue("nextTokenProbabilities").DictionaryValue | |
:> Collections.Generic.IDictionary<NSObject, NSNumber> | |
|> Seq.map (fun (x : System.Collections.Generic.KeyValuePair<NSObject, NSNumber>) -> string x.Key, x.Value.DoubleValue) | |
|> Seq.filter (fun (_, p) -> p > 1.0e-4) | |
|> Seq.sortBy (fun (_, p) -> -p) | |
|> Seq.map (fun (tokenName, p) -> | |
let kind = CSharpPredictor.stringToSyntaxKind tokenName | |
let insertText, formatText = CSharpPredictor.kindToCompletion kind | |
kind, insertText, formatText, p) | |
|> Array.ofSeq | |
predictions | |
else | |
[| |] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment