Skip to content

Instantly share code, notes, and snippets.

@icub3d
Created January 4, 2024 00:56
Show Gist options
  • Save icub3d/7f4caa23be23a0c0ec33770125eb704c to your computer and use it in GitHub Desktop.
Save icub3d/7f4caa23be23a0c0ec33770125eb704c to your computer and use it in GitHub Desktop.
A* // LeetCode - Open the Lock
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
#[derive(Clone, Debug, Eq, PartialEq)]
struct State<N>
where
N: Clone + Eq + PartialEq,
{
node: N,
cost: usize,
}
impl<N> State<N>
where
N: Clone + Eq + PartialEq,
{
fn new(node: N, cost: usize) -> Self {
Self { node, cost }
}
}
impl<N> Ord for State<N>
where
N: Clone + Eq + PartialEq,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.cost.cmp(&self.cost)
}
}
impl<N> PartialOrd for State<N>
where
N: Clone + Eq + PartialEq,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
fn dijkstra<N, F, C>(start: &N, neighbor_fn: F, complete_fn: C) -> Option<(usize, usize)>
where
N: Clone + Eq + PartialEq + Hash,
F: Fn(&N) -> Vec<(N, usize)>,
C: Fn(&N) -> bool,
{
let mut frontier = BinaryHeap::new();
frontier.push(State::new(start.clone(), 0));
let mut distances = HashMap::new();
distances.insert(start.clone(), 0);
let mut counter = 0;
while let Some(State { node, cost, .. }) = frontier.pop() {
counter += 1;
if complete_fn(&node) {
return Some((cost, counter));
}
for (neighbor, neighbor_cost) in neighbor_fn(&node) {
let new_cost = cost + neighbor_cost;
if let Some(&best_cost) = distances.get(&neighbor) {
if best_cost <= new_cost {
continue;
}
}
frontier.push(State::new(neighbor.clone(), new_cost));
distances.insert(neighbor.clone(), new_cost);
}
}
None
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct StateWithHeuristic<N>
where
N: Clone + Eq + PartialEq,
{
heuristic: usize,
state: State<N>,
}
impl<N> StateWithHeuristic<N>
where
N: Clone + Eq + PartialEq,
{
fn new(node: N, cost: usize, heuristic: usize) -> Self {
Self {
heuristic,
state: State::new(node, cost),
}
}
}
impl<N> Ord for StateWithHeuristic<N>
where
N: Clone + Eq + PartialEq,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We want to sort by the total cost of the state, which is
// the distance from start and the heuristic which is a quick
// estimate of the distance to the goal.
(other.state.cost + other.heuristic).cmp(&(self.state.cost + self.heuristic))
}
}
impl<N> PartialOrd for StateWithHeuristic<N>
where
N: Clone + Eq + PartialEq,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
fn astar<N, F, C, H>(
start: &N,
neighbor_fn: F,
complete_fn: C,
heuristic_fn: H,
) -> Option<(usize, usize)>
where
N: Clone + Eq + PartialEq + Hash,
F: Fn(&N) -> Vec<(N, usize)>,
C: Fn(&N) -> bool,
H: Fn(&N) -> usize,
{
// This will look a lot like Dijkstra's algorithm, with one
// notable change, we now use a heuristic function to estimate the
// distance to the goal and include that when comparing states for
// priority.
let mut queue = BinaryHeap::new();
queue.push(StateWithHeuristic::new(
start.clone(),
0,
heuristic_fn(start),
));
let mut distances = HashMap::new();
distances.insert(start.clone(), 0);
let mut counter = 0;
while let Some(StateWithHeuristic {
state: State { node, cost },
..
}) = queue.pop()
{
counter += 1;
if complete_fn(&node) {
return Some((cost, counter));
}
for (neighbor, neighbor_cost) in neighbor_fn(&node) {
let new_cost = cost + neighbor_cost;
if let Some(&best_cost) = distances.get(&neighbor) {
if best_cost <= new_cost {
continue;
}
}
// Add to our queue but now we include the heuristic.
queue.push(StateWithHeuristic::new(
neighbor.clone(),
new_cost,
heuristic_fn(&neighbor),
));
distances.insert(neighbor.clone(), new_cost);
}
}
None
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct Point {
x: usize,
y: usize,
}
impl Point {
fn new(x: usize, y: usize) -> Self {
Self { x, y }
}
}
impl Point {
fn distance(&self, other: &Self) -> usize {
(self.x as isize - other.x as isize).unsigned_abs()
+ (self.y as isize - other.y as isize).unsigned_abs()
}
}
fn main() {
// This is used to track relatives positions on the map.
let grid = [
('A', Point::new(0, 0)),
('J', Point::new(50, 0)),
('S', Point::new(0, 5)),
('B', Point::new(0, 8)),
('C', Point::new(2, 8)),
('D', Point::new(2, 13)),
('L', Point::new(7, 8)),
('N', Point::new(12, 8)),
('O', Point::new(17, 8)),
('R', Point::new(22, 8)),
('K', Point::new(7, 3)),
('M', Point::new(12, 3)),
('P', Point::new(17, 3)),
('Q', Point::new(22, 3)),
('F', Point::new(22, 13)),
('G', Point::new(22, 43)),
('H', Point::new(27, 43)),
('E', Point::new(37, 43)),
]
.iter()
.cloned()
.collect::<HashMap<_, _>>();
// This is used to track the neighbors of each node the cost of
// the edge (time between them in our case).
let neighbors: HashMap<char, Vec<(char, usize)>> = vec![
('S', 'A', 5),
('S', 'B', 3),
('A', 'J', 30),
('B', 'C', 2),
('C', 'D', 5),
('C', 'L', 5),
('L', 'K', 5),
('L', 'N', 5),
('K', 'M', 5),
('N', 'O', 5),
('N', 'M', 5),
('M', 'P', 5),
('O', 'P', 5),
('O', 'R', 5),
('P', 'Q', 5),
('Q', 'R', 5),
('R', 'F', 15),
('D', 'F', 20),
('F', 'G', 10),
('G', 'H', 5),
('J', 'H', 50),
('H', 'E', 10),
]
.into_iter()
.fold(HashMap::new(), |mut acc, (a, b, c)| {
acc.entry(a).or_default().push((b, c));
acc
});
let start = 'S';
let end = 'E';
// Get the neighbors of a node.
let neighbors_fn = |node: &char| {
neighbors
.get(node)
.unwrap_or(&vec![])
.iter()
.map(|(n, s)| (*n, *s))
.collect::<Vec<_>>()
};
// Test if we have reached the end.
let complete_fn = |node: &char| *node == end;
// Calculate the estimated distance to the end.
let heuristic_fn = |node: &char| grid.get(&end).unwrap().distance(grid.get(node).unwrap());
// Run Dijkstra's algorithm and A-star.
let (cost, counter) = dijkstra(&start, neighbors_fn, complete_fn).unwrap();
println!("Dijkstra: {} cost, {} nodes visited", cost, counter);
// Run A-star.
let (cost, counter) = astar(&start, neighbors_fn, complete_fn, heuristic_fn).unwrap();
println!("A-star: {} cost, {} nodes visited", cost, counter);
}
struct Solution;
#[derive(Clone, Debug, Eq, PartialEq)]
struct State {
state: Vec<i32>,
cost: i32,
}
impl State {
fn new(state: Vec<i32>, cost: i32) -> Self {
Self { state, cost }
}
}
impl Ord for State {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Again, reverse here for a min-heap.
(other.cost).cmp(&(self.cost))
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct StateWithHeuristic {
state: State,
heuristic: i32,
}
impl StateWithHeuristic {
fn new(state: Vec<i32>, cost: i32, heuristic: i32) -> Self {
Self {
state: State::new(state, cost),
heuristic,
}
}
}
impl Ord for StateWithHeuristic {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Again, reverse here for a min-heap.
(other.heuristic).cmp(&(self.heuristic))
}
}
impl PartialOrd for StateWithHeuristic {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Solution {
fn neighbors(state: &[i32]) -> Vec<(Vec<i32>, i32)> {
let mut neighbors = Vec::new();
// For each wheel, we want to roll it forward and backward.
for i in 0..state.len() {
// Roll it one position forward.
let mut neighbor = state.to_vec();
neighbor[i] = (neighbor[i] + 1) % 10;
neighbors.push((neighbor, 1));
// Roll it one position backward.
let mut neighbor = state.to_vec();
neighbor[i] = (neighbor[i] + 9) % 10;
neighbors.push((neighbor, 1));
}
neighbors
}
fn heuristic(state: &[i32], target: &[i32]) -> i32 {
// For each digit, we calculate the distance between the
// current digit and the target digit. It may be faster to
// wrap around from 0 to 9 or 9 to 0.
state
.iter()
.zip(target.iter())
.map(|(a, b)| {
let diff = (a - b).abs();
std::cmp::min(diff, 10 - diff)
})
.sum()
}
pub fn open_lock(deadends: Vec<String>, target: String) -> i32 {
// Turn our strings into vectors of digits.
let start = vec![0, 0, 0, 0];
let target: Vec<i32> = target
.chars()
.map(|c| c.to_digit(10).unwrap() as i32)
.collect();
let deadends: Vec<Vec<i32>> = deadends
.iter()
.map(|deadend| {
deadend
.chars()
.map(|c| c.to_digit(10).unwrap() as i32)
.collect()
})
.collect();
let mut queue = std::collections::BinaryHeap::new();
queue.push(StateWithHeuristic::new(
start.clone(),
0,
Solution::heuristic(&start, &target),
));
let mut distances = std::collections::HashMap::new();
distances.insert(start.clone(), 0);
let mut counter = 0;
while let Some(StateWithHeuristic {
state: State { state, cost },
..
}) = queue.pop()
{
counter += 1;
if state == target {
println!("counter: {}", counter);
return cost;
}
if deadends.contains(&state) {
continue;
}
for (neighbor, neighbor_cost) in Solution::neighbors(&state) {
let new_cost = cost + neighbor_cost;
if let Some(&best_cost) = distances.get(&neighbor) {
if best_cost <= new_cost {
continue;
}
}
// Add to our queue but now we include the heuristic.
queue.push(StateWithHeuristic::new(
neighbor.clone(),
new_cost,
Solution::heuristic(&neighbor, &target),
));
distances.insert(neighbor.clone(), new_cost);
}
}
-1
}
pub fn open_lock_dijktra(deadends: Vec<String>, target: String) -> i32 {
// Turn our strings into vectors of digits.
let start = vec![0, 0, 0, 0];
let target: Vec<i32> = target
.chars()
.map(|c| c.to_digit(10).unwrap() as i32)
.collect();
let deadends: Vec<Vec<i32>> = deadends
.iter()
.map(|deadend| {
deadend
.chars()
.map(|c| c.to_digit(10).unwrap() as i32)
.collect()
})
.collect();
let mut queue = std::collections::BinaryHeap::new();
queue.push(State::new(start.clone(), 0));
let mut distances = std::collections::HashMap::new();
distances.insert(start.clone(), 0);
let mut counter = 0;
while let Some(State { state, cost }) = queue.pop() {
counter += 1;
if state == target {
println!("counter: {}", counter);
return cost;
}
if deadends.contains(&state) {
continue;
}
for (neighbor, neighbor_cost) in Solution::neighbors(&state) {
let new_cost = cost + neighbor_cost;
if let Some(&best_cost) = distances.get(&neighbor) {
if best_cost <= new_cost {
continue;
}
}
// Add to our queue but now we include the heuristic.
queue.push(State::new(neighbor.clone(), new_cost));
distances.insert(neighbor.clone(), new_cost);
}
}
-1
}
}
fn main() {
println!(
"a-star: {}",
Solution::open_lock(
vec![
"0201".to_string(),
"0101".to_string(),
"0102".to_string(),
"1212".to_string(),
"2002".to_string()
],
"0202".to_string()
)
);
println!();
println!(
"dijkstra: {}",
Solution::open_lock_dijktra(
vec![
"0201".to_string(),
"0101".to_string(),
"0102".to_string(),
"1212".to_string(),
"2002".to_string()
],
"0202".to_string()
)
);
}
@icub3d
Copy link
Author

icub3d commented Jan 4, 2024

example-graph

Here is the example graph in the video and used in graph.rs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment