Last active
November 24, 2021 08:59
-
-
Save adriangb/9d4561fa9ac04eb59dafc324b0550a60 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp; | |
use std::hash; | |
use std::fmt; | |
use pyo3::basic::CompareOp; | |
use pyo3::prelude::*; | |
// We can't put a Py<PyAny> directly into a HashMap key | |
// So to be able to hold references to arbitrary Python objects in HashMap as keys | |
// we wrap them in a struct that gets the hash() when it receives the object from Python | |
// and then just echoes back that hash when called Rust needs to hash it | |
#[derive(Clone)] | |
pub struct HashedAny { | |
pub o: Py<PyAny>, | |
pub hash: isize, | |
} | |
// Use the result of calling repr() on the Python object as the debug string value | |
impl fmt::Debug for HashedAny { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
Python::with_gil(|py| -> PyResult<fmt::Result> { | |
let obj = self.o.as_ref(py); | |
let pystr = obj.repr()?; | |
let ruststr = pystr.to_str()?; | |
Ok(write!(f, "{}", ruststr)) | |
}).unwrap() | |
} | |
} | |
impl <'source>FromPyObject<'source> for HashedAny | |
{ | |
fn extract(ob: &'source PyAny) -> PyResult<Self> { | |
Ok(HashedAny{ o: ob.into(), hash: ob.hash()? }) | |
} | |
} | |
impl hash::Hash for HashedAny { | |
fn hash<H: hash::Hasher>(&self, state: &mut H) { | |
self.hash.hash(state) | |
} | |
} | |
impl cmp::PartialEq for HashedAny { | |
fn eq(&self, other: &Self) -> bool { | |
Python::with_gil(|py| -> PyResult<bool> { | |
let this_ref = self.o.as_ref(py); | |
let other_ref = other.o.as_ref(py); | |
if this_ref.eq(other_ref) { | |
Ok(true) | |
} | |
else { | |
Ok(this_ref.rich_compare(other_ref, CompareOp::Eq)?.is_true()?) | |
} | |
}).unwrap() | |
} | |
} | |
impl cmp::Eq for HashedAny {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use pyo3::prelude::*; | |
use pyo3::exceptions; | |
use pyo3::types::PyTuple; | |
mod hashedany; | |
use crate::hashedany::HashedAny; | |
#[pyclass(module="di_lib",subclass)] | |
#[derive(Debug,Clone)] | |
struct Graph { | |
children: HashMap<HashedAny, Vec<HashedAny>>, | |
parents: HashMap<HashedAny, Vec<HashedAny>>, | |
child_counts: HashMap<HashedAny, usize>, | |
ready_nodes: Vec<Py<PyAny>>, | |
not_done_count: usize, | |
} | |
impl Graph { | |
fn remove_node(&mut self, node: &HashedAny, to_remove: &mut Vec<HashedAny>) -> () { | |
match self.child_counts.remove(&node) { | |
Some(_) => (), | |
// This node was already removed | |
// This happens if parents and children are passed in the nodes argument | |
None => return, | |
} | |
// Find all parents and reduce their dependency count by one | |
match self.parents.remove(&node) { | |
Some(parents) => { | |
for parent in parents { | |
match self.child_counts.get_mut(&parent) { | |
Some(v) => { | |
*v -= 1; | |
}, | |
// This node was already removed | |
// This happens if parents and children are passed in the nodes argument | |
None => continue, | |
} | |
} | |
}, | |
// this node was already removed | |
None => return, | |
}; | |
// Push all children onto the stack for removal | |
match self.children.remove(&node) { | |
Some(children) => { | |
for child in children { | |
to_remove.push(child); | |
}; | |
}, | |
None => () | |
}; | |
} | |
} | |
#[pymethods] | |
impl Graph { | |
#[new] | |
fn new(graph: HashMap<HashedAny, Vec<HashedAny>>) -> Self { | |
let mut child_counts: HashMap<HashedAny, usize> = HashMap::new(); | |
let mut parents: HashMap<HashedAny, Vec<HashedAny>> = HashMap::new(); | |
let mut ready_nodes: Vec<Py<PyAny>> = Vec::new(); | |
let mut child_count: usize; | |
for (node, children) in &graph { | |
parents.entry(node.clone()).or_insert_with(Vec::new); | |
child_count = (*children).len(); | |
child_counts.insert(node.clone(), child_count); | |
if child_count == 0 { | |
ready_nodes.push(node.o.clone()); | |
} | |
for child in children { | |
parents.entry(child.clone()).or_insert_with(Vec::new).push(node.clone()); | |
} | |
} | |
Graph { | |
children: graph.clone(), | |
parents: parents, | |
child_counts: child_counts, | |
ready_nodes: ready_nodes, | |
not_done_count: graph.len(), | |
} | |
} | |
/// Returns string representation of the graph | |
fn __str__(&self) -> PyResult<String> { | |
Ok(format!("Graph({:?})", self.children)) | |
} | |
fn __repr__(&self) -> PyResult<String> { | |
self.__str__() | |
} | |
/// Returns a deep copy of this graph | |
fn copy(&self) -> Graph { | |
self.clone() | |
} | |
/// Returns any nodes with no dependencies after marking `node` as done | |
/// # Arguments | |
/// | |
/// * `node` - A node in the graph | |
#[args(args="*")] | |
fn done(&mut self, args: &PyTuple) -> PyResult<()> { | |
let mut node: HashedAny; | |
let mut v: usize; | |
for obj in args { | |
node = HashedAny::extract(obj)?; | |
// Check that this node is ready to be marked as done and mark it | |
v = *self.child_counts.get(&node).unwrap(); | |
if v != 0 { | |
return Err(exceptions::PyException::new_err("Node still has children")); | |
} | |
self.not_done_count -= 1; | |
// Find all parents and reduce their dependency count by one, | |
// returning all parents w/o any further dependencies | |
for parent in self.parents.get(&node).unwrap() { | |
match self.child_counts.get_mut(parent) { | |
Some(v) => { | |
*v -= 1; | |
if *v == 0 { | |
self.ready_nodes.push(parent.o.clone()); | |
} | |
}, | |
None => return Err(exceptions::PyKeyError::new_err(format!("Parent node {:?} not found", parent))) | |
} | |
} | |
} | |
Ok(()) | |
} | |
fn is_active(&self) -> bool { | |
self.not_done_count != 0 || !self.ready_nodes.is_empty() | |
} | |
/// Removes nodes from the graph and cleans up newly created disconnected components | |
/// # Arguments | |
/// | |
/// * `nodes` - Nodes to be removed from the graph | |
fn remove(&mut self, nodes: &PyTuple) -> PyResult<()> { | |
let mut to_remove: Vec<HashedAny> = Vec::new(); | |
for node in nodes { | |
self.remove_node(&HashedAny::extract(node)?, &mut to_remove); | |
} | |
let mut node: HashedAny; | |
loop { | |
node = match to_remove.pop() { | |
Some(v) => v, | |
None => return Ok(()) | |
}; | |
self.remove_node(&node, &mut to_remove); | |
} | |
} | |
/// Returns all nodes with no dependencies | |
fn get_ready(&mut self) -> Vec<Py<PyAny>> { | |
let ret = self.ready_nodes.clone(); | |
self.ready_nodes.clear(); | |
ret | |
} | |
} | |
#[pymodule] | |
fn di_lib(_py: Python, m: &PyModule) -> PyResult<()> { | |
m.add_class::<Graph>()?; | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment