{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:exla, "~> 0.2.0", github: "elixir-nx/nx", sparse: "exla", override: true},
{:nx, "~> 0.2.0", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.5"}
system_env: [
{"XLA_TARGET", "cuda111"},
{"EXLA_TARGET", "cuda"},
An autoencoder is a deep learning model which consists of two parts: encoder and decoder. The encoder compresses high dimensional data into a low dimensional representation and feeds it to the decoder. The decoder tries to recreate the original data from the low dimensional representation. Autoencoders can be used in the following problems:
- Dimensionality reduction
- Noise reduction
- Generative models
- Data augmentation
Let's walk through a basic autoencoder implementation in Axon to get a better understanding of how they work in practice.
First, we have to import some essential packages. Axon
will be the primary tool to build and train the autoencoder. We also need to import EXLA
for hardware acceleration, Nx
for some basic tensor operations, and Scidata
for downloading training data.
EXLA provides JIT compilation to hardware accelerators. We'll start by telling EXLA which order of accelerators to use. If you have a specific target in mind, such as :cuda
or :rocm
, you can adjust the precedence to reflect that. The order of precedence here is tpu
> cuda
> rocm
> host
EXLA.set_preferred_defn_options([:tpu, :cuda, :rocm, :host])
To train and test how our model works, we use one of the most popular data set: MNIST fashion. It consists of small black and white images of clothes. Loading this data set is very simple. Just use Scidata.FashionMNIST.download()
The next step is image preprocessing. First, we need to load the training data into tensor using Nx.from_binary
and reshape with Nx.reshape
. Then normalize the data with Nx.divide
and split the dataset into a list of batches with Nx.to_batched_list
First we need to define the encoder and decoder. Both are one-layer neural nets.
In the encoder, we start by flattening the input using Axon.flatten
because initially, the input shape is {batch_size, 1, 28, 28} and we want to pass the input into a dense layer with Axon.dense
. Our dense layer has only latent_dim
number of neurons. The latent_dim
or latent space is a compressed representation of data. Remember, we want our encoder to compress the input data into a lower-dimensional representation, so we choose a latent_dim
which is less than the dimensionality of the input.
Next, we pass the output of the encoder to the decoder and try to reconstruct the compressed data into its original form. Since our original input had a dimensionality of 784, we use an Axon.dense
layer with 784 neurons. Because our original data was normalized to have pixel values between 0 and 1, we use a :sigmoid
activation in our dense layer to squeeze output values between 0 and 1. Our original input shape was 28x28, so we use Axon.reshape
to convert the flattened representation of the outputs into an image with correct the width and height.
If we just bind the encoder and decoder into a sequential model, we'll get the desired model. This was pretty smooth, wasn't it?
Finally, we can train the model. We'll use the :adam
and :mean_squared_error
loss with Axon.Loop.trainer
. Our loss function will measure the aggregate error between pixels of original images and the model's reconstructed images. We'll also :mean_absolute_error
using Axon.Loop.metric
. Axon.Loop.run
trains the model with the given training data.
To better understand what is mean absolute error and mean square error let's go through an example.
defmodule Example do
defp transform_images({bin, type, shape}) do
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 1, 28, 28})
|> Nx.divide(255.0)
|> Nx.to_batched_list(32)
def calculate_mse(y_pred, y) do
|> Nx.subtract(y)
|> Nx.power(2)
|> Nx.mean(axes: [-1])
|> Nx.sum()
def calculate_mae(y_pred, y) do
|> Nx.subtract(y)
|> Nx.abs()
|> Nx.mean(axes: [-1])
|> Nx.sum()
def run_example do
{images, _} = Scidata.FashionMNIST.download()
train_images = transform_images(images)
shoe_image =
|> hd()
|> Nx.slice_axis(0, 1, 0)
|> Nx.reshape({1, 1, 28, 28})
noised_shoe_image =
|> hd()
|> Nx.slice_axis(0, 1, 0)
|> Nx.reshape({1, 1, 28, 28})
|> Nx.add(Nx.random_normal({1, 1, 28, 28}, 0.0, 0.05))
random_image =
|> Enum.at(1)
|> Nx.slice_axis(0, 1, 0)
|> Nx.reshape({1, 1, 28, 28})
same_images_mse_error = calculate_mse(shoe_image, shoe_image)
similar_images_mse_error = calculate_mse(shoe_image, noised_shoe_image)
different_images_mse_error = calculate_mse(shoe_image, random_image)
same_images_mae_error = calculate_mae(shoe_image, shoe_image)
similar_images_mae_error = calculate_mae(shoe_image, noised_shoe_image)
different_images_mae_error = calculate_mae(shoe_image, random_image)
same: {[mse: same_images_mse_error], [mae: same_images_mae_error]},
similar: {[mse: similar_images_mse_error], [mae: similar_images_mae_error]},
different: {[mse: different_images_mse_error], [mae: different_images_mae_error]}
We choose an image of a shoe as a reference. As we can see, the error for the same picture is 0 (both mae and mse). Indeed, when we have two exact copies, there is no pixel with different values, and the error is equal to zero. Noised image has a non-zero mse and mae but is much smaller than the error of two completely different pictures. This fact indicates the level of similarity between images. The small error implies decent prediction values. On the other hand, a large value of loss suggests poor quality of predictions.
defmodule Fashionmist do
require Axon
defp transform_images({bin, type, shape}) do
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 1, 28, 28})
|> Nx.divide(255.0)
|> Nx.to_batched_list(256)
defp encoder(x, latent_dim) do
|> Axon.flatten()
|> Axon.dense(latent_dim, activation: :relu)
defp decoder(x) do
|> Axon.dense(784, activation: :sigmoid)
|> Axon.reshape({1, 28, 28})
def build_model(input_shape, latent_dim) do
|> encoder(latent_dim)
|> decoder()
defp train_model(model, train_images, epochs) do
|> Axon.Loop.trainer(:mean_squared_error, :adam)
|> Axon.Loop.metric(:mean_absolute_error, "Error")
|> Axon.Loop.run(Stream.zip(train_images, train_images), epochs: epochs, compiler: EXLA)
def run do
{images, _} = Scidata.FashionMNIST.download()
train_images = transform_images(images)
model = build_model({nil, 1, 28, 28}, 64) |> IO.inspect()
model_state = train_model(model, train_images, 5)
sample_image =
|> hd()
|> Nx.slice_axis(0, 1, 0)
|> Nx.reshape({1, 1, 28, 28})
sample_image |> Nx.to_heatmap() |> IO.inspect()
|> Axon.predict(model_state, sample_image, compiler: EXLA)
|> Nx.to_heatmap()
|> IO.inspect()
As we can see, the generated image is similar to the input image. The only difference between them is the absence of a sign in the middle of the second shoe. The model treated the sign as noise and bled this into the plain shoe. To get the outputs run this code.