Last active
July 23, 2019 07:39
-
-
Save t-ae/3cd33e4f0535b98c2df9bfeef49645e5 to your computer and use it in GitHub Desktop.
VAE on Swift for TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// VAE by modifying official autoencoder code | |
// https://github.com/tensorflow/swift-models/blob/2fa11ba1d28ef09454af9da77e22b585cf3e5b7b/Autoencoder/main.swift | |
// Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
import Foundation | |
import TensorFlow | |
import Python | |
// Import Python modules | |
let matplotlib = Python.import("matplotlib") | |
let np = Python.import("numpy") | |
let plt = Python.import("matplotlib.pyplot") | |
// Turn off using display on server / linux | |
matplotlib.use("Agg") | |
// Some globals | |
let epochCount = 50 | |
let batchSize = 128 | |
let outputFolder = "./output/" | |
let imageHeight = 28, imageWidth = 28 | |
func plot(image: [Float], name: String) { | |
// Create figure | |
let ax = plt.gca() | |
let array = np.array([image]) | |
let pixels = array.reshape([imageHeight, imageWidth]) | |
if !FileManager.default.fileExists(atPath: outputFolder) { | |
try! FileManager.default.createDirectory(atPath: outputFolder, | |
withIntermediateDirectories: false, | |
attributes: nil) | |
} | |
ax.imshow(pixels, cmap: "gray") | |
plt.savefig("\(outputFolder)\(name).png", dpi: 300) | |
plt.close() | |
} | |
/// Reads a file into an array of bytes. | |
func readFile(_ filename: String) -> [UInt8] { | |
let possibleFolders = [".", "Resources", "Autoencoder/Resources"] | |
for folder in possibleFolders { | |
let parent = URL(fileURLWithPath: folder) | |
let filePath = parent.appendingPathComponent(filename).path | |
guard FileManager.default.fileExists(atPath: filePath) else { | |
continue | |
} | |
let d = Python.open(filePath, "rb").read() | |
return Array(numpy: np.frombuffer(d, dtype: np.uint8))! | |
} | |
print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).") | |
exit(-1) | |
} | |
/// Reads MNIST images and labels from specified file paths. | |
func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>, | |
labels: Tensor<Int32>) { | |
print("Reading data.") | |
let images = readFile(imagesFile).dropFirst(16).map { Float($0) } | |
let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) } | |
let rowCount = labels.count | |
print("Constructing data tensors.") | |
return ( | |
images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0, | |
labels: Tensor(labels) | |
) | |
} | |
struct Encoder: Layer { | |
typealias Input = Tensor<Float> | |
typealias Output = Encoded | |
var encoder1 = Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128, | |
activation: relu) | |
var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu) | |
var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu) | |
var encoderMean = Dense<Float>(inputSize: 12, outputSize: 4, activation: identity) | |
var encoderLogVar = Dense<Float>(inputSize: 12, outputSize: 4, activation: identity) | |
@differentiable | |
func callAsFunction(_ input: Input) -> Output { | |
let intermediate = input.sequenced(through: encoder1, encoder2, encoder3) | |
let mean = encoderMean(intermediate) | |
let logVar = encoderLogVar(intermediate) | |
return Encoded(mean: mean, logVar: logVar) | |
} | |
} | |
struct Encoded: Differentiable { | |
var mean: Tensor<Float> | |
var logVar: Tensor<Float> | |
} | |
struct Decoder: Layer { | |
typealias Input = Tensor<Float> | |
typealias Output = Tensor<Float> | |
var decoder1 = Dense<Float>(inputSize: 4, outputSize: 12, activation: relu) | |
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu) | |
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu) | |
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth, | |
activation: tanh) | |
@differentiable | |
func callAsFunction(_ input: Input) -> Output { | |
return input.sequenced(through: decoder1, decoder2, decoder3, decoder4) | |
} | |
} | |
struct VAE: Layer { | |
typealias Input = Tensor<Float> | |
typealias Output = VAEResult | |
var encoder = Encoder() | |
var decoder = Decoder() | |
@differentiable | |
func callAsFunction(_ input: Input) -> Output { | |
let encoded = encoder(input) | |
let mean = encoded.mean | |
let logVar = encoded.logVar | |
let gaussian = Tensor<Float>(randomNormal: mean.shape) | |
let std = exp(logVar/2) | |
let images = decoder(gaussian * std + mean) | |
return VAEResult(image: images, mean: mean, logVar: logVar) | |
} | |
} | |
struct VAEResult: Differentiable { | |
var image: Tensor<Float> | |
var mean: Tensor<Float> | |
var logVar: Tensor<Float> | |
} | |
@differentiable | |
func loss(result: VAEResult, original: Tensor<Float>) -> Tensor<Float> { | |
let reconstrcutionLoss = (result.image - original).squared().sum(alongAxes: 1) | |
let klLoss = (1 + result.logVar - result.mean.squared() - exp(result.logVar)) | |
.sum(alongAxes: 1) * -0.5 | |
return (reconstrcutionLoss + klLoss).mean() | |
} | |
// MNIST data logic | |
func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> { | |
let start = index * batchSize | |
return x[start..<start+batchSize] | |
} | |
let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte", | |
labelsFile: "train-labels-idx1-ubyte") | |
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10) | |
var vae = VAE() | |
let optimizer = Adam(for: vae) | |
// Training loop | |
for epoch in 1...epochCount { | |
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: images[epoch].scalars) | |
let testResult = vae(sampleImage) | |
plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input") | |
plot(image: testResult.image.scalars, name: "epoch-\(epoch)-output") | |
let sampleLoss = loss(result: testResult, original: sampleImage) | |
print("[Epoch: \(epoch)] Loss: \(sampleLoss)") | |
for i in 0 ..< Int(labels.shape[0]) / batchSize { | |
let x = minibatch(in: images, at: i) | |
let 𝛁model = vae.gradient { vae -> Tensor<Float> in | |
let result = vae(x) | |
return loss(result: result, original: x) | |
} | |
optimizer.update(&vae.allDifferentiableVariables, along: 𝛁model) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment