Last active
April 19, 2022 19:22
-
-
Save seanmor5/8a27ff8048040e22ae012983981f97b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule Day3 do | |
# Since my last answer wasn't purely Nx, I'm going | |
# to try to stick to Nx as much as is possible, but | |
# we don't have string manipulation stuff so that will | |
# have to be done in Elixir | |
import Nx.Defn | |
def part1 do | |
File.read!("aoc/3.txt") | |
|> parse_input() | |
|> power_consumption() | |
end | |
def part2 do | |
File.read!("aoc/3.txt") | |
|> parse_input() | |
|> compute_ratings() | |
end | |
defp parse_input(file) do | |
file | |
# we are all children of windows | |
|> String.replace("\r", "") | |
|> String.split("\n") | |
# from byte value, then shift by 48, wouldn't it be nice | |
# to have some char/string manipulation in Nx? (star for yes) | |
|> Enum.map(&Nx.subtract(Nx.from_binary(&1, {:u, 8}), Nx.tensor(48))) | |
|> Nx.stack() | |
end | |
defnp power_consumption(bytes) do | |
count_ones = count_value(bytes, 1, axis: 0) | |
count_zeros = count_value(bytes, 0, axis: 0) | |
# tensors are now {bitwidth} shape, so we can compare count | |
# ones versus count zeros and the result will tell us which | |
# value is more prevalent in each bit position | |
gamma = Nx.greater(count_ones, count_zeros) | |
# gamma are most prevalent bits, so epsilon is logically the | |
# opposite! | |
epsilon = Nx.logical_not(gamma) | |
# convert binary to decimal and multiply | |
gamma_dec = bin2dec(gamma) | |
epsilon_dec = bin2dec(epsilon) | |
Nx.multiply(gamma_dec, epsilon_dec) | |
end | |
defnp compute_ratings(bytes) do | |
# To compute the rating, we build the mask and | |
# then select values where the mask is true, otherwise | |
# we select 0, then we sum along the zeroth axis to reduce | |
# the tensor down to the correct chosen bit values, finally | |
# we convert to decimal :) | |
oxygen_rating = | |
bytes | |
|> build_mask(&Nx.greater_equal/2) | |
|> then(&Nx.select(bytes, &1, 0)) | |
|> Nx.sum(axes: [0]) | |
|> bin2dec() | |
co2_rating = | |
bytes | |
|> build_mask(&Nx.less/2) | |
|> then(&Nx.select(bytes, &1, 0)) | |
|> Nx.sum(axes: [0]) | |
|> bin2dec() | |
Nx.multiply(oxygen_rating, co2_rating) | |
end | |
defnp bin2dec(x, opts \\ []) do | |
opts = keyword!(opts, bitwidth: 12) | |
# the binary representation is ordered MSB to LSB, | |
# so we can obtain this by using iota (a counter) | |
# and taking element-wise 2^x. Then we reverse (bits | |
# are MSB to LSB) and take the dot product between | |
# our binary number and the bit values | |
2 | |
|> Nx.power(Nx.iota({opts[:bitwidth]})) | |
|> Nx.reverse() | |
|> Nx.dot(x) | |
end | |
defnp count_value(x, val, opts \\ []) do | |
# the number of times a value is present in a tensor | |
# is the sum of the equality x == val, because the equality | |
# will be computed elementwise (scalar value will be broadcasted) | |
# and thus the resulting tensor will be all 1's and 0's, 1's in | |
# positions the value is present, and 0's everywhere else, you | |
# can compute this along an axis by passing an axis to sum | |
opts = keyword!(opts, axis: 0) | |
Nx.sum(Nx.equal(x, val), axes: [opts[:axis]]) | |
end | |
# we're going to iteratively build a mask over the input | |
defnp build_mask(bytes, condition, opts \\ []) do | |
opts = keyword!(opts, bitwidth: 12) | |
# to start, nothing is masked, so the default mask | |
# is all true, we also need to make sure that we're | |
# not slicing passed the bitwidth in the input bytes, | |
# we can stop when we have exactly `bitwidth` values left | |
# in the mask (this represents 1 whole value remaining) | |
{_, mask, _} = | |
while {i = Nx.tensor(0), mask = Nx.broadcast(Nx.tensor(1, type: {:u, 8}), bytes), bytes}, | |
Nx.logical_and( | |
Nx.less(i, opts[:bitwidth]), | |
Nx.not_equal(Nx.sum(mask), opts[:bitwidth]) | |
) do | |
# slice bytes along the current axis to count the number | |
# of ones and zeros, we select between bytes and -1 in order | |
# to show that some of the byte values are no longer considered | |
# in the count | |
bytes_slice = | |
mask | |
|> Nx.select(bytes, -1) | |
|> Nx.slice_axis(i, 1, 1) | |
# we have to squeeze the bytes slice so we get a scalar | |
count_zeros = count_value(Nx.squeeze(bytes_slice), 0) | |
count_ones = count_value(Nx.squeeze(bytes_slice), 1) | |
# condition above is a condition which chooses a value (0 or 1) | |
# based on what we're trying to compute | |
value = condition.(count_ones, count_zeros) | |
# mask slice computes positions in bytes slice that are equal | |
# to the value chosen in this axis, then we compute the new | |
# mask where our allowed values are positions where mask slice | |
# is true AND mask is true (because they are still considered)! | |
# notice that mask slice has shape {samples, 1}, so it will be | |
# broadcasted across the bitwidth of the current mask! | |
updated_mask = | |
bytes_slice | |
|> Nx.equal(value) | |
|> Nx.logical_and(mask) | |
{Nx.add(i, 1), updated_mask, bytes} | |
end | |
mask | |
end | |
end | |
# Part 1 | |
Day3.part1() |> IO.inspect() | |
# Part 2 | |
Day3.part2() |> IO.inspect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment