Skip to content

Instantly share code, notes, and snippets.

@seanmor5
Created March 2, 2021 00:30
Show Gist options
  • Save seanmor5/25ebb0cadd9af2a66af17fb1d0a5cefb to your computer and use it in GitHub Desktop.
Save seanmor5/25ebb0cadd9af2a66af17fb1d0a5cefb to your computer and use it in GitHub Desktop.
defmodule MNIST do
import Nx.Defn
@default_defn_compiler {EXLA, run_options: [keep_on_device: true]}
defn init_random_params do
w1 = Nx.random_normal({32, 1, 7, 7}, 0.0, 0.1)
b1 = Nx.random_normal({1, 32, 1, 1}, 0.0, 0.1)
w2 = Nx.random_normal({64, 32, 4, 4}, 0.0, 0.1)
b2 = Nx.random_normal({1, 64, 1, 1}, 0.0, 0.1)
w3 = Nx.random_normal({5184, 10}, 0.0, 0.1)
b3 = Nx.random_normal({10}, 0.0, 0.1, names: [:output])
{w1, b1, w2, b2, w3, b3}
end
defn flatten(x) do
new_shape = transform(Nx.shape(x), fn shape -> {elem(shape, 0), div(Nx.size(shape), elem(shape, 0))} end)
Nx.reshape(x, new_shape)
end
defn softmax(logits) do
Nx.exp(logits) / Nx.sum(Nx.exp(logits), axes: [:output], keep_axes: true)
end
defn predict({w1, b1, w2, b2, w3, b3}, batch) do
batch
|> Nx.conv(w1)
|> Nx.add(b1)
|> Nx.logistic()
|> Nx.conv(w2)
|> Nx.add(b2)
|> Nx.logistic()
|> Nx.window_mean({1, 1, 3, 3}, strides: [1, 1, 2, 2])
|> flatten()
|> Nx.dot(w3)
|> Nx.add(b3)
|> softmax()
end
defn accuracy({w1, b1, w2, b2, w3, b3}, batch_images, batch_labels) do
Nx.mean(
Nx.equal(
Nx.argmax(batch_labels, axis: :output),
Nx.argmax(predict({w1, b1, w2, b2, w3, b3}, batch_images), axis: :output)
)
)
end
defn loss({w1, b1, w2, b2, w3, b3}, batch_images, batch_labels) do
preds = predict({w1, b1, w2, b2, w3, b3}, batch_images)
-Nx.sum(Nx.mean(Nx.log(preds) * batch_labels, axes: [:output]))
end
defn update({w1, b1, w2, b2, w3, b3} = params, batch_images, batch_labels, step) do
{grad_w1, grad_b1, grad_w2, grad_b2, grad_w3, grad_b3} = grad(params, loss(params, batch_images, batch_labels))
{
w1 - grad_w1 * step,
b1 - grad_b1 * step,
w2 - grad_w2 * step,
b2 - grad_b2 * step,
w3 - grad_w3 * step,
b3 - grad_b3 * step
}
end
defn update_with_averages({_, _, _, _, _, _} = cur_params, imgs, tar, avg_loss, avg_accuracy, total) do
batch_loss = loss(cur_params, imgs, tar)
batch_accuracy = accuracy(cur_params, imgs, tar)
avg_loss = avg_loss + batch_loss / total
avg_accuracy = avg_accuracy + batch_accuracy / total
{update(cur_params, imgs, tar, 0.01), avg_loss, avg_accuracy}
end
defp unzip_cache_or_download(zip) do
base_url = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
path = Path.join("tmp", zip)
data =
if File.exists?(path) do
IO.puts("Using #{zip} from tmp/\n")
File.read!(path)
else
IO.puts("Fetching #{zip} from https://storage.googleapis.com/cvdf-datasets/mnist/\n")
:inets.start()
:ssl.start()
{:ok, {_status, _response, data}} = :httpc.request(:get, {base_url ++ zip, []}, [], [])
File.mkdir_p!("tmp")
File.write!(path, data)
data
end
:zlib.gunzip(data)
end
def download(images, labels) do
<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> =
unzip_cache_or_download(images)
train_images =
images
|> Nx.from_binary({:u, 8})
|> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:batch, :channels, :height, :width])
|> Nx.divide(255)
|> Nx.to_batched_list(32)
IO.puts("#{n_images} #{n_rows}x#{n_cols} images\n")
<<_::32, n_labels::32, labels::binary>> = unzip_cache_or_download(labels)
train_labels =
labels
|> Nx.from_binary({:u, 8})
|> Nx.reshape({n_labels, 1}, names: [:batch, :output])
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
IO.puts("#{n_labels} labels\n")
{train_images, train_labels}
end
def train_epoch(cur_params, imgs, labels) do
total_batches = Enum.count(imgs)
imgs
|> Enum.zip(labels)
|> Enum.reduce({cur_params, Nx.tensor(0.0), Nx.tensor(0.0)}, fn
{imgs, tar}, {cur_params, avg_loss, avg_accuracy} ->
update_with_averages(cur_params, imgs, tar, avg_loss, avg_accuracy, total_batches)
end)
end
def train(imgs, labels, params, opts \\ []) do
epochs = opts[:epochs] || 5
for epoch <- 1..epochs, reduce: params do
cur_params ->
{time, {new_params, epoch_avg_loss, epoch_avg_acc}} =
:timer.tc(__MODULE__, :train_epoch, [cur_params, imgs, labels])
epoch_avg_loss =
epoch_avg_loss
|> Nx.backend_transfer()
|> Nx.to_scalar()
epoch_avg_acc =
epoch_avg_acc
|> Nx.backend_transfer()
|> Nx.to_scalar()
IO.puts("Epoch #{epoch} Time: #{time / 1_000_000}s")
IO.puts("Epoch #{epoch} average loss: #{inspect(epoch_avg_loss)}")
IO.puts("Epoch #{epoch} average accuracy: #{inspect(epoch_avg_acc)}")
IO.puts("\n")
new_params
end
end
end
{train_images, train_labels} =
MNIST.download('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz')
IO.puts("Initializing parameters...\n")
params = MNIST.init_random_params()
IO.puts("Training MNIST for 10 epochs...\n\n")
final_params = MNIST.train(train_images, train_labels, params, epochs: 10)
IO.puts("Bring the parameters back from the device and print them")
final_params = Nx.backend_transfer(final_params)
IO.inspect(final_params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment