Created
December 5, 2021 01:30
-
-
Save seanmor5/7cfd9f283528454d62e79841d1a2a525 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule Day4 do | |
import Nx.Defn | |
def part1() do | |
File.read!("aoc/4.txt") | |
|> parse_input() | |
|> play_bingo() | |
|> find_winning_board() | |
end | |
def part2() do | |
File.read!("aoc/4.txt") | |
|> parse_input() | |
|> play_bingo_until_last() | |
|> compute_last_score() | |
end | |
defp parse_input(input) do | |
[draws | boards] = | |
input | |
|> String.replace("\r", "") | |
|> String.split("\n\n") | |
draws = | |
draws | |
|> String.split(",") | |
|> Enum.map(&String.to_integer/1) | |
|> Nx.tensor() | |
{draws, to_matrices(boards)} | |
end | |
defp to_matrices(boards) do | |
boards | |
|> Enum.map(&to_matrix/1) | |
|> Nx.stack() | |
end | |
defp to_matrix(board) do | |
# otherwise the end of the board gets cut off, there's | |
# probably a better way | |
board = <<board::binary, " "::binary>> | |
digits = | |
for <<c::3-binary <- board>> do | |
c | |
|> String.trim() | |
|> String.to_integer() | |
end | |
digits | |
|> Nx.tensor() | |
|> Nx.reshape({5, 5}, names: [:rows, :columns]) | |
end | |
defnp play_bingo({draws, boards}) do | |
# the current draw | |
current = Nx.tensor(0) | |
# mask of filled spaces on all boards, nobody has | |
# anything filled in | |
mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards) | |
# iterate through draws, this will be much easier | |
# when we merge the while loop on leading axis syntax | |
{current, bingo_mask, _, boards} = | |
while {current, mask, draws, boards}, Nx.logical_not(bingo?(mask)) do | |
next_draw = Nx.squeeze(draws[current]) | |
values_to_fill = Nx.equal(boards, next_draw) | |
update_mask = Nx.logical_or(values_to_fill, mask) | |
{current + 1, update_mask, draws, boards} | |
end | |
{Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0)), bingo_mask, boards} | |
end | |
defnp bingo?(mask) do | |
# bingo occurs when the sum along rows or columns is | |
# 5, thank goodness there are no diagonal bingos :) | |
any_rows? = | |
mask | |
|> Nx.sum(axes: [:rows]) | |
|> Nx.equal(5) | |
|> Nx.any?() | |
any_cols? = | |
mask | |
|> Nx.sum(axes: [:columns]) | |
|> Nx.equal(5) | |
|> Nx.any?() | |
Nx.logical_or(any_rows?, any_cols?) | |
end | |
defnp find_winning_board({last_drawn, mask, boards}) do | |
# the winning board index is the one where the sum of | |
# the rows or columns is 5, so we can select it with iota | |
# then slice out the winning board | |
rows? = | |
mask | |
|> Nx.sum(axes: [:rows]) | |
|> Nx.reduce_max(axes: [:columns]) | |
|> Nx.equal(5) | |
|> Nx.any?() | |
cols? = | |
mask | |
|> Nx.sum(axes: [:columns]) | |
|> Nx.equal(5) | |
|> Nx.any?() | |
winning_board_index = | |
cond do | |
rows? -> | |
mask | |
|> Nx.sum(axes: [:rows]) | |
|> Nx.equal(5) | |
# we need to reduce away the columns now | |
|> Nx.sum(axes: [:columns]) | |
|> Nx.select(Nx.iota({100}), 0) | |
|> Nx.sum() | |
cols? -> | |
mask | |
|> Nx.sum(axes: [:columns]) | |
|> Nx.equal(5) | |
# we need to reduce away the rows now | |
|> Nx.sum(axes: [:rows]) | |
|> Nx.select(Nx.iota({100}), 0) | |
|> Nx.sum() | |
:otherwise -> | |
# oh no | |
Nx.tensor(1_000_000) | |
end | |
not_drawn = | |
mask | |
|> Nx.slice_axis(winning_board_index, 1, 0) | |
|> Nx.logical_not() | |
winning_board = | |
boards | |
|> Nx.slice_axis(winning_board_index, 1, 0) | |
not_drawn | |
|> Nx.select(winning_board, 0) | |
|> Nx.sum() | |
|> Nx.multiply(last_drawn) | |
end | |
defnp play_bingo_until_last({draws, boards}) do | |
# the current draw | |
current = Nx.tensor(0) | |
# number of bingos | |
num_bingos = Nx.tensor(0, type: {:u, 64}) | |
# mask of filled spaces on all boards, nobody has | |
# anything filled in | |
mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards) | |
# iterate through draws, this will be much easier | |
# when we merge the while loop on leading axis syntax | |
{current, bingo_mask, _, draws, boards} = | |
while {current, mask, num_bingos, draws, boards}, Nx.less(num_bingos, 99) do | |
next_draw = Nx.squeeze(draws[current]) | |
values_to_fill = Nx.equal(boards, next_draw) | |
update_mask = Nx.logical_or(values_to_fill, mask) | |
num_bingos = count_bingos(update_mask) | |
{current + 1, update_mask, num_bingos, draws, boards} | |
end | |
{current, bingo_mask, draws, boards} | |
end | |
defnp count_bingos(mask) do | |
# it's possible to have duplicates unfortunately, so we | |
# need to count unique wins | |
row_bingos = | |
mask | |
|> Nx.sum(axes: [:rows]) | |
|> Nx.equal(5) | |
|> Nx.select(Nx.iota({100, 5}, axis: 1), -1) | |
|> Nx.reduce_max(axes: [:columns]) | |
col_bingos = | |
mask | |
|> Nx.sum(axes: [:columns]) | |
|> Nx.equal(5) | |
|> Nx.select(Nx.iota({100, 5}, axis: 1), -1) | |
|> Nx.reduce_max(axes: [:rows]) | |
row_bingos | |
|> Nx.not_equal(-1) | |
|> Nx.logical_or(Nx.not_equal(col_bingos, -1)) | |
|> Nx.sum() | |
end | |
defnp compute_last_score({current, mask, draws, boards}) do | |
row_wins = | |
mask | |
|> Nx.sum(axes: [:rows]) | |
|> Nx.equal(5) | |
|> Nx.reduce_max(axes: [:columns]) | |
col_wins = | |
mask | |
|> Nx.sum(axes: [:columns]) | |
|> Nx.equal(5) | |
|> Nx.reduce_max(axes: [:rows]) | |
loser_idx = | |
row_wins | |
|> Nx.logical_or(col_wins) | |
|> Nx.logical_not() | |
|> Nx.multiply(Nx.iota({100})) | |
|> Nx.sum() | |
loser_mask = Nx.slice_axis(mask, loser_idx, 1, 0) | |
loser_board = Nx.slice_axis(boards, loser_idx, 1, 0) | |
{current, loser_mask, _, loser_board} = | |
while {current, loser_mask, draws, loser_board}, Nx.logical_not(bingo?(loser_mask)) do | |
next_draw = Nx.squeeze(draws[current]) | |
values_to_fill = Nx.equal(loser_board, next_draw) | |
update_mask = Nx.logical_or(loser_mask, values_to_fill) | |
{current + 1, update_mask, draws, loser_board} | |
end | |
not_drawn = | |
loser_mask | |
|> Nx.logical_not() | |
last_drawn = Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0)) | |
not_drawn | |
|> Nx.select(loser_board, 0) | |
|> Nx.sum() | |
|> Nx.multiply(last_drawn) | |
end | |
end | |
Day4.part1() |> IO.inspect() | |
Day4.part2() |> IO.inspect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment