Created
April 21, 2020 19:10
-
-
Save JacopoMangiavacchi/fc0ebf6074a0b4ddd2c8ee1993e36245 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func prepareBatchProvider() -> MLBatchProvider { | |
var featureProviders = [MLFeatureProvider]() | |
var count = 0 | |
errno = 0 | |
let trainFilePath = Bundle.main.url(forResource: "mnist_train", withExtension: "csv")! | |
if freopen(trainFilePath.path, "r", stdin) == nil { | |
print("error opening file") | |
} | |
while let line = readLine()?.split(separator: ",") { | |
count += 1 | |
let imageMultiArr = try! MLMultiArray(shape: [1, 28, 28], dataType: .float32) | |
let outputMultiArr = try! MLMultiArray(shape: [1], dataType: .int32) | |
for r in 0..<28 { | |
for c in 0..<28 { | |
let i = (r*28)+c | |
imageMultiArr[i] = NSNumber(value: Float(String(line[i + 1]))! / Float(255.0)) | |
} | |
} | |
outputMultiArr[0] = NSNumber(value: Int(String(line[0]))!) | |
let imageValue = MLFeatureValue(multiArray: imageMultiArr) | |
let outputValue = MLFeatureValue(multiArray: outputMultiArr) | |
let dataPointFeatures: [String: MLFeatureValue] = ["image": imageValue, | |
"output_true": outputValue] | |
if let provider = try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) { | |
featureProviders.append(provider) | |
} | |
} | |
return MLArrayBatchProvider(array: featureProviders) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment