Skip to content

Instantly share code, notes, and snippets.

@ChristopheBelpaire
Last active March 6, 2024 17:41
Show Gist options
  • Save ChristopheBelpaire/80f46604df1da56aeb3c26e8a563605a to your computer and use it in GitHub Desktop.
Save ChristopheBelpaire/80f46604df1da56aeb3c26e8a563605a to your computer and use it in GitHub Desktop.
Mix.install([{:axon, "~> 0.6"}, {:nx, "~> 0.7"}, {:exla, "~> 0.7"}, {:stb_image, "0.6.6"}, {:kino, "~> 0.8"}])
defmodule CatsAndDogs do
def pipeline(paths, batch_size, target_height, target_width) do
paths
|> Enum.shuffle()
|> Task.async_stream(&parse_image/1)
|> Stream.filter(fn
{:ok, {%StbImage{shape: {_, _, 3}}, _}} -> true
_ -> false end)
|> Stream.map(&to_tensors(&1, target_height, target_width))
|> Stream.chunk_every(batch_size, batch_size, :discard)
|> Stream.map(fn chunks ->
{img_chunk, label_chunk} = Enum.unzip(chunks)
{Nx.stack(img_chunk), Nx.stack(label_chunk)}
end)
end
defp parse_image(path) do
label = if String.contains?(path, "cat"), do: 0, else: 1
case StbImage.read_file(path) do
{:ok, img} -> {img, label}
_error -> :error
end
end
defp to_tensors({:ok, {img, label}}, target_height, target_width) do img_tensor =
img
|> StbImage.resize(target_height, target_width)
|> StbImage.to_nx()
|> Nx.divide(255)
label_tensor = Nx.tensor([label])
{img_tensor, label_tensor}
end
end
{test_paths, train_paths} = Path.wildcard("/Users/christophebelpaire/perso/machine-learning-in-elixir/train-2/*.jpg")
|> Enum.shuffle()
|> Enum.split(1000)
target_height = 96
target_width = 96
batch_size = 32
train_pipeline = CatsAndDogs.pipeline(
train_paths, batch_size, target_height, target_width
)
_test_pipeline = CatsAndDogs.pipeline(
test_paths, batch_size, target_height, target_width
)
cnn_model =
Axon.input("images", shape: {nil, 96, 96, 3})
|> Axon.conv(32, kernel_size: {3, 3}, activation: :relu, padding: :same)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.conv(64, kernel_size: {3, 3}, activation: :relu, padding: :same)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.conv(128, kernel_size: {3, 3}, activation: :relu, padding: :same)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
_cnn_trained_model_state =
cnn_model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_pipeline, %{}, epochs: 5, compiler: EXLA)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment