Last active
November 14, 2021 10:27
-
-
Save th3terrorist/39c1fe5af005b7bea666f946dfc87c11 to your computer and use it in GitHub Desktop.
A rust implementation of the Prim's mst algorithm.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use random::Source; | |
mod structures { | |
#[derive(Eq, PartialEq, Clone, Copy)] | |
pub struct Branch { | |
pub x: u32, | |
pub y: u32, | |
pub w: u32 | |
} | |
impl Branch { | |
pub fn new(x: u32, y: u32, w: u32) -> Self { | |
Self { | |
x, | |
y, | |
w | |
} | |
} | |
} | |
#[derive(Debug, Clone, Copy)] | |
pub struct Node { | |
pub id: u32, | |
pub key: u32, | |
pub parent: i32 | |
} | |
impl Node { | |
pub fn new(id: u32, key: u32, parent: i32) -> Self { | |
Self { | |
id, | |
key, | |
parent | |
} | |
} | |
} | |
#[derive(Clone)] | |
pub struct Graph { | |
pub verts: Vec<Node>, | |
pub edges: Vec<Branch> | |
} | |
impl Graph { | |
pub fn new(nodes: Vec<Node>) -> Self { | |
Self { | |
verts: nodes, | |
edges: Vec::new() | |
} | |
} | |
pub fn adj(&self, node: u32) -> Vec<u32> { | |
let mut adjv = Vec::<u32>::new(); | |
for e in &self.edges { | |
if e.x == node { | |
adjv.push(e.y); | |
} | |
if e.y == node { | |
adjv.push(e.x); | |
} | |
} | |
return adjv; | |
} | |
pub fn insert(&mut self, id_n1: u32, id_n2: u32, cost: u32) { | |
let mut node1: Option<&Node> = None; | |
let mut node2: Option<&Node> = None; | |
for node in &self.verts { | |
if node.id == id_n1 { | |
node1 = Some(node); | |
} | |
if node.id == id_n2 { | |
node2 = Some(node); | |
} | |
} | |
if node1.is_none() || node2.is_none() { | |
panic!("One of the given ids doesn't exist"); | |
} | |
self.edges.push(Branch::new(id_n1, id_n2, cost)); | |
} | |
pub fn get_branch(&self, node1: u32, node2: u32) -> Option<&Branch> { | |
let mut found_branch: &Branch = self.edges.first().unwrap(); | |
for e in &self.edges { | |
if (node1 == e.x && node2 == e.y) || (node1 == e.y && node2 == e.x) { | |
found_branch = e; | |
} | |
}; | |
return Some(found_branch); | |
} | |
} | |
} | |
use structures::{ Graph, Branch, Node }; | |
fn prim_extract_min(orig: &mut Graph, queue: &mut Vec<Node>) -> u32 { | |
let mut min: u32 = std::u32::MAX; | |
type Collect = (u32, u32); | |
let mut inspect: Vec<Collect> = Vec::new(); | |
let _queue = queue.clone(); | |
let tree = orig.verts | |
.clone() | |
.iter() | |
.filter(|&n| !_queue.iter().any(|&x| x.id == n.id)) | |
.cloned() | |
.collect::<Vec<Node>>(); | |
for n in &tree { | |
for o in orig.adj(n.id) { | |
if !tree.iter().any(|&n| n.id == o) { | |
inspect.push((o, orig.get_branch(n.id, o).unwrap().w)); | |
} | |
} | |
} | |
let mut cand = 0; | |
for i in inspect { | |
if i.1 < min { | |
cand = i.0; | |
min = i.1; | |
} | |
} | |
let index = queue.iter().position(|&x| x.id == cand).unwrap(); | |
queue.remove(index); | |
return cand; | |
} | |
fn prim(mut graph: Graph) { | |
let mut rand = random::default().seed([49, 62]).read_u64() as usize; | |
rand %= graph.verts.len(); | |
graph.verts[rand].key = 0; | |
println!("Chosen root: {}", graph.verts[rand].id); | |
let mut q = graph.verts.clone(); //every vertex out of the tree | |
q.remove(rand); | |
while q.len() > 0 { | |
let u = prim_extract_min(&mut graph, &mut q); | |
for v in graph.adj(u) { | |
let branch = graph.get_branch(u, v).unwrap(); | |
if branch.w < graph.verts[v as usize].key { | |
let cost = branch.w; | |
let id = graph.verts[v as usize].id; | |
graph.verts[v as usize].key = branch.w; | |
graph.verts[v as usize].parent = u as i32; | |
println!("Found a better connection at ({}) for cost: {}", id, cost); | |
} | |
} | |
} | |
} | |
fn test_prim() { | |
let mut nodes = Vec::<Node>::new(); | |
for i in 0..10 { | |
nodes.push(Node::new(i, std::u32::MAX, -1)); | |
} | |
let mut graph = Graph::new(nodes); | |
graph.insert(0, 1, 2); | |
graph.insert(0, 2, 1); | |
graph.insert(1, 3, 3); | |
graph.insert(1, 4, 7); | |
graph.insert(2, 4, 2); | |
graph.insert(3, 6, 4); | |
graph.insert(4, 5, 5); | |
graph.insert(5, 7, 3); | |
graph.insert(6, 7, 11); | |
graph.insert(7, 8, 4); | |
graph.insert(7, 9, 8); | |
graph.insert(8, 9, 7); | |
println!("initial graph:"); | |
for b in &graph.edges { | |
println!("branch: ({})---[{}]---({})", b.x, b.w, b.y); | |
} | |
prim(graph); | |
} | |
fn main() { | |
test_prim(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment