Created
March 2, 2021 00:30
-
-
Save seanmor5/25ebb0cadd9af2a66af17fb1d0a5cefb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule MNIST do | |
import Nx.Defn | |
@default_defn_compiler {EXLA, run_options: [keep_on_device: true]} | |
defn init_random_params do | |
w1 = Nx.random_normal({32, 1, 7, 7}, 0.0, 0.1) | |
b1 = Nx.random_normal({1, 32, 1, 1}, 0.0, 0.1) | |
w2 = Nx.random_normal({64, 32, 4, 4}, 0.0, 0.1) | |
b2 = Nx.random_normal({1, 64, 1, 1}, 0.0, 0.1) | |
w3 = Nx.random_normal({5184, 10}, 0.0, 0.1) | |
b3 = Nx.random_normal({10}, 0.0, 0.1, names: [:output]) | |
{w1, b1, w2, b2, w3, b3} | |
end | |
defn flatten(x) do | |
new_shape = transform(Nx.shape(x), fn shape -> {elem(shape, 0), div(Nx.size(shape), elem(shape, 0))} end) | |
Nx.reshape(x, new_shape) | |
end | |
defn softmax(logits) do | |
Nx.exp(logits) / Nx.sum(Nx.exp(logits), axes: [:output], keep_axes: true) | |
end | |
defn predict({w1, b1, w2, b2, w3, b3}, batch) do | |
batch | |
|> Nx.conv(w1) | |
|> Nx.add(b1) | |
|> Nx.logistic() | |
|> Nx.conv(w2) | |
|> Nx.add(b2) | |
|> Nx.logistic() | |
|> Nx.window_mean({1, 1, 3, 3}, strides: [1, 1, 2, 2]) | |
|> flatten() | |
|> Nx.dot(w3) | |
|> Nx.add(b3) | |
|> softmax() | |
end | |
defn accuracy({w1, b1, w2, b2, w3, b3}, batch_images, batch_labels) do | |
Nx.mean( | |
Nx.equal( | |
Nx.argmax(batch_labels, axis: :output), | |
Nx.argmax(predict({w1, b1, w2, b2, w3, b3}, batch_images), axis: :output) | |
) | |
) | |
end | |
defn loss({w1, b1, w2, b2, w3, b3}, batch_images, batch_labels) do | |
preds = predict({w1, b1, w2, b2, w3, b3}, batch_images) | |
-Nx.sum(Nx.mean(Nx.log(preds) * batch_labels, axes: [:output])) | |
end | |
defn update({w1, b1, w2, b2, w3, b3} = params, batch_images, batch_labels, step) do | |
{grad_w1, grad_b1, grad_w2, grad_b2, grad_w3, grad_b3} = grad(params, loss(params, batch_images, batch_labels)) | |
{ | |
w1 - grad_w1 * step, | |
b1 - grad_b1 * step, | |
w2 - grad_w2 * step, | |
b2 - grad_b2 * step, | |
w3 - grad_w3 * step, | |
b3 - grad_b3 * step | |
} | |
end | |
defn update_with_averages({_, _, _, _, _, _} = cur_params, imgs, tar, avg_loss, avg_accuracy, total) do | |
batch_loss = loss(cur_params, imgs, tar) | |
batch_accuracy = accuracy(cur_params, imgs, tar) | |
avg_loss = avg_loss + batch_loss / total | |
avg_accuracy = avg_accuracy + batch_accuracy / total | |
{update(cur_params, imgs, tar, 0.01), avg_loss, avg_accuracy} | |
end | |
defp unzip_cache_or_download(zip) do | |
base_url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' | |
path = Path.join("tmp", zip) | |
data = | |
if File.exists?(path) do | |
IO.puts("Using #{zip} from tmp/\n") | |
File.read!(path) | |
else | |
IO.puts("Fetching #{zip} from https://storage.googleapis.com/cvdf-datasets/mnist/\n") | |
:inets.start() | |
:ssl.start() | |
{:ok, {_status, _response, data}} = :httpc.request(:get, {base_url ++ zip, []}, [], []) | |
File.mkdir_p!("tmp") | |
File.write!(path, data) | |
data | |
end | |
:zlib.gunzip(data) | |
end | |
def download(images, labels) do | |
<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = | |
unzip_cache_or_download(images) | |
train_images = | |
images | |
|> Nx.from_binary({:u, 8}) | |
|> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:batch, :channels, :height, :width]) | |
|> Nx.divide(255) | |
|> Nx.to_batched_list(32) | |
IO.puts("#{n_images} #{n_rows}x#{n_cols} images\n") | |
<<_::32, n_labels::32, labels::binary>> = unzip_cache_or_download(labels) | |
train_labels = | |
labels | |
|> Nx.from_binary({:u, 8}) | |
|> Nx.reshape({n_labels, 1}, names: [:batch, :output]) | |
|> Nx.equal(Nx.tensor(Enum.to_list(0..9))) | |
|> Nx.to_batched_list(32) | |
IO.puts("#{n_labels} labels\n") | |
{train_images, train_labels} | |
end | |
def train_epoch(cur_params, imgs, labels) do | |
total_batches = Enum.count(imgs) | |
imgs | |
|> Enum.zip(labels) | |
|> Enum.reduce({cur_params, Nx.tensor(0.0), Nx.tensor(0.0)}, fn | |
{imgs, tar}, {cur_params, avg_loss, avg_accuracy} -> | |
update_with_averages(cur_params, imgs, tar, avg_loss, avg_accuracy, total_batches) | |
end) | |
end | |
def train(imgs, labels, params, opts \\ []) do | |
epochs = opts[:epochs] || 5 | |
for epoch <- 1..epochs, reduce: params do | |
cur_params -> | |
{time, {new_params, epoch_avg_loss, epoch_avg_acc}} = | |
:timer.tc(__MODULE__, :train_epoch, [cur_params, imgs, labels]) | |
epoch_avg_loss = | |
epoch_avg_loss | |
|> Nx.backend_transfer() | |
|> Nx.to_scalar() | |
epoch_avg_acc = | |
epoch_avg_acc | |
|> Nx.backend_transfer() | |
|> Nx.to_scalar() | |
IO.puts("Epoch #{epoch} Time: #{time / 1_000_000}s") | |
IO.puts("Epoch #{epoch} average loss: #{inspect(epoch_avg_loss)}") | |
IO.puts("Epoch #{epoch} average accuracy: #{inspect(epoch_avg_acc)}") | |
IO.puts("\n") | |
new_params | |
end | |
end | |
end | |
{train_images, train_labels} = | |
MNIST.download('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz') | |
IO.puts("Initializing parameters...\n") | |
params = MNIST.init_random_params() | |
IO.puts("Training MNIST for 10 epochs...\n\n") | |
final_params = MNIST.train(train_images, train_labels, params, epochs: 10) | |
IO.puts("Bring the parameters back from the device and print them") | |
final_params = Nx.backend_transfer(final_params) | |
IO.inspect(final_params) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment