Created
January 10, 2020 02:29
-
-
Save JacopoMangiavacchi/13c3f9d6354cb121dbfd2e7a75767c88 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
import CoreML | |
func generateData(sampleSize: Int = 100) -> ([Float], [Float]) { | |
let a: Float = 2.0 | |
let b: Float = 1.5 | |
var X = [Float]() | |
var Y = [Float]() | |
for i in 0..<sampleSize { | |
let x: Float = Float(i) / Float(sampleSize) | |
let noise: Float = (Float.random(in: 0..<1) - 0.5) * 0.1 | |
let y: Float = (a * x + b) + noise | |
X.append(x) | |
Y.append(y) | |
} | |
return (X, Y) | |
} | |
func prepareTrainingBatch() -> MLBatchProvider { | |
var featureProviders = [MLFeatureProvider]() | |
let inputName = "dense_input" | |
let outputName = "output_true" | |
let (X, Y) = generateData() | |
for (x,y) in zip(X, Y) { | |
let multiArr = try! MLMultiArray(shape: [1], dataType: .double) | |
multiArr[0] = NSNumber(value: x) | |
let inputValue = MLFeatureValue(multiArray: multiArr) | |
multiArr[0] = NSNumber(value: y) | |
let outputValue = MLFeatureValue(multiArray: multiArr) | |
let dataPointFeatures: [String: MLFeatureValue] = [inputName: inputValue, | |
outputName: outputValue] | |
if let provider = try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) { | |
featureProviders.append(provider) | |
} | |
} | |
return MLArrayBatchProvider(array: featureProviders) | |
} | |
func train(url: URL) { | |
let configuration = MLModelConfiguration() | |
configuration.computeUnits = .all | |
configuration.parameters = [.epochs : 100] | |
let progressHandler = { (context: MLUpdateContext) in | |
switch context.event { | |
case .trainingBegin: | |
print("Training begin") | |
case .miniBatchEnd: | |
let batchIndex = context.metrics[.miniBatchIndex] as! Int | |
let batchLoss = context.metrics[.lossValue] as! Double | |
print("Mini batch \(batchIndex), loss: \(batchLoss)") | |
case .epochEnd: | |
let epochIndex = context.metrics[.epochIndex] as! Int | |
let trainLoss = context.metrics[.lossValue] as! Double | |
print("Epoch \(epochIndex) end with loss \(trainLoss)") | |
default: | |
print("Unknown event") | |
} | |
} | |
let completionHandler = { (context: MLUpdateContext) in | |
print("Training completed with state \(context.task.state.rawValue)") | |
print("CoreML Error: \(context.task.error.debugDescription)") | |
if context.task.state != .completed { | |
print("Failed") | |
return | |
} | |
let trainLoss = context.metrics[.lossValue] as! Double | |
print("Final loss: \(trainLoss)") | |
let updatedModel = context.model | |
let updatedModelURL = URL(fileURLWithPath: retrainedCoreMLFilePath) | |
try! updatedModel.write(to: updatedModelURL) | |
print("Model Trained!") | |
print("Press return to continue..") | |
} | |
let handlers = MLUpdateProgressHandlers( | |
forEvents: [.trainingBegin, .miniBatchEnd, .epochEnd], | |
progressHandler: progressHandler, | |
completionHandler: completionHandler) | |
let updateTask = try! MLUpdateTask(forModelAt: url, | |
trainingData: prepareTrainingBatch(), | |
configuration: configuration, | |
progressHandlers: handlers) | |
updateTask.resume() | |
} | |
train(url: compiledModelUrl) | |
// easily wait for completition of the asyncronous training task | |
let _ = readLine() | |
let retrainedModel = try! MLModel(contentsOf: URL(fileURLWithPath: retrainedCoreMLFilePath)) | |
let prediction = inferenceCoreML(model: retrainedModel, x: 1.0) | |
print(prediction) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment