Last active
December 17, 2023 14:26
-
-
Save afiodorov/9002ae0c012e77c4a60b53bed7181673 to your computer and use it in GitHub Desktop.
advent of code 2023 day 17 a* algo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from queue import PriorityQueue | |
from dataclasses import dataclass | |
from functools import cache | |
grid = pd.read_csv("./data/17.txt", names=[0], dtype='str').apply(lambda x: pd.Series(list(x[0])), axis=1).astype(int) | |
directions = [(0, 1), (1, 0), (-1, 0), (0, -1)] | |
@dataclass(frozen=True, order=True) | |
class NodeState: | |
position: tuple | |
direction: int | |
step_count: int | |
@cache | |
def enhanced_heuristic(position, goal): | |
x, y = position | |
subgrid = grid.iloc[y:, x:] | |
return min(subgrid.min(axis=0).sum(), subgrid.min(axis=1).sum()) | |
def line(pos0, pos1): | |
x0, y0 = pos0 | |
x1, y1 = pos1 | |
if x0 == x1: # Vertical line | |
step = 1 if y1 > y0 else -1 | |
for y in range(y0 + step, y1 + step, step): | |
yield (x0, y) | |
elif y0 == y1: # Horizontal line | |
step = 1 if x1 > x0 else -1 | |
for x in range(x0 + step, x1 + step, step): | |
yield (x, y0) | |
def get_neighbors_with_restriction(node_state): | |
for idx, (dx, dy) in enumerate(directions): | |
if idx == node_state.direction and node_state.step_count >= 3: | |
continue | |
if directions[node_state.direction] == (-dx, -dy): | |
continue | |
x, y = node_state.position[0] + dx, node_state.position[1] + dy | |
if 0 <= x < grid.shape[1] and 0 <= y < grid.shape[0]: | |
new_step_count = node_state.step_count + 1 if idx == node_state.direction else 1 | |
yield NodeState(position=(x, y), direction=idx, step_count=new_step_count) | |
# In[2]: | |
def astar(get_neighbors_with_restriction=get_neighbors_with_restriction): | |
start = (0, 0) | |
goal = (grid.shape[1] - 1, grid.shape[0] - 1) | |
frontier = PriorityQueue() | |
frontier.put((0, NodeState(position=start, direction=1, step_count=1))) | |
frontier.put((0, NodeState(position=start, direction=0, step_count=1))) | |
cost_so_far = {start_state: 0 for _, start_state in frontier.queue} | |
came_from = {start_state: None for _, start_state in frontier.queue} | |
while not frontier.empty(): | |
_, current = frontier.get() | |
if current.position == goal: | |
break | |
for next_state in get_neighbors_with_restriction(current): | |
new_cost = cost_so_far[current] + sum(grid.iloc[y, x] for x, y in line(current.position, next_state.position)) | |
if next_state not in cost_so_far or new_cost < cost_so_far[next_state]: | |
cost_so_far[next_state] = new_cost | |
priority = new_cost + enhanced_heuristic(next_state.position, goal) | |
frontier.put((priority, next_state)) | |
came_from[next_state] = current | |
return cost_so_far[current], came_from, current | |
c, came_from, current = astar() | |
c | |
# In[3]: | |
def part2(node_state): | |
if node_state.step_count == 1: | |
dx, dy = directions[node_state.direction] | |
x, y = node_state.position[0] + 3 * dx, node_state.position[1] + 3 * dy | |
if 0 <= x < grid.shape[1] and 0 <= y < grid.shape[0]: | |
yield NodeState( | |
position=(x, y), | |
direction=node_state.direction, | |
step_count=4, | |
) | |
return | |
for idx, (dx, dy) in enumerate(directions): | |
if idx == node_state.direction and node_state.step_count >= 10: | |
continue | |
if directions[node_state.direction] == (-dx, -dy): | |
continue | |
x, y = node_state.position[0] + dx, node_state.position[1] + dy | |
if 0 <= x < grid.shape[1] and 0 <= y < grid.shape[0]: | |
new_step_count = node_state.step_count + 1 if idx == node_state.direction else 1 | |
yield NodeState(position=(x, y), direction=idx, step_count=new_step_count) | |
c, came_from, current = astar(part2) | |
c | |
# In[4]: | |
# path = [current] | |
# while (n := came_from[path[-1]]) is not None: | |
# path.append(n) | |
# path = list(reversed(path)) | |
# g = grid.copy().astype(str) | |
# for f, t in zip(path, path[1:]): | |
# for x, y in line(f.position, t.position): | |
# g.iloc[y, x] = '#' | |
# g |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment