Created
September 23, 2017 01:08
-
-
Save kevmal/9661bd1f32beb0785649cbc5ab619b3b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open System | |
open System.Collections.Generic | |
open System.IO | |
open CNTK | |
let dataFolder = @"E:\Temp\CNTK\Tests\EndToEndTests\Text\SequenceClassification\Data" | |
let inputDim = 2000 | |
let cellDim = 25 | |
let hiddenDim = 25 | |
let embeddingDim = 50 | |
let numOutputClasses = 5 | |
let device = DeviceDescriptor.CPUDevice | |
let useSparseLabels = true | |
let featuresName = "features" | |
type VarOrFunc = | |
| Var of Variable | |
| Fun of Function | |
member x.Variable = | |
match x with | |
| Var v -> v | |
| Fun f -> new Variable(f) | |
member x.Function = | |
match x with | |
| Var v -> failwith "var" | |
| Fun f -> f | |
static member ( *. )(a : VarOrFunc, b : VarOrFunc) = CNTKLib.ElementTimes(a.Variable, b.Variable) |> Fun | |
static member (*)(a : VarOrFunc, b : VarOrFunc) = CNTKLib.Times(a.Variable, b.Variable) |> Fun | |
static member (+)(a : VarOrFunc, b : VarOrFunc) = (a.Variable + b.Variable) |> Fun | |
module C = | |
let log (x : VarOrFunc) = CNTKLib.Log(x.Variable) |> Fun | |
let exp (x : VarOrFunc) = CNTKLib.Exp(x.Variable) |> Fun | |
let sigmoid (x : VarOrFunc) = CNTKLib.Sigmoid(x.Variable) |> Fun | |
let tanh (x : VarOrFunc) = CNTKLib.Tanh(x.Variable) |> Fun | |
let printTrainingProgress (trainer : Trainer) minibatchIdx outputFrequencyInMinibatches = | |
if ((minibatchIdx % outputFrequencyInMinibatches) = 0 && trainer.PreviousMinibatchSampleCount() <> 0u) then | |
let trainLossValue = trainer.PreviousMinibatchLossAverage() |> float | |
let evaluationValue = trainer.PreviousMinibatchEvaluationAverage() |> float | |
printfn "Minibatch: %d CrossEntropyLoss = %f, EvaluationCriterion = %f" minibatchIdx trainLossValue evaluationValue | |
let features = Variable.InputVariable(NDShape.CreateNDShape([|inputDim|]), DataType.Float, featuresName, null, true) | |
let embedding (input : Variable) embeddingDim device = | |
let inputDim = input.Shape.[0] | |
let embeddingParameters = new Parameter(NDShape.CreateNDShape [|embeddingDim; inputDim|], DataType.Float, CNTKLib.GlorotUniformInitializer(), device) | |
Var(embeddingParameters) * Var(features) | |
let stabilize<'a> (x : Variable) device = | |
let isFloatType = typeof<'a> = typeof<float> | |
let f, fInv = | |
if isFloatType then | |
let f = Constant.Scalar(4.0f, device) | |
f, Constant.Scalar(f.DataType, 1.0 / 4.0) | |
else | |
let f = Constant.Scalar(4.0, device) | |
f, Constant.Scalar(f.DataType, 1.0 / 4.0) | |
let one = Constant.Scalar(f.DataType, 1.0) :> Variable |> Var | |
let beta = Var(fInv) *. C.log(one + C.exp(Var(f) *. Var(new Parameter(new NDShape(), f.DataType, 0.99537863, device)))) | |
beta *. Var(x) | |
let lstmPCellWithSelfStabilization<'a> (input : Variable) (prevOutput : Variable) (prevCellState : Variable) device = | |
let outputDim = prevOutput.Shape.[0] | |
let cellDim = prevCellState.Shape.[0] | |
let isFloatType = typeof<'a> = typeof<float> | |
let dataType = if isFloatType then DataType.Float else DataType.Double | |
// new Parameter(new NDShape(new uint[] { 1 }), (ElementType)(object)0.0, device, ""); | |
// TODO, how to use ElementType? | |
let createBiasParam = | |
if isFloatType then | |
fun (dim : int) -> new Parameter(NDShape.CreateNDShape[|dim|], 0.01f, device, "") :> Variable |> Var | |
else | |
fun (dim : int) -> new Parameter(NDShape.CreateNDShape[|dim|], 0.01, device, "") :> Variable |> Var | |
let mutable seed2 = 0u | |
let createProjectionParam oDim = | |
seed2 <- seed2 + 1u | |
new Parameter(NDShape.CreateNDShape [|oDim; NDShape.InferredDimension|], dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2), device) :> Variable |> Var | |
let createDiagWeightParam (dim : int) = | |
seed2 <- seed2 + 1u | |
new Parameter(NDShape.CreateNDShape[|dim|], dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2), device) :> Variable |> Var | |
let stabilizedPrevOutput = stabilize<'a> prevOutput device | |
let stabilizedPrevCellState = stabilize<'a> prevCellState device | |
let projectInput() = (createBiasParam cellDim) + (((createProjectionParam cellDim)) * Var input) | |
// Input gate | |
let it = C.sigmoid((projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + (createDiagWeightParam(cellDim) *. stabilizedPrevCellState)) | |
let bit = it *. C.tanh(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) | |
// Forget-me-not gate | |
let ft = C.sigmoid( projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput) + (createDiagWeightParam(cellDim) *. stabilizedPrevCellState)) | |
let bft = ft *. Var prevCellState | |
let ct = bft + bit | |
// Output gate | |
let ot = C.sigmoid( (projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + (createDiagWeightParam(cellDim) *. stabilize<'a> ct.Variable device)) | |
let ht = ot * C.tanh(ct) | |
let c = ct | |
let h = | |
if outputDim <> cellDim then | |
createProjectionParam(outputDim) * stabilize<'a> ht.Variable device | |
else ht | |
h, c | |
let lstmPComponentWithSelfStabilization<'a> input outputShape cellShape recurrenceHookH recurrenceHookC device = | |
let dh = Variable.PlaceholderVariable(NDShape.CreateNDShape [|hiddenDim|], features.DynamicAxes); | |
let dc = Variable.PlaceholderVariable(NDShape.CreateNDShape [|cellDim|], features.DynamicAxes); | |
let lstmCell = lstmPCellWithSelfStabilization<'a> input dh dc device | |
let actualDh = recurrenceHookH(fst lstmCell) | |
let actualDc = recurrenceHookC(snd lstmCell) | |
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc | |
Fun((fst lstmCell).Function.ReplacePlaceholders(dict [dh, new Variable(actualDh); dc, new Variable(actualDc)])), snd lstmCell | |
let fullyConnectedLinearLayer (input : Variable) outputDim device outputName = | |
let inputDim = input.Shape.[0] | |
let s = NDShape.CreateNDShape [|outputDim; inputDim|] | |
let timesParam = | |
new Parameter(s, DataType.Float, | |
CNTKLib.GlorotUniformInitializer( | |
float CNTKLib.DefaultParamInitScale, | |
CNTKLib.SentinelValueForInferParamInitRank, | |
CNTKLib.SentinelValueForInferParamInitRank, 1u), | |
device, "timesParam"); | |
let timesFunction = CNTKLib.Times(timesParam, input, "times"); | |
let s2 = NDShape.CreateNDShape [|outputDim|] | |
let plusParam = new Parameter(s2, 0.0f, device, "plusParam"); | |
CNTKLib.Plus(plusParam, new Variable(timesFunction), outputName) | |
let lstmSequenceClassifierNet input numOutputClasses embeddingDim (lstmDim : int) (cellDim : int) device outputName = | |
//LSTMSequenceClassifierNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, "classifierOutput") | |
let embeddingFunction = embedding input embeddingDim device | |
let pastValueRecurrenceHook (x : VarOrFunc) = CNTKLib.PastValue x.Variable | |
let lstmFunction, _ = lstmPComponentWithSelfStabilization<float> embeddingFunction.Variable (NDShape.CreateNDShape[|lstmDim|]) (NDShape.CreateNDShape[|cellDim|]) pastValueRecurrenceHook pastValueRecurrenceHook device | |
let thoughtVectorFunction = CNTKLib.SequenceLast(lstmFunction.Variable) | |
fullyConnectedLinearLayer (new Variable (thoughtVectorFunction)) numOutputClasses device outputName | |
let classifierOutput = lstmSequenceClassifierNet features numOutputClasses embeddingDim hiddenDim cellDim device "classifierOutput" | |
let labelsName = "labels" | |
let labels = Variable.InputVariable(NDShape.CreateNDShape [|numOutputClasses|], DataType.Float, labelsName, ResizeArray([Axis.DefaultBatchAxis()]), useSparseLabels) | |
let trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction") | |
let prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError") | |
let streamConfigurations = | |
ResizeArray [ | |
new StreamConfiguration(featuresName, inputDim, true, "x") | |
new StreamConfiguration(labelsName, numOutputClasses, false, "y") | |
] | |
let learningRatePerSample = new TrainingParameterScheduleDouble(0.0005,1u) | |
let momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256.0) | |
let parameterLearners = ResizeArray [Learner.MomentumSGDLearner(classifierOutput.Parameters(), learningRatePerSample, momentumTimeConstant, true)] | |
let trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners) | |
let minibatchSource = MinibatchSource.TextFormatMinibatchSource(Path.Combine(dataFolder, "Train.ctf"), streamConfigurations, MinibatchSource.InfinitelyRepeat, true) | |
let featureStreamInfo = minibatchSource.StreamInfo(featuresName) | |
let labelStreamInfo = minibatchSource.StreamInfo(labelsName) | |
let minibatchSize = 200u | |
let outputFrequencyInMinibatches = 20 | |
let mutable miniBatchCount = 0 | |
let mutable numEpochs = 5 | |
while numEpochs > 0 do | |
let minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device) | |
let arguments = | |
dict [ | |
features, minibatchData.[featureStreamInfo] | |
labels, minibatchData.[labelStreamInfo] | |
] | |
trainer.TrainMinibatch(arguments, device) |> ignore | |
printTrainingProgress trainer miniBatchCount outputFrequencyInMinibatches | |
miniBatchCount <- miniBatchCount + 1 | |
if minibatchData.Values |> Seq.exists (fun x -> x.sweepEnd) then | |
numEpochs <- numEpochs - 1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment