Skip to content

Instantly share code, notes, and snippets.

@trvswgnr
Last active August 27, 2023 16:26
Show Gist options
  • Save trvswgnr/4d3222272169a25bcc8cc7ad1b1c6246 to your computer and use it in GitHub Desktop.
Save trvswgnr/4d3222272169a25bcc8cc7ad1b1c6246 to your computer and use it in GitHub Desktop.
prim's algo in crab & ts
use std::cmp::Ordering;
use std::collections::BinaryHeap;
const INF: u32 = u32::MAX;
#[derive(Copy, Clone, Eq, PartialEq)]
struct Edge {
weight: u32,
vertex: usize,
}
impl Ord for Edge {
fn cmp(&self, other: &Self) -> Ordering {
other.weight.cmp(&self.weight)
}
}
impl PartialOrd for Edge {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct Graph {
adj_matrix: Vec<Vec<u32>>,
}
impl Graph {
pub fn new(n: usize) -> Self {
Self {
adj_matrix: vec![vec![INF; n]; n],
}
}
pub fn add_edge(&mut self, u: usize, v: usize, w: u32) {
self.adj_matrix[u][v] = w;
self.adj_matrix[v][u] = w;
}
pub fn prim(&self) -> Option<u32> {
let n = self.adj_matrix.len();
let mut total_weight = 0;
let mut visited = vec![false; n];
let mut min_heap = BinaryHeap::new();
min_heap.push(Edge { weight: 0, vertex: 0 });
while let Some(Edge { weight, vertex }) = min_heap.pop() {
if visited[vertex] {
continue;
}
visited[vertex] = true;
total_weight += weight;
for (neighbour, &weight) in self.adj_matrix[vertex].iter().enumerate() {
if !visited[neighbour] && weight != INF {
min_heap.push(Edge {
weight,
vertex: neighbour,
});
}
}
}
if visited.iter().all(|&v| v) {
return Some(total_weight)
}
None
}
}
class BinaryHeap<T> {
private heap: T[] = [];
constructor(private compare: (a: T, b: T) => number) {}
push(value: T) {
this.heap.push(value);
this.bubbleUp(this.heap.length - 1);
}
pop(): T | undefined {
const result = this.heap[0];
const end = this.heap.pop();
if (this.heap.length > 0 && end !== undefined) {
this.heap[0] = end;
this.sinkDown(0);
}
return result;
}
private bubbleUp(n: number) {
const element = this.heap[n];
while (n > 0) {
const parentN = Math.floor((n + 1) / 2) - 1;
const parent = this.heap[parentN];
if (this.compare(element, parent) <= 0) break;
this.heap[parentN] = element;
this.heap[n] = parent;
n = parentN;
}
}
private sinkDown(n: number) {
const length = this.heap.length;
const element = this.heap[n];
while (true) {
const child2N = (n + 1) * 2;
const child1N = child2N - 1;
let swap = null;
let child1, child2;
if (child1N < length) {
child1 = this.heap[child1N];
if (this.compare(child1, element) > 0) swap = child1N;
}
if (child2N < length) {
child2 = this.heap[child2N];
if (this.compare(child2, (swap === null ? element : child1)) > 0) swap = child2N;
}
if (swap === null) break;
this.heap[n] = this.heap[swap];
this.heap[swap] = element;
n = swap;
}
}
}
const INF = Number.MAX_SAFE_INTEGER;
class Edge {
constructor(public weight: number, public vertex: number) {}
}
class Graph {
private adjMatrix: number[][];
constructor(n: number) {
this.adjMatrix = Array.from({ length: n }, () => Array(n).fill(INF));
}
addEdge(u: number, v: number, w: number) {
this.adjMatrix[u][v] = w;
this.adjMatrix[v][u] = w;
}
prim(): number | null {
const n = this.adjMatrix.length;
let totalWeight = 0;
const visited = Array(n).fill(false);
const minHeap = new BinaryHeap<Edge>((a, b) => b.weight - a.weight);
minHeap.push(new Edge(0, 0));
while (true) {
const edge = minHeap.pop();
if (!edge) break;
if (visited[edge.vertex]) continue;
visited[edge.vertex] = true;
totalWeight += edge.weight;
for (let neighbour = 0; neighbour < n; neighbour++) {
const weight = this.adjMatrix[edge.vertex][neighbour];
if (!visited[neighbour] && weight !== INF) {
minHeap.push(new Edge(weight, neighbour));
}
}
}
if (visited.every(v => v)) return totalWeight;
return null;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment