Skip to content

Instantly share code, notes, and snippets.

@smourier
Created May 15, 2023 13:31
Show Gist options
  • Save smourier/40495257014a6b68ada7b814a066f4a8 to your computer and use it in GitHub Desktop.
Save smourier/40495257014a6b68ada7b814a066f4a8 to your computer and use it in GitHub Desktop.
A* implementation in C# (.NET core 6+)
public static class AStar
{
public interface IHasNeighbours<T>
{
IEnumerable<T> Neighbours { get; }
}
public static IEnumerable<T> FindPath<T>(T start, T destination, Func<T, T, double> distance, Func<T, double>? estimate = null) where T : IHasNeighbours<T>
{
estimate ??= (n) => distance(n, destination);
var closed = new HashSet<T>();
var queue = new PriorityQueue<Path<T>, double>();
queue.Enqueue(new Path<T>(start), 0);
while (queue.Count > 0)
{
var path = queue.Dequeue();
if (closed.Contains(path.LastStep))
continue;
if (path.LastStep.Equals(destination))
return path.Reverse();
closed.Add(path.LastStep);
foreach (var node in path.LastStep.Neighbours)
{
var newPath = path.AddStep(node, distance(path.LastStep, node));
queue.Enqueue(newPath, newPath.TotalCost + estimate(node));
}
}
return Enumerable.Empty<T>();
}
private class Path<TNode> : IEnumerable<TNode>
{
public Path(TNode start)
: this(start, null, 0)
{
}
private Path(TNode lastStep, Path<TNode>? previousSteps, double totalCost)
{
LastStep = lastStep;
PreviousSteps = previousSteps;
TotalCost = totalCost;
}
public TNode LastStep { get; }
public Path<TNode>? PreviousSteps { get; }
public double TotalCost { get; }
public Path<TNode> AddStep(TNode step, double stepCost) => new(step, this, TotalCost + stepCost);
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public IEnumerator<TNode> GetEnumerator()
{
for (var p = this; p != null; p = p.PreviousSteps)
{
yield return p.LastStep;
}
}
}
}
@smourier
Copy link
Author

Sample usage:

static AStarNode[,] _grid;
const int _size = 512;

static void TestAStar()
{
    var rnd = new Random();
    _grid = new AStarNode[_size, _size];
    for (int x = 0; x < _size; x++)
    {
        for (int y = 0; y < _size; y++)
        {
            var isWall = ((y % 2) != 0) && (rnd.Next(0, 10) != 8);
            _grid[x, y] = new AStarNode(x, y, isWall);
        }
    }

    foreach (var node in AStar.FindPath(_grid[0, 0], _grid[_size - 2, _size - 2], AStarNode.GetDistance))
    {
        Console.WriteLine(node);
    }
}


public class AStarNode : AStar.IHasNeighbours<AStarNode>
{
    public AStarNode(int x, int y, bool isWall)
    {
        X = x;
        Y = y;
        IsWall = isWall;
    }

    public int X { get; }
    public int Y { get; }
    public bool IsWall { get; }

    public override string ToString() => X + "," + Y + " " + IsWall;

    public IEnumerable<AStarNode> Neighbours
    {
        get
        {
            if (X > 0 && !_grid[X - 1, Y].IsWall)
                yield return _grid[X - 1, Y];

            if (Y > 0 && !_grid[X, Y - 1].IsWall)
                yield return _grid[X, Y - 1];

            if (X < (_size - 1) && !_grid[X + 1, Y].IsWall)
                yield return _grid[X + 1, Y];

            if (Y < (_size - 1) && !_grid[X, Y + 1].IsWall)
                yield return _grid[X, Y + 1];
        }
    }

    public static double GetDistance(AStarNode from, AStarNode to) => Math.Abs(from.X - to.X) + Math.Abs(from.Y - to.Y);
}

@smourier
Copy link
Author

Another sample usage:

var edges = new List<Edge>();
for (var i = 0; i < 14; i++)
{
    edges.Add(new Edge(i));
}

void addedge(int from, int to, double cost)
{
    edges[from].AddEdge(edges[to], cost);
    edges[to].AddEdge(edges[from], cost);
}

// http://graphonline.ru/en/?graph=xiloEvjFkRNQbQKv
addedge(0, 1, 3);
addedge(0, 2, 6);
addedge(0, 3, 5);
addedge(1, 4, 9);
addedge(1, 5, 8);
addedge(2, 6, 12);
addedge(2, 7, 14);
addedge(3, 8, 7);
addedge(8, 9, 5);
addedge(8, 10, 6);
addedge(9, 11, 1);
addedge(9, 12, 10);
addedge(9, 13, 2);

foreach (var n in AStar.FindPath(edges[0], edges[9], Edge.GetDistance))
{
    Console.WriteLine(n);
}

public class Edge : AStar.IHasNeighbours<Edge>
{
    public Edge(int index)
    {
        Index = index;
        EdgesCosts = new List<Tuple<Edge, double>>();
    }

    public int Index { get; }
    public List<Tuple<Edge, double>> EdgesCosts { get; }

    public override string ToString() => Index.ToString();
    public void AddEdge(Edge to, double cost) => EdgesCosts.Add(new Tuple<Edge, double>(to, cost));
    public IEnumerable<Edge> Neighbours => EdgesCosts.Select(c => c.Item1);

    public static double GetDistance(Edge from, Edge to) => from.EdgesCosts.Find(e => e.Item1 == to)?.Item2 ?? int.MinValue;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment