Skip to content

Instantly share code, notes, and snippets.

@th3terrorist
Last active November 14, 2021 10:27
Show Gist options
  • Save th3terrorist/39c1fe5af005b7bea666f946dfc87c11 to your computer and use it in GitHub Desktop.
Save th3terrorist/39c1fe5af005b7bea666f946dfc87c11 to your computer and use it in GitHub Desktop.
A rust implementation of the Prim's mst algorithm.
use random::Source;
mod structures {
#[derive(Eq, PartialEq, Clone, Copy)]
pub struct Branch {
pub x: u32,
pub y: u32,
pub w: u32
}
impl Branch {
pub fn new(x: u32, y: u32, w: u32) -> Self {
Self {
x,
y,
w
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Node {
pub id: u32,
pub key: u32,
pub parent: i32
}
impl Node {
pub fn new(id: u32, key: u32, parent: i32) -> Self {
Self {
id,
key,
parent
}
}
}
#[derive(Clone)]
pub struct Graph {
pub verts: Vec<Node>,
pub edges: Vec<Branch>
}
impl Graph {
pub fn new(nodes: Vec<Node>) -> Self {
Self {
verts: nodes,
edges: Vec::new()
}
}
pub fn adj(&self, node: u32) -> Vec<u32> {
let mut adjv = Vec::<u32>::new();
for e in &self.edges {
if e.x == node {
adjv.push(e.y);
}
if e.y == node {
adjv.push(e.x);
}
}
return adjv;
}
pub fn insert(&mut self, id_n1: u32, id_n2: u32, cost: u32) {
let mut node1: Option<&Node> = None;
let mut node2: Option<&Node> = None;
for node in &self.verts {
if node.id == id_n1 {
node1 = Some(node);
}
if node.id == id_n2 {
node2 = Some(node);
}
}
if node1.is_none() || node2.is_none() {
panic!("One of the given ids doesn't exist");
}
self.edges.push(Branch::new(id_n1, id_n2, cost));
}
pub fn get_branch(&self, node1: u32, node2: u32) -> Option<&Branch> {
let mut found_branch: &Branch = self.edges.first().unwrap();
for e in &self.edges {
if (node1 == e.x && node2 == e.y) || (node1 == e.y && node2 == e.x) {
found_branch = e;
}
};
return Some(found_branch);
}
}
}
use structures::{ Graph, Branch, Node };
fn prim_extract_min(orig: &mut Graph, queue: &mut Vec<Node>) -> u32 {
let mut min: u32 = std::u32::MAX;
type Collect = (u32, u32);
let mut inspect: Vec<Collect> = Vec::new();
let _queue = queue.clone();
let tree = orig.verts
.clone()
.iter()
.filter(|&n| !_queue.iter().any(|&x| x.id == n.id))
.cloned()
.collect::<Vec<Node>>();
for n in &tree {
for o in orig.adj(n.id) {
if !tree.iter().any(|&n| n.id == o) {
inspect.push((o, orig.get_branch(n.id, o).unwrap().w));
}
}
}
let mut cand = 0;
for i in inspect {
if i.1 < min {
cand = i.0;
min = i.1;
}
}
let index = queue.iter().position(|&x| x.id == cand).unwrap();
queue.remove(index);
return cand;
}
fn prim(mut graph: Graph) {
let mut rand = random::default().seed([49, 62]).read_u64() as usize;
rand %= graph.verts.len();
graph.verts[rand].key = 0;
println!("Chosen root: {}", graph.verts[rand].id);
let mut q = graph.verts.clone(); //every vertex out of the tree
q.remove(rand);
while q.len() > 0 {
let u = prim_extract_min(&mut graph, &mut q);
for v in graph.adj(u) {
let branch = graph.get_branch(u, v).unwrap();
if branch.w < graph.verts[v as usize].key {
let cost = branch.w;
let id = graph.verts[v as usize].id;
graph.verts[v as usize].key = branch.w;
graph.verts[v as usize].parent = u as i32;
println!("Found a better connection at ({}) for cost: {}", id, cost);
}
}
}
}
fn test_prim() {
let mut nodes = Vec::<Node>::new();
for i in 0..10 {
nodes.push(Node::new(i, std::u32::MAX, -1));
}
let mut graph = Graph::new(nodes);
graph.insert(0, 1, 2);
graph.insert(0, 2, 1);
graph.insert(1, 3, 3);
graph.insert(1, 4, 7);
graph.insert(2, 4, 2);
graph.insert(3, 6, 4);
graph.insert(4, 5, 5);
graph.insert(5, 7, 3);
graph.insert(6, 7, 11);
graph.insert(7, 8, 4);
graph.insert(7, 9, 8);
graph.insert(8, 9, 7);
println!("initial graph:");
for b in &graph.edges {
println!("branch: ({})---[{}]---({})", b.x, b.w, b.y);
}
prim(graph);
}
fn main() {
test_prim();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment