Skip to content

Instantly share code, notes, and snippets.

@matthewjberger
Last active November 3, 2023 02:20
Show Gist options
  • Save matthewjberger/1333a2a10ad98d9cf72556c42f4a2800 to your computer and use it in GitHub Desktop.
Save matthewjberger/1333a2a10ad98d9cf72556c42f4a2800 to your computer and use it in GitHub Desktop.
use petgraph::{graph::NodeIndex, prelude::*};
use std::ops::{Index, IndexMut};
pub trait Aggregatable: Clone + PartialEq + std::fmt::Debug {
fn aggregate(&self, parent: &Self) -> Self;
}
// The SceneGraph is now generic over the type of the transform
#[derive(Default)]
pub struct SceneGraph<T: Aggregatable>(pub Graph<T, ()>);
impl<T: Aggregatable> SceneGraph<T> {
pub fn new() -> Self {
Self(Graph::<T, ()>::new())
}
pub fn add_node(&mut self, transform: T) -> NodeIndex {
self.0.add_node(transform)
}
pub fn add_edge(&mut self, parent_node: NodeIndex, child_node: NodeIndex) {
self.0.add_edge(parent_node, child_node, ());
}
pub fn walk(&self, mut action: impl FnMut(NodeIndex, &T)) {
for node_index in self.0.node_indices() {
if self.has_parents(node_index) {
continue;
}
let mut dfs = Dfs::new(&self.0, node_index);
while let Some(nx) = dfs.next(&self.0) {
let transform = &self.0[nx];
action(nx, transform);
}
}
}
pub fn has_parents(&self, index: NodeIndex) -> bool {
self.0.neighbors_directed(index, Incoming).next().is_some()
}
pub fn global_transform(&self, index: NodeIndex) -> T {
let mut global_transform = self[index].clone();
let mut current_index = Some(index);
while let Some(idx) = current_index {
if let Some(parent_idx) = self.parent_of(idx) {
let parent_transform = self[parent_idx].clone();
global_transform = global_transform.aggregate(&parent_transform);
current_index = Some(parent_idx);
} else {
break;
}
}
global_transform
}
pub fn parent_of(&self, index: NodeIndex) -> Option<NodeIndex> {
let mut incoming_walker = self.0.neighbors_directed(index, Incoming).detach();
incoming_walker.next_node(&self.0)
}
}
impl<T: Aggregatable> Index<NodeIndex> for SceneGraph<T> {
type Output = T;
fn index(&self, index: NodeIndex) -> &Self::Output {
&self.0[index]
}
}
impl<T: Aggregatable> IndexMut<NodeIndex> for SceneGraph<T> {
fn index_mut(&mut self, index: NodeIndex) -> &mut Self::Output {
&mut self.0[index]
}
}
// Example Transform type for testing
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct SimpleTransform {
pub x: f32,
pub y: f32,
}
impl Aggregatable for SimpleTransform {
fn aggregate(&self, parent: &Self) -> Self {
Self {
x: self.x + parent.x,
y: self.y + parent.y,
}
}
}
// Testing module
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_node() {
let mut graph = SceneGraph::new();
let node_transform = SimpleTransform { x: 5.0, y: 10.0 };
let node_index = graph.add_node(node_transform);
assert_eq!(graph[node_index], node_transform);
}
#[test]
fn test_add_edge_and_parent_of() {
let mut graph = SceneGraph::new();
let parent_index = graph.add_node(SimpleTransform { x: 1.0, y: 2.0 });
let child_index = graph.add_node(SimpleTransform { x: 3.0, y: 4.0 });
graph.add_edge(parent_index, child_index);
assert_eq!(graph.parent_of(child_index), Some(parent_index));
}
#[test]
fn test_global_transform() {
let mut graph = SceneGraph::new();
let parent_index = graph.add_node(SimpleTransform { x: 1.0, y: 2.0 });
let child_index = graph.add_node(SimpleTransform { x: 3.0, y: 4.0 });
graph.add_edge(parent_index, child_index);
let global_transform = graph.global_transform(child_index);
let expected_transform = SimpleTransform { x: 4.0, y: 6.0 };
assert_eq!(global_transform, expected_transform);
}
#[test]
fn test_walk() {
let mut graph = SceneGraph::new();
let node1 = graph.add_node(SimpleTransform { x: 1.0, y: 2.0 });
let node2 = graph.add_node(SimpleTransform { x: 3.0, y: 4.0 });
graph.add_edge(node1, node2);
let mut visited = Vec::new();
graph.walk(|idx, transform| {
visited.push((idx, *transform));
});
assert_eq!(visited.len(), 2);
assert_eq!(visited[0].1, SimpleTransform { x: 1.0, y: 2.0 });
assert_eq!(visited[1].1, SimpleTransform { x: 3.0, y: 4.0 });
}
#[test]
fn test_has_parents() {
let mut graph = SceneGraph::new();
let node1 = graph.add_node(SimpleTransform { x: 1.0, y: 2.0 });
let node2 = graph.add_node(SimpleTransform { x: 3.0, y: 4.0 });
assert!(!graph.has_parents(node1));
assert!(!graph.has_parents(node2));
graph.add_edge(node1, node2);
assert!(!graph.has_parents(node1));
assert!(graph.has_parents(node2));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment