Created
August 12, 2022 16:13
-
-
Save Steboss89/1a06440446d42e29e891e4de70b249d0 to your computer and use it in GitHub Desktop.
Convert an array to tch::Tensor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::result::Result; | |
use std::error::Error; | |
use mnist::*; | |
use tch::{kind, no_grad, Kind, Tensor}; | |
use ndarray::{Array3, Array2}; | |
pub fn image_to_tensor(data:Vec<u8>, dim1:usize, dim2:usize, dim3:usize)-> Tensor{ | |
// normalize the image as well | |
let inp_data: Array3<f32> = Array3::from_shape_vec((dim1, dim2, dim3), data) | |
.expect("Error converting data to 3D array") | |
.map(|x| *x as f32/256.0); | |
// convert to tensor | |
let inp_tensor = Tensor::of_slice(inp_data.as_slice().unwrap()); | |
// reshape so we'll have dim1, dim2*dim3 shape array | |
let ax1 = dim1 as i64; | |
let ax2 = (dim2 as i64)*(dim3 as i64); | |
let shape: Vec<i64> = vec![ ax1, ax2 ]; | |
let output_data = inp_tensor.reshape(&shape); | |
println!("Output image tensor size {:?}", shape); | |
output_data | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment