Skip to content

Instantly share code, notes, and snippets.

@rj00a
Created May 2, 2022 10:06
Show Gist options
  • Save rj00a/d63bc00db277bfda122f1a6524d40070 to your computer and use it in GitHub Desktop.
Save rj00a/d63bc00db277bfda122f1a6524d40070 to your computer and use it in GitHub Desktop.
wave function collapse
// WARNING: Old code. Wrote this a while ago.
//! An interface for solving Constraint Satisfaction Problems (CSPs)
use bitvec::prelude::*;
use indexmap::IndexSet;
use log::warn;
use ndarray::prelude::*;
use num_enum::IntoPrimitive;
use rand::{distributions::WeightedError, prelude::*};
use std::hash::Hash;
/// A Simple directed graph collection which represents the space a constraint solver will operate on.
///
/// Each node in the graph is associated with a value.
/// The actual structure of the graph cannot be modified with this trait, but the value associated with each node can be.
///
/// Additionally, each edge has a color (a `usize` index).
/// This is useful if you need to encode directionality constraints.
/// For instance, a square grid might use four colors for the cardinal directions.
pub trait CspGraph {
/// The actual container type.
/// Represents `forall A. Self<A>` since we're faking HKTs.
type Graph<E>: CspGraph;
/// Uniquely identifies the location of a node in the graph.
type NodeIdx: Copy + Eq + Hash;
/// The type describing edge colors.
/// Edge colors must convert into a `usize` which is less than `NUM_EDGE_COLORS`
type EdgeColor: Copy + Eq + Hash + Into<usize>;
/// An upper bound on the number of edge colors that can appear in the graph.
/// Color indices must be less than this value.
const NUM_EDGE_COLORS: usize;
// TODO: use Iterator when impl trait in traits is available.
/// Returns all out-neighbors of a given node and the colors of the edges used to get there in an arbitrary order.
fn outgoing_neighbors<T>(
c: &Self::Graph<T>,
idx: Self::NodeIdx,
out: &mut Vec<(Self::NodeIdx, Self::EdgeColor)>,
);
// TODO: use Iterator when impl trait in traits is available.
/// Returns the index of every node in the graph.
/// The order that the nodes are given can affect the performance of the algorithm.
/// For instance, scanline order is good. Random order is really bad.
fn all_nodes<T>(c: &Self::Graph<T>, out: &mut Vec<Self::NodeIdx>);
/// Get an immutable reference to the node at the provided index.
fn get<T>(c: &Self::Graph<T>, idx: Self::NodeIdx) -> &T;
/// Get a mutable reference to the node the provided index.
fn get_mut<T>(c: &mut Self::Graph<T>, idx: Self::NodeIdx) -> &mut T;
/// Create a new graph by mapping every element while retaining structure.
/// In other words, the graph is a functor.
fn map<T, U>(c: &Self::Graph<T>, f: impl FnMut(&T) -> U) -> Self::Graph<U>;
}
pub struct SquareGrid<T> {
pub array: Array2<T>,
pub periodic_y: bool,
pub periodic_x: bool,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, IntoPrimitive)]
#[repr(usize)]
pub enum SquareEdge {
PositiveY,
NegativeY,
PositiveX,
NegativeX,
}
impl<A> CspGraph for SquareGrid<A> {
type Graph<E> = SquareGrid<E>;
type NodeIdx = [usize; 2];
type EdgeColor = SquareEdge;
const NUM_EDGE_COLORS: usize = 4;
fn outgoing_neighbors<T>(
c: &Self::Graph<T>,
[y, x]: Self::NodeIdx,
out: &mut Vec<(Self::NodeIdx, SquareEdge)>,
) {
out.clear();
let y = y as isize;
let x = x as isize;
let (size_y, size_x) = c.array.raw_dim().into_pattern();
let size_y = size_y as isize;
let size_x = size_x as isize;
for (edge, [mut y, mut x]) in [
(SquareEdge::PositiveY, [y + 1, x]),
(SquareEdge::NegativeY, [y - 1, x]),
(SquareEdge::PositiveX, [y, x + 1]),
(SquareEdge::NegativeX, [y, x - 1]),
] {
if c.periodic_y {
y = y.rem_euclid(size_y);
}
if c.periodic_x {
x = x.rem_euclid(size_x);
}
if y < size_y && y >= 0 && x < size_x && x >= 0 {
out.push(([y as usize, x as usize], edge));
}
}
}
fn all_nodes<T>(c: &Self::Graph<T>, out: &mut Vec<Self::NodeIdx>) {
out.clear();
let (size_y, size_x) = c.array.raw_dim().into_pattern();
for y in 0..size_y {
for x in 0..size_x {
out.push([y, x]);
}
}
}
fn get<T>(c: &Self::Graph<T>, idx: Self::NodeIdx) -> &T {
&c.array[idx]
}
fn get_mut<T>(c: &mut Self::Graph<T>, idx: Self::NodeIdx) -> &mut T {
&mut c.array[idx]
}
fn map<T, U>(c: &Self::Graph<T>, f: impl FnMut(&T) -> U) -> Self::Graph<U> {
SquareGrid {
array: c.array.map(f),
periodic_y: c.periodic_y,
periodic_x: c.periodic_x,
}
}
}
pub type ElementId = u32;
#[derive(Clone)]
struct TableEntry {
node_weight: f64,
// TODO: use G::NUM_EDGE_COLORS in an array to avoid an extra indirection.
// (blocked by generic_const_exprs)
by_edge_color: Box<[BitBox]>,
}
pub struct ExampleModel {
/// Contains node weights and the valid connections.
table: Box<[TableEntry]>,
}
impl ExampleModel {
pub fn new<G, I>(examples: I, num_elements: ElementId) -> Self
where
G: CspGraph,
I: IntoIterator<Item = G::Graph<ElementId>>,
{
let mut table = vec![
TableEntry {
node_weight: 0.0,
by_edge_color: vec![bitbox![0; num_elements as usize]; G::NUM_EDGE_COLORS]
.into_boxed_slice(),
};
num_elements as usize
]
.into_boxed_slice();
let mut all_nodes = Vec::new();
let mut other_nodes = Vec::new();
for ex in examples {
G::all_nodes(&ex, &mut all_nodes);
for &this_node in &all_nodes {
let this_element = *G::get(&ex, this_node);
debug_assert!(this_element < num_elements);
let entry = &mut table[this_element as usize];
entry.node_weight += 1.0;
G::outgoing_neighbors(&ex, this_node, &mut other_nodes);
for &(other_node, edge_color) in &other_nodes {
let other_element = *G::get(&ex, other_node);
debug_assert!(other_element < num_elements);
debug_assert!(edge_color.into() < G::NUM_EDGE_COLORS);
entry.by_edge_color[edge_color.into()].set(other_element as usize, true);
}
}
}
ExampleModel { table }
}
pub fn solve_csp_simple<G, I, R>(
&self,
res: &mut G::Graph<ElementId>,
initial_nodes: I,
rng: &mut R,
retries: u32,
) -> bool
where
G: CspGraph,
I: IntoIterator<Item = (G::NodeIdx, ElementId)> + Clone,
R: Rng + ?Sized,
{
for i in 0..retries {
if solve_csp_simple::<G, I, R, _, _>(
res,
initial_nodes.clone(),
self.table.len(),
rng,
|this, other, color: G::EdgeColor| {
self.table[this as usize].by_edge_color[color.into()][other as usize]
},
|id| self.table[id as usize].node_weight,
) {
return true;
}
warn!("Synthesis failed: {}/{}", i + 1, retries);
}
false
}
}
/// Attempt to solve the CSP by randomly exploring the possibility space + local consistency.
///
/// This function doesn't use backtracking for efficiency.
/// Because of this, we might not find a solution even if one exists.
/// This is often acceptable for simple problems, but unacceptable for others.
pub fn solve_csp_simple<G, I, R, A, W>(
res: &mut G::Graph<ElementId>,
initial_nodes: I,
num_elems: usize,
rng: &mut R,
allowed: A,
get_element_weight: W,
) -> bool
where
G: CspGraph,
I: IntoIterator<Item = (G::NodeIdx, ElementId)>,
R: Rng + ?Sized,
A: Fn(ElementId, ElementId, G::EdgeColor) -> bool,
W: Fn(ElementId) -> f64,
{
let mut candidate_set = G::map(res, |_| bitvec![1; num_elems].into_boxed_bitslice());
let mut needs_update = IndexSet::new();
for (idx, id) in initial_nodes {
assert!(
(id as usize) < num_elems,
"element ID must be less than the number of elements"
);
let bits = G::get_mut(&mut candidate_set, idx);
// only this ID should be set.
bits.set_all(false);
bits.set(id as usize, true);
needs_update.insert(idx);
}
// We have to use a temporary bit buffer because we can't have mutable
// references to the current node and the adjacent node at the same time.
let mut bit_buffer = Vec::new();
let mut neighborhood_buffer = Vec::new();
// Uses AC3
let mut establish_arc_consistency =
|candidate_set: &mut G::Graph<BitBox<_, _>>, needs_update: &mut IndexSet<G::NodeIdx>| {
// Propagate the updated constraints globally
while let Some(this_idx) = needs_update.pop() {
G::outgoing_neighbors(candidate_set, this_idx, &mut neighborhood_buffer);
for &(adj_idx, edge_color) in &neighborhood_buffer {
let this_bits = G::get(candidate_set, this_idx);
let adj_bits = G::get(candidate_set, adj_idx);
for (adj_id, _) in adj_bits.iter().enumerate().filter(|(_, b)| **b) {
if !this_bits
.iter_ones()
.any(|this_id| allowed(this_id as u32, adj_id as u32, edge_color))
{
bit_buffer.push(adj_id);
}
}
if !bit_buffer.is_empty() {
needs_update.insert(adj_idx);
}
let adj_bits = G::get_mut(candidate_set, adj_idx);
for &id in &bit_buffer {
adj_bits.set(id, false);
}
bit_buffer.clear();
}
}
};
establish_arc_consistency(&mut candidate_set, &mut needs_update);
let mut all_indices = Vec::new();
G::all_nodes(res, &mut all_indices);
for idx in all_indices {
let bits = G::get_mut(&mut candidate_set, idx);
let choices = bits
.iter_ones()
.map(|id| id as ElementId)
.collect::<Vec<_>>();
match choices.choose_weighted(rng, |&id| get_element_weight(id)) {
Ok(&id) => {
if choices.len() > 1 {
bits.set_all(false);
bits.set(id as usize, true);
needs_update.insert(idx);
establish_arc_consistency(&mut candidate_set, &mut needs_update);
}
*G::get_mut(res, idx) = id;
},
Err(WeightedError::NoItem) => {
// There is no valid tile at this index.
// Since we don't do any backtracking, we have no choice but to give up.
return false;
},
Err(WeightedError::InvalidWeight) => panic!("Invalid weight"),
Err(WeightedError::AllWeightsZero) => panic!("All element weights are zero (elements should never have a weight of zero to begin with)"),
Err(WeightedError::TooMany) => panic!("Too many element weights"),
}
}
true
}
mod tests {
use super::*;
#[test]
fn square_grid_from_example() {
let example = [
"πŸŒ²πŸš–πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸŒ²πŸš–πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸŒ²πŸš–πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸš—πŸš¦πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸš–πŸŒŠπŸš–πŸŒ²πŸŒŠπŸš–",
"πŸš—πŸš—πŸš—πŸŒ‰πŸš—πŸš—πŸš—πŸŒ‰πŸš—πŸš¦πŸŒ‰πŸš¦πŸŒ²πŸŒŠπŸš–",
"πŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒŠπŸŒ²πŸŒ²πŸŒŠπŸš–",
"πŸš—πŸš—πŸš—πŸŒ‰πŸš—πŸš—πŸš—πŸŒ‰πŸš—πŸš—πŸŒ‰πŸš—πŸš—πŸŒ‰πŸš¦",
"🌲🌲🌲🌊🌲🌲🌲🌊🌲🌲🌊🌲🌲🌊🌲",
"🌲🌲🌲🌊🌲🌲🌲🌊🌲🌲🌊🌲🌲🌊🌲",
];
let shape = [example.len(), example[0].chars().count()];
let array = Array2::from_shape_vec(
shape,
example
.into_iter()
.flat_map(|s| {
s.chars().map(|c| match c {
'🌲' => 0,
'πŸš–' => 1,
'🌊' => 2,
'πŸš—' => 3,
'πŸŒ‰' => 4,
'🚦' => 5,
_ => unreachable!(),
})
})
.collect(),
)
.unwrap();
let grid = SquareGrid {
array,
periodic_y: false,
periodic_x: false,
};
let model = ExampleModel::new::<SquareGrid<()>, _>(std::iter::once(grid), 6);
let mut res = SquareGrid {
array: Array2::from_elem([20, 20], 0),
periodic_x: false,
periodic_y: false,
};
assert!(model.solve_csp_simple::<SquareGrid<()>, _, _>(
&mut res,
std::iter::empty(),
&mut thread_rng(),
10
));
for row in res.array.rows() {
for &n in row {
print!(
"{}",
match n {
0 => '🌲',
1 => 'πŸš–',
2 => '🌊',
3 => 'πŸš—',
4 => 'πŸŒ‰',
5 => '🚦',
_ => unreachable!(),
}
);
}
println!();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment