Skip to content

Instantly share code, notes, and snippets.

@botev
Last active December 15, 2016 02:40
Show Gist options
  • Save botev/bac770e32f7df341ce18562f5333e5e5 to your computer and use it in GitHub Desktop.
Save botev/bac770e32f7df341ce18562f5333e5e5 to your computer and use it in GitHub Desktop.
Demonstrating how we can use symbolic integer for automatic shape inference and verification.
use std::collections::HashMap;
extern crate symints;
use symints::*;
type SymInt = Polynomial<String, i64, u8>;
type Shape = (SymInt, SymInt, SymInt, SymInt);
enum ConvolutionMode {
Valid,
Half,
Full,
}
fn is_2d(shape: &Shape) -> bool {
shape.2 == 1 && shape.3 == 1
}
fn matrix_mul_shape(left: &Shape, right: &Shape) -> Option<Shape> {
// Check we have a 2D tensors with matching middle dimension
if left.1 != right.0 || left.2 != 1 || left.3 != 1 || right.2 != 1 || right.3 != 1 {
None
} else {
Some((left.0.clone(), right.1.clone(), 1.into(), 1.into()))
}
}
fn element_wise_shape(left: &Shape, right: &Shape) -> Option<Shape> {
// Check we have a 2D tensors with matching middle dimension
if left != right {
None
} else {
Some(left.clone())
}
}
fn convolution_2d_shape(image: &Shape, kernel: &Shape, stride: &Shape, mode: ConvolutionMode) -> Option<Shape> {
// Check everything is 2D
if is_2d(image) && is_2d(kernel) && is_2d(stride) {
let (padding0, padding1): (SymInt, SymInt) = match mode {
ConvolutionMode::Valid => (0.into(), 0.into()),
ConvolutionMode::Half => (floor(&kernel.0, &2.into()), floor(&kernel.1, &2.into())),
ConvolutionMode::Full => (&kernel.0 - 1, &kernel.1 - 1),
};
Some((ceil(&(&(&image.0 - &kernel.0) + &(2 * &padding0)), &stride.0),
ceil(&(&(&image.1 - &kernel.1) + &(2 * &padding1)), &stride.1),
1.into(),
1.into()))
} else {
None
}
}
fn eval_shape(shape: &Shape, values: &HashMap<String, i64>) -> Result<(i64, i64, i64, i64), String> {
Ok((shape.0.evaluate(values)?, shape.1.evaluate(values)?, shape.2.evaluate(values)?, shape.3.evaluate(values)?))
}
fn main() {
let a = primitive("a".into());
let b = primitive("b".into());
let c = primitive("c".into());
let d = primitive("d".into());
let mut values: HashMap<String, i64> = HashMap::new();
values.insert("a".into(), 20);
values.insert("b".into(), 7);
values.insert("c".into(), 10);
values.insert("d".into(), 3);
let mut temp: Shape;
let s1: Shape = (a.clone(), b.clone(), 1.into(), 1.into());
let s2: Shape = (b.clone(), c.clone(), 1.into(), 1.into());
let s3: Shape = (a.clone(), b.clone(), 1.into(), 1.into());
let im: Shape = (c.clone(), c.clone(), 1.into(), 1.into());
let ker: Shape = (d.clone(), d.clone(), 1.into(), 1.into());
let st: Shape = (2.into(), 2.into(), 1.into(), 1.into());
temp = matrix_mul_shape(&s1, &s2).unwrap();
println!("({}, {}, {}, {})", temp.0, temp.1, temp.2, temp.3);
println!("{:?}", eval_shape(&temp, &values));
println!("{:?}", matrix_mul_shape(&s1, &s1));
temp = element_wise_shape(&s1, &s3).unwrap();
println!("({}, {}, {}, {})", temp.0, temp.1, temp.2, temp.3);
println!("{:?}", eval_shape(&temp, &values));
println!("{:?}", element_wise_shape(&s1, &s2));
temp = convolution_2d_shape(&im, &ker, &st, ConvolutionMode::Valid).unwrap();
println!("({}, {}, {}, {})", temp.0, temp.1, temp.2, temp.3);
println!("{:?}", eval_shape(&temp, &values));
temp = convolution_2d_shape(&im, &ker, &st, ConvolutionMode::Half).unwrap();
println!("({}, {}, {}, {})", temp.0, temp.1, temp.2, temp.3);
println!("{:?}", eval_shape(&temp, &values));
temp = convolution_2d_shape(&im, &ker, &st, ConvolutionMode::Full).unwrap();
println!("({}, {}, {}, {})", temp.0, temp.1, temp.2, temp.3);
println!("{:?}", eval_shape(&temp, &values));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment