Skip to content

Instantly share code, notes, and snippets.

@kevmal
Created September 23, 2017 01:08
Show Gist options
  • Save kevmal/9661bd1f32beb0785649cbc5ab619b3b to your computer and use it in GitHub Desktop.
Save kevmal/9661bd1f32beb0785649cbc5ab619b3b to your computer and use it in GitHub Desktop.
open System
open System.Collections.Generic
open System.IO
open CNTK
let dataFolder = @"E:\Temp\CNTK\Tests\EndToEndTests\Text\SequenceClassification\Data"
let inputDim = 2000
let cellDim = 25
let hiddenDim = 25
let embeddingDim = 50
let numOutputClasses = 5
let device = DeviceDescriptor.CPUDevice
let useSparseLabels = true
let featuresName = "features"
type VarOrFunc =
| Var of Variable
| Fun of Function
member x.Variable =
match x with
| Var v -> v
| Fun f -> new Variable(f)
member x.Function =
match x with
| Var v -> failwith "var"
| Fun f -> f
static member ( *. )(a : VarOrFunc, b : VarOrFunc) = CNTKLib.ElementTimes(a.Variable, b.Variable) |> Fun
static member (*)(a : VarOrFunc, b : VarOrFunc) = CNTKLib.Times(a.Variable, b.Variable) |> Fun
static member (+)(a : VarOrFunc, b : VarOrFunc) = (a.Variable + b.Variable) |> Fun
module C =
let log (x : VarOrFunc) = CNTKLib.Log(x.Variable) |> Fun
let exp (x : VarOrFunc) = CNTKLib.Exp(x.Variable) |> Fun
let sigmoid (x : VarOrFunc) = CNTKLib.Sigmoid(x.Variable) |> Fun
let tanh (x : VarOrFunc) = CNTKLib.Tanh(x.Variable) |> Fun
let printTrainingProgress (trainer : Trainer) minibatchIdx outputFrequencyInMinibatches =
if ((minibatchIdx % outputFrequencyInMinibatches) = 0 && trainer.PreviousMinibatchSampleCount() <> 0u) then
let trainLossValue = trainer.PreviousMinibatchLossAverage() |> float
let evaluationValue = trainer.PreviousMinibatchEvaluationAverage() |> float
printfn "Minibatch: %d CrossEntropyLoss = %f, EvaluationCriterion = %f" minibatchIdx trainLossValue evaluationValue
let features = Variable.InputVariable(NDShape.CreateNDShape([|inputDim|]), DataType.Float, featuresName, null, true)
let embedding (input : Variable) embeddingDim device =
let inputDim = input.Shape.[0]
let embeddingParameters = new Parameter(NDShape.CreateNDShape [|embeddingDim; inputDim|], DataType.Float, CNTKLib.GlorotUniformInitializer(), device)
Var(embeddingParameters) * Var(features)
let stabilize<'a> (x : Variable) device =
let isFloatType = typeof<'a> = typeof<float>
let f, fInv =
if isFloatType then
let f = Constant.Scalar(4.0f, device)
f, Constant.Scalar(f.DataType, 1.0 / 4.0)
else
let f = Constant.Scalar(4.0, device)
f, Constant.Scalar(f.DataType, 1.0 / 4.0)
let one = Constant.Scalar(f.DataType, 1.0) :> Variable |> Var
let beta = Var(fInv) *. C.log(one + C.exp(Var(f) *. Var(new Parameter(new NDShape(), f.DataType, 0.99537863, device))))
beta *. Var(x)
let lstmPCellWithSelfStabilization<'a> (input : Variable) (prevOutput : Variable) (prevCellState : Variable) device =
let outputDim = prevOutput.Shape.[0]
let cellDim = prevCellState.Shape.[0]
let isFloatType = typeof<'a> = typeof<float>
let dataType = if isFloatType then DataType.Float else DataType.Double
// new Parameter(new NDShape(new uint[] { 1 }), (ElementType)(object)0.0, device, "");
// TODO, how to use ElementType?
let createBiasParam =
if isFloatType then
fun (dim : int) -> new Parameter(NDShape.CreateNDShape[|dim|], 0.01f, device, "") :> Variable |> Var
else
fun (dim : int) -> new Parameter(NDShape.CreateNDShape[|dim|], 0.01, device, "") :> Variable |> Var
let mutable seed2 = 0u
let createProjectionParam oDim =
seed2 <- seed2 + 1u
new Parameter(NDShape.CreateNDShape [|oDim; NDShape.InferredDimension|], dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2), device) :> Variable |> Var
let createDiagWeightParam (dim : int) =
seed2 <- seed2 + 1u
new Parameter(NDShape.CreateNDShape[|dim|], dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2), device) :> Variable |> Var
let stabilizedPrevOutput = stabilize<'a> prevOutput device
let stabilizedPrevCellState = stabilize<'a> prevCellState device
let projectInput() = (createBiasParam cellDim) + (((createProjectionParam cellDim)) * Var input)
// Input gate
let it = C.sigmoid((projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + (createDiagWeightParam(cellDim) *. stabilizedPrevCellState))
let bit = it *. C.tanh(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput))
// Forget-me-not gate
let ft = C.sigmoid( projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput) + (createDiagWeightParam(cellDim) *. stabilizedPrevCellState))
let bft = ft *. Var prevCellState
let ct = bft + bit
// Output gate
let ot = C.sigmoid( (projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + (createDiagWeightParam(cellDim) *. stabilize<'a> ct.Variable device))
let ht = ot * C.tanh(ct)
let c = ct
let h =
if outputDim <> cellDim then
createProjectionParam(outputDim) * stabilize<'a> ht.Variable device
else ht
h, c
let lstmPComponentWithSelfStabilization<'a> input outputShape cellShape recurrenceHookH recurrenceHookC device =
let dh = Variable.PlaceholderVariable(NDShape.CreateNDShape [|hiddenDim|], features.DynamicAxes);
let dc = Variable.PlaceholderVariable(NDShape.CreateNDShape [|cellDim|], features.DynamicAxes);
let lstmCell = lstmPCellWithSelfStabilization<'a> input dh dc device
let actualDh = recurrenceHookH(fst lstmCell)
let actualDc = recurrenceHookC(snd lstmCell)
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
Fun((fst lstmCell).Function.ReplacePlaceholders(dict [dh, new Variable(actualDh); dc, new Variable(actualDc)])), snd lstmCell
let fullyConnectedLinearLayer (input : Variable) outputDim device outputName =
let inputDim = input.Shape.[0]
let s = NDShape.CreateNDShape [|outputDim; inputDim|]
let timesParam =
new Parameter(s, DataType.Float,
CNTKLib.GlorotUniformInitializer(
float CNTKLib.DefaultParamInitScale,
CNTKLib.SentinelValueForInferParamInitRank,
CNTKLib.SentinelValueForInferParamInitRank, 1u),
device, "timesParam");
let timesFunction = CNTKLib.Times(timesParam, input, "times");
let s2 = NDShape.CreateNDShape [|outputDim|]
let plusParam = new Parameter(s2, 0.0f, device, "plusParam");
CNTKLib.Plus(plusParam, new Variable(timesFunction), outputName)
let lstmSequenceClassifierNet input numOutputClasses embeddingDim (lstmDim : int) (cellDim : int) device outputName =
//LSTMSequenceClassifierNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, "classifierOutput")
let embeddingFunction = embedding input embeddingDim device
let pastValueRecurrenceHook (x : VarOrFunc) = CNTKLib.PastValue x.Variable
let lstmFunction, _ = lstmPComponentWithSelfStabilization<float> embeddingFunction.Variable (NDShape.CreateNDShape[|lstmDim|]) (NDShape.CreateNDShape[|cellDim|]) pastValueRecurrenceHook pastValueRecurrenceHook device
let thoughtVectorFunction = CNTKLib.SequenceLast(lstmFunction.Variable)
fullyConnectedLinearLayer (new Variable (thoughtVectorFunction)) numOutputClasses device outputName
let classifierOutput = lstmSequenceClassifierNet features numOutputClasses embeddingDim hiddenDim cellDim device "classifierOutput"
let labelsName = "labels"
let labels = Variable.InputVariable(NDShape.CreateNDShape [|numOutputClasses|], DataType.Float, labelsName, ResizeArray([Axis.DefaultBatchAxis()]), useSparseLabels)
let trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction")
let prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError")
let streamConfigurations =
ResizeArray [
new StreamConfiguration(featuresName, inputDim, true, "x")
new StreamConfiguration(labelsName, numOutputClasses, false, "y")
]
let learningRatePerSample = new TrainingParameterScheduleDouble(0.0005,1u)
let momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256.0)
let parameterLearners = ResizeArray [Learner.MomentumSGDLearner(classifierOutput.Parameters(), learningRatePerSample, momentumTimeConstant, true)]
let trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners)
let minibatchSource = MinibatchSource.TextFormatMinibatchSource(Path.Combine(dataFolder, "Train.ctf"), streamConfigurations, MinibatchSource.InfinitelyRepeat, true)
let featureStreamInfo = minibatchSource.StreamInfo(featuresName)
let labelStreamInfo = minibatchSource.StreamInfo(labelsName)
let minibatchSize = 200u
let outputFrequencyInMinibatches = 20
let mutable miniBatchCount = 0
let mutable numEpochs = 5
while numEpochs > 0 do
let minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device)
let arguments =
dict [
features, minibatchData.[featureStreamInfo]
labels, minibatchData.[labelStreamInfo]
]
trainer.TrainMinibatch(arguments, device) |> ignore
printTrainingProgress trainer miniBatchCount outputFrequencyInMinibatches
miniBatchCount <- miniBatchCount + 1
if minibatchData.Values |> Seq.exists (fun x -> x.sweepEnd) then
numEpochs <- numEpochs - 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment