Skip to content

Instantly share code, notes, and snippets.

@eospi
Created June 7, 2021 20:42
Show Gist options
  • Save eospi/77b597626ad00958234df4e2398c2582 to your computer and use it in GitHub Desktop.
Save eospi/77b597626ad00958234df4e2398c2582 to your computer and use it in GitHub Desktop.
import MLCompute
struct Model {
let batchSize: Int
let imageSize: Int
let outSize: Int
let numberOfClasses: Int
let imagesTensor: MLCTensor
let labelsTensor: MLCTensor
let device = MLCDevice(type: .any)!
let graph = MLCGraph()
init(batchSize: Int, imageSize: Int, outSize: Int, numberOfClasses: Int, imagesTensor: MLCTensor, labelsTensor: MLCTensor) {
self.batchSize = batchSize
self.imageSize = imageSize
self.outSize = outSize
self.numberOfClasses = numberOfClasses
self.imagesTensor = imagesTensor
self.labelsTensor = labelsTensor
let inputTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [batchSize, imageSize, 1, 1], dataType: .float32)!)
let weightsTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [1, imageSize*outSize, 1, 1], dataType: .float32)!, randomInitializerType: .glorotUniform)
let biasTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [1, outSize, 1, 1], dataType: .float32)!, randomInitializerType: .glorotUniform)
let lossLabelTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [batchSize, numberOfClasses], dataType: .float32)!)
let outDense = graph.node(with: MLCFullyConnectedLayer(weights: weightsTensor, biases: biasTensor, descriptor: MLCConvolutionDescriptor(kernelSizes: (height: 1, width: 1), inputFeatureChannelCount: imageSize, outputFeatureChannelCount: outSize))!, sources: [inputTensor])
let lossLayer = MLCLossLayer.softmaxCrossEntropy(reductionType: .mean, labelSmoothing: 0.0, classCount: numberOfClasses, weight: 1.0)
let optimizer = MLCAdamOptimizer(descriptor: MLCOptimizerDescriptor(learningRate: 0.01, gradientRescale: 1.0, regularizationType: .none, regularizationScale: 1.0), beta1: 0.9, beta2: 0.999, epsilon: 1e-7, timeStep: 1)
let trainingGraph = MLCTrainingGraph(graphObjects: [graph], lossLayer: lossLayer, optimizer: optimizer)
trainingGraph.addInputs([imagesTensor.label : imagesTensor], lossLabels: [labelsTensor.label: labelsTensor])
let lossTensor = trainingGraph.resultTensors(for: lossLayer)[0]
trainingGraph.addOutputs([lossTensor.label: lossTensor])
guard trainingGraph.compile(options: [], device: device) else {
fatalError("Failed to compile the training graph")
}
let epochCount = 10
// How should the images be loaded?
let trainingBatchCount: Int = 0
let imageWidth: Int = 0
let imageHeight: Int = 0
let trainingImages = [Float]()
let trainingLabels = [Float]()
// Training Loop
for epoch in 0..<epochCount {
print("Epoch: \(epoch)")
var totalLoss: Float = 0
for batchIndex in 0..<trainingBatchCount {
// Make a batch of images
let imagesSliceSize = batchSize * imageWidth * imageHeight
let imagesSliceOffset = batchIndex * imagesSliceSize
let imagesSlice = trainingImages[imagesSliceOffset ..< imagesSliceOffset+imagesSliceSize]
// Make a batch of labels
let labelsSliceSize = batchSize * numberOfClasses
let labelsSliceOffset = batchIndex * labelsSliceSize
let labelsSlice = trainingLabels[labelsSliceOffset ..< labelsSliceOffset+labelsSliceSize]
// Get buffer pointers
imagesSlice.withUnsafeBytes { imageBuffer in
labelsSlice.withUnsafeBytes { labelsBuffer in
let imagesData = MLCTensorData(immutableBytesNoCopy: labelsBuffer.baseAddress!, length: imageBuffer.count)
let labelsData = MLCTensorData(immutableBytesNoCopy: labelsBuffer.baseAddress!, length: labelsBuffer.count)
autoreleasepool {
// What is outputsData supposed to be?
// outputsData: [lossTensor.label: lossData]
guard trainingGraph.execute(inputsData: [imagesTensor.label: imagesData], lossLabelsData: [labelsTensor.label: labelsData], lossLabelWeightsData: nil, outputsData: nil, batchSize: batchSize, options: [.synchronous]) else {
fatalError("Graph execution failed")
}
}
}
}
// Where does this value come from?
let loss = lossPointer.bindMemory(to: Float.self, capacity: 1).pointee
totalLoss += loss
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment