You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Open src/main.rs in a text editor and replace contents with the following:
extern crate collenchyma as co;
extern crate collenchyma_blas as blas;
use co::prelude::*;
use blas::plugin::Dot;
fn write_to_memory<T: Copy>(mem: &mut MemoryType, data: &[T]) {
if let &mut MemoryType::Native(ref mut mem) = mem {
let mut mem_buffer = mem.as_mut_slice::<T>();
for (index, datum) in data.iter().enumerate() {
mem_buffer[index] = *datum;
}
}
}
fn main() {
// Initialize a CUDA Backend.
let backend = Backend::<Native>::default().unwrap();
// Initialize two SharedTensors.
let mut x = SharedTensor::<f32>::new(backend.device(), &(2,2)).unwrap();
let mut y = SharedTensor::<f32>::new(backend.device(), &(2,2)).unwrap();
let mut result = SharedTensor::<f32>::new(backend.device(), &(2,2)).unwrap();
// Fill `x` with some data.
let payload_x: &[f32] = &::std::iter::repeat(2f32).take(x.capacity()).collect::<Vec<f32>>();
let payload_y: &[f32] = &::std::iter::repeat(3f32).take(y.capacity()).collect::<Vec<f32>>();
let native = Backend::<Native>::default().unwrap();
// x.add_device(native.device()).unwrap(); // Add native host memory
x.sync(native.device()).unwrap(); // Sync to native host memory
y.sync(native.device()).unwrap(); // Sync to native host memory
write_to_memory(x.get_mut(native.device()).unwrap(), payload_x); // Write to native host memory.
write_to_memory(y.get_mut(native.device()).unwrap(), payload_y); // Write to native host memory.
x.sync(backend.device()).unwrap(); // Sync the data to the CUDA device.
y.sync(backend.device()).unwrap();
// Run the sigmoid operation, provided by the NN Plugin, on your CUDA enabled GPU.
// backend.sigmoid(&mut x, &mut result).unwrap();
let res = backend.dot(&mut x, &mut y, &mut result).unwrap();
// See the result.
// result.add_device(native.device()).unwrap(); // Add native host memory
result.sync(native.device()).unwrap(); // Sync the result to host memory.
println!("{:?}", result.get(native.device()).unwrap().as_native().unwrap().as_slice::<f32>());
}
cargo run and you should see output of 24, 0, 0, 0.