Skip to content

Instantly share code, notes, and snippets.

@aymanosman
Created November 29, 2024 20:15
Show Gist options
  • Save aymanosman/d95ae22934dd8a3f2e7c7506c3b963a0 to your computer and use it in GitHub Desktop.
Save aymanosman/d95ae22934dd8a3f2e7c7506c3b963a0 to your computer and use it in GitHub Desktop.
emlx bug
Mix.install([
{:emlx, github: "elixir-nx/emlx"},
{:axon, github: "elixir-nx/axon"},
{:scidata, "~> 0.1.11"}
])
Nx.default_backend(EMLX.Backend)
Nx.default_backend({EMLX.Backend, device: :gpu})
{images, labels} = Scidata.MNIST.download()
{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = labels
images =
image_data
|> Nx.from_binary(image_type)
|> Nx.divide(255)
|> Nx.reshape({60000, :auto})
labels =
label_data
|> Nx.from_binary(label_type)
|> Nx.reshape(label_shape)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.iota({1, 10}))
train_range = 0..49_999//1
test_range = 50_000..-1//1
train_images = images[train_range]
train_labels = labels[train_range]
test_images = images[test_range]
test_labels = labels[test_range]
batch_size = 64
train_data =
train_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(train_labels, batch_size))
test_data =
test_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(test_labels, batch_size))
model =
Axon.input("images", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EMLX)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment