Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created August 12, 2022 16:13
Show Gist options
  • Save Steboss89/1a06440446d42e29e891e4de70b249d0 to your computer and use it in GitHub Desktop.
Save Steboss89/1a06440446d42e29e891e4de70b249d0 to your computer and use it in GitHub Desktop.
Convert an array to tch::Tensor
use std::result::Result;
use std::error::Error;
use mnist::*;
use tch::{kind, no_grad, Kind, Tensor};
use ndarray::{Array3, Array2};
pub fn image_to_tensor(data:Vec<u8>, dim1:usize, dim2:usize, dim3:usize)-> Tensor{
// normalize the image as well
let inp_data: Array3<f32> = Array3::from_shape_vec((dim1, dim2, dim3), data)
.expect("Error converting data to 3D array")
.map(|x| *x as f32/256.0);
// convert to tensor
let inp_tensor = Tensor::of_slice(inp_data.as_slice().unwrap());
// reshape so we'll have dim1, dim2*dim3 shape array
let ax1 = dim1 as i64;
let ax2 = (dim2 as i64)*(dim3 as i64);
let shape: Vec<i64> = vec![ ax1, ax2 ];
let output_data = inp_tensor.reshape(&shape);
println!("Output image tensor size {:?}", shape);
output_data
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment