Created
January 11, 2020 03:56
-
-
Save crackcomm/5cdf2ce8e5a1a5d05847e1a4fd825edc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate libc; | |
use libc::size_t; | |
#[link(name = "onnx", kind = "static")] | |
extern "C" { | |
pub fn onnx_proto_shape_inference(buffer: *const u8, size: size_t, out: *mut u8) -> size_t; | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn read_proto() { | |
let buffer = read_buf("tests/model.onnx"); | |
let inferred = read_buf("tests/model-inferred.onnx"); | |
unsafe { | |
let capacity = buffer.len() * 10; | |
let mut output = Vec::with_capacity(capacity); | |
output.set_len(capacity); | |
let out_size = | |
onnx_proto_shape_inference(buffer.as_ptr(), buffer.len(), output.as_mut_ptr()); | |
output.truncate(out_size); | |
assert_eq!(output, inferred); | |
} | |
} | |
fn read_buf<P: AsRef<std::path::Path>>(path: P) -> Vec<u8> { | |
use std::io::Read; | |
let mut file = std::fs::File::open(path).unwrap(); | |
let mut buffer = Vec::new(); | |
// read the whole file | |
file.read_to_end(&mut buffer).unwrap(); | |
buffer | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment