Created
June 7, 2021 20:42
-
-
Save eospi/77b597626ad00958234df4e2398c2582 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import MLCompute | |
struct Model { | |
let batchSize: Int | |
let imageSize: Int | |
let outSize: Int | |
let numberOfClasses: Int | |
let imagesTensor: MLCTensor | |
let labelsTensor: MLCTensor | |
let device = MLCDevice(type: .any)! | |
let graph = MLCGraph() | |
init(batchSize: Int, imageSize: Int, outSize: Int, numberOfClasses: Int, imagesTensor: MLCTensor, labelsTensor: MLCTensor) { | |
self.batchSize = batchSize | |
self.imageSize = imageSize | |
self.outSize = outSize | |
self.numberOfClasses = numberOfClasses | |
self.imagesTensor = imagesTensor | |
self.labelsTensor = labelsTensor | |
let inputTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [batchSize, imageSize, 1, 1], dataType: .float32)!) | |
let weightsTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [1, imageSize*outSize, 1, 1], dataType: .float32)!, randomInitializerType: .glorotUniform) | |
let biasTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [1, outSize, 1, 1], dataType: .float32)!, randomInitializerType: .glorotUniform) | |
let lossLabelTensor = MLCTensor(descriptor: MLCTensorDescriptor(shape: [batchSize, numberOfClasses], dataType: .float32)!) | |
let outDense = graph.node(with: MLCFullyConnectedLayer(weights: weightsTensor, biases: biasTensor, descriptor: MLCConvolutionDescriptor(kernelSizes: (height: 1, width: 1), inputFeatureChannelCount: imageSize, outputFeatureChannelCount: outSize))!, sources: [inputTensor]) | |
let lossLayer = MLCLossLayer.softmaxCrossEntropy(reductionType: .mean, labelSmoothing: 0.0, classCount: numberOfClasses, weight: 1.0) | |
let optimizer = MLCAdamOptimizer(descriptor: MLCOptimizerDescriptor(learningRate: 0.01, gradientRescale: 1.0, regularizationType: .none, regularizationScale: 1.0), beta1: 0.9, beta2: 0.999, epsilon: 1e-7, timeStep: 1) | |
let trainingGraph = MLCTrainingGraph(graphObjects: [graph], lossLayer: lossLayer, optimizer: optimizer) | |
trainingGraph.addInputs([imagesTensor.label : imagesTensor], lossLabels: [labelsTensor.label: labelsTensor]) | |
let lossTensor = trainingGraph.resultTensors(for: lossLayer)[0] | |
trainingGraph.addOutputs([lossTensor.label: lossTensor]) | |
guard trainingGraph.compile(options: [], device: device) else { | |
fatalError("Failed to compile the training graph") | |
} | |
let epochCount = 10 | |
// How should the images be loaded? | |
let trainingBatchCount: Int = 0 | |
let imageWidth: Int = 0 | |
let imageHeight: Int = 0 | |
let trainingImages = [Float]() | |
let trainingLabels = [Float]() | |
// Training Loop | |
for epoch in 0..<epochCount { | |
print("Epoch: \(epoch)") | |
var totalLoss: Float = 0 | |
for batchIndex in 0..<trainingBatchCount { | |
// Make a batch of images | |
let imagesSliceSize = batchSize * imageWidth * imageHeight | |
let imagesSliceOffset = batchIndex * imagesSliceSize | |
let imagesSlice = trainingImages[imagesSliceOffset ..< imagesSliceOffset+imagesSliceSize] | |
// Make a batch of labels | |
let labelsSliceSize = batchSize * numberOfClasses | |
let labelsSliceOffset = batchIndex * labelsSliceSize | |
let labelsSlice = trainingLabels[labelsSliceOffset ..< labelsSliceOffset+labelsSliceSize] | |
// Get buffer pointers | |
imagesSlice.withUnsafeBytes { imageBuffer in | |
labelsSlice.withUnsafeBytes { labelsBuffer in | |
let imagesData = MLCTensorData(immutableBytesNoCopy: labelsBuffer.baseAddress!, length: imageBuffer.count) | |
let labelsData = MLCTensorData(immutableBytesNoCopy: labelsBuffer.baseAddress!, length: labelsBuffer.count) | |
autoreleasepool { | |
// What is outputsData supposed to be? | |
// outputsData: [lossTensor.label: lossData] | |
guard trainingGraph.execute(inputsData: [imagesTensor.label: imagesData], lossLabelsData: [labelsTensor.label: labelsData], lossLabelWeightsData: nil, outputsData: nil, batchSize: batchSize, options: [.synchronous]) else { | |
fatalError("Graph execution failed") | |
} | |
} | |
} | |
} | |
// Where does this value come from? | |
let loss = lossPointer.bindMemory(to: Float.self, capacity: 1).pointee | |
totalLoss += loss | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment