Last active
January 24, 2024 07:52
-
-
Save robertknight/ad54cc02a79d0824e6e576401d3d433e to your computer and use it in GitHub Desktop.
rten-ndarray conversion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use ndarray::{Array, Array2, ArrayView, Dim, Dimension, Ix, StrideShape}; | |
use rten_tensor::prelude::*; | |
use rten_tensor::{NdTensor, NdTensorView}; | |
/// Convert an N-dimensional ndarray view to an [NdTensorView]. | |
/// | |
/// Returns `None` if the view is not in the standard layout (see | |
/// [ArrayView::is_standard_layout]). | |
fn as_ndtensor_view<'a, T, const N: usize>( | |
view: ArrayView<'a, T, Dim<[Ix; N]>>, | |
) -> Option<NdTensorView<'a, T, N>> | |
where | |
Dim<[Ix; N]>: Dimension, | |
{ | |
view.to_slice().map(|slice| { | |
let shape: [usize; N] = view.shape().try_into().unwrap(); | |
NdTensorView::from_data(shape, slice) | |
}) | |
} | |
/// Convert an N-dimensional [NdTensorView] into an ndarray view. | |
/// | |
/// Returns `None` if the view is not in "standard layout" (see | |
/// [ArrayView::is_standard_layout]). | |
fn as_array_view<'a, T, const N: usize>( | |
view: NdTensorView<'a, T, N>, | |
) -> Option<ArrayView<'a, T, Dim<[Ix; N]>>> | |
where | |
Dim<[Ix; N]>: Dimension, | |
[usize; N]: Into<StrideShape<Dim<[Ix; N]>>>, | |
{ | |
view.data() | |
.map(|data| ArrayView::from_shape(view.shape(), data).unwrap()) | |
} | |
/// Convert an N-dimensional [NdTensor] into an ndarray. | |
fn into_array<T, const N: usize>(tensor: NdTensor<T, N>) -> Array<T, Dim<[Ix; N]>> | |
where | |
Dim<[Ix; N]>: Dimension, | |
[usize; N]: Into<StrideShape<Dim<[Ix; N]>>>, | |
{ | |
Array::from_shape_vec(tensor.shape(), tensor.into_data()).unwrap() | |
} | |
fn main() { | |
// Owned ndarray => NdTensorView | |
let mut array: Array2<f32> = Array2::zeros([2, 2]); | |
array[[0, 0]] = 1.; | |
array[[0, 1]] = 2.; | |
array[[1, 0]] = 3.; | |
array[[1, 1]] = 4.; | |
let view = as_ndtensor_view(array.view()).expect("non-contiguous view"); | |
for (idx, el) in view.indices().zip(view.iter()) { | |
println!("index {:?} element {}", idx, el); | |
} | |
// NdTensor => ArrayView | |
let permuted_owned = view.permuted([1, 0]).to_tensor(); | |
let ndarray_view = as_array_view(permuted_owned.view()).expect("non-contiguous view"); | |
println!("ndarray_view {:?}", ndarray_view); | |
// Ndtensor => Array | |
let ndarray = into_array(permuted_owned); | |
println!("ndarray {:?}", ndarray); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment